diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..fb2799bc4e97931dbf07b7be3d51af5ac1012e70 --- /dev/null +++ b/.clang-format @@ -0,0 +1,151 @@ +--- +Language: Cpp +# BasedOnStyle: Google +AccessModifierOffset: -1 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Left +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: true +AllowShortLoopsOnASingleLine: true +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeInheritanceComma: false +BreakInheritanceList: BeforeColon +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeColon +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 2 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: +# - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^' + Priority: 2 + - Regex: '^<.*\.h>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IndentCaseLabels: true +IndentPPDirectives: None +IndentWidth: 2 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Never +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Right +RawStringFormats: + - Language: Cpp + Delimiters: + - cc + - CC + - cpp + - Cpp + - CPP + - 'c++' + - 'C++' + CanonicalDelimiter: '' + BasedOnStyle: google + - Language: TextProto + Delimiters: + - pb + - PB + - proto + - PROTO + EnclosingFunctions: + - EqualsProto + - EquivToProto + - PARSE_PARTIAL_TEXT_PROTO + - PARSE_TEST_PROTO + - PARSE_TEXT_PROTO + - ParseTextOrDie + - ParseTextProtoOrDie + CanonicalDelimiter: '' + BasedOnStyle: google +ReflowComments: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Auto +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +TabWidth: 2 +UseTab: Never +SortIncludes: false +... + diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..94cf2f5dfbf4c6f83f05c036cd35bfa254d9da05 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "mindspore"] + path = mindspore + url = https://gitee.com/mindspore/mindspore.git + shallow = true diff --git a/mindspore-lite/tools/graph_kernel/OWNERS b/.jenkins/OWNERS similarity index 55% rename from mindspore-lite/tools/graph_kernel/OWNERS rename to .jenkins/OWNERS index 940537e3d5cc20466c8be473fd4037dca129cf84..c5b51a62d074d1b51f72d0e6ad331603ed27b6da 100644 --- a/mindspore-lite/tools/graph_kernel/OWNERS +++ b/.jenkins/OWNERS @@ -1,8 +1,6 @@ approvers: -- jjfeing +- jjfeing # +- fengyixing - YeFeng_24 -- fatmouse007fatmouse007 - xu_anyue - -options: - no_parent_owners: true +- fatmouse007fatmouse007 diff --git a/.jenkins/check/config/filter_cppcheck.txt b/.jenkins/check/config/filter_cppcheck.txt new file mode 100644 index 0000000000000000000000000000000000000000..bc72526f8879ba0507c6dc42a3291d2ae6e719da --- /dev/null +++ b/.jenkins/check/config/filter_cppcheck.txt @@ -0,0 +1,42 @@ +# tools +"mindspore-lite/mindspore-lite/tools/common/flag_parser.cc" "useStlAlgorithm" +"mindspore-lite/mindspore-lite/tools/common/tensor_util.cc" "useStlAlgorithm" +"mindspore-lite/mindspore-lite/tools/converter/parser/onnx/onnx_relu_parser.cc" "useStlAlgorithm" +"mindspore-lite/mindspore-lite/tools/converter/parser/pytorch/pytorch_model_parser.cc" "variableScope" +"mindspore-lite/mindspore-lite/tools/converter/quantizer/quantize_util.cc" "useStlAlgorithm" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_and_aicpu/op_proto/" "syntaxError" +"mindspore-lite/mindspore-lite/tools/optimizer/fusion/batchnorm_to_scale_fusion.cc" "nullPointerRedundantCheck" + +# src +"mindspore-lite/mindspore-lite/src/common/draw/drawer.cc" "duplicateCondition" +"mindspore-lite/mindspore-lite/src/common/ops/unsqueeze.cc" "useStlAlgorithm" +"mindspore-lite/mindspore-lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc" "knownConditionTrueFalse" +"mindspore-lite/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc" "knownConditionTrueFalse" +"mindspore-lite/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_im2col_fp32.cc" "shadowVariable" +"mindspore-lite/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc" "knownConditionTrueFalse" +"mindspore-lite/mindspore-lite/src/litert/kernel/cpu/fp32/convolution_winograd_fp32.cc" "shadowVariable" +"mindspore-lite/mindspore-lite/src/litert/kernel/opencl/cl/" "unreadVariable" +"mindspore-lite/mindspore-lite/src/litert/kernel/opencl/kernel/" "unreadVariable" +"mindspore-lite/mindspore-lite/src/train/optimizer/fusion/gru_fusion_pass.cc" "stlFindInsert" + +# test +"mindspore-lite/mindspore-lite/test/" "syntaxError" +"mindspore-lite/mindspore-lite/test/ut/tools/converter/registry/pass_registry_test.cc" "unknownMacro" + +# MindData +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/dvpp_image_utils.cc" "nullPointerRedundantCheck" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/dvpp_image_utils.cc" "unsignedLessThanZero" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/dvpp_image_utils.cc" "constParameter" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/dvpp_image_utils.cc" "constParameter" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/dvpp_image_utils.cc" "useStlAlgorithm" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.cc" "unreadVariable" +"mindspore-lite/mindspore-lite/minddata/dataset/util/arena.cc" "useStlAlgorithm" + +# other +"mindspore-lite/mindspore-lite/examples/quick_start_micro/" "syntaxError" +"mindspore-lite/mindspore-lite/python/src/pybind_module.cc" "syntaxError" + +# nnacl +"mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/" "unreadVariable" +"mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_avx_fp32.c" "unknownMacro" + diff --git a/.jenkins/check/config/filter_cpplint.txt b/.jenkins/check/config/filter_cpplint.txt new file mode 100644 index 0000000000000000000000000000000000000000..f55a07b30fc294e8b162a48d91bee12a471896f5 --- /dev/null +++ b/.jenkins/check/config/filter_cpplint.txt @@ -0,0 +1,95 @@ +# tools +"mindspore-lite/mindspore-lite/tools/converter/config_parser/config_file_parser.h" "runtime/references" +"mindspore-lite/mindspore-lite/tools/converter/micro/coder/wrapper/" "readability/casting" +"mindspore-lite/mindspore-lite/tools/converter/model_parser.h" "build/namespaces" +"mindspore-lite/mindspore-lite/tools/converter/legacy_optimizer/fusion/fusion_pass.h" "runtime/references" +"mindspore-lite/mindspore-lite/tools/converter/optimizer.h" "build/namespaces" +"mindspore-lite/mindspore-lite/tools/converter/parser/caffe/caffe_node_parser.cc" "readability/casting" +"mindspore-lite/mindspore-lite/tools/converter/parser/tflite/tflite_node_parser.h" "runtime/references" +"mindspore-lite/mindspore-lite/tools/converter/quantizer/quantize_util.h" "runtime/references" +"mindspore-lite/mindspore-lite/tools/benchmark/benchmark.cc" "runtime/threadsafe_fn" +"mindspore-lite/mindspore-lite/tools/benchmark/benchmark_base.cc" "runtime/threadsafe_fn" +"mindspore-lite/mindspore-lite/tools/benchmark/benchmark_unified_api.cc" "runtime/threadsafe_fn" +"mindspore-lite/mindspore-lite/tools/benchmark/run_benchmark.cc" "runtime/threadsafe_fn" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/ascendc/op_host/decoder_kv_cache.cpp" "runtime/references" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/ascendc/op_host/platform_ascendc.h" "runtime/references" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/ascendc/op_host/prompt_kv_cache.cpp" "runtime/references" +"mindspore-lite/mindspore-lite/tools/optimizer/fusion/flash_attention_fusion.cc" "runtime/string" + +# src +"mindspore-lite/mindspore-lite/src/common/ops/ops_def.cc" "runtime/int" +"mindspore-lite/mindspore-lite/src/executor.h" "runtime/references" +"mindspore-lite/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc" "whitespace/parens" +"mindspore-lite/mindspore-lite/src/lite_kernel.h" "runtime/references" +"mindspore-lite/mindspore-lite/src/litert/delegate/coreml/coreml_executor.h" "readability/casting" +"mindspore-lite/mindspore-lite/src/litert/kernel/opencl/cl/" "legal/copyright" +"mindspore-lite/mindspore-lite/src/litert/kernel/opencl/cl/" "readability/casting" +"mindspore-lite/mindspore-lite/src/litert/kernel/opencl/cl/" "readability/fn_size" +"mindspore-lite/mindspore-lite/src/litert/opencl/opencl_executor.h" "runtime/references" +"mindspore-lite/mindspore-lite/src/litert/opencl/opencl_runtime.h" "runtime/references" +"mindspore-lite/mindspore-lite/src/litert/opencl/opencl_wrapper.h" "readability/casting" +"mindspore-lite/mindspore-lite/src/litert/thread_pool.c" "readability/casting" +"mindspore-lite/mindspore-lite/src/litert/thread_pool.c" "runtime/arrays" +"mindspore-lite/mindspore-lite/src/litert/thread_pool.c" "runtime/int" + +# MindData +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/cutmix_batch_op.cc" "build/include_what_you_use" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/acl_plugin.cc" "runtime/references" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.cc" "runtime/references" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.cc" "runtime/string" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.cc" "runtime/threadsafe_fn" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/AclLiteUtils.h" "runtime/references" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.cc" "build/storage_class" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.h" "runtime/references" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" "runtime/string" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/MDAclProcess.h" "runtime/references" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.h" "runtime/references" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/VdecHelper.cc" "build/namespaces" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/image_utils.cc" "runtime/int" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/lite_cv/image_process.cc" "runtime/references" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/lite_cv/image_process.h" "runtime/references" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/lite_cv/lite_mat.h" "runtime/explicit" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/resize_cubic_op.cc" "runtime/references" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/resize_cubic_op.h" "runtime/references" +"mindspore-lite/mindspore-lite/minddata/dataset/kernels/tensor_op.h" "runtime/references" +"mindspore-lite/mindspore-lite/minddata/dataset/util/bit.h" "runtime/references" +"mindspore-lite/mindspore-lite/minddata/dataset/util/btree.h" "build/include" + +# other +"mindspore-lite/mindspore-lite/examples/quick_start_c/main.c" "readability/casting" +"mindspore-lite/mindspore-lite/examples/quick_start_c/main.c" "runtime/threadsafe_fn" +"mindspore-lite/mindspore-lite/examples/quick_start_micro" "readability/casting" +"mindspore-lite/mindspore-lite/examples/runtime_gpu_extend/src/cl" "legal/copyright" +"mindspore-lite/mindspore-lite/examples/runtime_gpu_extend/src/cl" "readability/casting" +"mindspore-lite/mindspore-lite/examples/runtime_gpu_extend/src/cl" "readability/fn_size" +"mindspore-lite/mindspore-lite/include/lite_utils.h" "build/include_what_you_use" +"mindspore-lite/mindspore-lite/python/src/lite_infer_pybind.cc" "runtime/references" + +# ascend samples +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/aicpu/sample/" "build/include_subdir" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/aicpu/sample/" "runtime/references" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/aicpu/sample/" "whitespace/comments" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/aicpu/sample/" "legal/copyright" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/aicpu/sample/" "whitespace/ending_newline" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/aicpu/sample/" "build/include" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/ascendc/op_host/" "build/include" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/ascendc/op_host/" "runtime/references" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/ascendc/op_kernel/" "build/include" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/ascendc/op_kernel/" "build/namespaces" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_and_aicpu/cpukernel/impl/" "build/include" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_and_aicpu/op_proto/" "build/include" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "build/include_subdir" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "runtime/references" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "whitespace/comments" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "legal/copyright" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "whitespace/ending_newline" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "build/include" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "build/include_subdir" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "runtime/references" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "whitespace/comments" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "legal/copyright" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "whitespace/ending_newline" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "build/include" + +# nnacl +"mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/" "readability/casting" diff --git a/.jenkins/check/config/filter_linklint.txt b/.jenkins/check/config/filter_linklint.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt new file mode 100644 index 0000000000000000000000000000000000000000..065ccbc2feffede7c7106570c5182a44697d8d2f --- /dev/null +++ b/.jenkins/check/config/filter_pylint.txt @@ -0,0 +1,21 @@ +# MindSpore-Lite +"mindspore-lite/mindspore-lite/python/api/_checkparam.py" "len-as-condition" +"mindspore-lite/mindspore-lite/python/api/base_model.py" "len-as-condition" +"mindspore-lite/mindspore-lite/python/api/context.py" "protected-access" +"mindspore-lite/mindspore-lite/python/api/model.py" "protected-access" +"mindspore-lite/mindspore-lite/python/api/model.py" "len-as-condition" +"mindspore-lite/mindspore-lite/python/api/tensor.py" "protected-access" +"mindspore-lite/mindspore-lite/test" "missing-docstring" +"mindspore-lite/mindspore-lite/test" "unused-variable" +"mindspore-lite/mindspore-lite/test/st/python/import_ms_and_mslite/" "unused-import" +"mindspore-lite/mindspore-lite/test/st/python/test_large_model_inference.py" "redefined-outer-name" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/ascendc/cmake/util/insert_simplified_keys.py" "duplicate-key" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/ascendc/cmake/util/replay_codegen.py" "bad-continuation" + +# ascend samples +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "wrong-import-order" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "wrong-import-order" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "bad-whitespace" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" "bad-whitespace" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "bad-continuation" +"mindspore-lite/mindspore-lite/tools/kernel_builder/ascend/tbe_tik/sample/" diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt new file mode 100644 index 0000000000000000000000000000000000000000..aa510e221fe05bf3cc098707616de11d621dab97 --- /dev/null +++ b/.jenkins/check/config/whitelizard.txt @@ -0,0 +1,243 @@ +# Scene1: +# function_name1, function_name2 +# Scene2: +# file_path:function_name1, function_name2 +# + +# tools +mindspore-lite/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc:mindspore::lite::MatMulFusionMapper::Mapper +mindspore-lite/mindspore-lite/tools/converter/config_parser/config_file_parser.cc:mindspore::lite::ConfigFileParser::SetParamByConfigfile +mindspore-lite/mindspore-lite/tools/converter/graphdef_transform.cc:mindspore::lite::GraphDefTransform::Transform +mindspore-lite/mindspore-lite/tools/converter/parser/onnx/onnx_inputs_adjust.cc:mindspore::lite::OnnxInputAdjust::Adjust +mindspore-lite/mindspore-lite/tools/converter/quantizer/weight_quantizer.cc:mindspore::lite::quant::WeightQuantizer::LinearQuant +mindspore-lite/mindspore-lite/tools/optimizer/fusion/flash_attention_fusion.cc:mindspore::opt::FlashAttentionFusion::Process + +# src +mindspore-lite/mindspore-lite/src/common/ops/primitive_c.cc:mindspore::lite::PrimitiveC::Create +mindspore-lite/mindspore-lite/src/extendrt/convert/runtime_convert.cc:RuntimeConvert +mindspore-lite/mindspore-lite/src/litert/ios_reg_kernels.h:mindspore::kernel::IosRegisterKernels +mindspore-lite/mindspore-lite/src/litert/ios_reg_ops.cc:mindspore::lite::IosRegisterOps +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/base/quant_dtype_cast.cc:mindspore::kernel::QuantDTypeCastCPUKernel::QuantDTypeCast +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/base/quant_dtype_cast.cc:mindspore::kernel::QuantDTypeCastCPUKernel::Run +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/control/tensorlist_setitem.cc:mindspore::kernel::TensorListSetItemCPUKernel::Run +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::init_global_variable +mindspore-lite/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc:mindspore::kernel::MatmulFp32BaseCPUKernel::Run +mindspore-lite/mindspore-lite/src/litert/kernel/opencl/kernel/conv2d.cc:mindspore::kernel::UseWinograd4x4To6x6 +mindspore-lite/mindspore-lite/src/litert/kernel/opencl/kernel/fullconnection.cc:mindspore::kernel::FullConnectionOpenCLKernel::CheckSpecs +mindspore-lite/mindspore-lite/src/litert/scheduler.cc:mindspore::lite::Scheduler::FindBackendKernel +mindspore-lite/mindspore-lite/src/litert/thread_pool.c:GetArch +mindspore-lite/mindspore-lite/src/train/train_loop.cc:mindspore::lite::TrainLoop::Train + +# minddata +mindspore-lite/mindspore-lite/minddata/dataset/engine/datasetops/data_queue_op.cc:mindspore::dataset::DataQueueOp::SendDataToAscend +mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/dvpp_image_utils.cc:mindspore::dataset::DvppConvertColor, mindspore::dataset::DvppErase, mindspore::dataset::DvppNormalize, mindspore::dataset::DvppPerspective, mindspore::dataset::DvppResizedCrop, mindspore::dataset::DvppRotate, mindspore::dataset::CreateAclTensor +mindspore-lite/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/dvpp_video.cc:DvppVideo::SaveYuvFile + +# other +mindspore-lite/mindspore-lite/providers/nnie_proposal/src/proposal.cc:mindspore::proposal::Rpn +mindspore-lite/mindspore-lite/providers/nnie/src/custom_infer.cc:mindspore::nnie::CustomInterface::Infer + +# nnacl +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_infer.c:StridedSliceInferShape +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_infer.c:CheckInputShapeValid +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_transform_fp16.c:WinogradInputTransformFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pooling_fp16.c:AvgPoolingFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pooling_fp16.c:MaxPoolingFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c:OutputTransform4x2UnitFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c:OutputTransform4x2ReluUnitFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c:OutputTransform4x2Relu6UnitFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c:OutputTransform8x6UnitFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c:OutputTransform8x6ReluUnitFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c:OutputTransform8x6Relu6UnitFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/int8/pooling_int8.c:AvgPoolingOptInt8 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/int8/pooling_int8.c:MaxPoolingWithQuantInt8 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv3x3_int8.c:Conv3x3Int8OutputUnit +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_int8.c:Conv1x1PreOptPeroc +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/infer/infer_register.c:RegisterInfer +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/gemm.c:RowMajor2Col12MajorStride +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/gemm.c:RowMajor2Col8MajorStride +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_transform_fp16.c:Conv3x3Fp16InputUnit +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_transform_fp16.c:Conv3x3Fp16FilterTransform +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pooling_fp16.c:AvgPoolingFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pooling_fp16.c:MaxPoolingFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pack_fp16.c:PackNHWCToNCHWFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c:InputTransform6x6UnitFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c:InputTransform8x8UnitFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c:OutputTransform4x2Relu6UnitFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c:OutputTransform8x6UnitFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c:OutputTransform8x6ReluUnitFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c:OutputTransform8x6Relu6UnitFp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/int8/pooling_int8.c:AvgPoolingOptInt8 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv3x3_int8.c:Conv3x3Int8InputUnit +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv3x3_int8.c:Conv3x3Int8FilterTransform +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv3x3_int8.c:Conv3x3Int8OutputUnit +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_int8.c:Conv1x1PreOptPeroc +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_int8.c:Conv1x1PreOptPert +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/int8/pack_int8.c:PackNHWCToNCHWInt8 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pooling_fp32.c:AvgPooling +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_fp32.c:MatMul4x1Kernel, MatMul2x1Kernel +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_avx_fp32.c:SWConv3x32AVXKernel, SWConv4x24AVXKernel, SWConv12x8AVXKernel, SWConv8x8AVXKernel, SWConv4x8AVXKernel, SWConv6x16AVXKernel, SWConv4x16AVXKernel +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_fp32.c:DepthwiseSW3x32Kernel, DepthwiseSW4x24Kernel, DepthwiseSW12x8Kernel, DepthwiseSW8x8Kernel, DepthwiseSW4x8Kernel, DepthwiseSW6x16Kernel, DepthwiseSW4x16Kernel +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_1x1_avx_fp32.c:Conv1x1SW3x32AVXKernel, Conv1x1SW4x24AVXKernel, Conv1x1SW12x8AVXKernel, Conv1x1SW8x8AVXKernel, Conv1x1SW4x8AVXKernel, Conv1x1SW6x16AVXKernel, Conv1x1SW4x16AVXKernel, Conv1x1SW1x32AVXKernel, Conv1x1SW1x24AVXKernel, Conv1x1SW1x16AVXKernel, Conv1x1SW1x8AVXKernel +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_fp32.c:MatMul3x32Kernel, MatMul4x24Kernel, MatMul12x8Kernel, MatMul8x8Kernel, MatMul4x8Kernel, MatMul6x16Kernel, MatMul4x16Kernel, MatVecMul1x32Kernel, MatVecMul1x24Kernel, MatVecMul1x16Kernel, MatVecMul1x8Kernel +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx_fp32.c:MatMul3x32Kernel, MatMul4x24Kernel, MatMul12x8Kernel, MatMul8x8Kernel, MatMul4x8Kernel, MatMul6x16Kernel, MatMul4x16Kernel, MatVecMul1x32Kernel, MatVecMul1x24Kernel, MatVecMul1x16Kernel, MatVecMul1x8Kernel +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/TiledC4MatMulFp32.c:TiledC4MatmulFp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/PostFuncBiasReluC4.c:PostFuncBiasReluC4 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/WinogradTrans.c:WinogradTransRight +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/WinogradTrans.c:WinogradTransLeft +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/WinogradPostFuncBiasReluC4.c:WinogradPostFuncBiasReluC4 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/PostFuncBiasReluC8.c:PostFuncBiasReluC8 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/WinogradPostFuncBiasReluC8.c:WinogradPostFuncBiasReluC8 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_winograd_fp32.c:PackDeConvWgDataFp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/WinogradPostFuncBiasReluC4.c:WinogradPostFuncBiasReluC4 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/PostFuncBiasReluC8.c:PostFuncBiasReluC8 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/WinogradTransAvx.c:WinogradTransLeft +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/WinogradTransAvx.c:WinogradTransRight +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/PostFuncBiasReluC8.c:PostFuncBiasReluC8 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/WinogradPostFuncBiasReluC8.c:WinogradPostFuncBiasReluC8 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_winograd_fp32.c:PackDeConvWgDataFp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_winograd_fp32.c:DeConvWgMerge +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/TiledC8MatMulFp32.c:TiledC8MatmulFp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/quant_dtype_cast_fp16.c:Fp16ToInt8_arm64 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_10x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_11x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x96_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x96_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_9x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_9x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x96_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_10x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_11x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x80_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x32_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x96_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x16_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x48_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x64_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp16/instance_norm_fp16.c:InstanceNormNC8HW8Fp16 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/MatMul_Sse.c:MatmulFloatSse64Opt +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_winograd_fp32.c:ConvWinogardFp32 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_winograd_fp32.c:ConvWinogardFp32CutByBatch +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/experimental/conv_fp32_nchwx_avx512.c:conv2d_compute_fp32_nchwx_avx512 +mindspore-lite/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_mask_fp32.c:GemmRowxColMaskKernelFp32 \ No newline at end of file diff --git a/.jenkins/rules/codespell/codespell.allow b/.jenkins/rules/codespell/codespell.allow new file mode 100644 index 0000000000000000000000000000000000000000..90fad709573eb85592cd360c24392362303ceb87 --- /dev/null +++ b/.jenkins/rules/codespell/codespell.allow @@ -0,0 +1,19 @@ +nd +te +ans +tbe +TBE +dout +claus +crate +creat +imBED +rouge +noone +outputOf +followings +functionAble +Aline +allEdges +FillD +enque diff --git a/.jenkins/rules/codespell/sensitive.allow b/.jenkins/rules/codespell/sensitive.allow new file mode 100644 index 0000000000000000000000000000000000000000..3e817b177c13bc4eaa79b4803d3c2bf59df4ae13 --- /dev/null +++ b/.jenkins/rules/codespell/sensitive.allow @@ -0,0 +1,29 @@ +0x23910C13 +Qw1981Gv +TASSP.1980.1163351 +40198016 +64198005 +mW1980OTa +CF1980XD +1160, 1981. +n01980 +n01981 +0001980 +0001981 +0011980 +0011981 +0021980 +0021981 +0031980 +0031981 +18198144 +4681981 +40198135 +6281981 +5198169 +70198125 +IBM Corporation, 1981 +198020 +19805, +198016 +0.1980 \ No newline at end of file diff --git a/.jenkins/rules/darglint/tox.ini b/.jenkins/rules/darglint/tox.ini new file mode 100644 index 0000000000000000000000000000000000000000..e60c6c062bf9514d7d123273f59611ba6aa87fce --- /dev/null +++ b/.jenkins/rules/darglint/tox.ini @@ -0,0 +1,3 @@ +[darglint] +ignore=DAR001,DAR004,DAR006,DAR101,DAR102,DAR103,DAR104,DAR105,DAR201,DAR202,DAR203,DAR301,DAR302,DAR401,DAR501 +ignore_regex=^_(.*) diff --git a/.jenkins/rules/markdownlint/markdownlint_docs.rb b/.jenkins/rules/markdownlint/markdownlint_docs.rb new file mode 100644 index 0000000000000000000000000000000000000000..560889114b860a5e6ffcc0085c60402e3af84f55 --- /dev/null +++ b/.jenkins/rules/markdownlint/markdownlint_docs.rb @@ -0,0 +1,13 @@ +all +rule 'MD007', :indent => 4 +rule 'MD009', :br_spaces => 2 +rule 'MD029', :style => :ordered +exclude_rule 'MD013' +exclude_rule 'MD002' +exclude_rule 'MD041' +exclude_rule 'MD005' +exclude_rule 'MD024' +exclude_rule 'MD033' +exclude_rule 'MD029' +exclude_rule 'MD034' +exclude_rule 'MD036' diff --git a/.jenkins/rules/markdownlint/markdownlint_mindspore_lite.rb b/.jenkins/rules/markdownlint/markdownlint_mindspore_lite.rb new file mode 100644 index 0000000000000000000000000000000000000000..37889e7753a1b8bc7f20ae15763b1d7ef81180d9 --- /dev/null +++ b/.jenkins/rules/markdownlint/markdownlint_mindspore_lite.rb @@ -0,0 +1,14 @@ +all +rule 'MD007', :indent => 4 +rule 'MD009', :br_spaces => 2 +rule 'MD029', :style => :ordered +exclude_rule 'MD013' +exclude_rule 'MD002' +exclude_rule 'MD041' +exclude_rule 'MD005' +exclude_rule 'MD024' +exclude_rule 'MD033' +exclude_rule 'MD029' +exclude_rule 'MD025' +exclude_rule 'MD034' +exclude_rule 'MD036' diff --git a/.jenkins/rules/pylint/pylintrc b/.jenkins/rules/pylint/pylintrc new file mode 100755 index 0000000000000000000000000000000000000000..83c78e98a8d6f3400d753c7c6eb4668c313cf50c --- /dev/null +++ b/.jenkins/rules/pylint/pylintrc @@ -0,0 +1,337 @@ +[MASTER] + +# Specify a configuration file. +#rcfile= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Profiled execution. +profile=no + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Pickle collected data for later comparisons. +persistent=yes + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + + +[MESSAGES CONTROL] + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time. See also the "--disable" option for examples. +enable=indexing-exception,old-raise-syntax + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager + + +# Set the cache size for astng objects. +cache-size=500 + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Put messages in a separate file for each module / package specified on the +# command line instead of printing them on stdout. Reports (if any) will be +# written in a file name "pylint_global.[txt|html]". +files-output=no + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Add a comment according to your evaluation note. This is used by the global +# evaluation report (RP0004). +comment=no + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[TYPECHECK] + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# List of classes names for which member attributes should not be checked +# (useful for classes with attributes dynamically set). +ignored-classes=SQLObject + +# When zope mode is activated, add a predefined set of Zope acquired attributes +# to generated-members. +zope=no + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E0201 when accessed. Python regular +# expressions are accepted. +generated-members=REQUEST,acl_users,aq_parent + +# List of decorators that create context managers from functions, such as +# contextlib.contextmanager. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the beginning of the name of dummy variables +# (i.e. not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + + +[BASIC] + +# Required attributes for module, separated by a comma +required-attributes= + +# List of builtins function names that should not be used, separated by a comma +bad-functions=apply,input,reduce + + +# Disable the report(s) with the given id(s). +# All non-Google reports are disabled by default. +disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 + +# Regular expression which should only match correct module names +module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ + +# Regular expression which should only match correct module level names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression which should only match correct function names +function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct method names +method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match correct instance attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression which should only match correct attribute names in class +# bodies +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression which should only match correct list comprehension / +# generator expression variable names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main) + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=10 + + +[FORMAT] + +# Maximum number of characters on a single line. +# mindspore: max-line-length=120 (Default: 80) +max-line-length=120 + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x) + (^\s*(import|from)\s + |\$Id:\s\/\/depot\/.+#\d+\s\$ + |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') + |^\s*\#\ LINT\.ThenChange + |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$ + |pylint + |""" + |\# + |lambda + |(https?|ftp):) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=y + +# List of optional constructs for which whitespace checking is disabled +no-space-check= + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +# mindspore: 4 spaces (Default: 2 spaces) +indent-string=' ' + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes= + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + + +[CLASSES] + +# List of interface methods to ignore, separated by a comma. This is used for +# instance to not check methods defines in Zope's Interface base class. +ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls,class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=5 + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.* + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of statements in function / method body +max-statements=50 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception,StandardError,BaseException + + +[AST] + +# Maximum line length for lambdas +short-func-length=1 + +# List of module members that should be marked as deprecated. +# All of the string functions are listed in 4.1.4 Deprecated string functions +# in the Python 2.4 docs. +deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc + + +[DOCSTRING] + +# List of exceptions that do not need to be mentioned in the Raises section of +# a docstring. +ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError + + + +[TOKENS] + +# Number of spaces of indent required when the last token on the preceding line +# is an open (, [, or {. +indent-after-paren=4 + + +[MINDSPORE LINES] + +# Regexp for a proper copyright notice. +copyright=Copyright \d{4} The MindSpore Authors\. +All [Rr]ights [Rr]eserved\. diff --git a/build.bat b/build.bat index 1405ccd75cde4fd0c2274d2ca8d3823ac533af53..ec9c81462bfb013b45ac1bb363acfa5c41d36bd7 100644 --- a/build.bat +++ b/build.bat @@ -34,6 +34,9 @@ set ENABLE_FFMPEG=ON set ENABLE_FFMPEG_DOWNLOAD=OFF for /f "tokens=1" %%a in (version.txt) do (set VERSION_STR=%%a) git submodule update --init --remote mindspore +cd "%BASEPATH%\mindspore" && ( + git apply "%BASEPATH%\third_party\patch\mindspore\decouple_mindspore.patch" "%BASEPATH%\third_party\patch\mindspore\move-ccsrc-ops-dependencies.patch" +) ECHO %2%|FINDSTR "^[0-9][0-9]*$" IF %errorlevel% == 0 ( SET threads=%2% @@ -65,19 +68,47 @@ IF NOT EXIST "%BUILD_PATH%/mindspore" ( ) cd %BUILD_PATH%/mindspore - -echo "======Start building MindSpore Lite %VERSION_STR%======" -rd /s /q "%BASE_PATH%\output" -(git log -1 | findstr "^commit") > %BUILD_PATH%\.commit_id -IF defined VisualStudioVersion ( - cmake -DMSLITE_MINDDATA_IMPLEMENT=off -DMSLITE_ENABLE_TRAIN=off -DVERSION_STR=%VERSION_STR% ^ - -DCMAKE_BUILD_TYPE=Release -G "Ninja" "%BASE_PATH%/mindspore-lite" +IF "%1%" == "lite" ( + echo "======Start building MindSpore Lite %VERSION_STR%======" + rd /s /q "%BASE_PATH%\output" + (git log -1 | findstr "^commit") > %BUILD_PATH%\.commit_id + IF defined VisualStudioVersion ( + cmake -DMSLITE_MINDDATA_IMPLEMENT=off -DMSLITE_ENABLE_TRAIN=off -DVERSION_STR=%VERSION_STR% ^ + -DCMAKE_BUILD_TYPE=Release -G "Ninja" "%BASE_PATH%/mindspore-lite" + ) ELSE ( + cmake -DMSLITE_MINDDATA_IMPLEMENT=off -DMSLITE_ENABLE_TRAIN=off -DVERSION_STR=%VERSION_STR% ^ + -DCMAKE_BUILD_TYPE=Release -G "CodeBlocks - MinGW Makefiles" "%BASE_PATH%/mindspore-lite" + ) ) ELSE ( - cmake -DMSLITE_MINDDATA_IMPLEMENT=off -DMSLITE_ENABLE_TRAIN=off -DVERSION_STR=%VERSION_STR% ^ - -DCMAKE_BUILD_TYPE=Release -G "CodeBlocks - MinGW Makefiles" "%BASE_PATH%/mindspore-lite" + for /f "delims=" %%i in ('powershell.exe -ExecutionPolicy Bypass -Command "Get-ChildItem HKLM:\SOFTWARE\Wow6432Node\Microsoft\Windows\CurrentVersion\Uninstall | foreach { Get-ItemProperty $_.PsPath } | where { $_.DisplayName -like '*Visual Studio*' -and $_.InstallLocation.Length -gt 0 } | sort InstallDate -Descending | foreach { Join-Path $_.InstallLocation 'VC\Auxiliary\Build'}"') do (call "%%i\vcvars64.bat") + SET CMAKE_ARGS=-DENABLE_CPU=ON -DENABLE_MINDDATA=ON -DUSE_GLOG=ON -DENABLE_GITEE=%ENABLE_GITEE% -DCMAKE_EXE_LINKER_FLAGS="/manifest:no" -DCMAKE_MODULE_LINKER_FLAGS="/manifest:no" -DCMAKE_SHARED_LINKER_FLAGS="/manifest:no" + where ccache + IF !errorlevel! == 0 ( + echo "use ccache to speed up compile" + SET CMAKE_ARGS=!CMAKE_ARGS! -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + ) + IF "%1%" == "ms_vs_gpu" ( + echo "======Start gen VS2019 Project for MS gpu ======" + SET CMAKE_ARGS=!CMAKE_ARGS! -DCMAKE_BUILD_TYPE=Release -DENABLE_GPU=ON -DGPU_BACKEND_CUDA=ON -DMS_REQUIRE_CUDA_VERSION=11.1 + ) ELSE IF "%1%" == "ms_vs_cpu" ( + echo "======Start gen VS2019 Project for MS cpu ======" + SET CMAKE_ARGS=!CMAKE_ARGS! -DCMAKE_BUILD_TYPE=Release + ) ELSE IF "%1%" == "ms_vs_cpu_debug" ( + echo "======Start gen VS2019 Project for MS cpu debug======" + SET CMAKE_ARGS=!CMAKE_ARGS! -DCMAKE_BUILD_TYPE=Debug -DDEBUG_MODE=ON + set BUILD_TYPE=Debug + ) + IF ON == %ENABLE_FFMPEG% ( + call %BASE_PATH%\cmake\external_libs\ffmpeg.bat + IF errorlevel 1 ( + echo "cmake fail." + call :clean + EXIT /b 1 + ) + ) + cmake !CMAKE_ARGS! -G Ninja ../.. ) - IF NOT %errorlevel% == 0 ( echo "cmake fail." call :clean diff --git a/build.sh b/build.sh index 80bbf0f1f7559a30616e853daaebc6ff34f491a5..5d91ce33feb69b549fdc8a8353e10abf9f5bfb44 100755 --- a/build.sh +++ b/build.sh @@ -39,6 +39,8 @@ check_on_off() update_submodule() { git submodule update --init --remote mindspore +# cd "${BASEPATH}/mindspore" +# TODO(compile so used in mindspore-lite) } build_exit() @@ -48,8 +50,14 @@ build_exit() exit 1 } +make_clean() +{ + echo "enable make clean" + cd "${BUILD_PATH}/mindspore" + cmake --build . --target clean +} update_submodule -echo "---------------- MindSpore-Lite: build start ----------------" +echo "---------------- MindSpore: build start ----------------" init_default_options process_options "$@" @@ -57,6 +65,5 @@ export COMPILE_MINDDATA_LITE export ENABLE_VERBOSE export LITE_PLATFORM export LITE_ENABLE_AAR -source ./scripts/build/build_lite.sh - -echo "---------------- MindSpore-Lite: build end ----------------" +source mindspore-lite/build_lite.sh +echo "---------------- MindSpore: build end ----------------" diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index 4307b7791d28cfb9e08ed7a61e99dc60aa0fcb0f..35940e4dd5e4e9c2a9537f8c1ad24019c1f531f0 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -5,7 +5,7 @@ set(RUNTIME_PKG_NAME ${PKG_NAME_PREFIX}-${RUNTIME_COMPONENT_NAME}) set(CONVERTER_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/converter) set(OBFUSCATOR_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/obfuscator) set(CROPPER_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/cropper) -set(BUILD_DIR ${TOP_DIR}/build) +set(BUILD_DIR ${TOP_DIR}/mindspore-lite/build) set(TEST_CASE_DIR ${TOP_DIR}/mindspore-lite/test/build) set(EXTENDRT_BUILD_DIR ${BUILD_DIR}/src/extendrt) set(EXECUTOR_BUILD_DIR ${BUILD_DIR}/src/extendrt/unified_executor) @@ -220,7 +220,8 @@ function(__install_ascend_ascendc) endfunction() # full mode will also package the files of lite_cv mode. -if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") +if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full" + AND NOT(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) # full header files install(FILES ${TOP_DIR}/mindspore-lite/minddata/dataset/include/dataset/constants.h @@ -257,10 +258,10 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") DESTINATION ${SECUREC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) else() if((MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) AND MSLITE_ENABLE_ACL) - install(FILES ${TOP_DIR}/mindspore-lite/minddata/dataset/include/dataset/vision_ascend.h - DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) - install(FILES ${BUILD_DIR}/minddata/kernels-dvpp-image/utils/libdvpp_utils.so - DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(FILES ${TOP_DIR}/mindspore-lite/minddata/dataset/include/dataset/vision_ascend.h + DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(FILES ${BUILD_DIR}/minddata/kernels-dvpp-image/utils/libdvpp_utils.so + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) endif() install(FILES ${BUILD_DIR}/minddata/libminddata-lite.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) @@ -279,20 +280,21 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") endif() -if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper") +if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper" + AND NOT(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) install(DIRECTORY ${TOP_DIR}/mindspore-lite/minddata/dataset/include/ DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "vision.h" EXCLUDE) if(PLATFORM_ARM64) install(FILES ${BUILD_DIR}/minddata/libminddata-lite.so DESTINATION - ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${JPEGTURBO_LIB_LIST} DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) elseif(PLATFORM_ARM32) install(FILES ${BUILD_DIR}/minddata/libminddata-lite.so DESTINATION - ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${JPEGTURBO_LIB_LIST} DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) else() install(FILES ${BUILD_DIR}/minddata/libminddata-lite.so DESTINATION - ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${jpeg_turbo_LIBPATH}/libjpeg.so.62.4.0 DESTINATION ${TURBO_DIR}/lib RENAME libjpeg.so.62 COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${jpeg_turbo_LIBPATH}/libturbojpeg.so.0.3.0 DESTINATION ${TURBO_DIR}/lib RENAME libturbojpeg.so.0 @@ -300,26 +302,27 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper") endif() endif() -if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite") +if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" + AND NOT(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) install(DIRECTORY ${TOP_DIR}/mindspore-lite/minddata/dataset/include/ DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") if(PLATFORM_ARM64) install(FILES ${BUILD_DIR}/minddata/libminddata-lite.so DESTINATION - ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libjpeg.so DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libturbojpeg.so DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) elseif(PLATFORM_ARM32) install(FILES ${BUILD_DIR}/minddata/libminddata-lite.so DESTINATION - ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libjpeg.so DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libturbojpeg.so DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) else() install(FILES ${BUILD_DIR}/minddata/libminddata-lite.so DESTINATION - ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libjpeg.so.62.4.0 DESTINATION ${TURBO_DIR}/lib RENAME libjpeg.so.62 COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libturbojpeg.so.0.3.0 @@ -327,7 +330,8 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite") endif() endif() -if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite_cv") +if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite_cv" AND + NOT(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) if(PLATFORM_ARM64) install(DIRECTORY ${TOP_DIR}/mindspore-lite/minddata/dataset/kernels/image/lite_cv DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") @@ -337,7 +341,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite_cv") install(DIRECTORY ${TOP_DIR}/mindspore-lite/minddata/dataset/kernels/image/lite_cv DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") install(FILES ${BUILD_DIR}/minddata/libminddata-lite.so DESTINATION - ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) else() install(DIRECTORY ${TOP_DIR}/mindspore-lite/minddata/dataset/kernels/image/lite_cv DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") @@ -409,20 +413,20 @@ if(PLATFORM_ARM64) install(FILES ${BUILD_DIR}/src/extendrt/convert/libruntime_convert_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) if(MSLITE_ENABLE_ACL) - install(FILES ${BUILD_DIR}/src/extendrt/kernel/ascend/libascend_kernel_plugin.so + install(FILES ${BUILD_DIR}/src/extendrt/delegate/ascend_acl/libascend_acl_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) if(NOT MSLITE_SIMPLEST_CLOUD_INFERENCE) install(FILES ${BUILD_DIR}/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) endif() - install(FILES ${BUILD_DIR}/src/extendrt/cxx_api/llm_engine/libllm_engine_plugin.so - DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + # install(FILES ${TOP_DIR}/mindspore-lite/build/src/extendrt/cxx_api/llm_engine/libllm_engine_plugin.so + # DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) __install_ascend_tbe_and_aicpu() __install_ascend_ascendc() endif() if(MSLITE_GPU_BACKEND STREQUAL tensorrt) install(FILES ${BUILD_DIR}/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so - DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) endif() else() install(FILES ${BUILD_DIR}/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR} @@ -430,7 +434,7 @@ if(PLATFORM_ARM64) install(FILES ${BUILD_DIR}/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) if(MSLITE_ENABLE_ACL) - install(FILES ${BUILD_DIR}/src/litert/kernel/ascend/libascend_kernel_plugin.so + install(FILES ${BUILD_DIR}/src/extendrt/delegate/ascend_acl/libascend_acl_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) endif() endif() @@ -534,11 +538,11 @@ if(PLATFORM_ARM64) endif() if(MSLITE_ENABLE_ACL) set(LITE_ACL_DIR ${BUILD_DIR}/tools/converter/adapter/acl) - install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so - DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) +# install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so +# DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) if(MSLITE_ENABLE_RUNTIME_CONVERT) - install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so - DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + # install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so + # DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${glog_LIBPATH}/${glog_name} DESTINATION ${RUNTIME_LIB_DIR} RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) install(TARGETS mindspore_core mindspore_ops DESTINATION ${CONVERTER_ROOT_DIR}/lib @@ -669,20 +673,20 @@ elseif(PLATFORM_ARM32) install(FILES ${BUILD_DIR}/src/extendrt/convert/libruntime_convert_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) if(MSLITE_ENABLE_ACL) - install(FILES ${BUILD_DIR}/src/extendrt/kernel/ascend/libascend_kernel_plugin.so + install(FILES ${BUILD_DIR}/src/extendrt/delegate/ascend_acl/libascend_acl_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) if(NOT MSLITE_SIMPLEST_CLOUD_INFERENCE) install(FILES ${BUILD_DIR}/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) endif() - install(FILES ${BUILD_DIR}/src/extendrt/cxx_api/llm_engine/libllm_engine_plugin.so - DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +# install(FILES ${BUILD_DIR}/src/extendrt/cxx_api/llm_engine/libllm_engine_plugin.so +# DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) __install_ascend_tbe_and_aicpu() __install_ascend_ascendc() endif() if(MSLITE_GPU_BACKEND STREQUAL tensorrt) install(FILES ${BUILD_DIR}/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so - DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) endif() else() install(FILES ${BUILD_DIR}/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR} @@ -690,7 +694,7 @@ elseif(PLATFORM_ARM32) install(FILES ${BUILD_DIR}/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) if(MSLITE_ENABLE_ACL) - install(FILES ${BUILD_DIR}/src/litert/kernel/ascend/libascend_kernel_plugin.so + install(FILES ${BUILD_DIR}/src/extendrt/delegate/ascend_acl/libascend_acl_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) endif() endif() @@ -878,14 +882,14 @@ else() install(FILES ${BUILD_DIR}/src/extendrt/convert/libruntime_convert_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) if(MSLITE_ENABLE_ACL) - install(FILES ${BUILD_DIR}/src/extendrt/kernel/ascend/libascend_kernel_plugin.so + install(FILES ${BUILD_DIR}/src/extendrt/delegate/ascend_acl/libascend_acl_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) if(NOT MSLITE_SIMPLEST_CLOUD_INFERENCE) install(FILES ${BUILD_DIR}/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) endif() - install(FILES ${BUILD_DIR}/src/extendrt/cxx_api/llm_engine/libllm_engine_plugin.so - DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + # install(FILES ${TOP_DIR}/mindspore-lite/build/src/extendrt/cxx_api/llm_engine/libllm_engine_plugin.so + # DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) __install_ascend_tbe_and_aicpu() __install_ascend_ascendc() if(MSLITE_ASCEND_TARGET) @@ -901,7 +905,7 @@ else() endif() if(MSLITE_GPU_BACKEND STREQUAL tensorrt) install(FILES ${BUILD_DIR}/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so - DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) endif() else() install(FILES ${BUILD_DIR}/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR} @@ -909,7 +913,7 @@ else() install(FILES ${BUILD_DIR}/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) if(MSLITE_ENABLE_ACL) - install(FILES ${BUILD_DIR}/src/litert/kernel/ascend/libascend_kernel_plugin.so + install(FILES ${BUILD_DIR}/src/extendrt/delegate/ascend_acl/libascend_acl_plugin.so DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) endif() endif() @@ -973,11 +977,11 @@ else() if(MSLITE_ENABLE_ACL) set(LITE_ACL_DIR ${BUILD_DIR}/tools/converter/adapter/acl) - install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so - DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) +# install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so +# DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) if(MSLITE_ENABLE_RUNTIME_CONVERT) - install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so - DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +# install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so +# DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${glog_LIBPATH}/${glog_name} DESTINATION ${RUNTIME_LIB_DIR} RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) install(TARGETS mindspore_core mindspore_ops DESTINATION ${RUNTIME_LIB_DIR} @@ -1057,14 +1061,14 @@ else() if(NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) install(TARGETS cropper RUNTIME DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${BUILD_DIR}/tools/cropper/cropper_mapping_cpu.cfg - DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${BUILD_DIR}/tools/cropper/cropper_mapping_gpu.cfg - DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) install(FILES ${BUILD_DIR}/tools/cropper/cropper_mapping_npu.cfg - DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) if(SUPPORT_TRAIN) install(FILES ${BUILD_DIR}/tools/cropper/cropper_mapping_cpu_train.cfg - DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + DESTINATION ${CROPPER_ROOT_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) endif() endif() endif() diff --git a/mindspore b/mindspore new file mode 160000 index 0000000000000000000000000000000000000000..21d4121389d099a511df6caccb951ef0a7a36e32 --- /dev/null +++ b/mindspore @@ -0,0 +1 @@ +Subproject commit 21d4121389d099a511df6caccb951ef0a7a36e32 diff --git a/mindspore-lite/CMakeLists.txt b/mindspore-lite/CMakeLists.txt index 593d9c0fa68e01a9e84ad41a3b1c978250b103d2..78e3c3c9542513ae196dd8261b2ffd02bf09c1b6 100644 --- a/mindspore-lite/CMakeLists.txt +++ b/mindspore-lite/CMakeLists.txt @@ -12,9 +12,7 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/secure_option.cmake) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/compile_link_option.cmake) # =================================== [Lite build options] =================================== - # === part 1. basic options === - # kernel option(MSLITE_ENABLE_SSE "enable SSE instruction set, only x86_64 support" off) option(MSLITE_ENABLE_AVX "enable AVX instruction set, only x86_64 support" off) @@ -42,7 +40,6 @@ option(MSLITE_ENABLE_RUNTIME_PASS "enable runtime pass" on) option(MSLITE_ENABLE_BFC_MEMORY "enable distribute BFC memory" off) option(MSLITE_ENABLE_RUNTIME_GLOG "enable runtime glog" off) -option(MSLITE_EXPORT_COMPUTE_IR "export graph structure dotfile when debug mode" off) option(MSLITE_ENABLE_MODEL_PRE_INFERENCE "enable model do pre-inference when Build interface is called" off) # tools @@ -56,9 +53,10 @@ option(MSLITE_ENABLE_MODEL_OBF "enable model obfuscation" off) # pre/post process option(MSLITE_ENABLE_OPENCV "enable opencv" on) -set(MSLITE_MINDDATA_IMPLEMENT "full" CACHE STRING "minddata mode, \ +if(NOT((MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE))) + set(MSLITE_MINDDATA_IMPLEMENT "full" CACHE STRING "minddata mode, \ currently supported : off/full/wrapper/lite/lite_cv") - +endif() # tests option(MSLITE_ENABLE_TESTCASES "enable testcase" off) option(MSLITE_ENABLE_COVERAGE "enable code coverage" off) @@ -94,10 +92,8 @@ option(MSLITE_ENABLE_ACL_QUANT_PARAM "enable ACL_QUANT_PARAM" off) # runtime option(MSLITE_ENABLE_SHARING_MODEL_WEIGHT "enable sharing model weight" off) -option(MSLITE_ENABLE_SERVER_INFERENCE "enable inference on server" off) option(MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE "enable distribute thread dynamically" off) option(MSLITE_ENABLE_AUTO_PARALLEL "enable automatic parallelism" on) -option(MSLITE_ENABLE_PARALLEL_INFERENCE "enable parallel inference interface" off) option(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE "enable cloud and device fusion inference architecture" off) option(MSLITE_ENABLE_CLOUD_INFERENCE "alias to MSLITE_ENABLE_CLOUD_FUSION_INFERENCE" off) @@ -124,9 +120,6 @@ if(DEFINED ENV{MSLITE_ENABLE_CLOUD_INFERENCE}) set(MSLITE_ENABLE_CLOUD_INFERENCE $ENV{MSLITE_ENABLE_CLOUD_INFERENCE}) endif() -if(DEFINED ENV{MSLITE_GPU_BACKEND}) - set(MSLITE_GPU_BACKEND $ENV{MSLITE_GPU_BACKEND}) -endif() if(DEFINED ENV{MSLITE_REGISTRY_DEVICE}) set(MSLITE_REGISTRY_DEVICE $ENV{MSLITE_REGISTRY_DEVICE}) endif() @@ -147,20 +140,15 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) set(MSLITE_ENABLE_TRAIN OFF) endif() -if(DEFINED ENV{MSLITE_ENABLE_SERVER_INFERENCE}) - set(MSLITE_ENABLE_SERVER_INFERENCE $ENV{MSLITE_ENABLE_SERVER_INFERENCE}) -endif() if(MSLITE_ENABLE_SERVER_INFERENCE OR MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) set(MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE on) set(MSLITE_ENABLE_BFC_MEMORY on) - set(MSLITE_ENABLE_PARALLEL_INFERENCE on) set(MSLITE_ENABLE_SHARING_MODEL_WEIGHT on) set(MSLITE_ENABLE_RUNTIME_GLOG on) set(MSLITE_ENABLE_AVX512 on) set(MSLITE_ENABLE_CAPTURE_SIGNALS off) set(MSLITE_ENABLE_PACKAGE_WHEEL on) - add_compile_definitions(SERVER_INFERENCE) endif() if(DEFINED ENV{MSLITE_ENABLE_PACKAGE_WHEEL}) @@ -244,7 +232,8 @@ endif() if(DEFINED ENV{MSLITE_ENABLE_OPENCV}) set(MSLITE_ENABLE_OPENCV $ENV{MSLITE_ENABLE_OPENCV}) endif() -if(DEFINED ENV{MSLITE_MINDDATA_IMPLEMENT}) +if(DEFINED ENV{MSLITE_MINDDATA_IMPLEMENT} AND +NOT(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) set(MSLITE_MINDDATA_IMPLEMENT $ENV{MSLITE_MINDDATA_IMPLEMENT}) endif() if(DEFINED ENV{MSLITE_TARGET_SITEAI}) @@ -300,11 +289,6 @@ if(DEFINED ENV{MSLITE_ENABLE_GITEE_MIRROR}) set(MSLITE_ENABLE_GITEE_MIRROR $ENV{MSLITE_ENABLE_GITEE_MIRROR}) endif() -if(MSLITE_EXPORT_COMPUTE_IR) - set(MSLITE_EXPORT_COMPUTE_IR $ENV{MSLITE_EXPORT_COMPUTE_IR}) - add_compile_definitions(ENABLE_DRAW) -endif() - if(MSLITE_ENABLE_GITEE_MIRROR) set(ENABLE_GITEE ON) endif() @@ -382,7 +366,6 @@ elseif(PLATFORM_ARM32) set(MSLITE_GPU_BACKEND "off") endif() elseif(WIN32) -# set(MSLITE_GPU_BACKEND "off") else() if(${MSLITE_REGISTRY_DEVICE} STREQUAL "SD3403" AND (NOT MSLITE_ENABLE_ACL)) set(MSLITE_ENABLE_DPICO_ATC_ADAPTER on) @@ -406,7 +389,7 @@ if(PLATFORM_ARM64 OR PLATFORM_ARM32) if(NOT MACHINE_LINUX_ARM64) set(MSLITE_ENABLE_CONVERTER off) endif() - if(MSLITE_ENABLE_SERVER_INFERENCE) + if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) set(MSLITE_ENABLE_RUNTIME_GLOG on) else() set(MSLITE_ENABLE_RUNTIME_GLOG off) @@ -446,14 +429,6 @@ if(MSLITE_ENABLE_BFC_MEMORY) add_compile_definitions(BFC_MEMORY) endif() -if(DEFINED ENV{MSLITE_ENABLE_PARALLEL_INFERENCE}) - set(MSLITE_ENABLE_PARALLEL_INFERENCE $ENV{MSLITE_ENABLE_PARALLEL_INFERENCE}) -endif() - -if(MSLITE_ENABLE_PARALLEL_INFERENCE) - add_compile_definitions(PARALLEL_INFERENCE) -endif() - if(DEFINED ENV{MSLITE_ENABLE_SHARING_MODEL_WEIGHT}) set(MSLITE_ENABLE_SHARING_MODEL_WEIGHT $ENV{MSLITE_ENABLE_SHARING_MODEL_WEIGHT}) endif() @@ -472,9 +447,6 @@ endif() if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) set(MSLITE_ENABLE_RUNTIME_GLOG on) set(MSLITE_ENABLE_TRAIN off) - if(NOT MSLITE_MINDDATA_IMPLEMENT STREQUAL "off") - add_compile_definitions(ENABLE_MINDDATA_PYTHON) - endif() endif() if(MSLITE_ENABLE_TRAIN AND NOT MSLITE_ENABLE_WEIGHT_DECODE) @@ -592,10 +564,8 @@ message(STATUS "\tMSLITE_ENABLE_SPARSE_COMPUTE = \t${MSLITE_ENABLE message(STATUS "\tMSLITE_ENABLE_RUNTIME_CONVERT = \t${MSLITE_ENABLE_RUNTIME_CONVERT}") message(STATUS "\tMSLITE_ENABLE_RUNTIME_GLOG = \t${MSLITE_ENABLE_RUNTIME_GLOG}") message(STATUS "\tMSLITE_ENABLE_COVERAGE = \t${MSLITE_ENABLE_COVERAGE}") -message(STATUS "\tMSLITE_ENABLE_SERVER_INFERENCE = \t${MSLITE_ENABLE_SERVER_INFERENCE}") message(STATUS "\tMSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE = \t${MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE}") message(STATUS "\tMSLITE_ENABLE_BFC_MEMORY = \t${MSLITE_ENABLE_BFC_MEMORY}") -message(STATUS "\tMSLITE_ENABLE_PARALLEL_INFERENCE = \t${MSLITE_ENABLE_PARALLEL_INFERENCE}") message(STATUS "\tMSLITE_ENABLE_SHARING_MODEL_WEIGHT = \t${MSLITE_ENABLE_SHARING_MODEL_WEIGHT}") message(STATUS "\tMSLITE_ENABLE_KERNEL_EXECUTOR = \t${MSLITE_ENABLE_KERNEL_EXECUTOR}") message(STATUS "\tMSLITE_ENABLE_CLOUD_FUSION_INFERENCE = \t${MSLITE_ENABLE_CLOUD_FUSION_INFERENCE}") @@ -603,7 +573,6 @@ message(STATUS "\tMSLITE_ENABLE_CLOUD_INFERENCE = \t${MSLITE_ENABLE message(STATUS "\tMSLITE_ENABLE_CAPTURE_SIGNALS = \t${MSLITE_ENABLE_CAPTURE_SIGNALS}") message(STATUS "\tMSLITE_ENABLE_MODEL_OBF = \t${MSLITE_ENABLE_MODEL_OBF}") message(STATUS "\tMSLITE_ENABLE_MODEL_PRE_INFERENCE = \t${MSLITE_ENABLE_MODEL_PRE_INFERENCE}") -message(STATUS "\tMSLITE_EXPORT_COMPUTE_IR = \t${MSLITE_EXPORT_COMPUTE_IR}") message(STATUS "\tMSLITE_ENABLE_PACKAGE_WHEEL = \t${MSLITE_ENABLE_PACKAGE_WHEEL}") message(STATUS "\tMSLITE_ENABLE_OPENCV = \t${MSLITE_ENABLE_OPENCV}") message(STATUS "\tMSLITE_TARGET_SITEAI = \t${MSLITE_TARGET_SITEAI}") @@ -743,11 +712,10 @@ set(CORE_DIR ${TOP_DIR}/mindspore/mindspore/core) set(CORE_INC_DIR ${TOP_DIR}/mindspore/mindspore/core/include) set(CCSRC_DIR ${TOP_DIR}/mindspore/mindspore/ccsrc) set(OPS_DIR ${TOP_DIR}/mindspore/mindspore/ops) -set(NNACL_DIR ${OPS_DIR}/kernel/cpu/nnacl) +set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/ops/kernel/cpu/nnacl) if(PLATFORM_MCU) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-incompatible-pointer-types") -# set(MSLITE_DEPS_CMSIS on) add_subdirectory(${NNACL_DIR} build/nnacl) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/cmake/cortex-m/ build) include(${TOP_DIR}/cmake/package_lite.cmake) @@ -767,7 +735,9 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/litert/kernel/cpu) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/) include_directories(${TOP_DIR}/third_party) include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/minddata/dataset) +if(NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/minddata/dataset) +endif() include(${TOP_DIR}/cmake/utils.cmake) include(${TOP_DIR}/cmake/dependency_utils.cmake) include(${TOP_DIR}/cmake/dependency_securec.cmake) @@ -802,9 +772,7 @@ if(MSLITE_ENABLE_COREML) include(${TOP_DIR}/cmake/external_libs/protobuf_arm.cmake) endif() -if(MSLITE_ENABLE_CONVERTER OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "full" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper" - OR MSLITE_ENABLE_TOOLS OR MSLITE_ENABLE_KERNEL_EXECUTOR) - # include(${TOP_DIR}/cmake/external_libs/json.cmake) +if(MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_TOOLS OR MSLITE_ENABLE_KERNEL_EXECUTOR) set(MSLITE_DEPS_JSON on) endif() @@ -833,7 +801,6 @@ else() endif() if(MSLITE_GPU_BACKEND STREQUAL cuda) - # add_definitions(-DGPU_CUDA) add_compile_definitions(SUPPORT_GPU) set(SUPPORT_CUDA on) if(DEFINED ENV{CUDA_HOME}) @@ -914,13 +881,11 @@ if(MSLITE_ENABLE_MODEL_OBF) endif() if((MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_RUNTIME_GLOG)) - # include(${TOP_DIR}/cmake/external_libs/glog.cmake) set(MSLITE_DEPS_GLOG on) endif() if(MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_KERNEL_EXECUTOR) find_required_package(Patch) - # include(${TOP_DIR}/cmake/external_libs/protobuf.cmake) set(MSLITE_DEPS_PROTOBUF on) endif() @@ -950,11 +915,6 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) set(MSLITE_DEPS_LIBEVENT on) set(MSLITE_DEPS_PYBIND11 on) endif() - if(SUPPORT_TENSORRT) - if(MSLITE_ENABLE_GRAPH_KERNEL) - set(MSLITE_DEPS_AKG_TENSORRT on) - endif() - endif() endif() # In core/CMakelists, core link mindspore::crypto, and crypto need the OPENSSL lib, @@ -999,12 +959,6 @@ ms_build_flatbuffers_lite(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/schema/ fbs_inne "inner") if(MSLITE_ENABLE_CONVERTER) - # find_required_package(Patch) - # include_directories(${PYTHON_INCLUDE_DIRS}) - # if(NOT ENABLE_CLOUD_AND_LITE) - # include(${TOP_DIR}/cmake/external_libs/opencv.cmake) - # include(${TOP_DIR}/cmake/external_libs/eigen.cmake) - # endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) endif() @@ -1021,12 +975,6 @@ if(MSLITE_ENABLE_FP16) endif() endif() -# if(MSLITE_ENABLE_MODEL_ENCRYPTION AND NOT ENABLE_CLOUD_AND_LITE) -# find_required_package(Patch) -# include(${TOP_DIR}/cmake/external_libs/openssl.cmake) -# add_compile_definitions(ENABLE_OPENSSL) -# endif() - if(MSLITE_ENABLE_MINDRT) add_compile_definitions(ENABLE_MINDRT) endif() @@ -1062,7 +1010,8 @@ if(NOT PLATFORM_ARM) endif() if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "full" - OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite_cv") + OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper" OR MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite_cv" + AND NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) add_compile_definitions(ENABLE_ANDROID) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/minddata) endif() @@ -1073,7 +1022,7 @@ if(ANDROID_NDK_TOOLCHAIN_INCLUDED OR TARGET_OHOS_LITE OR TARGET_HIMIX) endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src) -add_subdirectory(${OPS_DIR}/kernel/cpu/nnacl build) +add_subdirectory(${NNACL_DIR} build) if(MSLITE_ENABLE_TOOLS) if(NOT MSLITE_COMPILE_TWICE) diff --git a/mindspore-lite/OWNERS b/mindspore-lite/OWNERS index f7bfd7af770f240ed8f5730dc2cbf89bd5801c18..20f5876c770d4a517d434f067bcba3db4a5e037c 100644 --- a/mindspore-lite/OWNERS +++ b/mindspore-lite/OWNERS @@ -1,8 +1,12 @@ approvers: -- jjfeing +- zhaizhiqiang +- zhang_xue_tong # +- jpc_chenjianping +- wang_shaocong - YeFeng_24 -- fatmouse007fatmouse007 +reviewers: - xu_anyue +- fatmouse007fatmouse007 # zhuguodong options: no_parent_owners: true diff --git a/mindspore-lite/README.md b/mindspore-lite/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d06613c96f23b53cdb27e407e69e7514ce1624e1 --- /dev/null +++ b/mindspore-lite/README.md @@ -0,0 +1,68 @@ +[查看中文](./README_CN.md) + +## What Is MindSpore Lite + +MindSpore lite is a high-performance, lightweight open source reasoning framework that can be used to meet the needs of AI applications on mobile devices. MindSpore Lite focuses on how to deploy AI technology more effectively on devices. It has been integrated into HMS (Huawei Mobile Services) to provide inferences for applications such as image classification, object detection and OCR. MindSpore Lite will promote the development and enrichment of the AI software/hardware application ecosystem. + +MindSpore Lite Architecture + +For more details please check out our [MindSpore Lite Architecture Guide](https://www.mindspore.cn/lite/docs/en/master/reference/architecture_lite.html). + +### MindSpore Lite features + +1. Cooperative work with MindSpore training + - Provides training, optimization, and deployment. + - The unified IR realizes the device-cloud AI application integration. + +2. Lightweight + - Provides model compress, which could help to improve performance as well. + - Provides the ultra-lightweight reasoning solution MindSpore Micro to meet the deployment requirements in extreme environments such as smart watches and headphones. + +3. High-performance + - The built-in high-performance kernel computing library NNACL supports multiple convolution optimization algorithms such as Slide window, im2col+gemm, winograde, etc. + - Assembly code to improve performance of kernel operators. Supports CPU, GPU, and NPU. +4. Versatility + - Supports IOS, Android. + - Supports Lite OS. + - Supports mobile device, smart screen, pad, and IOT devices. + - Supports third party models such as TFLite, CAFFE and ONNX. + +## MindSpore Lite AI deployment procedure + +1. Model selection and personalized training + + Select a new model or use an existing model for incremental training using labeled data. When designing a model for mobile device, it is necessary to consider the model size, accuracy and calculation amount. + + The MindSpore team provides a series of pre-training models used for image classification, object detection. You can use these pre-trained models in your application. + + The pre-trained model provided by MindSpore: [Image Classification](https://download.mindspore.cn/model_zoo/official/lite/). More models will be provided in the feature. + + MindSpore allows you to retrain pre-trained models to perform other tasks. + +2. Model converter and optimization + + If you use MindSpore or a third-party model, you need to use [MindSpore Lite Model Converter Tool](https://www.mindspore.cn/lite/docs/en/master/converter/converter_tool.html) to convert the model into MindSpore Lite model. The MindSpore Lite model converter tool provides the converter of TensorFlow Lite, Caffe, ONNX to MindSpore Lite model, fusion and quantization could be introduced during convert procedure. + + MindSpore also provides a tool to convert models running on IoT devices . + +3. Model deployment + + This stage mainly realizes model deployment, including model management, deployment, operation and maintenance monitoring, etc. + +4. Inference + + Load the model and perform inference. [Inference](https://www.mindspore.cn/lite/docs/en/master/infer/runtime_cpp.html) is the process of running input data through the model to get output. + + MindSpore provides pre-trained model that can be deployed on mobile device [example](https://www.mindspore.cn/lite/examples/en). + +## MindSpore Lite benchmark test result + +We test a couple of networks on HUAWEI Mate40 (Hisilicon Kirin9000e) mobile phone, and get the test results below for your reference. + +| NetWork | Thread Number | Average Run Time(ms) | +| ------------------- | ------------- | -------------------- | +| basic_squeezenet | 4 | 6.415 | +| inception_v3 | 4 | 36.767 | +| mobilenet_v1_10_224 | 4 | 4.936 | +| mobilenet_v2_10_224 | 4 | 3.644 | +| resnet_v2_50 | 4 | 25.071 | diff --git a/mindspore-lite/README_CN.md b/mindspore-lite/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..4aca59d617135a6c4aadd9f253158c54d36c118c --- /dev/null +++ b/mindspore-lite/README_CN.md @@ -0,0 +1,76 @@ + +[View English](./README.md) + +## MindSpore Lite介绍 + +MindSpore Lite是MindSpore推出的端云协同的、轻量化、高性能AI推理框架,用于满足越来越多的端测AI应用需求。MindSpore Lite聚焦AI技术在端侧设备上的部署和运行,已经在华为HMS和智能终端的图像分类、目标识别、人脸识别、文字识别等应用中广泛使用,未来MindSpore Lite将与MindSpore AI社区一起,致力于丰富AI软硬件应用生态。 + +MindSpore Lite Architecture + +欲了解更多详情,请查看我们的[MindSpore Lite 总体架构](https://www.mindspore.cn/lite/docs/zh-CN/master/reference/architecture_lite.html)。 + +## MindSpore Lite技术特点 + +1. 端云协同提供一站式训练和推理 + + - 提供模型训练、模型转换优化、部署和推理端到端流程。 + - 统一的IR实现端云AI应用一体化。 + +2. 超轻量 + + - 支持模型量化压缩,模型更小跑得更快。 + - 提供超轻量的推理解决方案MindSpore Micro,满足智能手表、耳机等极限环境下的部署要求。 + +3. 高性能 + + - 自带的高性能内核计算库NNACL,支持Sliding Windows、Im2Col+GEMM、Winograd等多种卷积优化算法。 + - 汇编级优化,支持CPU、GPU、NPU异构调度,最大化发挥硬件算力,最小化推理时延和功耗。 + +4. 广覆盖 + + - 支持iOS、Android等手机操作系统。 + - 支持LiteOS嵌入式操作系统。 + - 支持手机、大屏、平板、IoT等各种智能设备上的AI应用。 + - 支持MindSpore/TensorFlow Lite/Caffe/ONNX模型,方便用户快速部署。 + +## MindSpore Lite AI部署流程 + +1. 模型选择和个性化训练 + + 包括选择新模型或对已有模型,利用标注数据进行增量训练。面向端侧设计模型时,需要考虑模型大小、精度和计算量。 + + MindSpore团队提供了一系列预训练模型,用于解决图像分类、目标检测等场景的学习问题。可以在您的应用程序中使用这些预训练模型对应的终端模型。 + + MindSpore提供的预训练模型:[图像分类(Image Classification)](https://download.mindspore.cn/model_zoo/official/lite/)。后续MindSpore团队会增加更多的预置模型。 + + MindSpore允许您重新训练预训练模型,以执行其他任务。比如:使用预训练的图像分类模型,可以重新训练来识别新的图像类型。 + +2. 模型转换/优化 + + 如果您使用MindSpore或第三方训练的模型,需要使用[MindSpore Lite模型转换工具](https://www.mindspore.cn/lite/docs/zh-CN/master/converter/converter_tool.html)转换成MindSpore Lite模型格式。MindSpore Lite模型转换工具不仅提供了将TensorFlow Lite、Caffe、ONNX等模型格式转换为MindSpore Lite模型格式,还提供了算子融合、量化等功能。 + + MindSpore还提供了将IoT设备上运行的模型转换成.C代码的生成工具。 + + 经过上述两个部署,您已经得到端侧可以部署的模型。 + +3. 模型部署 + + 这个阶段主要实现模型部署,包括模型管理、部署和运维监控等。 + +4. 模型推理 + + 主要完成模型推理工作,即加载模型,完成模型相关的所有计算。[推理](https://www.mindspore.cn/lite/docs/zh-CN/master/infer/runtime_cpp.html)是通过模型运行输入数据,获取预测的过程。 + + MindSpore提供了预训练模型部署在智能终端的[样例](https://www.mindspore.cn/lite/examples)。 + +## MindSpore Lite性能参考数据 + +我们在HUAWEI Mate40(Hisilicon Kirin9000e)手机上,测试了一组端侧常见网络的性能数据,供您参考: + +| 网络 | 线程数 | 平均推理时间(毫秒) | +| ------------------- | ----- | --------------- | +| basic_squeezenet | 4 | 6.415 | +| inception_v3 | 4 | 36.767 | +| mobilenet_v1_10_224 | 4 | 4.936 | +| mobilenet_v2_10_224 | 4 | 3.644 | +| resnet_v2_50 | 4 | 25.071 | diff --git a/mindspore-lite/build_lite.sh b/mindspore-lite/build_lite.sh new file mode 100755 index 0000000000000000000000000000000000000000..f449d7b1cd4db49cbd05db9883e16055459d59ec --- /dev/null +++ b/mindspore-lite/build_lite.sh @@ -0,0 +1,1009 @@ +#!/bin/bash +# Copyright 2021-2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +set -e + +checkndk() { + if [ "${ANDROID_NDK}" ]; then + echo -e "\e[31mANDROID_NDK=$ANDROID_NDK \e[0m" + else + echo -e "\e[31mplease set ANDROID_NDK in environment variable for example: export ANDROID_NDK=/root/usr/android-ndk-r20b/ \e[0m" + exit 1 + fi +} + +check_Hi35xx() { + if [[ "X${HI35XX_SDK_PATH}" == "X" ]]; then + echo "error: to compile the runtime package of Hi35XX, you need to set HI35XX_SDK_PATH to declare the path of Hi35XX sdk." + exit 1 + else + cp -r ${HI35XX_SDK_PATH}/third_patry ${BASEPATH}/mindspore-lite/providers/nnie + fi +} + +get_version() { + VERSION_STR=$(cat ${BASEPATH}/version.txt) +} + +get_cpack_dir() { + local pack_dir="_CPack_Packages" + local ms_pack="mindspore-lite-${VERSION_STR}" + local cpack_dir="${INSTALL_PREFIX}/${pack_dir}" + + local linux_x86_path="${cpack_dir}/Linux/TGZ/${ms_pack}/linux-x64" + local linux_cortex_path="${cpack_dir}/Linux/TGZ/${ms_pack}/none-cortex-m7" + local linux_aarch64_path="${cpack_dir}/Linux/TGZ/${ms_pack}/linux-aarch64" + local linux_aarch32_path="${cpack_dir}/Linux/TGZ/${ms_pack}/linux-aarch32" + local android_aarch64_path="${cpack_dir}/Android/TGZ/${ms_pack}/android-aarch64" + local android_aarch32_path="${cpack_dir}/Android/TGZ/${ms_pack}/android-aarch32" + + CPACK_PACKAGE_DIR="" + if [ -d "${linux_x86_path}" ]; then + CPACK_PACKAGE_DIR=${linux_x86_path} + elif [ -d "${linux_cortex_path}" ]; then + CPACK_PACKAGE_DIR=${linux_cortex_path} + elif [ -d "${linux_aarch64_path}" ]; then + CPACK_PACKAGE_DIR=${linux_aarch64_path} + elif [ -d "${linux_aarch32_path}" ]; then + CPACK_PACKAGE_DIR=${linux_aarch32_path} + elif [ -d "${android_aarch64_path}" ]; then + CPACK_PACKAGE_DIR=${android_aarch64_path} + elif [ -d "${android_aarch32_path}" ]; then + CPACK_PACKAGE_DIR=${android_aarch32_path} + else + echo "Please check cpack path." + fi + echo "Using cpack output path: ${CPACK_PACKAGE_DIR}" +} + +write_commit_file() { + COMMIT_STR=$(git log -1 | grep commit) + echo ${COMMIT_STR} > "${BASEPATH}/mindspore-lite/build/.commit_id" +} + +build_lite_jni_and_jar() { + local JNI_CMAKE_ARGS=$1 + local local_lite_platform=$2 + echo "============ mindspore lite: start building jni and jar ${VERSION_STR} ${local_lite_platform} ============" + export MSLITE_ENABLE_RUNTIME_CONVERT=off + + local PKG_NAME + local NATIVE_PATH_ARCH + local RESOURCE_PATH_ARCH + if [[ "${local_lite_platform}" == "x86_64" ]]; then + PKG_NAME=mindspore-lite-${VERSION_STR}-linux-x64 + NATIVE_PATH_ARCH=linux_x86 + RESOURCE_PATH_ARCH=linux_x86_64 + elif [[ "${local_lite_platform}" == "aarch64" ]]; then + PKG_NAME=mindspore-lite-${VERSION_STR}-linux-aarch64 + NATIVE_PATH_ARCH=linux_aarch64 + RESOURCE_PATH_ARCH=linux_aarch64 + JNI_CMAKE_ARGS="${JNI_CMAKE_ARGS} -DMACHINE_LINUX_ARM64=on" + else + echo "platform: ${local_lite_platform} not support building jni." + exit 0 + fi + + # copy so + local is_train=on + local is_cloud_infer=off + cd ${INSTALL_PREFIX}/ + + rm -rf ${PKG_NAME} + tar -zxf ${INSTALL_PREFIX}/${PKG_NAME}.tar.gz + rm -rf ${LITE_JAVA_PATH}/java/${NATIVE_PATH_ARCH}/libs/ && mkdir -pv ${LITE_JAVA_PATH}/java/${NATIVE_PATH_ARCH}/libs/ + rm -rf ${LITE_JAVA_PATH}/native/libs/${NATIVE_PATH_ARCH}/ && mkdir -pv ${LITE_JAVA_PATH}/native/libs/${NATIVE_PATH_ARCH}/ + cp ./${PKG_NAME}/runtime/lib/*.so* ${LITE_JAVA_PATH}/java/${NATIVE_PATH_ARCH}/libs/ + cp ./${PKG_NAME}/runtime/lib/*.so* ${LITE_JAVA_PATH}/native/libs/${NATIVE_PATH_ARCH}/ + local train_so=${PKG_NAME}/runtime/lib/libmindspore-lite-train.so + if [ ! -f "$train_so" ]; then + echo "libmindspore-lite-train.so so not exist" + is_train=off + fi + if [[ "X$is_train" = "Xon" ]]; then + cp ./${PKG_NAME}/runtime/third_party/libjpeg-turbo/lib/*.so* ${LITE_JAVA_PATH}/java/${NATIVE_PATH_ARCH}/libs/ + cp ./${PKG_NAME}/runtime/third_party/libjpeg-turbo/lib/*.so* ${LITE_JAVA_PATH}/native/libs/${NATIVE_PATH_ARCH}/ + fi + # prepare + cd ${BASEPATH}/mindspore-lite/build + rm -rf java/jni && mkdir -pv java/jni + cd java/jni + # copy glog lib and headers + LIB_GLOG="libmindspore_glog.so*" + if [ -f "`echo ${INSTALL_PREFIX}/${PKG_NAME}/runtime/third_party/glog/${LIB_GLOG}`" ]; then + cp ${INSTALL_PREFIX}/${PKG_NAME}/runtime/third_party/glog/*.so* ${LITE_JAVA_PATH}/java/${NATIVE_PATH_ARCH}/libs/ + cp ${INSTALL_PREFIX}/${PKG_NAME}/runtime/third_party/glog/*.so* ${LITE_JAVA_PATH}/native/libs/${NATIVE_PATH_ARCH}/ + else + echo "no glog lib found." + fi + if [ -d "${BASEPATH}/output/tmp/${PKG_NAME}/runtime/include/third_party/glog" ]; then + rm -rf jni_include && mkdir jni_include + cp ${BASEPATH}/output/tmp/${PKG_NAME}/runtime/include/third_party/glog ./jni_include -r + else + echo "no glog hesders found." + fi + # build jni so + echo "cmake ${JNI_CMAKE_ARGS} -DSUPPORT_TRAIN=${is_train} ${LITE_JAVA_PATH}/native/" + cmake ${JNI_CMAKE_ARGS} -DSUPPORT_TRAIN=${is_train} "${LITE_JAVA_PATH}/native/" + make -j$THREAD_NUM + if [[ $? -ne 0 ]]; then + echo "---------------- mindspore lite: build jni ${local_lite_platform} failed----------------" + exit 1 + fi + rm -f ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/${RESOURCE_PATH_ARCH}/*.so* + cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/${NATIVE_PATH_ARCH}/libs/ + cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/${NATIVE_PATH_ARCH}/ + cp ./libmindspore-lite-jni.so ${INSTALL_PREFIX}/${PKG_NAME}/runtime/lib/ + + if [[ "$MSLITE_ENABLE_CLOUD_FUSION_INFERENCE" == "ON" || "$MSLITE_ENABLE_CLOUD_FUSION_INFERENCE" == "on" ]];then + is_cloud_infer=on + fi + if [[ "$MSLITE_ENABLE_CLOUD_INFERENCE" == "ON" || "$MSLITE_ENABLE_CLOUD_INFERENCE" == "on" ]];then + is_cloud_infer=on + fi + + RUNTIME_LIB_DIR="${BASEPATH}/output/tmp/${PKG_NAME}/runtime/lib" + if [[ -d ${RUNTIME_LIB_DIR} ]]; then + if [ "$(ls -A ${RUNTIME_LIB_DIR})" ]; then + cp ${RUNTIME_LIB_DIR}/*.so ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/${RESOURCE_PATH_ARCH}/ + fi + fi + CONVERTER_LIB_DIR="${BASEPATH}/output/tmp/${PKG_NAME}/tools/converter/lib" + if [[ "X$is_cloud_infer" = "Xon" && -d ${CONVERTER_LIB_DIR} ]]; then + if [ "$(ls -A ${CONVERTER_LIB_DIR})" ]; then + cp ${CONVERTER_LIB_DIR}/*.so ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/${RESOURCE_PATH_ARCH}/ + fi + fi + + if [ -f "`echo ${BASEPATH}/output/tmp/${PKG_NAME}/tools/converter/lib/${LIB_GLOG}`" ]; then + cp ${BASEPATH}/output/tmp/${PKG_NAME}/tools/converter/lib/libmindspore_glog.so* ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/${RESOURCE_PATH_ARCH}/libmindspore_glog.so + fi + LIB_OPENCV_IMGPROC="libopencv_imgproc.so*" + if [[ "X$is_cloud_infer" = "Xon" && -f "`echo ${BASEPATH}/output/tmp/${PKG_NAME}/tools/converter/lib/${LIB_OPENCV_IMGPROC}`" ]]; then + cp ${BASEPATH}/output/tmp/${PKG_NAME}/tools/converter/lib/${LIB_OPENCV_IMGPROC} ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/${RESOURCE_PATH_ARCH}/libopencv_imgproc.so + fi + LIB_OPENCV_CORE="libopencv_core.so*" + if [[ "X$is_cloud_infer" = "Xon" && -f "`echo ${BASEPATH}/output/tmp/${PKG_NAME}/tools/converter/lib/${LIB_OPENCV_CORE}`" ]]; then + cp ${BASEPATH}/output/tmp/${PKG_NAME}/tools/converter/lib/${LIB_OPENCV_CORE} ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/${RESOURCE_PATH_ARCH}/libopencv_core.so + fi + LIB_OPENCV_IMGCODECS="libopencv_imgcodecs.so*" + if [[ "X$is_cloud_infer" = "Xon" && -f "`echo ${BASEPATH}/output/tmp/${PKG_NAME}/tools/converter/lib/${LIB_OPENCV_IMGCODECS}`" ]]; then + cp ${BASEPATH}/output/tmp/${PKG_NAME}/tools/converter/lib/${LIB_OPENCV_IMGCODECS} ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/${RESOURCE_PATH_ARCH}/libopencv_imgcodecs.so + fi + LIB_DNNL="libdnnl.so*" + if [[ "X$is_cloud_infer" = "Xon" && -f "`echo ${BASEPATH}/output/tmp/${PKG_NAME}/runtime/third_party/dnnl/${LIB_DNNL}`" ]]; then + cp ${BASEPATH}/output/tmp/${PKG_NAME}/runtime/third_party/dnnl/${LIB_DNNL} ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/${RESOURCE_PATH_ARCH}/libdnnl.so + fi + + if [[ "X$is_train" = "Xon" ]]; then + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/${NATIVE_PATH_ARCH}/libs/ + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/${NATIVE_PATH_ARCH}/ + cp ./libmindspore-lite-train-jni.so ${INSTALL_PREFIX}/${PKG_NAME}/runtime/lib/ + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite/${RESOURCE_PATH_ARCH}/ + fi + + chmod -R 750 ${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite + cd ${LITE_JAVA_PATH}/java + rm -rf gradle .gradle gradlew gradlew.bat + local gradle_version="" + gradle_version=`gradle --version | grep Gradle | awk '{print$2}'` + if [[ ${gradle_version} == '6.6.1' ]]; then + gradle_command=gradle + else + gradle wrapper --gradle-version 6.6.1 --distribution-type all + gradle_command=${LITE_JAVA_PATH}/java/gradlew + fi + + # build jar + ${gradle_command} clean -p ${LITE_JAVA_PATH}/ + if [[ "${ENABLE_ASAN}" == "ON" || "${ENABLE_ASAN}" == "on" ]] ; then + ${gradle_command} releaseJar -p ${LITE_JAVA_PATH}/ -x test --info + else + if [[ "${MSLITE_ENABLE_TESTCASES}" == "ON" || "${MSLITE_ENABLE_TESTCASES}" == "on" ]] && [[ "${MSLITE_ENABLE_CLOUD_INFERENCE}" != "on" ]] ; then + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${LITE_JAVA_PATH}/native/libs/${NATIVE_PATH_ARCH}/ + ${gradle_command} releaseJar -p ${LITE_JAVA_PATH}/ --info + else + ${gradle_command} releaseJar -p ${LITE_JAVA_PATH}/ -x test --info + fi + fi + cp ${LITE_JAVA_PATH}/build/lib/jar/*.jar ${INSTALL_PREFIX}/${PKG_NAME}/runtime/lib/ + + # package + cd ${INSTALL_PREFIX} + rm -rf ${PKG_NAME}.tar.gz ${PKG_NAME}.tar.gz.sha256 + tar czf ${PKG_NAME}.tar.gz ${PKG_NAME} + sha256sum ${PKG_NAME}.tar.gz > ${PKG_NAME}.tar.gz.sha256 + rm -rf ${LITE_JAVA_PATH}/java/${NATIVE_PATH_ARCH}/libs/ + rm -rf ${LITE_JAVA_PATH}/native/libs/${NATIVE_PATH_ARCH}/ + echo "---------------- mindspore lite jni and jar: build success ----------------" +} + +get_python_tag(){ + local minor_version=`python3 -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $2}'` || true + local py_tags="cp${python_version}${minor_version}-cp${python_version}${minor_version}" + if [[ "${minor_version}" == "7" ]]; then + py_tags="cp37-cp37m" + fi + echo ${py_tags} +} + +build_akg() { + local python_version=`python3 -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $1}'` || true + if [[ "${python_version}" == "3" ]]; then + echo "---------------- AKG: build start ----------------" + AKG_DIR=${BASEPATH}/akg + BUILD_DIR=${BASEPATH}/mindspore-lite/build + write_version() { + cd ${AKG_DIR} + if [ ! -e ${AKG_DIR}/version.txt ]; then + version=$(git branch | sed -n '/\* /s///p') + if [ -z "${version}" ]; then + version='master' + fi + echo ${version#r} >${BUILD_DIR}/akg/version.txt + else + cp ${AKG_DIR}/version.txt ${BUILD_DIR}/akg + fi + cp ${AKG_DIR}/setup.py ${BUILD_DIR}/akg/akg_setup.py + } + cd ${BUILD_DIR} + mkdir -pv akg + write_version + cd ${BUILD_DIR}/akg + mkdir -pv build + cd build + if [[ "X${DEBUG_MODE}" == "Xon" ]]; then + AKG_CMAKE_ARGS="${AKG_CMAKE_ARGS} -DCMAKE_BUILD_TYPE=Debug -DUSE_AKG_LOG=1" + fi + # check llvm version for akg + AKG_CMAKE_ARGS="${AKG_CMAKE_ARGS} -DUSE_RPC=ON " + USE_LLVM=`bash ${BASEPATH}/scripts/build/akg_find_llvm.sh` + if [[ "X$USE_LLVM" == "Xon" ]]; then + graph_kernel_cfg="-DAKG_USE_LLVM=ON ${graph_kernel_cfg}" + AKG_CMAKE_ARGS="${AKG_CMAKE_ARGS} -DUSE_LLVM=ON" + fi + if [[ ("X${MSLITE_GPU_BACKEND}" == "Xtensorrt") && $1 == "x86_64" ]]; then + AKG_CMAKE_ARGS="${AKG_CMAKE_ARGS} -DUSE_CUDA=ON" + graph_kernel_cfg="-DAKG_USE_CUDA=ON ${graph_kernel_cfg}" + fi + if [[ ("X${MSLITE_ENABLE_ACL}" == "Xon") ]]; then + AKG_CMAKE_ARGS="${AKG_CMAKE_ARGS} -DENABLE_D=ON" + graph_kernel_cfg="-DAKG_ENABLE_D=ON ${graph_kernel_cfg}" + fi + echo "cmake ${AKG_CMAKE_ARGS}" + cmake -S ${AKG_DIR} ${AKG_CMAKE_ARGS} -B . + cmake --build . -j$THREAD_NUM + cmake --install . + cd ${BUILD_DIR}/akg + [ -d dist ] && rm -rf dist + python3 ${BUILD_DIR}/akg/akg_setup.py bdist_wheel + cd dist + for file in *.whl; do + file_name=$(basename $file) + prefix=$(echo $file_name | cut -d '-' -f 1-2) + PY_TAGS=`get_python_tag` + akg_pkg_name="${prefix}-${PY_TAGS}-linux_$1.whl" + mv $file ${akg_pkg_name} + sha256sum -b "$akg_pkg_name" >"$akg_pkg_name.sha256" + akg_package_path=dist/${akg_pkg_name} + done + echo "---------------- AKG: build end ----------------" + cd ${BUILD_DIR} + else + echo -e "\e[31mPython3 not found, so AKG will not be compiled. \e[0m" + fi +} + +build_python_wheel_package() { + local python_version=`python3 -V 2>&1 | awk '{print $2}' | awk -F '.' '{print $1}'` || true + if [[ "${python_version}" == "3" ]]; then + cd ${BASEPATH}/mindspore-lite/build/ + local lite_wrapper_so=`ls python/_c_lite_wrapper*.so` || true + if [ ! -f "${lite_wrapper_so}" ]; then + echo "error: can not find python so." + return 0 + fi + mkdir -pv package/mindspore_lite/lib/ + cp ../python/api/* package/mindspore_lite/ + local pkg_name=mindspore-lite-${VERSION_STR}-linux-$1 + if [[ "$1" == "x86_64" ]]; then + local pkg_name=mindspore-lite-${VERSION_STR}-linux-x64 + fi + if [[ ("${MSLITE_ENABLE_CLOUD_FUSION_INFERENCE}" == "on") || ("${MSLITE_ENABLE_CLOUD_INFERENCE}" == "on") ]]; then + cp src/extendrt/*.so package/mindspore_lite/lib/ + find src/extendrt/delegate/graph_executor/litert/ -name "*.so" -exec cp '{}' package/mindspore_lite/lib/ \; + find src/extendrt/unified_executor/ -name "*.so" -exec cp '{}' package/mindspore_lite/lib/ \; + find src/extendrt/convert/ -name "*.so" -exec cp '{}' package/mindspore_lite/lib/ \; + if [[ "X${MSLITE_ENABLE_ACL}" == "Xon" ]]; then + # cp src/extendrt/kernel/ascend/*.so package/mindspore_lite/lib/ + local dvpp_utils=minddata/kernels-dvpp-image/utils/libdvpp_utils.so + if [ -f ${dvpp_utils} ]; then + cp ${dvpp_utils} package/mindspore_lite/lib/ + fi + fi + if [ -f "${INSTALL_PREFIX}/${pkg_name}/runtime/lib/libtransformer-shared.so" ]; then + cp ${INSTALL_PREFIX}/${pkg_name}/runtime/lib/libtransformer-shared.so package/mindspore_lite/lib/ + fi + if [ -f "${INSTALL_PREFIX}/${pkg_name}/runtime/lib/libmindspore_graph_ir.so" ]; then + cp ${INSTALL_PREFIX}/${pkg_name}/runtime/lib/libmindspore_graph_ir.so package/mindspore_lite/lib/ + fi + if [ -f "${INSTALL_PREFIX}/${pkg_name}/runtime/lib/libtensorrt_plugin.so" ]; then + cp ${INSTALL_PREFIX}/${pkg_name}/runtime/lib/libtensorrt_plugin.so package/mindspore_lite/lib/ + fi + if [ -f "${INSTALL_PREFIX}/${pkg_name}/runtime/lib/libascend_ge_plugin.so" ]; then + cp ${INSTALL_PREFIX}/${pkg_name}/runtime/lib/libascend_ge_plugin.so package/mindspore_lite/lib/ + fi + if [ -f "${INSTALL_PREFIX}/${pkg_name}/runtime/lib/libllm_engine_plugin.so" ]; then + cp ${INSTALL_PREFIX}/${pkg_name}/runtime/lib/libllm_engine_plugin.so package/mindspore_lite/lib/ + fi + else + if [[ "X${MSLITE_ENABLE_ACL}" == "Xon" ]]; then + cp src/litert/kernel/ascend/*.so package/mindspore_lite/lib/ + fi + cp src/*.so package/mindspore_lite/lib/ + fi + if [ -d "${INSTALL_PREFIX}/${pkg_name}/runtime/third_party/glog" ]; then + cp ${INSTALL_PREFIX}/${pkg_name}/runtime/third_party/glog/*.so* package/mindspore_lite/lib/ + fi + if [ -d "${INSTALL_PREFIX}/${pkg_name}/runtime/third_party/dnnl" ]; then + cp ${INSTALL_PREFIX}/${pkg_name}/runtime/third_party/dnnl/*.so* package/mindspore_lite/lib/ + fi + if [ -d "${CPACK_PACKAGE_DIR}/${pkg_name}/tools/custom_kernels" ]; then + cp -rf ${CPACK_PACKAGE_DIR}/${pkg_name}/tools/custom_kernels package/mindspore_lite/ + fi + if [ -d "${INSTALL_PREFIX}/${pkg_name}/tools/converter/lib" ]; then + cp ${INSTALL_PREFIX}/${pkg_name}/tools/converter/lib/*.so* package/mindspore_lite/lib/ + fi + cp python/*.so package/mindspore_lite/lib/ + mkdir -pv package/mindspore_lite/include/registry/ + if [ -d "${INSTALL_PREFIX}/${pkg_name}/runtime/include/api" ]; then + cp -rf ${INSTALL_PREFIX}/${pkg_name}/runtime/include/api package/mindspore_lite/include/ + fi + if [ -d "${INSTALL_PREFIX}/${pkg_name}/runtime/include/mindapi" ]; then + cp -rf ${INSTALL_PREFIX}/${pkg_name}/runtime/include/mindapi package/mindspore_lite/include/ + fi + if [ -f "${INSTALL_PREFIX}/${pkg_name}/tools/converter/include/registry/converter_context.h" ]; then + cp ${INSTALL_PREFIX}/${pkg_name}/tools/converter/include/registry/converter_context.h package/mindspore_lite/include/registry/ + fi + if [ -f "${INSTALL_PREFIX}/${pkg_name}/tools/converter/include/converter.h" ]; then + cp ${INSTALL_PREFIX}/${pkg_name}/tools/converter/include/converter.h package/mindspore_lite/include/ + fi + cp .commit_id package/mindspore_lite/ + echo "__version__ = '${VERSION_STR}'" > package/mindspore_lite/version.py + cp ../python/setup.py package/ + cd package + rm -rf dist/mindspore_lite-*.whl + python3 setup.py bdist_wheel ${BASEPATH} + py_tags=`get_python_tag` + local whl_name=mindspore_lite-${VERSION_STR}-${py_tags}-linux_$1.whl + cp dist/mindspore_lite-*.whl ${BASEPATH}/output/${whl_name} + cd ${BASEPATH}/output/ + sha256sum ${whl_name} > ${whl_name}.sha256 + else + echo -e "\e[31mPython3 not found, so Python API will not be compiled. \e[0m" + fi +} + +build_lite() { + LITE_CMAKE_ARGS=${CMAKE_ARGS} + [ -n "${BASEPATH}" ] && rm -rf ${BASEPATH}/output + echo "============ Start building MindSpore Lite ${VERSION_STR} ============" + local local_lite_platform=${LITE_PLATFORM} + if [[ "${LITE_ENABLE_AAR}" == "on" ]]; then + local_lite_platform=$1 + mkdir -pv ${BASEPATH}/mindspore-lite/build/java + cd ${BASEPATH}/mindspore-lite/build/ + [ -n "${BASEPATH}" ] && find . -maxdepth 1 | grep -v java | grep '/' | xargs -I {} rm -rf {} + else + if [[ "${INC_BUILD}" == "off" ]]; then + [ -n "${BASEPATH}" ] && rm -rf ${BASEPATH}/mindspore-lite/build + fi + mkdir -pv ${BASEPATH}/mindspore-lite/build + fi + cd ${BASEPATH}/mindspore-lite/build + write_commit_file + + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DENABLE_ASAN=${ENABLE_ASAN} -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX}" + + if [[ "$(uname)" == "Darwin" && "${local_lite_platform}" != "x86_64" ]]; then + LITE_CMAKE_ARGS=`echo $LITE_CMAKE_ARGS | sed 's/-DCMAKE_BUILD_TYPE=Debug/-DCMAKE_BUILD_TYPE=Release/g'` + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off -DMSLITE_ENABLE_NPU=off" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=off" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DENABLE_BITCODE=0 -G Xcode" + CMAKE_TOOLCHAIN_FILE=${BASEPATH}/cmake/lite_ios.cmake + fi + + BRANCH_NAME=nnie_r2.6_dev + if [[ ("${MSLITE_REGISTRY_DEVICE}" == "Hi3516D" || "${TOOLCHAIN_NAME}" == "himix200") && "${local_lite_platform}" == "arm32" ]]; then + TOOLCHAIN_NAME="himix200" + MSLITE_REGISTRY_DEVICE=Hi3516D + check_Hi35xx + MSLITE_COMPILE_TWICE=ON + twice_target='benchmark micro_nnie' + elif [[ "${MSLITE_REGISTRY_DEVICE}" == "Hi3559A" && "${local_lite_platform}" == "arm64" ]]; then + TOOLCHAIN_NAME="himix100" + check_Hi35xx + MSLITE_COMPILE_TWICE=ON + twice_target='benchmark micro_nnie' + elif [[ "${MSLITE_REGISTRY_DEVICE}" == "SD3403" && "${local_lite_platform}" == "arm64" ]]; then + TOOLCHAIN_NAME="mix210" + MSLITE_COMPILE_TWICE=ON + twice_target=benchmark + elif [[ "${MSLITE_REGISTRY_DEVICE}" == "Hi3519A" && "${local_lite_platform}" == "arm32" ]]; then + TOOLCHAIN_NAME="himix200" + check_Hi35xx + MSLITE_COMPILE_TWICE=ON + twice_target='benchmark micro_nnie' + elif [[ ("${MSLITE_ENABLE_NNIE}" == "on" || "${MSLITE_REGISTRY_DEVICE}" == "Hi3516D") && "${local_lite_platform}" == "x86_64" ]]; then + MSLITE_REGISTRY_DEVICE=Hi3516D + elif [[ "${MSLITE_MICRO_PLATFORM}" == cortex-m* && "${local_lite_platform}" == "x86_64" ]]; then + TOOLCHAIN_NAME=${MSLITE_MICRO_PLATFORM} + fi + + machine=`uname -m` + echo "machine:${machine}." + if [[ "${local_lite_platform}" == "arm32" ]]; then + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DPLATFORM_ARM32=on -DENABLE_NEON=on" + if [ "$(uname)" == "Darwin" ]; then + # CPU : iOS-aarch32 + pkg_name=mindspore-lite-${VERSION_STR}-ios-aarch32 + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DARCHS=armv7;armv7s" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=off" + elif [[ "${TOOLCHAIN_NAME}" == "ohos-lite" ]]; then + # Hi3516D : OpenHarmony-aarch32 + CMAKE_TOOLCHAIN_FILE=${BASEPATH}/mindspore-lite/cmake/ohos-lite.toolchain.cmake + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DTOOLCHAIN_NAME=ohos-lite" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=off" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=off" + elif [[ "${TOOLCHAIN_NAME}" == "himix200" ]]; then + # Hi3516D : Linux-aarch32 + CMAKE_TOOLCHAIN_FILE=${BASEPATH}/mindspore-lite/cmake/himix200.toolchain.cmake + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DTOOLCHAIN_NAME=himix200" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=off" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=off -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off" + elif [[ "${TOOLCHAIN_NAME}" == "ohos" ]]; then + pkg_name=mindspore-lite-${VERSION_STR}-ohos-arm32 + CMAKE_TOOLCHAIN_FILE=${OHOS_NDK}/build/cmake/ohos.toolchain.cmake + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DCMAKE_OHOS_NDK=${OHOS_NDK} -DOHOS_ARCH=armeabi-v7a -DOHOS_STL=c++_static -DTOOLCHAIN_NAME=${TOOLCHAIN_NAME}" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=off" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=off -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off" + else + # CPU : Android-aarch32 + checkndk + if [ -n "${MS_CCACHE_PATH}" ]; then + echo "use ${MS_CCACHE_PATH} to speed up compilation." + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DANDROID_CCACHE=${MS_CCACHE_PATH}" + fi + export PATH=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin:${ANDROID_NDK}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:${PATH} + CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=full" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=on" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DANDROID_NATIVE_API_LEVEL=19 -DANDROID_NDK=${ANDROID_NDK} -DANDROID_ABI=armeabi-v7a -DANDROID_TOOLCHAIN_NAME=clang -DANDROID_STL=${MSLITE_ANDROID_STL}" + fi + elif [[ "${local_lite_platform}" == "arm64" ]]; then + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DPLATFORM_ARM64=on -DENABLE_NEON=on" + if [ "$(uname)" == "Darwin" ]; then + # CPU : iOS-aarch64 + pkg_name=mindspore-lite-${VERSION_STR}-ios-aarch64 + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DARCHS=arm64" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=on" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=off" + elif [[ "${TOOLCHAIN_NAME}" == "himix100" ]]; then + # Hi3559A : + CMAKE_TOOLCHAIN_FILE=${BASEPATH}/mindspore-lite/cmake/himix100.toolchain.cmake + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DTOOLCHAIN_NAME=himix100" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=off" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=off -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off" + elif [[ "${TOOLCHAIN_NAME}" == "mix210" ]]; then + # SD3403 : + CMAKE_TOOLCHAIN_FILE=${BASEPATH}/mindspore-lite/cmake/mix210.toolchain.cmake + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DTOOLCHAIN_NAME=mix210" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=off" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=on -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off" + elif [[ "${TOOLCHAIN_NAME}" == "ohos" ]]; then + pkg_name=mindspore-lite-${VERSION_STR}-ohos-aarch64 + CMAKE_TOOLCHAIN_FILE=${OHOS_NDK}/build/cmake/ohos.toolchain.cmake + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DCMAKE_OHOS_NDK=${OHOS_NDK} -DOHOS_ARCH=arm64-v8a -DOHOS_STL=c++_static -DTOOLCHAIN_NAME=${TOOLCHAIN_NAME}" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=off" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=on -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off" + elif [[ "${AOS_SDK}" ]]; then + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=full" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DTARGET_AOS_ARM=on" + if [[ "${TOOLCHAIN_NAME}" == "gcc" ]]; then + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DTOOLCHAIN_NAME=gcc" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=off" + fi + else + if [[ "${machine}" == "aarch64" ]]; then + # CPU : Linux-aarch64 + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMACHINE_LINUX_ARM64=on" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_TRAIN=off" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_GPU_BACKEND=off" + else + # CPU/GPU : Android-aarch64 + checkndk + if [ -n "${MS_CCACHE_PATH}" ]; then + echo "use ${MS_CCACHE_PATH} to speed up compilation." + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DANDROID_CCACHE=${MS_CCACHE_PATH}" + fi + export PATH=${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/bin:${ANDROID_NDK}/toolchains/arm-linux-androideabi-4.9/prebuilt/linux-x86_64/bin:${PATH} + CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DANDROID_NATIVE_API_LEVEL=19 -DANDROID_NDK=${ANDROID_NDK} -DANDROID_ABI=arm64-v8a -DANDROID_TOOLCHAIN_NAME=aarch64-linux-android-clang -DANDROID_STL=${MSLITE_ANDROID_STL}" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=full" + fi + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_FP16=on" + fi + elif [[ "$(uname)" == "Darwin" ]]; then + pkg_name=mindspore-lite-${VERSION_STR}-ios-simulator + CMAKE_TOOLCHAIN_FILE=${BASEPATH}/cmake/lite_ios.cmake + LITE_CMAKE_ARGS=`echo $LITE_CMAKE_ARGS | sed 's/-DCMAKE_BUILD_TYPE=Debug/-DCMAKE_BUILD_TYPE=Release/g'` + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DPLATFORM=SIMULATOR64 -DPLATFORM_ARM64=off -DENABLE_NEON=off -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off -DMSLITE_ENABLE_NPU=off" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=off" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_ENABLE_TOOLS=off -DMSLITE_ENABLE_CONVERTER=off" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -G Xcode .." + else + if [[ "${machine}" == "aarch64" ]]; then + echo "Use the '-I arm64' command when compiling MindSpore Lite on an aarch64 architecture system." + exit 1 + fi + if [[ "${TOOLCHAIN_NAME}" == cortex-m* ]]; then + CMAKE_TOOLCHAIN_FILE=${BASEPATH}/mindspore-lite/cmake/${TOOLCHAIN_NAME}.toolchain.cmake + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DTOOLCHAIN_NAME=${TOOLCHAIN_NAME}" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DPLATFORM_MCU=on" + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_MINDDATA_IMPLEMENT=off -DMSLITE_ENABLE_TRAIN=off -DMSLITE_GPU_BACKEND=off -DMSLITE_ENABLE_TOOLS=off" + else + # CPU : Linux-x86_64 + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DPLATFORM_X86_64=on" + fi + fi + + if [[ "X$CMAKE_TOOLCHAIN_FILE" != "X" ]]; then + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE}" + fi + if [[ "X$MSLITE_REGISTRY_DEVICE" != "X" ]]; then + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_REGISTRY_DEVICE=${MSLITE_REGISTRY_DEVICE}" + fi + if [[ "X$MSLITE_COMPILE_TWICE" != "X" ]]; then + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} -DMSLITE_COMPILE_TWICE=${MSLITE_COMPILE_TWICE}" + fi + if [[ ("${local_lite_platform}" == "arm64" || "${local_lite_platform}" == "arm32") && "${TOOLCHAIN_NAME}" != "ohos" ]]; then + echo "default link libc++_static.a, export MSLITE_ANDROID_STL=c++_shared to link libc++_shared.so" + fi + if [[ ("X${MSLITE_ENABLE_CLOUD_FUSION_INFERENCE}" != "Xon") && ("X${MSLITE_ENABLE_CLOUD_INFERENCE}" != "Xon") ]]; then + ENABLE_AKG="off" + fi + if [[ ( ("${local_lite_platform}" == "x86_64" && "${machine}" == "x86_64") || ("${local_lite_platform}" == "arm64" && ${machine} == "aarch64") ) && ("${ENABLE_AKG}" == "on") ]]; then + akg_package_path="" + graph_kernel_cfg="" +# build_akg ${machine} + LITE_CMAKE_ARGS="${LITE_CMAKE_ARGS} ${graph_kernel_cfg} -DAKG_PKG_PATH=${akg_package_path}" + fi + echo "cmake ${LITE_CMAKE_ARGS} ${BASEPATH}/mindspore-lite/" + cmake ${LITE_CMAKE_ARGS} "${BASEPATH}/mindspore-lite/" + + if [[ "$(uname)" == "Darwin" && "${local_lite_platform}" != "x86_64" ]]; then + xcodebuild ONLY_ACTIVE_ARCH=NO -configuration Release -scheme mindspore-lite_static -target mindspore-lite_static -sdk iphoneos -quiet -UseModernBuildSystem=YES + elif [[ "$(uname)" == "Darwin" && "${local_lite_platform}" == "x86_64" ]]; then + xcodebuild ONLY_ACTIVE_ARCH=NO -configuration Release -scheme mindspore-lite_static -target mindspore-lite_static -sdk iphonesimulator -quiet -UseModernBuildSystem=YES + else + make -j$THREAD_NUM && make install + if [[ "X$MSLITE_COMPILE_TWICE" == "XON" ]]; then + if [[ "X$MSLITE_ENABLE_TOOLS" != "X" ]]; then + MSLITE_ENABLE_TOOLS=$(echo $MSLITE_ENABLE_TOOLS | tr '[a-z]' '[A-Z]') + fi + if [[ "X$MSLITE_ENABLE_TOOLS" != "XOFF" ]]; then + LITE_CMAKE_ARGS=`echo $LITE_CMAKE_ARGS | sed 's/-DMSLITE_COMPILE_TWICE=ON/-DMSLITE_COMPILE_TWICE=OFF/g'` + cp -r ${INSTALL_PREFIX}/mindspore*/runtime ${BASEPATH}/mindspore-lite/providers + PKG_PATH=${INSTALL_PREFIX}/`ls ${INSTALL_PREFIX}/` + echo "cmake ${LITE_CMAKE_ARGS} -DPKG_PATH=${PKG_PATH} ${BASEPATH}/mindspore-lite/" + cmake ${LITE_CMAKE_ARGS} -DPKG_PATH=${PKG_PATH} "${BASEPATH}/mindspore-lite/" + cmake --build "${BASEPATH}/mindspore-lite/build" --target ${twice_target} -j$THREAD_NUM + make install + fi + fi + + make package + get_cpack_dir + + local package_wheel=${MSLITE_ENABLE_PACKAGE_WHEEL} + if [[ ("${MSLITE_ENABLE_CLOUD_FUSION_INFERENCE}" == "on") || ("${MSLITE_ENABLE_CLOUD_INFERENCE}" == "on") || ("${MSLITE_ENABLE_SERVER_INFERENCE}" == "on") ]] && [[ ${MSLITE_ENABLE_PACKAGE_WHEEL} == "" ]]; then + package_wheel=on + fi + if [[ "${local_lite_platform}" == "x86_64" && "X$CMAKE_TOOLCHAIN_FILE" == "X" ]]; then + if [[ "${package_wheel}" == "on" ]]; then + build_python_wheel_package "x86_64" + fi + if [ "${JAVA_HOME}" ]; then + echo -e "\e[31mJAVA_HOME=$JAVA_HOME \e[0m" + build_lite_jni_and_jar "${CMAKE_ARGS}" "x86_64" + else + echo -e "\e[31mJAVA_HOME is not set, so jni and jar packages will not be compiled \e[0m" + echo -e "\e[31mIf you want to compile the JAR package, please set $JAVA_HOME. For example: export JAVA_HOME=/usr/lib/jvm/java-1.8.0-openjdk-amd64 \e[0m" + fi + elif [[ "${local_lite_platform}" == "arm64" ]] && [[ "${machine}" == "aarch64" ]]; then + if [[ "${package_wheel}" == "on" ]]; then + build_python_wheel_package "aarch64" + fi + if [ "${JAVA_HOME}" ]; then + echo -e "\e[31mJAVA_HOME=$JAVA_HOME \e[0m" + build_lite_jni_and_jar "${CMAKE_ARGS}" "aarch64" + else + echo -e "\e[31mJAVA_HOME is not set, so jni and jar packages will not be compiled \e[0m" + echo -e "\e[31mIf you want to compile the JAR package, please set $JAVA_HOME. For example: export JAVA_HOME=/usr/lib/jvm/java-1.8.0-openjdk-amd64 \e[0m" + fi + fi + fi + if [[ $? -ne 0 ]]; then + echo "---------------- mindspore lite: build failed ----------------" + exit 1 + else + if [ "$(uname)" == "Darwin" ]; then + mkdir -p ${BASEPATH}/output + cp -r ${BASEPATH}/mindspore-lite/build/src/Release-*/mindspore-lite.framework ${BASEPATH}/output/mindspore-lite.framework + # schema files are copied in shell script since they will not generated until xcodebuild is executed. + local schema_path=${BASEPATH}/output/mindspore-lite.framework/Headers/include/schema + mkdir -p ${schema_path} + cp ${BASEPATH}/mindspore-lite/build/schema/model_generated.h ${schema_path} + cp ${BASEPATH}/mindspore-lite/build/schema/ops_generated.h ${schema_path} + cp ${BASEPATH}/mindspore-lite/build/schema/ops_types_generated.h ${schema_path} + if [[ "${MSLITE_ENABLE_COREML}" == "ON" || "${MSLITE_ENABLE_COREML}" == "on" ]]; then + local protobuf_arm_lib=${BASEPATH}/mindspore-lite/build/_deps/protobuf_arm-src/_build/libprotobuf-lite.a + if [ ! -e "$protobuf_arm_lib" ]; then + local protobuf_arm_libpath=$(grep protobuf_arm_LIBPATH ${BASEPATH}/mindspore-lite/build/CMakeCache.txt | cut -d'=' -f2) + protobuf_arm_lib="${protobuf_arm_libpath}/libprotobuf-lite.a" + fi + if [ ! -e "$protobuf_arm_lib" ]; then + echo "failed to find libprotobuf-lite.a to build ios package" + exit 1 + fi + mkdir -p ${BASEPATH}/output/mindspore-lite.framework/third_party/protobuf + cp $protobuf_arm_lib ${BASEPATH}/output/mindspore-lite.framework/third_party/protobuf/ + fi + cd ${BASEPATH}/output + tar -zcvf ${pkg_name}.tar.gz mindspore-lite.framework/ + sha256sum ${pkg_name}.tar.gz > ${pkg_name}.tar.gz.sha256 + rm -r mindspore-lite.framework + else + mv ${INSTALL_PREFIX}/*.tar.gz* ${BASEPATH}/output/ + fi + + if [[ "${local_lite_platform}" == "x86_64" || "${local_lite_platform}" == "arm64" ]]; then + if [[ "${MSLITE_ENABLE_TESTCASES}" == "ON" || "${MSLITE_ENABLE_TESTCASES}" == "on" ]]; then + mkdir -pv ${BASEPATH}/mindspore-lite/test/do_test || true + if [[ ! "${MSLITE_ENABLE_CONVERTER}" || "${MSLITE_ENABLE_CONVERTER}" == "ON" || "${MSLITE_ENABLE_CONVERTER}" == "on" || "${MSLITE_ENABLE_CLOUD_INFERENCE}" == "on" ]]; then + cp ${INSTALL_PREFIX}/mindspore-lite*/tools/converter/lib/*.so* ${BASEPATH}/mindspore-lite/test/do_test || true + cp ${INSTALL_PREFIX}/mindspore-lite*/runtime/lib/*.so* ${BASEPATH}/mindspore-lite/test/do_test || true + fi + cp ${INSTALL_PREFIX}/mindspore-lite*/runtime/lib/*.so* ${BASEPATH}/mindspore-lite/test/do_test || true + if [ -d "${INSTALL_PREFIX}/mindspore-lite*/runtime/third_party/glog" ]; then + cp ${INSTALL_PREFIX}/mindspore-lite*/runtime/third_party/glog/*.so* ${BASEPATH}/mindspore-lite/test/do_test || true + fi + if [[ ! "${MSLITE_ENABLE_TRAIN}" || "${MSLITE_ENABLE_TRAIN}" == "ON" || "${MSLITE_ENABLE_TRAIN}" == "on" ]]; then + cp ${INSTALL_PREFIX}/mindspore-lite*/runtime/third_party/libjpeg-turbo/lib/*.so* ${BASEPATH}/mindspore-lite/test/do_test || true + fi + fi + fi + + rm -rf ${INSTALL_PREFIX:?}/ + if [[ "X$MSLITE_REGISTRY_DEVICE" != "X" ]] && [[ "${MSLITE_REGISTRY_DEVICE}" != "SD3403" ]]; then + local compile_nnie_script=${BASEPATH}/mindspore-lite/tools/providers/NNIE/Hi3516D/compile_nnie.sh + cd ${BASEPATH}/../ + if [[ "${local_lite_platform}" == "x86_64" ]]; then + bash ${compile_nnie_script} -I ${local_lite_platform} -b ${BRANCH_NAME} -j $THREAD_NUM + fi + if [[ $? -ne 0 ]]; then + echo "compile ${local_lite_platform} for nnie failed." + exit 1 + fi + fi + echo "---------------- mindspore lite: build success ----------------" + fi +} + +build_lite_arm64_and_jni() { + local ARM64_CMAKE_ARGS=${CMAKE_ARGS} + build_lite "arm64" + # copy arm64 so + local is_train=on + local pkg_name=mindspore-lite-${VERSION_STR}-android-aarch64 + cd "${BASEPATH}/mindspore-lite/build" + + rm -rf ${pkg_name} + tar -zxf ${BASEPATH}/output/${pkg_name}.tar.gz + rm -rf ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ && mkdir -p ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ + rm -rf ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ && mkdir -p ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ + cp ./${pkg_name}/runtime/lib/libmindspore-lite.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ + cp ./${pkg_name}/runtime/lib/libmindspore-lite.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ + local train_so=$pkg_name/runtime/lib/libmindspore-lite-train.so + if [ ! -f "$train_so" ]; then + echo "not exist" + is_train=off + fi + if [[ "X$is_train" = "Xon" ]]; then + cp ./${pkg_name}/runtime/lib/libmindspore-lite*so* ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ + cp ./${pkg_name}/runtime/lib/libmindspore-lite*so* ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ + fi + local minddata_so=${pkg_name}/runtime/lib/libminddata-lite.so + if [ -e "${minddata_so}" ]; then + cp ./${pkg_name}/runtime/lib/libminddata-lite.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ + cp ./${pkg_name}/runtime/lib/libminddata-lite.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ + fi + local jpeg_turbo_dir=${pkg_name}/runtime/third_party/libjpeg-turbo/lib + if [ -e "$jpeg_turbo_dir" ]; then + cp ./${pkg_name}/runtime/third_party/libjpeg-turbo/lib/*.so* ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ + cp ./${pkg_name}/runtime/third_party/libjpeg-turbo/lib/*.so* ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ + fi + local npu_so=${pkg_name}/runtime/third_party/hiai_ddk/lib/libhiai.so + if [ -e "${npu_so}" ]; then + cp ./${pkg_name}/runtime/third_party/hiai_ddk/lib/*.so* ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ + cp ./${pkg_name}/runtime/third_party/hiai_ddk/lib/*.so* ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ + fi + # build jni so + [ -n "${BASEPATH}" ] && rm -rf java/jni && mkdir -pv java/jni + cd java/jni + if [ -n "${MS_CCACHE_PATH}" ]; then + echo "use ${MS_CCACHE_PATH} to speed up compilation." + ARM64_CMAKE_ARGS="$ARM64_CMAKE_ARGS -DANDROID_CCACHE=${MS_CCACHE_PATH}" + fi + cmake ${ARM64_CMAKE_ARGS} -DSUPPORT_TRAIN=${is_train} -DPLATFORM_ARM64=on \ + -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ + -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ + -DANDROID_STL=${MSLITE_ANDROID_STL} "${LITE_JAVA_PATH}/native/" + make -j$THREAD_NUM + if [[ $? -ne 0 ]]; then + echo "---------------- mindspore lite: build jni arm64 failed----------------" + exit 1 + fi + cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ + cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ + if [[ "X$is_train" = "Xon" ]]; then + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/ + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/arm64-v8a/ + fi +} + +build_lite_arm32_and_jni() { + local ARM32_CMAKE_ARGS=${CMAKE_ARGS} + build_lite "arm32" + # copy arm32 so + local is_train=on + local pkg_name=mindspore-lite-${VERSION_STR}-android-aarch32 + cd "${BASEPATH}/mindspore-lite/build" + + rm -rf ${pkg_name} + tar -zxf ${BASEPATH}/output/${pkg_name}.tar.gz + rm -rf ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ && mkdir -pv ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ + rm -rf ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ && mkdir -pv ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ + cp ./${pkg_name}/runtime/lib/libmindspore-lite.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ + cp ./${pkg_name}/runtime/lib/libmindspore-lite.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ + local train_so=$pkg_name/runtime/lib/libmindspore-lite-train.so + if [ ! -f "$train_so" ]; then + echo "not exist" + is_train=off + fi + if [[ "X$is_train" = "Xon" ]]; then + cp ./${pkg_name}/runtime/lib/libmindspore-lite*so* ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ + cp ./${pkg_name}/runtime/lib/libmindspore-lite*so* ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ + fi + local minddata_so=${pkg_name}/runtime/lib/libminddata-lite.so + if [ -e "${minddata_so}" ]; then + cp ./${pkg_name}/runtime/lib/libminddata-lite.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ + cp ./${pkg_name}/runtime/lib/libminddata-lite.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ + fi + local jpeg_turbo_dir=${pkg_name}/runtime/third_party/libjpeg-turbo/lib + if [ -e "$jpeg_turbo_dir" ]; then + cp ./${pkg_name}/runtime/third_party/libjpeg-turbo/lib/*.so* ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ + cp ./${pkg_name}/runtime/third_party/libjpeg-turbo/lib/*.so* ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ + fi + local npu_so=$pkg_name/runtime/third_party/hiai_ddk/lib/libhiai.so + if [ -e "$npu_so" ]; then + cp ./${pkg_name}/runtime/third_party/hiai_ddk/lib/*.so* ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ + cp ./${pkg_name}/runtime/third_party/hiai_ddk/lib/*.so* ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ + fi + # build jni so + [ -n "${BASEPATH}" ] && rm -rf java/jni && mkdir -pv java/jni + cd java/jni + if [ -n "${MS_CCACHE_PATH}" ]; then + echo "use ${MS_CCACHE_PATH} to speed up compilation." + ARM32_CMAKE_ARGS="$ARM32_CMAKE_ARGS -DANDROID_CCACHE=${MS_CCACHE_PATH}" + fi + cmake ${ARM32_CMAKE_ARGS} -DSUPPORT_TRAIN=${is_train} -DPLATFORM_ARM32=on \ + -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ + -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ + -DANDROID_STL=${MSLITE_ANDROID_STL} "${LITE_JAVA_PATH}/native" + make -j$THREAD_NUM + if [[ $? -ne 0 ]]; then + echo "---------------- mindspore lite: build jni arm32 failed----------------" + exit 1 + fi + cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ + cp ./libmindspore-lite-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ + if [[ "X$is_train" = "Xon" ]]; then + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/java/app/libs/armeabi-v7a/ + cp ./libmindspore-lite-train-jni.so ${LITE_JAVA_PATH}/native/libs/armeabi-v7a/ + fi +} + +build_aar() { + if [[ "X${INC_BUILD}" == "Xoff" ]]; then + [ -n "${BASEPATH}" ] && rm -rf ${BASEPATH}/mindspore-lite/build + fi + cd ${LITE_JAVA_PATH}/java + rm -rf gradle .gradle gradlew gradlew.bat + local gradle_version="" + gradle_version=`gradle --version | grep Gradle | awk '{print$2}'` + if [[ ${gradle_version} == '6.6.1' ]]; then + gradle_command=gradle + else + gradle wrapper --gradle-version 6.6.1 --distribution-type all + gradle_command=${LITE_JAVA_PATH}/java/gradlew + fi + # build new java api module + ${gradle_command} clean -p ${LITE_JAVA_PATH}/ + ${gradle_command} build -p ${LITE_JAVA_PATH}/ -x test + + # build aar + build_lite_arm64_and_jni + build_lite_arm32_and_jni + + cp ${LITE_JAVA_PATH}/build/libs/mindspore-lite-java.jar ${LITE_JAVA_PATH}/java/app/libs + # build aar + ${gradle_command} clean -p ${LITE_JAVA_PATH}/java/app + ${gradle_command} assembleRelease -p ${LITE_JAVA_PATH}/java/app + + cd ${LITE_JAVA_PATH}/java/app/build + [ -n "${BASEPATH}" ] && rm -rf ${BASEPATH}/output/*.tar.gz* + local minddata_so=${LITE_JAVA_PATH}/java/app/libs/arm64-v8a/libminddata-lite.so + if [ -e "${minddata_so}" ]; then + cp ${LITE_JAVA_PATH}/java/app/build/outputs/aar/mindspore-lite.aar ${BASEPATH}/output/mindspore-lite-full-${VERSION_STR}.aar + cd ${BASEPATH}/output + sha256sum mindspore-lite-full-${VERSION_STR}.aar > mindspore-lite-full-${VERSION_STR}.aar.sha256 + else + cp ${LITE_JAVA_PATH}/java/app/build/outputs/aar/mindspore-lite.aar ${BASEPATH}/output/mindspore-lite-${VERSION_STR}.aar + cd ${BASEPATH}/output + sha256sum mindspore-lite-${VERSION_STR}.aar > mindspore-lite-${VERSION_STR}.aar.sha256 + fi +} + +build_lite_x86_64_aarch64_jar() +{ + echo "build_lite_x86_64_aarch64_jar start." + if [[ ! -f ${SERVER_X86_64_PACKAGE_FILE} ]] || [[ ! -f ${SERVER_AARCH64_PACKAGE_FILE} ]]; then + echo "x86_64_package aarch64_package file not exist." + exit 1 + fi + + local x86_64_base_path="" + x86_64_base_path=${SERVER_X86_64_PACKAGE_FILE%/*} + local aarch64_base_path="" + aarch64_base_path=${SERVER_AARCH64_PACKAGE_FILE%/*} + echo "x86_64_base_path: "${x86_64_base_path} + echo "aarch64_base_path: "${aarch64_base_path} + + local x86_64_package_name="" + local aarch64_package_name="" + x86_64_package_name=`basename ${SERVER_X86_64_PACKAGE_FILE} '.tar.gz'` + aarch64_package_name=`basename ${SERVER_AARCH64_PACKAGE_FILE} '.tar.gz'` + echo "x86_64_package_name: "${x86_64_package_name} + echo "aarch64_package_name: "${aarch64_package_name} + + # unzip tar.gz, extract native libs(libmindspore-lite.so,libmindspore-lite-jni.so) + rm -rf ${x86_64_base_path}/tmp ${aarch64_base_path}/tmp + mkdir ${x86_64_base_path}/tmp ${aarch64_base_path}/tmp + tar -zxvf ${SERVER_X86_64_PACKAGE_FILE} -C ${x86_64_base_path}/tmp + tar -zxvf ${SERVER_AARCH64_PACKAGE_FILE} -C ${aarch64_base_path}/tmp + + LITE_JAVA_PATH=${LITE_BASEPATH}/java + local LITE_JAVA_NATIVE_RESOURCE_PATH=${LITE_JAVA_PATH}/src/main/resources/com/mindspore/lite + rm -f ${LITE_JAVA_NATIVE_RESOURCE_PATH}/linux_x86_64/*.so* + rm -f ${LITE_JAVA_NATIVE_RESOURCE_PATH}/linux_aarch64/*.so* + cp ${x86_64_base_path}/tmp/${x86_64_package_name}/runtime/lib/libmindspore-lite.so ${LITE_JAVA_NATIVE_RESOURCE_PATH}/linux_x86_64/ + cp ${x86_64_base_path}/tmp/${x86_64_package_name}/runtime/lib/libmindspore-lite-jni.so ${LITE_JAVA_NATIVE_RESOURCE_PATH}/linux_x86_64/ + cp ${aarch64_base_path}/tmp/${aarch64_package_name}/runtime/lib/libmindspore-lite.so ${LITE_JAVA_NATIVE_RESOURCE_PATH}/linux_aarch64/ + cp ${aarch64_base_path}/tmp/${aarch64_package_name}/runtime/lib/libmindspore-lite-jni.so ${LITE_JAVA_NATIVE_RESOURCE_PATH}/linux_aarch64/ + + if [ -f "${x86_64_base_path}/tmp/${x86_64_package_name}/runtime/third_party/glog/libmindspore_glog.so.0" ]; then + cp ${x86_64_base_path}/tmp/${x86_64_package_name}/runtime/third_party/glog/libmindspore_glog.so* ${LITE_JAVA_NATIVE_RESOURCE_PATH}/linux_x86_64/libmindspore_glog.so + fi + + if [ -f "${aarch64_base_path}/tmp/${aarch64_package_name}/runtime/third_party/glog/libmindspore_glog.so.0" ]; then + cp ${aarch64_base_path}/tmp/${aarch64_package_name}/runtime/third_party/glog/libmindspore_glog.so* ${LITE_JAVA_NATIVE_RESOURCE_PATH}/linux_aarch64/libmindspore_glog.so + fi + # compile jar package + rm -rf ${LITE_JAVA_PATH}/build + # build jar + local gradle_version="" + gradle_version=`gradle --version | grep Gradle | awk '{print$2}'` + if [[ ${gradle_version} == '6.6.1' ]]; then + gradle_command=gradle + else + gradle wrapper --gradle-version 6.6.1 --distribution-type all + gradle_command=${LITE_JAVA_PATH}/java/gradlew + fi + + ${gradle_command} clean -p ${LITE_JAVA_PATH}/ + if [[ "${ENABLE_ASAN}" == "ON" || "${ENABLE_ASAN}" == "on" ]] ; then + ${gradle_command} releaseJar -p ${LITE_JAVA_PATH}/ -x test --info + else + if [[ "${MSLITE_ENABLE_TESTCASES}" == "ON" || "${MSLITE_ENABLE_TESTCASES}" == "on" ]] && [[ "${MSLITE_ENABLE_CLOUD_INFERENCE}" != "on" ]]; then + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${LITE_JAVA_PATH}/native/libs/linux_x86/ + ${gradle_command} releaseJar -p ${LITE_JAVA_PATH}/ --info + else + ${gradle_command} releaseJar -p ${LITE_JAVA_PATH}/ -x test --info + fi + fi + echo "compile jar success." + + # update jar package, compress tar.gz, update tar.gz + cp ${LITE_JAVA_PATH}/build/lib/jar/mindspore-lite-java.jar ${x86_64_base_path}/tmp/${x86_64_package_name}/runtime/lib/ -f + cp ${LITE_JAVA_PATH}/build/lib/jar/mindspore-lite-java.jar ${aarch64_base_path}/tmp/${aarch64_package_name}/runtime/lib/ -f + + cd ${x86_64_base_path}/tmp + tar -zcvf ${x86_64_package_name}.tar.gz ${x86_64_package_name}/ + sha256sum ${x86_64_package_name}.tar.gz > ${x86_64_package_name}.tar.gz.sha256 + rm -f ../${x86_64_package_name}.tar.gz ../${x86_64_package_name}.tar.gz.sha256 + mv ${x86_64_package_name}.tar.gz ../ + mv ${x86_64_package_name}.tar.gz.sha256 ../ + + cd ${aarch64_base_path}/tmp + tar -zcvf ${aarch64_package_name}.tar.gz ${aarch64_package_name}/ + sha256sum ${aarch64_package_name}.tar.gz > ${aarch64_package_name}.tar.gz.sha256 + rm -f ../${aarch64_package_name}.tar.gz ../${aarch64_package_name}.tar.gz.sha256 + mv ${aarch64_package_name}.tar.gz ../ + mv ${aarch64_package_name}.tar.gz.sha256 ../ + + cd ${LITE_BASEPATH} + rm -rf ${x86_64_base_path}/tmp + rm -rf ${aarch64_base_path}/tmp + java -version +} + +LITE_BASEPATH=$(cd "$(dirname $0)"; pwd) +echo "LITE_BASEPATH="${LITE_BASEPATH} +if [[ -z "${BASEPATH}" ]]; then + BASEPATH=${LITE_BASEPATH}/../.. +fi + +INSTALL_PREFIX=${BASEPATH}/output/tmp +LITE_JAVA_PATH=${BASEPATH}/mindspore-lite/java + +CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_VERBOSE=${ENABLE_VERBOSE}" +if [[ "${DEBUG_MODE}" == "on" ]]; then + CMAKE_ARGS="${CMAKE_ARGS} -DCMAKE_BUILD_TYPE=Debug " +else + CMAKE_ARGS="${CMAKE_ARGS} -DCMAKE_BUILD_TYPE=Release " +fi + +if [[ "X$ENABLE_FAST_HASH_TABLE" == "Xon" ]]; then + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_FAST_HASH_TABLE=ON" +else + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_FAST_HASH_TABLE=OFF" +fi +cd ${BASEPATH}/mindspore && git apply ${BASEPATH}/third_party/patch/mindspore/decouple_mindspore.patch +get_version +CMAKE_ARGS="${CMAKE_ARGS} -DVERSION_STR=${VERSION_STR}" + +if [[ "X$LITE_ENABLE_AAR" = "Xon" ]]; then + build_aar +elif [[ "X$LITE_PLATFORM" != "X" ]]; then + build_lite +else + echo "Invalid parameter" +fi + +if [[ -n "${SERVER_X86_64_PACKAGE_FILE}" ]] && [[ -n "${SERVER_AARCH64_PACKAGE_FILE}" ]]; then + build_lite_x86_64_aarch64_jar +fi diff --git a/mindspore-lite/cmake/ccsrc_converter.cmake b/mindspore-lite/cmake/ccsrc_converter.cmake deleted file mode 100644 index 51815b7c9afa54420904a7b2618e4be9d07ac048..0000000000000000000000000000000000000000 --- a/mindspore-lite/cmake/ccsrc_converter.cmake +++ /dev/null @@ -1,120 +0,0 @@ -# Compile ccsrc files in converter independently -if(MSLITE_ENABLE_CONVERTER) - add_definitions(-DPRIMITIVE_WRITEABLE) - add_definitions(-DUSE_GLOG) - set(USE_GLOG on) - if(MSLITE_ENABLE_MODEL_ENCRYPTION AND MSLITE_DEPS_OPENSSL) - add_compile_definitions(ENABLE_OPENSSL) - endif() - - if(ENABLE_GPU) - add_compile_definitions(ENABLE_GPU) - endif() - - set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src) - set(TOOLS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../tools) - set(CCSRC_SRC - ${CCSRC_DIR}/backend/backend_manager/backend_jit_config.cc - ${CCSRC_DIR}/backend/common/optimizer/pattern_engine.cc - ${CCSRC_DIR}/backend/common/optimizer/visitor.cc - ${CCSRC_DIR}/backend/common/optimizer/graph_optimizer.cc - ${CCSRC_DIR}/backend/operator/ops_backend_infer_function.cc - ${OPS_DIR}/kernel/common/kernel.cc - ${OPS_DIR}/kernel/common/kernel_tensor.cc - ${OPS_DIR}/kernel/common/kernel_factory.cc - ${OPS_DIR}/kernel/common/format_utils.cc - ${CCSRC_DIR}/utils/convert_utils.cc - ) - - if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) - set(CCSRC_SRC ${CCSRC_SRC} - ${CCSRC_DIR}/ps/ps_context.cc - ${CCSRC_DIR}/common/thread_pool.cc - ${CCSRC_DIR}/debug/profiler/profiler.cc - ${CCSRC_DIR}/common/pynative/abstract_converter.cc - ${CCSRC_DIR}/plugin/device/cpu/kernel/cpu_kernel.cc - ${CCSRC_DIR}/distributed/cluster/dummy_cluster_context.cc - ${OPS_DIR}/kernel/common/kernel_utils.cc - ${OPS_DIR}/kernel/common/common_utils.cc - ${CCSRC_DIR}/kernel/framework_utils.cc - ${CCSRC_DIR}/kernel/philox_random.cc - ${CCSRC_DIR}/kernel/kash/kernel_pack.cc - ${OPS_DIR}/kernel/common/kernel_build_info.cc - ${OPS_DIR}/kernel/common/oplib/oplib.cc - ${CCSRC_DIR}/kernel/kernel_info.cc - ${CCSRC_DIR}/runtime/device/res_manager/utils/convert_tensor_utils.cc - ${CCSRC_DIR}/utils/ms_device_shape_transfer.cc - ${CCSRC_DIR}/runtime/device/kernel_runtime_manager.cc - ${CCSRC_DIR}/runtime/hardware/device_context_manager.cc - ${CCSRC_DIR}/common/runtime_conf/runtime_conf.cc - ${CCSRC_DIR}/utils/comm_manager.cc - ${CCSRC_DIR}/backend/common/session/exec_order_builder.cc - ${CCSRC_DIR}/backend/common/session/kernel_graph.cc - ${CCSRC_DIR}/backend/common/session/anf_runtime_algorithm.cc - ${CCSRC_DIR}/runtime/device/res_manager/hal_res_manager.cc - ${CCSRC_DIR}/runtime/device/res_manager/multi_stream_controller.cc - ${SRC_DIR}/extendrt/utils/tensor_utils.cc - ) - endif() - - if(NOT WIN32) - set(CCSRC_SRC ${CCSRC_SRC} - ${CCSRC_DIR}/utils/anfalgo.cc - ${CCSRC_DIR}/utils/utils.cc - ${CCSRC_DIR}/utils/parallel_context.cc - ) - endif() - - if(ENABLE_GPU) - add_compile_definitions(ENABLE_GPU) - endif() - - if(MSLITE_ENABLE_GRAPH_KERNEL) - - if(AKG_USE_LLVM) - add_compile_definitions(AKG_USE_LLVM) - message(STATUS "Converter support Graph Kernel CPU backend") - endif() - - if(AKG_ENABLE_D) - add_compile_definitions(AKG_ENABLE_D) - message(STATUS "Converter support Graph Kernel Ascend backend") - endif() - - if(AKG_USE_CUDA) - add_compile_definitions(AKG_USE_CUDA) - message(STATUS "Converter support Graph Kernel CUDA backend") - endif() - - add_compile_definitions(MSLITE_ENABLE_GRAPH_KERNEL) - file(GLOB_RECURSE GRAPH_KERNEL_SRC - ${TOOLS_DIR}/graph_kernel/common/*.cc - ${TOOLS_DIR}/graph_kernel/converter/*.cc - ${CCSRC_DIR}/backend/common/graph_kernel/core/*.cc - ${CCSRC_DIR}/backend/common/graph_kernel/expander/*.cc - ${CCSRC_DIR}/backend/common/graph_kernel/expanders/*.cc - ${CCSRC_DIR}/backend/common/graph_kernel/model/*.cc - ${CCSRC_DIR}/backend/common/graph_kernel/split_model/*.cc - ${CCSRC_DIR}/backend/common/graph_kernel/graph_kernel_flags.cc - ${CCSRC_DIR}/kernel/graph_kernel/graph_kernel_json_generator.cc - ${CCSRC_DIR}/backend/common/optimizer/optimizer.cc - ) - set_property(SOURCE ${GRAPH_KERNEL_SRC} - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GRAPH_KERNEL) - set(CCSRC_SRC - ${CCSRC_SRC} - ${GRAPH_KERNEL_SRC} - ) - endif() - set_property(SOURCE ${CCSRC_SRC} PROPERTY COMPILE_DEFINITIONS - LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" - SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) - add_library(ccsrc_src_mid OBJECT ${CCSRC_SRC}) - add_dependencies(ccsrc_src_mid fbs_src fbs_inner_src) - if(MSLITE_ENABLE_CLOUD_INFERENCE) - add_dependencies(ccsrc_src_mid mindspore-lite-proto) - endif() - target_compile_definitions(ccsrc_src_mid PRIVATE BACKEND_DLL) - target_compile_definitions(ccsrc_src_mid PRIVATE COMMON_DLL) - target_compile_definitions(ccsrc_src_mid PRIVATE OPS_KERNEL_COMMON_DLL) -endif() diff --git a/mindspore-lite/cmake/ccsrc_extendrt.cmake b/mindspore-lite/cmake/ccsrc_extendrt.cmake deleted file mode 100644 index a0fbf76aa0b404cd91e96beab3b750d8e8d8a3ef..0000000000000000000000000000000000000000 --- a/mindspore-lite/cmake/ccsrc_extendrt.cmake +++ /dev/null @@ -1,130 +0,0 @@ -# Compile ccsrc files in extendrt independently -string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") -string(REPLACE "-fvisibility=hidden" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") -string(REPLACE "-fvisibility-inlines-hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") -string(REPLACE "-fvisibility=hidden" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") -if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) - if(PLATFORM_ARM64) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fexceptions") - endif() - set(ENABLE_CPU on) - set(LOAD_PLUGIN_STATIC on) - string(REPLACE "-fno-rtti" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) - string(REPLACE "-fno-rtti" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) - add_compile_definitions(ENABLE_CLOUD_FUSION_INFERENCE) - add_compile_definitions(ENABLE_CLOUD_INFERENCE) - remove_definitions(-DBUILD_LITE_INFERENCE) - - include_directories("${CCSRC_DIR}/ps/core") - file(GLOB_RECURSE COMM_PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "${CCSRC_DIR}/ps/core/protos/*.proto") - ms_protobuf_generate(COMM_PROTO_SRCS COMM_PROTO_HDRS ${COMM_PROTO_IN}) - list(APPEND MSLITE_PROTO_SRC ${COMM_PROTO_SRCS}) - - if(NOT ENABLE_SECURITY) - include_directories("${CCSRC_DIR}/include/backend/debug/profiler/ascend") - file(GLOB_RECURSE PROFILER_PROTO_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "${CCSRC_DIR}/plugin/device/ascend/hal/profiler/memory_profiling.proto") - ms_protobuf_generate(PROFILER_MEM_PROTO_SRC PROFILER_MEM_PROTO_HDRS ${PROFILER_PROTO_LIST}) - list(APPEND MSLITE_PROTO_SRC ${PROFILER_MEM_PROTO_SRC}) - endif() - - include_directories("${CMAKE_BINARY_DIR}/runtime/graph_scheduler/actor/rpc") - file(GLOB_RECURSE RPC_PROTO RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "${CCSRC_DIR}/runtime/graph_scheduler/actor/rpc/protocol/rpc.proto") - ms_protobuf_generate(RPC_PROTO_SRCS RPC_PROTO_HDRS ${RPC_PROTO}) - list(APPEND MSLITE_PROTO_SRC ${RPC_PROTO_SRCS}) - - add_library(mindspore-lite-proto STATIC ${MSLITE_PROTO_SRC}) - - set(ANF_ALG_SRC - ${CCSRC_DIR}/utils/anfalgo.cc - ${CCSRC_DIR}/utils/utils.cc - ${CCSRC_DIR}/utils/parallel_context.cc - ${CCSRC_DIR}/utils/convert_utils.cc) - add_library(mindspore-infer-anfalgo OBJECT ${ANF_ALG_SRC}) - - set(KERNEL_GRAPH_SRC - ${CCSRC_DIR}/backend/common/session/kernel_graph.cc - ${CCSRC_DIR}/backend/common/session/exec_order_builder.cc - ${CCSRC_DIR}/backend/common/session/anf_runtime_algorithm.cc - ${CCSRC_DIR}/backend/common/somas/somas.cc - ${CCSRC_DIR}/backend/common/somas/somas_tensor.cc - ${CCSRC_DIR}/backend/common/somas/somas_solver_pre.cc - ${CCSRC_DIR}/backend/common/somas/somas_solver_core.cc - ${CCSRC_DIR}/backend/common/somas/somas_solver_alg.cc - ${CCSRC_DIR}/backend/operator/ops_backend_infer_function.cc - ${CCSRC_DIR}/backend/graph_compiler/graph_partition.cc - ${CMAKE_CURRENT_SOURCE_DIR}/mock/segment_runner.cc - ${CCSRC_DIR}/utils/ms_device_shape_transfer.cc - ${CCSRC_DIR}/kernel/kernel_info.cc - ${CCSRC_DIR}/runtime/device/kernel_runtime_manager.cc - ${CCSRC_DIR}/runtime/device/kernel_runtime.cc - ${CCSRC_DIR}/runtime/device/memory_scheduler.cc - ${CCSRC_DIR}/runtime/device/memory_offload_strategy.cc - ${CCSRC_DIR}/runtime/device/res_manager/memory_manager.cc - ${CCSRC_DIR}/runtime/device/res_manager/auto_mem_offload.cc - ${CCSRC_DIR}/runtime/device/gsm/mem_usage_analyzer.cc - ${CCSRC_DIR}/runtime/device/gsm/swap_strategy_builder.cc - ${CCSRC_DIR}/runtime/device/common_somas_allocator.cc - ${CCSRC_DIR}/runtime/pynative/op_runtime_info.cc - ${CCSRC_DIR}/common/runtime_conf/runtime_conf.cc - ${OPS_DIR}/kernel/common/kernel_build_info.cc - ${OPS_DIR}/kernel/common/kernel_utils.cc - ${OPS_DIR}/kernel/common/common_utils.cc - ${OPS_DIR}/kernel/common/format_utils.cc - ${CCSRC_DIR}/kernel/framework_utils.cc - ${CCSRC_DIR}/kernel/philox_random.cc - ${OPS_DIR}/kernel/common/kernel_factory.cc - ${OPS_DIR}/kernel/common/kernel.cc - ${CCSRC_DIR}/kernel/kash/kernel_pack.cc - ${OPS_DIR}/kernel/common/oplib/oplib.cc - ${CMAKE_CURRENT_SOURCE_DIR}/mock/anf_ir_dump.cc - ${CCSRC_DIR}/common/debug/common.cc - ${CCSRC_DIR}/common/debug/env_config_parser.cc - ${CCSRC_DIR}/memory/mem_pool/mem_pool_util.cc - ${CCSRC_DIR}/memory/mem_pool/dynamic_mem_pool.cc - ${CCSRC_DIR}/memory/mem_pool/abstract_dynamic_mem_pool.cc - ${CCSRC_DIR}/memory/mem_pool/mem_dynamic_allocator.cc - ${CCSRC_DIR}/memory/mem_pool/mem_tracker.cc - ${CCSRC_DIR}/memory/mem_pool/tracker_graph.cc - ${CCSRC_DIR}/memory/mem_pool/race_checker.cc - ${CCSRC_DIR}/common/thread_pool.cc - ${CCSRC_DIR}/debug/profiler/profiler.cc - ${CCSRC_DIR}/common/pynative/abstract_converter.cc - ${CCSRC_DIR}/utils/scoped_long_running.cc - ${CCSRC_DIR}/utils/cse.cc - ${CCSRC_DIR}/utils/comm_manager.cc - ${CCSRC_DIR}/utils/signal_util.cc - ${CORE_DIR}/utils/status.cc - ) - - add_library(mindspore-kernel-graph OBJECT ${KERNEL_GRAPH_SRC}) - add_dependencies(mindspore-kernel-graph fbs_src fbs_inner_src) - add_dependencies(mindspore-kernel-graph mindspore-lite-proto) - - if(NOT PLATFORM_ARM) - set(KERNEL_MOD_DEPEND_SRC - ${CCSRC_DIR}/kernel/environ_manager.cc - ${CCSRC_DIR}/utils/python_fallback_running.cc - ${CCSRC_DIR}/runtime/device/tensors_queue.cc - ${CCSRC_DIR}/runtime/device/res_manager/tensor_array.cc - ${CCSRC_DIR}/runtime/hardware/device_context_manager.cc - ${CCSRC_DIR}/plugin/device/cpu/hal/device/cpu_tensor_array.cc - ${CCSRC_DIR}/plugin/res_manager/cpu/cpu_mem_manager/cpu_memory_pool.cc - ${CCSRC_DIR}/distributed/embedding_cache/embedding_cache_utils.cc - ${CCSRC_DIR}/distributed/embedding_cache/embedding_hash_map.cc - ${CCSRC_DIR}/distributed/embedding_cache/embedding_storage/dense_embedding_storage.cc - ${CCSRC_DIR}/distributed/embedding_cache/embedding_storage/sparse_embedding_storage.cc - ${CCSRC_DIR}/distributed/embedding_cache/embedding_storage/embedding_storage.cc - ${CCSRC_DIR}/distributed/persistent/storage/local_file.cc - ${CCSRC_DIR}/distributed/persistent/storage/block.cc - ${CCSRC_DIR}/distributed/persistent/storage/json_utils.cc - ${CCSRC_DIR}/distributed/persistent/storage/file_io_utils.cc - ${CCSRC_DIR}/distributed/cluster/dummy_cluster_context.cc - ${CCSRC_DIR}/ps/ps_context.cc - ) - add_library(_mindspore_cpu_kernel_mod_depend_obj OBJECT ${KERNEL_MOD_DEPEND_SRC}) - add_dependencies(_mindspore_cpu_kernel_mod_depend_obj fbs_src fbs_inner_src) - endif() -endif() diff --git a/mindspore-lite/cmake/ccsrc_module.cmake b/mindspore-lite/cmake/ccsrc_module.cmake index 6283cf57f36c320dbef8169df72d2d7645403025..cac8fe6ab0bbb62d97689005961112d996caa545 100644 --- a/mindspore-lite/cmake/ccsrc_module.cmake +++ b/mindspore-lite/cmake/ccsrc_module.cmake @@ -32,6 +32,5 @@ if(Python3_FOUND) include_directories(${TOP_DIR}/mindspore/mindspore/core/mindrt) include_directories(${TOP_DIR}/mindspore/mindspore/core/mindrt/include) include(${TOP_DIR}/cmake/external_libs/pybind11.cmake) - endif() endif() include(${TOP_DIR}/cmake/external_libs/libevent.cmake) diff --git a/mindspore-lite/cmake/compile_link_option.cmake b/mindspore-lite/cmake/compile_link_option.cmake index 51b7eb13c7ed244890be2ce70db9443594b0cdc7..9ecfac207354f768af1b36c0b71835e27256bf3c 100644 --- a/mindspore-lite/cmake/compile_link_option.cmake +++ b/mindspore-lite/cmake/compile_link_option.cmake @@ -1,4 +1,3 @@ - if(MSVC) add_compile_definitions(_ENABLE_ATOMIC_ALIGNMENT_FIX) set(CMAKE_C_FLAGS "/O2 /EHsc /GS /Zi /utf-8") diff --git a/mindspore-lite/cmake/file_list.cmake b/mindspore-lite/cmake/file_list.cmake index 944333c396745e91de83d812b960a488a0641530..e97cf5ba93c842edb0cecf90ca63a2b436468339 100644 --- a/mindspore-lite/cmake/file_list.cmake +++ b/mindspore-lite/cmake/file_list.cmake @@ -6,7 +6,6 @@ set(API_HEADER ${TOP_DIR}/mindspore/include/api/types.h ${TOP_DIR}/mindspore/include/api/visible.h ) - set(ABSTRACT_HEADER ${CORE_INC_DIR}/abstract/abstract_value.h ${CORE_INC_DIR}/abstract/dshape.h diff --git a/mindspore-lite/cmake/lite_compile_definitions.cmake b/mindspore-lite/cmake/lite_compile_definitions.cmake index b2b39b8ce913fe6a514f02aa1991a29587de9601..32ece023c2e49e4231d36d69e66256eb5557344a 100644 --- a/mindspore-lite/cmake/lite_compile_definitions.cmake +++ b/mindspore-lite/cmake/lite_compile_definitions.cmake @@ -27,10 +27,6 @@ if(MSLITE_ENABLE_BFC_MEMORY) add_compile_definitions(BFC_MEMORY) endif() -if(MSLITE_ENABLE_PARALLEL_INFERENCE) - add_compile_definitions(PARALLEL_INFERENCE) -endif() - if(MSLITE_ENABLE_SHARING_MODEL_WEIGHT) add_compile_definitions(SHARING_MODEL_WEIGHT) endif() diff --git a/mindspore-lite/cmake/lite_options.cmake b/mindspore-lite/cmake/lite_options.cmake index 5db0c07fd184e6e427e9c71f96f8ea10394cb23c..7469529ec2536a32e1d329948997474fb29a1718 100644 --- a/mindspore-lite/cmake/lite_options.cmake +++ b/mindspore-lite/cmake/lite_options.cmake @@ -38,10 +38,8 @@ option(MSLITE_ENABLE_SPARSE_COMPUTE "enable sparse kernel" off) option(MSLITE_ENABLE_RUNTIME_CONVERT "enable runtime convert" off) option(MSLITE_ENABLE_RUNTIME_GLOG "enable runtime glog" off) option(MSLITE_ENABLE_COVERAGE "enable code coverage" off) -option(MSLITE_ENABLE_SERVER_INFERENCE "enable inference on server" off) option(MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE "enable distribute thread dynamically" off) option(MSLITE_ENABLE_BFC_MEMORY "enable distribute BFC memory" off) -option(MSLITE_ENABLE_PARALLEL_INFERENCE "enable parallel inference interface" off) option(MSLITE_ENABLE_SHARING_MODEL_WEIGHT "enable sharing model weight" off) option(MSLITE_ENABLE_EXPERIMENTAL_KERNEL "enable experimental kernel" on) option(MSLITE_ENABLE_GRAPH_KERNEL "enable graph kernel" off) @@ -84,13 +82,9 @@ if(DEFINED ENV{MSLITE_ENABLE_TRAIN}) set(MSLITE_ENABLE_TRAIN $ENV{MSLITE_ENABLE_TRAIN}) endif() -if(DEFINED ENV{MSLITE_ENABLE_SERVER_INFERENCE}) - set(MSLITE_ENABLE_SERVER_INFERENCE $ENV{MSLITE_ENABLE_SERVER_INFERENCE}) -endif() -if(MSLITE_ENABLE_SERVER_INFERENCE) +if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) set(MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE on) set(MSLITE_ENABLE_BFC_MEMORY on) - set(MSLITE_ENABLE_PARALLEL_INFERENCE on) set(MSLITE_ENABLE_SHARING_MODEL_WEIGHT on) set(MSLITE_ENABLE_RUNTIME_GLOG on) set(MSLITE_ENABLE_AVX512 on) @@ -110,9 +104,9 @@ endif() if(DEFINED ENV{MSLITE_ENABLE_RUNTIME_CONVERT}) set(MSLITE_ENABLE_RUNTIME_CONVERT $ENV{MSLITE_ENABLE_RUNTIME_CONVERT}) endif() -if(DEFINED ENV{ENABLE_AKG} AND NOT MSLITE_ENABLE_RUNTIME_CONVERT) - set(MSLITE_ENABLE_GRAPH_KERNEL $ENV{ENABLE_AKG}) -endif() +#if(DEFINED ENV{ENABLE_AKG} AND NOT MSLITE_ENABLE_RUNTIME_CONVERT) +# set(MSLITE_ENABLE_GRAPH_KERNEL $ENV{ENABLE_AKG}) +#endif() if(DEFINED ENV{MSLITE_ENABLE_TOOLS}) set(MSLITE_ENABLE_TOOLS $ENV{MSLITE_ENABLE_TOOLS}) endif() @@ -158,7 +152,8 @@ endif() if(DEFINED ENV{MSLITE_ENABLE_ACL}) set(MSLITE_ENABLE_ACL $ENV{MSLITE_ENABLE_ACL}) endif() -if(DEFINED ENV{MSLITE_MINDDATA_IMPLEMENT}) +if(DEFINED ENV{MSLITE_MINDDATA_IMPLEMENT} AND NOT + (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) set(MSLITE_MINDDATA_IMPLEMENT $ENV{MSLITE_MINDDATA_IMPLEMENT}) endif() if(DEFINED ENV{MSLITE_ENABLE_MODEL_ENCRYPTION}) @@ -291,10 +286,6 @@ if(DEFINED ENV{MSLITE_ENABLE_BFC_MEMORY}) set(MSLITE_ENABLE_BFC_MEMORY $ENV{MSLITE_ENABLE_BFC_MEMORY}) endif() -if(DEFINED ENV{MSLITE_ENABLE_PARALLEL_INFERENCE}) - set(MSLITE_ENABLE_PARALLEL_INFERENCE $ENV{MSLITE_ENABLE_PARALLEL_INFERENCE}) -endif() - if(DEFINED ENV{MSLITE_ENABLE_SHARING_MODEL_WEIGHT}) set(MSLITE_ENABLE_SHARING_MODEL_WEIGHT $ENV{MSLITE_ENABLE_SHARING_MODEL_WEIGHT}) endif() @@ -324,7 +315,8 @@ endif() if(MSLITE_ENABLE_TRAIN) set(SUPPORT_TRAIN on) - if(NOT MSLITE_MINDDATA_IMPLEMENT STREQUAL "off" OR NOT PLATFORM_ARM) + if(NOT MSLITE_MINDDATA_IMPLEMENT STREQUAL "off" OR NOT PLATFORM_ARM AND + NOT(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) set(MSLITE_MINDDATA_IMPLEMENT full) endif() endif() @@ -392,13 +384,11 @@ message(STATUS "\tMSLITE_ENABLE_SPARSE_COMPUTE = \t${MSLITE_ENABLE message(STATUS "\tMSLITE_ENABLE_RUNTIME_CONVERT = \t${MSLITE_ENABLE_RUNTIME_CONVERT}") message(STATUS "\tMSLITE_ENABLE_RUNTIME_GLOG = \t${MSLITE_ENABLE_RUNTIME_GLOG}") message(STATUS "\tMSLITE_ENABLE_COVERAGE = \t${MSLITE_ENABLE_COVERAGE}") -message(STATUS "\tMSLITE_ENABLE_SERVER_INFERENCE = \t${MSLITE_ENABLE_SERVER_INFERENCE}") message(STATUS "\tMSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE = \t${MSLITE_ENABLE_DYNAMIC_THREAD_DISTRIBUTE}") message(STATUS "\tMSLITE_ENABLE_BFC_MEMORY = \t${MSLITE_ENABLE_BFC_MEMORY}") -message(STATUS "\tMSLITE_ENABLE_PARALLEL_INFERENCE = \t${MSLITE_ENABLE_PARALLEL_INFERENCE}") message(STATUS "\tMSLITE_ENABLE_SHARING_MODEL_WEIGHT = \t${MSLITE_ENABLE_SHARING_MODEL_WEIGHT}") message(STATUS "\tMSLITE_ENABLE_EXPERIMENTAL_KERNEL = \t${MSLITE_ENABLE_EXPERIMENTAL_KERNEL}") -message(STATUS "\tMSLITE_ENABLE_GRAPH_KERNEL = \t${MSLITE_ENABLE_GRAPH_KERNEL}") +# message(STATUS "\tMSLITE_ENABLE_GRAPH_KERNEL = \t${MSLITE_ENABLE_GRAPH_KERNEL}") message(STATUS "\tMSLITE_ENABLE_KERNEL_EXECUTOR = \t${MSLITE_ENABLE_KERNEL_EXECUTOR}") message(STATUS "\tMSLITE_ENABLE_CLOUD_FUSION_INFERENCE = \t${MSLITE_ENABLE_CLOUD_FUSION_INFERENCE}") message(STATUS "\tMSLITE_ENABLE_CLOUD_INFERENCE = \t${MSLITE_ENABLE_CLOUD_INFERENCE}") diff --git a/mindspore-lite/cmake/mix210.toolchain.cmake b/mindspore-lite/cmake/mix210.toolchain.cmake index 6b2e8f22c205151ab0dd59515bb2d66be81f0242..1a087750952996cbcadabd4078564369df3e434c 100644 --- a/mindspore-lite/cmake/mix210.toolchain.cmake +++ b/mindspore-lite/cmake/mix210.toolchain.cmake @@ -24,8 +24,6 @@ set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) -#set(CMAKE_CXX_FLAGS "-march= -mfloat-abi=softfp -mfpu=neon-vfpv4 ${CMAKE_CXX_FLAGS}") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+fp16") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+fp16") diff --git a/mindspore-lite/examples/cloud_infer/quick_start_parallel_python/quick_start_parallel_python.py b/mindspore-lite/examples/cloud_infer/quick_start_parallel_python/quick_start_parallel_python.py index d198e9801003553ad77ea4e9ef09b9e11b987d77..eaea1c2a7027592e24d9ce61e3f6d8f4296dc904 100644 --- a/mindspore-lite/examples/cloud_infer/quick_start_parallel_python/quick_start_parallel_python.py +++ b/mindspore-lite/examples/cloud_infer/quick_start_parallel_python/quick_start_parallel_python.py @@ -21,7 +21,6 @@ import mindspore_lite as mslite # Use case: serving inference. # Precondition 1: Download MindSpore Lite serving package or building MindSpore Lite serving package by -# export MSLITE_ENABLE_SERVER_INFERENCE=on. # Precondition 2: Install wheel package of MindSpore Lite built by precondition 1. # The result can be find in the tutorial of runtime_parallel_python. # the number of threads of one worker. diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.pbxproj b/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.pbxproj deleted file mode 100644 index 3ea0947b48b4d6acb40d448a7e278edd9d27cba0..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.pbxproj +++ /dev/null @@ -1,431 +0,0 @@ -// !$*UTF8*$! -{ - archiveVersion = 1; - classes = { - }; - objectVersion = 50; - objects = { - -/* Begin PBXBuildFile section */ - 4E3C6FB72764985E00BFE4F8 /* mindspore-lite.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 4E3C6FB62764985E00BFE4F8 /* mindspore-lite.framework */; }; - 4E3C6FBB2764989900BFE4F8 /* test_model.ms.bin in Resources */ = {isa = PBXBuildFile; fileRef = 4E3C6FB82764989900BFE4F8 /* test_model.ms.bin */; }; - 4E3C6FBC2764989900BFE4F8 /* test_model.ms.out in Resources */ = {isa = PBXBuildFile; fileRef = 4E3C6FB92764989900BFE4F8 /* test_model.ms.out */; }; - 4E3C6FBD2764989900BFE4F8 /* test_model.ms in Resources */ = {isa = PBXBuildFile; fileRef = 4E3C6FBA2764989900BFE4F8 /* test_model.ms */; }; - 4E492F322605D4DC003AA9B6 /* SceneDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 4E492F312605D4DC003AA9B6 /* SceneDelegate.m */; }; - 4E492F352605D4DC003AA9B6 /* ViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = 4E492F342605D4DC003AA9B6 /* ViewController.m */; }; - 4E492F382605D4DC003AA9B6 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 4E492F362605D4DC003AA9B6 /* Main.storyboard */; }; - 4E492F3A2605D4DE003AA9B6 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 4E492F392605D4DE003AA9B6 /* Assets.xcassets */; }; - 4E492F3D2605D4DE003AA9B6 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 4E492F3B2605D4DE003AA9B6 /* LaunchScreen.storyboard */; }; - 4E492F402605D4DE003AA9B6 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 4E492F3F2605D4DE003AA9B6 /* main.m */; }; - 4E492F4B2605D58F003AA9B6 /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 4E492F472605D58F003AA9B6 /* AppDelegate.mm */; }; - 4E8D9A972612F60B00C7FBC1 /* Benchmark.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4E492F482605D58F003AA9B6 /* Benchmark.cpp */; }; -/* End PBXBuildFile section */ - -/* Begin PBXBuildRule section */ - 4ED51573262D21EA00CC8DA0 /* PBXBuildRule */ = { - isa = PBXBuildRule; - compilerSpec = com.apple.compilers.proxy.script; - fileType = sourcecode.metal; - inputFiles = ( - ); - isEditable = 1; - outputFiles = ( - ); - script = "# metal\n"; - }; -/* End PBXBuildRule section */ - -/* Begin PBXFileReference section */ - 4E3C6FB62764985E00BFE4F8 /* mindspore-lite.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; path = "mindspore-lite.framework"; sourceTree = ""; }; - 4E3C6FB82764989900BFE4F8 /* test_model.ms.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; name = test_model.ms.bin; path = ../model/test_model.ms.bin; sourceTree = ""; }; - 4E3C6FB92764989900BFE4F8 /* test_model.ms.out */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; name = test_model.ms.out; path = ../model/test_model.ms.out; sourceTree = ""; }; - 4E3C6FBA2764989900BFE4F8 /* test_model.ms */ = {isa = PBXFileReference; lastKnownFileType = file; name = test_model.ms; path = ../model/test_model.ms; sourceTree = ""; }; - 4E492F2A2605D4DC003AA9B6 /* mindspore-lite.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "mindspore-lite.app"; sourceTree = BUILT_PRODUCTS_DIR; }; - 4E492F2D2605D4DC003AA9B6 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; - 4E492F302605D4DC003AA9B6 /* SceneDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = SceneDelegate.h; sourceTree = ""; }; - 4E492F312605D4DC003AA9B6 /* SceneDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = SceneDelegate.m; sourceTree = ""; }; - 4E492F332605D4DC003AA9B6 /* ViewController.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = ViewController.h; sourceTree = ""; }; - 4E492F342605D4DC003AA9B6 /* ViewController.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = ViewController.m; sourceTree = ""; }; - 4E492F372605D4DC003AA9B6 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; }; - 4E492F392605D4DE003AA9B6 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; - 4E492F3C2605D4DE003AA9B6 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = ""; }; - 4E492F3E2605D4DE003AA9B6 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; - 4E492F3F2605D4DE003AA9B6 /* main.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = main.m; sourceTree = ""; }; - 4E492F472605D58F003AA9B6 /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = ""; }; - 4E492F482605D58F003AA9B6 /* Benchmark.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = Benchmark.cpp; sourceTree = ""; }; - 4E492F492605D58F003AA9B6 /* Benchmark.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = Benchmark.hpp; sourceTree = ""; }; -/* End PBXFileReference section */ - -/* Begin PBXFrameworksBuildPhase section */ - 4E492F272605D4DC003AA9B6 /* Frameworks */ = { - isa = PBXFrameworksBuildPhase; - buildActionMask = 2147483647; - files = ( - 4E3C6FB72764985E00BFE4F8 /* mindspore-lite.framework in Frameworks */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXFrameworksBuildPhase section */ - -/* Begin PBXGroup section */ - 4E492F212605D4DC003AA9B6 = { - isa = PBXGroup; - children = ( - 4E3C6FBA2764989900BFE4F8 /* test_model.ms */, - 4E3C6FB82764989900BFE4F8 /* test_model.ms.bin */, - 4E3C6FB92764989900BFE4F8 /* test_model.ms.out */, - 4E492F2C2605D4DC003AA9B6 /* mindspore-lite */, - 4E492F2B2605D4DC003AA9B6 /* Products */, - ); - sourceTree = ""; - }; - 4E492F2B2605D4DC003AA9B6 /* Products */ = { - isa = PBXGroup; - children = ( - 4E492F2A2605D4DC003AA9B6 /* mindspore-lite.app */, - ); - name = Products; - sourceTree = ""; - }; - 4E492F2C2605D4DC003AA9B6 /* mindspore-lite */ = { - isa = PBXGroup; - children = ( - 4E3C6FB62764985E00BFE4F8 /* mindspore-lite.framework */, - 4E492F472605D58F003AA9B6 /* AppDelegate.mm */, - 4E492F482605D58F003AA9B6 /* Benchmark.cpp */, - 4E492F492605D58F003AA9B6 /* Benchmark.hpp */, - 4E492F2D2605D4DC003AA9B6 /* AppDelegate.h */, - 4E492F302605D4DC003AA9B6 /* SceneDelegate.h */, - 4E492F312605D4DC003AA9B6 /* SceneDelegate.m */, - 4E492F332605D4DC003AA9B6 /* ViewController.h */, - 4E492F342605D4DC003AA9B6 /* ViewController.m */, - 4E492F362605D4DC003AA9B6 /* Main.storyboard */, - 4E492F392605D4DE003AA9B6 /* Assets.xcassets */, - 4E492F3B2605D4DE003AA9B6 /* LaunchScreen.storyboard */, - 4E492F3E2605D4DE003AA9B6 /* Info.plist */, - 4E492F3F2605D4DE003AA9B6 /* main.m */, - ); - path = "mindspore-lite"; - sourceTree = ""; - }; -/* End PBXGroup section */ - -/* Begin PBXNativeTarget section */ - 4E492F292605D4DC003AA9B6 /* mindspore-lite */ = { - isa = PBXNativeTarget; - buildConfigurationList = 4E492F432605D4DE003AA9B6 /* Build configuration list for PBXNativeTarget "mindspore-lite" */; - buildPhases = ( - 4E492F262605D4DC003AA9B6 /* Sources */, - 4E492F272605D4DC003AA9B6 /* Frameworks */, - 4E492F282605D4DC003AA9B6 /* Resources */, - ); - buildRules = ( - 4ED51573262D21EA00CC8DA0 /* PBXBuildRule */, - ); - dependencies = ( - ); - name = "mindspore-lite"; - productName = "mindspore-lite"; - productReference = 4E492F2A2605D4DC003AA9B6 /* mindspore-lite.app */; - productType = "com.apple.product-type.application"; - }; -/* End PBXNativeTarget section */ - -/* Begin PBXProject section */ - 4E492F222605D4DC003AA9B6 /* Project object */ = { - isa = PBXProject; - attributes = { - LastUpgradeCheck = 1140; - ORGANIZATIONNAME = mindspore; - TargetAttributes = { - 4E492F292605D4DC003AA9B6 = { - CreatedOnToolsVersion = 11.4.1; - }; - }; - }; - buildConfigurationList = 4E492F252605D4DC003AA9B6 /* Build configuration list for PBXProject "mindspore-lite" */; - compatibilityVersion = "Xcode 9.3"; - developmentRegion = en; - hasScannedForEncodings = 0; - knownRegions = ( - en, - Base, - ); - mainGroup = 4E492F212605D4DC003AA9B6; - productRefGroup = 4E492F2B2605D4DC003AA9B6 /* Products */; - projectDirPath = ""; - projectRoot = ""; - targets = ( - 4E492F292605D4DC003AA9B6 /* mindspore-lite */, - ); - }; -/* End PBXProject section */ - -/* Begin PBXResourcesBuildPhase section */ - 4E492F282605D4DC003AA9B6 /* Resources */ = { - isa = PBXResourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 4E492F3D2605D4DE003AA9B6 /* LaunchScreen.storyboard in Resources */, - 4E492F3A2605D4DE003AA9B6 /* Assets.xcassets in Resources */, - 4E3C6FBB2764989900BFE4F8 /* test_model.ms.bin in Resources */, - 4E3C6FBC2764989900BFE4F8 /* test_model.ms.out in Resources */, - 4E492F382605D4DC003AA9B6 /* Main.storyboard in Resources */, - 4E3C6FBD2764989900BFE4F8 /* test_model.ms in Resources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXResourcesBuildPhase section */ - -/* Begin PBXSourcesBuildPhase section */ - 4E492F262605D4DC003AA9B6 /* Sources */ = { - isa = PBXSourcesBuildPhase; - buildActionMask = 2147483647; - files = ( - 4E492F352605D4DC003AA9B6 /* ViewController.m in Sources */, - 4E492F402605D4DE003AA9B6 /* main.m in Sources */, - 4E492F4B2605D58F003AA9B6 /* AppDelegate.mm in Sources */, - 4E492F322605D4DC003AA9B6 /* SceneDelegate.m in Sources */, - 4E8D9A972612F60B00C7FBC1 /* Benchmark.cpp in Sources */, - ); - runOnlyForDeploymentPostprocessing = 0; - }; -/* End PBXSourcesBuildPhase section */ - -/* Begin PBXVariantGroup section */ - 4E492F362605D4DC003AA9B6 /* Main.storyboard */ = { - isa = PBXVariantGroup; - children = ( - 4E492F372605D4DC003AA9B6 /* Base */, - ); - name = Main.storyboard; - sourceTree = ""; - }; - 4E492F3B2605D4DE003AA9B6 /* LaunchScreen.storyboard */ = { - isa = PBXVariantGroup; - children = ( - 4E492F3C2605D4DE003AA9B6 /* Base */, - ); - name = LaunchScreen.storyboard; - sourceTree = ""; - }; -/* End PBXVariantGroup section */ - -/* Begin XCBuildConfiguration section */ - 4E492F412605D4DE003AA9B6 /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_ANALYZER_NONNULL = YES; - CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; - CLANG_CXX_LIBRARY = "libc++"; - CLANG_ENABLE_CPP_STATIC_DESTRUCTORS = NO; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_ENABLE_OBJC_WEAK = YES; - CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_COMMA = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_DOCUMENTATION_COMMENTS = YES; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; - CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; - CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; - CLANG_WARN_STRICT_PROTOTYPES = YES; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = dwarf; - ENABLE_STRICT_OBJC_MSGSEND = YES; - ENABLE_TESTABILITY = YES; - GCC_C_LANGUAGE_STANDARD = gnu11; - GCC_DYNAMIC_NO_PIC = NO; - GCC_INLINES_ARE_PRIVATE_EXTERN = YES; - GCC_NO_COMMON_BLOCKS = YES; - GCC_OPTIMIZATION_LEVEL = 0; - GCC_PREPROCESSOR_DEFINITIONS = ( - "DEBUG=1", - "$(inherited)", - ); - GCC_SYMBOLS_PRIVATE_EXTERN = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 13.4; - MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; - MTL_FAST_MATH = YES; - ONLY_ACTIVE_ARCH = YES; - SDKROOT = iphoneos; - VALID_ARCHS = "arm64 arm64e armv7 armv7s x86_64"; - }; - name = Debug; - }; - 4E492F422605D4DE003AA9B6 /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ALWAYS_SEARCH_USER_PATHS = NO; - CLANG_ANALYZER_NONNULL = YES; - CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; - CLANG_CXX_LIBRARY = "libc++"; - CLANG_ENABLE_CPP_STATIC_DESTRUCTORS = NO; - CLANG_ENABLE_MODULES = YES; - CLANG_ENABLE_OBJC_ARC = YES; - CLANG_ENABLE_OBJC_WEAK = YES; - CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; - CLANG_WARN_BOOL_CONVERSION = YES; - CLANG_WARN_COMMA = YES; - CLANG_WARN_CONSTANT_CONVERSION = YES; - CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; - CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; - CLANG_WARN_DOCUMENTATION_COMMENTS = YES; - CLANG_WARN_EMPTY_BODY = YES; - CLANG_WARN_ENUM_CONVERSION = YES; - CLANG_WARN_INFINITE_RECURSION = YES; - CLANG_WARN_INT_CONVERSION = YES; - CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; - CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; - CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; - CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; - CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; - CLANG_WARN_STRICT_PROTOTYPES = YES; - CLANG_WARN_SUSPICIOUS_MOVE = YES; - CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; - CLANG_WARN_UNREACHABLE_CODE = YES; - CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; - COPY_PHASE_STRIP = NO; - DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; - ENABLE_NS_ASSERTIONS = NO; - ENABLE_STRICT_OBJC_MSGSEND = YES; - GCC_C_LANGUAGE_STANDARD = gnu11; - GCC_INLINES_ARE_PRIVATE_EXTERN = YES; - GCC_NO_COMMON_BLOCKS = YES; - GCC_SYMBOLS_PRIVATE_EXTERN = YES; - GCC_WARN_64_TO_32_BIT_CONVERSION = YES; - GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; - GCC_WARN_UNDECLARED_SELECTOR = YES; - GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; - GCC_WARN_UNUSED_FUNCTION = YES; - GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 13.4; - MTL_ENABLE_DEBUG_INFO = NO; - MTL_FAST_MATH = YES; - SDKROOT = iphoneos; - VALIDATE_PRODUCT = YES; - VALID_ARCHS = "arm64 arm64e armv7 armv7s x86_64"; - }; - name = Release; - }; - 4E492F442605D4DE003AA9B6 /* Debug */ = { - isa = XCBuildConfiguration; - buildSettings = { - ARCHS = "$(ARCHS_STANDARD)"; - ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; - CLANG_ENABLE_CPP_STATIC_DESTRUCTORS = NO; - CODE_SIGN_STYLE = Automatic; - COPY_PHASE_STRIP = YES; - DEVELOPMENT_TEAM = B6BQP4K5VM; - ENABLE_BITCODE = NO; - FRAMEWORK_SEARCH_PATHS = ( - "$(inherited)", - "$(PROJECT_DIR)/mindspore-lite", - ); - FRAMEWORK_VERSION = C; - GCC_INCREASE_PRECOMPILED_HEADER_SHARING = NO; - INFOPLIST_EXPAND_BUILD_SETTINGS = YES; - INFOPLIST_FILE = "mindspore-lite/Info.plist"; - IPHONEOS_DEPLOYMENT_TARGET = 9.0; - LD_RUNPATH_SEARCH_PATHS = ( - "$(inherited)", - "@executable_path/Frameworks", - ); - ONLY_ACTIVE_ARCH = NO; - OTHER_LDFLAGS = "-all_load"; - PRODUCT_BUNDLE_IDENTIFIER = "mindspore.mindspore-lite"; - PRODUCT_NAME = "$(TARGET_NAME)"; - SCAN_ALL_SOURCE_FILES_FOR_INCLUDES = NO; - SUPPORTED_PLATFORMS = "iphonesimulator iphoneos"; - SYSTEM_HEADER_SEARCH_PATHS = "$(PROJECT_DIR)/mindspore-lite/mindspore-lite.framework/Headers"; - TARGETED_DEVICE_FAMILY = "1,2"; - USER_HEADER_SEARCH_PATHS = "$(PROJECT_DIR)/mindspore-lite/mindspore-lite.framework/Headers"; - VALIDATE_PRODUCT = YES; - VALID_ARCHS = arm64; - }; - name = Debug; - }; - 4E492F452605D4DE003AA9B6 /* Release */ = { - isa = XCBuildConfiguration; - buildSettings = { - ARCHS = "$(ARCHS_STANDARD)"; - ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; - CLANG_ENABLE_CPP_STATIC_DESTRUCTORS = NO; - CODE_SIGN_STYLE = Automatic; - COPY_PHASE_STRIP = YES; - DEVELOPMENT_TEAM = B6BQP4K5VM; - ENABLE_BITCODE = NO; - FRAMEWORK_SEARCH_PATHS = ( - "$(inherited)", - "$(PROJECT_DIR)/mindspore-lite", - ); - FRAMEWORK_VERSION = C; - GCC_INCREASE_PRECOMPILED_HEADER_SHARING = NO; - INFOPLIST_EXPAND_BUILD_SETTINGS = YES; - INFOPLIST_FILE = "mindspore-lite/Info.plist"; - IPHONEOS_DEPLOYMENT_TARGET = 9.0; - LD_RUNPATH_SEARCH_PATHS = ( - "$(inherited)", - "@executable_path/Frameworks", - ); - ONLY_ACTIVE_ARCH = NO; - OTHER_LDFLAGS = "-all_load"; - PRODUCT_BUNDLE_IDENTIFIER = "mindspore.mindspore-lite"; - PRODUCT_NAME = "$(TARGET_NAME)"; - SCAN_ALL_SOURCE_FILES_FOR_INCLUDES = NO; - SUPPORTED_PLATFORMS = "iphonesimulator iphoneos"; - SYSTEM_HEADER_SEARCH_PATHS = "$(PROJECT_DIR)/mindspore-lite/mindspore-lite.framework/Headers"; - TARGETED_DEVICE_FAMILY = "1,2"; - USER_HEADER_SEARCH_PATHS = "$(PROJECT_DIR)/mindspore-lite/mindspore-lite.framework/Headers"; - VALIDATE_PRODUCT = YES; - VALID_ARCHS = arm64; - }; - name = Release; - }; -/* End XCBuildConfiguration section */ - -/* Begin XCConfigurationList section */ - 4E492F252605D4DC003AA9B6 /* Build configuration list for PBXProject "mindspore-lite" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 4E492F412605D4DE003AA9B6 /* Debug */, - 4E492F422605D4DE003AA9B6 /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; - 4E492F432605D4DE003AA9B6 /* Build configuration list for PBXNativeTarget "mindspore-lite" */ = { - isa = XCConfigurationList; - buildConfigurations = ( - 4E492F442605D4DE003AA9B6 /* Debug */, - 4E492F452605D4DE003AA9B6 /* Release */, - ); - defaultConfigurationIsVisible = 0; - defaultConfigurationName = Release; - }; -/* End XCConfigurationList section */ - }; - rootObject = 4E492F222605D4DC003AA9B6 /* Project object */; -} diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/contents.xcworkspacedata deleted file mode 100644 index 2424244d7c4acde0dcc6dc821f0ce492a9bfdc0c..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/contents.xcworkspacedata +++ /dev/null @@ -1,7 +0,0 @@ - - - - - diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist deleted file mode 100644 index 18d981003d68d0546c4804ac2ff47dd97c6e7921..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +++ /dev/null @@ -1,8 +0,0 @@ - - - - - IDEDidComputeMac32BitWarning - - - diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings b/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings deleted file mode 100644 index 6b30c7459cb23dc90f4bc2a2fb3228c58f9bfe02..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings +++ /dev/null @@ -1,10 +0,0 @@ - - - - - BuildSystemType - Original - PreviewsEnabled - - - diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcuserdata/ms.xcuserdatad/IDEFindNavigatorScopes.plist b/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcuserdata/ms.xcuserdatad/IDEFindNavigatorScopes.plist deleted file mode 100644 index 5dd5da85fdbd81ad600c193382e3305209b9e392..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcuserdata/ms.xcuserdatad/IDEFindNavigatorScopes.plist +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcuserdata/ms.xcuserdatad/UserInterfaceState.xcuserstate b/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcuserdata/ms.xcuserdatad/UserInterfaceState.xcuserstate deleted file mode 100644 index 4c0d09407145169223d0305f9c79e97c4b668aae..0000000000000000000000000000000000000000 Binary files a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcuserdata/ms.xcuserdatad/UserInterfaceState.xcuserstate and /dev/null differ diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcuserdata/ms.xcuserdatad/WorkspaceSettings.xcsettings b/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcuserdata/ms.xcuserdatad/WorkspaceSettings.xcsettings deleted file mode 100644 index 379adbed5a64c54e605e0a7b5ec7b19fc20e53a4..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/project.xcworkspace/xcuserdata/ms.xcuserdatad/WorkspaceSettings.xcsettings +++ /dev/null @@ -1,18 +0,0 @@ - - - - - BuildLocationStyle - UseAppPreferences - CustomBuildLocationType - RelativeToDerivedData - DerivedDataLocationStyle - Default - IssueFilterStyle - ShowActiveSchemeOnly - LiveSourceIssuesEnabled - - ShowSharedSchemesAutomaticallyEnabled - - - diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/xcuserdata/ms.xcuserdatad/xcdebugger/Breakpoints_v2.xcbkptlist b/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/xcuserdata/ms.xcuserdatad/xcdebugger/Breakpoints_v2.xcbkptlist deleted file mode 100644 index 69906e31fe22d385fab93cd75a6a22e63f357c9f..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/xcuserdata/ms.xcuserdatad/xcdebugger/Breakpoints_v2.xcbkptlist +++ /dev/null @@ -1,40 +0,0 @@ - - - - - - - - - - - - - diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/xcuserdata/ms.xcuserdatad/xcschemes/xcschememanagement.plist b/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/xcuserdata/ms.xcuserdatad/xcschemes/xcschememanagement.plist deleted file mode 100644 index 883ded4f382819e39a7e9fe5e0adca73e1f27af8..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite.xcodeproj/xcuserdata/ms.xcuserdatad/xcschemes/xcschememanagement.plist +++ /dev/null @@ -1,14 +0,0 @@ - - - - - SchemeUserState - - mindspore-lite.xcscheme_^#shared#^_ - - orderHint - 0 - - - - diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/Assets.xcassets/AppIcon.appiconset/Contents.json b/mindspore-lite/examples/quick_start_ios/mindspore-lite/Assets.xcassets/AppIcon.appiconset/Contents.json deleted file mode 100644 index 9221b9bb1a35f5de270a41afa01305478221ae32..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/Assets.xcassets/AppIcon.appiconset/Contents.json +++ /dev/null @@ -1,98 +0,0 @@ -{ - "images" : [ - { - "idiom" : "iphone", - "scale" : "2x", - "size" : "20x20" - }, - { - "idiom" : "iphone", - "scale" : "3x", - "size" : "20x20" - }, - { - "idiom" : "iphone", - "scale" : "2x", - "size" : "29x29" - }, - { - "idiom" : "iphone", - "scale" : "3x", - "size" : "29x29" - }, - { - "idiom" : "iphone", - "scale" : "2x", - "size" : "40x40" - }, - { - "idiom" : "iphone", - "scale" : "3x", - "size" : "40x40" - }, - { - "idiom" : "iphone", - "scale" : "2x", - "size" : "60x60" - }, - { - "idiom" : "iphone", - "scale" : "3x", - "size" : "60x60" - }, - { - "idiom" : "ipad", - "scale" : "1x", - "size" : "20x20" - }, - { - "idiom" : "ipad", - "scale" : "2x", - "size" : "20x20" - }, - { - "idiom" : "ipad", - "scale" : "1x", - "size" : "29x29" - }, - { - "idiom" : "ipad", - "scale" : "2x", - "size" : "29x29" - }, - { - "idiom" : "ipad", - "scale" : "1x", - "size" : "40x40" - }, - { - "idiom" : "ipad", - "scale" : "2x", - "size" : "40x40" - }, - { - "idiom" : "ipad", - "scale" : "1x", - "size" : "76x76" - }, - { - "idiom" : "ipad", - "scale" : "2x", - "size" : "76x76" - }, - { - "idiom" : "ipad", - "scale" : "2x", - "size" : "83.5x83.5" - }, - { - "idiom" : "ios-marketing", - "scale" : "1x", - "size" : "1024x1024" - } - ], - "info" : { - "author" : "xcode", - "version" : 1 - } -} diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/Assets.xcassets/Contents.json b/mindspore-lite/examples/quick_start_ios/mindspore-lite/Assets.xcassets/Contents.json deleted file mode 100644 index 73c00596a7fca3f3d4bdd64053b69d86745f9e10..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/Assets.xcassets/Contents.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "info" : { - "author" : "xcode", - "version" : 1 - } -} diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/Base.lproj/LaunchScreen.storyboard b/mindspore-lite/examples/quick_start_ios/mindspore-lite/Base.lproj/LaunchScreen.storyboard deleted file mode 100644 index 865e9329f3767a7c1dd66294b8025bf81dee7d2c..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/Base.lproj/LaunchScreen.storyboard +++ /dev/null @@ -1,25 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/Base.lproj/Main.storyboard b/mindspore-lite/examples/quick_start_ios/mindspore-lite/Base.lproj/Main.storyboard deleted file mode 100644 index 808a21ce779bae61839ac1803bc4e2c854578f5e..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/Base.lproj/Main.storyboard +++ /dev/null @@ -1,24 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/Benchmark.hpp b/mindspore-lite/examples/quick_start_ios/mindspore-lite/Benchmark.hpp deleted file mode 100644 index 78eeb19c8569588b7eb67f4e9cc360c4f65dbe74..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/Benchmark.hpp +++ /dev/null @@ -1,276 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef BENCHMARK_H_ -#define BENCHMARK_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "mindspore-lite/include/model.h" -#include "mindspore-lite/include/lite_session.h" -#include "mindspore-lite/include/errorcode.h" - -constexpr float kNumRelativeTolerance = 1e-5; -constexpr float kNumAbsoluteTolerance = 1e-8; -constexpr float kNumMinAbsoluteError = 1e-5; -constexpr float kNumDefaultLoopCount = 10; -constexpr float kNumDefaultNumThread = 2; -constexpr float kNumDefaultWarmUpLoopCount = 3; -constexpr float kNumDefaultAccuracyThreshold = 0.5; -constexpr float kNumDefaultPrintLen = 50; -constexpr float kNumPercentage = 100; -constexpr float kNumMinMeanError = 0.0000001; - -enum InDataType { kImage = 0, kBinary = 1 }; -namespace mindspore::lite { -#ifdef ENABLE_ARM64 -struct PerfResult { - int64_t nr; - struct { - int64_t value; - int64_t id; - } values[2]; -}; -struct PerfCount { - int64_t value[2]; -}; -#endif - -struct CheckTensor { - CheckTensor(const std::vector &shape, const std::vector &data, - const std::vector &strings_data = {""}) { - this->shape = shape; - this->data = data; - this->strings_data = strings_data; - } - std::vector shape; - std::vector data; - std::vector strings_data; -}; - -struct Flag { - std::string model_file_; - std::string in_data_file_; - std::vector input_data_list_; - InDataType in_data_type_ = kBinary; - std::string in_data_type_in_ = "bin"; - int cpu_bind_mode_ = HIGHER_CPU; - // MarkPerformance - int loop_count_ = kNumDefaultLoopCount; - int num_threads_ = kNumDefaultNumThread; - bool enable_fp16_ = false; - int warm_up_loop_count_ = kNumDefaultWarmUpLoopCount; - // MarkAccuracy - std::string benchmark_data_file_; - std::string benchmark_data_type_ = "FLOAT"; - float accuracy_threshold_ = kNumDefaultAccuracyThreshold; - // Resize - std::string resize_dims_in_; - std::vector> resize_dims_; - - std::string device_ = "CPU"; - bool time_profiling_ = false; - bool perf_profiling_ = false; - std::string perf_event_ = "CYCLE"; - bool dump_profiling_ = false; -}; - -class Benchmark { - public: - explicit Benchmark(Flag *flag) : flags_(flag) {} - - virtual ~Benchmark(); - - int Init(); - int RunBenchmark(); - - private: - // call GenerateInputData or ReadInputFile to init inputTensors - int LoadInput(); - - // call GenerateRandomData to fill inputTensors - int GenerateInputData(); - - int GenerateRandomData(size_t size, void *data, TypeId data_type); - - int ReadInputFile(); - - int ReadCalibData(); - - int ReadTensorData(std::ifstream &in_file_stream, const std::string &tensor_name, const std::vector &dims); - - int CompareOutput(); - - tensor::MSTensor *GetTensorByNameOrShape(const std::string &node_or_tensor_name, const std::vector &dims); - - tensor::MSTensor *GetTensorByNodeShape(const std::vector &node_shape); - - int CompareStringData(const std::string &name, tensor::MSTensor *tensor); - - int CompareDataGetTotalBiasAndSize(const std::string &name, tensor::MSTensor *tensor, float *total_bias, - int *total_size); - - int InitCallbackParameter(); - - int InitTimeProfilingCallbackParameter(); - - int InitPerfProfilingCallbackParameter(); - - int InitDumpProfilingCallbackParameter(); - - int PrintResult(const std::vector &title, const std::map> &result); - - void InitContext(const std::shared_ptr &context); - -#ifdef ENABLE_ARM64 - int PrintPerfResult(const std::vector &title, - const std::map> &result); -#endif - - int PrintInputData(); - - // tensorData need to be converter first - template - float CompareData(const std::string &nodeName, const std::vector &msShape, const void *tensor_data) { - const T *msTensorData = static_cast(tensor_data); - auto iter = this->benchmark_data_.find(nodeName); - if (iter != this->benchmark_data_.end()) { - std::vector castedMSShape; - size_t shapeSize = 1; - for (int64_t dim : msShape) { - castedMSShape.push_back(size_t(dim)); - shapeSize *= dim; - } - - CheckTensor *calibTensor = iter->second; - if (calibTensor->shape != castedMSShape) { - std::cout << "Shape of mslite output("; - for (auto dim : castedMSShape) { - std::cout << dim << ","; - } - std::cout << ") and shape source model output("; - for (auto dim : calibTensor->shape) { - std::cout << dim << ","; - } - std::cout << ") are different"; - return RET_ERROR; - } - size_t errorCount = 0; - float meanError = 0; - std::cout << "Data of node " << nodeName << " : "; - for (size_t j = 0; j < shapeSize; j++) { - if (j < kNumDefaultPrintLen) { - std::cout << static_cast(msTensorData[j]) << " "; - } - - if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { - std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; - return RET_ERROR; - } - - auto tolerance = kNumAbsoluteTolerance + kNumRelativeTolerance * fabs(calibTensor->data.at(j)); - auto absoluteError = std::fabs(msTensorData[j] - calibTensor->data.at(j)); - if (absoluteError > tolerance) { - if (fabs(calibTensor->data.at(j) - 0.0f) < FLT_EPSILON) { - if (absoluteError > kNumMinAbsoluteError) { - meanError += absoluteError; - errorCount++; - } else { - continue; - } - } else { - // just assume that atol = rtol - meanError += absoluteError / (fabs(calibTensor->data.at(j)) + FLT_MIN); - errorCount++; - } - } - } - std::cout << std::endl; - if (meanError > 0.0f) { - meanError /= errorCount; - } - - if (meanError <= kNumMinMeanError) { - std::cout << "Mean bias of node/tensor " << nodeName << " : 0%" << std::endl; - } else { - std::cout << "Mean bias of node/tensor " << nodeName << " : " << meanError * kNumPercentage << "%" << std::endl; - } - return meanError; - } else { - std::cout << "%s is not in Source Model output" << nodeName.c_str(); - return RET_ERROR; - } - } - - template - void FillInputData(int size, void *data, Distribution distribution) { - if (data == nullptr) { - std::cout << "data is nullptr."; - return; - } - int elements_num = size / sizeof(T); - (void)std::generate_n(static_cast(data), elements_num, - [&]() { return static_cast(distribution(random_engine_)); }); - } - - int MarkPerformance(); - - int MarkAccuracy(); - - private: - struct Flag *flags_; - session::LiteSession *session_{nullptr}; - std::vector ms_inputs_; - std::unordered_map> ms_outputs_; - std::unordered_map benchmark_data_; - std::unordered_map data_type_map_{{"FLOAT", TypeId::kNumberTypeFloat}, - {"INT8", TypeId::kNumberTypeInt8}, - {"INT32", TypeId::kNumberTypeInt32}, - {"UINT8", TypeId::kNumberTypeUInt8}}; - TypeId msCalibDataType = TypeId::kNumberTypeFloat; - - // callback parameters - uint64_t op_begin_ = 0; - int op_call_times_total_ = 0; - float op_cost_total_ = 0.0f; - std::map> op_times_by_type_; - std::map> op_times_by_name_; -#ifdef ENABLE_ARM64 - int perf_fd = 0; - int perf_fd2 = 0; - float op_cost2_total_ = 0.0f; - std::map> op_perf_by_type_; - std::map> op_perf_by_name_; -#endif - KernelCallBack before_call_back_ = nullptr; - KernelCallBack after_call_back_ = nullptr; - std::mt19937 random_engine_; -}; -int RunBenchmark(Flag *flags); -int main_benchmark(); -} // namespace mindspore::lite -#endif // MINNIE_BENCHMARK_BENCHMARK_H_ diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/Info.plist b/mindspore-lite/examples/quick_start_ios/mindspore-lite/Info.plist deleted file mode 100644 index ca4ce4bab14d5e159b8e37156420414f08f0bcfd..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/Info.plist +++ /dev/null @@ -1,60 +0,0 @@ - - - - - CFBundleDevelopmentRegion - $(DEVELOPMENT_LANGUAGE) - CFBundleExecutable - $(EXECUTABLE_NAME) - CFBundleIdentifier - $(PRODUCT_BUNDLE_IDENTIFIER) - CFBundleInfoDictionaryVersion - 6.0 - CFBundleName - $(PRODUCT_NAME) - CFBundlePackageType - $(PRODUCT_BUNDLE_PACKAGE_TYPE) - CFBundleShortVersionString - 1.0 - CFBundleVersion - 1 - LSRequiresIPhoneOS - - UIApplicationSceneManifest - - UIApplicationSupportsMultipleScenes - - UISceneConfigurations - - UIWindowSceneSessionRoleApplication - - - UISceneConfigurationName - Default Configuration - UISceneDelegateClassName - SceneDelegate - UISceneStoryboardFile - Main - - - - - UILaunchStoryboardName - LaunchScreen - UIMainStoryboardFile - Main - UIRequiredDeviceCapabilities - - armv7 - - UISupportedInterfaceOrientations - - UISupportedInterfaceOrientations~ipad - - UIInterfaceOrientationPortrait - UIInterfaceOrientationPortraitUpsideDown - UIInterfaceOrientationLandscapeLeft - UIInterfaceOrientationLandscapeRight - - - diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/SceneDelegate.m b/mindspore-lite/examples/quick_start_ios/mindspore-lite/SceneDelegate.m deleted file mode 100644 index ce197af4322ab580b212b5832b2485e443df875f..0000000000000000000000000000000000000000 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/SceneDelegate.m +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#import "SceneDelegate.h" - -@interface SceneDelegate () - -@end - -@implementation SceneDelegate - - -- (void)scene:(UIScene *)scene willConnectToSession:(UISceneSession *)session options:(UISceneConnectionOptions *)connectionOptions { - // Use this method to optionally configure and attach the UIWindow `window` to the provided UIWindowScene `scene`. - // If using a storyboard, the `window` property will automatically be initialized and attached to the scene. - // This delegate does not imply the connecting scene or session are new (see `application:configurationForConnectingSceneSession` instead). -} - - -- (void)sceneDidDisconnect:(UIScene *)scene { - // Called as the scene is being released by the system. - // This occurs shortly after the scene enters the background, or when its session is discarded. - // Release any resources associated with this scene that can be re-created the next time the scene connects. - // The scene may re-connect later, as its session was not necessarily discarded (see `application:didDiscardSceneSessions` instead). -} - - -- (void)sceneDidBecomeActive:(UIScene *)scene { - // Called when the scene has moved from an inactive state to an active state. - // Use this method to restart any tasks that were paused (or not yet started) when the scene was inactive. -} - - -- (void)sceneWillResignActive:(UIScene *)scene { - // Called when the scene will move from an active state to an inactive state. - // This may occur due to temporary interruptions (ex. an incoming phone call). -} - - -- (void)sceneWillEnterForeground:(UIScene *)scene { - // Called as the scene transitions from the background to the foreground. - // Use this method to undo the changes made on entering the background. -} - - -- (void)sceneDidEnterBackground:(UIScene *)scene { - // Called as the scene transitions from the foreground to the background. - // Use this method to save data, release shared resources, and store enough scene-specific state information - // to restore the scene back to its current state. -} - - -@end diff --git a/mindspore-lite/java/native/CMakeLists.txt b/mindspore-lite/java/native/CMakeLists.txt index bb389e811117a693be8b3e41e80ae4eb39dd95f6..5fb1439818ceaeda45d72b3a3a05fc694f555c01 100644 --- a/mindspore-lite/java/native/CMakeLists.txt +++ b/mindspore-lite/java/native/CMakeLists.txt @@ -21,7 +21,7 @@ endif() if(CMAKE_BUILD_TYPE) if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") - message("build release") + message("build release") set(CMAKE_CXX_VISIBILITY_PRESET hidden) set(CMAKE_C_VISIBILITY_PRESET hidden) endif() @@ -181,9 +181,9 @@ else() target_link_libraries(mindspore-lite-jni ${LITE_SO_NAME}) if(USE_GLOG) find_library(GLOG_LIB - NAMES mindspore_glog libmindspore_glog.so.0 - PATHS ${PLATFORM_DIR} - ) + NAMES mindspore_glog libmindspore_glog.so.0 + PATHS ${PLATFORM_DIR} + ) if(GLOG_LIB) message("Found glog lib :${GLOG_LIB}") else() @@ -200,7 +200,6 @@ if(SUPPORT_TRAIN) ) if(USE_GLOG) - set_property(SOURCE ${CCSRC_SRC} PROPERTY COMPILE_DEFINITIONS LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) diff --git a/mindspore-lite/java/src/main/java/com/mindspore/config/NativeLibrary.java b/mindspore-lite/java/src/main/java/com/mindspore/config/NativeLibrary.java index 23fa8d86f8e8dd805e4f9a0926982a266bffe8b7..8045852129c7c7f0cd84652b75a495f79f85d25a 100644 --- a/mindspore-lite/java/src/main/java/com/mindspore/config/NativeLibrary.java +++ b/mindspore-lite/java/src/main/java/com/mindspore/config/NativeLibrary.java @@ -54,7 +54,7 @@ public class NativeLibrary { private static final String ASCEND_GE_PLUGIN_LIBNAME = "ascend_ge_plugin"; private static final String ASCEND_PASS_PLUGIN_LIBNAME = "ascend_pass_plugin"; private static final String TENSORRT_PLUGIN_LIBNAME = "tensorrt_plugin"; - private static final String MSLITE_SHARED_LIB_LIBNAME = "mslite_shared_lib"; +// private static final String MSLITE_SHARED_LIB_LIBNAME = "mslite_shared_lib"; private static final String TRANSFORMER_SHARED_LIB_LIBNAME = "transformer-shared"; private static final String MINDSPORE_GRAPH_IR_LIBNAME = "mindspore_graph_ir"; private static Long timestamp = null; @@ -162,7 +162,7 @@ public class NativeLibrary { extractLib(makeResourceName("lib" + TRANSFORMER_SHARED_LIB_LIBNAME + ".so"), tmpDir); extractLib(makeResourceName("lib" + TENSORRT_PLUGIN_LIBNAME + ".so"), tmpDir); } else if (("lib" + MINDSPORE_CONVERTER_LIBNAME + ".so").equals(libName)) { - extractLib(makeResourceName("lib" + MSLITE_SHARED_LIB_LIBNAME + ".so"), tmpDir); +// extractLib(makeResourceName("lib" + MSLITE_SHARED_LIB_LIBNAME + ".so"), tmpDir); extractLib(makeResourceName("lib" + ASCEND_PASS_PLUGIN_LIBNAME + ".so"), tmpDir); extractLib(makeResourceName("lib" + MINDSPORE_GRAPH_IR_LIBNAME + ".so"), tmpDir); } diff --git a/mindspore-lite/minddata/dataset/api/vision.cc b/mindspore-lite/minddata/dataset/api/vision.cc index 72cdbd83def0520b467a3507beb7458ab5da3f53..7bd5df1cab03bed6d91e8e7f421af382f789bdf7 100644 --- a/mindspore-lite/minddata/dataset/api/vision.cc +++ b/mindspore-lite/minddata/dataset/api/vision.cc @@ -93,7 +93,7 @@ #include "mindspore-lite/minddata/dataset/kernels/ir/vision/vertical_flip_ir.h" #include "mindspore-lite/minddata/dataset/util/log_adapter.h" -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) #include "mindspore-lite/minddata/dataset/kernels/ir/vision/pad_ir.h" #include "mindspore-lite/minddata/dataset/kernels/ir/vision/rescale_ir.h" #include "mindspore-lite/minddata/dataset/kernels/ir/vision/swap_red_blue_ir.h" @@ -668,7 +668,7 @@ Pad::Pad(const std::vector &padding, const std::vector &fill_v : data_(std::make_shared(padding, fill_value, padding_mode)) {} std::shared_ptr Pad::Parse() { -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) return std::make_shared(data_->padding_, data_->fill_value_, data_->padding_mode_); #else MS_LOG(ERROR) << "Unsupported Pad."; @@ -1261,7 +1261,7 @@ struct Rescale::Data { Rescale::Rescale(float rescale, float shift) : data_(std::make_shared(rescale, shift)) {} std::shared_ptr Rescale::Parse() { -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) return std::make_shared(data_->rescale_, data_->shift_); #else MS_LOG(ERROR) << "Unsupported Rescale."; @@ -1443,7 +1443,7 @@ std::shared_ptr Solarize::Parse() { return std::make_shared SwapRedBlue::Parse() { -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) return std::make_shared(); #else MS_LOG(ERROR) << "Unsupported SwapRedBlue."; diff --git a/mindspore-lite/minddata/dataset/core/data_type.cc b/mindspore-lite/minddata/dataset/core/data_type.cc index 2994b174d2355aa816ecf8837e413d46cee9f0fd..48bf54e20b71eb7658ce0607cb99f88631a607bb 100644 --- a/mindspore-lite/minddata/dataset/core/data_type.cc +++ b/mindspore-lite/minddata/dataset/core/data_type.cc @@ -14,10 +14,9 @@ * limitations under the License. */ #include "mindspore-lite/minddata/dataset/core/data_type.h" -#ifdef ENABLE_MINDDATA_PYTHON +#ifdef ENABLE_CLOUD_FUSION_INFERENCE #include "mindspore-lite/minddata/dataset/core/pybind_support.h" #endif - #include "mindspore-lite/minddata/dataset/util/log_adapter.h" namespace mindspore { @@ -30,7 +29,7 @@ uint8_t DataType::SizeInBytes() const { } } -#ifdef ENABLE_MINDDATA_PYTHON +#ifdef ENABLE_CLOUD_FUSION_INFERENCE py::dtype DataType::AsNumpyType() const { if (type_ < DataType::NUM_OF_TYPES) { return py::dtype(kTypeInfo[type_].pybindType_); @@ -40,7 +39,7 @@ py::dtype DataType::AsNumpyType() const { } #endif -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) || defined(ENABLE_CLOUD_FUSION_INFERENCE) uint8_t DataType::AsCVType() const { uint8_t res = kCVInvalidType; if (type_ < DataType::NUM_OF_TYPES) { @@ -117,7 +116,7 @@ DataType::DataType(const std::string &type_str) { type_ = DE_STRING; } else if (type_str == "bytes") { type_ = DE_BYTES; -#ifdef ENABLE_MINDDATA_PYTHON +#ifdef ENABLE_CLOUD_FUSION_INFERENCE } else if (type_str == "python") { type_ = DE_PYTHON; #endif @@ -134,7 +133,7 @@ std::string DataType::ToString() const { } } -#ifdef ENABLE_MINDDATA_PYTHON +#ifdef ENABLE_CLOUD_FUSION_INFERENCE DataType DataType::FromNpArray(const py::array &arr) { if (py::isinstance>(arr)) { return DataType(DataType::DE_BOOL); diff --git a/mindspore-lite/minddata/dataset/core/data_type.h b/mindspore-lite/minddata/dataset/core/data_type.h index 59501632034cf105dc3973cdd8b2f021155b9dc6..06621384a589de47c4cf76ac6c1d43d106725163 100644 --- a/mindspore-lite/minddata/dataset/core/data_type.h +++ b/mindspore-lite/minddata/dataset/core/data_type.h @@ -16,14 +16,14 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_ -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) || defined(ENABLE_CLOUD_FUSION_INFERENCE) #include #endif #include #include -#ifdef ENABLE_MINDDATA_PYTHON +#ifdef ENABLE_CLOUD_FUSION_INFERENCE #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "mindspore-lite/minddata/dataset/core/pybind_support.h" @@ -67,7 +67,7 @@ class DataType { const uint8_t cvType_; // OpenCv matching type }; -#ifdef ENABLE_MINDDATA_PYTHON +#ifdef ENABLE_CLOUD_FUSION_INFERENCE static inline const TypeInfo kTypeInfo[] = { // name, sizeInBytes, pybindType, pybindFormatDescriptor, openCV {"unknown", 0, "object", "", kCVInvalidType}, // DE_UNKNOWN @@ -88,7 +88,7 @@ class DataType { {"python", 0, "object", "O", kCVInvalidType} // DE_PYTHON }; #else -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) || defined(ENABLE_CLOUD_FUSION_INFERENCE) static inline const TypeInfo kTypeInfo[] = { // name, sizeInBytes, pybindTypem formatDescriptor, openCV {"unknown", 0, "object", "", kCVInvalidType}, // DE_UNKNOWN @@ -165,7 +165,7 @@ class DataType { /// \return the number of bytes of the type. uint8_t SizeInBytes() const; -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) || defined(ENABLE_CLOUD_FUSION_INFERENCE) // Convert from DataType to OpenCV type /// \return uint8_t AsCVType() const; @@ -207,7 +207,7 @@ class DataType { template static DataType FromCType(); -#ifdef ENABLE_MINDDATA_PYTHON +#ifdef ENABLE_CLOUD_FUSION_INFERENCE // Convert from DataType to Pybind type /// \return py::dtype AsNumpyType() const; diff --git a/mindspore-lite/minddata/dataset/core/device_resource.h b/mindspore-lite/minddata/dataset/core/device_resource.h index 95e13c4f3de75bc20bd9b4503ba8cc484e2ccf72..013bdc50818296a90b98cc3c782bbd7ef97de2e4 100644 --- a/mindspore-lite/minddata/dataset/core/device_resource.h +++ b/mindspore-lite/minddata/dataset/core/device_resource.h @@ -52,4 +52,4 @@ class DeviceResource { } // namespace dataset } // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DEVICE_RESOURCE_H_ +#endif // MINDSPORE_DEVICE_RESOURCE_H diff --git a/mindspore-lite/minddata/dataset/core/global_context.cc b/mindspore-lite/minddata/dataset/core/global_context.cc index e2f760610099c52f7b726f80e069a64b97c9863c..1e8abb6d7bcc46bcf2900ef425dd82eb23a4156b 100644 --- a/mindspore-lite/minddata/dataset/core/global_context.cc +++ b/mindspore-lite/minddata/dataset/core/global_context.cc @@ -59,7 +59,7 @@ Status GlobalContext::Init() { // Create some tensor allocators for the different types and hook them into the pool. tensor_allocator_ = std::make_unique>(mem_pool_); -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) cv_tensor_allocator_ = std::make_unique>(mem_pool_); #endif device_tensor_allocator_ = std::make_unique>(mem_pool_); diff --git a/mindspore-lite/minddata/dataset/kernels/image/dvpp/CMakeLists.txt b/mindspore-lite/minddata/dataset/kernels/image/dvpp/CMakeLists.txt index 61c7be606b634f3e4808effef39d63ccf60d8175..edcac64f46a19e04ff9017f31ce7e1acb8924bf2 100644 --- a/mindspore-lite/minddata/dataset/kernels/image/dvpp/CMakeLists.txt +++ b/mindspore-lite/minddata/dataset/kernels/image/dvpp/CMakeLists.txt @@ -15,37 +15,6 @@ set(DVPP_IMAGE_SOURCE acl_adapter.cc ) -if(NOT BUILD_LITE AND ENABLE_D) - include(${CMAKE_SOURCE_DIR}/cmake/graphengine_variables.cmake) -set(DVPP_IMAGE_SOURCE - ${DVPP_IMAGE_SOURCE} - # Ascend910B - ascend910b/dvpp_adjust_brightness_op.cc - ascend910b/dvpp_adjust_contrast_op.cc - ascend910b/dvpp_adjust_hue_op.cc - ascend910b/dvpp_adjust_saturation_op.cc - ascend910b/dvpp_adjust_sharpness_op.cc - ascend910b/dvpp_affine_op.cc - ascend910b/dvpp_auto_contrast_op.cc - ascend910b/dvpp_crop_op.cc - ascend910b/dvpp_convert_color_op.cc - ascend910b/dvpp_decode_op.cc - ascend910b/dvpp_equalize_op.cc - ascend910b/dvpp_erase_op.cc - ascend910b/dvpp_gaussian_blur_op.cc - ascend910b/dvpp_horizontal_flip_op.cc - ascend910b/dvpp_invert_op.cc - ascend910b/dvpp_normalize_v2_op.cc - ascend910b/dvpp_pad_op.cc - ascend910b/dvpp_perspective_op.cc - ascend910b/dvpp_posterize_op.cc - ascend910b/dvpp_resize_op.cc - ascend910b/dvpp_resized_crop_op.cc - ascend910b/dvpp_rotate_op.cc - ascend910b/dvpp_solarize_op.cc - ascend910b/dvpp_vertical_flip_op.cc - ) -endif() add_library(kernels-dvpp-image OBJECT ${DVPP_IMAGE_SOURCE}) if(ENABLE_ACL OR MSLITE_ENABLE_ACL) diff --git a/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt b/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt index 57190636c60c7c6a6972372835f865f1abc2af86..d234b9dd97e1c62a3dd3ef330a48520182cfbddb 100644 --- a/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt +++ b/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt @@ -21,27 +21,20 @@ if(NOT MSLITE_ENABLE_ACL) ) endif() -if(NOT BUILD_LITE AND ENABLE_D) -set(DVPP_UTILS_SRC - ${DVPP_UTILS_SRC} - # Ascend910B - dvpp_image_utils.cc - ) -endif() - add_library(dvpp_utils SHARED ${DVPP_UTILS_SRC}) enable_target_when_only_build_plugins(dvpp_utils) -if(MSLITE_ENABLE_ACL) - find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(nnopbase libnnopbase.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl_dvpp_op libacl_dvpp_op.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl_dvpp_mpi libacl_dvpp_mpi.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - # find acl_env_guard in ascend_kernel_plugin - target_link_libraries(dvpp_utils PRIVATE ascend_kernel_plugin minddata-lite ${acl} ${acl_dvpp} - mindspore_core ${nnopbase} ${acl_dvpp_op} ${acl_dvpp_mpi}) -else() +if(NOT MSLITE_ENABLE_ACL) +# if(MSLITE_ENABLE_ACL) +# find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +# find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +# find_library(nnopbase libnnopbase.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +# find_library(acl_dvpp_op libacl_dvpp_op.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +# find_library(acl_dvpp_mpi libacl_dvpp_mpi.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +# # find acl_env_guard in ascend_kernel_plugin +# target_link_libraries(dvpp_utils PRIVATE ascend_kernel_plugin minddata-lite ${acl} ${acl_dvpp} +# mindspore_core ${nnopbase} ${acl_dvpp_op} ${acl_dvpp_mpi}) +# else() find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(nnopbase libnnopbase.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) diff --git a/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.h b/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.h index e8c99e29ba299eccbcca913e864445ef875bcae0..c90acb2fec33092492a7a0e373ce63c8f74e0478 100644 --- a/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.h +++ b/mindspore-lite/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.h @@ -94,7 +94,7 @@ const uint32_t MIN_CROP_WIDTH = 10; // Min width of crop area const uint32_t MIN_CROP_HEIGHT = 6; // Min height of crop area const uint8_t YUV_GREYER_VALUE = 128; // Filling value of the resized YUV image -#define CONVERT_TO_ODD(NUM) (((NUM) % MODULUS_NUM_2 != 0) ? (NUM) : ((NUM)-1)) // Convert the input to odd num +#define CONVERT_TO_ODD(NUM) (((NUM) % MODULUS_NUM_2 != 0) ? (NUM) : ((NUM)-1)) // Convert the input to odd num #define CONVERT_TO_EVEN(NUM) (((NUM) % MODULUS_NUM_2 == 0) ? (NUM) : ((NUM)-1)) // Convert the input to even num #define CHECK_ODD(num) ((num) % MODULUS_NUM_2 != 0) #define CHECK_EVEN(num) ((num) % MODULUS_NUM_2 == 0) diff --git a/mindspore-lite/minddata/dataset/kernels/image/lite_image_utils.cc b/mindspore-lite/minddata/dataset/kernels/image/lite_image_utils.cc index 013fcb6f4a312412c4725338b99f50f80719a29d..dd6b0ac17277a4fdf12df7583669a7f3d097ac6a 100644 --- a/mindspore-lite/minddata/dataset/kernels/image/lite_image_utils.cc +++ b/mindspore-lite/minddata/dataset/kernels/image/lite_image_utils.cc @@ -20,13 +20,13 @@ #include #include -#if defined(ENABLE_MINDDATA_PYTHON) +#if defined(ENABLE_CLOUD_FUSION_INFERENCE) #include #include #include #endif -#if defined(ENABLE_MINDDATA_PYTHON) +#if defined(ENABLE_CLOUD_FUSION_INFERENCE) #include "mindspore-lite/minddata/dataset/core/cv_tensor.h" #endif #include "mindspore-lite/minddata/dataset/core/tensor.h" @@ -35,7 +35,7 @@ #include "mindspore-lite/minddata/dataset/kernels/image/lite_cv/image_process.h" #include "mindspore-lite/minddata/dataset/kernels/image/lite_cv/lite_mat.h" #include "mindspore-lite/minddata/dataset/kernels/image/math_utils.h" -#if defined(ENABLE_MINDDATA_PYTHON) +#if defined(ENABLE_CLOUD_FUSION_INFERENCE) #include "mindspore-lite/minddata/dataset/kernels/image/resize_cubic_op.h" #endif #include "mindspore-lite/minddata/dataset/util/random.h" @@ -46,7 +46,7 @@ constexpr int64_t hwc_rank = 3; #define MAX_INT_PRECISION 16777216 // float int precision is 16777216 namespace mindspore { namespace dataset { -#if defined(ENABLE_MINDDATA_PYTHON) +#if defined(ENABLE_CLOUD_FUSION_INFERENCE) bool IsNonEmptyPNG(const std::shared_ptr &input) { const unsigned char kPngMagic[] = "\x89\x50\x4E\x47"; constexpr dsize_t kPngMagicLen = 4; @@ -305,7 +305,7 @@ static LDataType GetLiteCVDataType(const DataType &data_type) { } } -#if defined(ENABLE_MINDDATA_PYTHON) +#if defined(ENABLE_CLOUD_FUSION_INFERENCE) Status DecodeCv(const std::shared_ptr &input, std::shared_ptr *output) { std::shared_ptr input_cv = CVTensor::AsCVTensor(input); if (!input_cv->mat().data) { @@ -333,7 +333,7 @@ Status Decode(const std::shared_ptr &input, std::shared_ptr *out if (IsNonEmptyJPEG(input)) { return JpegCropAndDecode(input, output); } else { -#if defined(ENABLE_MINDDATA_PYTHON) +#if defined(ENABLE_CLOUD_FUSION_INFERENCE) return DecodeCv(input, output); #else RETURN_STATUS_UNEXPECTED("Decode: Decode only supports jpeg for android"); @@ -465,7 +465,7 @@ Status Normalize(const std::shared_ptr &input, std::shared_ptr * return Status::OK(); } -#if defined(ENABLE_MINDDATA_PYTHON) +#if defined(ENABLE_CLOUD_FUSION_INFERENCE) int GetCVInterpolationMode(InterpolationMode mode) { switch (mode) { case InterpolationMode::kLinear: diff --git a/mindspore-lite/minddata/dataset/kernels/image/lite_image_utils.h b/mindspore-lite/minddata/dataset/kernels/image/lite_image_utils.h index 3627d4457db53aedaa3b4d0ade0f25f94908798f..238d0c3fb23ddaf119afe25f33a0c96af64ac230 100644 --- a/mindspore-lite/minddata/dataset/kernels/image/lite_image_utils.h +++ b/mindspore-lite/minddata/dataset/kernels/image/lite_image_utils.h @@ -67,7 +67,7 @@ struct JpegErrorManagerCustom { jmp_buf setjmp_buffer; }; -#if defined(ENABLE_MINDDATA_PYTHON) +#if defined(ENABLE_CLOUD_FUSION_INFERENCE) bool IsNonEmptyPNG(const std::shared_ptr &input); /// \brief Returns Rescaled image diff --git a/mindspore-lite/minddata/dataset/kernels/ir/vision/pad_ir.cc b/mindspore-lite/minddata/dataset/kernels/ir/vision/pad_ir.cc index 78dfea7c3721b84bb7128dc8eef9c3712ce34d68..bf244c42bbc39b68d43d4bec0e7a77051ee72da9 100644 --- a/mindspore-lite/minddata/dataset/kernels/ir/vision/pad_ir.cc +++ b/mindspore-lite/minddata/dataset/kernels/ir/vision/pad_ir.cc @@ -18,7 +18,7 @@ #include -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) #include "mindspore-lite/minddata/dataset/kernels/image/pad_op.h" #endif #if !defined(BUILD_LITE) && defined(ENABLE_D) @@ -30,7 +30,7 @@ namespace mindspore { namespace dataset { namespace vision { -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) // PadOperation PadOperation::PadOperation(const std::vector &padding, const std::vector &fill_value, BorderType padding_mode, const std::string &device_target) diff --git a/mindspore-lite/minddata/dataset/kernels/ir/vision/rescale_ir.cc b/mindspore-lite/minddata/dataset/kernels/ir/vision/rescale_ir.cc index eeef36718e828b7c957883cd649f3c0b39fae1bb..515c585942d4f7812e4618bfee57616d3403e49c 100644 --- a/mindspore-lite/minddata/dataset/kernels/ir/vision/rescale_ir.cc +++ b/mindspore-lite/minddata/dataset/kernels/ir/vision/rescale_ir.cc @@ -15,7 +15,7 @@ */ #include "mindspore-lite/minddata/dataset/kernels/ir/vision/rescale_ir.h" -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) #include "mindspore-lite/minddata/dataset/kernels/image/rescale_op.h" #endif #include "mindspore-lite/minddata/dataset/util/validators.h" @@ -23,7 +23,7 @@ namespace mindspore { namespace dataset { namespace vision { -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) // RescaleOperation RescaleOperation::RescaleOperation(float rescale, float shift) : rescale_(rescale), shift_(shift) {} diff --git a/mindspore-lite/minddata/dataset/kernels/ir/vision/swap_red_blue_ir.cc b/mindspore-lite/minddata/dataset/kernels/ir/vision/swap_red_blue_ir.cc index ebc02906f783c92bb14301a3324acc032767d12a..ef5deded509e77d4b83e1cfc2cda72e37bb754ec 100644 --- a/mindspore-lite/minddata/dataset/kernels/ir/vision/swap_red_blue_ir.cc +++ b/mindspore-lite/minddata/dataset/kernels/ir/vision/swap_red_blue_ir.cc @@ -15,14 +15,14 @@ */ #include "mindspore-lite/minddata/dataset/kernels/ir/vision/swap_red_blue_ir.h" -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) #include "mindspore-lite/minddata/dataset/kernels/image/swap_red_blue_op.h" #endif namespace mindspore { namespace dataset { namespace vision { -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) // SwapRedBlueOperation. SwapRedBlueOperation::SwapRedBlueOperation() = default; diff --git a/mindspore-lite/minddata/dataset/util/log_adapter.h b/mindspore-lite/minddata/dataset/util/log_adapter.h index b22fccd9deda4c70e6f310d13081ad91c28360c6..c842c000a930d45d086117e7be55a7b048728774 100644 --- a/mindspore-lite/minddata/dataset/util/log_adapter.h +++ b/mindspore-lite/minddata/dataset/util/log_adapter.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_LOG_ADAPTER_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_LOG_ADAPTER_H_ -#if !defined(ENABLE_ANDROID) || defined(ENABLE_MINDDATA_PYTHON) +#if !defined(ENABLE_ANDROID) #include "src/common/log_adapter.h" #define DATASET_SRC_FILE_NAME FILE_NAME #else diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/CMakeLists.txt b/mindspore-lite/ops/kernel/cpu/nnacl/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a43577439cd1615d451802c87f54fbad08e33a2d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/CMakeLists.txt @@ -0,0 +1,293 @@ +project(nnacl) + +set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${NNACL_DIR}/..) + +set(NNACL_SIMD_DIR ${CMAKE_BINARY_DIR}/src/nnacl) +include_directories(${NNACL_SIMD_DIR}/..) +file(GLOB SIMD_CONFIG_HEADER + ${NNACL_DIR}/base/*_simd.h.in + ${NNACL_DIR}/fp32/*_simd.h.in + ${NNACL_DIR}/fp32/online_fusion/*_simd.h.in + ${NNACL_DIR}/fp32_grad/*_simd.h.in +) +function(generate_simd_header_code) + foreach(simd_config_file ${SIMD_CONFIG_HEADER}) + string(REGEX MATCHALL "[0-9A-Za-z_]*_simd.h.in" tmp1 ${simd_config_file}) + string(REGEX REPLACE "_simd.h.in" "_${SIMD_INSTRUCTION_LOWER}.h" tmp2 ${tmp1}) + string(REGEX REPLACE "_simd.h.in" "" tmp3 ${tmp1}) + string(TOLOWER ${tmp3} OP_NAME_LOWER) + string(TOUPPER ${tmp3} OP_NAME_UPPER) + configure_file(${NNACL_DIR}/op_simd_header_file.h.in ${NNACL_SIMD_DIR}/${tmp3}_simd.h) + endforeach() +endfunction() + +function(generate_simd_code SIMD_INSTRUCTION SIMD_BLOCK_NUM SIMD_TARGET) + string(TOLOWER ${SIMD_INSTRUCTION} SIMD_INSTRUCTION_LOWER) + set(SIMD_DEFINE "#define MS_SIMD_${SIMD_INSTRUCTION}") + set(SIMD_UNDEF "#undef MS_SIMD_${SIMD_INSTRUCTION}") + set(SIMD_DEF_INSTRUCTION "#define MS_SIMD_INSTRUCTION MS_SIMD_${SIMD_INSTRUCTION}_INSTRUCTION") + set(SIMD_UNDEF_INSTRUCTION "#undef MS_SIMD_INSTRUCTION") + set(SIMD_DEF_BLOCK_NUM "#define BLOCK_NUM ${SIMD_BLOCK_NUM}") + set(SIMD_UNDEF_BLOCK_NUM "#undef BLOCK_NUM") + if(SIMD_INSTRUCTION_LOWER STREQUAL "neon") + set(SIMD_TARGET_BEGIN "") + set(SIMD_TARGET_END "") + else() + set(SIMD_TARGET_BEGIN "#pragma GCC push_options\n#pragma GCC target(${SIMD_TARGET})") + set(SIMD_TARGET_END "#pragma GCC pop_options") + endif() + + set(SIMD_INSTRUCTION_BEGIN "${SIMD_TARGET_BEGIN}\n${SIMD_DEF_INSTRUCTION}\n${SIMD_DEF_BLOCK_NUM}\n${SIMD_DEFINE}") + set(SIMD_INSTRUCTION_END "${SIMD_UNDEF_INSTRUCTION}\n${SIMD_UNDEF_BLOCK_NUM}\n${SIMD_TARGET_END}\n${SIMD_UNDEF}") + foreach(simd_config_file ${SIMD_CONFIG_HEADER}) + string(REGEX MATCHALL "[0-9A-Za-z_]*_simd.h.in" tmp1 ${simd_config_file}) + string(REGEX REPLACE "_simd.h.in" "_${SIMD_INSTRUCTION_LOWER}.h" tmp2 ${tmp1}) + configure_file(${simd_config_file} ${NNACL_SIMD_DIR}/${SIMD_INSTRUCTION_LOWER}/${tmp2}) + endforeach() +endfunction() +generate_simd_code(NEON 4 \"\") +generate_simd_code(SSE 4 \"sse4.1\") +generate_simd_code(AVX 8 "\"avx\", \"avx2\"") +generate_simd_code(AVX512 16 \"avx512f\") +generate_simd_header_code() + +if(ENABLE_CPU AND NOT MSVC) + set(CMAKE_C_FLAGS "-Wno-attributes ${CMAKE_C_FLAGS}") +endif() + +if(APPLE OR PLATFORM_ARM32 OR PLATFORM_ARM64 OR PLATFORM_MCU) + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND DEFINED ARCHS) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fstrict-aliasing \ + -ffunction-sections -fdata-sections -ffast-math -Wno-shorten-64-to-32") + endif() + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND NOT DEFINED ARCHS) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fomit-frame-pointer -fstrict-aliasing \ + -ffunction-sections -fdata-sections -ffast-math") + endif() + if(TARGET_OHOS) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-inline-asm") + endif() +elseif(NOT MSVC) + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fomit-frame-pointer -fstrict-aliasing -ffunction-sections \ + -fdata-sections") + endif() +endif() + +if(NOT MSVC) + if("${X86_64_SIMD}" STREQUAL "sse" OR "${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1") + endif() + if("${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma") + endif() + if(WIN32) + if("${X86_64_SIMD}" STREQUAL "avx512") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx512f -fno-asynchronous-unwind-tables") + endif() + endif() +endif() + + +########################### files ########################### +file(GLOB KERNEL_SRC + ${NNACL_DIR}/*.c + ${NNACL_DIR}/fp32/*.c + ${NNACL_DIR}/infer/*.c + ${NNACL_DIR}/base/*.c + ${NNACL_DIR}/fp32_grad/*.c + ${NNACL_DIR}/kernel/*.c + ${NNACL_DIR}/experimental/*.c + ${NNACL_DIR}/fp32/online_fusion/*.c +) + +set(KERNEL_AVX512_FILE ${NNACL_DIR}/fp32/matmul_avx512_fp32.c + ${NNACL_DIR}/fp32/matmul_avx512_mask_fp32.c + ${NNACL_DIR}/fp32/conv_im2col_avx512_fp32.c +) +list(REMOVE_ITEM KERNEL_SRC ${KERNEL_AVX512_FILE}) + +set(KERNEL_AVX_FILE ${NNACL_DIR}/fp32/conv_sw_avx_fp32.c + ${NNACL_DIR}/fp32/conv_1x1_avx_fp32.c + ${NNACL_DIR}/fp32/matmul_avx_fp32.c + ${NNACL_DIR}/fp32/conv_depthwise_avx_fp32.c) +list(REMOVE_ITEM KERNEL_SRC ${KERNEL_AVX_FILE}) + +set(KERNEL_ARM64_FILE ${NNACL_DIR}/fp32/conv_sw_arm64_fp32.c) +list(REMOVE_ITEM KERNEL_SRC ${KERNEL_ARM64_FILE}) + +if(NOT MSLITE_ENABLE_RUNTIME_PASS) + list(REMOVE_ITEM KERNEL_SRC ${NNACL_DIR}/infer/shape_fusion_infer.c) +endif() +if((NOT DEFINED MSLITE_ENABLE_INT8) OR MSLITE_ENABLE_INT8) + file(GLOB KERNEL_SRC_INT8 + ${NNACL_DIR}/int8/*.c + ) + set(KERNEL_SRC + ${KERNEL_SRC} + ${KERNEL_SRC_INT8} + ) +else() + set(KERNEL_SRC + ${KERNEL_SRC} + ${NNACL_DIR}/int8/pack_int8.c + ${NNACL_DIR}/int8/quantize.c + ) +endif() + +if(MSLITE_ENABLE_SPARSE_COMPUTE) + file(GLOB KERNEL_SRC_SPARSE + ${NNACL_DIR}/fp32_sparse/*.c + ) + set(KERNEL_SRC + ${KERNEL_SRC} + ${KERNEL_SRC_SPARSE} + ) +endif() + +if(MSLITE_ENABLE_STRING_KERNEL) + file(GLOB KERNEL_SRC_INFER_STRING + ${NNACL_DIR}/infer/string/*.c + ) + set(KERNEL_SRC + ${KERNEL_SRC} + ${KERNEL_SRC_INFER_STRING} + ) +endif() +if(MSLITE_ENABLE_CONTROLFLOW) + file(GLOB KERNEL_SRC_INFER_CONTROL_TENSORLIST + ${NNACL_DIR}/infer/control/*.c + ) + set(KERNEL_SRC + ${KERNEL_SRC} + ${KERNEL_SRC_INFER_CONTROL_TENSORLIST} + ) +endif() +if(PLATFORM_ARM64) + file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/assembly/arm64/*.S) + set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_SRC ${KERNEL_SRC} ${KERNEL_ARM64_FILE}) +endif() + +if(PLATFORM_ARM32) + file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/assembly/arm32/*.S) + set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) +endif() + +if("${X86_64_SIMD}" STREQUAL "sse" OR "${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512") + file(GLOB ASSEMBLY_SSE_SRC ${NNACL_DIR}/intrinsics/sse/*.c) + set_property(SOURCE ${ASSEMBLY_SSE_SRC} PROPERTY LANGUAGE C) + + set(MS_X86_SSE_SRC + ${ASSEMBLY_SSE_SRC} + ${KERNEL_SSE_FILE}) + set_source_files_properties(${MS_X86_SSE_SRC} PROPERTIES LANGUAGE C + COMPILE_FLAGS "${CMAKE_C_FLAGS} -msse4.1 -fPIC") + + set(MS_X86_SIMD_SRC ${MS_X86_SIMD_SRC} ${MS_X86_SSE_SRC}) +endif() + +if("${X86_64_SIMD}" STREQUAL "avx" OR "${X86_64_SIMD}" STREQUAL "avx512") + file(GLOB ASSEMBLY_AVX_SRC + ${NNACL_DIR}/intrinsics/avx/*.c + ${NNACL_DIR}/assembly/avx/*.S + ${NNACL_DIR}/intrinsics/ms_simd_cpu_info.c) + set_property(SOURCE ${ASSEMBLY_AVX_SRC} PROPERTY LANGUAGE C) + + set(MS_X86_AVX_SRC + ${ASSEMBLY_AVX_SRC} + ${KERNEL_AVX_FILE}) + set_source_files_properties(${MS_X86_AVX_SRC} PROPERTIES LANGUAGE C + COMPILE_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -fPIC") + + set(MS_X86_SIMD_SRC ${MS_X86_SIMD_SRC} ${MS_X86_AVX_SRC}) +endif() + +if("${X86_64_SIMD}" STREQUAL "avx512") + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") + file(GLOB HPC_SRC ${NNACL_DIR}/experimental/HPC-generator/gemm_avx512/*.c + ${NNACL_DIR}/experimental/HPC-generator/gemm_mask_avx512/*.c) + + set_property(SOURCE ${HPC_SRC} PROPERTY LANGUAGE C) + endif() + + file(GLOB ASSEMBLY_AVX512_SRC ${NNACL_DIR}/assembly/avx512/*.S) + set_property(SOURCE ${ASSEMBLY_AVX512_SRC} PROPERTY LANGUAGE C) + + set(MS_X86_AVX512_SRC + ${HPC_SRC} + ${ASSEMBLY_AVX512_SRC} + ${KERNEL_AVX512_FILE}) + + set_source_files_properties(${MS_X86_AVX512_SRC} PROPERTIES LANGUAGE C + COMPILE_FLAGS "${CMAKE_C_FLAGS} -mavx512f -fPIC") + + set(MS_X86_SIMD_SRC ${MS_X86_SIMD_SRC} ${MS_X86_AVX512_SRC}) +endif() + +if(APPLE) + set_source_files_properties(${ASSEMBLY_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") +endif() + +########################### build nnacl library ######################## +if(NOT MSVC) +string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") +endif() + +if(PLATFORM_ARM) + set(NO_FAST_MATH_OPTI ${NNACL_DIR}/fp32/resize_fp32.c) + set_source_files_properties(${NO_FAST_MATH_OPTI} PROPERTIES LANGUAGE C + COMPILE_FLAGS "${CMAKE_C_FLAGS} -fno-fast-math") +endif() + +add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${ASSEMBLY_SRC} ${MS_X86_SIMD_SRC}) + +if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_DEBUG) +endif() + +if(ENABLE_CPU) + if(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "aarch64") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_ARM ENABLE_ARM64 ENABLE_NEON) + target_compile_options(nnacl_mid PRIVATE -ffast-math -flax-vector-conversions) + elseif("${X86_64_SIMD}" STREQUAL "sse") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE) + elseif("${X86_64_SIMD}" STREQUAL "avx") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX) + elseif("${X86_64_SIMD}" STREQUAL "avx512") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_SSE ENABLE_AVX ENABLE_AVX512) + endif() + if(NOT MSVC) + target_compile_options(nnacl_mid PRIVATE -fPIC -fstack-protector-all) + add_library(nnacl SHARED $) + else() + add_library(nnacl STATIC $) + endif() + if(NOT CMAKE_SYSTEM_NAME MATCHES "Windows") + if(NOT CMAKE_SYSTEM_NAME MATCHES "Darwin") + target_link_options(nnacl PRIVATE -Wl,-z,relro,-z,now,-z,noexecstack) + target_link_libraries(nnacl PRIVATE m) + endif() + if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") + target_link_options(nnacl PRIVATE -s) + endif() + endif() +endif() + +set(nnacl_static_obj $) +########################### arm fp16 build optimize library ######################## +if(PLATFORM_ARM) + add_subdirectory(${NNACL_DIR}/optimize) + if(TARGET nnacl_optimize_mid) + set(nnacl_static_obj ${nnacl_static_obj} $) + endif() + if(TARGET nnacl_fp16_mid) + set(nnacl_static_obj ${nnacl_static_obj} $) + endif() +endif() +if(NOT ${CMAKE_GENERATOR} MATCHES "Ninja") + add_library(nnacl_static STATIC ${nnacl_static_obj}) + set_target_properties(nnacl_static PROPERTIES OUTPUT_NAME "nnacl") + set_target_properties(nnacl_static PROPERTIES CLEAN_DIRECT_OUTPUT 1) +endif() diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/OWNERS b/mindspore-lite/ops/kernel/cpu/nnacl/OWNERS new file mode 100644 index 0000000000000000000000000000000000000000..51cce0f0335962ae234c1561649b41814ee107f4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/OWNERS @@ -0,0 +1,8 @@ +approvers: +- zhang_xue_tong # +- jpc_chenjianping +- xu_anyue +- fatmouse007fatmouse007 # zhuguodong + +options: + no_parent_owners: true diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/README.md b/mindspore-lite/ops/kernel/cpu/nnacl/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c756a7db075467cf782af013fd36cc0c8ed942e3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/README.md @@ -0,0 +1 @@ +NNACL(neural network accelerated computing library) is a high performance library of neural network inference computing kernels for ARM. diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/activation_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/activation_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..ed34a3d37407cfc99d330041b2e37bbfd505a612 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/activation_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_ACTIVATION_PARAMETER_H_ +#define NNACL_ACTIVATION_PARAMETER_H_ + +#include "nnacl/op_base.h" +typedef struct ActivationParameter { + OpParameter op_parameter_; + int type_; + float alpha_; + float min_val_; + float max_val_; + bool approximate_; +} ActivationParameter; + +#endif // NNACL_ACTIVATION_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/affine_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/affine_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..33686e3ffc4aefa9ee3c0b7a39ca7683e2155be8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/affine_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_AFFINE_PARAMETER_H_ +#define NNACL_AFFINE_PARAMETER_H_ +#include "nnacl/op_base.h" +#include "nnacl/matmul_parameter.h" +typedef struct AffineParameter { + OpParameter op_parameter_; + // parameters from splice op + int context_size_; + int *context_; + int output_dim_; + // parameters from activation op + int activation_type_; + // parameters from matmul op + MatMulParameter *matmul_parameter_; +} AffineParameter; +#endif // NNACL_AFFINE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/all_gather_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/all_gather_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..fe0008a18477e4c9a1b8f725be9025d1d1bec3fa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/all_gather_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ALL_GATHER_PARAMETER_H_ +#define NNACL_ALL_GATHER_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct AllGatherParameter { + // primitive parameter + OpParameter op_parameter_; + char group_[DEFAULT_GROUP_NAME_LEN]; + + // other parameter + int rank_size_; +} AllGatherParameter; +#endif // NNACL_ALL_GATHER_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/arg_min_max_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/arg_min_max_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..1d8899c64d93ba9db635f9673af9904cc135e425 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/arg_min_max_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ARG_MIN_MAX_PARAMETER_H_ +#define NNACL_ARG_MIN_MAX_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct ArgMinMaxParameter { + OpParameter op_parameter_; + int32_t axis_; + int32_t topk_; + bool keep_dims_; + bool out_value_; +} ArgMinMaxParameter; + +#endif // NNACL_ARG_MIN_MAX_PARAMETER_H_ diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/acl_env_guard.h b/mindspore-lite/ops/kernel/cpu/nnacl/arithmetic_parameter.h similarity index 38% rename from mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/acl_env_guard.h rename to mindspore-lite/ops/kernel/cpu/nnacl/arithmetic_parameter.h index fc3d098e2dc723c17cd20a17ec918e3ef034b946..5aeaa7214dd5d227e169b301df11c57ed3cd5d21 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/acl_env_guard.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/arithmetic_parameter.h @@ -13,41 +13,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_ENV_GUARD_H -#define MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_ENV_GUARD_H - -#include -#include -#include "acl/acl_base.h" - -namespace mindspore { -class __attribute__((visibility("default"))) AclInitAdapter { - public: - static AclInitAdapter &GetInstance(); - aclError AclInit(const char *config_file); - aclError AclFinalize(); - aclError ForceFinalize(); - - private: - AclInitAdapter() : init_flag_(false) {} - ~AclInitAdapter() = default; - - bool init_flag_; - std::mutex flag_mutex_; -}; - -class __attribute__((visibility("default"))) AclEnvGuard { - public: - explicit AclEnvGuard(); - ~AclEnvGuard(); - aclError GetErrno() const { return errno_; } - static std::shared_ptr GetAclEnv(); - - private: - static std::shared_ptr global_acl_env_; - static std::mutex global_acl_env_mutex_; - - aclError errno_; -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_ENV_GUARD_H + +#ifndef NNACL_ARTITHMETIC_PARAMETER_H_ +#define NNACL_ARTITHMETIC_PARAMETER_H_ + +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/nnacl_utils.h" + +#define ARITHMETIC_SUPPORT_DIMS_NUM 10 + +typedef struct ArithmeticParameter { + OpParameter op_parameter_; + bool broadcasting_; + size_t ndim_; + int activation_type_; + int in_shape0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int64_t in_elements_num0_; + int in_shape1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int64_t in_elements_num1_; + + int out_shape_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int out_elements_num_; + + int in_strides0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int in_strides1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int out_strides_[ARITHMETIC_SUPPORT_DIMS_NUM]; + + int multiples0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int multiples1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int eltwise_mode_; // eltwise need +} ArithmeticParameter; + +#endif // NNACL_ARTITHMETIC_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/arithmetic_self_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/arithmetic_self_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..a5024d1853d197506e1368062fafc1c1bc7a4654 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/arithmetic_self_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ARITHMETIC_SELF_PARAMETER_H_ +#define NNACL_ARITHMETIC_SELF_PARAMETER_H_ + +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/int8/quantize.h" + +// For Abs, Cos, Exp, Log, Square, Sqrt, Rsqrt ops. +typedef struct ArithmeticSelfParameter { + OpParameter op_parameter_; + ArithSelfQuantArg quant_arg_; +} ArithmeticSelfParameter; + +#endif // NNACL_ARITHMETIC_SELF_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDw3x3Int8BorderPixel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDw3x3Int8BorderPixel.S new file mode 100644 index 0000000000000000000000000000000000000000..1d9616753e228f36eaf9fabcd9e920d55ee5898f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDw3x3Int8BorderPixel.S @@ -0,0 +1,128 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, +// size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, +// size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max +// size_t per_channel) { + +// todo: support per channel +// r0: dst, r1: src, r2: weight, r3: bias, r4: height, r5: width, r6: in_kh_step, r7: in_kw_step, +// r8: channel, r9: in_zp, r10: out_zp, r11: out_multiplier, r12: left_shift, r13: right_shift +// r14: acc_min, r15: acc_max +asm_function ConvDw3x3Int8BorderPixel + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + + push {r4-r8, r9-r12, lr} + vpush {q4-q7} + add sp, sp, #104 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + ldrb r10, [sp, #20] // in_zp + vdup.8 d18, r10 + ldr r10, [sp, #24] // out_zp + vdup.32 q15, r10 + ldr r10, [sp, #28] // out_multiplier + vdup.32 q14, r10 + ldr r10, [sp, #32] // left_shift + vdup.32 q13, r10 + ldr r10, [sp, #36] // right_shift + vdup.32 q12, r10 + ldr r10, [sp, #40] // acc_min + vdup.32 q11, r10 + ldr r10, [sp, #44] // acc_max + vdup.32 q10, r10 + + mov r4, #2 + mul lr, r8, r4 + + LoopC: + mov r9, r1 + mov r10, r2 + ldr r4, [sp] + + vld1.32 {q3}, [r3]! + vld1.32 {q4}, [r3]! + LoopH: + mov r11, r9 + mov r12, r10 + ldr r5, [sp, #4] + LoopW: + vld1.8 {d0}, [r11], r7 + vld1.16 {d2, d3}, [r12], lr // weight + vsubl.s8 q2, d0, d18 // -zp + + vmlal.s16 q3, d4, d2 + vmlal.s16 q4, d5, d3 + + subs r5, r5, #1 + bne LoopW + subs r4, r4, #1 + add r9, r9, r6 + mov r11, #3 + mul r5, lr, r11 + add r10, r10, r5 + bne LoopH + + vshl.s32 q3, q3, q13 + vqrdmulh.s32 q3, q3, q14 + vand q5, q3, q12 + vshr.s32 q5, q5, #31 + vqadd.s32 q3, q3, q5 + vrshl.s32 q3, q3, q12 + vadd.i32 q3, q3, q15 + vmax.s32 q3, q3, q11 + vmin.s32 q3, q3, q10 + vqmovn.s32 d14, q3 + + vshl.s32 q4, q4, q13 + vqrdmulh.s32 q4, q4, q14 + vand q6, q4, q12 + vshr.s32 q6, q6, #31 + vqadd.s32 q4, q4, q6 + vrshl.s32 q4, q4, q12 + vadd.i32 q4, q4, q15 + vmax.s32 q4, q4, q11 + vmin.s32 q4, q4, q10 + vqmovn.s32 d15, q4 + vqmovn.s16 d16, q7 + + vst1.8 {d16}, [r0]! + add r1, r1, #8 + add r2, r2, #16 + + sub r8, r8, #8 + cmp r8, #8 + bge LoopC + + sub sp, sp, #104 + vpop {q4-q7} + pop {r4-r8, r9-r12, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwFp32Border.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwFp32Border.S new file mode 100644 index 0000000000000000000000000000000000000000..b1f68f0fe193332eba539e35ddd0fdf16ea45f3c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwFp32Border.S @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6) +// r0: dst, r1: src, r2: weight, r3: bias, r4: height, r5: width, r6: in_kh_step, r7: in_kw_step, +// r8: kernel_w, r9: relu, r10: relu6 +asm_function ConvDwFp32Border + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r12, lr} + vpush {q4-q7} + add sp, sp, #104 + + ldr r4, [sp] // height + ldr r5, [sp, #4] // width + ldr r6, [sp, #8] // in_kh_step + ldr r7, [sp, #12] // in_kw_step + ldr r8, [sp, #16] // kernel_w + ldr r9, [sp, #20] // relu + ldr r10, [sp, #24] // relu6 + + vld1.32 {q0}, [r3] // bias + vmov.i32 q1, #6 // relu6 + vcvt.f32.s32 q1, q1 + veor q2, q2, q2 // relu + + LoopH: + mov r11, r1 + mov r12, r2 + mov r14, r5 + LoopW: + vld1.32 {q3}, [r11], r7 + vld1.32 {q4}, [r12]! + vmla.f32 q0, q3, q4 + subs r14, r14, #1 + bne LoopW + subs r4, r4, #1 + add r1, r1, r6 + add r2, r2, r8 + bne LoopH + + cmp r10, #0 + bne Relu6 + cmp r9, #0 + bne Relu + b Write + Relu6: + vmin.f32 q0, q0, q1 + Relu: + vmax.f32 q0, q0, q2 + Write: + vst1.32 {q0}, [r0] + + sub sp, sp, #104 + vpop {q4-q7} + pop {r4-r12, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwFp32Center.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwFp32Center.S new file mode 100644 index 0000000000000000000000000000000000000000..8e1c6afe6f3031028bb9354925a00ce9d52e58e6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwFp32Center.S @@ -0,0 +1,176 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// r0: dst, r1: src, r2: weight, r3: bias, #0: height, #4: width, #8: kernel_h, #12: kernel_w, +// #16: out_h_step, #20: block_channel, #24: in_sh_step, #28: in_sw_step, #32: in_kh_step,#36: in_kw_step +// #40: relu, #44: relu6 +asm_function ConvDwFp32Center + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #112 + + ldr r4, [sp] // height + + vld1.32 {q13}, [r3] + vmov.i32 q14, #6 + vcvt.f32.s32 q14, q14 + veor q15, q15, q15 + + LoopH: + ldr r1, [sp, #-44] // src_w, src_h = src + ldr r5, [sp, #4] // width + ldr r0, [sp, #-48] // dst_w, dst_h = dst + cmp r5, #4 + blt LoopW + LoopW4: + ldr r11, [sp, #28] // in_sw_step + mov r8, r1 // src_kh, src_w + ldr r2, [sp, #-40] // weight_kh, weight + ldr r6, [sp, #8] // kernel_h + vmov q0, q13 + vmov q1, q13 + vmov q2, q13 + vmov q3, q13 + LoopKh4: + ldr r7, [sp, #12] // kernel_w + mov lr, r8 // src_kw, src_kh + LoopKw4: + ldr r12, [sp, #36] //in_kw_step + mov r10, lr + vld1.32 {q12}, [r2]! + vld1.32 {q4}, [r10] + add r10, r10, r11 + vmla.f32 q0, q4, q12 + vld1.32 {q5}, [r10] + add r10, r10, r11 + vmla.f32 q1, q5, q12 + vld1.32 {q6}, [r10] + add r10, r10, r11 + vmla.f32 q2, q6, q12 + vld1.32 {q7}, [r10] + add r10, r10, r11 + vmla.f32 q3, q7, q12 + subs r7, r7, #1 + add lr, lr, r12 + bne LoopKw4 + ldr r12, [sp, #32] // in_kh_step + add r8, r8, r12 + subs r6, r6, #1 + bne LoopKh4 + ldr r12, [sp, #44] + cmp r12, #0 + bne Relu64 + ldr r12, [sp, #40] + cmp r12, #0 + bne Relu4 + b Write4 + Relu64: + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmin.f32 q2, q2, q14 + vmin.f32 q3, q3, q14 + Relu4: + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vmax.f32 q2, q2, q15 + vmax.f32 q3, q3, q15 + Write4: + ldr r12, [sp, #20] // block_channel + vst1.32 {q0}, [r0] + add r0, r0, r12 + vst1.32 {q1}, [r0] + add r0, r0, r12 + vst1.32 {q2}, [r0] + add r0, r0, r12 + vst1.32 {q3}, [r0] + add r0, r0, r12 + mov r12, #4 + mul r11, r11, r12 + add r1, r1, r11 // src_w += in_sw_step + sub r5, r5, #4 + cmp r5, #0 + ble LoopWEnd + cmp r5, #4 + bge LoopW + LoopW: + mov r8, r1 // src_kh, src_w + ldr r2, [sp, #-40] // weight_kh, weight + ldr r6, [sp, #8] // kernel_h + vmov q0, q13 // bias + LoopKh: + ldr r7, [sp, #12] // kernel_w + mov r10, r8 // src_kw, src_kh + LoopKw: + ldr r12, [sp, #36] //in_kw_step + vld1.32 {q1}, [r10] + add r10, r10, r12 + vld1.32 {q12}, [r2]! + vmla.f32 q0, q1, q12 + subs r7, r7, #1 + bne LoopKw + ldr r12, [sp, #32] // in_kh_step + add r8, r8, r12 + subs r6, r6, #1 + bne LoopKh + ldr r12, [sp, #44] + cmp r12, #0 + bne Relu6 + ldr r12, [sp, #40] + cmp r12, #0 + bne Relu + b Write + Relu6: + vmin.f32 q0, q0, q14 + Relu: + vmax.f32 q0, q0, q15 + Write: + ldr r12, [sp, #20] // block_channel + vst1.32 {q0}, [r0] // dst_kw += block_channel + add r0, r0, r12 + ldr r12, [sp, #28] // in_sw_step + add r1, r1, r12 // src_w += in_sw_step + subs r5, r5, #1 + bne LoopW + ldr r3, [sp, #16] // out_h_step + ldr r12, [sp, #-48] + add r12, r12, r3 + str r12, [sp, #-48] + + ldr r3, [sp, #24] // in_sh_step + ldr r12, [sp, #-44] // src_h += in_sh_step + add r12, r12, r3 + str r12, [sp, #-44] + + subs r4, r4, #1 // height + bne LoopH +LoopWEnd: + sub sp, sp, #112 + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwFp32Row.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwFp32Row.S new file mode 100644 index 0000000000000000000000000000000000000000..197509200dbf93e83ae53af884d116a11db70e58 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwFp32Row.S @@ -0,0 +1,125 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// voidConvDwFp32Row(float* output_ptr, const float* input_ptr, const float* filter_ptr, +// size_t num_pixels, size_t input_channel, size_t input_step) +// r0: output_ptr, r1: input_ptr, r2: filter_ptr, r3: num_pixels, +// r4: input_channel, r5: input_step +asm_function ConvDwFp32Row + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + + push {r4-r6, r8, r10, r11} + vpush {q4-q7} + add sp, sp, #88 + mov r11, r0 + ldr r4, [sp] + ldr r5, [sp, #4] + mov r6, #4 + mul r5, r5, r6 + cmp r3, #0 + ble End + + LoopNumPixel: + mov r6, r1 // input_ptr + mov r8, r2 // filter_ptr + mov r10, r4 // input_channel + + LoopDepth16In: + cmp r10, #16 + blt L4 + sub r10, r10, #16 + + vld1.32 {q0, q1}, [r6]! + vld1.32 {q4, q5}, [r8]! + vld1.32 {q8, q9}, [r0]! + + cmp r10, #16 + blt LoopDepth16Out + LoopDepth16: + vmla.f32 q8, q0, q4 + vmla.f32 q9, q1, q5 + vst1.32 {q8, q9}, [r11]! + + vld1.32 {q2, q3}, [r6]! + vld1.32 {q6, q7}, [r8]! + vld1.32 {q10, q11}, [r0]! + vmla.f32 q10, q2, q6 + vmla.f32 q11, q3, q7 + vst1.32 {q10, q11}, [r11]! + + vld1.32 {q0, q1}, [r6]! + vld1.32 {q4, q5}, [r8]! + vld1.32 {q8, q9}, [r0]! + + sub r10, r10, #16 + cmp r10, #16 + bge LoopDepth16 + + LoopDepth16Out: + vmla.f32 q8, q0, q4 + vmla.f32 q9, q1, q5 + vst1.32 {q8, q9}, [r11]! + + vld1.32 {q2, q3}, [r6]! + vld1.32 {q6, q7}, [r8]! + vld1.32 {q10, q11}, [r0]! + vmla.f32 q10, q2, q6 + vmla.f32 q11, q3, q7 + vst1.32 {q10, q11}, [r11]! + + L4: + cmp r10, #4 + blt L0 + + LoopDepth4: + vld1.32 {q0}, [r6]! + vld1.32 {q4}, [r8]! + vld1.32 {q8}, [r0]! + vmla.f32 q8, q0, q4 + vst1.32 {q8}, [r11]! + sub r10, r10, #4 + cmp r10, #4 + bge LoopDepth4 + + L0: + cmp r10, #0 + beq Loop16LineEnd + + LoopDepth0: + vld1.32 d0[0], [r6]! + vld1.32 d2[0], [r8]! + vld1.32 d4[0], [r0]! + vmla.f32 s8, s0, s4 + vst1.32 d4[0], [r11]! + subs r10, r10, #1 + bne LoopDepth0 + + Loop16LineEnd: + subs r3, r3, #1 + add r1, r1, r5 + bne LoopNumPixel + + End: + sub sp, sp, #88 + vpop {q4-q7} + pop {r4-r6, r8, r10, r11} + bx lr +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwInt8Center.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwInt8Center.S new file mode 100644 index 0000000000000000000000000000000000000000..0f72bfef56e52cd71bfb1cc885c062e7a12e9a5a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwInt8Center.S @@ -0,0 +1,290 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void DepthwiseCenterInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, +// int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, +// int in_sw_step, int in_kh_step, int in_kw_step, int8_t *in_zp, int32_t *out_zp, +// int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t *acc_min, +// int32_t *acc_max) +// #-48: dst, #-44: src, #-40: weight, #-36: bias, #0: height, #4: width, #8: kernel_h, #12: kernel_w, +// #16: out_h_step, #20: block_channel, #24: in_sh_step, #28: in_sw_step, #32: in_kh_step, #36: in_kw_step +// #40: in_zp, #44: out_zp, #48: out_multiplier, #52: left_shift, #56: right_shift, #60:acc_min, #64: acc_max +asm_function ConvDwInt8Center +// at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" +// according to https://stackoverflow.com/questions/53625807 +// even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway +// clang's rule seems more simple, though there are no subroutine calls here +// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + + ldr lr, [sp, #168] + vld1.32 {q0, q1}, [lr] + vpush {q0, q1} + ldr lr, [sp, #204] + vld1.32 {q0, q1}, [lr] + vpush {q0, q1} + ldr lr, [sp, #240] + vld1.32 {q0, q1}, [lr] + vpush {q0, q1} + add sp, sp, #208 + + ldr r1, [sp, #-36] + vld1.32 {q8, q9}, [r1] + ldr r1, [sp, #44] + vld1.32 {q10, q11}, [r1] + ldr r1, [sp, #48] + vld1.32 {q12, q13}, [r1] + ldr r1, [sp, #52] + vld1.32 {q14, q15}, [r1] + + ldr r11, [sp, #28] + ldr r4, [sp] + LoopH: + ldr r1, [sp, #-44] + ldr r0, [sp, #-48] + ldr r5, [sp, #4] + LoopW2: + vmov q4, q8 + vmov q5, q9 + vmov q6, q8 + vmov q7, q9 + mov r7, r1 + ldr r3, [sp, #-40] + ldr r6, [sp, #8] + LoopKH: + mov r9, r7 + ldr r10, [sp, #12] + LoopKW: + mov r8, r9 + vld1.16 {q0}, [r3]! + ldr lr, [sp, #40] + vld1.8 {d2}, [lr] + + vld1.8 {d3}, [r8] + add r8, r8, r11 + vsubl.s8 q2, d3, d2 + vmlal.s16 q4, d4, d0 + vmlal.s16 q5, d5, d1 + + vld1.8 {d3}, [r8] + add r8, r8, r11 + vsubl.s8 q2, d3, d2 + vmlal.s16 q6, d4, d0 + vmlal.s16 q7, d5, d1 + + ldr r12, [sp, #36] + add r9, r9, r12 + subs r10, r10, #1 + bne LoopKW + ldr r12, [sp, #32] + add r7, r7, r12 + subs r6, r6, #1 + bne LoopKH + + vshl.s32 q4, q4, q14 + vshl.s32 q5, q5, q15 + vshl.s32 q6, q6, q14 + vshl.s32 q7, q7, q15 + + vqrdmulh.s32 q4, q4, q12 + vqrdmulh.s32 q5, q5, q13 + vqrdmulh.s32 q6, q6, q12 + vqrdmulh.s32 q7, q7, q13 + + sub lr, sp, #144 + vld1.32 {q0, q1}, [lr] + + vand q2, q4, q0 + vshr.s32 q2, q2, #31 + vqadd.s32 q4, q4, q2 + vrshl.s32 q4, q4, q0 + + vand q2, q5, q1 + vshr.s32 q2, q2, #31 + vqadd.s32 q5, q5, q2 + vrshl.s32 q5, q5, q1 + + vand q2, q6, q0 + vshr.s32 q2, q2, #31 + vqadd.s32 q6, q6, q2 + vrshl.s32 q6, q6, q0 + + vand q2, q7, q1 + vshr.s32 q2, q2, #31 + vqadd.s32 q7, q7, q2 + vrshl.s32 q7, q7, q1 + + vadd.i32 q4, q4, q10 + vadd.i32 q5, q5, q11 + vadd.i32 q6, q6, q10 + vadd.i32 q7, q7, q11 + + sub lr, sp, #176 + vld1.32 {q0, q1}, [lr] + vmax.s32 q4, q4, q0 + vmax.s32 q5, q5, q1 + vmax.s32 q6, q6, q0 + vmax.s32 q7, q7, q1 + + sub lr, sp, #208 + vld1.32 {q0, q1}, [lr] + vmin.s32 q4, q4, q0 + vmin.s32 q5, q5, q1 + vmin.s32 q6, q6, q0 + vmin.s32 q7, q7, q1 + + vqmovn.s32 d0, q4 + vqmovn.s32 d1, q5 + vqmovn.s32 d2, q6 + vqmovn.s32 d3, q7 + vqmovn.s16 d0, q0 + vqmovn.s16 d1, q1 + + + ldr r12, [sp, #20] + mov r8, r0 + vst1.8 {d0[0]}, [r8]! + vst1.8 {d0[1]}, [r8]! + vst1.8 {d0[2]}, [r8]! + vst1.8 {d0[3]}, [r8]! + vst1.8 {d0[4]}, [r8]! + vst1.8 {d0[5]}, [r8]! + vst1.8 {d0[6]}, [r8]! + vst1.8 {d0[7]}, [r8]! + add r0, r0, r12 + + mov r8, r0 + vst1.8 {d1[0]}, [r8]! + vst1.8 {d1[1]}, [r8]! + vst1.8 {d1[2]}, [r8]! + vst1.8 {d1[3]}, [r8]! + vst1.8 {d1[4]}, [r8]! + vst1.8 {d1[5]}, [r8]! + vst1.8 {d1[6]}, [r8]! + vst1.8 {d1[7]}, [r8]! + add r0, r0, r12 + + add r1, r1, r11 + add r1, r1, r11 + subs r5, r5, #2 + beq LoopEndW + cmp r5, #2 + bge LoopW2 + + LoopW: + vmov q4, q8 + vmov q5, q9 + mov r7, r1 + ldr r3, [sp, #-40] + ldr r6, [sp, #8] + LoopKH1: + mov r9, r7 + ldr r10, [sp, #12] + LoopKW1: + vld1.16 {q0}, [r3]! + ldr lr, [sp, #40] + vld1.8 {d2}, [lr] + + vld1.8 {d3}, [r9] + vsubl.s8 q2, d3, d2 + vmlal.s16 q4, d4, d0 + vmlal.s16 q5, d5, d1 + + ldr r12, [sp, #36] + add r9, r9, r12 + subs r10, r10, #1 + bne LoopKW1 + ldr r12, [sp, #32] + add r7, r7, r12 + subs r6, r6, #1 + bne LoopKH1 + + vshl.s32 q4, q4, q14 + vshl.s32 q5, q5, q15 + + vqrdmulh.s32 q4, q4, q12 + vqrdmulh.s32 q5, q5, q13 + + sub lr, sp, #144 + vld1.32 {q0, q1}, [lr] + vand q2, q4, q0 + vshr.s32 q2, q2, #31 + vqadd.s32 q4, q4, q2 + vrshl.s32 q4, q4, q0 + + vand q2, q5, q1 + vshr.s32 q2, q2, #31 + vqadd.s32 q5, q5, q2 + vrshl.s32 q5, q5, q1 + + vadd.i32 q4, q4, q10 + vadd.i32 q5, q5, q11 + + sub lr, sp, #176 + vld1.32 {q0, q1}, [lr] + vmax.s32 q4, q4, q0 + vmax.s32 q5, q5, q1 + + sub lr, sp, #208 + vld1.32 {q0, q1}, [lr] + vmin.s32 q4, q4, q0 + vmin.s32 q5, q5, q1 + + vqmovn.s32 d0, q4 + vqmovn.s32 d1, q5 + vqmovn.s16 d0, q0 + + mov r8, r0 + vst1.8 {d0[0]}, [r8]! + vst1.8 {d0[1]}, [r8]! + vst1.8 {d0[2]}, [r8]! + vst1.8 {d0[3]}, [r8]! + vst1.8 {d0[4]}, [r8]! + vst1.8 {d0[5]}, [r8]! + vst1.8 {d0[6]}, [r8]! + vst1.8 {d0[7]}, [r8]! + ldr r12, [sp, #20] + add r0, r0, r12 + add r1, r1, r11 + subs r5, r5, #1 + bne LoopW + + LoopEndW: + ldr r12, [sp, #16] + ldr r1, [sp, #-48] + add r1, r1, r12 + str r1, [sp, #-48] + ldr r12, [sp, #24] + ldr r1, [sp, #-44] + add r1, r1, r12 + str r1, [sp, #-44] + subs r4, r4, #1 + bne LoopH + + LoopEndH: + sub sp, sp, #208 + vpop {q0, q1} + vpop {q0, q1} + vpop {q0, q1} + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwInt8PostAlign4.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwInt8PostAlign4.S new file mode 100644 index 0000000000000000000000000000000000000000..ee222b07ae25b59d0e5874a7fd83d8b977876e97 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwInt8PostAlign4.S @@ -0,0 +1,120 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, +// int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +// r0: dst, r1: buffer, r2: num_pixels, r3: output_zp, r4: out_multiplier, +// r5: left_shift, r6: right_shift, r7: acc_min, r8: acc_max + +asm_function ConvDwInt8PostAlign4 + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10} + vpush {q4-q7} + add sp, sp, #88 + + vdup.32 q15, r3 // output_zp + + ldr r4, [sp] // out_multiplier + vdup.32 q14, r4 + + ldr r5, [sp, #4] // left_shift + vdup.32 q13, r5 + + ldr r6, [sp, #8] // right_shift + vdup.32 q12, r6 + + ldr r7, [sp, #12] // acc_min + vdup.32 q11, r7 + + ldr r8, [sp, #16] // acc_max + vdup.32 q10, r8 + + mov r10, r0 + + LoopDepth8: + cmp r2, #8 + blt End + vld1.32 {q0}, [r1]! + vshl.s32 q0, q0, q13 + vqrdmulh.s32 q0, q0, q14 + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + vqmovn.s32 d4, q0 + + vld1.32 {q1}, [r1]! + vshl.s32 q1, q1, q13 + vqrdmulh.s32 q1, q1, q14 + vand q4, q1, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q1, q1, q4 + vrshl.s32 q1, q1, q12 + vadd.i32 q1, q1, q15 + vmax.s32 q1, q1, q11 + vmin.s32 q1, q1, q10 + vqmovn.s32 d5, q1 + vqmovn.s16 d4, q2 + + vst1.8 {d4}, [r10]! + + sub r2, r2, #8 + b LoopDepth8 + + LoopDepth4: + cmp r2, #4 + blt End + vld1.32 {q0}, [r1]! + + vshl.s32 q0, q0, q13 + vqrdmulh.s32 q0, q0, q14 + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + + vqmovn.s32 d0, q0 + vqmovn.s16 d0, q0 + + vst1.8 {d0[0]}, [r10]! + vst1.8 {d0[1]}, [r10]! + vst1.8 {d0[2]}, [r10]! + vst1.8 {d0[3]}, [r10]! + + sub r2, r2, #4 + b LoopDepth4 + End: + sub sp, sp, #88 + vpop {q4-q7} + pop {r4-r8, r10} + bx lr + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S new file mode 100644 index 0000000000000000000000000000000000000000..dc0e7723eb9837af5c4636ec99211a4537411d20 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwInt8PostAlign4PerChannel.S @@ -0,0 +1,123 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max); +// r0: dst, r1: buffer, r2: num_pixels, r3: output_zp, r4: out_multiplier, +// r5: left_shift, r6: right_shift, r7: acc_min, r8: acc_max + +asm_function ConvDwInt8PostAlign4PerChannel + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10} + vpush {q4-q7} + add sp, sp, #88 + + vdup.32 q15, r3 // output_zp + + ldr r4, [sp] // out_multiplier + ldr r5, [sp, #4] // left_shift + ldr r6, [sp, #8] // right_shift + + ldr r7, [sp, #12] // acc_min + vdup.32 q11, r7 + + ldr r8, [sp, #16] // acc_max + vdup.32 q10, r8 + + mov r10, r0 + + LoopDepth8: + cmp r2, #8 + blt End + vld1.32 {q0}, [r1]! + vld1.32 {q13}, [r5]! + vshl.s32 q0, q0, q13 + vld1.32 {q14}, [r4]! + vqrdmulh.s32 q0, q0, q14 + vld1.32 {q12}, [r6]! + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + vqmovn.s32 d4, q0 + + vld1.32 {q1}, [r1]! + vld1.32 {q13}, [r5]! + vshl.s32 q1, q1, q13 + vld1.32 {q14}, [r4]! + vqrdmulh.s32 q1, q1, q14 + vld1.32 {q12}, [r6]! + vand q4, q1, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q1, q1, q4 + vrshl.s32 q1, q1, q12 + vadd.i32 q1, q1, q15 + vmax.s32 q1, q1, q11 + vmin.s32 q1, q1, q10 + vqmovn.s32 d5, q1 + vqmovn.s16 d4, q2 + + vst1.8 {d4}, [r10]! + + sub r2, r2, #8 + b LoopDepth8 + + LoopDepth4: + cmp r2, #4 + blt End + vld1.32 {q0}, [r1]! + vld1.32 {q13}, [r5]! + vshl.s32 q0, q0, q13 + vld1.32 {q14}, [r4]! + vqrdmulh.s32 q0, q0, q14 + vld1.32 {q12}, [r6]! + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + + vqmovn.s32 d0, q0 + vqmovn.s16 d0, q0 + + vst1.8 {d0[0]}, [r10]! + vst1.8 {d0[1]}, [r10]! + vst1.8 {d0[2]}, [r10]! + vst1.8 {d0[3]}, [r10]! + + sub r2, r2, #4 + b LoopDepth4 + End: + sub sp, sp, #88 + vpop {q4-q7} + pop {r4-r8, r10} + bx lr + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwInt8Row.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwInt8Row.S new file mode 100644 index 0000000000000000000000000000000000000000..4e5059dc69f113c5bd5d284276477bc0178fdd26 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/ConvDwInt8Row.S @@ -0,0 +1,144 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, +// int output_channel, int input_step, int8_t input_zp) +// r0: output_ptr, r1: input_ptr, r2: weight_ptr, r3: num_pixels, +// r4: output_channel, r5: input_step, r6: input_zp, + +asm_function ConvDwInt8Row + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r9-r12, lr} + vpush {q4-q7} + add sp, sp, #104 + + cmp r3, #0 + beq End + + ldr r4, [sp] // channel + ldr r5, [sp, #4] // input_step + ldr r6, [sp, #8] // input_zp + vdup.8 d30, r6 + + mov r7, r0 + + LoopPixel: + mov r8, r1 // input + mov r10, r2 // weight + mov r11, r4 + + LoopDepth16In: + cmp r11, #16 + blt L8 + sub r11, r11, #16 + + vld1.8 {q0}, [r8]! + vld1.16 {q1, q2}, [r10]! // weight + + vsubl.s8 q3, d0, d30 // -zp + vld1.32 {q4, q5}, [r0]! + vmlal.s16 q4, d6, d2 + vmlal.s16 q5, d7, d3 + + cmp r11, #16 + blt LoopDepth16Out + LoopDepth16: + vst1.32 {q4, q5}, [r7]! + + vsubl.s8 q6, d1, d30 + vld1.32 {q7, q8}, [r0]! + vmlal.s16 q7, d12, d4 + vmlal.s16 q8, d13, d5 + vst1.32 {q7, q8}, [r7]! + + vld1.8 {q0}, [r8]! + vld1.16 {q1, q2}, [r10]! // weight + + vsubl.s8 q3, d0, d30 // -zp + vld1.32 {q4, q5}, [r0]! + vmlal.s16 q4, d6, d2 + vmlal.s16 q5, d7, d3 + + sub r11, r11, #16 + cmp r11, #16 + bge LoopDepth16 + + LoopDepth16Out: + vst1.32 {q4, q5}, [r7]! + + vsubl.s8 q6, d1, d30 + vld1.32 {q7, q8}, [r0]! + vmlal.s16 q7, d12, d4 + vmlal.s16 q8, d13, d5 + vst1.32 {q7, q8}, [r7]! + + L8: + cmp r11, #8 + blt L0 + + LoopDepth8: + vld1.8 {d0}, [r8]! + vld1.16 {d2, d3}, [r10]! // weight + + vsubl.s8 q2, d0, d30 // -zp + + vld1.32 {q3}, [r0]! + vmlal.s16 q3, d4, d2 + vst1.32 {q3}, [r7]! + + vld1.32 {q4}, [r0]! + vmlal.s16 q4, d5, d3 + vst1.32 {q4}, [r7]! + + sub r11, r11, #8 + cmp r11, #8 + bge LoopDepth8 + + L0: + cmp r11, #0 + beq LoopDepthEnd + + LoopDepth0: + ldrsb r12, [r8], #1 + ldrsh r9, [r10], #2 + sub r12, r12, r6 + + ldr lr, [r0], #4 + smlabb r12, r12, r9, lr + str r12, [r7], #4 + + subs r11, r11, #1 + bne L0 + + LoopDepthEnd: + add r1, r1, r5 + subs r3, r3, #1 + bne LoopPixel + + End: + sub sp, sp, #104 + vpop {q4-q7} + pop {r4-r8, r9-r12, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/DeconvDwFp32Center.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/DeconvDwFp32Center.S new file mode 100644 index 0000000000000000000000000000000000000000..1eff8e9c8d63f0a283ece03621d03ac23206dda1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/DeconvDwFp32Center.S @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, +// size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +// r0: dst, r1: src, r2: weight, r3: height, r4: width, #52: kernel_h, #56: kernel_w, #60: out_h_step +// #64: block_channel, #68: in_sh_step, #72: in_sw_step, #76: in_kh_step, #80: in_kw_step +asm_function DeconvDwFp32Center + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + + ldr r10, [sp, #80] // in_kw_step + ldr r11, [sp, #76] // in_kh_step + + LoopH: + ldr r0, [sp] // dst_w + ldr r1, [sp, #4] // src_w + ldr r4, [sp, #48] // width + LoopW: + mov r6, r0 // dst_kh + ldr r2, [sp, #8] // weight_kh + ldr r5, [sp, #52] // kernel_h + vld1.32 {q1}, [r1] + LoopKh: + mov r7, r6 // dst_kw + ldr r12, [sp, #56] // kernel_w + LoopKw: + vld1.32 {q0}, [r7] + vld1.32 {q2}, [r2]! + vmla.f32 q0, q1, q2 + vst1.32 {q0}, [r7] + add r7, r7, r10 + subs r12, r12, #1 + bne LoopKw + add r6, r6, r11 + subs r5, r5, #1 + bne LoopKh + ldr r12, [sp, #72] + add r0, r0, r12 + ldr r8, [sp, #64] + add r1, r1, r8 + subs r4, r4, #1 + bne LoopW + ldr r8, [sp, #68] + ldr r12, [sp] + add r12, r12, r8 + str r12, [sp] + ldr r8, [sp, #60] + ldr r12, [sp, #4] + add r12, r12, r8 + str r12, [sp, #4] + subs r3, r3, #1 + bne LoopH + + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/DeconvDwInt8Center.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/DeconvDwInt8Center.S new file mode 100644 index 0000000000000000000000000000000000000000..2a8d69022b5936a76116359233190f96c6fa1aea --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/DeconvDwInt8Center.S @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, +// size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +// r0: dst, r1: src, r2: weight, r3: height, r4: width, #52: kernel_h, #56: kernel_w, #60: out_h_step +// #64: block_channel, #68: in_sh_step, #72: in_sw_step, #76: in_kh_step, #80: in_kw_step +asm_function DeconvDwInt8Center + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + + ldr r10, [sp, #80] // in_kw_step + ldr r11, [sp, #76] // in_kh_step + + LoopH: + ldr r0, [sp] // dst_w + ldr r1, [sp, #4] // src_w + ldr r4, [sp, #48] // width + LoopW: + mov r6, r0 // dst_kh + ldr r2, [sp, #8] // weight_kh + ldr r5, [sp, #52] // kernel_h + vld1.16 {d2}, [r1] + LoopKh: + mov r7, r6 // dst_kw + ldr r12, [sp, #56] // kernel_w + LoopKw: + vld1.32 {q0}, [r7] + vld1.16 {d24}, [r2]! + vmlal.s16 q0, d2, d24 + vst1.32 {q0}, [r7] + add r7, r7, r10 + subs r12, r12, #1 + bne LoopKw + add r6, r6, r11 + subs r5, r5, #1 + bne LoopKh + ldr r12, [sp, #72] + add r0, r0, r12 + ldr r8, [sp, #64] + add r1, r1, r8 + subs r4, r4, #1 + bne LoopW + ldr r8, [sp, #68] + ldr r12, [sp] + add r12, r12, r8 + str r12, [sp] + ldr r8, [sp, #60] + ldr r12, [sp, #4] + add r12, r12, r8 + str r12, [sp, #4] + subs r3, r3, #1 + bne LoopH + + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/DeconvDwInt8Post.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/DeconvDwInt8Post.S new file mode 100644 index 0000000000000000000000000000000000000000..b5ea1cacf97628061834a603f8e01b10cf90e84b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/DeconvDwInt8Post.S @@ -0,0 +1,84 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void DeconvDwInt8Post(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, int pixel_nums, +// int out_multiplier, int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, +// int32_t acc_max) +// r0: dst, r1: output_buffer, r2: bias, r3: block_channel, r4: pixel_nums, r5: out_multiplier, +// r6: left_shift, r7: right_shift, r8: out_zp, r9: acc_min, r10: acc_max + +asm_function DeconvDwInt8Post + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8} + add sp, sp, #20 + + vld1.32 {q9}, [r2] + ldr r4, [sp] + ldr r5, [sp, #4] + vdup.32 q14, r5 // out_multiplier + ldr r6, [sp, #8] + vdup.32 q13, r6 // left_shift + ldr r5, [sp, #12] + vdup.32 q12, r5 // right_shift + ldr r6, [sp, #16] + vdup.32 q15, r6 // output_zp + ldr r7, [sp, #20] + vdup.32 q11, r7 // acc_min + ldr r8, [sp, #24] + vdup.32 q10, r8 // acc_max + + LoopCount: + mov r8, r0 + vld1.32 {q0}, [r1]! + vand q0, q0, q9 + + vshl.s32 q0, q0, q13 + vqrdmulh.s32 q0, q0, q14 + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + + vqmovn.s32 d0, q0 + vqmovn.s16 d0, q0 + + vst1.8 {d0[0]}, [r8]! + vst1.8 {d0[1]}, [r8]! + vst1.8 {d0[2]}, [r8]! + vst1.8 {d0[3]}, [r8]! + add r0, r0, r3 + + sub r4, r4, #1 + cmp r4, #1 + bge LoopCount + End: + sub sp, sp, #20 + pop {r4-r8} + bx lr + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/IndirectGemmInt16to32_8x4.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/IndirectGemmInt16to32_8x4.S new file mode 100644 index 0000000000000000000000000000000000000000..7758ef884baf575fe5d0aa439215f084d6fd729b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/IndirectGemmInt16to32_8x4.S @@ -0,0 +1,249 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void IndirectGemmInt16to32_8x4(int *output, short *input, short *weight, size_t kszie, size_t ic8, size_t oc4, size_t offset); +// r0: output, r1: input, r2: weight, r3: kszie, r4: ic8, r5: oc4, r6: offset +asm_function IndirectGemmInt16to32_8x4 + + .macro INIT_ZERO + // we could also use "vmov.s32 q12, #0" to initialize q12 by 0 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + .endm + + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10, lr} + + ldr r4, [sp, #28] + ldr r5, [sp, #32] + ldr r6, [sp, #36] + + vpush {q4-q7} + + LoopOc: + + mov r7, r3 + mov r8, r1 + + LoopKsize: + mov r10, r0 + INIT_ZERO + + // load input + vld1.16 {q0, q1}, [r8]! + // load weight + vld1.16 {q4}, [r2]! + vmull.s16 q8, d8, d0[0] + vmull.s16 q9, d8, d2[0] + // load weight + vld1.16 {q5}, [r2]! + vmlal.s16 q8, d9, d0[1] + vmlal.s16 q9, d9, d2[1] + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q8, d10, d0[2] + vmlal.s16 q9, d10, d2[2] + vmlal.s16 q8, d11, d0[3] + vmlal.s16 q9, d11, d2[3] + // load weight + vld1.16 {q6, q7}, [r2]! + vmull.s16 q10, d8, d4[0] + vmull.s16 q11, d8, d6[0] + + subs r12, r4, #1 + beq LoopIcEnd + + LoopIc: + + vmlal.s16 q10, d9, d4[1] + vmlal.s16 q11, d9, d6[1] + vmlal.s16 q10, d10, d4[2] + vmlal.s16 q11, d10, d6[2] + vmlal.s16 q10, d11, d4[3] + vmlal.s16 q11, d11, d6[3] + + vmlal.s16 q8, d12, d1[0] + vmlal.s16 q9, d12, d3[0] + vmlal.s16 q8, d13, d1[1] + vmlal.s16 q9, d13, d3[1] + vmlal.s16 q8, d14, d1[2] + vmlal.s16 q9, d14, d3[2] + vmlal.s16 q8, d15, d1[3] + vmlal.s16 q9, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q10, d12, d5[0] + vmlal.s16 q11, d12, d7[0] + vmlal.s16 q10, d13, d5[1] + vmlal.s16 q11, d13, d7[1] + vmlal.s16 q10, d14, d5[2] + vmlal.s16 q11, d14, d7[2] + vmlal.s16 q10, d15, d5[3] + vmlal.s16 q11, d15, d7[3] + + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q12, d8, d0[0] + vmlal.s16 q13, d8, d2[0] + vmlal.s16 q12, d9, d0[1] + vmlal.s16 q13, d9, d2[1] + vmlal.s16 q12, d10, d0[2] + vmlal.s16 q13, d10, d2[2] + vmlal.s16 q12, d11, d0[3] + vmlal.s16 q13, d11, d2[3] + + vmlal.s16 q14, d8, d4[0] + vmlal.s16 q15, d8, d6[0] + vmlal.s16 q14, d9, d4[1] + vmlal.s16 q15, d9, d6[1] + vmlal.s16 q14, d10, d4[2] + vmlal.s16 q15, d10, d6[2] + vmlal.s16 q14, d11, d4[3] + vmlal.s16 q15, d11, d6[3] + // load weight + vld1.16 {q4, q5}, [r2]! + vmlal.s16 q12, d12, d1[0] + vmlal.s16 q13, d12, d3[0] + vmlal.s16 q12, d13, d1[1] + vmlal.s16 q13, d13, d3[1] + vmlal.s16 q12, d14, d1[2] + vmlal.s16 q13, d14, d3[2] + vmlal.s16 q12, d15, d1[3] + vmlal.s16 q13, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q14, d12, d5[0] + vmlal.s16 q15, d12, d7[0] + vmlal.s16 q14, d13, d5[1] + vmlal.s16 q15, d13, d7[1] + vmlal.s16 q14, d14, d5[2] + vmlal.s16 q15, d14, d7[2] + vmlal.s16 q14, d15, d5[3] + vmlal.s16 q15, d15, d7[3] + // load input + vld1.16 {q2, q3}, [r8]! + vmlal.s16 q8, d8, d0[0] + vmlal.s16 q9, d8, d2[0] + vmlal.s16 q8, d9, d0[1] + vmlal.s16 q9, d9, d2[1] + // load weight + vld1.16 {q6, q7}, [r2]! + vmlal.s16 q8, d10, d0[2] + vmlal.s16 q9, d10, d2[2] + vmlal.s16 q8, d11, d0[3] + vmlal.s16 q9, d11, d2[3] + vmlal.s16 q10, d8, d4[0] + vmlal.s16 q11, d8, d6[0] + + subs r12, r12, #1 + bne LoopIc + + LoopIcEnd: + + vmlal.s16 q10, d9, d4[1] + vmlal.s16 q11, d9, d6[1] + vmlal.s16 q10, d10, d4[2] + vmlal.s16 q11, d10, d6[2] + vmlal.s16 q10, d11, d4[3] + vmlal.s16 q11, d11, d6[3] + + vmlal.s16 q8, d12, d1[0] + vmlal.s16 q9, d12, d3[0] + vmlal.s16 q8, d13, d1[1] + vmlal.s16 q9, d13, d3[1] + vmlal.s16 q8, d14, d1[2] + vmlal.s16 q9, d14, d3[2] + vmlal.s16 q8, d15, d1[3] + vmlal.s16 q9, d15, d3[3] + // load input + vld1.16 {q0, q1}, [r8]! + vmlal.s16 q10, d12, d5[0] + vmlal.s16 q11, d12, d7[0] + vmlal.s16 q10, d13, d5[1] + vst1.32 {q8}, [r10], r6 + vmlal.s16 q11, d13, d7[1] + vmlal.s16 q10, d14, d5[2] + vst1.32 {q9}, [r10], r6 + vmlal.s16 q11, d14, d7[2] + vmlal.s16 q10, d15, d5[3] + vmlal.s16 q11, d15, d7[3] + + // load input + vld1.s16 {q2, q3}, [r8]! + vmlal.s16 q12, d8, d0[0] + vmlal.s16 q13, d8, d2[0] + vmlal.s16 q12, d9, d0[1] + vst1.32 {q10}, [r10], r6 + vmlal.s16 q13, d9, d2[1] + vmlal.s16 q12, d10, d0[2] + vst1.32 {q11}, [r10], r6 + vmlal.s16 q13, d10, d2[2] + vmlal.s16 q12, d11, d0[3] + vmlal.s16 q13, d11, d2[3] + + vmlal.s16 q14, d8, d4[0] + vmlal.s16 q15, d8, d6[0] + vmlal.s16 q14, d9, d4[1] + vmlal.s16 q15, d9, d6[1] + vmlal.s16 q14, d10, d4[2] + vmlal.s16 q15, d10, d6[2] + vmlal.s16 q14, d11, d4[3] + vmlal.s16 q15, d11, d6[3] + + vmlal.s16 q12, d12, d1[0] + vmlal.s16 q13, d12, d3[0] + vmlal.s16 q12, d13, d1[1] + vmlal.s16 q13, d13, d3[1] + vmlal.s16 q12, d14, d1[2] + vmlal.s16 q13, d14, d3[2] + vmlal.s16 q12, d15, d1[3] + vmlal.s16 q13, d15, d3[3] + vst1.32 {q12}, [r10], r6 + vmlal.s16 q14, d12, d5[0] + vmlal.s16 q15, d12, d7[0] + vmlal.s16 q14, d13, d5[1] + vmlal.s16 q15, d13, d7[1] + vmlal.s16 q14, d14, d5[2] + vst1.32 {q13}, [r10], r6 + vmlal.s16 q15, d14, d7[2] + vmlal.s16 q14, d15, d5[3] + vmlal.s16 q15, d15, d7[3] + + vst1.32 {q14}, [r10], r6 + vst1.32 {q15}, [r10] + + subs r7, r7, #1 + add r0, r0, #16 + bne LoopKsize + + subs r5, r5, #1 + bne LoopOc + + vpop {q4-q7} + pop {r4-r8, r10, pc} + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/IndirectGemmInt8_2x4.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/IndirectGemmInt8_2x4.S new file mode 100644 index 0000000000000000000000000000000000000000..cf50565c52ee54729155a719a53b479d2fb009d5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/IndirectGemmInt8_2x4.S @@ -0,0 +1,306 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void IndirectGemmInt8_2x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, +// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier, +// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset); +// r0: output, r1: input, r2: weight, r3: bias, r4: kSize, r5: ic4, r6: oc, r7: offset +// r8: input_sum, r10: act_min, r11: act_max, r10: out_zp, r11: out_multiplier, r10: shift_before, r11: shift_after +asm_function IndirectGemmInt8_2x4 + + .macro INIT_BIAS + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + .endm + + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #96 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + mul r5, r4, r5 + mov r4, #1 + + LoopOc: + + mov r8, r4 + mov r12, r1 + + LoopKsize: + INIT_BIAS + mov r11, r0 + + // as some processors do not support sdot intrinsic, we use instruction word + // dp support is stilled judged dymaticly, instruction word is just used to ensure compilation + // according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf + // the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is + // 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd + // mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index + + // load input for output 1-2 + vld1.8 {q0, q1}, [r12]! + // load weight for oc 1-2 + vld1.8 {q2, q3}, [r2]! + vmull.s8 q6, d0, d4 + vmull.s8 q7, d0, d6 + vmlal.s8 q6, d1, d5 + vmlal.s8 q7, d1, d7 + vpaddl.s16 q8, q6 + vpaddl.s16 q9, q7 + // load weight for oc 3-4 + vld1.8 {q4, q5}, [r2]! + vmull.s8 q6, d0, d8 + vmull.s8 q7, d0, d10 + vmlal.s8 q6, d1, d9 + vmlal.s8 q7, d1, d11 + + subs r10, r5, #1 + beq LoopIcEnd + + LoopIc: + // load input for output 1 + vld1.8 {q0}, [r12]! + vpadal.s16 q10, q6 + vpadal.s16 q11, q7 + vmull.s8 q6, d2, d4 + vmull.s8 q7, d2, d6 + vmlal.s8 q6, d3, d5 + vmlal.s8 q7, d3, d7 + vld1.8 {q2, q3}, [r2]! + vpadal.s16 q12, q6 + vpadal.s16 q13, q7 + vmull.s8 q6, d2, d8 + vmull.s8 q7, d2, d10 + vmlal.s8 q6, d3, d9 + vmlal.s8 q7, d3, d11 + vld1.8 {q4, q5}, [r2]! + vpadal.s16 q14, q6 + vpadal.s16 q15, q7 + vmull.s8 q6, d0, d4 + vmull.s8 q7, d0, d6 + vmlal.s8 q6, d1, d5 + vmlal.s8 q7, d1, d7 + vld1.8 {q1}, [r12]! + vpadal.s16 q8, q6 + vpadal.s16 q9, q7 + vmull.s8 q6, d0, d8 + vmull.s8 q7, d0, d10 + vmlal.s8 q6, d1, d9 + vmlal.s8 q7, d1, d11 + + subs r10, r10, #1 + bne LoopIc + + LoopIcEnd: + vpadal.s16 q10, q6 + vpadal.s16 q11, q7 + vmull.s8 q6, d2, d4 + vmull.s8 q7, d2, d6 + vmlal.s8 q6, d3, d5 + vmlal.s8 q7, d3, d7 + vpadal.s16 q12, q6 + vpadal.s16 q13, q7 + vmull.s8 q6, d2, d8 + vmull.s8 q7, d2, d10 + vmlal.s8 q6, d3, d9 + vmlal.s8 q7, d3, d11 + vpadal.s16 q14, q6 + vpadal.s16 q15, q7 + + // pairwise add + vpadd.i32 d16, d16, d17 + vpadd.i32 d18, d18, d19 + vpadd.i32 d20, d20, d21 + vpadd.i32 d22, d22, d23 + vpadd.i32 d24, d24, d25 + vpadd.i32 d26, d26, d27 + vpadd.i32 d28, d28, d29 + vpadd.i32 d30, d30, d31 + + vpadd.i32 d16, d16, d18 + vpadd.i32 d17, d20, d22 + vpadd.i32 d24, d24, d26 + vpadd.i32 d25, d28, d30 + + // load sum + ldr lr, [sp, #44] + cmp lr, #0 + beq NoSum + ldr r10, [sp, #16] + ldr lr, [sp, #48] + cmp lr, #0 + beq SymSum + ldr lr, [sp, #52] + vld1.32 {d0, d1}, [r10] + add r10, r10, lr + vld1.32 {d2, d3}, [r10] + b AddSum + SymSum: + vld1.32 {d0[], d1[]}, [r10]! + vld1.32 {d2[], d3[]}, [r10]! + AddSum: + vsub.i32 q8, q8, q0 + vsub.i32 q12, q12, q1 + NoSum: + cmp r3, #0 + beq NoBias + vld1.32 {d4, d5}, [r3] + vadd.i32 q8, q8, q2 + vadd.i32 q12, q12, q2 + + NoBias: + ldr lr, [sp, #48] + cmp lr, #0 + bne PerChannel + ldr lr, [sp, #36] + vld1.32 {d6[], d7[]}, [lr] + ldr lr, [sp, #32] + vld1.32 {d8[], d9[]}, [lr] + ldr lr, [sp, #40] + vld1.32 {d10[], d11[]}, [lr] + b QuantizeStart + PerChannel: + ldr lr, [sp, #36] + vld1.32 {d6, d7}, [lr] + ldr lr, [sp, #32] + vld1.32 {d8, d9}, [lr] + ldr lr, [sp, #40] + vld1.32 {d10, d11}, [lr] + QuantizeStart: + vshl.s32 q8, q8, q3 + vshl.s32 q12, q12, q3 + + vqrdmulh.s32 q8, q8, q4 + vqrdmulh.s32 q12, q12, q4 + + vand q3, q5, q8 + vshr.s32 q3, q3, #31 + vqadd.s32 q8, q8, q3 + vrshl.s32 q8, q8, q5 + vand q4, q5, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q12, q12, q4 + vrshl.s32 q12, q12, q5 + + ldr r10, [sp, #28] + vdup.32 q6, r10 + vadd.i32 q8, q8, q6 + vadd.i32 q12, q12, q6 + + ldr r10, [sp, #20] + vdup.32 q0, r10 + vmax.s32 q8, q8, q0 + vmax.s32 q12, q12, q0 + + ldr r10, [sp, #24] + vdup.32 q1, r10 + vmin.s32 q8, q8, q1 + vmin.s32 q12, q12, q1 + + vqmovn.s32 d30, q8 + vqmovn.s32 d31, q12 + vqmovn.s16 d0, q15 + + // prefetching is not preferred while writing results in spite of cache missing + // you could try prfm pstl2strm + WriteStart: + cmp r6, #1 + beq Write1 + cmp r6, #2 + beq Write2 + cmp r6, #3 + beq Write3 + b Write4 + Write1: + vst1.8 {d0[0]}, [r11], r7 + vst1.8 {d0[1]}, [r11] + add r0, r0, #1 + b WriteEnd + Write2: + vst1.16 {d0[0]}, [r11], r7 + vst1.16 {d0[1]}, [r11] + add r0, r0, #2 + b WriteEnd + Write3: + add r14, r11, #2 + vst1.16 {d0[0]}, [r11], r7 + vst1.16 {d0[1]}, [r11] + vst1.8 {d0[0]}, [r14], r7 + vst1.8 {d0[1]}, [r14] + add r0, r0, #3 + b WriteEnd + Write4: + vst1.32 {d0[0]}, [r11], r7 + vst1.32 {d0[1]}, [r11] + add r0, r0, #4 + + WriteEnd: + + subs r8, r8, #1 + bne LoopKsize + + cmp r6, #4 + ble LoopOcEnd + ldr lr, [sp, #48] + cmp lr, #0 + beq NoChannelForward + ldr lr, [sp, #44] + cmp lr, #0 + beq NoSumForward + ldr lr, [sp, #16] + add lr, lr, #16 + str lr, [sp, #16] + NoSumForward: + ldr lr, [sp, #36] + add lr, lr, #16 + str lr, [sp, #36] + ldr lr, [sp, #32] + add lr, lr, #16 + str lr, [sp, #32] + ldr lr, [sp, #40] + add lr, lr, #16 + str lr, [sp, #40] + NoChannelForward: + sub r6, r6, #4 + cmp r3, #0 + beq NoStepFowrard + add r3, r3, #16 + NoStepFowrard: + b LoopOc + +LoopOcEnd: + sub sp, sp, #96 + vpop {q4-q7} + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatVecMulFp32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatVecMulFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..ab009f9264e9e2017c523633fd6111e63e634a17 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatVecMulFp32.S @@ -0,0 +1,195 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: col + +asm_function MatVecMulFp32 + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r9, r10, r11, lr} + add sp, sp, #52 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + + mov r10, #4 + mul r10, r10, r5 // stride = depth * sizeof(float) + mov r11, #4 + mul r11, r11, r10 // stride x 4 + + cmp r6, #4 + blt Col1Loop + +Col4Loop: + mov r7, r0 // reload a(vector) ptr + mov r9, r1 // reload b(matrix) ptr + mov r8, r5 // reload depth value + + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q15, q15, q15 + + cmp r8, #4 + blt Col4Depth1 + + Col4Depth4: + vld1.f32 {q8}, [r7]! + add lr, r9, r10 + vld1.f32 {q0}, [r9]! + vld1.f32 {q1}, [lr], r10 + vld1.f32 {q2}, [lr], r10 + vld1.f32 {q3}, [lr] + + vmla.f32 q9, q8, q0 + vmla.f32 q10, q8, q1 + vmla.f32 q11, q8, q2 + vmla.f32 q12, q8, q3 + sub r8, r8, #4 + cmp r8, #4 + bge Col4Depth4 + + vpadd.f32 d26, d18, d20 + vpadd.f32 d27, d19, d21 + vpadd.f32 d28, d22, d24 + vpadd.f32 d29, d23, d25 + vadd.f32 d30, d26, d27 + vadd.f32 d31, d28, d29 + cmp r8, #0 + beq Col4End + + Col4Depth1: + vld1.f32 {d0[0]}, [r7]! + add lr, r9, r10 + vld1.f32 {d2[0]}, [r9]! + vld1.f32 {d2[1]}, [lr], r10 + vld1.f32 {d3[0]}, [lr], r10 + vld1.f32 {d3[1]}, [lr] + + vmla.f32 q15, q1, d0[0] + subs r8, r8, #1 + bne Col4Depth1 + + Col4End: + cmp r3, #0 + beq Col4Activation + vld1.f32 {q13}, [r3]! + vadd.f32 q15, q15, q13 + + Col4Activation: + cmp r4, #3 + beq Col4Relu6 + cmp r4, #1 + beq Col4Relu + b Col4Write + + Col4Relu6: + vmov.i32 q12, #6 + vcvt.f32.s32 q12, q12 + vmin.f32 q15, q15, q12 + + Col4Relu: + veor q13, q13, q13 + vmax.f32 q15, q15, q13 + + Col4Write: + vst1.f32 {q15}, [r2]! + subs r6, r6, #4 + beq End + add r1, r1, r11 + cmp r6, #4 + bge Col4Loop + +Col1Loop: + mov r7, r0 // reload a(vector) ptr + mov r9, r1 // reload b(matrix) ptr + mov r8, r5 // reload depth value + veor q10, q10, q10 + veor q13, q13, q13 + veor q15, q15, q15 + + cmp r8, #4 + blt Col1Depth1 + + Col1Depth4: + vld1.f32 {q0}, [r7]! + vld1.f32 {q1}, [r9]! + + vmla.f32 q10, q1, q0 + sub r8, r8, #4 + cmp r8, #4 + bge Col1Depth4 + + vpadd.f32 d24, d20, d22 + vpadd.f32 d25, d21, d23 + vadd.f32 d30, d24, d25 + cmp r8, #0 + beq Col1End + + Col1Depth1: + vld1.f32 {d0[0]}, [r7]! + vld1.f32 {d2[0]}, [r9]! + + vmla.f32 d30, d2, d0[0] + subs r8, r8, #1 + bne Col1Depth1 + + Col1End: + cmp r3, #0 + beq Col1Activation + vld1.f32 {d28[0]}, [r3]! + vadd.f32 d30, d30, d28 + + Col1Activation: + cmp r4, #3 + beq Col1Relu6 + cmp r4, #1 + beq Col1Relu + b Col1Write + + Col1Relu6: + vmov.i32 d26, #6 + vcvt.f32.s32 d26, d26 + vmin.f32 d30, d30, d26 + + Col1Relu: + veor d24, d24, d24 + vmax.f32 d30, d30, d24 + + Col1Write: + vst1.f32 {d30[0]}, [r2]! + subs r6, r6, #1 + beq End + add r1, r1, r10 + b Col1Loop + +End: + sub sp, sp, #52 + pop {r0-r8, r9, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulFp32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..3bc36a7080cab1d60995435d3db010a6d2e499df --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulFp32.S @@ -0,0 +1,381 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: row +// r7: col +// r8: stride +// lr: writeNhwc/writeWino + +asm_function MatmulFloatNeon32 + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + add sp, sp, #48 + + ldr r5, [sp, #4] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + mov lr, #32 // sizeof(float) * 8 + mul r12, r5, lr // block stride of lhs/rhs: sizeof(float) * 8 * depth + ldr lr, [sp, #24] + cmp lr, #0 + beq NoWinoSteps + mov lr, #4 + mul r11, r7, r8 // stride * col * sizeof(float) + mul r11, r11, lr + mov lr, #32 + mul r10, r8, lr // stride * 8 * sizeof(float) +NoWinoSteps: + mov lr, #4 + mul r8, r8, lr // stride * sizeof(float) + +LoopCol: + ldr r6, [sp, #8] // reload lhs row + ldr r0, [sp, #-48] // reload lhs ptr + ldr r2, [sp, #-40] // reload dst ptr + + LoopRow: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r5, [sp, #4] // reload depth + veor q8, q8, q8 + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + + LoopDepth: + vld1.32 {q0}, [r0]! + vld1.32 {q1, q2}, [r1]! + vmla.f32 q8, q1, d0[0] + vmla.f32 q9, q2, d0[0] + vmla.f32 q10, q1, d0[1] + vmla.f32 q11, q2, d0[1] + vmla.f32 q12, q1, d1[0] + vmla.f32 q13, q2, d1[0] + vmla.f32 q14, q1, d1[1] + vmla.f32 q15, q2, d1[1] + + subs r5, r5, #1 + bne LoopDepth + + Bias: + cmp r3, #0 + beq Activation + vld1.32 {q0}, [r3]! + vld1.32 {q1}, [r3] + sub r3, r3, #16 + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q1 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q1 + vadd.f32 q12, q12, q0 + vadd.f32 q13, q13, q1 + vadd.f32 q14, q14, q0 + vadd.f32 q15, q15, q1 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + vmin.f32 q12, q12, q2 + vmin.f32 q13, q13, q2 + vmin.f32 q14, q14, q2 + vmin.f32 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + vmax.f32 q12, q12, q3 + vmax.f32 q13, q13, q3 + vmax.f32 q14, q14, q3 + vmax.f32 q15, q15, q3 + + Write: + ldr lr, [sp, #24] + cmp lr, #0 + bne WriteWino + ldr lr, [sp, #20] + cmp lr, #0 + beq WriteC8 + cmp r7, #1 + beq Write1 + cmp r7, #2 + beq Write2 + cmp r7, #3 + beq Write3 + cmp r7, #4 + beq Write4 + cmp r7, #5 + beq Write5 + cmp r7, #6 + beq Write6 + cmp r7, #7 + beq Write7 + b Write8 + + Write1: + vst1.32 d16[0], [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20[0], [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24[0], [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28[0], [r2] + add r2, r2, r8 + b WriteEnd + Write2: + vst1.32 d16, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28, [r2] + add r2, r2, r8 + b WriteEnd + Write3: + add r4, r2, #8 + vst1.32 d16, [r2] + vst1.32 d17[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d20, [r2] + vst1.32 d21[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d24, [r2] + vst1.32 d25[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d28, [r2] + vst1.32 d29[0], [r4] + add r2, r2, r8 + b WriteEnd + Write4: + vst1.32 {d16, d17}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d20, d21}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d24, d25}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d28, d29}, [r2] + add r2, r2, r8 + b WriteEnd + Write5: + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30[0], [r4] + add r2, r2, r8 + b WriteEnd + Write6: + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18, [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22, [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26, [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30, [r4] + add r2, r2, r8 + b WriteEnd + Write7: + add lr, r2, #24 + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18, [r4] + vst1.32 d19[0], [lr] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22, [r4] + vst1.32 d23[0], [lr] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26, [r4] + vst1.32 d27[0], [lr] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30, [r4] + vst1.32 d31[0], [lr] + add r2, r2, r8 + b WriteEnd + WriteC8: + vst1.32 {q8, q9}, [r2]! + vst1.32 {q10, q11}, [r2]! + vst1.32 {q12, q13}, [r2]! + vst1.32 {q14, q15}, [r2]! + str r2, [sp, #-40] + b WriteEnd + WriteWino: + vst1.32 {q8, q9}, [r2] + add r2, r2, r11 + vst1.32 {q10, q11}, [r2] + add r2, r2, r11 + vst1.32 {q12, q13}, [r2] + add r2, r2, r11 + vst1.32 {q14, q15}, [r2] + add r2, r2, r11 + b WriteEnd + Write8: + vst1.32 {q8, q9}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q10, q11}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q12, q13}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q14, q15}, [r2] + add r2, r2, r8 + + WriteEnd: + cmp r6, #4 + ble LoopRowEnd + sub r6, r6, #4 // lhs row - 4 + b LoopRow + + LoopRowEnd: + ldr r1, [sp, #-44] + add r1, r1, r12 // rhs ptr + stride + str r1, [sp, #-44] + cmp r3, #0 + beq NoBiasStep + add r3, r3, #32 // bias ptr + stride + NoBiasStep: + ldr lr, [sp, #24] + cmp lr, #0 + bne WinoDstStep + ldr lr, [sp, #20] + cmp lr, #0 + beq NoDstStep + ldr r2, [sp, #-40] + add r2, r2, #32 // dst ptr + stride + str r2, [sp, #-40] + b NoDstStep + WinoDstStep: + ldr r2, [sp, #-40] + add r2, r2, r10 + str r2, [sp, #-40] + NoDstStep: + cmp r7, #8 + ble LoopColEnd + sub r7, r7, #8 // rhs col - 8 + b LoopCol + +LoopColEnd: + sub sp, sp, #48 + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulFp32Opt.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulFp32Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..01cdd2d6b0f192008bb4c15a080d3510959ce43c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulFp32Opt.S @@ -0,0 +1,422 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: row +// r7: col +// r8: stride +// lr: writeNhwc/writeWino + +asm_function MatmulFloatNeon32Opt + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + add sp, sp, #48 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + mov lr, #16 // sizeof(float) * 4 + mul r12, r5, lr // block stride of lhs/rhs: sizeof(float) * 4 * depth + ldr lr, [sp, #20] + cmp lr, #0 + bne NoC8Steps + mov lr, #32 + mul r10, r6, lr +NoC8Steps: + cmp lr, #2 + bne NoWinoSteps + mov lr, #4 + mul r11, r7, r8 // stride * col * sizeof(float) + mul r11, r11, lr + mov lr, #32 + mul r10, r8, lr // stride * 8 * sizeof(float) +NoWinoSteps: + mov lr, #4 + mul r8, r8, lr // stride * sizeof(float) + +LoopRow: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r7, [sp, #12] // reload rhs col + ldr r3, [sp, #-36] // reload bias ptr + + LoopCol: + ldr lr, [sp, #20] + cmp lr, #0 + beq NoReloadDst + ldr r2, [sp, #-40] // reload dst ptr + NoReloadDst: + ldr r0, [sp, #-48] // reload lhs ptr + ldr r5, [sp, #4] // reload depth + vld1.32 {q0}, [r0]! + vld1.32 {q1, q2}, [r1]! + vmul.f32 q8, q1, d0[0] + vmul.f32 q9, q2, d0[0] + vmul.f32 q10, q1, d0[1] + vmul.f32 q11, q2, d0[1] + vmul.f32 q12, q1, d1[0] + vmul.f32 q13, q2, d1[0] + vmul.f32 q14, q1, d1[1] + vmul.f32 q15, q2, d1[1] + + subs r5, r5, #1 + beq Bias + + LoopDepth: + vld1.32 {q0}, [r0]! + vld1.32 {q1, q2}, [r1]! + vmla.f32 q8, q1, d0[0] + vmla.f32 q9, q2, d0[0] + vmla.f32 q10, q1, d0[1] + vmla.f32 q11, q2, d0[1] + vmla.f32 q12, q1, d1[0] + vmla.f32 q13, q2, d1[0] + vmla.f32 q14, q1, d1[1] + vmla.f32 q15, q2, d1[1] + + subs r5, r5, #1 + bne LoopDepth + + Bias: + cmp r3, #0 + beq Activation + vld1.32 {q0}, [r3]! + vld1.32 {q1}, [r3]! + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q1 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q1 + vadd.f32 q12, q12, q0 + vadd.f32 q13, q13, q1 + vadd.f32 q14, q14, q0 + vadd.f32 q15, q15, q1 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + vmin.f32 q12, q12, q2 + vmin.f32 q13, q13, q2 + vmin.f32 q14, q14, q2 + vmin.f32 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + vmax.f32 q12, q12, q3 + vmax.f32 q13, q13, q3 + vmax.f32 q14, q14, q3 + vmax.f32 q15, q15, q3 + + Write: + ldr lr, [sp, #20] + cmp lr, #2 + beq WriteWino + cmp lr, #0 + beq WriteC8 + cmp r7, #1 + beq Write1 + cmp r7, #2 + beq Write2 + cmp r7, #3 + beq Write3 + cmp r7, #4 + beq Write4 + cmp r7, #5 + beq Write5 + cmp r7, #6 + beq Write6 + cmp r7, #7 + beq Write7 + b Write8 + + Write1: + add lr, r2, #4 + str lr, [sp, #-40] + vst1.32 d16[0], [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20[0], [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24[0], [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28[0], [r2] + add r2, r2, r8 + add r2, r2, #4 + b WriteEnd + Write2: + add lr, r2, #8 + str lr, [sp, #-40] + vst1.32 d16, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28, [r2] + add r2, r2, r8 + add r2, r2, #8 + b WriteEnd + Write3: + add lr, r2, #12 + str lr, [sp, #-40] + add r4, r2, #8 + vst1.32 d16, [r2] + vst1.32 d17[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d20, [r2] + vst1.32 d21[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d24, [r2] + vst1.32 d25[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d28, [r2] + vst1.32 d29[0], [r4] + add r2, r2, r8 + add r2, r2, #12 + b WriteEnd + Write4: + add lr, r2, #16 + str lr, [sp, #-40] + vst1.32 {d16, d17}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d20, d21}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d24, d25}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d28, d29}, [r2] + add r2, r2, r8 + add r2, r2, #16 + b WriteEnd + Write5: + add lr, r2, #20 + str lr, [sp, #-40] + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30[0], [r4] + add r2, r2, r8 + add r2, r2, #20 + b WriteEnd + Write6: + add lr, r2, #24 + str lr, [sp, #-40] + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18, [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22, [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26, [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30, [r4] + add r2, r2, r8 + add r2, r2, #24 + b WriteEnd + Write7: + add lr, r2, #28 + str lr, [sp, #-40] + add lr, r2, #24 + add r4, r2, #16 + vst1.32 {d16, d17}, [r2] + vst1.32 d18, [r4] + vst1.32 d19[0], [lr] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d20, d21}, [r2] + vst1.32 d22, [r4] + vst1.32 d23[0], [lr] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d24, d25}, [r2] + vst1.32 d26, [r4] + vst1.32 d27[0], [lr] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + add lr, lr, r8 + vst1.32 {d28, d29}, [r2] + vst1.32 d30, [r4] + vst1.32 d31[0], [lr] + add r2, r2, r8 + add r2, r2, #28 + b WriteEnd + WriteC8: + mov lr, r2 + vst1.32 {q8, q9}, [lr]! + vst1.32 {q10, q11}, [lr]! + vst1.32 {q12, q13}, [lr]! + vst1.32 {q14, q15}, [lr]! + add r2, r2, r10 + b WriteEnd + WriteWino: + add lr, r2, r10 + vst1.32 {q8, q9}, [r2] + add r2, r2, r11 + vst1.32 {q10, q11}, [r2] + add r2, r2, r11 + vst1.32 {q12, q13}, [r2] + add r2, r2, r11 + vst1.32 {q14, q15}, [r2] + str lr, [sp, #-40] + b WriteEnd + Write8: + add lr, r2, #32 + str lr, [sp, #-40] + vst1.32 {q8, q9}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q10, q11}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q12, q13}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {q14, q15}, [r2] + add r2, r2, r8 + add r2, r2, #32 + + WriteEnd: + cmp r7, #8 + ble LoopColEnd + sub r7, r7, #8 // rhs col - 8 + b LoopCol + + LoopColEnd: + ldr r0, [sp, #-48] + add r0, r0, r12 // rhs ptr + stride + str r0, [sp, #-48] + ldr lr, [sp, #20] + cmp lr, #0 + beq C8DstStep + cmp lr, #2 + beq WinoDstStep + mov lr, #4 + ldr r7, [sp, #12] // reload rhs col + mul lr, lr, r7 + sub r2, r2, lr + str r2, [sp, #-40] + b NoDstStep + C8DstStep: + ldr lr, [sp, #-40] + add r2, lr, #128 + str r2, [sp, #-40] + b NoDstStep + WinoDstStep: + add r2, r2, r10 + str r2, [sp, #-40] + NoDstStep: + cmp r6, #4 + ble LoopRowEnd + sub r6, r6, #4 // lhs row - 4 + b LoopRow + +LoopRowEnd: + sub sp, sp, #48 + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulFp32Opt12x4.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulFp32Opt12x4.S new file mode 100644 index 0000000000000000000000000000000000000000..7a2035b7d13cbbd61c1833801e3f9938c9e18559 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulFp32Opt12x4.S @@ -0,0 +1,578 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: row +// r7: col +// r8: stride +// lr: OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 + +asm_function MatmulFloatNeon32Opt12x4 + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #112 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + mov lr, #48 // sizeof(float) * 12 + mul r12, r5, lr // block stride of lhs: sizeof(float) * 12 * depth + mov lr, #4 + mul r8, r8, lr // stride * sizeof(float) + +LoopRowStart: + cmp r6, #4 + ble LoopRow4 + cmp r6, #8 + ble LoopRow8 + +LoopRow: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r7, [sp, #12] // reload rhs col + ldr r3, [sp, #-36] // reload bias ptr + + LoopCol: + ldr r2, [sp, #-40] // reload dst ptr + ldr r0, [sp, #-48] // reload lhs ptr + ldr r5, [sp, #4] // reload depth + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmul.f32 q4, q3, d0[0] + vmul.f32 q5, q3, d0[1] + vmul.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmul.f32 q7, q3, d1[1] + + vmul.f32 q8, q3, d2[0] + vmul.f32 q9, q3, d2[1] + vmul.f32 q10, q3, d3[0] + vmul.f32 q11, q3, d3[1] + + vmul.f32 q12, q3, d4[0] + vmul.f32 q13, q3, d4[1] + vmul.f32 q14, q3, d5[0] + vmul.f32 q15, q3, d5[1] + + subs r5, r5, #1 + beq Bias + + LoopDepth: + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmla.f32 q4, q3, d0[0] + vmla.f32 q5, q3, d0[1] + vmla.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmla.f32 q7, q3, d1[1] + + vmla.f32 q8, q3, d2[0] + vmla.f32 q9, q3, d2[1] + vmla.f32 q10, q3, d3[0] + vmla.f32 q11, q3, d3[1] + + vmla.f32 q12, q3, d4[0] + vmla.f32 q13, q3, d4[1] + vmla.f32 q14, q3, d5[0] + vmla.f32 q15, q3, d5[1] + + subs r5, r5, #1 + bne LoopDepth + + Bias: + cmp r3, #0 + beq Activation + vld1.32 {q0}, [r3]! + vadd.f32 q4, q4, q0 + vadd.f32 q5, q5, q0 + vadd.f32 q6, q6, q0 + vadd.f32 q7, q7, q0 + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q0 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q0 + vadd.f32 q12, q12, q0 + vadd.f32 q13, q13, q0 + vadd.f32 q14, q14, q0 + vadd.f32 q15, q15, q0 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q4, q4, q2 + vmin.f32 q5, q5, q2 + vmin.f32 q6, q6, q2 + vmin.f32 q7, q7, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + vmin.f32 q12, q12, q2 + vmin.f32 q13, q13, q2 + vmin.f32 q14, q14, q2 + vmin.f32 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f32 q4, q4, q3 + vmax.f32 q5, q5, q3 + vmax.f32 q6, q6, q3 + vmax.f32 q7, q7, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + vmax.f32 q12, q12, q3 + vmax.f32 q13, q13, q3 + vmax.f32 q14, q14, q3 + vmax.f32 q15, q15, q3 + b Write + +LoopRow8: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r7, [sp, #12] // reload rhs col + ldr r3, [sp, #-36] // reload bias ptr + + LoopCol_R8: + ldr r2, [sp, #-40] // reload dst ptr + ldr r0, [sp, #-48] // reload lhs ptr + ldr r5, [sp, #4] // reload depth + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmul.f32 q4, q3, d0[0] + vmul.f32 q5, q3, d0[1] + vmul.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmul.f32 q7, q3, d1[1] + + vmul.f32 q8, q3, d2[0] + vmul.f32 q9, q3, d2[1] + vmul.f32 q10, q3, d3[0] + vmul.f32 q11, q3, d3[1] + + subs r5, r5, #1 + beq Bias_R8 + + LoopDepth_R8: + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmla.f32 q4, q3, d0[0] + vmla.f32 q5, q3, d0[1] + vmla.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmla.f32 q7, q3, d1[1] + + vmla.f32 q8, q3, d2[0] + vmla.f32 q9, q3, d2[1] + vmla.f32 q10, q3, d3[0] + vmla.f32 q11, q3, d3[1] + + subs r5, r5, #1 + bne LoopDepth_R8 + + Bias_R8: + cmp r3, #0 + beq Activation_R8 + vld1.32 {q0}, [r3]! + vadd.f32 q4, q4, q0 + vadd.f32 q5, q5, q0 + vadd.f32 q6, q6, q0 + vadd.f32 q7, q7, q0 + vadd.f32 q8, q8, q0 + vadd.f32 q9, q9, q0 + vadd.f32 q10, q10, q0 + vadd.f32 q11, q11, q0 + + Activation_R8: + ldr lr, [sp] + cmp lr, #3 + beq Relu6_R8 + cmp lr, #1 + beq Relu_R8 + b Write + + Relu6_R8: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q4, q4, q2 + vmin.f32 q5, q5, q2 + vmin.f32 q6, q6, q2 + vmin.f32 q7, q7, q2 + vmin.f32 q8, q8, q2 + vmin.f32 q9, q9, q2 + vmin.f32 q10, q10, q2 + vmin.f32 q11, q11, q2 + + Relu_R8: + veor q3, q3, q3 + vmax.f32 q4, q4, q3 + vmax.f32 q5, q5, q3 + vmax.f32 q6, q6, q3 + vmax.f32 q7, q7, q3 + vmax.f32 q8, q8, q3 + vmax.f32 q9, q9, q3 + vmax.f32 q10, q10, q3 + vmax.f32 q11, q11, q3 + b Write + +LoopRow4: + ldr r1, [sp, #-44] // reload rhs ptr + ldr r7, [sp, #12] // reload rhs col + ldr r3, [sp, #-36] // reload bias ptr + + LoopCol_R4: + ldr r2, [sp, #-40] // reload dst ptr + ldr r0, [sp, #-48] // reload lhs ptr + ldr r5, [sp, #4] // reload depth + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmul.f32 q4, q3, d0[0] + vmul.f32 q5, q3, d0[1] + vmul.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmul.f32 q7, q3, d1[1] + + subs r5, r5, #1 + beq Bias_R4 + + LoopDepth_R4: + vld1.32 {q3}, [r1]! + vld1.32 {q0, q1}, [r0]! + vmla.f32 q4, q3, d0[0] + vmla.f32 q5, q3, d0[1] + vmla.f32 q6, q3, d1[0] + vld1.32 {q2}, [r0]! + vmla.f32 q7, q3, d1[1] + + subs r5, r5, #1 + bne LoopDepth_R4 + + Bias_R4: + cmp r3, #0 + beq Activation_R4 + vld1.32 {q0}, [r3]! + vadd.f32 q4, q4, q0 + vadd.f32 q5, q5, q0 + vadd.f32 q6, q6, q0 + vadd.f32 q7, q7, q0 + + Activation_R4: + ldr lr, [sp] + cmp lr, #3 + beq Relu6_R4 + cmp lr, #1 + beq Relu_R4 + b Write + + Relu6_R4: + vmov.i32 q2, #6 + vcvt.f32.s32 q2, q2 + vmin.f32 q4, q4, q2 + vmin.f32 q5, q5, q2 + vmin.f32 q6, q6, q2 + vmin.f32 q7, q7, q2 + + Relu_R4: + veor q3, q3, q3 + vmax.f32 q4, q4, q3 + vmax.f32 q5, q5, q3 + vmax.f32 q6, q6, q3 + vmax.f32 q7, q7, q3 + + Write: + cmp r7, #1 + beq Write1 + cmp r7, #2 + beq Write2 + cmp r7, #3 + beq Write3 + b Write4 + + Write1: + add lr, r2, #4 + str lr, [sp, #-40] + vst1.32 d8[0], [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d10[0], [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d12[0], [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d14[0], [r2] + cmp r6, #4 + beq WriteEnd + add r2, r2, r8 + vst1.32 d16[0], [r2] + cmp r6, #5 + beq WriteEnd + add r2, r2, r8 + vst1.32 d18[0], [r2] + cmp r6, #6 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20[0], [r2] + cmp r6, #7 + beq WriteEnd + add r2, r2, r8 + vst1.32 d22[0], [r2] + cmp r6, #8 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24[0], [r2] + cmp r6, #9 + beq WriteEnd + add r2, r2, r8 + vst1.32 d26[0], [r2] + cmp r6, #10 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28[0], [r2] + cmp r6, #11 + beq WriteEnd + add r2, r2, r8 + vst1.32 d30[0], [r2] + add r2, r2, r8 + add r2, r2, #4 + b WriteEnd + Write2: + add lr, r2, #8 + str lr, [sp, #-40] + vst1.32 d8, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 d10, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 d12, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 d14, [r2] + cmp r6, #4 + beq WriteEnd + add r2, r2, r8 + vst1.32 d16, [r2] + cmp r6, #5 + beq WriteEnd + add r2, r2, r8 + vst1.32 d18, [r2] + cmp r6, #6 + beq WriteEnd + add r2, r2, r8 + vst1.32 d20, [r2] + cmp r6, #7 + beq WriteEnd + add r2, r2, r8 + vst1.32 d22, [r2] + cmp r6, #8 + beq WriteEnd + add r2, r2, r8 + vst1.32 d24, [r2] + cmp r6, #9 + beq WriteEnd + add r2, r2, r8 + vst1.32 d26, [r2] + cmp r6, #10 + beq WriteEnd + add r2, r2, r8 + vst1.32 d28, [r2] + cmp r6, #11 + beq WriteEnd + add r2, r2, r8 + vst1.32 d30, [r2] + add r2, r2, r8 + add r2, r2, #8 + b WriteEnd + Write3: + add lr, r2, #12 + str lr, [sp, #-40] + add r4, r2, #8 + vst1.32 d8, [r2] + vst1.32 d9[0], [r4] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d10, [r2] + vst1.32 d11[0], [r4] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d12, [r2] + vst1.32 d13[0], [r4] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d14, [r2] + vst1.32 d15[0], [r4] + cmp r6, #4 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d16, [r2] + vst1.32 d17[0], [r4] + cmp r6, #5 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d18, [r2] + vst1.32 d19[0], [r4] + cmp r6, #6 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d20, [r2] + vst1.32 d21[0], [r4] + cmp r6, #7 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d22, [r2] + vst1.32 d23[0], [r4] + cmp r6, #8 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d24, [r2] + vst1.32 d25[0], [r4] + cmp r6, #9 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d26, [r2] + vst1.32 d27[0], [r4] + cmp r6, #10 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d28, [r2] + vst1.32 d29[0], [r4] + cmp r6, #11 + beq WriteEnd + add r2, r2, r8 + add r4, r4, r8 + vst1.32 d30, [r2] + vst1.32 d31[0], [r4] + add r2, r2, r8 + add r2, r2, #12 + b WriteEnd + Write4: + add lr, r2, #16 + str lr, [sp, #-40] + vst1.32 {d8, d9}, [r2] + cmp r6, #1 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d10, d11}, [r2] + cmp r6, #2 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d12, d13}, [r2] + cmp r6, #3 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d14, d15}, [r2] + cmp r6, #4 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d16, d17}, [r2] + cmp r6, #5 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d18, d19}, [r2] + cmp r6, #6 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d20, d21}, [r2] + cmp r6, #7 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d22, d23}, [r2] + cmp r6, #8 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d24, d25}, [r2] + cmp r6, #9 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d26, d27}, [r2] + cmp r6, #10 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d28, d29}, [r2] + cmp r6, #11 + beq WriteEnd + add r2, r2, r8 + vst1.32 {d30, d31}, [r2] + add r2, r2, r8 + add r2, r2, #16 + b WriteEnd + WriteEnd: + cmp r7, #4 + ble LoopColEnd + sub r7, r7, #4 // rhs col - 4 + b LoopCol + + LoopColEnd: + ldr r0, [sp, #-48] + add r0, r0, r12 // lhs ptr + stride + str r0, [sp, #-48] + mov lr, #4 + ldr r7, [sp, #12] // reload rhs col + mul lr, lr, r7 + sub r2, r2, lr + str r2, [sp, #-40] + cmp r6, #12 + ble LoopRowEnd + sub r6, r6, #12 // lhs row - 12 + b LoopRowStart + +LoopRowEnd: + sub sp, sp, #112 + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulInt8.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulInt8.S new file mode 100644 index 0000000000000000000000000000000000000000..cac044136968596fc7cc5fc51d8dc91015a72881 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulInt8.S @@ -0,0 +1,298 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, +// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, +// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel); +// #-52: a, #-48: b, #-44: dst, #-40: row +// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp +// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel + +asm_function MatmulInt8Neon32 + push {r0-r11, lr} + vpush {q4-q7} + add sp, sp, #116 + + ldr r4, [sp] // col + ldr r7, [sp, #40] // output stride + mov r8, #0 // output channels offset + ldr r10, [sp, #44] + cmp r10, #0 + beq L1 + ldr r6, [sp, #8] // load intpu_sums ptr if per_channel +L1: + cmp r4, #0 // if at the end of col + ble End1 + + ldr r0, [sp, #-52] // reload a ptr + ldr r3, [sp, #-40] // reset row counter + ldr r10, [sp, #44] + cmp r10, #0 + bne L2 + ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor +L2: + cmp r3, #0 // if at the end of row + ble End2 + + ldr r1, [sp, #-48] // reload b ptr + ldr r5, [sp, #4] // reset deep16 + vmov.i32 q6, #0 + vmov.i32 q7, #0 + vmov.i32 q8, #0 + vmov.i32 q9, #0 + vmov.i32 q10, #0 + vmov.i32 q11, #0 + vmov.i32 q12, #0 + vmov.i32 q13, #0 +L3: + cmp r5, #0 + beq End3 + + vld1.8 {d0, d1, d2, d3}, [r0]! + vld1.8 {d8, d9, d10, d11}, [r1]! + vmull.s8 q14, d0, d8 + vmull.s8 q2, d0, d10 + vmull.s8 q15, d2, d8 + vmull.s8 q3, d2, d10 + vmlal.s8 q14, d1, d9 + vmlal.s8 q2, d1, d11 + vmlal.s8 q15, d3, d9 + vmlal.s8 q3, d3, d11 + + vpadal.s16 q6, q14 + vpadal.s16 q7, q2 + vpadal.s16 q8, q15 + vpadal.s16 q9, q3 + + vld1.8 {d0, d1, d2, d3}, [r0]! + vmull.s8 q14, d0, d8 + vmull.s8 q2, d0, d10 + vmull.s8 q15, d2, d8 + vmull.s8 q3, d2, d10 + vmlal.s8 q14, d1, d9 + vmlal.s8 q2, d1, d11 + vmlal.s8 q15, d3, d9 + vmlal.s8 q3, d3, d11 + + vpadal.s16 q10, q14 + vpadal.s16 q11, q2 + vpadal.s16 q12, q15 + vpadal.s16 q13, q3 + sub r5, r5, #16 // deep16 -= 16 + b L3 + +End3: + vpadd.i32 d0, d12, d13 + vpadd.i32 d1, d14, d15 + vpadd.i32 d2, d16, d17 + vpadd.i32 d3, d18, d19 + vpadd.i32 d4, d20, d21 + vpadd.i32 d5, d22, d23 + vpadd.i32 d6, d24, d25 + vpadd.i32 d7, d26, d27 + + vpadd.i32 d28, d0, d1 + vpadd.i32 d29, d2, d3 + vpadd.i32 d30, d4, d5 + vpadd.i32 d31, d6, d7 + + // Add weight_bias + ldr r9, [sp, #12] // reload weight_bias ptr + add r9, r9, r8 + vld1.32 {d26}, [r9]! + vadd.i32 d28, d28, d26 + vadd.i32 d29, d29, d26 + vadd.i32 d30, d30, d26 + vadd.i32 d31, d31, d26 + + ldr r10, [sp, #44] + cmp r10, #0 + bgt PerChannel + +PerTensor: + // Subtract input_sums + vld1.32 {d24, d25}, [r6]! + vdup.32 d20, d24[0] + vdup.32 d21, d24[1] + vdup.32 d22, d25[0] + vdup.32 d23, d25[1] + vsub.s32 d28, d28, d20 + vsub.s32 d29, d29, d21 + vsub.s32 d30, d30, d22 + vsub.s32 d31, d31, d23 + + // Apply left shift + ldr r10, [sp, #32] + ldr r11, [r10]! + vdup.32 q9, r11 + vshl.s32 q14, q14, q9 + vshl.s32 q15, q15, q9 + + // Apply the fixed-point part of the multiplier + ldr r10, [sp, #28] + ldr r11, [r10] + vdup.32 q8, r11 + vqrdmulh.s32 q14, q14, q8 + vqrdmulh.s32 q15, q15, q8 + + // Apply right shift + ldr r10, [sp, #36] + ldr r11, [r10] + vdup.32 q7, r11 + vand q6, q7, q14 + vshr.s32 q6, q6, #31 + vqadd.s32 q14, q14, q6 + vrshl.s32 q14, q14, q7 + vand q5, q7, q15 + vshr.s32 q5, q5, #31 + vqadd.s32 q15, q15, q5 + vrshl.s32 q15, q15, q7 + b AddDstZP + +PerChannel: + // Subtract input_sums + vld1.32 {d24, d25, d26, d27}, [r6]! + vsub.s32 d28, d28, d24 + vsub.s32 d29, d29, d25 + vsub.s32 d30, d30, d26 + vsub.s32 d31, d31, d27 + + // Apply left shift + ldr r10, [sp, #32] + add r10, r10, r8 + vld1.32 {d23}, [r10] + vshl.s32 d28, d28, d23 + vshl.s32 d29, d29, d23 + vshl.s32 d30, d30, d23 + vshl.s32 d31, d31, d23 + + // Apply the fixed-point part of the multiplier + ldr r10, [sp, #28] + add r10, r10, r8 + vld1.32 {d22}, [r10] + vqrdmulh.s32 d28, d28, d22 + vqrdmulh.s32 d29, d29, d22 + vqrdmulh.s32 d30, d30, d22 + vqrdmulh.s32 d31, d31, d22 + + // Apply right shift + ldr r10, [sp, #36] + add r10, r10, r8 + vld1.32 {d21}, [r10] + vand d20, d21, d28 + vshr.s32 d20, d20, #31 + vqadd.s32 d28, d28, d20 + vrshl.s32 d28, d28, d21 + vand d19, d21, d29 + vshr.s32 d19, d19, #31 + vqadd.s32 d29, d29, d19 + vrshl.s32 d29, d29, d21 + vand d18, d21, d30 + vshr.s32 d18, d18, #31 + vqadd.s32 d30, d30, d18 + vrshl.s32 d30, d30, d21 + vand d17, d21, d31 + vshr.s32 d17, d17, #31 + vqadd.s32 d31, d31, d17 + vrshl.s32 d31, d31, d21 + +AddDstZP: + // Add the destination zero point + ldr r10, [sp, #24] + vdup.32 q4, r10 + vadd.i32 q14, q14, q4 + vadd.i32 q15, q15, q4 + + // Apply the act_min bound + ldr r10, [sp, #16] + vdup.32 q3, r10 + vmax.s32 q14, q14, q3 + vmax.s32 q15, q15, q3 + + // Apply the act_max bound + ldr r10, [sp, #20] + vdup.32 q2, r10 + vmin.s32 q14, q14, q2 + vmin.s32 q15, q15, q2 + + // Cast-and-saturate from int32 to int16 + vqmovn.s32 d28, q14 + vqmovn.s32 d29, q15 + + // Cast-and-saturate from int16 to int8 + vqmovn.s16 d30, q14 + + // start to write + cmp r4, #2 + bge WriteCol2 + cmp r4, #1 + beq WriteCol1 + b EndWrite + +WriteCol2: + vst1.16 {d30[0]}, [r2], r7 + cmp r3, #1 + beq EndWrite + vst1.16 {d30[1]}, [r2], r7 + cmp r3, #2 + beq EndWrite + vst1.16 {d30[2]}, [r2], r7 + cmp r3, #3 + beq EndWrite + vst1.16 {d30[3]}, [r2], r7 + b EndWrite + +WriteCol1: + vst1.8 {d30[0]}, [r2], r7 + cmp r3, #1 + beq EndWrite + vst1.8 {d30[2]}, [r2], r7 + cmp r3, #2 + beq EndWrite + vst1.8 {d30[4]}, [r2], r7 + cmp r3, #3 + beq EndWrite + vst1.8 {d30[6]}, [r2], r7 + b EndWrite + +EndWrite: + sub r3, r3, #4 // a row counter -= 4 + b L2 + +End2: + sub r4, r4, #2 // b col counter -= 2 + ldr r1, [sp, #-48] // load b ptr + ldr r9, [sp, #4] + mov r10, #2 + mul r9, r9, r10 // the stride of b + add r1, r1, r9 // b ptr + stride + str r1, [sp, #-48] + ldr r2, [sp, #-44] // load dst ptr + add r2, r2, #2 // dst ptr + offset + str r2, [sp, #-44] + add r8, r8, #8 // output channels offset + 2*sizeof(int) + b L1 + +End1: + sub sp, sp, #116 + vpop {q4-q7} + pop {r0-r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulInt8Opt.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulInt8Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..02a4b1ba042c5de186c9867227c1d4888031c48b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulInt8Opt.S @@ -0,0 +1,300 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void MatmulInt8Neon32Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, +// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, +// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel, +// int *filter_zp); +// #-48: a, #-44: b, #-40: dst, #-36: row +// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp +// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel, #48: filter_zp + +asm_function MatmulInt8Opt + push {r0-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #112 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] // reload a_sums ptr + ldr r8, [sp, #40] + mov r10, #4 + mul r10, r10, r5 // lhs step + mov r11, #4 + mul r11, r11, r8 // dst step +LoopRow: + ldr r1, [sp, #-44] //reload rhs ptr + ldr r4, [sp] // reload rhs col + ldr lr, [sp, #-40] + vmov.32 d4[0], lr // reload dst ptr + ldr lr, [sp, #32] + vmov.32 d4[1], lr // reload left shift + ldr lr, [sp, #28] + vmov.32 d5[0], lr // reload multiplier + ldr lr, [sp, #36] + vmov.32 d5[1], lr // reload right shift + ldr r7, [sp, #48] // reload filter_zp + ldr r12, [sp, #12] // reload bias ptr + + LoopCol: + vmov.32 r2, d4[0] // reload dst ptr + ldr r0, [sp, #-48] //reload lhs ptr + ldr r5, [sp, #4] // reaload depth + + vmov.i32 q6, #0 + vmov.i32 q7, #0 + vmov.i32 q8, #0 + vmov.i32 q9, #0 + vmov.i32 q10, #0 + vmov.i32 q11, #0 + vmov.i32 q12, #0 + vmov.i32 q13, #0 + + LoopDepth: + vld1.8 {q0-q1}, [r0]! + vld1.8 {q4-q5}, [r1]! + vmull.s8 q14, d0, d8 + vmull.s8 q15, d2, d8 + vmlal.s8 q14, d1, d9 + vmlal.s8 q15, d3, d9 + vpadal.s16 q6, q14 + vpadal.s16 q8, q15 + vmull.s8 q14, d0, d10 + vmull.s8 q15, d2, d10 + vmlal.s8 q14, d1, d11 + vmlal.s8 q15, d3, d11 + vld1.8 {q0-q1}, [r0]! + + vpadal.s16 q7, q14 + vpadal.s16 q9, q15 + + vmull.s8 q14, d0, d8 + vmull.s8 q15, d2, d8 + vmlal.s8 q14, d1, d9 + vmlal.s8 q15, d3, d9 + vpadal.s16 q10, q14 + vpadal.s16 q12, q15 + vmull.s8 q14, d0, d10 + vmull.s8 q15, d2, d10 + vmlal.s8 q14, d1, d11 + vmlal.s8 q15, d3, d11 + + vpadal.s16 q11, q14 + vpadal.s16 q13, q15 + + cmp r5, #16 + ble LoopDepthEnd + sub r5, r5, #16 + b LoopDepth + + LoopDepthEnd: + vpadd.i32 d12, d12, d13 + vpadd.i32 d14, d14, d15 + vpadd.i32 d16, d16, d17 + vpadd.i32 d18, d18, d19 + vpadd.i32 d20, d20, d21 + vpadd.i32 d22, d22, d23 + vpadd.i32 d24, d24, d25 + vpadd.i32 d26, d26, d27 + + vpadd.i32 d28, d12, d14 + vpadd.i32 d29, d16, d18 + vpadd.i32 d30, d20, d22 + vpadd.i32 d31, d24, d26 + + Bias: + cmp r12, #0 + beq NoBias + vld1.32 {d26}, [r12]! + vadd.i32 d28, d28, d26 + vadd.i32 d29, d29, d26 + vadd.i32 d30, d30, d26 + vadd.i32 d31, d31, d26 + + NoBias: + ldr lr, [sp, #44] + cmp lr, #0 + bne PerChannel + + PerTensor: + vld1.32 {d24, d25}, [r6] + vdup.32 d20, d24[0] + vdup.32 d21, d24[1] + vdup.32 d22, d25[0] + vdup.32 d23, d25[1] + vsub.s32 d28, d28, d20 + vsub.s32 d29, d29, d21 + vsub.s32 d30, d30, d22 + vsub.s32 d31, d31, d23 + + vmov.32 lr, d4[1] + vld1.32 {d18[], d19[]}, [lr] + vshl.s32 q14, q14, q9 + vshl.s32 q15, q15, q9 + + vmov.32 lr, d5[0] + vld1.32 {d16[], d17[]}, [lr] + vqrdmulh.s32 q14, q14, q8 + vqrdmulh.s32 q15, q15, q8 + + vmov.32 lr, d5[1] + vld1.32 {d14[], d15[]}, [lr] + vand q6, q7, q14 + vshr.s32 q6, q6, #31 + vqadd.s32 q14, q14, q6 + vrshl.s32 q14, q14, q7 + vand q5, q7, q15 + vshr.s32 q5, q5, #31 + vqadd.s32 q15, q15, q5 + vrshl.s32 q15, q15, q7 + b Quantize + + PerChannel: + vld1.32 {d24, d25}, [r6] + vdup.32 d20, d24[0] + vdup.32 d21, d24[1] + vdup.32 d22, d25[0] + vdup.32 d23, d25[1] + vld1.32 {d19}, [r7]! + vmul.s32 d24, d20, d19 + vmul.s32 d25, d21, d19 + vmul.s32 d26, d22, d19 + vmul.s32 d27, d23, d19 + vsub.s32 d28, d28, d24 + vsub.s32 d29, d29, d25 + vsub.s32 d30, d30, d26 + vsub.s32 d31, d31, d27 + + vmov.32 lr, d4[1] + vld1.32 {d23}, [lr]! + vmov.32 d4[1], lr + vshl.s32 d28, d28, d23 + vshl.s32 d29, d29, d23 + vshl.s32 d30, d30, d23 + vshl.s32 d31, d31, d23 + + vmov.32 lr, d5[0] + vld1.32 {d22}, [lr]! + vmov.32 d5[0], lr + vqrdmulh.s32 d28, d28, d22 + vqrdmulh.s32 d29, d29, d22 + vqrdmulh.s32 d30, d30, d22 + vqrdmulh.s32 d31, d31, d22 + + vmov.32 lr, d5[1] + vld1.32 {d21}, [lr]! + vmov.32 d5[1], lr + vand d20, d21, d28 + vshr.s32 d20, d20, #31 + vqadd.s32 d28, d28, d20 + vrshl.s32 d28, d28, d21 + vand d19, d21, d29 + vshr.s32 d19, d19, #31 + vqadd.s32 d29, d29, d19 + vrshl.s32 d29, d29, d21 + vand d18, d21, d30 + vshr.s32 d18, d18, #31 + vqadd.s32 d30, d30, d18 + vrshl.s32 d30, d30, d21 + vand d17, d21, d31 + vshr.s32 d17, d17, #31 + vqadd.s32 d31, d31, d17 + vrshl.s32 d31, d31, d21 + + Quantize: + ldr lr, [sp, #24] + vdup.32 q0, lr + vadd.i32 q14, q14, q0 + vadd.i32 q15, q15, q0 + + ldr lr, [sp, #16] + vdup.32 q1, lr + vmax.s32 q14, q14, q1 + vmax.s32 q15, q15, q1 + + ldr lr, [sp, #20] + vdup.32 q0, lr + vmin.s32 q14, q14, q0 + vmin.s32 q15, q15, q0 + + vqmovn.s32 d28, q14 + vqmovn.s32 d29, q15 + + vqmovn.s16 d30, q14 + + cmp r4, #1 + beq Write1 + b Write2 + + Write1: + vmov.32 lr, d4[0] + add lr, lr, #1 + vmov.32 d4[0], lr + vst1.8 {d30[0]}, [r2], r8 + cmp r3, #1 + beq WriteEnd + vst1.8 {d30[2]}, [r2], r8 + cmp r3, #2 + beq WriteEnd + vst1.8 {d30[4]}, [r2], r8 + cmp r3, #3 + beq WriteEnd + vst1.8 {d30[6]}, [r2], r8 + b WriteEnd + + Write2: + vmov.32 lr, d4[0] + add lr, lr, #2 + vmov.32 d4[0], lr + vst1.16 {d30[0]}, [r2], r8 + cmp r3, #1 + beq WriteEnd + vst1.16 {d30[1]}, [r2], r8 + cmp r3, #2 + beq WriteEnd + vst1.16 {d30[2]}, [r2], r8 + cmp r3, #3 + beq WriteEnd + vst1.16 {d30[3]}, [r2], r8 + + WriteEnd: + cmp r4, #2 + ble LoopColEnd + sub r4, r4, #2 + b LoopCol + +LoopColEnd: + cmp r3, #4 + ble LoopRowEnd + ldr lr, [sp, #-48] + add lr, lr, r10 + str lr, [sp, #-48] + ldr lr, [sp, #-40] + add lr, lr, r11 + str lr, [sp, #-40] + sub r3, r3, #4 + add r6, r6, #16 + b LoopRow + +LoopRowEnd: + sub sp, sp, #112 + vpop {q4-q7} + pop {r0-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulWinogradFp32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulWinogradFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..60a5af3f3b0d685ccee8673ae5b44ce6dc12a6b8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/MatmulWinogradFp32.S @@ -0,0 +1,187 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// MatrixMultiplyWinograd(float *matix_a, float *matrix_b, float *matrix_c, int m, int k, int n, int in_channel, int c4_channel) + // r0: matrix_a, r1: matrix_b, r2: matrix_c, r3: m, r4: k, r5: n, r6: in_channel, r7: c4_channel * 4 + // #-56: matrix_a, #-52: matrix_b, #-48: matrix_c, #-44: m, #0: k, #4: n, #8: in_channel, #12: c4_channel * 4 +asm_function MatrixMultiplyWinograd + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r12, lr} + vpush {q4-q7} + add sp, sp, #120 + + mov r0, #4 + ldr r4, [sp, #4] // n + ldr r5, [sp, #8] // in_channel + ldr r6, [sp] // k + mul r5, r5, r0 // in_channel * 4 + mul r4, r4, r0 // n * 4 + mul r6, r6, r5 // in_channel * 4 * k + + // r3 = m + // r2 = dst + LoopM: + ldr r7, [sp, #4] // n + ldr r8, [sp, #-52] // matrix_b + LoopN: + ldr r0, [sp, #4] // n + ldr r1, [sp, #-44] // m + sub r0, r0, r7 // ni + mul r0, r0, r1 // ni * m + sub r1, r1, r3 // mi + add r0, r0, r1 // ni * m + mi + ldr r1, [sp, #12] + mul r9, r0, r1 // (ni * m + mi) * c4_channel * 4 + add r11, r2, r9 // dst + offset + + ldr r10, [sp, #8] // in_channel + ldr r9, [sp, #-56] // src + cmp r10, #16 + bge LoopC16 + cmp r10, #8 + bge LoopC8 + cmp r10, #4 + bge LoopC4 + cmp r10, #1 + bge LoopC + b EndLoopC + + LoopC16: + mov r0, r8 // mat_b1 + ldr r12, [sp] // k + veor q5, q5, q5 + veor q6, q6, q6 + veor q7, q7, q7 + veor q8, q8, q8 + LoopK16: + vld1.32 {q0, q1}, [r9]! + vld1.32 {q2, q3}, [r9]! + add r9, r9, r5 + sub r9, r9, #64 + vld1.32 d8[0], [r0], r4 + vmla.f32 q5, q0, d8[0] + vmla.f32 q6, q1, d8[0] + vmla.f32 q7, q2, d8[0] + vmla.f32 q8, q3, d8[0] + subs r12, r12, #1 + bne LoopK16 + Write16: + vst1.32 {q5, q6}, [r11]! + vst1.32 {q7, q8}, [r11]! + subs r10, r10, #16 + beq EndLoopC + sub r9, r9, r6 + add r9, r9, #64 + cmp r10, #16 + bge LoopC16 + cmp r10, #8 + bge LoopC8 + cmp r10, #4 + bge LoopC4 + cmp r10, #1 + bge LoopC + + LoopC8: + veor q5, q5, q5 + veor q6, q6, q6 + mov r0, r8 // mat_b1 + ldr r12, [sp] // k + LoopK8: + vld1.32 {q0, q1}, [r9], r5 + vld1.32 d8[0], [r0], r4 + vmla.f32 q5, q0, d8[0] + vmla.f32 q6, q1, d8[0] + subs r12, r12, #1 + bne LoopK8 + Write8: + vst1.32 {q5, q6}, [r11]! + subs r10, r10, #8 + beq EndLoopC + sub r9, r9, r6 + add r9, r9, #32 + cmp r10, #8 + bge LoopC8 + cmp r10, #4 + bge LoopC4 + cmp r10, #1 + bge LoopC + + LoopC4: + veor q5, q5, q5 + mov r0, r8 // mat_b1 + ldr r12, [sp] // k + LoopK4: + vld1.32 {q0}, [r9], r5 + vld1.32 d8[0], [r0], r4 + vmla.f32 q5, q0, d8[0] + subs r12, r12, #1 + bne LoopK4 + Write4: + vst1.32 {q5}, [r11]! + subs r10, r10, #4 + beq EndLoopC + sub r9, r9, r6 + add r9, r9, #16 + cmp r10, #4 + bge LoopC4 + cmp r10, #1 + bge LoopC + + LoopC: + veor q2, q2, q2 + mov r0, r8 // mat_b1 + ldr r12, [sp] // k + LoopK: + vld1.32 d0[0], [r9], r5 + vld1.32 d2[0], [r0], r4 + vmla.f32 s8, s0, s4 + subs r12, r12, #1 + bne LoopK + Write: + vst1.32 d4[0], [r11]! + subs r10, r10, #1 + beq EndLoopC + sub r9, r9, r6 + add r9, r9, #4 + b LoopC + + EndLoopC: + subs r7, r7, #1 + beq EndLoopN + add r8, r8, #4 + b LoopN + EndLoopN: + subs r3, r3, #1 + beq EndLoopM + ldr r0, [sp, #-56] + add r0, r0, r6 + str r0, [sp, #-56] + b LoopM + EndLoopM: + sub sp, sp, #120 + vpop {q4-q7} + pop {r0-r12, pc} +#endif + diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/PostFuncBiasReluC4.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/PostFuncBiasReluC4.S new file mode 100644 index 0000000000000000000000000000000000000000..d124d37dad22be91d2e810a0d3da5be7c88cf5a1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/PostFuncBiasReluC4.S @@ -0,0 +1,248 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +asm_function WinogradPostFuncBiasReluC4 + push {r4-r8, r10, r11, lr} + add sp, sp, #32 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + vmov.i32 q14, #6 + vcvt.f32.s32 q14, q14 + veor q15, q15, q15 + + mov lr, #4 + add r12, r3, r4 + mul r12, r12, lr + + mov lr, #0 + +Loop_C4: + cmp lr, r3 + beq Loop_C1 + mov r11, #4 + mul r10, lr, r11 + add r11, r0, r10 + add lr, lr, #4 + mov r8, r5 + vld1.32 {q12}, [r2]! + +Loop_4x4: + cmp r8, #4 + blt Loop_1x4 + sub r8, r8, #4 + vld1.32 {q0-q1}, [r1]! + vld1.32 {q2-q3}, [r1]! + + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q12 + vadd.f32 q2, q2, q12 + vadd.f32 q3, q3, q12 + + cmp r7, #3 + beq Relu6_4x4 + cmp r7, #1 + beq Relu_4x4 + b Write_4x4 +Relu6_4x4: + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmin.f32 q2, q2, q14 + vmin.f32 q3, q3, q14 +Relu_4x4: + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vmax.f32 q2, q2, q15 + vmax.f32 q3, q3, q15 +Write_4x4: + vst1.32 {q0}, [r11], r12 + vst1.32 {q1}, [r11], r12 + vst1.32 {q2}, [r11], r12 + vst1.32 {q3}, [r11], r12 + b Loop_4x4 + +Loop_1x4: + cmp r7, #3 + beq Relu6_1x4 + cmp r7, #1 + beq Relu_1x4 + b Write_1x4 +Relu6_1x4: + cmp r8, #0 + beq HW_Add + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r11], r12 + b Relu6_1x4 +Relu_1x4: + cmp r8, #0 + beq HW_Add + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r11], r12 + b Relu_1x4 +Write_1x4: + cmp r8, #0 + beq HW_Add + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {q0}, [r11], r12 + b Write_1x4 + +HW_Add: + add r1, r1, r6 + b Loop_C4 + +Loop_C1: + cmp r4, #0 + beq End + mov r8, r5 + vld1.32 {q12}, [r2]! + mov r11, #4 + mul r10, lr, r11 + add r0, r0, r10 + + cmp r4, #1 + beq Loop_C1_1 + cmp r4, #2 + beq Loop_C1_2 + cmp r4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp r7, #3 + beq Loop_C1_1_Relu6 + cmp r7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r12 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r12 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0[0]}, [r0], r12 + b Loop_C1_1_Write + +Loop_C1_2: + cmp r7, #3 + beq Loop_C1_2_Relu6 + cmp r7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r12 + b Loop_C1_2_Write + +Loop_C1_3: + add r11, r0, #8 + cmp r7, #3 + beq Loop_C1_3_Relu6 + cmp r7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + vst1.32 {d1[0]}, [r11], r12 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r12 + vst1.32 {d1[0]}, [r11], r12 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r12 + vst1.32 {d1[0]}, [r11], r12 + b Loop_C1_3_Write + +End: + sub sp, sp, #32 + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/PostFuncBiasReluC8.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/PostFuncBiasReluC8.S new file mode 100644 index 0000000000000000000000000000000000000000..ab704b34fe255c7098de6b22af29aa649026888a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/PostFuncBiasReluC8.S @@ -0,0 +1,450 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div,size_t oc8mod +// size_t plane_size, size_t stride, int relu_type); +// r0 dst r1 srx r2 bias +// r3 oc8div r4 oc8mod r5 plane_size +// r6 stride r7 relu_type + +// v0 ~ v15 value +// v16 v17 bias data +// r10 r11 weite loop tmp buf +// r16 relu6 #6; r17 relu #0 +// lr oc8 loop control +// r8 hw loop control + +asm_function PostFuncBiasReluC8 + push {r4-r8, r10, r11, lr} + add sp, sp, #32 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + vmov.i32 q14, #6 + vcvt.f32.s32 q14, q14 + veor q15, q15, q15 + mov lr, #0 + +Loop_C8: + cmp lr, r3 + beq Loop_C1 + mov r11, #4 + mul r10, lr, r11 + add r11, r0, r10 + add lr, lr, #8 + mov r8, r5 + vld1.32 {q12-q13}, [r2]! + +Loop_4x8: + cmp r8, #4 + blt Loop_1x8 + sub r8, r8, #4 + vld1.32 {q0-q1}, [r1]! + vld1.32 {q2-q3}, [r1]! + vld1.32 {q8-q9}, [r1]! + vld1.32 {q10-q11}, [r1]! + + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vadd.f32 q2, q2, q12 + vadd.f32 q3, q3, q13 + vadd.f32 q8, q8, q12 + vadd.f32 q9, q9, q13 + vadd.f32 q10, q10, q12 + vadd.f32 q11, q11, q13 + + cmp r7, #3 + beq Relu6_4x8 + cmp r7, #1 + beq Relu_4x8 + b Write_4x8 +Relu6_4x8: + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmin.f32 q2, q2, q14 + vmin.f32 q3, q3, q14 + vmin.f32 q8, q8, q14 + vmin.f32 q9, q9, q14 + vmin.f32 q10, q10, q14 + vmin.f32 q11, q11, q14 +Relu_4x8: + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vmax.f32 q2, q2, q15 + vmax.f32 q3, q3, q15 + vmax.f32 q8, q8, q15 + vmax.f32 q9, q9, q15 + vmax.f32 q10, q10, q15 + vmax.f32 q11, q11, q15 +Write_4x8: + vst1.32 {q0-q1}, [r11], r6 + vst1.32 {q2-q3}, [r11], r6 + vst1.32 {q8-q9}, [r11], r6 + vst1.32 {q10-q11}, [r11], r6 + b Loop_4x8 + +Loop_1x8: + cmp r7, #3 + beq Relu6_1x8 + cmp r7, #1 + beq Relu_1x8 + b Write_1x8 +Relu6_1x8: + cmp r8, #0 + beq Loop_C8 + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0-q1}, [r11], r6 + b Relu6_1x8 +Relu_1x8: + cmp r8, #0 + beq Loop_C8 + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0-q1}, [r11], r6 + b Relu_1x8 +Write_1x8: + cmp r8, #0 + beq Loop_C8 + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vst1.32 {q0-q1}, [r11], r6 + b Write_1x8 + +Loop_C1: + cmp r4, #0 + beq End + mov r8, r5 + vld1.32 {q12-q13}, [r2]! + mov r11, #4 + mul r10, lr, r11 + add r0, r0, r10 + + cmp r4, #1 + beq Loop_C1_1 + cmp r4, #2 + beq Loop_C1_2 + cmp r4, #3 + beq Loop_C1_3 + cmp r4, #4 + beq Loop_C1_4 + cmp r4, #5 + beq Loop_C1_5 + cmp r4, #6 + beq Loop_C1_6 + cmp r4, #7 + beq Loop_C1_7 + +Loop_C1_1: + cmp r7, #3 + beq Loop_C1_1_Relu6 + cmp r7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r6 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0[0]}, [r0], r6 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0[0]}, [r0], r6 + b Loop_C1_1_Write + +Loop_C1_2: + cmp r7, #3 + beq Loop_C1_2_Relu6 + cmp r7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r6 + b Loop_C1_2_Write + +Loop_C1_3: + add r11, r0, #8 + cmp r7, #3 + beq Loop_C1_3_Relu6 + cmp r7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + vst1.32 {d1[0]}, [r11], r6 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {d0}, [r0], r6 + vst1.32 {d1[0]}, [r11], r6 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {d0}, [r0], r6 + vst1.32 {d1[0]}, [r11], r6 + b Loop_C1_3_Write + +Loop_C1_4: + cmp r7, #3 + beq Loop_C1_4_Relu6 + cmp r7, #1 + beq Loop_C1_4_Relu + b Loop_C1_4_Write +Loop_C1_4_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmin.f32 q0, q0, q14 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r0], r6 + b Loop_C1_4_Relu6 +Loop_C1_4_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vmax.f32 q0, q0, q15 + vst1.32 {q0}, [r0], r6 + b Loop_C1_4_Relu6 +Loop_C1_4_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vst1.32 {q0}, [r0], r6 + b Loop_C1_4_Write + +Loop_C1_5: + add r11, r0, #16 + cmp r7, #3 + beq Loop_C1_5_Relu6 + cmp r7, #1 + beq Loop_C1_5_Relu + b Loop_C1_5_Write +Loop_C1_5_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2[0]}, [r11], r6 + b Loop_C1_5_Relu6 +Loop_C1_5_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2[0]}, [r11], r6 + b Loop_C1_5_Relu +Loop_C1_5_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2[0]}, [r11], r6 + b Loop_C1_5_Write + +Loop_C1_6: + add r11, r0, #16 + cmp r7, #3 + beq Loop_C1_6_Relu6 + cmp r7, #1 + beq Loop_C1_6_Relu + b Loop_C1_6_Write +Loop_C1_6_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + b Loop_C1_6_Relu6 +Loop_C1_6_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + b Loop_C1_6_Relu +Loop_C1_6_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + b Loop_C1_6_Write + +Loop_C1_7: + add r11, r0, #16 + add r10, r0, #24 + cmp r7, #3 + beq Loop_C1_7_Relu6 + cmp r7, #1 + beq Loop_C1_7_Relu + b Loop_C1_7_Write +Loop_C1_7_Relu6: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmin.f32 q0, q0, q14 + vmin.f32 q1, q1, q14 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + vst1.32 {d3[0]}, [r10], r6 + b Loop_C1_7_Relu6 +Loop_C1_7_Relu: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vmax.f32 q0, q0, q15 + vmax.f32 q1, q1, q15 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + vst1.32 {d3[0]}, [r10], r6 + b Loop_C1_7_Relu +Loop_C1_7_Write: + cmp r8, #0 + beq End + sub r8, r8, #1 + vld1.32 {q0-q1}, [r1]! + vadd.f32 q0, q0, q12 + vadd.f32 q1, q1, q13 + vst1.32 {q0}, [r0], r6 + vst1.32 {d2}, [r11], r6 + vst1.32 {d3[0]}, [r10], r6 + b Loop_C1_7_Write + +End: + sub sp, sp, #32 + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/PreSum4x16Int8Peroc.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/PreSum4x16Int8Peroc.S new file mode 100644 index 0000000000000000000000000000000000000000..3ce34fd655eb392cec190a7d7983a6cb1529e710 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/PreSum4x16Int8Peroc.S @@ -0,0 +1,143 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div2, +// size_t oc_res2, size_t stride); + +// r0 src +// r1 sum +// r2 zp +// r3 hw4 +// r4 ic16 +// r5 oc_div2 +// r6 oc_res2 +// r7 stride + +asm_function PreSum4x16Int8Peroc + push {r4-r11, lr} + vpush {q4-q7} + add sp, sp, #100 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + + mov r8, #0 + mov r10, #8 + +RowLoop: + cmp r8, r3 + beq End + add r8, r8, #4 + vmov.s32 q13, #0 + mov r9, #0 + mov r11, r2 + +Sum: + cmp r9, r4 + beq Mul + add r9, r9, #16 + + vld1.8 {q0, q1}, [r0]! + vld1.8 {q2, q3}, [r0]! + + vpaddl.s8 q4, q0 + vpaddl.s8 q5, q1 + vpaddl.s8 q6, q2 + vpaddl.s8 q7, q3 + + vpaddl.s16 q0, q4 + vpaddl.s16 q1, q5 + vpaddl.s16 q2, q6 + vpaddl.s16 q3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + vpaddl.s32 q6, q2 + vpaddl.s32 q7, q3 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + vqmovn.s64 d2, q6 + vqmovn.s64 d3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + + vadd.i32 q13, q13, q0 + b Sum + +Mul: + mov r12, r1 + add r1, r1, #32 + mov r9, #0 + + vdup.32 d1, d26[0] + vdup.32 d2, d26[1] + vdup.32 d3, d27[0] + vdup.32 d4, d27[1] + +Write: + + cmp r9, r5 + beq OcRes + add r9, r9, #2 + vld1.32 {d9}, [r11]! + + vmul.i32 d5, d1, d9 + vmul.i32 d6, d2, d9 + vmul.i32 d7, d3, d9 + vmul.i32 d8, d4, d9 + + vst1.32 d5, [r12], r10 + vst1.32 d6, [r12], r10 + vst1.32 d7, [r12], r10 + vst1.32 d8, [r12], r10 + add r12, r12, r7 + b Write + +OcRes: + cmp r6, #0 + beq RowLoop + + vmov.s32 d9, #0 + vld1.8 {d9[0]}, [r11] + + vmul.i32 d5, d1, d9 + vmul.i32 d6, d2, d9 + vmul.i32 d7, d3, d9 + vmul.i32 d8, d4, d9 + + vst1.32 d5, [r12], r10 + vst1.32 d6, [r12], r10 + vst1.32 d7, [r12], r10 + vst1.32 d8, [r12], r10 + b RowLoop + +End: + sub sp, sp, #100 + vpop {q4-q7} + pop {r4-r11, pc} +#endif diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_serializer.h b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/PreSum4x16Int8Pert.S similarity index 34% rename from mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_serializer.h rename to mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/PreSum4x16Int8Pert.S index 7525c1659094b0227676851b00d448b714c41b2a..11da2d2f6d4918b499c64cb00bc7a0ac51ad89c9 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_serializer.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/PreSum4x16Int8Pert.S @@ -13,33 +13,82 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_SERIALIZER_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_SERIALIZER_H_ -#include -#include -#include -#include "include/errorcode.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h" - -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; - -namespace mindspore::lite { -class TensorRTSerializer { - public: - explicit TensorRTSerializer(const std::string &serialize_file_path) - : serialize_file_path_(std::move(serialize_file_path)) {} - - ~TensorRTSerializer() = default; - - nvinfer1::ICudaEngine *GetSerializedEngine(); - - void SaveSerializedEngine(nvinfer1::ICudaEngine *engine); - - private: - std::string serialize_file_path_; - TensorRTLogger logger_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_SERIALIZER_H_ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp); + +// r0 src +// r1 sum +// r2 row4 +// r3 co16 +// r4 filter_zp + +asm_function PreSum4x16Int8Pert + push {r4-r8, r10, r11, lr} + vpush {q4-q7} + add sp, sp, #96 + + ldr r4, [sp] + + vdup.32 q10, r4 + mov r5, #0 + mov r7, #16 + +RowLoop: + cmp r5, r2 + beq End + add r5, r5, #4 + vmov.s32 q13, #0 + mov r6, #0 + +CalLoop: + cmp r6, r3 + beq Write + add r6, r6, #16 + + vld1.8 {q0, q1}, [r0]! + vld1.8 {q2, q3}, [r0]! + + vpaddl.s8 q4, q0 + vpaddl.s8 q5, q1 + vpaddl.s8 q6, q2 + vpaddl.s8 q7, q3 + + vpaddl.s16 q0, q4 + vpaddl.s16 q1, q5 + vpaddl.s16 q2, q6 + vpaddl.s16 q3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + vpaddl.s32 q6, q2 + vpaddl.s32 q7, q3 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + vqmovn.s64 d2, q6 + vqmovn.s64 d3, q7 + + vpaddl.s32 q4, q0 + vpaddl.s32 q5, q1 + + vqmovn.s64 d0, q4 + vqmovn.s64 d1, q5 + + vadd.i32 q13, q13, q0 + b CalLoop + +Write: + vmul.i32 q13, q13, q10 + vst1.32 {d26, d27}, [r1], r7 + beq RowLoop + +End: + sub sp, sp, #96 + vpop {q4-q7} + pop {r4-r8, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/TiledC4MatmulFp32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/TiledC4MatmulFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..cb7235dea36d96066834b68c79f9f07427c8c1b2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/TiledC4MatmulFp32.S @@ -0,0 +1,211 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +asm_function TiledC4MatmulFp32 +//void TiledC4MatmulFp32(float* dst, const float* src, const float* weight, size_t cal_num, size_t ic4, size_t oc4) +//x0: dst +//x1: src +//x2: weight +//x3: cal_num +//x4: ic4 +//x5: oc4 + +push {r4-r8, lr} +ldr r4, [sp, #24] +ldr r5, [sp, #28] +//step multi by sizeof(float) +mov r8, #4 +mul r3, r8, r3 + +vpush {q4-q7} + +LoopOc: + mov r6, r1 + mov r8, r0 + subs r7, r4, #1 + vld1.32 {q0, q1}, [r1]! + vld1.32 {q2, q3}, [r1]! + vld1.32 {q4, q5}, [r2]! + vld1.32 {q6, q7}, [r2]! + + vmul.f32 q8, q4, d0[0] + vmul.f32 q9, q4, d2[0] + vmul.f32 q10, q4, d4[0] + vmul.f32 q11, q4, d6[0] + + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q7, d5[1] + vmla.f32 q11, q7, d7[1] + + vld1.32 {q0, q1}, [r1]! + vld1.32 {q2, q3}, [r1]! + + vmul.f32 q12, q4, d0[0] + vmul.f32 q13, q4, d2[0] + vmul.f32 q14, q4, d4[0] + vmul.f32 q15, q4, d6[0] + + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q14, q6, d5[0] + vmla.f32 q15, q6, d7[0] + + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q7, d5[1] + vmla.f32 q15, q7, d7[1] + beq LoopIcEnd + + subs r7, r7, #1 + + vld1.32 {q4, q5}, [r2]! + vld1.32 {q0, q1}, [r1]! + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q8, q4, d0[0] + vmla.f32 q9, q4, d2[0] + beq LoopIcEndHalf + + LoopIc: + vmla.f32 q10, q4, d4[0] + vmla.f32 q11, q4, d6[0] + + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vld1.32 {q6, q7}, [r2]! + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q7, d5[1] + vld1.32 {q0, q1}, [r1]! + vmla.f32 q11, q7, d7[1] + + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q12, q4, d0[0] + vmla.f32 q13, q4, d2[0] + vmla.f32 q14, q4, d4[0] + vmla.f32 q15, q4, d6[0] + + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q14, q6, d5[0] + vld1.32 {q4, q5}, [r2]! + vmla.f32 q15, q6, d7[0] + + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q7, d5[1] + vld1.32 {q0, q1}, [r1]! + vmla.f32 q15, q7, d7[1] + + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q8, q4, d0[0] + vmla.f32 q9, q4, d2[0] + + subs r7, r7, #1 + bne LoopIc + LoopIcEndHalf: + vmla.f32 q10, q4, d4[0] + vmla.f32 q11, q4, d6[0] + + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vld1.32 {q6, q7}, [r2]! + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q7, d5[1] + vld1.32 {q0, q1}, [r1]! + vmla.f32 q11, q7, d7[1] + + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q12, q4, d0[0] + vmla.f32 q13, q4, d2[0] + vmla.f32 q14, q4, d4[0] + vmla.f32 q15, q4, d6[0] + + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q14, q6, d5[0] + vmla.f32 q15, q6, d7[0] + + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q7, d5[1] + vmla.f32 q15, q7, d7[1] + LoopIcEnd: + vst1.32 {q8, q9}, [r0]! + vst1.32 {q10, q11}, [r0]! + vst1.32 {q12, q13}, [r0]! + vst1.32 {q14, q15}, [r0]! + mov r1, r6 + + subs r5, r5, #1 + add r0, r8, r3 + bne LoopOc + + vpop {q4-q7} + pop {r4-r8, pc} + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/WinogradTransLeft.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/WinogradTransLeft.S new file mode 100644 index 0000000000000000000000000000000000000000..2d6e40302fe45b9a73283ed47a66a390489239eb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/WinogradTransLeft.S @@ -0,0 +1,230 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void WinogradTransLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6: length +asm_function WinogradTransLeft + push {r4-r11, lr} + ldr r4, [sp, #36] + ldr r5, [sp, #40] + ldr r6, [sp, #44] + + mov r8, #16 // 4 * sizeof(float) + mul r8, r6, r8 + mul r9, r3, r8 + sub r9, r9, r8 + add r7, r9, r8 // step for S + mov r10, #4 + mul r10, r4, r10 // step for B + +LoopH: + push {r0, r3} + LoopW: + push {r0, r1} + vmov.i32 q14, #0 + mov r11, r6 + InitZero: + vst1.32 {q14}, [r2]! + subs r11, r11, #1 + bne InitZero + + sub r2, r2, r8 + mov r12, r5 + + LoopKStart7: + cmp r12, #7 + blt LoopKStart4 + push {r3-r7} + LoopK7: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + vld1.32 {d2[0]}, [r1], r10 + vld1.32 {d2[1]}, [r1], r10 + vld1.32 {d3[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r7 + add r3, r1, r7 + add r4, r3, r7 + add r5, r4, r7 + add r6, r5, r7 + add r7, r6, r7 + + LoopLength7: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + vld1.32 {q12}, [r5]! + vmla.f32 q8, q12, d2[0] + vld1.32 {q13}, [r6]! + vmla.f32 q9, q13, d2[1] + vld1.32 {q12}, [r7]! + vmla.f32 q8, q12, d3[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength7 + + sub r2, r2, r8 + sub r12, r12, #7 + add r0, r7, r9 + vmov.32 r1, d30[0] + cmp r12, #7 + bge LoopK7 + + pop {r3-r7} + + LoopKStart4: + cmp r12, #4 + blt LoopKStart3 + vmov.32 d30[1], r3 + vmov.32 d31[0], r4 + LoopK4: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r7 + add r3, r1, r7 + add r4, r3, r7 + + LoopLength4: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength4 + + sub r2, r2, r8 + sub r12, r12, #4 + add r0, r4, r9 + vmov.32 r1, d30[0] + cmp r12, #4 + bge LoopK4 + + vmov.32 r3, d30[1] + vmov.32 r4, d31[0] + + LoopKStart3: + cmp r12, #3 + blt LoopKStart + vmov.32 d30[1], r3 + vmov.32 d31[0], r4 + LoopK3: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r7 + add r3, r1, r7 + + LoopLength3: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength3 + + sub r2, r2, r8 + sub r12, r12, #3 + add r0, r3, r9 + vmov.32 r1, d30[0] + cmp r12, #3 + bge LoopK3 + + vmov.32 r3, d30[1] + vmov.32 r4, d31[0] + + LoopKStart: + cmp r12, #0 + beq LoopKEnd + + LoopK: + vld1.32 {d30[0]}, [r1], r10 + + vdup.32 q15, d30[0] + mov r11, r6 + LoopLength: + vld1.32 {q0}, [r2] + vld1.32 {q1}, [r0]! + vmla.f32 q0, q1, q15 + + vst1.32 {q0}, [r2]! + subs r11, r11, #1 + bne LoopLength + subs r12, r12, #1 + + sub r2, r2, r8 + add r0, r0, r9 + bne LoopK + + LoopKEnd: + pop {r0, r1} + subs r3, r3, #1 + add r0, r0, r8 + add r2, r2, r8 + bne LoopW + + pop {r0, r3} + add r1, r1, #4 //sizeof(float) + subs r4, r4, #1 + bne LoopH + + pop {r4-r11, pc} + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/WinogradTransRight.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/WinogradTransRight.S new file mode 100644 index 0000000000000000000000000000000000000000..6eb101b7b5426b3c23391c18c6cbe564a0e3a9a4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm32/WinogradTransRight.S @@ -0,0 +1,220 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void WinogradTransRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6: length +asm_function WinogradTransRight + push {r4-r11, lr} + ldr r4, [sp, #36] + ldr r5, [sp, #40] + ldr r6, [sp, #44] + + mov r8, #16 // 4 * sizeof(float) + mul r8, r6, r8 + mul r9, r5, r8 // step for S + mov r10, #4 + mul r10, r4, r10 // step for B + +LoopH: + push {r1, r3} + LoopW: + push {r0, r1} + vmov.i32 q14, #0 + mov r11, r6 + InitZero: + vst1.32 {q14}, [r2]! + subs r11, r11, #1 + bne InitZero + + sub r2, r2, r8 + mov r12, r5 + LoopKStart7: + cmp r12, #7 + blt LoopKStart4 + push {r3-r7} + LoopK7: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + vld1.32 {d2[0]}, [r1], r10 + vld1.32 {d2[1]}, [r1], r10 + vld1.32 {d3[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r8 + add r3, r1, r8 + add r4, r3, r8 + add r5, r4, r8 + add r6, r5, r8 + add r7, r6, r8 + LoopLength7: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + vld1.32 {q12}, [r5]! + vmla.f32 q8, q12, d2[0] + vld1.32 {q13}, [r6]! + vmla.f32 q9, q13, d2[1] + vld1.32 {q12}, [r7]! + vmla.f32 q8, q12, d3[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength7 + + sub r2, r2, r8 + sub r12, r12, #7 + mov r0, r7 + vmov.32 r1, d30[0] + cmp r12, #7 + bge LoopK7 + + pop {r3-r7} + + LoopKStart4: + cmp r12, #4 + blt LoopKStart3 + vmov.32 d30[1], r3 + vmov.32 d31[0], r4 + LoopK4: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r8 + add r3, r1, r8 + add r4, r3, r8 + + LoopLength4: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength4 + + sub r2, r2, r8 + sub r12, r12, #4 + mov r0, r4 + vmov.32 r1, d30[0] + cmp r12, #4 + bge LoopK4 + + vmov.32 r3, d30[1] + vmov.32 r4, d31[0] + + LoopKStart3: + cmp r12, #3 + blt LoopKStart + vmov.32 d30[1], r3 + LoopK3: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r8 + add r3, r1, r8 + + LoopLength3: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength3 + + sub r2, r2, r8 + sub r12, r12, #3 + mov r0, r3 + vmov.32 r1, d30[0] + cmp r12, #3 + bge LoopK3 + + vmov.32 r3, d30[1] + + LoopKStart: + cmp r12, #0 + beq LoopKEnd + LoopK: + vld1.32 {d30[0]}, [r1], r10 + vdup.32 q15, d30[0] + mov r11, r6 + LoopLength: + vld1.32 {q0}, [r2] + vld1.32 {q1}, [r0]! + vmla.f32 q0, q1, q15 + + vst1.32 {q0}, [r2]! + subs r11, r11, #1 + bne LoopLength + + subs r12, r12, #1 + sub r2, r2, r8 + bne LoopK + LoopKEnd: + pop {r0, r1} + subs r3, r3, #1 + add r2, r2, r8 + add r1, r1, #4 //sizeof(float) + bne LoopW + + pop {r1, r3} + add r0, r0, r9 + subs r4, r4, #1 + bne LoopH + + pop {r4-r11, pc} + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/AdderFp32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/AdderFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..9123d88c764646e9dfc606690006beed9463d857 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/AdderFp32.S @@ -0,0 +1,622 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function AdderFloatNeon64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + + ldr x8, [sp, #144] + + mov x20, #48 // sizeof(float) * 12 + mul x17, x5, x20 // block stride of lhs/rhs: sizeof(float) * 12 * depth + + mov x20, #4 + mul x8, x8, x20 + +LoopRowStart: + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + blt LoopRow8 + +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + LoopDepthStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + + dup v8.4s, v0.s[0] + fabd v9.4s, v3.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v11.4s, v3.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v13.4s, v3.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v15.4s, v3.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v17.4s, v3.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v19.4s, v3.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v21.4s, v3.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v23.4s, v3.4s, v22.4s + + dup v24.4s, v2.s[0] + fabd v25.4s, v3.4s, v24.4s + dup v26.4s, v2.s[1] + fabd v27.4s, v3.4s, v26.4s + dup v28.4s, v2.s[2] + fabd v29.4s, v3.4s, v28.4s + dup v30.4s, v2.s[3] + fabd v31.4s, v3.4s, v30.4s + + subs x19, x19, #1 + beq Bias + + LoopDepth: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + + dup v8.4s, v0.s[0] + fabd v8.4s, v3.4s, v8.4s + fadd v9.4s, v9.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v10.4s, v3.4s, v10.4s + fadd v11.4s, v11.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v12.4s, v3.4s, v12.4s + fadd v13.4s, v13.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v14.4s, v3.4s, v14.4s + fadd v15.4s, v15.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v16.4s, v3.4s, v16.4s + fadd v17.4s, v17.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v18.4s, v3.4s, v18.4s + fadd v19.4s, v19.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v20.4s, v3.4s, v20.4s + fadd v21.4s, v21.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v22.4s, v3.4s, v22.4s + fadd v23.4s, v23.4s, v22.4s + + dup v24.4s, v2.s[0] + fabd v24.4s, v3.4s, v24.4s + fadd v25.4s, v25.4s, v24.4s + dup v26.4s, v2.s[1] + fabd v26.4s, v3.4s, v26.4s + fadd v27.4s, v27.4s, v26.4s + dup v28.4s, v2.s[2] + fabd v28.4s, v3.4s, v28.4s + fadd v29.4s, v29.4s, v28.4s + dup v30.4s, v2.s[3] + fabd v30.4s, v3.4s, v30.4s + fadd v31.4s, v31.4s, v30.4s + + subs x19, x19, #1 + bgt LoopDepth + + Bias: + fneg v9.4s, v9.4s + fneg v11.4s, v11.4s + fneg v13.4s, v13.4s + fneg v15.4s, v15.4s + fneg v17.4s, v17.4s + fneg v19.4s, v19.4s + fneg v21.4s, v21.4s + fneg v23.4s, v23.4s + fneg v25.4s, v25.4s + fneg v27.4s, v27.4s + fneg v29.4s, v29.4s + fneg v31.4s, v31.4s + cbz x3, Activation + ld1 {v0.4s}, [x12], #16 + fadd v9.4s, v9.4s, v0.4s + fadd v11.4s, v11.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + fadd v17.4s, v17.4s, v0.4s + fadd v19.4s, v19.4s, v0.4s + fadd v21.4s, v21.4s, v0.4s + fadd v23.4s, v23.4s, v0.4s + fadd v25.4s, v25.4s, v0.4s + fadd v27.4s, v27.4s, v0.4s + fadd v29.4s, v29.4s, v0.4s + fadd v31.4s, v31.4s, v0.4s + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu: + dup v3.4s, wzr + fmax v9.4s, v9.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b Write + +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + LoopDepthStart8: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + + dup v8.4s, v0.s[0] + fabd v9.4s, v3.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v11.4s, v3.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v13.4s, v3.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v15.4s, v3.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v17.4s, v3.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v19.4s, v3.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v21.4s, v3.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v23.4s, v3.4s, v22.4s + + subs x19, x19, #1 + beq Bias8 + + LoopDepth8: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + dup v8.4s, v0.s[0] + fabd v8.4s, v3.4s, v8.4s + fadd v9.4s, v9.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v10.4s, v3.4s, v10.4s + fadd v11.4s, v11.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v12.4s, v3.4s, v12.4s + fadd v13.4s, v13.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v14.4s, v3.4s, v14.4s + fadd v15.4s, v15.4s, v14.4s + + dup v16.4s, v1.s[0] + fabd v16.4s, v3.4s, v16.4s + fadd v17.4s, v17.4s, v16.4s + dup v18.4s, v1.s[1] + fabd v18.4s, v3.4s, v18.4s + fadd v19.4s, v19.4s, v18.4s + dup v20.4s, v1.s[2] + fabd v20.4s, v3.4s, v20.4s + fadd v21.4s, v21.4s, v20.4s + dup v22.4s, v1.s[3] + fabd v22.4s, v3.4s, v22.4s + fadd v23.4s, v23.4s, v22.4s + + subs x19, x19, #1 + bgt LoopDepth8 + + Bias8: + fneg v9.4s, v9.4s + fneg v11.4s, v11.4s + fneg v13.4s, v13.4s + fneg v15.4s, v15.4s + fneg v17.4s, v17.4s + fneg v19.4s, v19.4s + fneg v21.4s, v21.4s + fneg v23.4s, v23.4s + cbz x3, Activation8 + ld1 {v0.4s}, [x12], #16 + fadd v9.4s, v9.4s, v0.4s + fadd v11.4s, v11.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + fadd v17.4s, v17.4s, v0.4s + fadd v19.4s, v19.4s, v0.4s + fadd v21.4s, v21.4s, v0.4s + fadd v23.4s, v23.4s, v0.4s + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + Relu8: + dup v3.4s, wzr + fmax v9.4s, v9.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b Write + +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + LoopDepthStart4: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + dup v8.4s, v0.s[0] + fabd v9.4s, v3.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v11.4s, v3.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v13.4s, v3.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v15.4s, v3.4s, v14.4s + + subs x19, x19, #1 + beq Bias4 + + LoopDepth4: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s}, [x14], #16 + dup v8.4s, v0.s[0] + fabd v8.4s, v3.4s, v8.4s + fadd v9.4s, v9.4s, v8.4s + dup v10.4s, v0.s[1] + fabd v10.4s, v3.4s, v10.4s + fadd v11.4s, v11.4s, v10.4s + dup v12.4s, v0.s[2] + fabd v12.4s, v3.4s, v12.4s + fadd v13.4s, v13.4s, v12.4s + dup v14.4s, v0.s[3] + fabd v14.4s, v3.4s, v14.4s + fadd v15.4s, v15.4s, v14.4s + + subs x19, x19, #1 + bgt LoopDepth4 + + Bias4: + fneg v9.4s, v9.4s + fneg v11.4s, v11.4s + fneg v13.4s, v13.4s + fneg v15.4s, v15.4s + cbz x3, Activation4 + ld1 {v0.4s}, [x12], #16 + + fadd v9.4s, v9.4s, v0.4s + fadd v11.4s, v11.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4: + dup v3.4s, wzr + fmax v9.4s, v9.4s, v2.4s + fmax v11.4s, v11.4s, v2.4s + fmax v13.4s, v13.4s, v2.4s + fmax v15.4s, v15.4s, v2.4s + b Write + + Write: + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + b Write4 + + Write1: + add x2, x2, #4 + str s9, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s11, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s13, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s15, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s17, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s19, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s21, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s23, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str s25, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str s27, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str s29, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str s31, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v9.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v11.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v13.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v15.2s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v17.2s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v19.2s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v21.2s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.2s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v25.2s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v27.2s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v29.2s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v31.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v9.2s}, [x11], x8 + st1 {v9.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v11.2s}, [x11], x8 + st1 {v11.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v13.2s}, [x11], x8 + st1 {v13.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v15.2s}, [x11], x8 + st1 {v15.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v17.2s}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v19.2s}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v21.2s}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.2s}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v25.2s}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v27.2s}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v29.2s}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v31.2s}, [x11], x8 + st1 {v31.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v25.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v27.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v29.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v31.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + + WriteEnd: + subs x13, x13, #4 // rhs col - 4 + ble LoopColEnd + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol + +LoopColEnd: + add x0, x0, x17 + mov x20, #4 + mul x20, x20, x7 + sub x11, x11, x20 + mov x2, x11 + subs x6, x6, #12 + bgt LoopRowStart + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/BigMatmulFp32Opt.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/BigMatmulFp32Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..3d04e88a4560e31732440c8b4606e280b8ec8fda --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/BigMatmulFp32Opt.S @@ -0,0 +1,2528 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void BigMatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride + +asm_function BigMatmulFloatNeon64Opt + sub sp, sp, #224 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + stp x29, x30, [sp, #208] + + ldr x8, [sp, #224] + mov x20, #1 + mov x22, #32 + mov x23, #48 + mul x26, x5, x23 // stride for lhs + mul x24, x8, x23 // stride for out + lsl x27, x23, #9 // stride by depth for lhs + lsl x28, x22, #9 // stride by depth for rhs + lsl x22, x5, #5 // stride for rhs + lsl x8, x8, #2 + subs x5, x5, #512 + ble DepthTail +Depth512: + mov x25, x0 // restore lhs + mov x13, x2 // out + mov x10, x6 // restore row + RowStart: + mov x12, x1 // rhs + mov x14, x13 // out + mov x15, x3 // restore bias + mov x9, x7 // restore col + cmp x10, #4 + ble LoopRow4 + cmp x10, #8 + ble LoopRow8 + + LoopRow12: + mov x11, x25 // lhs + mov x23, x12 // rhs + mov x21, #512 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopCol12x4 + + LoopCol12x8: + cbz x20, Reload12x8 + cbnz x15, InitFromBias12x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + dup v24.2d, xzr + dup v25.2d, xzr + dup v26.2d, xzr + dup v27.2d, xzr + dup v28.2d, xzr + dup v29.2d, xzr + dup v30.2d, xzr + dup v31.2d, xzr + b Compute12x8Enter + InitFromBias12x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + ld1 {v16.4s, v17.4s}, [x15] + ld1 {v18.4s, v19.4s}, [x15] + ld1 {v20.4s, v21.4s}, [x15] + ld1 {v22.4s, v23.4s}, [x15] + ld1 {v24.4s, v25.4s}, [x15] + ld1 {v26.4s, v27.4s}, [x15] + ld1 {v28.4s, v29.4s}, [x15] + ld1 {v30.4s, v31.4s}, [x15] + add x15, x15, #32 + b Compute12x8Enter + Reload12x8: + bl Reload + Compute12x8Enter: + cbz x21, Write + bl Compute12x8Unit + b Write + + LoopCol12x4: + cbz x20, Reload12x4 + cbnz x15, InitFromBias12x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + dup v24.2d, xzr + dup v26.2d, xzr + dup v28.2d, xzr + dup v30.2d, xzr + b Compute12x4Enter + InitFromBias12x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + ld1 {v16.4s}, [x15] + ld1 {v18.4s}, [x15] + ld1 {v20.4s}, [x15] + ld1 {v22.4s}, [x15] + ld1 {v24.4s}, [x15] + ld1 {v26.4s}, [x15] + ld1 {v28.4s}, [x15] + ld1 {v30.4s}, [x15] + b Compute12x4Enter + Reload12x4: + bl Reload + Compute12x4Enter: + cbz x21, Write + bl Compute12x4Unit + b Write + + LoopRow8: + mov x11, x25 // lhs + mov x23, x12 // rhs + mov x21, #512 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopCol8x4 + + LoopCol8x8: + cbz x20, Reload8x8 + cbnz x15, InitFromBias8x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + b Compute8x8Enter + InitFromBias8x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + ld1 {v16.4s, v17.4s}, [x15] + ld1 {v18.4s, v19.4s}, [x15] + ld1 {v20.4s, v21.4s}, [x15] + ld1 {v22.4s, v23.4s}, [x15] + add x15, x15, #32 + b Compute8x8Enter + Reload8x8: + bl Reload + Compute8x8Enter: + cbz x21, Write + bl Compute8x8Unit + b Write + + LoopCol8x4: + cbz x20, Reload8x4 + cbnz x15, InitFromBias8x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + b Compute8x4Enter + InitFromBias8x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + ld1 {v16.4s}, [x15] + ld1 {v18.4s}, [x15] + ld1 {v20.4s}, [x15] + ld1 {v22.4s}, [x15] + b Compute8x4Enter + Reload8x4: + bl Reload + Compute8x4Enter: + cbz x21, Write + bl Compute8x4Unit + b Write + + LoopRow4: + mov x11, x25 // lhs + mov x23, x12 // rhs + mov x21, #512 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopCol4x4 + + LoopCol4x8: + cbz x20, Reload4x8 + cbnz x15, InitFromBias4x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + b Compute4x8Enter + InitFromBias4x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + add x15, x15, #32 + b Compute4x8Enter + Reload4x8: + bl Reload + Compute4x8Enter: + cbz x21, Write + bl Compute4x8Unit + b Write + + LoopCol4x4: + cbz x20, Reload4x4 + cbnz x15, InitFromBias4x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + b Compute4x4Enter + InitFromBias4x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + b Compute4x4Enter + Reload4x4: + bl Reload + Compute4x4Enter: + cbz x21, Write + bl Compute4x4Unit + +Write: + mov x21, x14 + cmp x9, #1 + beq Write1 + cmp x9, #2 + beq Write2 + cmp x9, #3 + beq Write3 + cmp x9, #4 + beq Write4 + cmp x9, #5 + beq Write5 + cmp x9, #6 + beq Write6 + cmp x9, #7 + beq Write7 + b Write8 + + Write1: + str s8, [x21] + cmp x10, #1 + beq LoopCol + add x21, x21, x8 + str s10, [x21] + cmp x10, #2 + beq LoopCol + add x21, x21, x8 + str s12, [x21] + cmp x10, #3 + beq LoopCol + add x21, x21, x8 + str s14, [x21] + cmp x10, #4 + beq LoopCol + add x21, x21, x8 + str s16, [x21] + cmp x10, #5 + beq LoopCol + add x21, x21, x8 + str s18, [x21] + cmp x10, #6 + beq LoopCol + add x21, x21, x8 + str s20, [x21] + cmp x10, #7 + beq LoopCol + add x21, x21, x8 + str s22, [x21] + cmp x10, #8 + beq LoopCol + add x21, x21, x8 + str s24, [x21] + cmp x10, #9 + beq LoopCol + add x21, x21, x8 + str s26, [x21] + cmp x10, #10 + beq LoopCol + add x21, x21, x8 + str s28, [x21] + cmp x10, #11 + beq LoopCol + add x21, x21, x8 + str s30, [x21] + b LoopCol + Write2: + st1 {v8.2s}, [x21], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.2s}, [x21], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.2s}, [x21], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.2s}, [x21], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.2s}, [x21], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.2s}, [x21], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.2s}, [x21], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.2s}, [x21], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.2s}, [x21], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.2s}, [x21], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.2s}, [x21], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.2s}, [x21], x8 + add x21, x21, #8 + b LoopCol + Write3: + add x11, x21, #8 + st1 {v8.2s}, [x21], x8 + st1 {v8.s}[2], [x11], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.2s}, [x21], x8 + st1 {v10.s}[2], [x11], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.2s}, [x21], x8 + st1 {v12.s}[2], [x11], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.2s}, [x21], x8 + st1 {v14.s}[2], [x11], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.2s}, [x21], x8 + st1 {v16.s}[2], [x11], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.2s}, [x21], x8 + st1 {v18.s}[2], [x11], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.2s}, [x21], x8 + st1 {v20.s}[2], [x11], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.2s}, [x21], x8 + st1 {v22.s}[2], [x11], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.2s}, [x21], x8 + st1 {v24.s}[2], [x11], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.2s}, [x21], x8 + st1 {v26.s}[2], [x11], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.2s}, [x21], x8 + st1 {v28.s}[2], [x11], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.2s}, [x21], x8 + st1 {v30.s}[2], [x11] + add x21, x21, #12 + b LoopCol + Write4: + st1 {v8.4s}, [x21], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.4s}, [x21], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.4s}, [x21], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.4s}, [x21], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.4s}, [x21], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.4s}, [x21], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.4s}, [x21], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.4s}, [x21], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.4s}, [x21], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.4s}, [x21], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.4s}, [x21], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.4s}, [x21], x8 + add x21, x21, #16 + b LoopCol + Write5: + add x11, x21, #16 + st1 {v8.4s}, [x21], x8 + str s9, [x11] + cmp x10, #1 + beq LoopCol + add x11, x11, x8 + st1 {v10.4s}, [x21], x8 + str s11, [x11] + cmp x10, #2 + beq LoopCol + add x11, x11, x8 + st1 {v12.4s}, [x21], x8 + str s13, [x11] + cmp x10, #3 + beq LoopCol + add x11, x11, x8 + st1 {v14.4s}, [x21], x8 + str s15, [x11] + cmp x10, #4 + beq LoopCol + add x11, x11, x8 + st1 {v16.4s}, [x21], x8 + str s17, [x11] + cmp x10, #5 + beq LoopCol + add x11, x11, x8 + st1 {v18.4s}, [x21], x8 + str s19, [x11] + cmp x10, #6 + beq LoopCol + add x11, x11, x8 + st1 {v20.4s}, [x21], x8 + str s21, [x11] + cmp x10, #7 + beq LoopCol + add x11, x11, x8 + st1 {v22.4s}, [x21], x8 + str s23, [x11] + cmp x10, #8 + beq LoopCol + add x11, x11, x8 + st1 {v24.4s}, [x21], x8 + str s25, [x11] + cmp x10, #9 + beq LoopCol + add x11, x11, x8 + st1 {v26.4s}, [x21], x8 + str s27, [x11] + cmp x10, #10 + beq LoopCol + add x11, x11, x8 + st1 {v28.4s}, [x21], x8 + str s29, [x11] + cmp x10, #11 + beq LoopCol + add x11, x11, x8 + st1 {v30.4s}, [x21], x8 + str s31, [x11] + add x21, x21, #20 + b LoopCol + Write6: + add x11, x21, #16 + st1 {v8.4s}, [x21], x8 + st1 {v9.2s}, [x11], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.4s}, [x21], x8 + st1 {v11.2s}, [x11], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.4s}, [x21], x8 + st1 {v13.2s}, [x11], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.4s}, [x21], x8 + st1 {v15.2s}, [x11], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.4s}, [x21], x8 + st1 {v17.2s}, [x11], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.4s}, [x21], x8 + st1 {v19.2s}, [x11], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.4s}, [x21], x8 + st1 {v21.2s}, [x11], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.4s}, [x21], x8 + st1 {v23.2s}, [x11], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.4s}, [x21], x8 + st1 {v25.2s}, [x11], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.4s}, [x21], x8 + st1 {v27.2s}, [x11], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.4s}, [x21], x8 + st1 {v29.2s}, [x11], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.4s}, [x21], x8 + st1 {v31.2s}, [x11] + add x21, x21, #24 + b LoopCol + Write7: + add x11, x21, #16 + add x23, x21, #24 + st1 {v8.4s}, [x21], x8 + st1 {v9.2s}, [x11], x8 + st1 {v9.s}[2], [x23], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.4s}, [x21], x8 + st1 {v11.2s}, [x11], x8 + st1 {v11.s}[2], [x23], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.4s}, [x21], x8 + st1 {v13.2s}, [x11], x8 + st1 {v13.s}[2], [x23], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.4s}, [x21], x8 + st1 {v15.2s}, [x11], x8 + st1 {v15.s}[2], [x23], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.4s}, [x21], x8 + st1 {v17.2s}, [x11], x8 + st1 {v17.s}[2], [x23], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.4s}, [x21], x8 + st1 {v19.2s}, [x11], x8 + st1 {v19.s}[2], [x23], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.4s}, [x21], x8 + st1 {v21.2s}, [x11], x8 + st1 {v21.s}[2], [x23], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.4s}, [x21], x8 + st1 {v23.2s}, [x11], x8 + st1 {v23.s}[2], [x23], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.4s}, [x21], x8 + st1 {v25.2s}, [x11], x8 + st1 {v25.s}[2], [x23], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.4s}, [x21], x8 + st1 {v27.2s}, [x11], x8 + st1 {v27.s}[2], [x23], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.4s}, [x21], x8 + st1 {v29.2s}, [x11], x8 + st1 {v29.s}[2], [x23], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.4s}, [x21], x8 + st1 {v31.2s}, [x11] + st1 {v31.s}[2], [x23] + add x21, x21, #28 + b LoopCol + + Write8: + st1 {v8.4s, v9.4s}, [x21], x8 + cmp x10, #1 + beq LoopCol + st1 {v10.4s, v11.4s}, [x21], x8 + cmp x10, #2 + beq LoopCol + st1 {v12.4s, v13.4s}, [x21], x8 + cmp x10, #3 + beq LoopCol + st1 {v14.4s, v15.4s}, [x21], x8 + cmp x10, #4 + beq LoopCol + st1 {v16.4s, v17.4s}, [x21], x8 + cmp x10, #5 + beq LoopCol + st1 {v18.4s, v19.4s}, [x21], x8 + cmp x10, #6 + beq LoopCol + st1 {v20.4s, v21.4s}, [x21], x8 + cmp x10, #7 + beq LoopCol + st1 {v22.4s, v23.4s}, [x21], x8 + cmp x10, #8 + beq LoopCol + st1 {v24.4s, v25.4s}, [x21], x8 + cmp x10, #9 + beq LoopCol + st1 {v26.4s, v27.4s}, [x21], x8 + cmp x10, #10 + beq LoopCol + st1 {v28.4s, v29.4s}, [x21], x8 + cmp x10, #11 + beq LoopCol + st1 {v30.4s, v31.4s}, [x21], x8 + add x21, x21, #32 + b LoopCol + +LoopCol: + subs x9, x9, #8 + ble LoopColEnd + add x12, x12, x22 // update rhs + add x14, x14, #32 // update out + cmp x10, #4 + ble LoopRow4 + cmp x10, #8 + ble LoopRow8 + b LoopRow12 + +LoopColEnd: + add x25, x25, x26 // update lhs + add x13, x13, x24 // update out + subs x10, x10, #12 // update row + bgt RowStart + mov x20, #0 + add x0, x0, x27 // update lhs by depth + add x1, x1, x28 // update rhs by depth + subs x5, x5, #512 + bgt Depth512 + +/////////////////////////////////////////////////////// + +DepthTail: + add x5, x5, #512 + mov x13, x2 // out + mov x10, x6 + TailRowStart: + mov x12, x1 // rhs + mov x14, x13 // out + mov x15, x3 // restore bias + mov x9, x7 // restore col + cmp x10, #4 + ble LoopTailRow4 + cmp x10, #8 + ble LoopTailRow8 + + LoopTailRow12: + mov x11, x0 // lhs + mov x23, x12 // rhs + mov x21, x5 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopTailCol12x4 + + LoopTailCol12x8: + cbz x20, ReloadTail12x8 + cbnz x15, InitTailFromBias12x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + dup v24.2d, xzr + dup v25.2d, xzr + dup v26.2d, xzr + dup v27.2d, xzr + dup v28.2d, xzr + dup v29.2d, xzr + dup v30.2d, xzr + dup v31.2d, xzr + b ComputeTail12x8Enter + InitTailFromBias12x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + ld1 {v16.4s, v17.4s}, [x15] + ld1 {v18.4s, v19.4s}, [x15] + ld1 {v20.4s, v21.4s}, [x15] + ld1 {v22.4s, v23.4s}, [x15] + ld1 {v24.4s, v25.4s}, [x15] + ld1 {v26.4s, v27.4s}, [x15] + ld1 {v28.4s, v29.4s}, [x15] + ld1 {v30.4s, v31.4s}, [x15] + add x15, x15, #32 + b ComputeTail12x8Enter + ReloadTail12x8: + bl Reload + ComputeTail12x8Enter: + cbz x21, Activation12x8 + bl Compute12x8Unit + Activation12x8: + cmp x4, #3 + beq Relu612x8 + cmp x4, #1 + beq Relu12x8 + b WriteTail + + Relu612x8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu12x8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b WriteTail + + LoopTailCol12x4: + cbz x20, ReloadTail12x4 + cbnz x15, InitTailFromBias12x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + dup v24.2d, xzr + dup v26.2d, xzr + dup v28.2d, xzr + dup v30.2d, xzr + b ComputeTail12x4Enter + InitTailFromBias12x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + ld1 {v16.4s}, [x15] + ld1 {v18.4s}, [x15] + ld1 {v20.4s}, [x15] + ld1 {v22.4s}, [x15] + ld1 {v24.4s}, [x15] + ld1 {v26.4s}, [x15] + ld1 {v28.4s}, [x15] + ld1 {v30.4s}, [x15] + b ComputeTail12x4Enter + ReloadTail12x4: + bl Reload + ComputeTail12x4Enter: + cbz x21, Activation12x4 + bl Compute12x4Unit + Activation12x4: + cmp x4, #3 + beq Relu612x4 + cmp x4, #1 + beq Relu12x4 + b WriteTail + + Relu612x4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + + Relu12x4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + b WriteTail + + LoopTailRow8: + mov x11, x0 // lhs + mov x23, x12 // rhs + mov x21, x5 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopTailCol8x4 + + LoopTailCol8x8: + cbz x20, ReloadTail8x8 + cbnz x15, InitTailFromBias8x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + b ComputeTail8x8Enter + InitTailFromBias8x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + ld1 {v16.4s, v17.4s}, [x15] + ld1 {v18.4s, v19.4s}, [x15] + ld1 {v20.4s, v21.4s}, [x15] + ld1 {v22.4s, v23.4s}, [x15] + add x15, x15, #32 + b ComputeTail8x8Enter + ReloadTail8x8: + bl Reload + ComputeTail8x8Enter: + cbz x21, Activation8x8 + bl Compute8x8Unit + Activation8x8: + cmp x4, #3 + beq Relu68x8 + cmp x4, #1 + beq Relu8x8 + b WriteTail + + Relu68x8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + + Relu8x8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b WriteTail + + LoopTailCol8x4: + cbz x20, ReloadTail8x4 + cbnz x15, InitTailFromBias8x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + b ComputeTail8x4Enter + InitTailFromBias8x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + ld1 {v16.4s}, [x15] + ld1 {v18.4s}, [x15] + ld1 {v20.4s}, [x15] + ld1 {v22.4s}, [x15] + b ComputeTail8x4Enter + ReloadTail8x4: + bl Reload + ComputeTail8x4Enter: + cbz x21, Activation8x4 + bl Compute8x4Unit + Activation8x4: + cmp x4, #3 + beq Relu68x4 + cmp x4, #1 + beq Relu8x4 + b WriteTail + + Relu68x4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + + Relu8x4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + b WriteTail + + LoopTailRow4: + mov x11, x0 // lhs + mov x23, x12 // rhs + mov x21, x5 // depth unit + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + cmp x9, #4 + ble LoopTailCol4x4 + + LoopTailCol4x8: + cbz x20, ReloadTail4x8 + cbnz x15, InitTailFromBias4x8 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + b ComputeTail4x8Enter + InitTailFromBias4x8: + ld1 {v8.4s, v9.4s}, [x15] + ld1 {v10.4s, v11.4s}, [x15] + ld1 {v12.4s, v13.4s}, [x15] + ld1 {v14.4s, v15.4s}, [x15] + add x15, x15, #32 + b ComputeTail4x8Enter + ReloadTail4x8: + bl Reload + ComputeTail4x8Enter: + cbz x21, Activation4x8 + bl Compute4x8Unit + Activation4x8: + cmp x4, #3 + beq Relu64x8 + cmp x4, #1 + beq Relu4x8 + b WriteTail + + Relu64x8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4x8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + b WriteTail + + LoopTailCol4x4: + cbz x20, ReloadTail4x4 + cbnz x15, InitTailFromBias4x4 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + b ComputeTail4x4Enter + InitTailFromBias4x4: + ld1 {v8.4s}, [x15] + ld1 {v10.4s}, [x15] + ld1 {v12.4s}, [x15] + ld1 {v14.4s}, [x15] + b ComputeTail4x4Enter + ReloadTail4x4: + bl Reload + ComputeTail4x4Enter: + cbz x21, Activation4x4 + bl Compute4x4Unit + Activation4x4: + cmp x4, #3 + beq Relu64x4 + cmp x4, #1 + beq Relu4x4 + b WriteTail + + Relu64x4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + + Relu4x4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + +WriteTail: + mov x21, x14 + cmp x9, #1 + beq WriteTail1 + cmp x9, #2 + beq WriteTail2 + cmp x9, #3 + beq WriteTail3 + cmp x9, #4 + beq WriteTail4 + cmp x9, #5 + beq WriteTail5 + cmp x9, #6 + beq WriteTail6 + cmp x9, #7 + beq WriteTail7 + b WriteTail8 + + WriteTail1: + str s8, [x21] + cmp x10, #1 + beq LoopTailCol + add x21, x21, x8 + str s10, [x21] + cmp x10, #2 + beq LoopTailCol + add x21, x21, x8 + str s12, [x21] + cmp x10, #3 + beq LoopTailCol + add x21, x21, x8 + str s14, [x21] + cmp x10, #4 + beq LoopTailCol + add x21, x21, x8 + str s16, [x21] + cmp x10, #5 + beq LoopTailCol + add x21, x21, x8 + str s18, [x21] + cmp x10, #6 + beq LoopTailCol + add x21, x21, x8 + str s20, [x21] + cmp x10, #7 + beq LoopTailCol + add x21, x21, x8 + str s22, [x21] + cmp x10, #8 + beq LoopTailCol + add x21, x21, x8 + str s24, [x21] + cmp x10, #9 + beq LoopTailCol + add x21, x21, x8 + str s26, [x21] + cmp x10, #10 + beq LoopTailCol + add x21, x21, x8 + str s28, [x21] + cmp x10, #11 + beq LoopTailCol + add x21, x21, x8 + str s30, [x21] + b LoopTailCol + WriteTail2: + st1 {v8.2s}, [x21], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.2s}, [x21], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.2s}, [x21], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.2s}, [x21], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.2s}, [x21], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.2s}, [x21], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.2s}, [x21], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.2s}, [x21], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.2s}, [x21], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.2s}, [x21], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.2s}, [x21], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.2s}, [x21], x8 + add x21, x21, #8 + b LoopTailCol + WriteTail3: + add x11, x21, #8 + st1 {v8.2s}, [x21], x8 + st1 {v8.s}[2], [x11], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.2s}, [x21], x8 + st1 {v10.s}[2], [x11], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.2s}, [x21], x8 + st1 {v12.s}[2], [x11], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.2s}, [x21], x8 + st1 {v14.s}[2], [x11], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.2s}, [x21], x8 + st1 {v16.s}[2], [x11], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.2s}, [x21], x8 + st1 {v18.s}[2], [x11], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.2s}, [x21], x8 + st1 {v20.s}[2], [x11], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.2s}, [x21], x8 + st1 {v22.s}[2], [x11], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.2s}, [x21], x8 + st1 {v24.s}[2], [x11], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.2s}, [x21], x8 + st1 {v26.s}[2], [x11], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.2s}, [x21], x8 + st1 {v28.s}[2], [x11], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.2s}, [x21], x8 + st1 {v30.s}[2], [x11] + add x21, x21, #12 + b LoopTailCol + WriteTail4: + st1 {v8.4s}, [x21], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.4s}, [x21], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.4s}, [x21], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.4s}, [x21], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.4s}, [x21], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.4s}, [x21], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.4s}, [x21], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.4s}, [x21], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.4s}, [x21], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.4s}, [x21], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.4s}, [x21], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.4s}, [x21], x8 + add x21, x21, #16 + b LoopTailCol + WriteTail5: + add x11, x21, #16 + st1 {v8.4s}, [x21], x8 + str s9, [x11] + cmp x10, #1 + beq LoopTailCol + add x11, x11, x8 + st1 {v10.4s}, [x21], x8 + str s11, [x11] + cmp x10, #2 + beq LoopTailCol + add x11, x11, x8 + st1 {v12.4s}, [x21], x8 + str s13, [x11] + cmp x10, #3 + beq LoopTailCol + add x11, x11, x8 + st1 {v14.4s}, [x21], x8 + str s15, [x11] + cmp x10, #4 + beq LoopTailCol + add x11, x11, x8 + st1 {v16.4s}, [x21], x8 + str s17, [x11] + cmp x10, #5 + beq LoopTailCol + add x11, x11, x8 + st1 {v18.4s}, [x21], x8 + str s19, [x11] + cmp x10, #6 + beq LoopTailCol + add x11, x11, x8 + st1 {v20.4s}, [x21], x8 + str s21, [x11] + cmp x10, #7 + beq LoopTailCol + add x11, x11, x8 + st1 {v22.4s}, [x21], x8 + str s23, [x11] + cmp x10, #8 + beq LoopTailCol + add x11, x11, x8 + st1 {v24.4s}, [x21], x8 + str s25, [x11] + cmp x10, #9 + beq LoopTailCol + add x11, x11, x8 + st1 {v26.4s}, [x21], x8 + str s27, [x11] + cmp x10, #10 + beq LoopTailCol + add x11, x11, x8 + st1 {v28.4s}, [x21], x8 + str s29, [x11] + cmp x10, #11 + beq LoopTailCol + add x11, x11, x8 + st1 {v30.4s}, [x21], x8 + str s31, [x11] + add x21, x21, #20 + b LoopTailCol + WriteTail6: + add x11, x21, #16 + st1 {v8.4s}, [x21], x8 + st1 {v9.2s}, [x11], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.4s}, [x21], x8 + st1 {v11.2s}, [x11], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.4s}, [x21], x8 + st1 {v13.2s}, [x11], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.4s}, [x21], x8 + st1 {v15.2s}, [x11], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.4s}, [x21], x8 + st1 {v17.2s}, [x11], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.4s}, [x21], x8 + st1 {v19.2s}, [x11], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.4s}, [x21], x8 + st1 {v21.2s}, [x11], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.4s}, [x21], x8 + st1 {v23.2s}, [x11], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.4s}, [x21], x8 + st1 {v25.2s}, [x11], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.4s}, [x21], x8 + st1 {v27.2s}, [x11], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.4s}, [x21], x8 + st1 {v29.2s}, [x11], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.4s}, [x21], x8 + st1 {v31.2s}, [x11] + add x21, x21, #24 + b LoopTailCol + WriteTail7: + add x11, x21, #16 + add x23, x21, #24 + st1 {v8.4s}, [x21], x8 + st1 {v9.2s}, [x11], x8 + st1 {v9.s}[2], [x23], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.4s}, [x21], x8 + st1 {v11.2s}, [x11], x8 + st1 {v11.s}[2], [x23], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.4s}, [x21], x8 + st1 {v13.2s}, [x11], x8 + st1 {v13.s}[2], [x23], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.4s}, [x21], x8 + st1 {v15.2s}, [x11], x8 + st1 {v15.s}[2], [x23], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.4s}, [x21], x8 + st1 {v17.2s}, [x11], x8 + st1 {v17.s}[2], [x23], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.4s}, [x21], x8 + st1 {v19.2s}, [x11], x8 + st1 {v19.s}[2], [x23], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.4s}, [x21], x8 + st1 {v21.2s}, [x11], x8 + st1 {v21.s}[2], [x23], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.4s}, [x21], x8 + st1 {v23.2s}, [x11], x8 + st1 {v23.s}[2], [x23], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.4s}, [x21], x8 + st1 {v25.2s}, [x11], x8 + st1 {v25.s}[2], [x23], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.4s}, [x21], x8 + st1 {v27.2s}, [x11], x8 + st1 {v27.s}[2], [x23], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.4s}, [x21], x8 + st1 {v29.2s}, [x11], x8 + st1 {v29.s}[2], [x23], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.4s}, [x21], x8 + st1 {v31.2s}, [x11] + st1 {v31.s}[2], [x23] + add x21, x21, #28 + b LoopTailCol + + WriteTail8: + st1 {v8.4s, v9.4s}, [x21], x8 + cmp x10, #1 + beq LoopTailCol + st1 {v10.4s, v11.4s}, [x21], x8 + cmp x10, #2 + beq LoopTailCol + st1 {v12.4s, v13.4s}, [x21], x8 + cmp x10, #3 + beq LoopTailCol + st1 {v14.4s, v15.4s}, [x21], x8 + cmp x10, #4 + beq LoopTailCol + st1 {v16.4s, v17.4s}, [x21], x8 + cmp x10, #5 + beq LoopTailCol + st1 {v18.4s, v19.4s}, [x21], x8 + cmp x10, #6 + beq LoopTailCol + st1 {v20.4s, v21.4s}, [x21], x8 + cmp x10, #7 + beq LoopTailCol + st1 {v22.4s, v23.4s}, [x21], x8 + cmp x10, #8 + beq LoopTailCol + st1 {v24.4s, v25.4s}, [x21], x8 + cmp x10, #9 + beq LoopTailCol + st1 {v26.4s, v27.4s}, [x21], x8 + cmp x10, #10 + beq LoopTailCol + st1 {v28.4s, v29.4s}, [x21], x8 + cmp x10, #11 + beq LoopTailCol + st1 {v30.4s, v31.4s}, [x21], x8 + add x21, x21, #32 + b LoopTailCol + +LoopTailCol: + subs x9, x9, #8 + ble LoopTailEnd + add x12, x12, x22 // update rhs + add x14, x14, #32 + cmp x10, #4 + ble LoopTailRow4 + cmp x10, #8 + ble LoopTailRow8 + b LoopTailRow12 + +LoopTailEnd: + add x0, x0, x26 // update lhs + add x13, x13, x24 // update out + subs x10, x10, #12 // update row + bgt TailRowStart + b End + +Reload: + mov x15, x14 + cmp x9, #1 + beq Reload1 + cmp x9, #2 + beq Reload2 + cmp x9, #3 + beq Reload3 + cmp x9, #4 + beq Reload4 + cmp x9, #5 + beq Reload5 + cmp x9, #6 + beq Reload6 + cmp x9, #7 + beq Reload7 + b Reload8 + + Reload1: + ldr s8, [x15] + cmp x10, #1 + beq ReloadEnd + add x15, x15, x8 + ldr s10, [x15] + cmp x10, #2 + beq ReloadEnd + add x15, x15, x8 + ldr s12, [x15] + cmp x10, #3 + beq ReloadEnd + add x15, x15, x8 + ldr s14, [x15] + cmp x10, #4 + beq ReloadEnd + add x15, x15, x8 + ldr s16, [x15] + cmp x10, #5 + beq ReloadEnd + add x15, x15, x8 + ldr s18, [x15] + cmp x10, #6 + beq ReloadEnd + add x15, x15, x8 + ldr s20, [x15] + cmp x10, #7 + beq ReloadEnd + add x15, x15, x8 + ldr s22, [x15] + cmp x10, #8 + beq ReloadEnd + add x15, x15, x8 + ldr s24, [x15] + cmp x10, #9 + beq ReloadEnd + add x15, x15, x8 + ldr s26, [x15] + cmp x10, #10 + beq ReloadEnd + add x15, x15, x8 + ldr s28, [x15] + cmp x10, #11 + beq ReloadEnd + add x15, x15, x8 + ldr s30, [x15] + b ReloadEnd + Reload2: + ld1 {v8.2s}, [x15], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.2s}, [x15], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.2s}, [x15], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.2s}, [x15], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.2s}, [x15], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.2s}, [x15], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.2s}, [x15], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.2s}, [x15], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.2s}, [x15], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.2s}, [x15], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.2s}, [x15], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.2s}, [x15], x8 + add x15, x15, #8 + b ReloadEnd + Reload3: + add x19, x15, #8 + ld1 {v8.2s}, [x15], x8 + ld1 {v8.s}[2], [x19], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.2s}, [x15], x8 + ld1 {v10.s}[2], [x19], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.2s}, [x15], x8 + ld1 {v12.s}[2], [x19], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.2s}, [x15], x8 + ld1 {v14.s}[2], [x19], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.2s}, [x15], x8 + ld1 {v16.s}[2], [x19], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.2s}, [x15], x8 + ld1 {v18.s}[2], [x19], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.2s}, [x15], x8 + ld1 {v20.s}[2], [x19], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.2s}, [x15], x8 + ld1 {v22.s}[2], [x19], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.2s}, [x15], x8 + ld1 {v24.s}[2], [x19], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.2s}, [x15], x8 + ld1 {v26.s}[2], [x19], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.2s}, [x15], x8 + ld1 {v28.s}[2], [x19], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.2s}, [x15], x8 + ld1 {v30.s}[2], [x19] + add x15, x15, #12 + b ReloadEnd + Reload4: + ld1 {v8.4s}, [x15], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.4s}, [x15], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.4s}, [x15], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.4s}, [x15], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.4s}, [x15], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.4s}, [x15], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.4s}, [x15], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.4s}, [x15], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.4s}, [x15], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.4s}, [x15], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.4s}, [x15], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.4s}, [x15], x8 + add x15, x15, #16 + b ReloadEnd + Reload5: + add x19, x15, #16 + ld1 {v8.4s}, [x15], x8 + ldr s9, [x19] + cmp x10, #1 + beq ReloadEnd + add x19, x19, x8 + ld1 {v10.4s}, [x15], x8 + ldr s11, [x19] + cmp x10, #2 + beq ReloadEnd + add x19, x19, x8 + ld1 {v12.4s}, [x15], x8 + ldr s13, [x19] + cmp x10, #3 + beq ReloadEnd + add x19, x19, x8 + ld1 {v14.4s}, [x15], x8 + ldr s15, [x19] + cmp x10, #4 + beq ReloadEnd + add x19, x19, x8 + ld1 {v16.4s}, [x15], x8 + ldr s17, [x19] + cmp x10, #5 + beq ReloadEnd + add x19, x19, x8 + ld1 {v18.4s}, [x15], x8 + ldr s19, [x19] + cmp x10, #6 + beq ReloadEnd + add x19, x19, x8 + ld1 {v20.4s}, [x15], x8 + ldr s21, [x19] + cmp x10, #7 + beq ReloadEnd + add x19, x19, x8 + ld1 {v22.4s}, [x15], x8 + ldr s23, [x19] + cmp x10, #8 + beq ReloadEnd + add x19, x19, x8 + ld1 {v24.4s}, [x15], x8 + ldr s25, [x19] + cmp x10, #9 + beq ReloadEnd + add x19, x19, x8 + ld1 {v26.4s}, [x15], x8 + ldr s27, [x19] + cmp x10, #10 + beq ReloadEnd + add x19, x19, x8 + ld1 {v28.4s}, [x15], x8 + ldr s29, [x19] + cmp x10, #11 + beq ReloadEnd + add x19, x19, x8 + ld1 {v30.4s}, [x15], x8 + ldr s31, [x19] + add x15, x15, #20 + b ReloadEnd + Reload6: + add x19, x15, #16 + ld1 {v8.4s}, [x15], x8 + ld1 {v9.2s}, [x19], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.4s}, [x15], x8 + ld1 {v11.2s}, [x19], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.4s}, [x15], x8 + ld1 {v13.2s}, [x19], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.4s}, [x15], x8 + ld1 {v15.2s}, [x19], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.4s}, [x15], x8 + ld1 {v17.2s}, [x19], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.4s}, [x15], x8 + ld1 {v19.2s}, [x19], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.4s}, [x15], x8 + ld1 {v21.2s}, [x19], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.4s}, [x15], x8 + ld1 {v23.2s}, [x19], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.4s}, [x15], x8 + ld1 {v25.2s}, [x19], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.4s}, [x15], x8 + ld1 {v27.2s}, [x19], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.4s}, [x15], x8 + ld1 {v29.2s}, [x19], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.4s}, [x15], x8 + ld1 {v31.2s}, [x19] + add x15, x15, #24 + b ReloadEnd + Reload7: + add x19, x15, #16 + add x16, x15, #24 + ld1 {v8.4s}, [x15], x8 + ld1 {v9.2s}, [x19], x8 + ld1 {v9.s}[2], [x16], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.4s}, [x15], x8 + ld1 {v11.2s}, [x19], x8 + ld1 {v11.s}[2], [x16], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.4s}, [x15], x8 + ld1 {v13.2s}, [x19], x8 + ld1 {v13.s}[2], [x16], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.4s}, [x15], x8 + ld1 {v15.2s}, [x19], x8 + ld1 {v15.s}[2], [x16], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.4s}, [x15], x8 + ld1 {v17.2s}, [x19], x8 + ld1 {v17.s}[2], [x16], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.4s}, [x15], x8 + ld1 {v19.2s}, [x19], x8 + ld1 {v19.s}[2], [x16], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.4s}, [x15], x8 + ld1 {v21.2s}, [x19], x8 + ld1 {v21.s}[2], [x16], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.4s}, [x15], x8 + ld1 {v23.2s}, [x19], x8 + ld1 {v23.s}[2], [x16], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.4s}, [x15], x8 + ld1 {v25.2s}, [x19], x8 + ld1 {v25.s}[2], [x16], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.4s}, [x15], x8 + ld1 {v27.2s}, [x19], x8 + ld1 {v27.s}[2], [x16], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.4s}, [x15], x8 + ld1 {v29.2s}, [x19], x8 + ld1 {v29.s}[2], [x16], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.4s}, [x15], x8 + ld1 {v31.2s}, [x19] + ld1 {v31.s}[2], [x16] + add x15, x15, #28 + b ReloadEnd + + Reload8: + ld1 {v8.4s, v9.4s}, [x15], x8 + cmp x10, #1 + beq ReloadEnd + ld1 {v10.4s, v11.4s}, [x15], x8 + cmp x10, #2 + beq ReloadEnd + ld1 {v12.4s, v13.4s}, [x15], x8 + cmp x10, #3 + beq ReloadEnd + ld1 {v14.4s, v15.4s}, [x15], x8 + cmp x10, #4 + beq ReloadEnd + ld1 {v16.4s, v17.4s}, [x15], x8 + cmp x10, #5 + beq ReloadEnd + ld1 {v18.4s, v19.4s}, [x15], x8 + cmp x10, #6 + beq ReloadEnd + ld1 {v20.4s, v21.4s}, [x15], x8 + cmp x10, #7 + beq ReloadEnd + ld1 {v22.4s, v23.4s}, [x15], x8 + cmp x10, #8 + beq ReloadEnd + ld1 {v24.4s, v25.4s}, [x15], x8 + cmp x10, #9 + beq ReloadEnd + ld1 {v26.4s, v27.4s}, [x15], x8 + cmp x10, #10 + beq ReloadEnd + ld1 {v28.4s, v29.4s}, [x15], x8 + cmp x10, #11 + beq ReloadEnd + ld1 {v30.4s, v31.4s}, [x15], x8 + add x15, x15, #32 + b ReloadEnd + +ReloadEnd: + ret + +Compute12x8Unit: + subs x21, x21, #2 + ble Compute12x8End + Compute12x8: + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + subs x21, x21, #2 + bgt Compute12x8 + + Compute12x8End: + cbnz x21, Compute12x8End1 + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + Compute12x8End1: + ld1 {v1.4s, v2.4s}, [x11] + ld1 {v4.4s}, [x23] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + ret + +Compute12x4Unit: + subs x21, x21, #2 + ble Compute12x4End + Compute12x4: + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + add x23, x23, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + add x23, x23, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + + subs x21, x21, #2 + bgt Compute12x4 + + Compute12x4End: + cbnz x21, Compute12x4End1 + prfm pldl1keep, [x11, #632] + ld1 {v1.4s, v2.4s}, [x11], #32 + add x23, x23, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + Compute12x4End1: + ld1 {v1.4s, v2.4s}, [x11] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + ret + +Compute8x8Unit: + subs x21, x21, #2 + ble Compute8x8End + Compute8x8: + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + subs x21, x21, #2 + bgt Compute8x8 + + Compute8x8End: + cbnz x21, Compute8x8End1 + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + Compute8x8End1: + ld1 {v1.4s}, [x11] + ld1 {v4.4s}, [x23] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + ret + +Compute8x4Unit: + subs x21, x21, #2 + ble Compute8x4End + Compute8x4: + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + + subs x21, x21, #2 + bgt Compute8x4 + + Compute8x4End: + cbnz x21, Compute8x4End1 + prfm pldl1keep, [x11, #632] + ld1 {v1.4s}, [x11] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + Compute8x4End1: + ld1 {v1.4s}, [x11] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + ret + +Compute4x8Unit: + subs x21, x21, #2 + ble Compute4x8End + Compute4x8: + prfm pldl1keep, [x11, #632] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + + prfm pldl1keep, [x11, #632] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + + subs x21, x21, #2 + bgt Compute4x8 + + Compute4x8End: + cbnz x21, Compute4x8End1 + prfm pldl1keep, [x11, #632] + add x11, x11, #32 + ld1 {v4.4s}, [x23], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ld1 {v0.4s}, [x11], #16 + Compute4x8End1: + ld1 {v4.4s}, [x23] + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + ret + +Compute4x4Unit: + subs x21, x21, #2 + ble Compute4x4End + Compute4x4: + prfm pldl1keep, [x11, #632] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + + prfm pldl1keep, [x11, #632] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + + subs x21, x21, #2 + bgt Compute4x4 + + Compute4x4End: + cbnz x21, Compute4x4End1 + prfm pldl1keep, [x11, #632] + add x23, x23, #16 + add x11, x11, #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + prfm pldl1strm, [x23, #632] + ld1 {v3.4s}, [x23], #16 + ld1 {v0.4s}, [x11], #16 + Compute4x4End1: + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ret + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Corner.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Corner.S new file mode 100644 index 0000000000000000000000000000000000000000..8dc87a7fb7512942ca0bf826b189c94db133bfc8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Corner.S @@ -0,0 +1,114 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, +// int in_kw_step, int channel, size_t relu, size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, x6: channel, x7: relu, x8: relu6 + +asm_function ConvDw3x3Corner + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + mov x9, #4 + mul x13, x6, x9 // x6 * 4 + mul x4, x4, x9 + mul x5, x5, x9 + mov x9, #3 + mul x14, x13, x9 // x6 * 3 * 4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + ld1 {v23.4s}, [x3], #16 + mov x9, x1 + mov x10, x2 + + ld1 {v0.4s}, [x9], x5 + add x11, x1, x4 + ld1 {v4.4s}, [x10], x13 // weight + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + ld1 {v7.4s}, [x12], x13 + + cmp x6, #4 + ble LoopC4Post + + LoopC4: + add x1, x1, #16 + add x2, x2, #16 + fmla v23.4s, v0.4s, v4.4s + mov x9, x1 + mov x10, x2 + ld1 {v0.4s}, [x9], x5 + ld1 {v4.4s}, [x10], x13 + add x11, x1, x4 + fmla v23.4s, v1.4s, v5.4s + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + fmla v23.4s, v2.4s, v6.4s + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + fmla v23.4s, v3.4s, v7.4s + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + ld1 {v7.4s}, [x12], x13 + + cbnz x8, C4_RELU6 + cbnz x7, C4_RELU + b C4_WRITE + C4_RELU6: + fmin v23.4s, v23.4s, v26.4s + C4_RELU: + fmax v23.4s, v23.4s, v27.4s + C4_WRITE: + st1 {v23.4s}, [x0], #16 + ld1 {v23.4s}, [x3], #16 + + sub x6, x6, #4 + cmp x6, #4 + bgt LoopC4 + + LoopC4Post: + fmla v23.4s, v0.4s, v4.4s + fmla v23.4s, v1.4s, v5.4s + fmla v23.4s, v2.4s, v6.4s + fmla v23.4s, v3.4s, v7.4s + + cbnz x8, RELU6 + cbnz x7, RELU + b WRITE + RELU6: + fmin v23.4s, v23.4s, v26.4s + RELU: + fmax v23.4s, v23.4s, v27.4s + WRITE: + st1 {v23.4s}, [x0], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Horizontal.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Horizontal.S new file mode 100644 index 0000000000000000000000000000000000000000..6ffdc16542ce3b31d0a66b7d0b72c69d0ce35c73 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Horizontal.S @@ -0,0 +1,130 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, +// int in_kw_step, int channel, size_t relu, size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, x6: channel, x7: relu, x8: relu6 + +asm_function ConvDw3x3Horizontal + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + mov x9, #4 + mul x13, x6, x9 // x6 * 4 + mul x4, x4, x9 + mul x5, x5, x9 + mov x9, #3 + mul x14, x13, x9 // x6 * 3 * 4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + ld1 {v23.4s}, [x3], #16 + mov x9, x1 + mov x10, x2 + + ld1 {v0.4s}, [x9], x5 + add x11, x1, x4 + ld1 {v4.4s}, [x10], x13 + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + ld1 {v5.4s}, [x10], x13 + add x15, x11, x4 + ld1 {v2.4s}, [x11], x5 + add x16, x12, x14 + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + ld1 {v7.4s}, [x12], x13 + ld1 {v16.4s}, [x15], x5 + ld1 {v18.4s}, [x16], x13 + ld1 {v17.4s}, [x15], x5 + ld1 {v19.4s}, [x16], x13 + + cmp x6, #4 + ble LoopC4Post + + LoopC4: + add x1, x1, #16 + add x2, x2, #16 + fmla v23.4s, v0.4s, v4.4s + mov x9, x1 + mov x10, x2 + ld1 {v0.4s}, [x9], x5 + ld1 {v4.4s}, [x10], x13 + add x11, x1, x4 + fmla v23.4s, v1.4s, v5.4s + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + fmla v23.4s, v2.4s, v6.4s + add x15, x11, x4 + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + fmla v23.4s, v3.4s, v7.4s + add x16, x12, x14 + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + fmla v23.4s, v16.4s, v18.4s + ld1 {v7.4s}, [x12], x13 + ld1 {v16.4s}, [x15], x5 + fmla v23.4s, v17.4s, v19.4s + ld1 {v18.4s}, [x16], x13 + ld1 {v17.4s}, [x15], x5 + ld1 {v19.4s}, [x16], x13 + + cbnz x8, C4_RELU6 + cbnz x7, C4_RELU + b C4_WRITE + C4_RELU6: + fmin v23.4s, v23.4s, v26.4s + C4_RELU: + fmax v23.4s, v23.4s, v27.4s + C4_WRITE: + st1 {v23.4s}, [x0], #16 + ld1 {v23.4s}, [x3], #16 + + sub x6, x6, #4 + cmp x6, #4 + bgt LoopC4 + + LoopC4Post: + fmla v23.4s, v0.4s, v4.4s + fmla v23.4s, v1.4s, v5.4s + fmla v23.4s, v2.4s, v6.4s + fmla v23.4s, v3.4s, v7.4s + fmla v23.4s, v16.4s, v18.4s + fmla v23.4s, v17.4s, v19.4s + + cbnz x8, RELU6 + cbnz x7, RELU + b WRITE + RELU6: + fmin v23.4s, v23.4s, v26.4s + RELU: + fmax v23.4s, v23.4s, v27.4s + WRITE: + st1 {v23.4s}, [x0], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Stride1.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Stride1.S new file mode 100644 index 0000000000000000000000000000000000000000..b96efd6496bcd8c2d59f569a41be700bf53f3b09 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Stride1.S @@ -0,0 +1,210 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Stride1(float *output, const float *buffer, const float *weight, const float *bias, int col_size, +// int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: relu +// w10: relu6 + +asm_function ConvDw3x3Stride1 + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + + ldr w8, [sp, #128] + ldr w9, [sp, #136] + ldr w10, [sp, #144] + + mov w11, #4 + mul w15, w4, w11 // col_size * 4 + mul w16, w6, w11 // channel * 4 + mul w17, w5, w11 // row_size * 4 + mov w11, #2 + mul w14, w11, w15 // col_size * 2 * 4 + + movi v23.4s, #6 + scvtf v23.4s, v23.4s + dup v24.4s, wzr + + // Load weights + ld1 {v0.4s}, [x2], x16 + ld1 {v1.4s}, [x2], x16 + ld1 {v2.4s}, [x2], x16 + ld1 {v3.4s}, [x2], x16 + ld1 {v4.4s}, [x2], x16 + ld1 {v5.4s}, [x2], x16 + ld1 {v6.4s}, [x2], x16 + ld1 {v7.4s}, [x2], x16 + ld1 {v8.4s}, [x2], x16 + + mov x11, x1 + add x12, x11, x17 + add x13, x12, x17 + ld1 {v9.4s}, [x11], x15 + ld1 {v10.4s}, [x11], x15 + ld1 {v11.4s}, [x11], x15 + ld1 {v13.4s}, [x12], x15 + ld1 {v14.4s}, [x12], x15 + ld1 {v15.4s}, [x12], x15 + ld1 {v17.4s}, [x13], x15 + ld1 {v18.4s}, [x13], x15 + ld1 {v19.4s}, [x13], x15 + + ld1 {v21.4s}, [x3] + ld1 {v22.4s}, [x3] + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +WIDTH2_LOOP: + fmla v21.4s, v0.4s, v9.4s + ld1 {v12.4s}, [x11] + ld1 {v16.4s}, [x12] + fmla v22.4s, v0.4s, v10.4s + ld1 {v20.4s}, [x13] + add x1, x1, x14 + fmla v21.4s, v1.4s, v10.4s + mov x11, x1 + add x12, x11, x17 + add x13, x12, x17 + ld1 {v9.4s}, [x11], x15 + fmla v22.4s, v1.4s, v11.4s + ld1 {v10.4s}, [x11], x15 + fmla v21.4s, v2.4s, v11.4s + fmla v22.4s, v2.4s, v12.4s + fmla v21.4s, v3.4s, v13.4s + ld1 {v11.4s}, [x11], x15 + fmla v22.4s, v3.4s, v14.4s + fmla v21.4s, v4.4s, v14.4s + ld1 {v13.4s}, [x12], x15 + fmla v22.4s, v4.4s, v15.4s + fmla v21.4s, v5.4s, v15.4s + ld1 {v14.4s}, [x12], x15 + fmla v22.4s, v5.4s, v16.4s + fmla v21.4s, v6.4s, v17.4s + ld1 {v15.4s}, [x12], x15 + fmla v22.4s, v6.4s, v18.4s + fmla v21.4s, v7.4s, v18.4s + ld1 {v17.4s}, [x13], x15 + fmla v22.4s, v7.4s, v19.4s + fmla v21.4s, v8.4s, v19.4s + ld1 {v18.4s}, [x13], x15 + fmla v22.4s, v8.4s, v20.4s + ld1 {v19.4s}, [x13], x15 + + cbnz x10, WIDTH2_RELU6 + cbnz x9, WIDTH2_RELU + b WIDTH2_WRITE + WIDTH2_RELU6: + fmin v21.4s, v21.4s, v23.4s + fmin v22.4s, v22.4s, v23.4s + WIDTH2_RELU: + fmax v21.4s, v21.4s, v24.4s + fmax v22.4s, v22.4s, v24.4s + WIDTH2_WRITE: + st1 {v21.4s}, [x0], x16 + ld1 {v21.4s}, [x3] + st1 {v22.4s}, [x0], x16 + ld1 {v22.4s}, [x3] + + sub w8, w8, #2 + cmp w8, #2 + bgt WIDTH2_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + fmla v21.4s, v0.4s, v9.4s + ld1 {v12.4s}, [x11] + fmla v22.4s, v0.4s, v10.4s + fmla v21.4s, v1.4s, v10.4s + ld1 {v16.4s}, [x12] + fmla v22.4s, v1.4s, v11.4s + fmla v21.4s, v2.4s, v11.4s + ld1 {v20.4s}, [x13] + fmla v22.4s, v2.4s, v12.4s + fmla v21.4s, v3.4s, v13.4s + fmla v22.4s, v3.4s, v14.4s + fmla v21.4s, v4.4s, v14.4s + fmla v22.4s, v4.4s, v15.4s + fmla v21.4s, v5.4s, v15.4s + fmla v22.4s, v5.4s, v16.4s + fmla v21.4s, v6.4s, v17.4s + fmla v22.4s, v6.4s, v18.4s + fmla v21.4s, v7.4s, v18.4s + fmla v22.4s, v7.4s, v19.4s + fmla v21.4s, v8.4s, v19.4s + fmla v22.4s, v8.4s, v20.4s + + cbnz x10, WIDTH2_LEFT_RELU6 + cbnz x9, WIDTH2_LEFT_RELU + b WIDTH2_LEFT_WRITE + WIDTH2_LEFT_RELU6: + fmin v21.4s, v21.4s, v23.4s + fmin v22.4s, v22.4s, v23.4s + WIDTH2_LEFT_RELU: + fmax v21.4s, v21.4s, v24.4s + fmax v22.4s, v22.4s, v24.4s + WIDTH2_LEFT_WRITE: + st1 {v21.4s}, [x0], x16 + st1 {v22.4s}, [x0], x16 + b End + +WIDTH1_LEFT: + fmla v21.4s, v0.4s, v9.4s + fmla v21.4s, v1.4s, v10.4s + fmla v21.4s, v2.4s, v11.4s + fmla v21.4s, v3.4s, v13.4s + fmla v21.4s, v4.4s, v14.4s + fmla v21.4s, v5.4s, v15.4s + fmla v21.4s, v6.4s, v17.4s + fmla v21.4s, v7.4s, v18.4s + fmla v21.4s, v8.4s, v19.4s + + cbnz x10, WIDTH1_RELU6 + cbnz x9, WIDTH1_RELU + b WIDTH1_WRITE + WIDTH1_RELU6: + fmin v21.4s, v21.4s, v23.4s + WIDTH1_RELU: + fmax v21.4s, v21.4s, v24.4s + WIDTH1_WRITE: + st1 {v21.4s}, [x0] + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Stride2.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Stride2.S new file mode 100644 index 0000000000000000000000000000000000000000..7632d48e135cfef9ea526c031d28fa20d447d1af --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Stride2.S @@ -0,0 +1,212 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Stride2(float *output, const float *buffer, const float *weight, const float *bias, int col_size, +// int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: relu +// w10: relu6 + +asm_function ConvDw3x3Stride2 + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + + ldr w8, [sp, #128] + ldr w9, [sp, #136] + ldr w10, [sp, #144] + + mov w11, #4 + mul w15, w4, w11 // col_size * 4 + mul w16, w6, w11 // channel * 4 + mul w17, w5, w11 // row_size * 4 + mov w11, #2 + mul w14, w11, w15 // col_size * 2 * 4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + // Load weights + ld1 {v0.4s}, [x2], x16 + ld1 {v1.4s}, [x2], x16 + ld1 {v2.4s}, [x2], x16 + ld1 {v3.4s}, [x2], x16 + ld1 {v4.4s}, [x2], x16 + ld1 {v5.4s}, [x2], x16 + ld1 {v6.4s}, [x2], x16 + ld1 {v7.4s}, [x2], x16 + ld1 {v8.4s}, [x2], x16 + + mov x11, x1 + add x12, x11, x17 + add x13, x12, x17 + ld1 {v9.4s}, [x11], x15 + ld1 {v10.4s}, [x11], x15 + ld1 {v11.4s}, [x11], x15 + ld1 {v14.4s}, [x12], x15 + ld1 {v15.4s}, [x12], x15 + ld1 {v16.4s}, [x12], x15 + ld1 {v19.4s}, [x13], x15 + ld1 {v20.4s}, [x13], x15 + ld1 {v21.4s}, [x13], x15 + + ld1 {v24.4s}, [x3] + ld1 {v25.4s}, [x3] + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +WIDTH2_LOOP: + fmla v24.4s, v0.4s, v9.4s + ld1 {v12.4s}, [x11], x15 + fmla v25.4s, v0.4s, v11.4s + ld1 {v17.4s}, [x12], x15 + fmla v24.4s, v1.4s, v10.4s + ld1 {v22.4s}, [x13], x15 + fmla v25.4s, v1.4s, v12.4s + ld1 {v13.4s}, [x11], x15 + fmla v24.4s, v2.4s, v11.4s + ld1 {v18.4s}, [x12], x15 + fmla v25.4s, v2.4s, v13.4s + ld1 {v23.4s}, [x13], x15 + fmla v24.4s, v3.4s, v14.4s + mov v9.16b, v13.16b + fmla v25.4s, v3.4s, v16.4s + mov v14.16b, v18.16b + fmla v24.4s, v4.4s, v15.4s + fmla v25.4s, v4.4s, v17.4s + ld1 {v10.4s}, [x11], x15 + fmla v24.4s, v5.4s, v16.4s + ld1 {v11.4s}, [x11], x15 + fmla v25.4s, v5.4s, v18.4s + ld1 {v15.4s}, [x12], x15 + fmla v24.4s, v6.4s, v19.4s + ld1 {v16.4s}, [x12], x15 + fmla v25.4s, v6.4s, v21.4s + mov v19.16b, v23.16b + fmla v24.4s, v7.4s, v20.4s + fmla v25.4s, v7.4s, v22.4s + ld1 {v20.4s}, [x13], x15 + fmla v24.4s, v8.4s, v21.4s + fmla v25.4s, v8.4s, v23.4s + ld1 {v21.4s}, [x13], x15 + + cbnz x10, WIDTH2_RELU6 + cbnz x9, WIDTH2_RELU + b WIDTH2_WRITE + WIDTH2_RELU6: + fmin v24.4s, v24.4s, v26.4s + fmin v25.4s, v25.4s, v26.4s + WIDTH2_RELU: + fmax v24.4s, v24.4s, v27.4s + fmax v25.4s, v25.4s, v27.4s + WIDTH2_WRITE: + st1 {v24.4s}, [x0], x16 + ld1 {v24.4s}, [x3] + st1 {v25.4s}, [x0], x16 + ld1 {v25.4s}, [x3] + + sub w8, w8, #2 + cmp w8, #2 + bgt WIDTH2_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + fmla v24.4s, v0.4s, v9.4s + ld1 {v12.4s}, [x11], x15 + fmla v25.4s, v0.4s, v11.4s + ld1 {v17.4s}, [x12], x15 + fmla v24.4s, v1.4s, v10.4s + ld1 {v22.4s}, [x13], x15 + fmla v25.4s, v1.4s, v12.4s + ld1 {v13.4s}, [x11], x15 + fmla v24.4s, v2.4s, v11.4s + ld1 {v18.4s}, [x12], x15 + fmla v25.4s, v2.4s, v13.4s + ld1 {v23.4s}, [x13], x15 + fmla v24.4s, v3.4s, v14.4s + fmla v25.4s, v3.4s, v16.4s + fmla v24.4s, v4.4s, v15.4s + fmla v25.4s, v4.4s, v17.4s + fmla v24.4s, v5.4s, v16.4s + fmla v25.4s, v5.4s, v18.4s + fmla v24.4s, v6.4s, v19.4s + fmla v25.4s, v6.4s, v21.4s + fmla v24.4s, v7.4s, v20.4s + fmla v25.4s, v7.4s, v22.4s + fmla v24.4s, v8.4s, v21.4s + fmla v25.4s, v8.4s, v23.4s + + cbnz x10, WIDTH2_LEFT_RELU6 + cbnz x9, WIDTH2_LEFT_RELU + b WIDTH2_LEFT_WRITE + WIDTH2_LEFT_RELU6: + fmin v24.4s, v24.4s, v26.4s + fmin v25.4s, v25.4s, v26.4s + WIDTH2_LEFT_RELU: + fmax v24.4s, v24.4s, v27.4s + fmax v25.4s, v25.4s, v27.4s + WIDTH2_LEFT_WRITE: + st1 {v24.4s}, [x0], x16 + st1 {v25.4s}, [x0], x16 + b End + +WIDTH1_LEFT: + fmla v24.4s, v0.4s, v9.4s + fmla v24.4s, v1.4s, v10.4s + fmla v24.4s, v2.4s, v11.4s + fmla v24.4s, v3.4s, v14.4s + fmla v24.4s, v4.4s, v15.4s + fmla v24.4s, v5.4s, v16.4s + fmla v24.4s, v6.4s, v19.4s + fmla v24.4s, v7.4s, v20.4s + fmla v24.4s, v8.4s, v21.4s + + cbnz x10, WIDTH1_RELU6 + cbnz x9, WIDTH1_RELU + b WIDTH1_WRITE + WIDTH1_RELU6: + fmin v24.4s, v24.4s, v26.4s + WIDTH1_RELU: + fmax v24.4s, v24.4s, v27.4s + WIDTH1_WRITE: + st1 {v24.4s}, [x0] + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Vertical.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Vertical.S new file mode 100644 index 0000000000000000000000000000000000000000..feef03676ff16f05ce9e323f926c28c761f1f7dd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Fp32Vertical.S @@ -0,0 +1,126 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, +// int in_kw_step, int channel, size_t relu, size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, x6: channel, x7: relu, x8: relu6 + +asm_function ConvDw3x3Vertical + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + mov x9, #4 + mul x13, x6, x9 // x6 * 4 + mul x4, x4, x9 + mul x5, x5, x9 + mov x9, #3 + mul x14, x13, x9 // x6 * 3 * 4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + ld1 {v23.4s}, [x3], #16 + mov x9, x1 + mov x10, x2 + + ld1 {v0.4s}, [x9], x5 + add x11, x1, x4 + ld1 {v4.4s}, [x10], x13 + ld1 {v1.4s}, [x9], x5 + add x12, x2, x14 + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + ld1 {v7.4s}, [x12], x13 + ld1 {v16.4s}, [x9], x5 + ld1 {v18.4s}, [x10], x13 + ld1 {v17.4s}, [x11], x5 + ld1 {v19.4s}, [x12], x13 + + cmp x6, #4 + ble LoopC4Post + + LoopC4: + add x1, x1, #16 + add x2, x2, #16 + fmla v23.4s, v0.4s, v4.4s + mov x9, x1 + mov x10, x2 + ld1 {v0.4s}, [x9], x5 + ld1 {v4.4s}, [x10], x13 + add x11, x1, x4 + fmla v23.4s, v1.4s, v5.4s + add x12, x2, x14 + ld1 {v1.4s}, [x9], x5 + fmla v23.4s, v2.4s, v6.4s + ld1 {v5.4s}, [x10], x13 + ld1 {v2.4s}, [x11], x5 + fmla v23.4s, v3.4s, v7.4s + ld1 {v6.4s}, [x12], x13 + ld1 {v3.4s}, [x11], x5 + fmla v23.4s, v16.4s, v18.4s + ld1 {v7.4s}, [x12], x13 + ld1 {v16.4s}, [x9], x5 + fmla v23.4s, v17.4s, v19.4s + ld1 {v18.4s}, [x10], x13 + ld1 {v17.4s}, [x11], x5 + ld1 {v19.4s}, [x12], x13 + + cbnz x8, C4_RELU6 + cbnz x7, C4_RELU + b C4_WRITE + C4_RELU6: + fmin v23.4s, v23.4s, v26.4s + C4_RELU: + fmax v23.4s, v23.4s, v27.4s + C4_WRITE: + st1 {v23.4s}, [x0], #16 + ld1 {v23.4s}, [x3], #16 + + sub x6, x6, #4 + cmp x6, #4 + bgt LoopC4 + + LoopC4Post: + fmla v23.4s, v0.4s, v4.4s + fmla v23.4s, v1.4s, v5.4s + fmla v23.4s, v2.4s, v6.4s + fmla v23.4s, v3.4s, v7.4s + fmla v23.4s, v16.4s, v18.4s + fmla v23.4s, v17.4s, v19.4s + + cbnz x8, RELU6 + cbnz x7, RELU + b WRITE + RELU6: + fmin v23.4s, v23.4s, v26.4s + RELU: + fmax v23.4s, v23.4s, v27.4s + WRITE: + st1 {v23.4s}, [x0], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8.S new file mode 100644 index 0000000000000000000000000000000000000000..1147c3e02e10d02ccf5474cdaa2cd82ca82f9eb8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8.S @@ -0,0 +1,500 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size, +// int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, +// int *out_multiplier, int *left_shift, int *right_shift, int32_t acc_min, int32_t acc_max, +// size_t per_channel) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: in_zp +// w10: out_zp +// w11: out_multiplier +// w12: left_shift +// w13: right_shift +// w14: acc_min +// w15: acc_max +// w16: per_channel + +asm_function ConvDw3x3Int8Neon64 + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + ldr x23, [sp, #256] // per_channel + + add x19, x3, #16 + add w20, w6, w6 // channel * 2 + add w21, w4, w4 // col_size * 2 + dup v25.8b, w9 + + cbnz w23, PER_CHANNEL_DUMP + PER_LAYER_DUMP: + ld1r {v27.4s}, [x11] // out_multiplier + ld1r {v26.4s}, [x12] // left_shift + ld1r {v28.4s}, [x13] // right_shift + b MAIN_FUC + PER_CHANNEL_DUMP: + ld1 {v27.4s}, [x11] + ld1 {v26.4s}, [x12] + ld1 {v28.4s}, [x13] + MAIN_FUC: + dup v29.4s, w10 + dup v30.4s, w14 + dup v31.4s, w15 + ldr w24, [x12] + + // Load weights + ld1 {v0.8h}, [x2], x20 + ld1 {v1.8h}, [x2], x20 + ld1 {v2.8h}, [x2], x20 + ld1 {v3.8h}, [x2], x20 + ld1 {v4.8h}, [x2], x20 + ld1 {v5.8h}, [x2], x20 + ld1 {v6.8h}, [x2], x20 + ld1 {v7.8h}, [x2], x20 + ld1 {v8.8h}, [x2], x20 + + mov x16, x1 + add x17, x16, x5 + add x25, x17, x5 + ld1 {v9.8b}, [x16], x4 + ld1 {v10.8b}, [x16], x4 + ld1 {v11.8b}, [x16], x4 + ld1 {v13.8b}, [x17], x4 + ld1 {v14.8b}, [x17], x4 + ld1 {v15.8b}, [x17], x4 + ld1 {v17.8b}, [x25], x4 + ld1 {v18.8b}, [x25], x4 + ld1 {v19.8b}, [x25], x4 + + ld1 {v21.4s}, [x3] + ld1 {v22.4s}, [x19] + ld1 {v23.4s}, [x3] + ld1 {v24.4s}, [x19] + + // subtract input zp + ssubl v9.8h, v9.8b, v25.8b + ssubl v10.8h, v10.8b, v25.8b + ssubl v11.8h, v11.8b, v25.8b + ssubl v13.8h, v13.8b, v25.8b + ssubl v14.8h, v14.8b, v25.8b + ssubl v15.8h, v15.8b, v25.8b + ssubl v17.8h, v17.8b, v25.8b + ssubl v18.8h, v18.8b, v25.8b + ssubl v19.8h, v19.8b, v25.8b + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +HEIGHT1_LOOP: + smlal v21.4s, v0.4h, v9.4h + ld1 {v12.8b}, [x16] + smlal2 v22.4s, v0.8h, v9.8h + ld1 {v16.8b}, [x17] + smlal v23.4s, v0.4h, v10.4h + smlal2 v24.4s, v0.8h, v10.8h + ld1 {v20.8b}, [x25] + add x1, x1, x21 + ssubl v12.8h, v12.8b, v25.8b + smlal v21.4s, v1.4h, v10.4h + mov x16, x1 + add x17, x16, x5 + add x25, x17, x5 + smlal2 v22.4s, v1.8h, v10.8h + ld1 {v9.8b}, [x16], x4 + ssubl v16.8h, v16.8b, v25.8b + smlal v23.4s, v1.4h, v11.4h + ld1 {v10.8b}, [x16], x4 + ssubl v20.8h, v20.8b, v25.8b + smlal2 v24.4s, v1.8h, v11.8h + smlal v21.4s, v2.4h, v11.4h + smlal2 v22.4s, v2.8h, v11.8h + ld1 {v11.8b}, [x16], x4 + smlal v23.4s, v2.4h, v12.4h + smlal2 v24.4s, v2.8h, v12.8h + smlal v21.4s, v3.4h, v13.4h + smlal2 v22.4s, v3.8h, v13.8h + ld1 {v13.8b}, [x17], x4 + smlal v23.4s, v3.4h, v14.4h + smlal2 v24.4s, v3.8h, v14.8h + smlal v21.4s, v4.4h, v14.4h + smlal2 v22.4s, v4.8h, v14.8h + ld1 {v14.8b}, [x17], x4 + smlal v23.4s, v4.4h, v15.4h + smlal2 v24.4s, v4.8h, v15.8h + smlal v21.4s, v5.4h, v15.4h + smlal2 v22.4s, v5.8h, v15.8h + ld1 {v15.8b}, [x17], x4 + smlal v23.4s, v5.4h, v16.4h + smlal2 v24.4s, v5.8h, v16.8h + smlal v21.4s, v6.4h, v17.4h + smlal2 v22.4s, v6.8h, v17.8h + ld1 {v17.8b}, [x25], x4 + smlal v23.4s, v6.4h, v18.4h + smlal2 v24.4s, v6.8h, v18.8h + smlal v21.4s, v7.4h, v18.4h + smlal2 v22.4s, v7.8h, v18.8h + ld1 {v18.8b}, [x25], x4 + smlal v23.4s, v7.4h, v19.4h + smlal2 v24.4s, v7.8h, v19.8h + smlal v21.4s, v8.4h, v19.4h + smlal2 v22.4s, v8.8h, v19.8h + ld1 {v19.8b}, [x25], x4 + smlal v23.4s, v8.4h, v20.4h + smlal2 v24.4s, v8.8h, v20.8h + + cbnz w23, PER_CHANNEL_POST1 + cbz w24, SKIP_LEFTSHIFT1 + sqshl v21.4s, v21.4s, v26.4s + sqshl v22.4s, v22.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b OUTZP1 + +SKIP_LEFTSHIFT1: + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + + and v12.16b, v21.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v21.4s, v21.4s, v12.4s + sqrshl v21.4s, v21.4s, v28.4s + + and v12.16b, v22.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v22.4s, v22.4s, v12.4s + sqrshl v22.4s, v22.4s, v28.4s + + and v12.16b, v23.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v23.4s, v23.4s, v12.4s + sqrshl v23.4s, v23.4s, v28.4s + + and v12.16b, v24.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v24.4s, v24.4s, v12.4s + sqrshl v24.4s, v24.4s, v28.4s + b OUTZP1 + +PER_CHANNEL_POST1: + sqshl v21.4s, v21.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + ldr q26, [x12, #16] + + and v12.16b, v21.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v21.4s, v21.4s, v12.4s + sqrshl v21.4s, v21.4s, v28.4s + + and v12.16b, v23.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v23.4s, v23.4s, v12.4s + sqrshl v23.4s, v23.4s, v28.4s + + ldr q27, [x11, #16] + sqshl v22.4s, v22.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + ldr q28, [x13, #16] + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + ld1 {v26.4s}, [x12] + + and v12.16b, v22.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v22.4s, v22.4s, v12.4s + sqrshl v22.4s, v22.4s, v28.4s + + and v12.16b, v24.16b, v28.16b + sshr v12.4s, v12.4s, #31 + sqadd v24.4s, v24.4s, v12.4s + sqrshl v24.4s, v24.4s, v28.4s + + ld1 {v27.4s}, [x11] + ld1 {v28.4s}, [x13] + +OUTZP1: + // Add output zero point + sqadd v21.4s, v21.4s, v29.4s + sqadd v22.4s, v22.4s, v29.4s + sqadd v23.4s, v23.4s, v29.4s + sqadd v24.4s, v24.4s, v29.4s + + // Apply min bound + smax v21.4s, v21.4s, v30.4s + smax v22.4s, v22.4s, v30.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + + // Apply max bound + smin v21.4s, v21.4s, v31.4s + smin v22.4s, v22.4s, v31.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v21.4h, v21.4s + sqxtn2 v21.8h, v22.4s + ld1 {v22.4s}, [x19] + ssubl v9.8h, v9.8b, v25.8b + ssubl v10.8h, v10.8b, v25.8b + sqxtn v23.4h, v23.4s + sqxtn2 v23.8h, v24.4s + ld1 {v24.4s}, [x19] + sqxtn v21.8b, v21.8h + sqxtn2 v21.16b, v23.8h + st1 {v21.8b}, [x0], x6 + mov v23.d[0], v21.d[1] + ld1 {v21.4s}, [x3] + st1 {v23.8b}, [x0], x6 + ssubl v11.8h, v11.8b, v25.8b + ssubl v13.8h, v13.8b, v25.8b + ld1 {v23.4s}, [x3] + ssubl v14.8h, v14.8b, v25.8b + ssubl v15.8h, v15.8b, v25.8b + ssubl v17.8h, v17.8b, v25.8b + ssubl v18.8h, v18.8b, v25.8b + ssubl v19.8h, v19.8b, v25.8b + sub w8, w8, #2 + cmp w8, #2 + bgt HEIGHT1_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + smlal v21.4s, v0.4h, v9.4h + smlal2 v22.4s, v0.8h, v9.8h + ld1 {v12.8b}, [x16] + ssubl v12.8h, v12.8b, v25.8b + smlal v23.4s, v0.4h, v10.4h + smlal2 v24.4s, v0.8h, v10.8h + smlal v21.4s, v1.4h, v10.4h + smlal2 v22.4s, v1.8h, v10.8h + ld1 {v16.8b}, [x17] + smlal v23.4s, v1.4h, v11.4h + smlal2 v24.4s, v1.8h, v11.8h + smlal v21.4s, v2.4h, v11.4h + smlal2 v22.4s, v2.8h, v11.8h + ld1 {v20.8b}, [x25] + smlal v23.4s, v2.4h, v12.4h + smlal2 v24.4s, v2.8h, v12.8h + smlal v21.4s, v3.4h, v13.4h + smlal2 v22.4s, v3.8h, v13.8h + smlal v23.4s, v3.4h, v14.4h + smlal2 v24.4s, v3.8h, v14.8h + smlal v21.4s, v4.4h, v14.4h + smlal2 v22.4s, v4.8h, v14.8h + ssubl v16.8h, v16.8b, v25.8b + smlal v23.4s, v4.4h, v15.4h + smlal2 v24.4s, v4.8h, v15.8h + smlal v21.4s, v5.4h, v15.4h + smlal2 v22.4s, v5.8h, v15.8h + ssubl v20.8h, v20.8b, v25.8b + smlal v23.4s, v5.4h, v16.4h + smlal2 v24.4s, v5.8h, v16.8h + smlal v21.4s, v6.4h, v17.4h + smlal2 v22.4s, v6.8h, v17.8h + smlal v23.4s, v6.4h, v18.4h + smlal2 v24.4s, v6.8h, v18.8h + smlal v21.4s, v7.4h, v18.4h + smlal2 v22.4s, v7.8h, v18.8h + smlal v23.4s, v7.4h, v19.4h + smlal2 v24.4s, v7.8h, v19.8h + smlal v21.4s, v8.4h, v19.4h + smlal2 v22.4s, v8.8h, v19.8h + smlal v23.4s, v8.4h, v20.4h + smlal2 v24.4s, v8.8h, v20.8h + + cbnz w23, PER_CHANNEL_POST2 + cbz w24, SKIP_LEFTSHIFT2 + sqshl v21.4s, v21.4s, v26.4s + sqshl v22.4s, v22.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b OUTZP2 + +SKIP_LEFTSHIFT2: + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v21.4s, v21.4s, v28.4s + sqrshl v22.4s, v22.4s, v28.4s + sqrshl v23.4s, v23.4s, v28.4s + sqrshl v24.4s, v24.4s, v28.4s + b OUTZP2 + +PER_CHANNEL_POST2: + sqshl v21.4s, v21.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + ldr q26, [x12, #16] + sqrshl v21.4s, v21.4s, v28.4s + sqrshl v23.4s, v23.4s, v28.4s + ldr q27, [x11, #16] + sqshl v22.4s, v22.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + ldr q28, [x13, #16] + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v22.4s, v22.4s, v28.4s + sqrshl v24.4s, v24.4s, v28.4s + +OUTZP2: + // Add output zero point + sqadd v21.4s, v21.4s, v29.4s + sqadd v22.4s, v22.4s, v29.4s + sqadd v23.4s, v23.4s, v29.4s + sqadd v24.4s, v24.4s, v29.4s + + // Apply min bound + smax v21.4s, v21.4s, v30.4s + smax v22.4s, v22.4s, v30.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + + // Apply max bound + smin v21.4s, v21.4s, v31.4s + smin v22.4s, v22.4s, v31.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v21.4h, v21.4s + sqxtn2 v21.8h, v22.4s + sqxtn v23.4h, v23.4s + sqxtn2 v23.8h, v24.4s + sqxtn v21.8b, v21.8h + sqxtn2 v21.16b, v23.8h + st1 {v21.8b}, [x0], x6 + mov v23.d[0], v21.d[1] + st1 {v23.8b}, [x0], x6 + b End + +WIDTH1_LEFT: + smlal v21.4s, v0.4h, v9.4h + smlal2 v22.4s, v0.8h, v9.8h + smlal v21.4s, v1.4h, v10.4h + smlal2 v22.4s, v1.8h, v10.8h + smlal v21.4s, v2.4h, v11.4h + smlal2 v22.4s, v2.8h, v11.8h + smlal v21.4s, v3.4h, v13.4h + smlal2 v22.4s, v3.8h, v13.8h + smlal v21.4s, v4.4h, v14.4h + smlal2 v22.4s, v4.8h, v14.8h + smlal v21.4s, v5.4h, v15.4h + smlal2 v22.4s, v5.8h, v15.8h + smlal v21.4s, v6.4h, v17.4h + smlal2 v22.4s, v6.8h, v17.8h + smlal v21.4s, v7.4h, v18.4h + smlal2 v22.4s, v7.8h, v18.8h + smlal v21.4s, v8.4h, v19.4h + smlal2 v22.4s, v8.8h, v19.8h + + cbnz w23, PER_CHANNEL_POST3 + cbz w24, SKIP_LEFTSHIFT3 + sqshl v21.4s, v21.4s, v26.4s + sqshl v22.4s, v22.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + b OUTZP3 + +SKIP_LEFTSHIFT3: + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrshl v21.4s, v21.4s, v28.4s + sqrshl v22.4s, v22.4s, v28.4s + b OUTZP3 + +PER_CHANNEL_POST3: + sqshl v21.4s, v21.4s, v26.4s + sqrdmulh v21.4s, v21.4s, v27.4s + ldr q26, [x12, #16] + sqrshl v21.4s, v21.4s, v28.4s + ldr q27, [x11, #16] + sqshl v22.4s, v22.4s, v26.4s + ldr q28, [x13, #16] + sqrdmulh v22.4s, v22.4s, v27.4s + sqrshl v22.4s, v22.4s, v28.4s + +OUTZP3: + // Add output zero point + sqadd v21.4s, v21.4s, v29.4s + sqadd v22.4s, v22.4s, v29.4s + + // Apply min bound + smax v21.4s, v21.4s, v30.4s + smax v22.4s, v22.4s, v30.4s + + // Apply max bound + smin v21.4s, v21.4s, v31.4s + smin v22.4s, v22.4s, v31.4s + + sqxtn v21.4h, v21.4s + sqxtn2 v21.8h, v22.4s + sqxtn v21.8b, v21.8h + st1 {v21.8b}, [x0], x6 + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8Corner.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8Corner.S new file mode 100644 index 0000000000000000000000000000000000000000..416e1a3a6779f043702a456b7db8e1caedeb7b8c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8Corner.S @@ -0,0 +1,222 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, +// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, size_t acc_min, size_t acc_max, size_t per_channel) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, +// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift +// x12: acc_min, x13: acc_max, x14: per_channel +asm_function ConvDw3x3Int8Corner + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + dup v25.8b, w7 // in_zp + ldr x8, [sp, #32] + dup v26.4s, w8 // out_zp + ldr x9, [sp, #40] // out_multiplier + ldr x10, [sp, #48] // left_shift + ldr x11, [sp, #56] // right_shift + ldr x12, [sp, #64] + dup v30.4s, w12 // acc_min + ldr x13, [sp, #72] + dup v31.4s, w13 // acc_max + ldr x14, [sp, #80] // per_channel + cbnz x14, PerChannelDump + PerLayerDump: + ld1r {v27.4s}, [x9] + ld1r {v28.4s}, [x10] + ld1r {v29.4s}, [x11] + b ContinueFunc + PerChannelDump: + ld1 {v27.4s}, [x9], #16 + ld1 {v28.4s}, [x10], #16 + ld1 {v29.4s}, [x11], #16 + ContinueFunc: + + mov x12, #2 + mul x21, x6, x12 // x6 * 2 + mov x12, #3 + mul x22, x21, x12 // x6 * 3 * 2 + + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + mov x12, x1 + mov x13, x2 + + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + add x19, x1, x4 + ld1 {v4.8h}, [x13], x21 // weight + add x20, x2, x22 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + ld1 {v5.8h}, [x13], x21 + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + ld1 {v6.8h}, [x20], x21 + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x20], x21 + + cmp x6, #8 + ble LoopC8Post + + LoopC8: + add x1, x1, #8 + add x2, x2, #16 + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + mov x12, x1 + mov x13, x2 + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + ld1 {v4.8h}, [x13], x21 // weight + add x19, x1, x4 + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + add x20, x2, x22 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + smlal v23.4s, v2.4h, v6.4h + ld1 {v5.8h}, [x13], x21 + smlal2 v24.4s, v2.8h, v6.8h + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + smlal v23.4s, v3.4h, v7.4h + ld1 {v6.8h}, [x20], x21 + smlal2 v24.4s, v3.8h, v7.8h + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x20], x21 + + cbnz x14, PerChannelPostLoop + ldr w8, [x10] + cbz w8, RightShiftLoop + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZpLoop + + RightShiftLoop: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZpLoop + PerChannelPostLoop: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v24.4s, v24.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v24.4s, v24.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + + AddZpLoop: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + sub x6, x6, #8 + cmp x6, #8 + bgt LoopC8 + + LoopC8Post: + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + smlal v23.4s, v2.4h, v6.4h + smlal2 v24.4s, v2.8h, v6.8h + smlal v23.4s, v3.4h, v7.4h + smlal2 v24.4s, v3.8h, v7.8h + + cbnz x14, PerChannelPost + ldr w8, [x10] + cbz w8, RightShift + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZp + + RightShift: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZp + PerChannelPost: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v24.4s, v24.4s, v29.4s + + AddZp: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8Horizontal.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8Horizontal.S new file mode 100644 index 0000000000000000000000000000000000000000..379154e68c62ff751e94dfabbb0113589a8875c0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8Horizontal.S @@ -0,0 +1,255 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, +// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, size_t acc_min, size_t acc_max, size_t per_channel) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, +// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift +// x12: acc_min, x13: acc_max, x14: per_channel +asm_function ConvDw3x3Int8Horizontal + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + + dup v25.8b, w7 // in_zp + ldr x8, [sp, #48] + dup v26.4s, w8 // out_zp + ldr x9, [sp, #56] // out_multiplier + ldr x10, [sp, #64] // left_shift + ldr x11, [sp, #72] // right_shift + ldr x12, [sp, #80] + dup v30.4s, w12 // acc_min + ldr x13, [sp, #88] + dup v31.4s, w13 // acc_max + ldr x14, [sp, #96] // per_channel + cbnz x14, PerChannelDump + PerLayerDump: + ld1r {v27.4s}, [x9] + ld1r {v28.4s}, [x10] + ld1r {v29.4s}, [x11] + b ContinueFunc + PerChannelDump: + ld1 {v27.4s}, [x9], #16 + ld1 {v28.4s}, [x10], #16 + ld1 {v29.4s}, [x11], #16 + ContinueFunc: + ldr x12, [sp, #80] + dup v30.4s, w12 // acc_min + ldr x13, [sp, #88] + dup v31.4s, w13 // acc_max + + mov x12, #2 + mul x23, x6, x12 // x6 * 2 + mov x12, #3 + mul x24, x23, x12 // x6 * 3 * 2 + + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + mov x12, x1 + mov x13, x2 + + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + add x19, x1, x4 + ld1 {v4.8h}, [x13], x23 // weight + add x20, x2, x24 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + ld1 {v5.8h}, [x13], x23 + add x21, x19, x4 + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + add x22, x20, x24 + ld1 {v6.8h}, [x20], x23 + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x20], x23 + ld1 {v16.8b}, [x21], x5 + ssubl v16.8h, v16.8b, v25.8b + ld1 {v18.8h}, [x22], x23 + ld1 {v17.8b}, [x21], x5 + ssubl v17.8h, v17.8b, v25.8b + ld1 {v19.8h}, [x22], x23 + + cmp x6, #8 + ble LoopC8Post + + LoopC8: + add x1, x1, #8 + add x2, x2, #16 + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + mov x12, x1 + mov x13, x2 + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + ld1 {v4.8h}, [x13], x23 // weight + add x19, x1, x4 + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + + add x20, x2, x24 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + smlal v23.4s, v2.4h, v6.4h + ld1 {v5.8h}, [x13], x23 + smlal2 v24.4s, v2.8h, v6.8h + + add x21, x19, x4 + add x22, x20, x24 + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + smlal v23.4s, v3.4h, v7.4h + ld1 {v6.8h}, [x20], x23 + smlal2 v24.4s, v3.8h, v7.8h + + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + smlal v23.4s, v16.4h, v18.4h + ld1 {v7.8h}, [x20], x23 + smlal2 v24.4s, v16.8h, v18.8h + + ld1 {v16.8b}, [x21], x5 + ssubl v16.8h, v16.8b, v25.8b + smlal v23.4s, v17.4h, v19.4h + ld1 {v18.8h}, [x22], x23 + smlal2 v24.4s, v17.8h, v19.8h + ld1 {v17.8b}, [x21], x5 + ssubl v17.8h, v17.8b, v25.8b + ld1 {v19.8h}, [x22], x23 + + cbnz x14, PerChannelPostLoop + ldr w8, [x10] + cbz w8, RightShiftLoop + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZpLoop + + RightShiftLoop: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZpLoop + PerChannelPostLoop: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v24.4s, v24.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v24.4s, v24.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + + AddZpLoop: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + sub x6, x6, #8 + cmp x6, #8 + bgt LoopC8 + + LoopC8Post: + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + smlal v23.4s, v2.4h, v6.4h + smlal2 v24.4s, v2.8h, v6.8h + smlal v23.4s, v3.4h, v7.4h + smlal2 v24.4s, v3.8h, v7.8h + smlal v23.4s, v16.4h, v18.4h + smlal2 v24.4s, v16.8h, v18.8h + smlal v23.4s, v17.4h, v19.4h + smlal2 v24.4s, v17.8h, v19.8h + + cbnz x14, PerChannelPost + ldr w8, [x10] + cbz w8, RightShift + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZp + + RightShift: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZp + PerChannelPost: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v24.4s, v24.4s, v29.4s + + AddZp: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8Stride2.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8Stride2.S new file mode 100644 index 0000000000000000000000000000000000000000..e3c12602ecad3abef9cd5125043ae4dd6056a9c0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8Stride2.S @@ -0,0 +1,474 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size, +// int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, +// int *out_multiplier, int *left_shift, int *right_shift, int32_t acc_min, int32_t acc_max +// size_t per_channel) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: in_zp +// w10: out_zp +// w11: out_multiplier +// w12: left_shift +// w13: right_shift +// w14: acc_min +// w15: acc_max +// w16: per_channel + +asm_function ConvDw3x3Int8Stride2 + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + ldr x23, [sp, #256] // per_channel + + add x19, x3, #16 + add w20, w6, w6 // channel * 2 + dup v28.8b, w9 // in_zp + ldr w24, [x12] + + dup v29.4s, w10 + dup v30.4s, w14 + dup v31.4s, w15 + // Load weights + ld1 {v0.8h}, [x2], x20 + ld1 {v1.8h}, [x2], x20 + ld1 {v2.8h}, [x2], x20 + ld1 {v3.8h}, [x2], x20 + ld1 {v4.8h}, [x2], x20 + ld1 {v5.8h}, [x2], x20 + ld1 {v6.8h}, [x2], x20 + ld1 {v7.8h}, [x2], x20 + ld1 {v8.8h}, [x2], x20 + + mov x16, x1 + add x17, x16, x5 + add x25, x17, x5 + ld1 {v9.8b}, [x16], x4 + ld1 {v10.8b}, [x16], x4 + ssubl v9.8h, v9.8b, v28.8b + ld1 {v11.8b}, [x16], x4 + ssubl v10.8h, v10.8b, v28.8b + ld1 {v14.8b}, [x17], x4 + ssubl v11.8h, v11.8b, v28.8b + ld1 {v15.8b}, [x17], x4 + ssubl v14.8h, v14.8b, v28.8b + ld1 {v16.8b}, [x17], x4 + ssubl v15.8h, v15.8b, v28.8b + ld1 {v19.8b}, [x25], x4 + ssubl v16.8h, v16.8b, v28.8b + ld1 {v20.8b}, [x25], x4 + ssubl v19.8h, v19.8b, v28.8b + ld1 {v21.8b}, [x25], x4 + ssubl v20.8h, v20.8b, v28.8b + ssubl v21.8h, v21.8b, v28.8b + + ld1 {v24.4s}, [x3] + ld1 {v25.4s}, [x19] + ld1 {v26.4s}, [x3] + ld1 {v27.4s}, [x19] + + cmp w8, #2 + beq WIDTH2_LEFT + cmp w8, #1 + beq WIDTH1_LEFT + +HEIGHT1_LOOP: + smlal v24.4s, v0.4h, v9.4h + ld1 {v12.8b}, [x16], x4 + smlal2 v25.4s, v0.8h, v9.8h + ld1 {v17.8b}, [x17], x4 + ssubl v12.8h, v12.8b, v28.8b + smlal v26.4s, v0.4h, v11.4h + ld1 {v22.8b}, [x25], x4 + ssubl v17.8h, v17.8b, v28.8b + smlal2 v27.4s, v0.8h, v11.8h + ld1 {v13.8b}, [x16], x4 + ssubl v22.8h, v22.8b, v28.8b + smlal v24.4s, v1.4h, v10.4h + ld1 {v18.8b}, [x17], x4 + ssubl v13.8h, v13.8b, v28.8b + smlal2 v25.4s, v1.8h, v10.8h + ld1 {v23.8b}, [x25], x4 + ssubl v18.8h, v18.8b, v28.8b + smlal v26.4s, v1.4h, v12.4h + mov v9.16b, v13.16b + ssubl v23.8h, v23.8b, v28.8b + smlal2 v27.4s, v1.8h, v12.8h + ld1 {v10.8b}, [x16], x4 + smlal v24.4s, v2.4h, v11.4h + smlal2 v25.4s, v2.8h, v11.8h + ld1 {v11.8b}, [x16], x4 + ssubl v10.8h, v10.8b, v28.8b + smlal v26.4s, v2.4h, v13.4h + ssubl v11.8h, v11.8b, v28.8b + smlal2 v27.4s, v2.8h, v13.8h + + smlal v24.4s, v3.4h, v14.4h + smlal2 v25.4s, v3.8h, v14.8h + mov v14.16b, v18.16b + smlal v26.4s, v3.4h, v16.4h + smlal2 v27.4s, v3.8h, v16.8h + smlal v24.4s, v4.4h, v15.4h + smlal2 v25.4s, v4.8h, v15.8h + ld1 {v15.8b}, [x17], x4 + smlal v26.4s, v4.4h, v17.4h + smlal2 v27.4s, v4.8h, v17.8h + smlal v24.4s, v5.4h, v16.4h + smlal2 v25.4s, v5.8h, v16.8h + ld1 {v16.8b}, [x17], x4 + ssubl v15.8h, v15.8b, v28.8b + smlal v26.4s, v5.4h, v18.4h + ssubl v16.8h, v16.8b, v28.8b + smlal2 v27.4s, v5.8h, v18.8h + + smlal v24.4s, v6.4h, v19.4h + smlal2 v25.4s, v6.8h, v19.8h + mov v19.16b, v23.16b + smlal v26.4s, v6.4h, v21.4h + smlal2 v27.4s, v6.8h, v21.8h + smlal v24.4s, v7.4h, v20.4h + smlal2 v25.4s, v7.8h, v20.8h + ld1 {v20.8b}, [x25], x4 + smlal v26.4s, v7.4h, v22.4h + smlal2 v27.4s, v7.8h, v22.8h + smlal v24.4s, v8.4h, v21.4h + smlal2 v25.4s, v8.8h, v21.8h + ld1 {v21.8b}, [x25], x4 + ssubl v20.8h, v20.8b, v28.8b + smlal v26.4s, v8.4h, v23.4h + ssubl v21.8h, v21.8b, v28.8b + smlal2 v27.4s, v8.8h, v23.8h + + cbnz w23, PER_CHANNEL_POST1 + ld1r {v17.4s}, [x11] + cbz w24, SKIP_LEFTSHIFT1 + ld1r {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + sqshl v25.4s, v25.4s, v12.4s + sqshl v26.4s, v26.4s, v12.4s + sqshl v27.4s, v27.4s, v12.4s + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v17.4s + b OUTZP1 + +SKIP_LEFTSHIFT1: + ld1r {v22.4s}, [x13] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v17.4s + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v22.4s + sqrshl v26.4s, v26.4s, v22.4s + sqrshl v27.4s, v27.4s, v22.4s + b OUTZP1 + +PER_CHANNEL_POST1: + ld1 {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + ldr q13, [x12, #16] + sqshl v25.4s, v25.4s, v13.4s + ld1 {v17.4s}, [x11] + sqshl v26.4s, v26.4s, v12.4s + sqshl v27.4s, v27.4s, v13.4s + ldr q18, [x11, #16] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v18.4s + ld1 {v22.4s}, [x13] + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v18.4s + ldr q23, [x13, #16] + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v23.4s + sqrshl v26.4s, v26.4s, v22.4s + sqrshl v27.4s, v27.4s, v23.4s + +OUTZP1: + // Add output zero point + sqadd v24.4s, v24.4s, v29.4s + sqadd v25.4s, v25.4s, v29.4s + sqadd v26.4s, v26.4s, v29.4s + sqadd v27.4s, v27.4s, v29.4s + + // Apply min bound + smax v24.4s, v24.4s, v30.4s + smax v25.4s, v25.4s, v30.4s + smax v26.4s, v26.4s, v30.4s + smax v27.4s, v27.4s, v30.4s + + // Apply max bound + smin v24.4s, v24.4s, v31.4s + smin v25.4s, v25.4s, v31.4s + smin v26.4s, v26.4s, v31.4s + smin v27.4s, v27.4s, v31.4s + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + ld1 {v25.4s}, [x19] + sqxtn v26.4h, v26.4s + sqxtn2 v26.8h, v27.4s + ld1 {v27.4s}, [x19] + sqxtn v24.8b, v24.8h + sqxtn2 v24.16b, v26.8h + st1 {v24.8b}, [x0], x6 + mov v26.d[0], v24.d[1] + ld1 {v24.4s}, [x3] + st1 {v26.8b}, [x0], x6 + ld1 {v26.4s}, [x3] + sub w8, w8, #2 + cmp w8, #2 + bgt HEIGHT1_LOOP + + cmp w8, #2 + blt WIDTH1_LEFT + +WIDTH2_LEFT: + smlal v24.4s, v0.4h, v9.4h + ld1 {v12.8b}, [x16], x4 + smlal2 v25.4s, v0.8h, v9.8h + ld1 {v17.8b}, [x17], x4 + ssubl v12.8h, v12.8b, v28.8b + smlal v26.4s, v0.4h, v11.4h + ld1 {v22.8b}, [x25], x4 + ssubl v17.8h, v17.8b, v28.8b + smlal2 v27.4s, v0.8h, v11.8h + ld1 {v13.8b}, [x16], x4 + ssubl v22.8h, v22.8b, v28.8b + smlal v24.4s, v1.4h, v10.4h + ld1 {v18.8b}, [x17], x4 + ssubl v13.8h, v13.8b, v28.8b + smlal2 v25.4s, v1.8h, v10.8h + ld1 {v23.8b}, [x25], x4 + ssubl v18.8h, v18.8b, v28.8b + smlal v26.4s, v1.4h, v12.4h + ssubl v23.8h, v23.8b, v28.8b + smlal2 v27.4s, v1.8h, v12.8h + smlal v24.4s, v2.4h, v11.4h + smlal2 v25.4s, v2.8h, v11.8h + smlal v26.4s, v2.4h, v13.4h + smlal2 v27.4s, v2.8h, v13.8h + + smlal v24.4s, v3.4h, v14.4h + smlal2 v25.4s, v3.8h, v14.8h + smlal v26.4s, v3.4h, v16.4h + smlal2 v27.4s, v3.8h, v16.8h + smlal v24.4s, v4.4h, v15.4h + smlal2 v25.4s, v4.8h, v15.8h + smlal v26.4s, v4.4h, v17.4h + smlal2 v27.4s, v4.8h, v17.8h + smlal v24.4s, v5.4h, v16.4h + smlal2 v25.4s, v5.8h, v16.8h + smlal v26.4s, v5.4h, v18.4h + smlal2 v27.4s, v5.8h, v18.8h + + smlal v24.4s, v6.4h, v19.4h + smlal2 v25.4s, v6.8h, v19.8h + smlal v26.4s, v6.4h, v21.4h + smlal2 v27.4s, v6.8h, v21.8h + smlal v24.4s, v7.4h, v20.4h + smlal2 v25.4s, v7.8h, v20.8h + smlal v26.4s, v7.4h, v22.4h + smlal2 v27.4s, v7.8h, v22.8h + smlal v24.4s, v8.4h, v21.4h + smlal2 v25.4s, v8.8h, v21.8h + smlal v26.4s, v8.4h, v23.4h + smlal2 v27.4s, v8.8h, v23.8h + + cbnz w23, PER_CHANNEL_POST2 + ld1r {v17.4s}, [x11] + cbz w24, SKIP_LEFTSHIFT2 + ld1r {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + sqshl v25.4s, v25.4s, v12.4s + sqshl v26.4s, v26.4s, v12.4s + sqshl v27.4s, v27.4s, v12.4s + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v17.4s + b OUTZP2 + +SKIP_LEFTSHIFT2: + ld1r {v22.4s}, [x13] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v17.4s + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v22.4s + sqrshl v26.4s, v26.4s, v22.4s + sqrshl v27.4s, v27.4s, v22.4s + b OUTZP2 + +PER_CHANNEL_POST2: + ld1 {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + ldr q13, [x12, #16] + sqshl v25.4s, v25.4s, v13.4s + ld1 {v17.4s}, [x11] + sqshl v26.4s, v26.4s, v12.4s + sqshl v27.4s, v27.4s, v13.4s + ldr q18, [x11, #16] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v18.4s + ld1 {v22.4s}, [x13] + sqrdmulh v26.4s, v26.4s, v17.4s + sqrdmulh v27.4s, v27.4s, v18.4s + ldr q23, [x13, #16] + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v23.4s + sqrshl v26.4s, v26.4s, v22.4s + sqrshl v27.4s, v27.4s, v23.4s + +OUTZP2: + // Add output zero point + sqadd v24.4s, v24.4s, v29.4s + sqadd v25.4s, v25.4s, v29.4s + sqadd v26.4s, v26.4s, v29.4s + sqadd v27.4s, v27.4s, v29.4s + + // Apply min bound + smax v24.4s, v24.4s, v30.4s + smax v25.4s, v25.4s, v30.4s + smax v26.4s, v26.4s, v30.4s + smax v27.4s, v27.4s, v30.4s + + // Apply max bound + smin v24.4s, v24.4s, v31.4s + smin v25.4s, v25.4s, v31.4s + smin v26.4s, v26.4s, v31.4s + smin v27.4s, v27.4s, v31.4s + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v26.4h, v26.4s + sqxtn2 v26.8h, v27.4s + sqxtn v24.8b, v24.8h + sqxtn2 v24.16b, v26.8h + st1 {v24.8b}, [x0], x6 + mov v26.d[0], v24.d[1] + st1 {v26.8b}, [x0], x6 + b End + +WIDTH1_LEFT: + smlal v24.4s, v0.4h, v9.4h + smlal2 v25.4s, v0.8h, v9.8h + smlal v24.4s, v1.4h, v10.4h + smlal2 v25.4s, v1.8h, v10.8h + smlal v24.4s, v2.4h, v11.4h + smlal2 v25.4s, v2.8h, v11.8h + smlal v24.4s, v3.4h, v14.4h + smlal2 v25.4s, v3.8h, v14.8h + smlal v24.4s, v4.4h, v15.4h + smlal2 v25.4s, v4.8h, v15.8h + smlal v24.4s, v5.4h, v16.4h + smlal2 v25.4s, v5.8h, v16.8h + smlal v24.4s, v6.4h, v19.4h + smlal2 v25.4s, v6.8h, v19.8h + smlal v24.4s, v7.4h, v20.4h + smlal2 v25.4s, v7.8h, v20.8h + smlal v24.4s, v8.4h, v21.4h + smlal2 v25.4s, v8.8h, v21.8h + + cbnz w23, PER_CHANNEL_POST3 + ld1r {v17.4s}, [x11] + cbz w24, SKIP_LEFTSHIFT3 + ld1r {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + sqshl v25.4s, v25.4s, v12.4s + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + b OUTZP3 + +SKIP_LEFTSHIFT3: + ld1r {v22.4s}, [x13] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v17.4s + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v22.4s + b OUTZP3 + +PER_CHANNEL_POST3: + ld1 {v12.4s}, [x12] + sqshl v24.4s, v24.4s, v12.4s + ldr q13, [x12, #16] + sqshl v25.4s, v25.4s, v13.4s + ld1 {v17.4s}, [x11] + ldr q18, [x11, #16] + sqrdmulh v24.4s, v24.4s, v17.4s + sqrdmulh v25.4s, v25.4s, v18.4s + ld1 {v22.4s}, [x13] + ldr q23, [x13, #16] + sqrshl v24.4s, v24.4s, v22.4s + sqrshl v25.4s, v25.4s, v23.4s + +OUTZP3: + // Add output zero point + sqadd v24.4s, v24.4s, v29.4s + sqadd v25.4s, v25.4s, v29.4s + + // Apply min bound + smax v24.4s, v24.4s, v30.4s + smax v25.4s, v25.4s, v30.4s + + // Apply max bound + smin v24.4s, v24.4s, v31.4s + smin v25.4s, v25.4s, v31.4s + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v24.8b, v24.8h + st1 {v24.8b}, [x0], x6 + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8Vertical.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8Vertical.S new file mode 100644 index 0000000000000000000000000000000000000000..706bc9fe811fe406c58d35209858639bd11cdcbc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Int8Vertical.S @@ -0,0 +1,245 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, +// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, size_t acc_min, size_t acc_max, size_t per_channel) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step, +// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift +// x12: acc_min, x13: acc_max, x14: per_channel +asm_function ConvDw3x3Int8Vertical + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + dup v25.8b, w7 // in_zp + ldr x8, [sp, #32] + dup v26.4s, w8 // out_zp + ldr x9, [sp, #40] // out_multiplier + ldr x10, [sp, #48] // left_shift + ldr x11, [sp, #56] // right_shift + ldr x12, [sp, #64] + dup v30.4s, w12 // acc_min + ldr x13, [sp, #72] + dup v31.4s, w13 // acc_max + ldr x14, [sp, #80] // per_channel + cbnz x14, PerChannelDump + PerLayerDump: + ld1r {v27.4s}, [x9] + ld1r {v28.4s}, [x10] + ld1r {v29.4s}, [x11] + b ContinueFunc + PerChannelDump: + ld1 {v27.4s}, [x9], #16 + ld1 {v28.4s}, [x10], #16 + ld1 {v29.4s}, [x11], #16 + ContinueFunc: + + mov x12, #2 + mul x21, x6, x12 // x6 * 2 + mov x12, #3 + mul x22, x21, x12 // x6 * 3 * 2 + + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + mov x12, x1 + mov x13, x2 + + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + add x19, x1, x4 + ld1 {v4.8h}, [x13], x21 // weight + add x20, x2, x22 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + ld1 {v5.8h}, [x13], x21 + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + ld1 {v6.8h}, [x20], x21 + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x20], x21 + ld1 {v16.8b}, [x12], x5 + ssubl v16.8h, v16.8b, v25.8b + ld1 {v18.8h}, [x13], x21 + ld1 {v17.8b}, [x19], x5 + ssubl v17.8h, v17.8b, v25.8b + ld1 {v19.8h}, [x20], x21 + + cmp x6, #8 + ble LoopC8Post + + LoopC8: + add x1, x1, #8 + add x2, x2, #16 + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + mov x12, x1 + mov x13, x2 + ld1 {v0.8b}, [x12], x5 + ssubl v0.8h, v0.8b, v25.8b + ld1 {v4.8h}, [x13], x21 // weight + add x19, x1, x4 + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + + add x20, x2, x22 + ld1 {v1.8b}, [x12], x5 + ssubl v1.8h, v1.8b, v25.8b + smlal v23.4s, v2.4h, v6.4h + ld1 {v5.8h}, [x13], x21 + smlal2 v24.4s, v2.8h, v6.8h + + ld1 {v2.8b}, [x19], x5 + ssubl v2.8h, v2.8b, v25.8b + smlal v23.4s, v3.4h, v7.4h + ld1 {v6.8h}, [x20], x21 + smlal2 v24.4s, v3.8h, v7.8h + + ld1 {v3.8b}, [x19], x5 + ssubl v3.8h, v3.8b, v25.8b + smlal v23.4s, v16.4h, v18.4h + ld1 {v7.8h}, [x20], x21 + smlal2 v24.4s, v16.8h, v18.8h + + ld1 {v16.8b}, [x12], x5 + ssubl v16.8h, v16.8b, v25.8b + smlal v23.4s, v17.4h, v19.4h + ld1 {v18.8h}, [x13], x21 + smlal2 v24.4s, v17.8h, v19.8h + ld1 {v17.8b}, [x19], x5 + ssubl v17.8h, v17.8b, v25.8b + ld1 {v19.8h}, [x20], x21 + + cbnz x14, PerChannelPostLoop + ldr w8, [x10] + cbz w8, RightShiftLoop + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZpLoop + + RightShiftLoop: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZpLoop + PerChannelPostLoop: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v24.4s, v24.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v24.4s, v24.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + + AddZpLoop: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + sub x6, x6, #8 + cmp x6, #8 + bgt LoopC8 + + LoopC8Post: + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + smlal v23.4s, v2.4h, v6.4h + smlal2 v24.4s, v2.8h, v6.8h + smlal v23.4s, v3.4h, v7.4h + smlal2 v24.4s, v3.8h, v7.8h + smlal v23.4s, v16.4h, v18.4h + smlal2 v24.4s, v16.8h, v18.8h + smlal v23.4s, v17.4h, v19.4h + smlal2 v24.4s, v17.8h, v19.8h + + cbnz x14, PerChannelPost + ldr w8, [x10] + cbz w8, RightShift + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + b AddZp + + RightShift: + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v23.4s, v23.4s, v29.4s + sqrshl v24.4s, v24.4s, v29.4s + b AddZp + PerChannelPost: + sqshl v23.4s, v23.4s, v28.4s + ld1 {v28.4s}, [x10], #16 + sqrdmulh v23.4s, v23.4s, v27.4s + ld1 {v27.4s}, [x9], #16 + sqrshl v23.4s, v23.4s, v29.4s + ld1 {v29.4s}, [x11], #16 + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v24.4s, v24.4s, v27.4s + sqrshl v24.4s, v24.4s, v29.4s + + AddZp: + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Line.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Line.S new file mode 100644 index 0000000000000000000000000000000000000000..4e2c5bca4a1199cf60073b48c2dd9d900971a59b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDw3x3Line.S @@ -0,0 +1,203 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel, +// bool relu, bool relu6) + +// x0: dst, x1: lines, x2: weight, x3: bias, x4: width, x5: ori_channel, x6: relu, x7: relu6 +asm_function ConvDw3x3Line + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + + ldr x8, [x1] + ldr x9, [x1, #8] + ldr x10, [x1, #16] + mov x11, x5 + mov x16, #4 + mul x16, x5, x16 + + mov w14, #6 + dup v30.4s, w14 + scvtf v30.4s, v30.4s + + LoopC4: + cbz x3, NoBias + ld1 {v31.4s}, [x3], #16 + NoBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + mov x12, x0 + mov x13, x4 + + cmp x13, #2 + blt LoopOwRemain + LoopOw2: + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x8], #64 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x9], #64 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64 + fmul v24.4s, v12.4s, v0.4s + fmul v25.4s, v13.4s, v1.4s + fmul v26.4s, v14.4s, v2.4s + fmul v27.4s, v15.4s, v3.4s + fmla v24.4s, v16.4s, v4.4s + fmla v25.4s, v17.4s, v5.4s + fmla v26.4s, v18.4s, v6.4s + fmla v27.4s, v19.4s, v7.4s + fmla v24.4s, v20.4s, v8.4s + fmla v25.4s, v21.4s, v9.4s + fmla v26.4s, v22.4s, v10.4s + fmla v27.4s, v23.4s, v11.4s + + fadd v28.4s, v25.4s, v26.4s + fadd v28.4s, v28.4s, v24.4s + fsub v29.4s, v27.4s, v26.4s + fadd v29.4s, v29.4s, v25.4s + + cbz x3, Activation + Bias: + fadd v28.4s, v28.4s, v31.4s + fadd v29.4s, v29.4s, v31.4s + + Activation: + cbnz x7, Relu6 + cbnz x6, Relu + b Write + Relu6: + fmin v28.4s, v28.4s, v30.4s + fmin v29.4s, v29.4s, v30.4s + Relu: + movi v27.16b, #0 + fmax v28.4s, v28.4s, v27.4s + fmax v29.4s, v29.4s, v27.4s + Write: + add x15, x12, x16 + cmp x11, #4 + bge Write4 + cmp x11, #3 + beq Write3 + cmp x11, #2 + beq Write2 + cmp x11, #1 + beq Write1 + + Write1: + str s28, [x12] + str s29, [x15] + b WriteEnd + Write2: + st1 {v28.2s}, [x12] + st1 {v29.2s}, [x15] + b WriteEnd + Write3: + st1 {v28.2s}, [x12] + add x17, x12, #8 + st1 {v28.s}[2], [x17] + st1 {v29.2s}, [x15] + add x18, x15, #8 + st1 {v29.s}[2], [x18] + b WriteEnd + Write4: + st1 {v28.4s}, [x12] + st1 {v29.4s}, [x15] + + WriteEnd: + add x12, x15, x16 + sub x13, x13, #2 + cmp x13, #2 + bge LoopOw2 + cmp x13, #0 + beq LoopOwEnd + + LoopOwRemain: + ld1 {v12.4s, v13.4s, v14.4s}, [x8] + add x8, x8, #64 + ld1 {v16.4s, v17.4s, v18.4s}, [x9] + add x9, x9, #64 + ld1 {v20.4s, v21.4s, v22.4s}, [x10] + add x10, x10, #64 + fmul v24.4s, v12.4s, v0.4s + fmul v25.4s, v13.4s, v1.4s + fmul v26.4s, v14.4s, v2.4s + + fmla v24.4s, v16.4s, v4.4s + fmla v25.4s, v17.4s, v5.4s + fmla v26.4s, v18.4s, v6.4s + + fmla v24.4s, v20.4s, v8.4s + fmla v25.4s, v21.4s, v9.4s + fmla v26.4s, v22.4s, v10.4s + + fadd v28.4s, v25.4s, v26.4s + fadd v28.4s, v28.4s, v24.4s + + cbz x3, ActivationRemain + BiasRemain: + fadd v28.4s, v28.4s, v31.4s + + ActivationRemain: + cbnz x7, Relu6Remain + cbnz x6, ReluRemain + b WriteRemain + Relu6Remain: + fmin v28.4s, v28.4s, v30.4s + ReluRemain: + movi v27.16b, #0 + fmax v28.4s, v28.4s, v27.4s + WriteRemain: + cmp x11, #4 + bge Write4Remain + cmp x11, #3 + beq Write3Remain + cmp x11, #2 + beq Write2Remain + cmp x11, #1 + beq Write1Remain + + Write1Remain: + str s28, [x12] + b LoopOwEnd + Write2Remain: + st1 {v28.2s}, [x12] + b LoopOwEnd + Write3Remain: + st1 {v28.2s}, [x12] + add x17, x12, #8 + st1 {v28.s}[2], [x17] + b LoopOwEnd + Write4Remain: + st1 {v28.4s}, [x12] + + LoopOwEnd: + subs x11, x11, #4 + add x0, x0, #16 + bgt LoopC4 + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Border.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Border.S new file mode 100644 index 0000000000000000000000000000000000000000..356dea125f5d4306d138a9069aee99d51f8d6379 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Border.S @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: in_kh_step, x7: in_kw_step, +// x8: kernel_w, x9: relu, x10: relu6 +asm_function ConvDwFp32Border + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + + ld1 {v0.4s}, [x3] // bias + movi v1.4s, #6 // relu 6 + scvtf v1.4s, v1.4s + dup v2.4s, wzr // relu + + mov x13, x1 + mov x14, x2 + LoopH: + mov x15, x13 + mov x16, x14 + mov x17, x5 + LoopW: + ld1 {v3.4s}, [x15], x7 + ld1 {v4.4s}, [x16], #16 + fmla v0.4s, v3.4s, v4.4s + subs x17, x17, #1 + bne LoopW + subs x4, x4, #1 + add x13, x13, x6 + add x14, x14, x8 + bne LoopH + cbnz x10, Relu6 + cbnz x9, Relu + b Write + Relu6: + fmin v0.4s, v0.4s, v1.4s + Relu: + fmax v0.4s, v0.4s, v2.4s + Write: + st1 {v0.4s}, [x0] + + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Center.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Center.S new file mode 100644 index 0000000000000000000000000000000000000000..6f30c3ac0392ad507d618224927b29ee25ddd806 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Center.S @@ -0,0 +1,313 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step +// x14: relu, x15: relu6 +asm_function ConvDwFp32Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + + ld1 {v24.4s}, [x3] + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + LoopH: + mov x23, x1 + mov x24, x5 + mov x3, x0 + cmp x24, #8 + blt LoopW + cmp x24, #16 + blt LoopW8 + + LoopW16: + mov x19, #16 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + mov v8.16b, v24.16b + mov v9.16b, v24.16b + mov v10.16b, v24.16b + mov v11.16b, v24.16b + mov v12.16b, v24.16b + mov v13.16b, v24.16b + mov v14.16b, v24.16b + mov v15.16b, v24.16b + LoopKh16: + mov x25, x7 + mov x21, x16 + LoopKw16: + mov x22, x21 + ld1 {v25.4s}, [x17], #16 + ld1 {v16.4s}, [x22], x11 + ld1 {v17.4s}, [x22], x11 + fmla v0.4s, v16.4s, v25.4s + fmla v1.4s, v17.4s, v25.4s + ld1 {v18.4s}, [x22], x11 + ld1 {v19.4s}, [x22], x11 + fmla v2.4s, v18.4s, v25.4s + fmla v3.4s, v19.4s, v25.4s + ld1 {v20.4s}, [x22], x11 + ld1 {v21.4s}, [x22], x11 + fmla v4.4s, v20.4s, v25.4s + fmla v5.4s, v21.4s, v25.4s + ld1 {v22.4s}, [x22], x11 + ld1 {v23.4s}, [x22], x11 + fmla v6.4s, v22.4s, v25.4s + fmla v7.4s, v23.4s, v25.4s + ld1 {v16.4s}, [x22], x11 + ld1 {v17.4s}, [x22], x11 + fmla v8.4s, v16.4s, v25.4s + fmla v9.4s, v17.4s, v25.4s + ld1 {v18.4s}, [x22], x11 + ld1 {v19.4s}, [x22], x11 + fmla v10.4s, v18.4s, v25.4s + fmla v11.4s, v19.4s, v25.4s + ld1 {v20.4s}, [x22], x11 + ld1 {v21.4s}, [x22], x11 + fmla v12.4s, v20.4s, v25.4s + fmla v13.4s, v21.4s, v25.4s + ld1 {v22.4s}, [x22], x11 + ld1 {v23.4s}, [x22], x11 + fmla v14.4s, v22.4s, v25.4s + fmla v15.4s, v23.4s, v25.4s + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw16 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh16 + cbnz x15, Relu616 + cbnz x14, Relu16 + b Write16 + Relu616: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + fmin v8.4s, v8.4s, v26.4s + fmin v9.4s, v9.4s, v26.4s + fmin v10.4s, v10.4s, v26.4s + fmin v11.4s, v11.4s, v26.4s + fmin v12.4s, v12.4s, v26.4s + fmin v13.4s, v13.4s, v26.4s + fmin v14.4s, v14.4s, v26.4s + fmin v15.4s, v15.4s, v26.4s + Relu16: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + fmax v8.4s, v8.4s, v27.4s + fmax v9.4s, v9.4s, v27.4s + fmax v10.4s, v10.4s, v27.4s + fmax v11.4s, v11.4s, v27.4s + fmax v12.4s, v12.4s, v27.4s + fmax v13.4s, v13.4s, v27.4s + fmax v14.4s, v14.4s, v27.4s + fmax v15.4s, v15.4s, v27.4s + Write16: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + st1 {v8.4s}, [x3], x9 + st1 {v9.4s}, [x3], x9 + st1 {v10.4s}, [x3], x9 + st1 {v11.4s}, [x3], x9 + st1 {v12.4s}, [x3], x9 + st1 {v13.4s}, [x3], x9 + st1 {v14.4s}, [x3], x9 + st1 {v15.4s}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #16 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + blt LoopW + cmp x24, #16 + bge LoopW16 + LoopW8: + mov x19, #8 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + LoopKh8: + mov x25, x7 + mov x21, x16 + LoopKw8: + mov x22, x21 + ld1 {v25.4s}, [x17], #16 + ld1 {v16.4s}, [x22], x11 + ld1 {v17.4s}, [x22], x11 + fmla v0.4s, v16.4s, v25.4s + fmla v1.4s, v17.4s, v25.4s + ld1 {v18.4s}, [x22], x11 + ld1 {v19.4s}, [x22], x11 + fmla v2.4s, v18.4s, v25.4s + fmla v3.4s, v19.4s, v25.4s + ld1 {v20.4s}, [x22], x11 + ld1 {v21.4s}, [x22], x11 + fmla v4.4s, v20.4s, v25.4s + fmla v5.4s, v21.4s, v25.4s + ld1 {v22.4s}, [x22], x11 + ld1 {v23.4s}, [x22], x11 + fmla v6.4s, v22.4s, v25.4s + fmla v7.4s, v23.4s, v25.4s + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw8 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh8 + cbnz x15, Relu68 + cbnz x14, Relu8 + b Write8 + Relu68: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + Relu8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + Write8: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #8 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + bge LoopW8 + LoopW: + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + LoopKh: + mov x25, x7 + mov x22, x16 + LoopKw: + ld1 {v16.4s}, [x22], x13 + ld1 {v25.4s}, [x17], #16 + fmla v0.4s, v16.4s, v25.4s + subs x25, x25, #1 + bne LoopKw + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh + cbnz x15, Relu6 + cbnz x14, Relu + b Write + Relu6: + fmin v0.4s, v0.4s, v26.4s + Relu: + fmax v0.4s, v0.4s, v27.4s + Write: + st1 {v0.4s}, [x3], x9 + add x23, x23, x11 + subs x24, x24, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x10 + subs x4, x4, #1 + bne LoopH + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Indirect3x3.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Indirect3x3.S new file mode 100644 index 0000000000000000000000000000000000000000..ca93dc7d9bc972d8ea4b4dcacf11758c96ddaca9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Indirect3x3.S @@ -0,0 +1,159 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Indirect3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, +// size_t input_stride, size_t relu, size_t relu6) +// x0: output, x1: input, x2: weights, x3: bias, x4: channels, x5: output_width, x6: input_stride, x7: relu, x8: relu6 + +asm_function ConvDwFp32Indirect3x3 + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + movi v31.4s, #6 + scvtf v31.4s, v31.4s + dup v30.4s, wzr + + ldr x8, [sp, #32] + cmp x5, #0 + beq End + + LoopPixel: + ldp x12, x13, [x1] + ldp x14, x15, [x1, #16] + ldp x16, x17, [x1, #32] + ldp x21, x19, [x1, #48] + ldr x20, [x1, #64] + mov x9, x2 + mov x10, x3 + mov x11, x4 + + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x13], #16 + ld1 {v2.4s}, [x14], #16 + + ld1 {v17.4s}, [x9], #16 + ld1 {v18.4s}, [x9], #16 + ld1 {v19.4s}, [x9], #16 + + ld1 {v29.4s}, [x10], #16 + cmp x11, #4 + ble LeftLoop + LoopC4: + fmla v29.4s, v0.4s, v17.4s + ld1 {v3.4s}, [x15], #16 + ld1 {v20.4s}, [x9], #16 + fmla v29.4s, v1.4s, v18.4s + ld1 {v4.4s}, [x16], #16 + ld1 {v21.4s}, [x9], #16 + fmla v29.4s, v2.4s, v19.4s + ld1 {v5.4s}, [x17], #16 + ld1 {v22.4s}, [x9], #16 + fmla v29.4s, v3.4s, v20.4s + ld1 {v6.4s}, [x21], #16 + ld1 {v23.4s}, [x9], #16 + fmla v29.4s, v4.4s, v21.4s + ld1 {v7.4s}, [x19], #16 + ld1 {v24.4s}, [x9], #16 + fmla v29.4s, v5.4s, v22.4s + ld1 {v16.4s}, [x20], #16 + ld1 {v25.4s}, [x9], #16 + fmla v29.4s, v6.4s, v23.4s + ld1 {v0.4s}, [x12], #16 + ld1 {v17.4s}, [x9], #16 + fmla v29.4s, v7.4s, v24.4s + ld1 {v1.4s}, [x13], #16 + ld1 {v18.4s}, [x9], #16 + fmla v29.4s, v16.4s, v25.4s + ld1 {v2.4s}, [x14], #16 + ld1 {v19.4s}, [x9], #16 + + cbnz x8, Relu6 + cbnz x7, Relu + b Write + Relu6: + fmin v29.4s, v29.4s, v31.4s + Relu: + fmax v29.4s, v29.4s, v30.4s + Write: + st1 {v29.4s}, [x0], #16 + + ld1 {v29.4s}, [x10], #16 + sub x11, x11, #4 + cmp x11, #4 + bgt LoopC4 + + LeftLoop: + fmla v29.4s, v0.4s, v17.4s + ld1 {v3.4s}, [x15], #16 + ld1 {v20.4s}, [x9], #16 + fmla v29.4s, v1.4s, v18.4s + ld1 {v4.4s}, [x16], #16 + ld1 {v21.4s}, [x9], #16 + fmla v29.4s, v2.4s, v19.4s + ld1 {v5.4s}, [x17], #16 + ld1 {v22.4s}, [x9], #16 + fmla v29.4s, v3.4s, v20.4s + ld1 {v6.4s}, [x21], #16 + ld1 {v23.4s}, [x9], #16 + fmla v29.4s, v4.4s, v21.4s + ld1 {v7.4s}, [x19], #16 + ld1 {v24.4s}, [x9], #16 + fmla v29.4s, v5.4s, v22.4s + ld1 {v16.4s}, [x20], #16 + ld1 {v25.4s}, [x9], #16 + fmla v29.4s, v6.4s, v23.4s + fmla v29.4s, v7.4s, v24.4s + fmla v29.4s, v16.4s, v25.4s + + cbnz x8, LeftRelu6 + cbnz x7, LeftRelu + b LeftWrite + LeftRelu6: + fmin v29.4s, v29.4s, v31.4s + LeftRelu: + fmax v29.4s, v29.4s, v30.4s + LeftWrite: + cmp x11, #4 + bne Write3 + st1 {v29.4s}, [x0], #16 + b NextPixel + Write3: + sxtw x11, w11 + tbnz w11, #1, Write2 + tbnz w11, #0, Write1 + Write2: + st1 {v29.2s}, [x0], #8 + ext v29.16b, v29.16b, v29.16b, #8 + tbz w11, #0, NextPixel + Write1: + str s29, [x0], #4 + + NextPixel: + add x1, x1, x6 + sub x5, x5, #1 + cmp x5, #0 + bgt LoopPixel +End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 +ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Indirect5x5.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Indirect5x5.S new file mode 100644 index 0000000000000000000000000000000000000000..a84e6c79303ad1abfbe34221857376d5344a433f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Indirect5x5.S @@ -0,0 +1,304 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, +// size_t input_stride, size_t relu, size_t relu6) +// x0: output, x1: input, x2: weights, x3: bias, x4: channels, x5: output_width, x6: input_stride, x7: relu, x8: relu6 + +asm_function ConvDwFp32Indirect5x5 + sub sp, sp, #176 + stp x19, x20, [sp, #64] + stp x21, x22, [sp, #80] + stp x23, x24, [sp, #96] + stp x25, x26, [sp, #112] + stp x27, x28, [sp, #128] + stp x29, x30, [sp, #144] + ldrb w8, [sp, #176] + stp x2, x3, [sp] + stp x4, x6, [sp, #16] + stp x7, x8, [sp, #32] + stp x0, x1, [sp, #160] + + movi v31.4s, #6 + scvtf v31.4s, v31.4s + dup v30.4s, wzr + + mov x3, x5 + cmp x3, #0 + beq End + + LoopPixel: + ldp x5, x4, [sp] // weight, bias + ld1 {v29.4s}, [x4], #16 + ldr x2, [sp, #16] // channel + + ldp x6, x7, [x1] + ldp x8, x9, [x1, #16] + ldp x10, x11, [x1, #32] + ldp x12, x13, [x1, #48] + ldp x14, x15, [x1, #64] + ldp x16, x17, [x1, #80] + ldp x0, x19, [x1, #96] + ldp x20, x21, [x1, #112] + ldp x22, x23, [x1, #128] + ldp x24, x25, [x1, #144] + ldp x26, x27, [x1, #160] + ldp x28, x29, [x1, #176] + ldr x30, [x1, #192] + + ld1 {v0.4s}, [x6], #16 + ld1 {v1.4s}, [x7], #16 + ld1 {v2.4s}, [x8], #16 + ld1 {v3.4s}, [x9], #16 + ld1 {v4.4s}, [x10], #16 + + ld1 {v18.4s}, [x5], #16 + ld1 {v19.4s}, [x5], #16 + ld1 {v20.4s}, [x5], #16 + ld1 {v21.4s}, [x5], #16 + ld1 {v22.4s}, [x5], #16 + stp x5, x4, [sp, #48] + + cmp x2, #4 + ble LeftLoop + LoopC4: + ldr x5, [sp, #48] + // column 0 + fmla v29.4s, v0.4s, v18.4s + ld1 {v5.4s}, [x11], #16 + ld1 {v23.4s}, [x5], #16 + fmla v29.4s, v1.4s, v19.4s + ld1 {v6.4s}, [x12], #16 + ld1 {v24.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v7.4s}, [x13], #16 + ld1 {v25.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v16.4s}, [x14], #16 + ld1 {v26.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v17.4s}, [x15], #16 + ld1 {v27.4s}, [x5], #16 + // column 1 + fmla v29.4s, v5.4s, v23.4s + ld1 {v0.4s}, [x16], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v6.4s, v24.4s + ld1 {v1.4s}, [x17], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v7.4s, v25.4s + ld1 {v2.4s}, [x0], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v16.4s, v26.4s + ld1 {v3.4s}, [x19], #16 + ld1 {v21.4s}, [x5], #16 + fmla v29.4s, v17.4s, v27.4s + ld1 {v4.4s}, [x20], #16 + ld1 {v22.4s}, [x5], #16 + // column 2 + fmla v29.4s, v0.4s, v18.4s + ld1 {v5.4s}, [x21], #16 + ld1 {v23.4s}, [x5], #16 + fmla v29.4s, v1.4s, v19.4s + ld1 {v6.4s}, [x22], #16 + ld1 {v24.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v7.4s}, [x23], #16 + ld1 {v25.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v16.4s}, [x24], #16 + ld1 {v26.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v17.4s}, [x25], #16 + ld1 {v27.4s}, [x5], #16 + // column 3 + fmla v29.4s, v5.4s, v23.4s + ld1 {v0.4s}, [x26], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v6.4s, v24.4s + ld1 {v1.4s}, [x27], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v7.4s, v25.4s + ld1 {v2.4s}, [x28], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v16.4s, v26.4s + ld1 {v3.4s}, [x29], #16 + ld1 {v21.4s}, [x5], #16 + fmla v29.4s, v17.4s, v27.4s + ld1 {v4.4s}, [x30], #16 + ld1 {v22.4s}, [x5], #16 + // column 4 + fmla v29.4s, v0.4s, v18.4s + fmla v29.4s, v1.4s, v19.4s + ld1 {v0.4s}, [x6], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v1.4s}, [x7], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v2.4s}, [x8], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v3.4s}, [x9], #16 + ld1 {v21.4s}, [x5], #16 + ld1 {v4.4s}, [x10], #16 + ld1 {v22.4s}, [x5], #16 + str x5, [sp, #48] + + ldp x4, x5, [sp, #32] + cbnz x5, RELU6 + cbnz x4, RELU + b WRITE + RELU6: + fmin v29.4s, v29.4s, v31.4s + RELU: + fmax v29.4s, v29.4s, v30.4s + WRITE: + ldr x4, [sp, #160] + st1 {v29.4s}, [x4], #16 + str x4, [sp, #160] + + ldr x4, [sp, #56] + ld1 {v29.4s}, [x4], #16 + str x4, [sp, #56] + sub x2, x2, #4 + cmp x2, #4 + bgt LoopC4 + + LeftLoop: + // column 0 + ldr x5, [sp, #48] + fmla v29.4s, v0.4s, v18.4s + ld1 {v5.4s}, [x11], #16 + ld1 {v23.4s}, [x5], #16 + fmla v29.4s, v1.4s, v19.4s + ld1 {v6.4s}, [x12], #16 + ld1 {v24.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v7.4s}, [x13], #16 + ld1 {v25.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v16.4s}, [x14], #16 + ld1 {v26.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v17.4s}, [x15], #16 + ld1 {v27.4s}, [x5], #16 + // column 1 + fmla v29.4s, v5.4s, v23.4s + ld1 {v0.4s}, [x16], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v6.4s, v24.4s + ld1 {v1.4s}, [x17], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v7.4s, v25.4s + ld1 {v2.4s}, [x0], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v16.4s, v26.4s + ld1 {v3.4s}, [x19], #16 + ld1 {v21.4s}, [x5], #16 + fmla v29.4s, v17.4s, v27.4s + ld1 {v4.4s}, [x20], #16 + ld1 {v22.4s}, [x5], #16 + // column 2 + fmla v29.4s, v0.4s, v18.4s + ld1 {v5.4s}, [x21], #16 + ld1 {v23.4s}, [x5], #16 + fmla v29.4s, v1.4s, v19.4s + ld1 {v6.4s}, [x22], #16 + ld1 {v24.4s}, [x5], #16 + fmla v29.4s, v2.4s, v20.4s + ld1 {v7.4s}, [x23], #16 + ld1 {v25.4s}, [x5], #16 + fmla v29.4s, v3.4s, v21.4s + ld1 {v16.4s}, [x24], #16 + ld1 {v26.4s}, [x5], #16 + fmla v29.4s, v4.4s, v22.4s + ld1 {v17.4s}, [x25], #16 + ld1 {v27.4s}, [x5], #16 + // column 3 + fmla v29.4s, v5.4s, v23.4s + ld1 {v0.4s}, [x26], #16 + ld1 {v18.4s}, [x5], #16 + fmla v29.4s, v6.4s, v24.4s + ld1 {v1.4s}, [x27], #16 + ld1 {v19.4s}, [x5], #16 + fmla v29.4s, v7.4s, v25.4s + ld1 {v2.4s}, [x28], #16 + ld1 {v20.4s}, [x5], #16 + fmla v29.4s, v16.4s, v26.4s + ld1 {v3.4s}, [x29], #16 + ld1 {v21.4s}, [x5], #16 + fmla v29.4s, v17.4s, v27.4s + ld1 {v4.4s}, [x30], #16 + ld1 {v22.4s}, [x5], #16 + // column 4 + fmla v29.4s, v0.4s, v18.4s + fmla v29.4s, v1.4s, v19.4s + fmla v29.4s, v2.4s, v20.4s + fmla v29.4s, v3.4s, v21.4s + fmla v29.4s, v4.4s, v22.4s + + ldp x4, x5, [sp, #32] + cbnz x5, LeftRelu6 + cbnz x4, LeftRelu + b LeftWrite + LeftRelu6: + fmin v29.4s, v29.4s, v31.4s + LeftRelu: + fmax v29.4s, v29.4s, v30.4s + LeftWrite: + cmp x2, #4 + bne Write3 + ldr x4, [sp, #160] + st1 {v29.4s}, [x4], #16 + str x4, [sp, #160] + b NextPixel + Write3: + sxtw x2, w2 + tbnz w2, #1, Write2 + tbnz w2, #0, Write1 + Write2: + ldr x4, [sp, #160] + st1 {v29.2s}, [x4], #8 + str x4, [sp, #160] + ext v29.16b, v29.16b, v29.16b, #8 + tbz w2, #0, NextPixel + Write1: + ldr x4, [sp, #160] + str s29, [x4], #4 + str x4, [sp, #160] + + NextPixel: + ldr x2, [sp, #24] + add x1, x1, x2 + sub x3, x3, #1 + cmp x3, #0 + bgt LoopPixel +End: + ldp x19, x20, [sp, #64] + ldp x21, x22, [sp, #80] + ldp x23, x24, [sp, #96] + ldp x25, x26, [sp, #112] + ldp x27, x28, [sp, #128] + ldp x29, x30, [sp, #144] + add sp, sp, #176 +ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Row.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Row.S new file mode 100644 index 0000000000000000000000000000000000000000..08f3ff534d002bbdfb181cc57fc42af743aa986f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwFp32Row.S @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp32Row(float* output_ptr, const float* input_ptr,const float* filter_ptr, +// size_t num_pixels, size_t input_channel, size_t input_step) +// x0: output_ptr, x1: input_ptr, x2: filter_ptr, x3: num_pixels, +// x4: input_channel, x5: input_step +// +asm_function ConvDwFp32Row + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters +cmp x3, #0 +ble End + +mov x9, x0 +mov x12, #4 +mul x5, x5, x12 + +LoopOutPixel: +mov x6, x1 +mov x7, x2 +mov x8, x4 + + LoopDepth16In: + cmp x8, #16 + blt L4 + sub x8, x8, #16 + + ld1 {v0.4s, v1.4s}, [x6], #32 + ld1 {v2.4s, v3.4s}, [x7], #32 + ld1 {v16.4s, v17.4s}, [x0], #32 + + cmp x8, #16 + blt LoopDepth16Out + LoopDepth16: + fmla v16.4s, v0.4s, v2.4s + fmla v17.4s, v1.4s, v3.4s + + st1 {v16.4s, v17.4s}, [x9], #32 + + ld1 {v4.4s, v5.4s}, [x6], #32 + ld1 {v6.4s, v7.4s}, [x7], #32 + ld1 {v18.4s, v19.4s}, [x0], #32 + + fmla v18.4s, v4.4s, v6.4s + fmla v19.4s, v5.4s, v7.4s + + st1 {v18.4s, v19.4s}, [x9], #32 + + ld1 {v0.4s, v1.4s}, [x6], #32 + ld1 {v2.4s, v3.4s}, [x7], #32 + ld1 {v16.4s, v17.4s}, [x0], #32 + + sub x8, x8, #16 + cmp x8, #16 + bge LoopDepth16 + + LoopDepth16Out: + fmla v16.4s, v0.4s, v2.4s + fmla v17.4s, v1.4s, v3.4s + st1 {v16.4s, v17.4s}, [x9], #32 + + ld1 {v4.4s, v5.4s}, [x6], #32 + ld1 {v6.4s, v7.4s}, [x7], #32 + ld1 {v18.4s, v19.4s}, [x0], #32 + + fmla v18.4s, v4.4s, v6.4s + fmla v19.4s, v5.4s, v7.4s + + st1 {v18.4s, v19.4s}, [x9], #32 + + L4: + cmp x8, #4 + blt L0 + + LoopDepth4: + ld1 {v0.4s}, [x6], #16 + ld1 {v2.4s}, [x7], #16 + ld1 {v16.4s}, [x0], #16 + fmla v16.4s, v0.4s, v2.4s + st1 {v16.4s}, [x9], #16 + sub x8, x8, #4 + cmp x8, #4 + bge LoopDepth4 + + L0: + cmp x8, #0 + beq Loop16LineEnd + + LoopDepth0: + ldr s0, [x6], #4 + ldr s1, [x7], #4 + ldr s2, [x0], #4 + fmul s0, s0, s1 + fadd s2, s2, s0 + str s2, [x9], #4 + subs x8, x8, #1 + bne LoopDepth0 + + Loop16LineEnd: + +subs x3, x3, #1 +add x1, x1, x5 +bne LoopOutPixel + +End: +ret + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwInt8Center.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwInt8Center.S new file mode 100644 index 0000000000000000000000000000000000000000..01cd9db6109ac5432739c748eb8b70a9dd39acae --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwInt8Center.S @@ -0,0 +1,294 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, +// size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, +// size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, +// int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, +// int32_t *acc_min, int32_t *acc_max) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step +// x14: in_zp, #56: out_zp, #64: out_multiplier, #72:left_shift, #80: right_shift, #88: acc_min, #96: acc_max +asm_function ConvDwInt8Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + + ldr x14, [sp, #240] // input_zp + ld1 {v19.8b}, [x14], #8 + + ldr x15, [sp, #248] // output_zp + ld1 {v20.4s}, [x15], #16 + ld1 {v21.4s}, [x15], #16 + + ldr x16, [sp, #256] // out_multiplier + ld1 {v22.4s}, [x16], #16 + ld1 {v23.4s}, [x16], #16 + + ldr x17, [sp, #264] // left_shift + ld1 {v24.4s}, [x17], #16 + ld1 {v25.4s}, [x17], #16 + + ldr x25, [sp, #272] // right shift + ld1 {v26.4s}, [x25], #16 + ld1 {v27.4s}, [x25], #16 + + ldr x19, [sp, #280] // acc_min + ld1 {v28.4s}, [x19], #16 + ld1 {v29.4s}, [x19], #16 + + ldr x20, [sp, #288] // acc_max + ld1 {v30.4s}, [x20], #16 + ld1 {v31.4s}, [x20], #16 + + ld1 {v17.4s}, [x3], #16 + ld1 {v18.4s}, [x3], #16 + + LoopH: + mov x23, x1 + mov x24, x5 + mov x3, x0 + + LoopW4: + mov x19, #4 + mul x19, x19, x11 + mov x25, #4 + mul x25, x25, x9 + + mov x16, x23 + mov x17, x2 + mov x20, x6 + + mov v0.16b, v17.16b + mov v1.16b, v18.16b + mov v2.16b, v17.16b + mov v3.16b, v18.16b + mov v4.16b, v17.16b + mov v5.16b, v18.16b + mov v6.16b, v17.16b + mov v7.16b, v18.16b + LoopKh4: + mov x25, x7 + mov x21, x16 + LoopKw4: + mov x22, x21 + ld1 {v16.8h}, [x17], #16 + + ld1 {v15.8b}, [x22], x11 + ssubl v14.8h, v15.8b, v19.8b + smlal v0.4s, v14.4h, v16.4h + smlal2 v1.4s, v14.8h, v16.8h + + ld1 {v13.8b}, [x22], x11 + ssubl v12.8h, v13.8b, v19.8b + smlal v2.4s, v12.4h, v16.4h + smlal2 v3.4s, v12.8h, v16.8h + + ld1 {v11.8b}, [x22], x11 + ssubl v10.8h, v11.8b, v19.8b + smlal v4.4s, v10.4h, v16.4h + smlal2 v5.4s, v10.8h, v16.8h + + ld1 {v9.8b}, [x22], x11 + ssubl v8.8h, v9.8b, v19.8b + smlal v6.4s, v8.4h, v16.4h + smlal2 v7.4s, v8.8h, v16.8h + + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw4 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh4 + + sqshl v0.4s, v0.4s, v24.4s + sqshl v1.4s, v1.4s, v25.4s + sqshl v2.4s, v2.4s, v24.4s + sqshl v3.4s, v3.4s, v25.4s + sqshl v4.4s, v4.4s, v24.4s + sqshl v5.4s, v5.4s, v25.4s + sqshl v6.4s, v6.4s, v24.4s + sqshl v7.4s, v7.4s, v25.4s + + sqrdmulh v0.4s, v0.4s, v22.4s + sqrdmulh v1.4s, v1.4s, v23.4s + sqrdmulh v2.4s, v2.4s, v22.4s + sqrdmulh v3.4s, v3.4s, v23.4s + sqrdmulh v4.4s, v4.4s, v22.4s + sqrdmulh v5.4s, v5.4s, v23.4s + sqrdmulh v6.4s, v6.4s, v22.4s + sqrdmulh v7.4s, v7.4s, v23.4s + + sqrshl v0.4s, v0.4s, v26.4s + sqrshl v1.4s, v1.4s, v27.4s + sqrshl v2.4s, v2.4s, v26.4s + sqrshl v3.4s, v3.4s, v27.4s + sqrshl v4.4s, v4.4s, v26.4s + sqrshl v5.4s, v5.4s, v27.4s + sqrshl v6.4s, v6.4s, v26.4s + sqrshl v7.4s, v7.4s, v27.4s + + add v0.4s, v0.4s, v20.4s + add v1.4s, v1.4s, v21.4s + add v2.4s, v2.4s, v20.4s + add v3.4s, v3.4s, v21.4s + add v4.4s, v4.4s, v20.4s + add v5.4s, v5.4s, v21.4s + add v6.4s, v6.4s, v20.4s + add v7.4s, v7.4s, v21.4s + smax v0.4s, v0.4s, v28.4s + smax v1.4s, v1.4s, v29.4s + smax v2.4s, v2.4s, v28.4s + smax v3.4s, v3.4s, v29.4s + smax v4.4s, v4.4s, v28.4s + smax v5.4s, v5.4s, v29.4s + smax v6.4s, v6.4s, v28.4s + smax v7.4s, v7.4s, v29.4s + smin v0.4s, v0.4s, v30.4s + smin v1.4s, v1.4s, v31.4s + smin v2.4s, v2.4s, v30.4s + smin v3.4s, v3.4s, v31.4s + smin v4.4s, v4.4s, v30.4s + smin v5.4s, v5.4s, v31.4s + smin v6.4s, v6.4s, v30.4s + smin v7.4s, v7.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + sqxtn v2.4h, v2.4s + sqxtn v3.4h, v3.4s + sqxtn v4.4h, v4.4s + sqxtn v5.4h, v5.4s + sqxtn v6.4h, v6.4s + sqxtn v7.4h, v7.4s + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + sqxtn v2.8b, v2.8h + sqxtn v3.8b, v3.8h + sqxtn v4.8b, v4.8h + sqxtn v5.8b, v5.8h + sqxtn v6.8b, v6.8h + sqxtn v7.8b, v7.8h + + mov x16, x3 + add x17, x16, x9 + add x25, x17, x9 + add x21, x25, x9 + + st1 {v0.s}[0], [x16], #4 + st1 {v1.s}[0], [x16], #4 + st1 {v2.s}[0], [x17], #4 + st1 {v3.s}[0], [x17], #4 + st1 {v4.s}[0], [x25], #4 + st1 {v5.s}[0], [x25], #4 + st1 {v6.s}[0], [x21], #4 + st1 {v7.s}[0], [x21], #4 + + add x3, x3, x25 + add x23, x23, x19 + sub x24, x24, #4 + cmp x24, #0 + ble LoopWEnd + cmp x24, #4 + bge LoopW4 + + LoopW: + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v17.16b + mov v1.16b, v18.16b + LoopKh: + mov x25, x7 + mov x22, x16 + LoopKw: + ld1 {v15.8b}, [x22], x13 + ssubl v14.8h, v15.8b, v19.8b + ld1 {v16.8h}, [x17], #16 + smlal v0.4s, v14.4h, v16.4h + smlal2 v1.4s, v14.8h, v16.8h + subs x25, x25, #1 + bne LoopKw + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh + + sqshl v0.4s, v0.4s, v24.4s + sqrdmulh v0.4s, v0.4s, v22.4s + sqshl v1.4s, v1.4s, v25.4s + sqrdmulh v1.4s, v1.4s, v23.4s + + sqrshl v0.4s, v0.4s, v26.4s + sqrshl v1.4s, v1.4s, v27.4s + + add v0.4s, v0.4s, v20.4s + smax v0.4s, v0.4s, v28.4s + smin v0.4s, v0.4s, v30.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + add v1.4s, v1.4s, v21.4s + smax v1.4s, v1.4s, v29.4s + smin v1.4s, v1.4s, v31.4s + + sqxtn v1.4h, v1.4s + sqxtn v1.8b, v1.8h + + mov x17, x3 + st1 {v0.s}[0], [x17], #4 + st1 {v1.s}[0], [x17], #4 + add x3, x3, x9 + + add x23, x23, x11 + subs x24, x24, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x10 + subs x4, x4, #1 + bne LoopH + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwInt8PostAlign4.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwInt8PostAlign4.S new file mode 100644 index 0000000000000000000000000000000000000000..9d3911f6a3c311127d02def968d8247835927a41 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwInt8PostAlign4.S @@ -0,0 +1,191 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, +// int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +// x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier, +// x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max + +asm_function ConvDwInt8PostAlign4 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + dup v26.4s, w5 + dup v27.4s, w4 + dup v28.4s, w6 + + dup v29.4s, w3 + dup v30.4s, w7 + dup v31.4s, w8 + + cmp x2, #16 + blt LoopDepth8 + + LoopDepth16: + ld1 {v0.4s}, [x1], #16 + ld1 {v1.4s}, [x1], #16 + ld1 {v2.4s}, [x1], #16 + ld1 {v3.4s}, [x1], #16 + + cbz w5, RightShiftDepth16 + sqshl v0.4s, v0.4s, v26.4s + sqshl v1.4s, v1.4s, v26.4s + sqshl v2.4s, v2.4s, v26.4s + sqshl v3.4s, v3.4s, v26.4s + sqrdmulh v0.4s, v0.4s, v27.4s + sqrdmulh v1.4s, v1.4s, v27.4s + sqrdmulh v2.4s, v2.4s, v27.4s + sqrdmulh v3.4s, v3.4s, v27.4s + b AddZpDepth16 + + RightShiftDepth16: + sqrdmulh v0.4s, v0.4s, v27.4s + sqrdmulh v1.4s, v1.4s, v27.4s + sqrdmulh v2.4s, v2.4s, v27.4s + sqrdmulh v3.4s, v3.4s, v27.4s + + and v4.16b, v0.16b, v28.16b + sshr v4.4s, v4.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + srshl v0.4s, v0.4s, v28.4s + and v5.16b, v1.16b, v28.16b + sshr v5.4s, v5.4s, #31 + sqadd v1.4s, v1.4s, v5.4s + srshl v1.4s, v1.4s, v28.4s + and v6.16b, v2.16b, v28.16b + sshr v6.4s, v6.4s, #31 + sqadd v2.4s, v2.4s, v6.4s + srshl v2.4s, v2.4s, v28.4s + and v7.16b, v3.16b, v28.16b + sshr v7.4s, v7.4s, #31 + sqadd v3.4s, v3.4s, v7.4s + srshl v3.4s, v3.4s, v28.4s + + AddZpDepth16: + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + add v2.4s, v2.4s, v29.4s + add v3.4s, v3.4s, v29.4s + + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + smax v2.4s, v2.4s, v30.4s + smax v3.4s, v3.4s, v30.4s + + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + smin v2.4s, v2.4s, v31.4s + smin v3.4s, v3.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + sqxtn v2.4h, v2.4s + sqxtn v3.4h, v3.4s + + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + sqxtn v2.8b, v2.8h + sqxtn v3.8b, v3.8h + + st1 {v0.s}[0], [x0], #4 + st1 {v1.s}[0], [x0], #4 + st1 {v2.s}[0], [x0], #4 + st1 {v3.s}[0], [x0], #4 + + sub x2, x2, #16 + cmp x2, #16 + bge LoopDepth16 + + LoopDepth8: + cmp x2, #8 + blt LoopDepth4 + ld1 {v0.4s}, [x1], #16 + ld1 {v1.4s}, [x1], #16 + + cbz w5, RightShiftDepth8 + sqshl v0.4s, v0.4s, v26.4s + sqshl v1.4s, v1.4s, v26.4s + sqrdmulh v0.4s, v0.4s, v27.4s + sqrdmulh v1.4s, v1.4s, v27.4s + b AddZpDepth8 + + RightShiftDepth8: + sqrdmulh v0.4s, v0.4s, v27.4s + sqrdmulh v1.4s, v1.4s, v27.4s + and v4.16b, v0.16b, v28.16b + sshr v4.4s, v4.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + srshl v0.4s, v0.4s, v28.4s + and v5.16b, v1.16b, v28.16b + sshr v5.4s, v5.4s, #31 + sqadd v1.4s, v1.4s, v5.4s + srshl v1.4s, v1.4s, v28.4s + + AddZpDepth8: + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + + st1 {v0.s}[0], [x0], #4 + st1 {v1.s}[0], [x0], #4 + + sub x2, x2, #8 + cmp x2, #8 + bge LoopDepth8 + + LoopDepth4: + cmp x2, #4 + blt End + ld1 {v0.4s}, [x1], #16 + + sqshl v0.4s, v0.4s, v26.4s + sqrdmulh v0.4s, v0.4s, v27.4s + and v4.16b, v0.16b, v28.16b + sshr v4.4s, v4.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + srshl v0.4s, v0.4s, v28.4s + + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + st1 {v0.s}[0], [x0], #4 + + sub x2, x2, #4 + bge LoopDepth4 + End: + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S new file mode 100644 index 0000000000000000000000000000000000000000..bad15a971413b2114061a7bd33295d6d15829504 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S @@ -0,0 +1,119 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, int32_t *out_multiplier, +// int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max); +// x0: dst, x1: buffer, x2: num_pixels, x3: output_zp, x4: out_multiplier, +// x5: left_shift, x6: right_shift, x7: acc_min, x8: acc_max + +asm_function ConvDwInt8PostAlign4PerChannel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + + dup v29.4s, w3 + dup v30.4s, w7 + dup v31.4s, w8 + + LoopDepth8: + cmp x2, #8 + blt LoopDepth4 + ld1 {v0.4s}, [x1], #16 + ld1 {v1.4s}, [x1], #16 + + ld1 {v2.4s}, [x5], #16 + ld1 {v3.4s}, [x5], #16 + + ld1 {v4.4s}, [x4], #16 + ld1 {v5.4s}, [x4], #16 + + sqshl v0.4s, v0.4s, v2.4s + sqshl v1.4s, v1.4s, v3.4s + + ld1 {v6.4s}, [x6], #16 + ld1 {v7.4s}, [x6], #16 + + sqrdmulh v0.4s, v0.4s, v4.4s + sqrdmulh v1.4s, v1.4s, v5.4s + and v16.16b, v0.16b, v6.16b + sshr v16.4s, v16.4s, #31 + sqadd v0.4s, v0.4s, v16.4s + srshl v0.4s, v0.4s, v6.4s + and v17.16b, v1.16b, v7.16b + sshr v17.4s, v17.4s, #31 + sqadd v1.4s, v1.4s, v17.4s + srshl v1.4s, v1.4s, v7.4s + + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v1.4h, v1.4s + + sqxtn v0.8b, v0.8h + sqxtn v1.8b, v1.8h + + st1 {v0.s}[0], [x0], #4 + st1 {v1.s}[0], [x0], #4 + + sub x2, x2, #8 + cmp x2, #8 + bge LoopDepth8 + + LoopDepth4: + cmp x2, #4 + blt End + ld1 {v0.4s}, [x1], #16 + ld1 {v2.4s}, [x5], #16 + + sqshl v0.4s, v0.4s, v2.4s + + ld1 {v4.4s}, [x4], #16 + sqrdmulh v0.4s, v0.4s, v4.4s + + ld1 {v6.4s}, [x6], #16 + and v16.16b, v0.16b, v6.16b + sshr v16.4s, v16.4s, #31 + sqadd v0.4s, v0.4s, v16.4s + srshl v0.4s, v0.4s, v6.4s + + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + st1 {v0.s}[0], [x0], #4 + + sub x2, x2, #4 + bge LoopDepth4 + End: + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwInt8Row.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwInt8Row.S new file mode 100644 index 0000000000000000000000000000000000000000..a69f35ed2a56be67a50d497c99b0d587e6aceb17 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvDwInt8Row.S @@ -0,0 +1,134 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, +// int output_channel, int input_step, int8_t input_zp) +// x0: output_ptr, x1: input_ptr, x2: weight_ptr, x3: num_pixels, +// x4: output_channel, x5: input_step, x6: input_zp +// +asm_function ConvDwInt8Row + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters +cmp x3, #0 +beq End + +mov x10, x0 + +dup v31.8b, w6 + +LoopOutPixel: +mov x7, x1 +mov x8, x2 +mov x9, x4 + + LoopDepth16In: + cmp x9, #16 + blt L8 + sub x9, x9, #16 + + ld1 {v0.8b, v1.8b}, [x7], #16 + ld1 {v2.8h, v3.8h}, [x8], #32 + ld1 {v16.4s, v17.4s}, [x0], #32 + + ssubl v20.8h, v0.8b, v31.8b + smlal v16.4s, v20.4h, v2.4h + smlal2 v17.4s, v20.8h, v2.8h + + + cmp x9, #16 + blt LoopDepth16Out + LoopDepth16: + + st1 {v16.4s, v17.4s}, [x10], #32 + ld1 {v18.4s, v19.4s}, [x0], #32 + ssubl v21.8h, v1.8b, v31.8b + smlal v18.4s, v21.4h, v3.4h + smlal2 v19.4s, v21.8h, v3.8h + st1 {v18.4s, v19.4s}, [x10], #32 + + ld1 {v0.8b, v1.8b}, [x7], #16 + ld1 {v2.8h, v3.8h}, [x8], #32 + ld1 {v16.4s, v17.4s}, [x0], #32 + + ssubl v20.8h, v0.8b, v31.8b + smlal v16.4s, v20.4h, v2.4h + smlal2 v17.4s, v20.8h, v2.8h + + sub x9, x9, #16 + cmp x9, #16 + bge LoopDepth16 + + LoopDepth16Out: + + st1 {v16.4s, v17.4s}, [x10], #32 + ld1 {v18.4s, v19.4s}, [x0], #32 + ssubl v21.8h, v1.8b, v31.8b + smlal v18.4s, v21.4h, v3.4h + smlal2 v19.4s, v21.8h, v3.8h + st1 {v18.4s, v19.4s}, [x10], #32 + + L8: + cmp x9, #8 + blt L0 + + LoopDepth8: + ld1 {v0.8b}, [x7], #8 + ld1 {v2.8h}, [x8], #16 + ld1 {v16.4s, v17.4s}, [x0], #32 + + ssubl v20.8h, v0.8b, v31.8b + smlal v16.4s, v20.4h, v2.4h + smlal2 v17.4s, v20.8h, v2.8h + st1 {v16.4s, v17.4s}, [x10], #32 + + sub x9, x9, #8 + cmp x9, #8 + bge LoopDepth8 + + L0: + cmp x9, #0 + beq Loop16LineEnd + + LoopDepth0: + ldrsb w14, [x7], #1 + ldrsh w15, [x8], #2 + ldr w16, [x0], #4 + sub w14, w14, w6 + + sxth w14, w14 + madd w14, w14, w15, w16 + str w14, [x10], #4 + + subs x9, x9, #1 + bne LoopDepth0 + + Loop16LineEnd: + +subs x3, x3, #1 +add x1, x1, x5 +bne LoopOutPixel + +End: +ret + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvFp32Center.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvFp32Center.S new file mode 100644 index 0000000000000000000000000000000000000000..0a9d32657609dde01c8c4294d70c1f4b50922f40 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvFp32Center.S @@ -0,0 +1,458 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvSwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t ic4, size_t in_sh_step, +// size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: ic4, x11: in_sh_step, x12: in_sw_step, x13: in_kh_step, x14: in_kw_step +// x26: relu, x16: relu6 +asm_function ConvSwFp32Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x8, [sp, #208] + ldr x9, [sp, #216] + ldr x10, [sp, #224] + ldr x11, [sp, #232] + ldr x12, [sp, #240] + ldr x13, [sp, #248] + ldr x14, [sp, #256] + mul x15, x6, x7 + mul x15, x10, x15 + mov x16, #16 + mul x15, x15, x16 + + ld1 {v25.4s}, [x3] + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + LoopH: + mov x17, x1 + mov x28, x5 + mov x3, x0 + cmp x28, #8 + blt LoopW + cmp x28, #16 + blt LoopW8 + + LoopW16: + mov x19, #16 + mul x19, x19, x12 + mov x20, x17 + mov x21, x2 + mov x22, x6 + mov v0.16b, v25.16b + mov v1.16b, v25.16b + mov v2.16b, v25.16b + mov v3.16b, v25.16b + mov v4.16b, v25.16b + mov v5.16b, v25.16b + mov v6.16b, v25.16b + mov v7.16b, v25.16b + mov v8.16b, v25.16b + mov v9.16b, v25.16b + mov v10.16b, v25.16b + mov v11.16b, v25.16b + mov v12.16b, v25.16b + mov v13.16b, v25.16b + mov v14.16b, v25.16b + mov v15.16b, v25.16b + LoopKh16: + mov x23, x7 + mov x24, x20 + LoopKw16: + mov x25, x24 + mov x27, x10 + LoopIc16: + mov x26, x25 + mov x16, x21 + ld1 {v28.4s}, [x16], x15 + ld1 {v29.4s}, [x16], x15 + ld1 {v30.4s}, [x16], x15 + ld1 {v31.4s}, [x16], x15 + zip1 v20.4s, v28.4s, v29.4s + zip2 v21.4s, v28.4s, v29.4s + zip1 v22.4s, v30.4s, v31.4s + zip2 v23.4s, v30.4s, v31.4s + ld1 {v16.4s}, [x26], x12 + ld1 {v17.4s}, [x26], x12 + trn1 v28.2d, v20.2d, v22.2d + trn2 v29.2d, v20.2d, v22.2d + trn1 v30.2d, v21.2d, v23.2d + trn2 v31.2d, v21.2d, v23.2d + ld1 {v18.4s}, [x26], x12 + ld1 {v19.4s}, [x26], x12 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v28.4s, v17.s[0] + fmla v0.4s, v29.4s, v16.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v0.4s, v30.4s, v16.s[2] + fmla v1.4s, v30.4s, v17.s[2] + fmla v0.4s, v31.4s, v16.s[3] + fmla v1.4s, v31.4s, v17.s[3] + ld1 {v20.4s}, [x26], x12 + ld1 {v21.4s}, [x26], x12 + fmla v2.4s, v28.4s, v18.s[0] + fmla v3.4s, v28.4s, v19.s[0] + fmla v2.4s, v29.4s, v18.s[1] + fmla v3.4s, v29.4s, v19.s[1] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v30.4s, v19.s[2] + fmla v2.4s, v31.4s, v18.s[3] + fmla v3.4s, v31.4s, v19.s[3] + ld1 {v22.4s}, [x26], x12 + ld1 {v23.4s}, [x26], x12 + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v28.4s, v21.s[0] + fmla v4.4s, v29.4s, v20.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v4.4s, v30.4s, v20.s[2] + fmla v5.4s, v30.4s, v21.s[2] + fmla v4.4s, v31.4s, v20.s[3] + fmla v5.4s, v31.4s, v21.s[3] + ld1 {v16.4s}, [x26], x12 + ld1 {v17.4s}, [x26], x12 + fmla v6.4s, v28.4s, v22.s[0] + fmla v7.4s, v28.4s, v23.s[0] + fmla v6.4s, v29.4s, v22.s[1] + fmla v7.4s, v29.4s, v23.s[1] + fmla v6.4s, v30.4s, v22.s[2] + fmla v7.4s, v30.4s, v23.s[2] + fmla v6.4s, v31.4s, v22.s[3] + fmla v7.4s, v31.4s, v23.s[3] + ld1 {v18.4s}, [x26], x12 + ld1 {v19.4s}, [x26], x12 + fmla v8.4s, v28.4s, v16.s[0] + fmla v9.4s, v28.4s, v17.s[0] + fmla v8.4s, v29.4s, v16.s[1] + fmla v9.4s, v29.4s, v17.s[1] + fmla v8.4s, v30.4s, v16.s[2] + fmla v9.4s, v30.4s, v17.s[2] + fmla v8.4s, v31.4s, v16.s[3] + fmla v9.4s, v31.4s, v17.s[3] + ld1 {v20.4s}, [x26], x12 + ld1 {v21.4s}, [x26], x12 + fmla v10.4s, v28.4s, v18.s[0] + fmla v11.4s, v28.4s, v19.s[0] + fmla v10.4s, v29.4s, v18.s[1] + fmla v11.4s, v29.4s, v19.s[1] + fmla v10.4s, v30.4s, v18.s[2] + fmla v11.4s, v30.4s, v19.s[2] + fmla v10.4s, v31.4s, v18.s[3] + fmla v11.4s, v31.4s, v19.s[3] + ld1 {v22.4s}, [x26], x12 + ld1 {v23.4s}, [x26], x12 + fmla v12.4s, v28.4s, v20.s[0] + fmla v13.4s, v28.4s, v21.s[0] + fmla v12.4s, v29.4s, v20.s[1] + fmla v13.4s, v29.4s, v21.s[1] + fmla v12.4s, v30.4s, v20.s[2] + fmla v13.4s, v30.4s, v21.s[2] + fmla v12.4s, v31.4s, v20.s[3] + fmla v13.4s, v31.4s, v21.s[3] + fmla v14.4s, v28.4s, v22.s[0] + fmla v15.4s, v28.4s, v23.s[0] + fmla v14.4s, v29.4s, v22.s[1] + fmla v15.4s, v29.4s, v23.s[1] + fmla v14.4s, v30.4s, v22.s[2] + fmla v15.4s, v30.4s, v23.s[2] + fmla v14.4s, v31.4s, v22.s[3] + fmla v15.4s, v31.4s, v23.s[3] + add x21, x21, #16 + add x25, x25, #16 + subs x27, x27, #1 + bgt LoopIc16 + subs x23, x23, #1 + add x24, x24, x14 + bne LoopKw16 + add x20, x20, x13 + subs x22, x22, #1 + bne LoopKh16 + ldr x16, [sp, #272] + cbnz x16, Relu616 + ldr x26, [sp, #264] + cbnz x26, Relu16 + b Write16 + Relu616: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + fmin v8.4s, v8.4s, v26.4s + fmin v9.4s, v9.4s, v26.4s + fmin v10.4s, v10.4s, v26.4s + fmin v11.4s, v11.4s, v26.4s + fmin v12.4s, v12.4s, v26.4s + fmin v13.4s, v13.4s, v26.4s + fmin v14.4s, v14.4s, v26.4s + fmin v15.4s, v15.4s, v26.4s + Relu16: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + fmax v8.4s, v8.4s, v27.4s + fmax v9.4s, v9.4s, v27.4s + fmax v10.4s, v10.4s, v27.4s + fmax v11.4s, v11.4s, v27.4s + fmax v12.4s, v12.4s, v27.4s + fmax v13.4s, v13.4s, v27.4s + fmax v14.4s, v14.4s, v27.4s + fmax v15.4s, v15.4s, v27.4s + Write16: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + st1 {v8.4s}, [x3], x9 + st1 {v9.4s}, [x3], x9 + st1 {v10.4s}, [x3], x9 + st1 {v11.4s}, [x3], x9 + st1 {v12.4s}, [x3], x9 + st1 {v13.4s}, [x3], x9 + st1 {v14.4s}, [x3], x9 + st1 {v15.4s}, [x3], x9 + add x17, x17, x19 + sub x28, x28, #16 + cmp x28, #0 + ble LoopWEnd + cmp x28, #8 + blt LoopW + cmp x28, #16 + bge LoopW16 + LoopW8: + mov x19, #8 + mul x19, x19, x12 + mov x20, x17 + mov x21, x2 + mov x22, x6 + mov v0.16b, v25.16b + mov v1.16b, v25.16b + mov v2.16b, v25.16b + mov v3.16b, v25.16b + mov v4.16b, v25.16b + mov v5.16b, v25.16b + mov v6.16b, v25.16b + mov v7.16b, v25.16b + LoopKh8: + mov x23, x7 + mov x24, x20 + LoopKw8: + mov x25, x24 + mov x27, x10 + LoopIc8: + mov x26, x25 + mov x16, x21 + ld1 {v28.4s}, [x16], x15 + ld1 {v29.4s}, [x16], x15 + ld1 {v30.4s}, [x16], x15 + ld1 {v31.4s}, [x16], x15 + zip1 v20.4s, v28.4s, v29.4s + zip2 v21.4s, v28.4s, v29.4s + zip1 v22.4s, v30.4s, v31.4s + zip2 v23.4s, v30.4s, v31.4s + ld1 {v16.4s}, [x26], x12 + ld1 {v17.4s}, [x26], x12 + trn1 v28.2d, v20.2d, v22.2d + trn2 v29.2d, v20.2d, v22.2d + trn1 v30.2d, v21.2d, v23.2d + trn2 v31.2d, v21.2d, v23.2d + ld1 {v18.4s}, [x26], x12 + ld1 {v19.4s}, [x26], x12 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v28.4s, v17.s[0] + fmla v0.4s, v29.4s, v16.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v0.4s, v30.4s, v16.s[2] + fmla v1.4s, v30.4s, v17.s[2] + fmla v0.4s, v31.4s, v16.s[3] + fmla v1.4s, v31.4s, v17.s[3] + ld1 {v20.4s}, [x26], x12 + ld1 {v21.4s}, [x26], x12 + fmla v2.4s, v28.4s, v18.s[0] + fmla v3.4s, v28.4s, v19.s[0] + fmla v2.4s, v29.4s, v18.s[1] + fmla v3.4s, v29.4s, v19.s[1] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v30.4s, v19.s[2] + fmla v2.4s, v31.4s, v18.s[3] + fmla v3.4s, v31.4s, v19.s[3] + ld1 {v22.4s}, [x26], x12 + ld1 {v23.4s}, [x26], x12 + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v28.4s, v21.s[0] + fmla v4.4s, v29.4s, v20.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v4.4s, v30.4s, v20.s[2] + fmla v5.4s, v30.4s, v21.s[2] + fmla v4.4s, v31.4s, v20.s[3] + fmla v5.4s, v31.4s, v21.s[3] + fmla v6.4s, v28.4s, v22.s[0] + fmla v7.4s, v28.4s, v23.s[0] + fmla v6.4s, v29.4s, v22.s[1] + fmla v7.4s, v29.4s, v23.s[1] + fmla v6.4s, v30.4s, v22.s[2] + fmla v7.4s, v30.4s, v23.s[2] + fmla v6.4s, v31.4s, v22.s[3] + fmla v7.4s, v31.4s, v23.s[3] + add x21, x21, #16 + add x25, x25, #16 + subs x27, x27, #1 + bgt LoopIc8 + subs x23, x23, #1 + add x24, x24, x14 + bne LoopKw8 + add x20, x20, x13 + subs x22, x22, #1 + bne LoopKh8 + ldr x16, [sp, #272] + cbnz x16, Relu68 + ldr x26, [sp, #264] + cbnz x26, Relu8 + b Write8 + Relu68: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + Relu8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + Write8: + st1 {v0.4s}, [x3], x9 + st1 {v1.4s}, [x3], x9 + st1 {v2.4s}, [x3], x9 + st1 {v3.4s}, [x3], x9 + st1 {v4.4s}, [x3], x9 + st1 {v5.4s}, [x3], x9 + st1 {v6.4s}, [x3], x9 + st1 {v7.4s}, [x3], x9 + add x17, x17, x19 + sub x28, x28, #8 + cmp x28, #0 + ble LoopWEnd + cmp x28, #8 + bge LoopW8 + LoopW: + mov x20, x17 + mov x21, x2 + mov x22, x6 + mov v0.16b, v25.16b + LoopKh: + mov x23, x7 + mov x24, x20 + LoopKw: + mov x25, x24 + mov x27, x10 + LoopIc: + mov x26, x25 + mov x16, x21 + ld1 {v28.4s}, [x16], x15 + ld1 {v29.4s}, [x16], x15 + ld1 {v30.4s}, [x16], x15 + ld1 {v31.4s}, [x16], x15 + zip1 v20.4s, v28.4s, v29.4s + zip2 v21.4s, v28.4s, v29.4s + zip1 v22.4s, v30.4s, v31.4s + zip2 v23.4s, v30.4s, v31.4s + ld1 {v16.4s}, [x26], x12 + trn1 v28.2d, v20.2d, v22.2d + trn2 v29.2d, v20.2d, v22.2d + trn1 v30.2d, v21.2d, v23.2d + trn2 v31.2d, v21.2d, v23.2d + fmla v0.4s, v28.4s, v16.s[0] + fmla v0.4s, v29.4s, v16.s[1] + fmla v0.4s, v30.4s, v16.s[2] + fmla v0.4s, v31.4s, v16.s[3] + add x21, x21, #16 + add x25, x25, #16 + subs x27, x27, #1 + bgt LoopIc + subs x23, x23, #1 + add x24, x24, x14 + bne LoopKw + add x20, x20, x13 + subs x22, x22, #1 + bne LoopKh + ldr x16, [sp, #272] + cbnz x16, Relu6 + ldr x26, [sp, #264] + cbnz x26, Relu + b Write + Relu6: + fmin v0.4s, v0.4s, v26.4s + Relu: + fmax v0.4s, v0.4s, v27.4s + Write: + st1 {v0.4s}, [x3], x9 + add x17, x17, x12 + subs x28, x28, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x11 + subs x4, x4, #1 + bne LoopH + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW1x16Kernel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW1x16Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..3b436c17e728544b50932291feb3845eab50adf8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW1x16Kernel.S @@ -0,0 +1,421 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv1x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + prfm pldl1keep, [x23] + mov x24, x23 + mov x25, x10 + subs x25, x25, #16 + blt LoopC12 + LoopC16: + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x24], #64 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v2.4s, v18.4s, v4.s[2] + fmla v3.4s, v19.4s, v4.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[3] + fmla v1.4s, v17.4s, v4.s[3] + fmla v2.4s, v18.4s, v4.s[3] + fmla v3.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v2.4s, v18.4s, v5.s[0] + fmla v3.4s, v19.4s, v5.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[1] + fmla v1.4s, v17.4s, v5.s[1] + fmla v2.4s, v18.4s, v5.s[1] + fmla v3.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v2.4s, v18.4s, v5.s[2] + fmla v3.4s, v19.4s, v5.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[3] + fmla v1.4s, v17.4s, v5.s[3] + fmla v2.4s, v18.4s, v5.s[3] + fmla v3.4s, v19.4s, v5.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[0] + fmla v1.4s, v17.4s, v6.s[0] + fmla v2.4s, v18.4s, v6.s[0] + fmla v3.4s, v19.4s, v6.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[1] + fmla v1.4s, v17.4s, v6.s[1] + fmla v2.4s, v18.4s, v6.s[1] + fmla v3.4s, v19.4s, v6.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[2] + fmla v1.4s, v17.4s, v6.s[2] + fmla v2.4s, v18.4s, v6.s[2] + fmla v3.4s, v19.4s, v6.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[3] + fmla v1.4s, v17.4s, v6.s[3] + fmla v2.4s, v18.4s, v6.s[3] + fmla v3.4s, v19.4s, v6.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[0] + fmla v1.4s, v17.4s, v7.s[0] + fmla v2.4s, v18.4s, v7.s[0] + fmla v3.4s, v19.4s, v7.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[1] + fmla v1.4s, v17.4s, v7.s[1] + fmla v2.4s, v18.4s, v7.s[1] + fmla v3.4s, v19.4s, v7.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[2] + fmla v1.4s, v17.4s, v7.s[2] + fmla v2.4s, v18.4s, v7.s[2] + fmla v3.4s, v19.4s, v7.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[3] + fmla v1.4s, v17.4s, v7.s[3] + fmla v2.4s, v18.4s, v7.s[3] + fmla v3.4s, v19.4s, v7.s[3] + subs x25, x25, #16 + bge LoopC16 + LoopC12: + adds x25, x25, #16 + cbz x25, LoopCEnd + cmp x25, #12 + blt LoopC8 + ld1 {v4.4s, v5.4s, v6.4s}, [x24], #48 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v2.4s, v18.4s, v4.s[2] + fmla v3.4s, v19.4s, v4.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[3] + fmla v1.4s, v17.4s, v4.s[3] + fmla v2.4s, v18.4s, v4.s[3] + fmla v3.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v2.4s, v18.4s, v5.s[0] + fmla v3.4s, v19.4s, v5.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[1] + fmla v1.4s, v17.4s, v5.s[1] + fmla v2.4s, v18.4s, v5.s[1] + fmla v3.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v2.4s, v18.4s, v5.s[2] + fmla v3.4s, v19.4s, v5.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[3] + fmla v1.4s, v17.4s, v5.s[3] + fmla v2.4s, v18.4s, v5.s[3] + fmla v3.4s, v19.4s, v5.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[0] + fmla v1.4s, v17.4s, v6.s[0] + fmla v2.4s, v18.4s, v6.s[0] + fmla v3.4s, v19.4s, v6.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[1] + fmla v1.4s, v17.4s, v6.s[1] + fmla v2.4s, v18.4s, v6.s[1] + fmla v3.4s, v19.4s, v6.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[2] + fmla v1.4s, v17.4s, v6.s[2] + fmla v2.4s, v18.4s, v6.s[2] + fmla v3.4s, v19.4s, v6.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[3] + fmla v1.4s, v17.4s, v6.s[3] + fmla v2.4s, v18.4s, v6.s[3] + fmla v3.4s, v19.4s, v6.s[3] + sub x25, x25, #12 + b LoopCTail + LoopC8: + cmp x25, #8 + blt LoopC4 + ld1 {v4.4s, v5.4s}, [x24], #32 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v2.4s, v18.4s, v4.s[2] + fmla v3.4s, v19.4s, v4.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[3] + fmla v1.4s, v17.4s, v4.s[3] + fmla v2.4s, v18.4s, v4.s[3] + fmla v3.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v2.4s, v18.4s, v5.s[0] + fmla v3.4s, v19.4s, v5.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[1] + fmla v1.4s, v17.4s, v5.s[1] + fmla v2.4s, v18.4s, v5.s[1] + fmla v3.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v2.4s, v18.4s, v5.s[2] + fmla v3.4s, v19.4s, v5.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[3] + fmla v1.4s, v17.4s, v5.s[3] + fmla v2.4s, v18.4s, v5.s[3] + fmla v3.4s, v19.4s, v5.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v4.4s}, [x24], #16 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v2.4s, v18.4s, v4.s[2] + fmla v3.4s, v19.4s, v4.s[2] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[3] + fmla v1.4s, v17.4s, v4.s[3] + fmla v2.4s, v18.4s, v4.s[3] + fmla v3.4s, v19.4s, v4.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + cmp x25, #2 + beq LoopC2 + cmp x25, #1 + beq LoopC1 + // LoopC3 + ld3r {v4.4s, v5.4s, v6.4s}, [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.4s + fmla v1.4s, v17.4s, v4.4s + fmla v2.4s, v18.4s, v4.4s + fmla v3.4s, v19.4s, v4.4s + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.4s + fmla v1.4s, v17.4s, v5.4s + fmla v2.4s, v18.4s, v5.4s + fmla v3.4s, v19.4s, v5.4s + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.4s + fmla v1.4s, v17.4s, v6.4s + fmla v2.4s, v18.4s, v6.4s + fmla v3.4s, v19.4s, v6.4s + b LoopCEnd + LoopC2: + ld1 {v4.d}[0], [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v2.4s, v18.4s, v4.s[0] + fmla v3.4s, v19.4s, v4.s[0] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[1] + fmla v1.4s, v17.4s, v4.s[1] + fmla v2.4s, v18.4s, v4.s[1] + fmla v3.4s, v19.4s, v4.s[1] + b LoopCEnd + LoopC1: + ld1r {v4.4s}, [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.4s + fmla v1.4s, v17.4s, v4.4s + fmla v2.4s, v18.4s, v4.4s + fmla v3.4s, v19.4s, v4.4s + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v4.2d, xzr // relu + fmax v0.4s, v0.4s, v4.4s + fmax v1.4s, v1.4s, v4.4s + fmax v2.4s, v2.4s, v4.4s + fmax v3.4s, v3.4s, v4.4s + + ands x6, x6, #1 + beq WriteBack + movi v4.4s, #6 // relu6 + scvtf v4.4s, v4.4s + fmin v0.4s, v0.4s, v4.4s + fmin v1.4s, v1.4s, v4.4s + fmin v2.4s, v2.4s, v4.4s + fmin v3.4s, v3.4s, v4.4s + fmin v4.4s, v4.4s, v4.4s + + WriteBack: + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + b End + NC4HW4: + add x21, x0, x7, LSL #1 + add x22, x20, x7, LSL #1 + st1 {v0.4s}, [x0] + st1 {v1.4s}, [x20] + st1 {v2.4s}, [x21] + st1 {v3.4s}, [x22] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW1x8Kernel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW1x8Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..6a29e95e3c8fd3e4c3b0b730be277aa738508fc8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW1x8Kernel.S @@ -0,0 +1,278 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv1x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + prfm pldl1keep, [x23] + mov x24, x23 + mov x25, x10 + subs x25, x25, #16 + blt LoopC12 + LoopC16: + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x24], #64 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v0.4s, v18.4s, v4.s[3] + fmla v1.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v0.4s, v18.4s, v5.s[1] + fmla v1.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v0.4s, v18.4s, v5.s[3] + fmla v1.4s, v19.4s, v5.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[0] + fmla v1.4s, v17.4s, v6.s[0] + fmla v0.4s, v18.4s, v6.s[1] + fmla v1.4s, v19.4s, v6.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[2] + fmla v1.4s, v17.4s, v6.s[2] + fmla v0.4s, v18.4s, v6.s[3] + fmla v1.4s, v19.4s, v6.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[0] + fmla v1.4s, v17.4s, v7.s[0] + fmla v0.4s, v18.4s, v7.s[1] + fmla v1.4s, v19.4s, v7.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v7.s[2] + fmla v1.4s, v17.4s, v7.s[2] + fmla v0.4s, v18.4s, v7.s[3] + fmla v1.4s, v19.4s, v7.s[3] + subs x25, x25, #16 + bge LoopC16 + LoopC12: + adds x25, x25, #16 + cbz x25, LoopCEnd + cmp x25, #12 + blt LoopC8 + ld1 {v4.4s, v5.4s, v6.4s}, [x24], #48 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v0.4s, v18.4s, v4.s[3] + fmla v1.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v0.4s, v18.4s, v5.s[1] + fmla v1.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v0.4s, v18.4s, v5.s[3] + fmla v1.4s, v19.4s, v5.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[0] + fmla v1.4s, v17.4s, v6.s[0] + fmla v0.4s, v18.4s, v6.s[1] + fmla v1.4s, v19.4s, v6.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v6.s[2] + fmla v1.4s, v17.4s, v6.s[2] + fmla v0.4s, v18.4s, v6.s[3] + fmla v1.4s, v19.4s, v6.s[3] + sub x25, x25, #12 + b LoopCTail + LoopC8: + cmp x25, #8 + blt LoopC4 + ld1 {v4.4s, v5.4s}, [x24], #32 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v0.4s, v18.4s, v4.s[3] + fmla v1.4s, v19.4s, v4.s[3] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[0] + fmla v1.4s, v17.4s, v5.s[0] + fmla v0.4s, v18.4s, v5.s[1] + fmla v1.4s, v19.4s, v5.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v5.s[2] + fmla v1.4s, v17.4s, v5.s[2] + fmla v0.4s, v18.4s, v5.s[3] + fmla v1.4s, v19.4s, v5.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v4.4s}, [x24], #16 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[2] + fmla v1.4s, v17.4s, v4.s[2] + fmla v0.4s, v18.4s, v4.s[3] + fmla v1.4s, v19.4s, v4.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + cmp x25, #2 + beq LoopC2 + cmp x25, #1 + beq LoopC1 + // LoopC3 + ld3r {v4.4s, v5.4s, v6.4s}, [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + ld1 {v20.4s, v21.4s}, [x2], #32 + fmla v0.4s, v16.4s, v4.4s + fmla v1.4s, v17.4s, v4.4s + fmla v0.4s, v18.4s, v5.4s + fmla v1.4s, v19.4s, v5.4s + fmla v0.4s, v20.4s, v6.4s + fmla v1.4s, v21.4s, v6.4s + b LoopCEnd + LoopC2: + ld1 {v4.2s}, [x24] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + fmla v0.4s, v16.4s, v4.s[0] + fmla v1.4s, v17.4s, v4.s[0] + fmla v0.4s, v18.4s, v4.s[1] + fmla v1.4s, v19.4s, v4.s[1] + b LoopCEnd + LoopC1: + ld1r {v4.4s}, [x24] + ld1 {v16.4s, v17.4s}, [x2], #32 + fmla v0.4s, v16.4s, v4.4s + fmla v1.4s, v17.4s, v4.4s + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v4.2d, xzr // relu + fmax v0.4s, v0.4s, v4.4s + fmax v1.4s, v1.4s, v4.4s + fmax v2.4s, v2.4s, v4.4s + fmax v3.4s, v3.4s, v4.4s + + ands x6, x6, #1 + beq WriteBack + movi v4.4s, #6 // relu6 + scvtf v4.4s, v4.4s + fmin v0.4s, v0.4s, v4.4s + fmin v1.4s, v1.4s, v4.4s + fmin v2.4s, v2.4s, v4.4s + fmin v3.4s, v3.4s, v4.4s + fmin v4.4s, v4.4s, v4.4s + + WriteBack: + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s}, [x0] + b End + NC4HW4: + st1 {v0.4s}, [x0] + st1 {v1.4s}, [x20] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW2x16Kernel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW2x16Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..8a5dd83a0736ca763a967e922ed26ec85e1cc074 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW2x16Kernel.S @@ -0,0 +1,407 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void SWConv2x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv2x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v30.4s, v18.s[0] + fmla v3.4s, v31.4s, v18.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[1] + fmla v1.4s, v29.4s, v18.s[1] + fmla v2.4s, v30.4s, v18.s[1] + fmla v3.4s, v31.4s, v18.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v31.4s, v18.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[3] + fmla v1.4s, v29.4s, v18.s[3] + fmla v2.4s, v30.4s, v18.s[3] + fmla v3.4s, v31.4s, v18.s[3] + fmla v4.4s, v28.4s, v21.s[3] + fmla v5.4s, v29.4s, v21.s[3] + fmla v6.4s, v30.4s, v21.s[3] + fmla v7.4s, v31.4s, v21.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + + WriteBack: + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + b End + NC4HW4: + add x21, x0, x7, LSL #1 + add x22, x21, x7 + st1 {v0.4s}, [x0], #16 + st1 {v4.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v5.4s}, [x20] + st1 {v2.4s}, [x21], #16 + st1 {v6.4s}, [x21] + st1 {v3.4s}, [x22], #16 + st1 {v7.4s}, [x22] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW2x8Kernel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW2x8Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..6efd21d0e73639a0ec8ecc71e9bb93333b44d594 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW2x8Kernel.S @@ -0,0 +1,265 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void SWConv2x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv2x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + ld1 {v2.4s, v3.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v0.4s, v30.4s, v18.s[1] + fmla v1.4s, v31.4s, v18.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v28.4s, v21.s[2] + fmla v3.4s, v29.4s, v21.s[2] + fmla v0.4s, v30.4s, v18.s[3] + fmla v1.4s, v31.4s, v18.s[3] + fmla v2.4s, v30.4s, v21.s[3] + fmla v3.4s, v31.4s, v21.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v28.4s, v29.4s}, [x2], #32 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + + WriteBack: + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s}, [x0] + st1 {v2.4s, v3.4s}, [x20] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v2.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v3.4s}, [x20] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW3x16Kernel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW3x16Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..428dea698e13a0860db083c4420040a6cf65e111 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW3x16Kernel.S @@ -0,0 +1,533 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void SWConv3x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv3x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + stp x19, x20, [sp, #64] + stp x21, x22, [sp, #80] + stp x23, x24, [sp, #96] + stp x25, x26, [sp, #112] + + ldr x10, [sp, #128] + ldr x11, [sp, #136] + ldr x12, [sp, #144] + ldr x13, [sp, #152] + ldr x14, [sp, #160] + ldr x15, [sp, #168] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3] + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x19, x23, x13, lsl #1 + prfm pldl1keep, [x19] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v22.4s, v23.4s, v24.4s}, [x19], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + fmla v8.4s, v28.4s, v23.s[0] + fmla v9.4s, v29.4s, v23.s[0] + fmla v10.4s, v30.4s, v23.s[0] + fmla v11.4s, v31.4s, v23.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + fmla v8.4s, v28.4s, v23.s[1] + fmla v9.4s, v29.4s, v23.s[1] + fmla v10.4s, v30.4s, v23.s[1] + fmla v11.4s, v31.4s, v23.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + fmla v8.4s, v28.4s, v23.s[2] + fmla v9.4s, v29.4s, v23.s[2] + fmla v10.4s, v30.4s, v23.s[2] + fmla v11.4s, v31.4s, v23.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + fmla v8.4s, v28.4s, v23.s[3] + fmla v9.4s, v29.4s, v23.s[3] + fmla v10.4s, v30.4s, v23.s[3] + fmla v11.4s, v31.4s, v23.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v30.4s, v18.s[0] + fmla v3.4s, v31.4s, v18.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + fmla v10.4s, v30.4s, v24.s[0] + fmla v11.4s, v31.4s, v24.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[1] + fmla v1.4s, v29.4s, v18.s[1] + fmla v2.4s, v30.4s, v18.s[1] + fmla v3.4s, v31.4s, v18.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v24.s[1] + fmla v9.4s, v29.4s, v24.s[1] + fmla v10.4s, v30.4s, v24.s[1] + fmla v11.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v31.4s, v18.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + fmla v8.4s, v28.4s, v24.s[2] + fmla v9.4s, v29.4s, v24.s[2] + fmla v10.4s, v30.4s, v24.s[2] + fmla v11.4s, v31.4s, v24.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[3] + fmla v1.4s, v29.4s, v18.s[3] + fmla v2.4s, v30.4s, v18.s[3] + fmla v3.4s, v31.4s, v18.s[3] + fmla v4.4s, v28.4s, v21.s[3] + fmla v5.4s, v29.4s, v21.s[3] + fmla v6.4s, v30.4s, v21.s[3] + fmla v7.4s, v31.4s, v21.s[3] + fmla v8.4s, v28.4s, v24.s[3] + fmla v9.4s, v29.4s, v24.s[3] + fmla v10.4s, v30.4s, v24.s[3] + fmla v11.4s, v31.4s, v24.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v22.4s, v23.4s}, [x19], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + fmla v8.4s, v28.4s, v23.s[0] + fmla v9.4s, v29.4s, v23.s[0] + fmla v10.4s, v30.4s, v23.s[0] + fmla v11.4s, v31.4s, v23.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + fmla v8.4s, v28.4s, v23.s[1] + fmla v9.4s, v29.4s, v23.s[1] + fmla v10.4s, v30.4s, v23.s[1] + fmla v11.4s, v31.4s, v23.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + fmla v8.4s, v28.4s, v23.s[2] + fmla v9.4s, v29.4s, v23.s[2] + fmla v10.4s, v30.4s, v23.s[2] + fmla v11.4s, v31.4s, v23.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + fmla v8.4s, v28.4s, v23.s[3] + fmla v9.4s, v29.4s, v23.s[3] + fmla v10.4s, v30.4s, v23.s[3] + fmla v11.4s, v31.4s, v23.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v22.4s}, [x19], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v22.s}[0], [x19], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + fmax v8.4s, v8.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + fmax v10.4s, v10.4s, v24.4s + fmax v11.4s, v11.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + fmin v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v24.4s + fmin v10.4s, v10.4s, v24.4s + fmin v11.4s, v11.4s, v24.4s + + WriteBack: + add x21, x0, x7, LSL #1 + add x22, x21, x7 + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x21] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v8.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v9.4s}, [x20] + st1 {v2.4s}, [x21], #16 + st1 {v6.4s}, [x21], #16 + st1 {v10.4s}, [x21] + st1 {v3.4s}, [x22], #16 + st1 {v7.4s}, [x22], #16 + st1 {v11.4s}, [x22] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW3x8Kernel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW3x8Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..472e50b9053300d86fd2dbc5429f509b2f862170 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW3x8Kernel.S @@ -0,0 +1,332 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void SWConv3x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv3x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #64 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + stp x25, x26, [sp, #48] + + ldr x10, [sp, #64] + ldr x11, [sp, #72] + ldr x12, [sp, #80] + ldr x13, [sp, #88] + ldr x14, [sp, #96] + ldr x15, [sp, #104] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + ld1 {v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x19, x23, x13, lsl #1 + prfm pldl1keep, [x19] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v22.4s, v23.4s, v24.4s}, [x19], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v4.4s, v28.4s, v23.s[0] + fmla v5.4s, v29.4s, v23.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v30.4s, v23.s[1] + fmla v5.4s, v31.4s, v23.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v4.4s, v28.4s, v23.s[2] + fmla v5.4s, v29.4s, v23.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v30.4s, v23.s[3] + fmla v5.4s, v31.4s, v23.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v24.s[0] + fmla v5.4s, v29.4s, v24.s[0] + fmla v0.4s, v30.4s, v18.s[1] + fmla v1.4s, v31.4s, v18.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v24.s[1] + fmla v5.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v28.4s, v21.s[2] + fmla v3.4s, v29.4s, v21.s[2] + fmla v4.4s, v28.4s, v24.s[2] + fmla v5.4s, v29.4s, v24.s[2] + fmla v0.4s, v30.4s, v18.s[3] + fmla v1.4s, v31.4s, v18.s[3] + fmla v2.4s, v30.4s, v21.s[3] + fmla v3.4s, v31.4s, v21.s[3] + fmla v4.4s, v30.4s, v24.s[3] + fmla v5.4s, v31.4s, v24.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v22.4s, v23.4s}, [x19], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v4.4s, v28.4s, v23.s[0] + fmla v5.4s, v29.4s, v23.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v30.4s, v23.s[1] + fmla v5.4s, v31.4s, v23.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v4.4s, v28.4s, v23.s[2] + fmla v5.4s, v29.4s, v23.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v30.4s, v23.s[3] + fmla v5.4s, v31.4s, v23.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v22.4s}, [x19], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v22.s}[0], [x19], #4 + ld1 {v28.4s, v29.4s}, [x2], #32 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + + WriteBack: + add x21, x0, x7, LSL #1 + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s}, [x0] + st1 {v2.4s, v3.4s}, [x20] + st1 {v4.4s, v5.4s}, [x21] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v2.4s}, [x0], #16 + st1 {v4.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v3.4s}, [x20], #16 + st1 {v5.4s}, [x20] + End: + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW4x16Kernel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW4x16Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..076724a7117c27c0cbf09eb4933bc5f44dc618b8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW4x16Kernel.S @@ -0,0 +1,662 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void SWConv4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv4x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3] + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x27, x23, x13, lsl #1 + prfm pldl1keep, [x27] + add x28, x27, x13 + prfm pldl1keep, [x28] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v22.4s, v23.4s, v24.4s}, [x27], #48 + ld1 {v25.4s, v26.4s, v27.4s}, [x28], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v25.s[0] + fmla v13.4s, v29.4s, v25.s[0] + fmla v14.4s, v30.4s, v25.s[0] + fmla v15.4s, v31.4s, v25.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v25.s[1] + fmla v13.4s, v29.4s, v25.s[1] + fmla v14.4s, v30.4s, v25.s[1] + fmla v15.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v25.s[2] + fmla v13.4s, v29.4s, v25.s[2] + fmla v14.4s, v30.4s, v25.s[2] + fmla v15.4s, v31.4s, v25.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + fmla v12.4s, v28.4s, v25.s[3] + fmla v13.4s, v29.4s, v25.s[3] + fmla v14.4s, v30.4s, v25.s[3] + fmla v15.4s, v31.4s, v25.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + fmla v8.4s, v28.4s, v23.s[0] + fmla v9.4s, v29.4s, v23.s[0] + fmla v10.4s, v30.4s, v23.s[0] + fmla v11.4s, v31.4s, v23.s[0] + fmla v12.4s, v28.4s, v26.s[0] + fmla v13.4s, v29.4s, v26.s[0] + fmla v14.4s, v30.4s, v26.s[0] + fmla v15.4s, v31.4s, v26.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + fmla v8.4s, v28.4s, v23.s[1] + fmla v9.4s, v29.4s, v23.s[1] + fmla v10.4s, v30.4s, v23.s[1] + fmla v11.4s, v31.4s, v23.s[1] + fmla v12.4s, v28.4s, v26.s[1] + fmla v13.4s, v29.4s, v26.s[1] + fmla v14.4s, v30.4s, v26.s[1] + fmla v15.4s, v31.4s, v26.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + fmla v8.4s, v28.4s, v23.s[2] + fmla v9.4s, v29.4s, v23.s[2] + fmla v10.4s, v30.4s, v23.s[2] + fmla v11.4s, v31.4s, v23.s[2] + fmla v12.4s, v28.4s, v26.s[2] + fmla v13.4s, v29.4s, v26.s[2] + fmla v14.4s, v30.4s, v26.s[2] + fmla v15.4s, v31.4s, v26.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + fmla v8.4s, v28.4s, v23.s[3] + fmla v9.4s, v29.4s, v23.s[3] + fmla v10.4s, v30.4s, v23.s[3] + fmla v11.4s, v31.4s, v23.s[3] + fmla v12.4s, v28.4s, v26.s[3] + fmla v13.4s, v29.4s, v26.s[3] + fmla v14.4s, v30.4s, v26.s[3] + fmla v15.4s, v31.4s, v26.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v30.4s, v18.s[0] + fmla v3.4s, v31.4s, v18.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + fmla v10.4s, v30.4s, v24.s[0] + fmla v11.4s, v31.4s, v24.s[0] + fmla v12.4s, v28.4s, v27.s[0] + fmla v13.4s, v29.4s, v27.s[0] + fmla v14.4s, v30.4s, v27.s[0] + fmla v15.4s, v31.4s, v27.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[1] + fmla v1.4s, v29.4s, v18.s[1] + fmla v2.4s, v30.4s, v18.s[1] + fmla v3.4s, v31.4s, v18.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v24.s[1] + fmla v9.4s, v29.4s, v24.s[1] + fmla v10.4s, v30.4s, v24.s[1] + fmla v11.4s, v31.4s, v24.s[1] + fmla v12.4s, v28.4s, v27.s[1] + fmla v13.4s, v29.4s, v27.s[1] + fmla v14.4s, v30.4s, v27.s[1] + fmla v15.4s, v31.4s, v27.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v30.4s, v18.s[2] + fmla v3.4s, v31.4s, v18.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + fmla v8.4s, v28.4s, v24.s[2] + fmla v9.4s, v29.4s, v24.s[2] + fmla v10.4s, v30.4s, v24.s[2] + fmla v11.4s, v31.4s, v24.s[2] + fmla v12.4s, v28.4s, v27.s[2] + fmla v13.4s, v29.4s, v27.s[2] + fmla v14.4s, v30.4s, v27.s[2] + fmla v15.4s, v31.4s, v27.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[3] + fmla v1.4s, v29.4s, v18.s[3] + fmla v2.4s, v30.4s, v18.s[3] + fmla v3.4s, v31.4s, v18.s[3] + fmla v4.4s, v28.4s, v21.s[3] + fmla v5.4s, v29.4s, v21.s[3] + fmla v6.4s, v30.4s, v21.s[3] + fmla v7.4s, v31.4s, v21.s[3] + fmla v8.4s, v28.4s, v24.s[3] + fmla v9.4s, v29.4s, v24.s[3] + fmla v10.4s, v30.4s, v24.s[3] + fmla v11.4s, v31.4s, v24.s[3] + fmla v12.4s, v28.4s, v27.s[3] + fmla v13.4s, v29.4s, v27.s[3] + fmla v14.4s, v30.4s, v27.s[3] + fmla v15.4s, v31.4s, v27.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v22.4s, v23.4s}, [x27], #32 + ld1 {v25.4s, v26.4s}, [x28], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v25.s[0] + fmla v13.4s, v29.4s, v25.s[0] + fmla v14.4s, v30.4s, v25.s[0] + fmla v15.4s, v31.4s, v25.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v25.s[1] + fmla v13.4s, v29.4s, v25.s[1] + fmla v14.4s, v30.4s, v25.s[1] + fmla v15.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v25.s[2] + fmla v13.4s, v29.4s, v25.s[2] + fmla v14.4s, v30.4s, v25.s[2] + fmla v15.4s, v31.4s, v25.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + fmla v12.4s, v28.4s, v25.s[3] + fmla v13.4s, v29.4s, v25.s[3] + fmla v14.4s, v30.4s, v25.s[3] + fmla v15.4s, v31.4s, v25.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v30.4s, v17.s[0] + fmla v3.4s, v31.4s, v17.s[0] + fmla v4.4s, v28.4s, v20.s[0] + fmla v5.4s, v29.4s, v20.s[0] + fmla v6.4s, v30.4s, v20.s[0] + fmla v7.4s, v31.4s, v20.s[0] + fmla v8.4s, v28.4s, v23.s[0] + fmla v9.4s, v29.4s, v23.s[0] + fmla v10.4s, v30.4s, v23.s[0] + fmla v11.4s, v31.4s, v23.s[0] + fmla v12.4s, v28.4s, v26.s[0] + fmla v13.4s, v29.4s, v26.s[0] + fmla v14.4s, v30.4s, v26.s[0] + fmla v15.4s, v31.4s, v26.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[1] + fmla v1.4s, v29.4s, v17.s[1] + fmla v2.4s, v30.4s, v17.s[1] + fmla v3.4s, v31.4s, v17.s[1] + fmla v4.4s, v28.4s, v20.s[1] + fmla v5.4s, v29.4s, v20.s[1] + fmla v6.4s, v30.4s, v20.s[1] + fmla v7.4s, v31.4s, v20.s[1] + fmla v8.4s, v28.4s, v23.s[1] + fmla v9.4s, v29.4s, v23.s[1] + fmla v10.4s, v30.4s, v23.s[1] + fmla v11.4s, v31.4s, v23.s[1] + fmla v12.4s, v28.4s, v26.s[1] + fmla v13.4s, v29.4s, v26.s[1] + fmla v14.4s, v30.4s, v26.s[1] + fmla v15.4s, v31.4s, v26.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v30.4s, v17.s[2] + fmla v3.4s, v31.4s, v17.s[2] + fmla v4.4s, v28.4s, v20.s[2] + fmla v5.4s, v29.4s, v20.s[2] + fmla v6.4s, v30.4s, v20.s[2] + fmla v7.4s, v31.4s, v20.s[2] + fmla v8.4s, v28.4s, v23.s[2] + fmla v9.4s, v29.4s, v23.s[2] + fmla v10.4s, v30.4s, v23.s[2] + fmla v11.4s, v31.4s, v23.s[2] + fmla v12.4s, v28.4s, v26.s[2] + fmla v13.4s, v29.4s, v26.s[2] + fmla v14.4s, v30.4s, v26.s[2] + fmla v15.4s, v31.4s, v26.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[3] + fmla v1.4s, v29.4s, v17.s[3] + fmla v2.4s, v30.4s, v17.s[3] + fmla v3.4s, v31.4s, v17.s[3] + fmla v4.4s, v28.4s, v20.s[3] + fmla v5.4s, v29.4s, v20.s[3] + fmla v6.4s, v30.4s, v20.s[3] + fmla v7.4s, v31.4s, v20.s[3] + fmla v8.4s, v28.4s, v23.s[3] + fmla v9.4s, v29.4s, v23.s[3] + fmla v10.4s, v30.4s, v23.s[3] + fmla v11.4s, v31.4s, v23.s[3] + fmla v12.4s, v28.4s, v26.s[3] + fmla v13.4s, v29.4s, v26.s[3] + fmla v14.4s, v30.4s, v26.s[3] + fmla v15.4s, v31.4s, v26.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v22.4s}, [x27], #16 + ld1 {v25.4s}, [x28], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v25.s[0] + fmla v13.4s, v29.4s, v25.s[0] + fmla v14.4s, v30.4s, v25.s[0] + fmla v15.4s, v31.4s, v25.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[1] + fmla v1.4s, v29.4s, v16.s[1] + fmla v2.4s, v30.4s, v16.s[1] + fmla v3.4s, v31.4s, v16.s[1] + fmla v4.4s, v28.4s, v19.s[1] + fmla v5.4s, v29.4s, v19.s[1] + fmla v6.4s, v30.4s, v19.s[1] + fmla v7.4s, v31.4s, v19.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v25.s[1] + fmla v13.4s, v29.4s, v25.s[1] + fmla v14.4s, v30.4s, v25.s[1] + fmla v15.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v30.4s, v16.s[2] + fmla v3.4s, v31.4s, v16.s[2] + fmla v4.4s, v28.4s, v19.s[2] + fmla v5.4s, v29.4s, v19.s[2] + fmla v6.4s, v30.4s, v19.s[2] + fmla v7.4s, v31.4s, v19.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v25.s[2] + fmla v13.4s, v29.4s, v25.s[2] + fmla v14.4s, v30.4s, v25.s[2] + fmla v15.4s, v31.4s, v25.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[3] + fmla v1.4s, v29.4s, v16.s[3] + fmla v2.4s, v30.4s, v16.s[3] + fmla v3.4s, v31.4s, v16.s[3] + fmla v4.4s, v28.4s, v19.s[3] + fmla v5.4s, v29.4s, v19.s[3] + fmla v6.4s, v30.4s, v19.s[3] + fmla v7.4s, v31.4s, v19.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + fmla v12.4s, v28.4s, v25.s[3] + fmla v13.4s, v29.4s, v25.s[3] + fmla v14.4s, v30.4s, v25.s[3] + fmla v15.4s, v31.4s, v25.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v22.s}[0], [x27], #4 + ld1 {v25.s}[0], [x28], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v30.4s, v16.s[0] + fmla v3.4s, v31.4s, v16.s[0] + fmla v4.4s, v28.4s, v19.s[0] + fmla v5.4s, v29.4s, v19.s[0] + fmla v6.4s, v30.4s, v19.s[0] + fmla v7.4s, v31.4s, v19.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v25.s[0] + fmla v13.4s, v29.4s, v25.s[0] + fmla v14.4s, v30.4s, v25.s[0] + fmla v15.4s, v31.4s, v25.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + fmax v8.4s, v8.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + fmax v10.4s, v10.4s, v24.4s + fmax v11.4s, v11.4s, v24.4s + fmax v12.4s, v12.4s, v24.4s + fmax v13.4s, v13.4s, v24.4s + fmax v14.4s, v14.4s, v24.4s + fmax v15.4s, v15.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + fmin v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v24.4s + fmin v10.4s, v10.4s, v24.4s + fmin v11.4s, v11.4s, v24.4s + fmin v12.4s, v12.4s, v24.4s + fmin v13.4s, v13.4s, v24.4s + fmin v14.4s, v14.4s, v24.4s + fmin v15.4s, v15.4s, v24.4s + + WriteBack: + add x21, x0, x7, LSL #1 + add x22, x21, x7 + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x21] + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x22] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v8.4s}, [x0], #16 + st1 {v12.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v9.4s}, [x20], #16 + st1 {v13.4s}, [x20] + st1 {v2.4s}, [x21], #16 + st1 {v6.4s}, [x21], #16 + st1 {v10.4s}, [x21], #16 + st1 {v14.4s}, [x21] + st1 {v3.4s}, [x22], #16 + st1 {v7.4s}, [x22], #16 + st1 {v11.4s}, [x22], #16 + st1 {v15.4s}, [x22] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW4x8Kernel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW4x8Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..6b24de97b2efacbf85896283b871ea101bd9106b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW4x8Kernel.S @@ -0,0 +1,406 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv4x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + add x20, x0, x7 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + ld1 {v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s}, [x3] + ld1 {v6.4s, v7.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x27, x23, x13, lsl #1 + prfm pldl1keep, [x27] + add x28, x27, x13 + prfm pldl1keep, [x28] + subs x25, x25, #12 + blt LoopC8 + LoopC12: + ld1 {v16.4s, v17.4s, v18.4s}, [x24], #48 + ld1 {v19.4s, v20.4s, v21.4s}, [x26], #48 + ld1 {v22.4s, v23.4s, v24.4s}, [x27], #48 + ld1 {v25.4s, v26.4s, v27.4s}, [x28], #48 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v25.s[0] + fmla v7.4s, v29.4s, v25.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v25.s[1] + fmla v7.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v6.4s, v28.4s, v25.s[2] + fmla v7.4s, v29.4s, v25.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + fmla v6.4s, v30.4s, v25.s[3] + fmla v7.4s, v31.4s, v25.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v4.4s, v28.4s, v23.s[0] + fmla v5.4s, v29.4s, v23.s[0] + fmla v6.4s, v28.4s, v26.s[0] + fmla v7.4s, v29.4s, v26.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v30.4s, v23.s[1] + fmla v5.4s, v31.4s, v23.s[1] + fmla v6.4s, v30.4s, v26.s[1] + fmla v7.4s, v31.4s, v26.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v4.4s, v28.4s, v23.s[2] + fmla v5.4s, v29.4s, v23.s[2] + fmla v6.4s, v28.4s, v26.s[2] + fmla v7.4s, v29.4s, v26.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v30.4s, v23.s[3] + fmla v5.4s, v31.4s, v23.s[3] + fmla v6.4s, v30.4s, v26.s[3] + fmla v7.4s, v31.4s, v26.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[0] + fmla v1.4s, v29.4s, v18.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v24.s[0] + fmla v5.4s, v29.4s, v24.s[0] + fmla v6.4s, v28.4s, v27.s[0] + fmla v7.4s, v29.4s, v27.s[0] + fmla v0.4s, v30.4s, v18.s[1] + fmla v1.4s, v31.4s, v18.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v24.s[1] + fmla v5.4s, v31.4s, v24.s[1] + fmla v6.4s, v30.4s, v27.s[1] + fmla v7.4s, v31.4s, v27.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v18.s[2] + fmla v1.4s, v29.4s, v18.s[2] + fmla v2.4s, v28.4s, v21.s[2] + fmla v3.4s, v29.4s, v21.s[2] + fmla v4.4s, v28.4s, v24.s[2] + fmla v5.4s, v29.4s, v24.s[2] + fmla v6.4s, v28.4s, v27.s[2] + fmla v7.4s, v29.4s, v27.s[2] + fmla v0.4s, v30.4s, v18.s[3] + fmla v1.4s, v31.4s, v18.s[3] + fmla v2.4s, v30.4s, v21.s[3] + fmla v3.4s, v31.4s, v21.s[3] + fmla v4.4s, v30.4s, v24.s[3] + fmla v5.4s, v31.4s, v24.s[3] + fmla v6.4s, v30.4s, v27.s[3] + fmla v7.4s, v31.4s, v27.s[3] + subs x25, x25, #12 + bge LoopC12 + LoopC8: + adds x25, x25, #12 + cbz x25, LoopCEnd + cmp x25, #8 + blt LoopC4 + ld1 {v16.4s, v17.4s}, [x24], #32 + ld1 {v19.4s, v20.4s}, [x26], #32 + ld1 {v22.4s, v23.4s}, [x27], #32 + ld1 {v25.4s, v26.4s}, [x28], #32 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v25.s[0] + fmla v7.4s, v29.4s, v25.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v25.s[1] + fmla v7.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v6.4s, v28.4s, v25.s[2] + fmla v7.4s, v29.4s, v25.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + fmla v6.4s, v30.4s, v25.s[3] + fmla v7.4s, v31.4s, v25.s[3] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[0] + fmla v1.4s, v29.4s, v17.s[0] + fmla v2.4s, v28.4s, v20.s[0] + fmla v3.4s, v29.4s, v20.s[0] + fmla v4.4s, v28.4s, v23.s[0] + fmla v5.4s, v29.4s, v23.s[0] + fmla v6.4s, v28.4s, v26.s[0] + fmla v7.4s, v29.4s, v26.s[0] + fmla v0.4s, v30.4s, v17.s[1] + fmla v1.4s, v31.4s, v17.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v30.4s, v23.s[1] + fmla v5.4s, v31.4s, v23.s[1] + fmla v6.4s, v30.4s, v26.s[1] + fmla v7.4s, v31.4s, v26.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v17.s[2] + fmla v1.4s, v29.4s, v17.s[2] + fmla v2.4s, v28.4s, v20.s[2] + fmla v3.4s, v29.4s, v20.s[2] + fmla v4.4s, v28.4s, v23.s[2] + fmla v5.4s, v29.4s, v23.s[2] + fmla v6.4s, v28.4s, v26.s[2] + fmla v7.4s, v29.4s, v26.s[2] + fmla v0.4s, v30.4s, v17.s[3] + fmla v1.4s, v31.4s, v17.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v30.4s, v23.s[3] + fmla v5.4s, v31.4s, v23.s[3] + fmla v6.4s, v30.4s, v26.s[3] + fmla v7.4s, v31.4s, v26.s[3] + sub x25, x25, #8 + b LoopCTail + LoopC4: + cmp x25, #4 + blt LoopCTail + ld1 {v16.4s}, [x24], #16 + ld1 {v19.4s}, [x26], #16 + ld1 {v22.4s}, [x27], #16 + ld1 {v25.4s}, [x28], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v25.s[0] + fmla v7.4s, v29.4s, v25.s[0] + fmla v0.4s, v30.4s, v16.s[1] + fmla v1.4s, v31.4s, v16.s[1] + fmla v2.4s, v30.4s, v19.s[1] + fmla v3.4s, v31.4s, v19.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v25.s[1] + fmla v7.4s, v31.4s, v25.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v16.s[2] + fmla v1.4s, v29.4s, v16.s[2] + fmla v2.4s, v28.4s, v19.s[2] + fmla v3.4s, v29.4s, v19.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v6.4s, v28.4s, v25.s[2] + fmla v7.4s, v29.4s, v25.s[2] + fmla v0.4s, v30.4s, v16.s[3] + fmla v1.4s, v31.4s, v16.s[3] + fmla v2.4s, v30.4s, v19.s[3] + fmla v3.4s, v31.4s, v19.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + fmla v6.4s, v30.4s, v25.s[3] + fmla v7.4s, v31.4s, v25.s[3] + sub x25, x25, #4 + LoopCTail: + cbz x25, LoopCEnd + LoopCTailCycle: + ld1 {v16.s}[0], [x24], #4 + ld1 {v19.s}[0], [x26], #4 + ld1 {v22.s}[0], [x27], #4 + ld1 {v25.s}[0], [x28], #4 + ld1 {v28.4s, v29.4s}, [x2], #32 + fmla v0.4s, v28.4s, v16.s[0] + fmla v1.4s, v29.4s, v16.s[0] + fmla v2.4s, v28.4s, v19.s[0] + fmla v3.4s, v29.4s, v19.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v25.s[0] + fmla v7.4s, v29.4s, v25.s[0] + subs x25, x25, #1 + bgt LoopCTailCycle + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + + WriteBack: + add x21, x0, x7, LSL #1 + cmp x15, #13 + beq NC4HW4 + add x22, x21, x7 + st1 {v0.4s, v1.4s}, [x0] + st1 {v2.4s, v3.4s}, [x20] + st1 {v4.4s, v5.4s}, [x21] + st1 {v6.4s, v7.4s}, [x22] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v2.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v6.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v3.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v7.4s}, [x20] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW5x16Kernel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW5x16Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..a2b7ea2c7fd6ebe96d9651bd53627362244f95a2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW5x16Kernel.S @@ -0,0 +1,457 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void SWConv5x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv5x16Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3] + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x3] + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x27, x23, x13, lsl #1 + prfm pldl1keep, [x27] + add x28, x27, x13 + prfm pldl1keep, [x28] + add x20, x23, x13, lsl #2 + prfm pldl1keep, [x20] + subs x25, x25, #4 + blt LoopCTail + LoopC4: + ld1 {v20.4s}, [x24], #16 + ld1 {v21.4s}, [x26], #16 + ld1 {v22.4s}, [x27], #16 + ld1 {v23.4s}, [x28], #16 + ld1 {v24.4s}, [x20], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v30.4s, v20.s[0] + fmla v3.4s, v31.4s, v20.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v23.s[0] + fmla v13.4s, v29.4s, v23.s[0] + fmla v14.4s, v30.4s, v23.s[0] + fmla v15.4s, v31.4s, v23.s[0] + fmla v16.4s, v28.4s, v24.s[0] + fmla v17.4s, v29.4s, v24.s[0] + fmla v18.4s, v30.4s, v24.s[0] + fmla v19.4s, v31.4s, v24.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[1] + fmla v1.4s, v29.4s, v20.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v23.s[1] + fmla v13.4s, v29.4s, v23.s[1] + fmla v14.4s, v30.4s, v23.s[1] + fmla v15.4s, v31.4s, v23.s[1] + fmla v16.4s, v28.4s, v24.s[1] + fmla v17.4s, v29.4s, v24.s[1] + fmla v18.4s, v30.4s, v24.s[1] + fmla v19.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[2] + fmla v1.4s, v29.4s, v20.s[2] + fmla v2.4s, v30.4s, v20.s[2] + fmla v3.4s, v31.4s, v20.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v23.s[2] + fmla v13.4s, v29.4s, v23.s[2] + fmla v14.4s, v30.4s, v23.s[2] + fmla v15.4s, v31.4s, v23.s[2] + fmla v16.4s, v28.4s, v24.s[2] + fmla v17.4s, v29.4s, v24.s[2] + fmla v18.4s, v30.4s, v24.s[2] + fmla v19.4s, v31.4s, v24.s[2] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[3] + fmla v1.4s, v29.4s, v20.s[3] + fmla v2.4s, v30.4s, v20.s[3] + fmla v3.4s, v31.4s, v20.s[3] + fmla v4.4s, v28.4s, v21.s[3] + fmla v5.4s, v29.4s, v21.s[3] + fmla v6.4s, v30.4s, v21.s[3] + fmla v7.4s, v31.4s, v21.s[3] + fmla v8.4s, v28.4s, v22.s[3] + fmla v9.4s, v29.4s, v22.s[3] + fmla v10.4s, v30.4s, v22.s[3] + fmla v11.4s, v31.4s, v22.s[3] + fmla v12.4s, v28.4s, v23.s[3] + fmla v13.4s, v29.4s, v23.s[3] + fmla v14.4s, v30.4s, v23.s[3] + fmla v15.4s, v31.4s, v23.s[3] + fmla v16.4s, v28.4s, v24.s[3] + fmla v17.4s, v29.4s, v24.s[3] + fmla v18.4s, v30.4s, v24.s[3] + fmla v19.4s, v31.4s, v24.s[3] + subs x25, x25, #4 + bge LoopC4 + LoopCTail: + add x25, x25, #4 + cbz x25, LoopCEnd + cmp x25, #3 + beq LoopCTail3 + cmp x25, #2 + beq LoopCTail2 + ld1 {v20.s}[0], [x24], #4 + ld1 {v21.s}[0], [x26], #4 + ld1 {v22.s}[0], [x27], #4 + ld1 {v23.s}[0], [x28], #4 + ld1 {v24.s}[0], [x20], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v30.4s, v20.s[0] + fmla v3.4s, v31.4s, v20.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v23.s[0] + fmla v13.4s, v29.4s, v23.s[0] + fmla v14.4s, v30.4s, v23.s[0] + fmla v15.4s, v31.4s, v23.s[0] + fmla v16.4s, v28.4s, v24.s[0] + fmla v17.4s, v29.4s, v24.s[0] + fmla v18.4s, v30.4s, v24.s[0] + fmla v19.4s, v31.4s, v24.s[0] + b LoopCEnd + LoopCTail2: + ld1 {v20.2s}, [x24], #8 + ld1 {v21.2s}, [x26], #8 + ld1 {v22.2s}, [x27], #8 + ld1 {v23.2s}, [x28], #8 + ld1 {v24.2s}, [x20], #8 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v30.4s, v20.s[0] + fmla v3.4s, v31.4s, v20.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v23.s[0] + fmla v13.4s, v29.4s, v23.s[0] + fmla v14.4s, v30.4s, v23.s[0] + fmla v15.4s, v31.4s, v23.s[0] + fmla v16.4s, v28.4s, v24.s[0] + fmla v17.4s, v29.4s, v24.s[0] + fmla v18.4s, v30.4s, v24.s[0] + fmla v19.4s, v31.4s, v24.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[1] + fmla v1.4s, v29.4s, v20.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v23.s[1] + fmla v13.4s, v29.4s, v23.s[1] + fmla v14.4s, v30.4s, v23.s[1] + fmla v15.4s, v31.4s, v23.s[1] + fmla v16.4s, v28.4s, v24.s[1] + fmla v17.4s, v29.4s, v24.s[1] + fmla v18.4s, v30.4s, v24.s[1] + fmla v19.4s, v31.4s, v24.s[1] + b LoopCEnd + LoopCTail3: + ld1 {v20.2s}, [x24], #8 + ld1 {v21.2s}, [x26], #8 + ld1 {v22.2s}, [x27], #8 + ld1 {v23.2s}, [x28], #8 + ld1 {v24.2s}, [x20], #8 + ld1 {v20.s}[2], [x24], #4 + ld1 {v21.s}[2], [x26], #4 + ld1 {v22.s}[2], [x27], #4 + ld1 {v23.s}[2], [x28], #4 + ld1 {v24.s}[2], [x20], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v30.4s, v20.s[0] + fmla v3.4s, v31.4s, v20.s[0] + fmla v4.4s, v28.4s, v21.s[0] + fmla v5.4s, v29.4s, v21.s[0] + fmla v6.4s, v30.4s, v21.s[0] + fmla v7.4s, v31.4s, v21.s[0] + fmla v8.4s, v28.4s, v22.s[0] + fmla v9.4s, v29.4s, v22.s[0] + fmla v10.4s, v30.4s, v22.s[0] + fmla v11.4s, v31.4s, v22.s[0] + fmla v12.4s, v28.4s, v23.s[0] + fmla v13.4s, v29.4s, v23.s[0] + fmla v14.4s, v30.4s, v23.s[0] + fmla v15.4s, v31.4s, v23.s[0] + fmla v16.4s, v28.4s, v24.s[0] + fmla v17.4s, v29.4s, v24.s[0] + fmla v18.4s, v30.4s, v24.s[0] + fmla v19.4s, v31.4s, v24.s[0] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[1] + fmla v1.4s, v29.4s, v20.s[1] + fmla v2.4s, v30.4s, v20.s[1] + fmla v3.4s, v31.4s, v20.s[1] + fmla v4.4s, v28.4s, v21.s[1] + fmla v5.4s, v29.4s, v21.s[1] + fmla v6.4s, v30.4s, v21.s[1] + fmla v7.4s, v31.4s, v21.s[1] + fmla v8.4s, v28.4s, v22.s[1] + fmla v9.4s, v29.4s, v22.s[1] + fmla v10.4s, v30.4s, v22.s[1] + fmla v11.4s, v31.4s, v22.s[1] + fmla v12.4s, v28.4s, v23.s[1] + fmla v13.4s, v29.4s, v23.s[1] + fmla v14.4s, v30.4s, v23.s[1] + fmla v15.4s, v31.4s, v23.s[1] + fmla v16.4s, v28.4s, v24.s[1] + fmla v17.4s, v29.4s, v24.s[1] + fmla v18.4s, v30.4s, v24.s[1] + fmla v19.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[2] + fmla v1.4s, v29.4s, v20.s[2] + fmla v2.4s, v30.4s, v20.s[2] + fmla v3.4s, v31.4s, v20.s[2] + fmla v4.4s, v28.4s, v21.s[2] + fmla v5.4s, v29.4s, v21.s[2] + fmla v6.4s, v30.4s, v21.s[2] + fmla v7.4s, v31.4s, v21.s[2] + fmla v8.4s, v28.4s, v22.s[2] + fmla v9.4s, v29.4s, v22.s[2] + fmla v10.4s, v30.4s, v22.s[2] + fmla v11.4s, v31.4s, v22.s[2] + fmla v12.4s, v28.4s, v23.s[2] + fmla v13.4s, v29.4s, v23.s[2] + fmla v14.4s, v30.4s, v23.s[2] + fmla v15.4s, v31.4s, v23.s[2] + fmla v16.4s, v28.4s, v24.s[2] + fmla v17.4s, v29.4s, v24.s[2] + fmla v18.4s, v30.4s, v24.s[2] + fmla v19.4s, v31.4s, v24.s[2] + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + fmax v8.4s, v8.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + fmax v10.4s, v10.4s, v24.4s + fmax v11.4s, v11.4s, v24.4s + fmax v12.4s, v12.4s, v24.4s + fmax v13.4s, v13.4s, v24.4s + fmax v14.4s, v14.4s, v24.4s + fmax v15.4s, v15.4s, v24.4s + fmax v16.4s, v16.4s, v24.4s + fmax v17.4s, v17.4s, v24.4s + fmax v18.4s, v18.4s, v24.4s + fmax v19.4s, v19.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + fmin v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v24.4s + fmin v10.4s, v10.4s, v24.4s + fmin v11.4s, v11.4s, v24.4s + fmin v12.4s, v12.4s, v24.4s + fmin v13.4s, v13.4s, v24.4s + fmin v14.4s, v14.4s, v24.4s + fmin v15.4s, v15.4s, v24.4s + fmin v16.4s, v16.4s, v24.4s + fmin v17.4s, v17.4s, v24.4s + fmin v18.4s, v18.4s, v24.4s + fmin v19.4s, v19.4s, v24.4s + + WriteBack: + add x20, x0, x7 + add x21, x0, x7, LSL #1 + add x23, x0, x7, LSL #2 + add x22, x20, x7, LSL #1 + cmp x15, #13 + beq NC4HW4 + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0] + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x21] + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x22] + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x23] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v8.4s}, [x0], #16 + st1 {v12.4s}, [x0], #16 + st1 {v16.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v9.4s}, [x20], #16 + st1 {v13.4s}, [x20], #16 + st1 {v17.4s}, [x20] + st1 {v2.4s}, [x21], #16 + st1 {v6.4s}, [x21], #16 + st1 {v10.4s}, [x21], #16 + st1 {v14.4s}, [x21], #16 + st1 {v18.4s}, [x21] + st1 {v3.4s}, [x22], #16 + st1 {v7.4s}, [x22], #16 + st1 {v11.4s}, [x22], #16 + st1 {v15.4s}, [x22], #16 + st1 {v19.4s}, [x22] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW5x8Kernel.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW5x8Kernel.S new file mode 100644 index 0000000000000000000000000000000000000000..b7e48480ab2e9749f28b5fba3a86bb99f14b0c3c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/ConvSW5x8Kernel.S @@ -0,0 +1,308 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void SWConv5x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, +// size_t kernel_w, size_t act_flag, size_t out_step, size_t ic_algin, size_t in_kw_step, +// size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode) +// x0: dst, x1: src, x2: weight, x3: bias, x4: kernel_h, x5: kernel_w, x6: act_flag, x7: out_step, +// x10: ic_algin, x11: in_kw_step, x12: in_kh_step, x13: in_sw_step, x14: kw_remainder, x15: write_mode +asm_function SWConv5x8Kernel + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + lsl x7, x7, #2 + lsl x11, x11, #2 + lsl x12, x12, #2 + lsl x13, x13, #2 + lsl x14, x14, #2 + + cbz x3, InitNoBias + InitWithBias: + ld1 {v0.4s, v1.4s}, [x3] + ld1 {v2.4s, v3.4s}, [x3] + ld1 {v4.4s, v5.4s}, [x3] + ld1 {v6.4s, v7.4s}, [x3] + ld1 {v8.4s, v9.4s}, [x3] + b LoopH + InitNoBias: + dup v0.2d, xzr + dup v1.2d, xzr + dup v2.2d, xzr + dup v3.2d, xzr + dup v4.2d, xzr + dup v5.2d, xzr + dup v6.2d, xzr + dup v7.2d, xzr + dup v8.2d, xzr + dup v9.2d, xzr + + LoopH: + mov x22, x5 + mov x23, x1 + LoopW: + mov x25, x10 + prfm pldl1keep, [x23] + mov x24, x23 + add x26, x23, x13 + prfm pldl1keep, [x26] + add x27, x23, x13, lsl #1 + prfm pldl1keep, [x27] + add x28, x27, x13 + prfm pldl1keep, [x28] + add x20, x23, x13, lsl #2 + prfm pldl1keep, [x20] + subs x25, x25, #4 + blt LoopCTail + LoopC4: + ld1 {v20.4s}, [x24], #16 + ld1 {v21.4s}, [x26], #16 + ld1 {v22.4s}, [x27], #16 + ld1 {v23.4s}, [x28], #16 + ld1 {v24.4s}, [x20], #16 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v23.s[0] + fmla v7.4s, v29.4s, v23.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + fmla v0.4s, v30.4s, v20.s[1] + fmla v1.4s, v31.4s, v20.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v23.s[1] + fmla v7.4s, v31.4s, v23.s[1] + fmla v8.4s, v30.4s, v24.s[1] + fmla v9.4s, v31.4s, v24.s[1] + + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[2] + fmla v1.4s, v29.4s, v20.s[2] + fmla v2.4s, v28.4s, v21.s[2] + fmla v3.4s, v29.4s, v21.s[2] + fmla v4.4s, v28.4s, v22.s[2] + fmla v5.4s, v29.4s, v22.s[2] + fmla v6.4s, v28.4s, v23.s[2] + fmla v7.4s, v29.4s, v23.s[2] + fmla v8.4s, v28.4s, v24.s[2] + fmla v9.4s, v29.4s, v24.s[2] + fmla v0.4s, v30.4s, v20.s[3] + fmla v1.4s, v31.4s, v20.s[3] + fmla v2.4s, v30.4s, v21.s[3] + fmla v3.4s, v31.4s, v21.s[3] + fmla v4.4s, v30.4s, v22.s[3] + fmla v5.4s, v31.4s, v22.s[3] + fmla v6.4s, v30.4s, v23.s[3] + fmla v7.4s, v31.4s, v23.s[3] + fmla v8.4s, v30.4s, v24.s[3] + fmla v9.4s, v31.4s, v24.s[3] + subs x25, x25, #4 + bge LoopC4 + LoopCTail: + add x25, x25, #4 + cbz x25, LoopCEnd + cmp x25, #3 + beq LoopCTail3 + cmp x25, #2 + beq LoopCTail2 + ld1 {v20.s}[0], [x24], #4 + ld1 {v21.s}[0], [x26], #4 + ld1 {v22.s}[0], [x27], #4 + ld1 {v23.s}[0], [x28], #4 + ld1 {v24.s}[0], [x20], #4 + ld1 {v28.4s, v29.4s}, [x2], #32 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v23.s[0] + fmla v7.4s, v29.4s, v23.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + b LoopCEnd + LoopCTail2: + ld1 {v20.2s}, [x24], #8 + ld1 {v21.2s}, [x26], #8 + ld1 {v22.2s}, [x27], #8 + ld1 {v23.2s}, [x28], #8 + ld1 {v24.2s}, [x20], #8 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v23.s[0] + fmla v7.4s, v29.4s, v23.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + fmla v0.4s, v30.4s, v20.s[1] + fmla v1.4s, v31.4s, v20.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v23.s[1] + fmla v7.4s, v31.4s, v23.s[1] + fmla v8.4s, v30.4s, v24.s[1] + fmla v9.4s, v31.4s, v24.s[1] + b LoopCEnd + LoopCTail3: + ld1 {v20.2s}, [x24], #8 + ld1 {v21.2s}, [x26], #8 + ld1 {v22.2s}, [x27], #8 + ld1 {v23.2s}, [x28], #8 + ld1 {v24.2s}, [x20], #8 + ld1 {v20.s}[2], [x24], #4 + ld1 {v21.s}[2], [x26], #4 + ld1 {v22.s}[2], [x27], #4 + ld1 {v23.s}[2], [x28], #4 + ld1 {v24.s}[2], [x20], #4 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + fmla v0.4s, v28.4s, v20.s[0] + fmla v1.4s, v29.4s, v20.s[0] + fmla v2.4s, v28.4s, v21.s[0] + fmla v3.4s, v29.4s, v21.s[0] + fmla v4.4s, v28.4s, v22.s[0] + fmla v5.4s, v29.4s, v22.s[0] + fmla v6.4s, v28.4s, v23.s[0] + fmla v7.4s, v29.4s, v23.s[0] + fmla v8.4s, v28.4s, v24.s[0] + fmla v9.4s, v29.4s, v24.s[0] + ld1 {v26.4s, v27.4s}, [x2], #32 + fmla v0.4s, v30.4s, v20.s[1] + fmla v1.4s, v31.4s, v20.s[1] + fmla v2.4s, v30.4s, v21.s[1] + fmla v3.4s, v31.4s, v21.s[1] + fmla v4.4s, v30.4s, v22.s[1] + fmla v5.4s, v31.4s, v22.s[1] + fmla v6.4s, v30.4s, v23.s[1] + fmla v7.4s, v31.4s, v23.s[1] + fmla v8.4s, v30.4s, v24.s[1] + fmla v9.4s, v31.4s, v24.s[1] + fmla v0.4s, v26.4s, v20.s[2] + fmla v1.4s, v27.4s, v20.s[2] + fmla v2.4s, v26.4s, v21.s[2] + fmla v3.4s, v27.4s, v21.s[2] + fmla v4.4s, v26.4s, v22.s[2] + fmla v5.4s, v27.4s, v22.s[2] + fmla v6.4s, v26.4s, v23.s[2] + fmla v7.4s, v27.4s, v23.s[2] + fmla v8.4s, v26.4s, v24.s[2] + fmla v9.4s, v27.4s, v24.s[2] + LoopCEnd: + add x23, x23, x11 + subs x22, x22, #1 + bgt LoopW + add x1, x1, x12 + add x2, x2, x14 + subs x4, x4, #1 + bgt LoopH + + ands x6, x6, #3 + beq WriteBack + dup v24.2d, xzr // relu + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s + fmax v4.4s, v4.4s, v24.4s + fmax v5.4s, v5.4s, v24.4s + fmax v6.4s, v6.4s, v24.4s + fmax v7.4s, v7.4s, v24.4s + fmax v8.4s, v8.4s, v24.4s + fmax v9.4s, v9.4s, v24.4s + + ands x6, x6, #1 + beq WriteBack + movi v24.4s, #6 // relu6 + scvtf v24.4s, v24.4s + fmin v0.4s, v0.4s, v24.4s + fmin v1.4s, v1.4s, v24.4s + fmin v2.4s, v2.4s, v24.4s + fmin v3.4s, v3.4s, v24.4s + fmin v4.4s, v4.4s, v24.4s + fmin v5.4s, v5.4s, v24.4s + fmin v6.4s, v6.4s, v24.4s + fmin v7.4s, v7.4s, v24.4s + fmin v8.4s, v8.4s, v24.4s + fmin v9.4s, v9.4s, v24.4s + + WriteBack: + add x20, x0, x7 + cmp x15, #13 + beq NC4HW4 + add x21, x0, x7, LSL #1 + add x23, x0, x7, LSL #2 + add x22, x20, x7, LSL #1 + st1 {v0.4s, v1.4s}, [x0] + st1 {v2.4s, v3.4s}, [x20] + st1 {v4.4s, v5.4s}, [x21] + st1 {v6.4s, v7.4s}, [x22] + st1 {v8.4s, v9.4s}, [x23] + b End + NC4HW4: + st1 {v0.4s}, [x0], #16 + st1 {v2.4s}, [x0], #16 + st1 {v4.4s}, [x0], #16 + st1 {v6.4s}, [x0], #16 + st1 {v8.4s}, [x0] + st1 {v1.4s}, [x20], #16 + st1 {v3.4s}, [x20], #16 + st1 {v5.4s}, [x20], #16 + st1 {v7.4s}, [x20], #16 + st1 {v9.4s}, [x20] + End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DeconvDwFp32Border.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DeconvDwFp32Border.S new file mode 100644 index 0000000000000000000000000000000000000000..ead01b8ef9f0069cb95939fdf7d5fb745af3f575 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DeconvDwFp32Border.S @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp32Border(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t in_kh_step, size_t in_kw_step, size_t kernel_w) + +// x0: dst, x1: src, x2: weight, x3: height, x4: width, x5: in_kh_step, x6: in_kw_step, x7: kernel_w +asm_function DeconvDwFp32Border + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + cmp x3, #0 + beq End + cmp x4, #0 + beq End + ld1 {v1.4s}, [x1] + + mov x13, x0 + mov x14, x2 + LoopH: + mov x15, x13 + mov x16, x14 + mov x17, x4 + LoopW: + ld1 {v0.4s}, [x15] + ld1 {v2.4s}, [x16], #16 + fmla v0.4s, v1.4s, v2.4s + st1 {v0.4s}, [x15], x6 + subs x17, x17, #1 + bne LoopW + subs x3, x3, #1 + add x13, x13, x5 + add x14, x14, x7 + bne LoopH + End: + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DeconvDwFp32Center.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DeconvDwFp32Center.S new file mode 100644 index 0000000000000000000000000000000000000000..11722e713dd9f3a84da8c4eb3ce3ab34c93beec1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DeconvDwFp32Center.S @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step); +// x0: dst, x1: src, x2: weight, x3: height, x4: weight, x5: kernel_h, x6: kernel_w, x7: out_h_step +// x8: block_channel, x9: in_sh_step, x10: in_sw_step, x11: in_kh_step, x12: in_kw_step +asm_function DeconvDwFp32Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + ldr x8, [sp, #32] + ldr x9, [sp, #40] + ldr x10, [sp, #48] + ldr x11, [sp, #56] + ldr x12, [sp, #64] + + LoopH: + mov x15, x0 + mov x16, x1 + mov x17, x4 + LoopW: + mov x22, x15 + mov x19, x2 + mov x20, x5 + ld1 {v1.4s}, [x16], x8 + LoopKh: + mov x21, x22 + mov x13, x6 + LoopKw: + ld1 {v0.4s}, [x21] + ld1 {v2.4s}, [x19], #16 + fmla v0.4s, v1.4s, v2.4s + st1 {v0.4s}, [x21], x12 + subs x13, x13, #1 + bne LoopKw + add x22, x22, x11 + subs x20, x20, #1 + bne LoopKh + add x15, x15, x10 + subs x17, x17, #1 + bne LoopW + add x0, x0, x9 + add x1, x1, x7 + subs x3, x3, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DeconvDwInt8Center.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DeconvDwInt8Center.S new file mode 100644 index 0000000000000000000000000000000000000000..1c3723fa8b45d1142d6a4d36a9701676c844d16f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DeconvDwInt8Center.S @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step); +// x0: dst, x1: src, x2: weight, x3: height, x4: weight, x5: kernel_h, x6: kernel_w, x7: out_h_step +// x8: block_channel, x9: in_sh_step, x10: in_sw_step, x11: in_kh_step, x12: in_kw_step +asm_function DeconvDwInt8Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + ldr x8, [sp, #32] + ldr x9, [sp, #40] + ldr x10, [sp, #48] + ldr x11, [sp, #56] + ldr x12, [sp, #64] + + LoopH: + mov x15, x0 + mov x16, x1 + mov x17, x4 + LoopW: + mov x18, x15 + mov x19, x2 + mov x20, x5 + ld1 {v1.4h}, [x16], x8 + LoopKh: + mov x21, x18 + mov x13, x6 + LoopKw: + ld1 {v0.4s}, [x21] + ld1 {v2.4h}, [x19], #8 + smlal v0.4s, v1.4h, v2.4h + st1 {v0.4s}, [x21], x12 + subs x13, x13, #1 + bne LoopKw + add x18, x18, x11 + subs x20, x20, #1 + bne LoopKh + add x15, x15, x10 + subs x17, x17, #1 + bne LoopW + add x0, x0, x9 + add x1, x1, x7 + subs x3, x3, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DeconvDwInt8Post.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DeconvDwInt8Post.S new file mode 100644 index 0000000000000000000000000000000000000000..f9909a81f3d6ea082fb3be11042b8d7f207645b1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DeconvDwInt8Post.S @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void DeconvDwInt8Post(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, int pixel_nums, +// int out_multiplier, int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, +// int32_t acc_max) +// x0: dst, x1: output_buffer, x2: bias, x3: block_channel, x4: pixel_nums, x5: out_multiplier +// x6: left_shift, x7: right_shift, x8: out_zp, x9: acc_min, x10: acc_max + +asm_function DeconvDwInt8Post + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ld1 {v25.4s}, [x2] + + dup v26.4s, w6 // left_shift + dup v27.4s, w5 // out_multiplier + dup v28.4s, w7 // right_shift + + ldr w8, [sp] + dup v29.4s, w8 // out_zp + ldr w9, [sp, #8] + dup v30.4s, w9 // acc_min + ldr w10, [sp, #16] + dup v31.4s, w10 // acc_max + + LoopCount: + ld1 {v0.4s}, [x1], #16 + add v0.4s, v0.4s, v25.4s + sqshl v0.4s, v0.4s, v26.4s + sqrdmulh v0.4s, v0.4s, v27.4s + sqrshl v0.4s, v0.4s, v28.4s + + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + + sqxtn v0.4h, v0.4s + sqxtn v0.8b, v0.8h + + st1 {v0.s}[0], [x0], x3 + + sub x4, x4, #1 + cmp x4, #1 + bge LoopCount + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DynamicGatherArm64.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DynamicGatherArm64.S new file mode 100644 index 0000000000000000000000000000000000000000..e442d8706e47df37d1b21635fb816ad217a102fd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/DynamicGatherArm64.S @@ -0,0 +1,48 @@ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" +.text +.align 5 + +// void DynamicGatherArm64(const int8_t *src, float *output, int count_16, int zp, float scale); +// x0: src(left matrix ptr) +// x1: output(right matrix ptr) +// w2: count_16 +// w3: zp +// w4: scale + +asm_function DynamicGatherArm64 + mov x5, x0 // reload src + mov x6, x1 // reload out + mov w7, w2 // reload count_16 + dup v1.4s, w3 // zp + dup v2.4s, v0.s[0] // scale + + LoopCount: + ld1 {v0.16b}, [x5], #16 + + sxtl v3.8h, v0.8b + sxtl2 v4.8h, v0.16b + + sxtl v16.4s, v3.4h + sxtl2 v17.4s, v3.8h + sxtl v18.4s, v4.4h + sxtl2 v19.4s, v4.8h + + sub v16.4s, v16.4s, v1.4s + scvtf v16.4s,v16.4s + fmul v16.4s, v16.4s, v2.4s + sub v17.4s, v17.4s, v1.4s + scvtf v17.4s,v17.4s + fmul v17.4s, v17.4s, v2.4s + sub v18.4s, v18.4s, v1.4s + scvtf v18.4s,v18.4s + fmul v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v1.4s + scvtf v19.4s,v19.4s + fmul v19.4s, v19.4s, v2.4s + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x6], #64 + subs w7, w7, #16 + bgt LoopCount +ret + +#endif \ No newline at end of file diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/IndirectGemmInt16to32_8x4.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/IndirectGemmInt16to32_8x4.S new file mode 100644 index 0000000000000000000000000000000000000000..8ca4b733678ed0478090f0785ab07a5f1419b9f1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/IndirectGemmInt16to32_8x4.S @@ -0,0 +1,233 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void IndirectGemmInt16to32_8x4(int *output, short *input, short *weight, size_t ksize, size_t ic8, size_t oc4, size_t offset); +// x0: output, x1: input, x2: weight, x3: ksize, x4: ic8, x5: oc4, x6: offset +asm_function IndirectGemmInt16to32_8x4 + + .macro INIT_ZERO + dup v28.4s, wzr + mov v29.16b, v28.16b + mov v30.16b, v28.16b + mov v31.16b, v28.16b + .endm + + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + LoopOc: + mov x7, x3 + mov x8, x1 + + LoopKsize: + mov x9, x0 + INIT_ZERO + + // load input + ld1 {v0.8h, v1.8h}, [x8], #32 + // load weight + ld1 {v16.8h}, [x2], #16 + smull v24.4s, v16.4h, v0.h[0] + smull v25.4s, v16.4h, v1.h[0] + // load weight + ld1 {v17.8h}, [x2], #16 + smlal2 v24.4s, v16.8h, v0.h[1] + smlal2 v25.4s, v16.8h, v1.h[1] + // load input + ld1 {v2.8h, v3.8h}, [x8], #32 + smlal v24.4s, v17.4h, v0.h[2] + smlal v25.4s, v17.4h, v1.h[2] + smlal2 v24.4s, v17.8h, v0.h[3] + smlal2 v25.4s, v17.8h, v1.h[3] + // load weight + ld1 {v18.8h, v19.8h}, [x2], #32 + smull v26.4s, v16.4h, v2.h[0] + smull v27.4s, v16.4h, v3.h[0] + + subs x10, x4, #1 + beq LoopIcEnd + + LoopIc: + + smlal2 v26.4s, v16.8h, v2.h[1] + smlal2 v27.4s, v16.8h, v3.h[1] + smlal v26.4s, v17.4h, v2.h[2] + smlal v27.4s, v17.4h, v3.h[2] + smlal2 v26.4s, v17.8h, v2.h[3] + smlal2 v27.4s, v17.8h, v3.h[3] + + smlal v24.4s, v18.4h, v0.h[4] + smlal v25.4s, v18.4h, v1.h[4] + smlal2 v24.4s, v18.8h, v0.h[5] + smlal2 v25.4s, v18.8h, v1.h[5] + smlal v24.4s, v19.4h, v0.h[6] + smlal v25.4s, v19.4h, v1.h[6] + smlal2 v24.4s, v19.8h, v0.h[7] + smlal2 v25.4s, v19.8h, v1.h[7] + // load input + ld1 {v4.8h, v5.8h}, [x8], #32 + smlal v26.4s, v18.4h, v2.h[4] + smlal v27.4s, v18.4h, v3.h[4] + smlal2 v26.4s, v18.8h, v2.h[5] + smlal2 v27.4s, v18.8h, v3.h[5] + smlal v26.4s, v19.4h, v2.h[6] + smlal v27.4s, v19.4h, v3.h[6] + smlal2 v26.4s, v19.8h, v2.h[7] + smlal2 v27.4s, v19.8h, v3.h[7] + + // load input + ld1 {v6.8h, v7.8h}, [x8], #32 + smlal v28.4s, v16.4h, v4.h[0] + smlal v29.4s, v16.4h, v5.h[0] + smlal2 v28.4s, v16.8h, v4.h[1] + smlal2 v29.4s, v16.8h, v5.h[1] + smlal v28.4s, v17.4h, v4.h[2] + smlal v29.4s, v17.4h, v5.h[2] + smlal2 v28.4s, v17.8h, v4.h[3] + smlal2 v29.4s, v17.8h, v5.h[3] + + smlal v30.4s, v16.4h, v6.h[0] + smlal v31.4s, v16.4h, v7.h[0] + smlal2 v30.4s, v16.8h, v6.h[1] + smlal2 v31.4s, v16.8h, v7.h[1] + smlal v30.4s, v17.4h, v6.h[2] + smlal v31.4s, v17.4h, v7.h[2] + smlal2 v30.4s, v17.8h, v6.h[3] + smlal2 v31.4s, v17.8h, v7.h[3] + // load weight + ld1 {v16.8h, v17.8h}, [x2], #32 + smlal v28.4s, v18.4h, v4.h[4] + smlal v29.4s, v18.4h, v5.h[4] + smlal2 v28.4s, v18.8h, v4.h[5] + smlal2 v29.4s, v18.8h, v5.h[5] + smlal v28.4s, v19.4h, v4.h[6] + smlal v29.4s, v19.4h, v5.h[6] + smlal2 v28.4s, v19.8h, v4.h[7] + smlal2 v29.4s, v19.8h, v5.h[7] + // load input + ld1 {v0.8h, v1.8h}, [x8], #32 + smlal v30.4s, v18.4h, v6.h[4] + smlal v31.4s, v18.4h, v7.h[4] + smlal2 v30.4s, v18.8h, v6.h[5] + smlal2 v31.4s, v18.8h, v7.h[5] + smlal v30.4s, v19.4h, v6.h[6] + smlal v31.4s, v19.4h, v7.h[6] + smlal2 v30.4s, v19.8h, v6.h[7] + smlal2 v31.4s, v19.8h, v7.h[7] + // load input + ld1 {v2.8h, v3.8h}, [x8], #32 + smlal v24.4s, v16.4h, v0.h[0] + smlal v25.4s, v16.4h, v1.h[0] + smlal2 v24.4s, v16.8h, v0.h[1] + smlal2 v25.4s, v16.8h, v1.h[1] + // load weight + ld1 {v18.8h, v19.8h}, [x2], #32 + smlal v24.4s, v17.4h, v0.h[2] + smlal v25.4s, v17.4h, v1.h[2] + smlal2 v24.4s, v17.8h, v0.h[3] + smlal2 v25.4s, v17.8h, v1.h[3] + smlal v26.4s, v16.4h, v2.h[0] + smlal v27.4s, v16.4h, v3.h[0] + + subs x10, x10, #1 + bne LoopIc + + LoopIcEnd: + smlal2 v26.4s, v16.8h, v2.h[1] + smlal2 v27.4s, v16.8h, v3.h[1] + smlal v26.4s, v17.4h, v2.h[2] + smlal v27.4s, v17.4h, v3.h[2] + smlal2 v26.4s, v17.8h, v2.h[3] + smlal2 v27.4s, v17.8h, v3.h[3] + + smlal v24.4s, v18.4h, v0.h[4] + smlal v25.4s, v18.4h, v1.h[4] + smlal2 v24.4s, v18.8h, v0.h[5] + smlal2 v25.4s, v18.8h, v1.h[5] + smlal v24.4s, v19.4h, v0.h[6] + smlal v25.4s, v19.4h, v1.h[6] + smlal2 v24.4s, v19.8h, v0.h[7] + smlal2 v25.4s, v19.8h, v1.h[7] + // load input + ld1 {v4.8h, v5.8h}, [x8], #32 + smlal v26.4s, v18.4h, v2.h[4] + smlal v27.4s, v18.4h, v3.h[4] + smlal2 v26.4s, v18.8h, v2.h[5] + st1 {v24.4s}, [x9], x6 + smlal2 v27.4s, v18.8h, v3.h[5] + smlal v26.4s, v19.4h, v2.h[6] + st1 {v25.4s}, [x9], x6 + smlal v27.4s, v19.4h, v3.h[6] + smlal2 v26.4s, v19.8h, v2.h[7] + smlal2 v27.4s, v19.8h, v3.h[7] + + // load input + ld1 {v6.8h, v7.8h}, [x8], #32 + smlal v28.4s, v16.4h, v4.h[0] + smlal v29.4s, v16.4h, v5.h[0] + smlal2 v28.4s, v16.8h, v4.h[1] + smlal2 v29.4s, v16.8h, v5.h[1] + smlal v28.4s, v17.4h, v4.h[2] + st1 {v26.4s}, [x9], x6 + smlal v29.4s, v17.4h, v5.h[2] + smlal2 v28.4s, v17.8h, v4.h[3] + smlal2 v29.4s, v17.8h, v5.h[3] + st1 {v27.4s}, [x9], x6 + smlal v30.4s, v16.4h, v6.h[0] + smlal v31.4s, v16.4h, v7.h[0] + smlal2 v30.4s, v16.8h, v6.h[1] + smlal2 v31.4s, v16.8h, v7.h[1] + smlal v30.4s, v17.4h, v6.h[2] + smlal v31.4s, v17.4h, v7.h[2] + smlal2 v30.4s, v17.8h, v6.h[3] + smlal2 v31.4s, v17.8h, v7.h[3] + smlal v28.4s, v18.4h, v4.h[4] + smlal v29.4s, v18.4h, v5.h[4] + smlal2 v28.4s, v18.8h, v4.h[5] + smlal2 v29.4s, v18.8h, v5.h[5] + smlal v28.4s, v19.4h, v4.h[6] + smlal v29.4s, v19.4h, v5.h[6] + smlal2 v28.4s, v19.8h, v4.h[7] + smlal2 v29.4s, v19.8h, v5.h[7] + smlal v30.4s, v18.4h, v6.h[4] + smlal v31.4s, v18.4h, v7.h[4] + st1 {v28.4s}, [x9], x6 + smlal2 v30.4s, v18.8h, v6.h[5] + smlal2 v31.4s, v18.8h, v7.h[5] + smlal v30.4s, v19.4h, v6.h[6] + st1 {v29.4s}, [x9], x6 + smlal v31.4s, v19.4h, v7.h[6] + smlal2 v30.4s, v19.8h, v6.h[7] + smlal2 v31.4s, v19.8h, v7.h[7] + + st1 {v30.4s}, [x9], x6 + st1 {v31.4s}, [x9] + + subs x7, x7, #1 + add x0, x0, #16 + bne LoopKsize + + subs x5, x5, #1 + bne LoopOc + + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatVecMulFp32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatVecMulFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..36c8d8ecfa36f676bec907e6c72963f3d57ebc8a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatVecMulFp32.S @@ -0,0 +1,252 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: col + +asm_default_function MatVecMulFp32 + sub sp, sp, #128 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + + mov w14, #4 // sizeof(float) + mul w8, w14, w5 // rhs depthx1 block stride + mov w14, #4 + mul w13, w8, w14 // rhs depthx4 block stride + +Loop: + mov x15, x0 // reload a ptr + mov x7, x1 // reload b ptr + mov w9, w5 // reload depth + cmp w6, #4 + blt Loop1x1 + +Loop1x4: + dup v10.8h, wzr + dup v11.8h, wzr + dup v12.8h, wzr + dup v13.8h, wzr + dup v14.8h, wzr + + add x10, x7, x8 + add x11, x10, x8 + add x12, x11, x8 + +Depth8_1x4: + cmp w9, #8 + blt Depth4_1x4 + sub w9, w9, #8 + ld1 {v0.4s, v1.4s}, [x15], #32 + ld1 {v2.4s, v3.4s}, [x7], #32 + ld1 {v4.4s, v5.4s}, [x10], #32 + cmp w9, #8 + blt Depth8_1x4_Loop_End + +Depth8_1x4_Loop: + fmla v10.4s, v0.4s, v2.4s + fmla v10.4s, v1.4s, v3.4s + ld1 {v6.4s, v7.4s}, [x11], #32 + fmla v11.4s, v0.4s, v4.4s + fmla v11.4s, v1.4s, v5.4s + ld1 {v8.4s, v9.4s}, [x12], #32 + fmla v12.4s, v0.4s, v6.4s + fmla v12.4s, v1.4s, v7.4s + ld1 {v2.4s, v3.4s}, [x7], #32 + fmla v13.4s, v0.4s, v8.4s + fmla v13.4s, v1.4s, v9.4s + ld1 {v0.4s, v1.4s}, [x15], #32 + ld1 {v4.4s, v5.4s}, [x10], #32 + sub w9, w9, #8 + cmp w9, #8 + bge Depth8_1x4_Loop + +Depth8_1x4_Loop_End: + fmla v10.4s, v0.4s, v2.4s + fmla v10.4s, v1.4s, v3.4s + ld1 {v6.4s, v7.4s}, [x11], #32 + fmla v11.4s, v0.4s, v4.4s + fmla v11.4s, v1.4s, v5.4s + ld1 {v8.4s, v9.4s}, [x12], #32 + fmla v12.4s, v0.4s, v6.4s + fmla v12.4s, v1.4s, v7.4s + fmla v13.4s, v0.4s, v8.4s + fmla v13.4s, v1.4s, v9.4s + +Depth4_1x4: + cmp w9, #4 + blt Depth1_1x4 + sub w9, w9, #4 + ld1 {v0.4s}, [x15], #16 + ld1 {v1.4s}, [x7], #16 + ld1 {v2.4s}, [x10], #16 + cmp w9, #4 + blt Depth4_1x4_Loop_End + +Depth4_1x4_Loop: + fmla v10.4s, v1.4s, v0.4s + ld1 {v3.4s}, [x11], #16 + fmla v11.4s, v2.4s, v0.4s + ld1 {v4.4s}, [x12], #16 + fmla v12.4s, v3.4s, v0.4s + ld1 {v1.4s}, [x7], #16 + fmla v13.4s, v4.4s, v0.4s + ld1 {v0.4s}, [x15], #16 + ld1 {v2.4s}, [x10], #16 + sub w9, w9, #4 + cmp w9, #4 + bge Depth4_1x4_Loop + +Depth4_1x4_Loop_End: + fmla v10.4s, v1.4s, v0.4s + ld1 {v3.4s}, [x11], #16 + fmla v11.4s, v2.4s, v0.4s + ld1 {v4.4s}, [x12], #16 + fmla v12.4s, v3.4s, v0.4s + fmla v13.4s, v4.4s, v0.4s + +Depth1_1x4: + cmp w9, #0 + beq End1x4 + ld1 {v0.s}[0], [x15], #4 + ld1 {v1.s}[0], [x7], #4 + ld1 {v1.s}[1], [x10], #4 + ld1 {v1.s}[2], [x11], #4 + ld1 {v1.s}[3], [x12], #4 + + fmla v14.4s, v1.4s, v0.s[0] + sub w9, w9, #1 + cbz w9, End1x4 + b Depth1_1x4 + +End1x4: + faddp v15.4s, v10.4s, v11.4s + faddp v16.4s, v12.4s, v13.4s + faddp v17.4s, v15.4s, v16.4s + fadd v14.4s, v14.4s, v17.4s + + cbz x3, Act1x4 + ld1 {v15.4s}, [x3], #16 + fadd v14.4s, v14.4s, v15.4s // add bias + +Act1x4: + cmp w4, #3 + beq Relu6_1x4 + cmp w4, #1 + beq Relu1x4 + b Write1x4 + +Relu6_1x4: + movi v15.4s, #0x46, lsl #8 + fmin v14.4s, v14.4s, v15.4s + +Relu1x4: + dup v15.4s, wzr + fmax v14.4s, v14.4s, v15.4s + +Write1x4: + st1 {v14.4s}, [x2], #16 + sub w6, w6, #4 + cbz w6, End + add x1, x1, x13 + b Loop + + +Loop1x1: + dup v4.4s, wzr + dup v5.4s, wzr + +Depth8_1x1: + cmp w9, #8 + blt Depth4_1x1 + + ld1 {v0.4s, v1.4s}, [x15], #32 + ld1 {v2.4s, v3.4s}, [x7], #32 + + fmla v4.4s, v2.4s, v0.4s + fmla v4.4s, v3.4s, v1.4s + sub w9, w9, #8 + cbz w9, End1x1 + b Depth8_1x1 + +Depth4_1x1: + cmp w9, #4 + blt Depth1_1x1 + + ld1 {v0.4s}, [x15], #16 + ld1 {v1.4s}, [x7], #16 + + fmla v4.4s, v1.4s, v0.4s + sub w9, w9, #4 + cbz w9, End1x1 + b Depth8_1x1 + +Depth1_1x1: + ld1 {v0.s}[0], [x15], #4 + ld1 {v1.s}[0], [x7], #4 + + fmla v5.4s, v1.4s, v0.s[0] + sub w9, w9, #1 + cbz w9, End1x1 + b Depth1_1x1 + +End1x1: + faddp v6.4s, v4.4s, v4.4s + faddp v7.4s, v6.4s, v6.4s + fadd v7.4s, v7.4s, v5.4s + + cbz x3, Act1x1 + ld1 {v8.s}[0], [x3], #4 + fadd v7.4s, v7.4s, v8.4s // add bias + +Act1x1: + cmp w4, #3 + beq Relu6_1x1 + cmp w4, #1 + beq Relu1x1 + b Write1x1 + +Relu6_1x1: + movi v8.4s, #0x46, lsl #8 + fmin v7.4s, v7.4s, v8.4s + +Relu1x1: + dup v8.4s, wzr + fmax v7.4s, v7.4s, v8.4s + +Write1x1: + st1 {v7.s}[0], [x2], #4 + sub w6, w6, #1 + cbz w6, End + add x1, x1, x8 + b Loop + +End: + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ret +#endif \ No newline at end of file diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatVecMulPackFp32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatVecMulPackFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..b013f48ab176ff5dcdf1b5a50129fbf779092058 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatVecMulPackFp32.S @@ -0,0 +1,198 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatVecMulPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: col + +asm_default_function MatVecMulPackFp32 + sub sp, sp, #16 + stp x29, x30, [sp] + + dup v1.2d, xzr + mov w7, #6 + dup v2.4s, w7 + scvtf v2.4s, v2.4s + subs w6, w6, #8 + blt Loop1xNStart + Loop1x8Start: + bl Compute1x8Unit + st1 {v24.4s, v25.4s}, [x2], #32 + subs w6, w6, #8 + bge Loop1x8Start + + Loop1xNStart: + add w6, w6, #8 + cbz w6, End + subs w6, w6, #4 + ble Loop1x4Start + bl Compute1x8Unit + st1 {v24.4s}, [x2], #16 + st1 {v25.s}[0], [x2], #4 + cmp w6, #1 + beq End + st1 {v25.s}[1], [x2], #4 + cmp w6, #2 + beq End + st1 {v25.s}[2], [x2] + b End + + Loop1x4Start: + add w6, w6, #4 + cbz w6, End + bl Compute1x4Unit + st1 {v24.s}[0], [x2], #4 + cmp w6, #1 + beq End + st1 {v24.s}[1], [x2], #4 + cmp w6, #2 + beq End + st1 {v24.s}[2], [x2], #4 + cmp w6, #3 + beq End + st1 {v24.s}[3], [x2], #4 + b End + + Compute1x8Unit: + mov x7, x0 // reload a-ptr + mov w8, w5 // reset depth + dup v24.2d, xzr + dup v25.2d, xzr + dup v26.2d, xzr + dup v27.2d, xzr + dup v28.2d, xzr + dup v29.2d, xzr + dup v30.2d, xzr + dup v31.2d, xzr + cbz x3, Compute1x8Enter + ld1 {v24.4s, v25.4s}, [x3], #32 + Compute1x8Enter: + subs w8, w8, #4 + blt Compute1x8Tail + Compute1x8: + ld1 {v0.4s}, [x7], #16 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 + fmla v24.4s, v16.4s, v0.s[0] + fmla v25.4s, v17.4s, v0.s[0] + fmla v26.4s, v18.4s, v0.s[1] + fmla v27.4s, v19.4s, v0.s[1] + fmla v28.4s, v20.4s, v0.s[2] + fmla v29.4s, v21.4s, v0.s[2] + fmla v30.4s, v22.4s, v0.s[3] + fmla v31.4s, v23.4s, v0.s[3] + subs w8, w8, #4 + bge Compute1x8 + Compute1x8Tail: + add w8, w8, #4 + cbz w8, Compute1x8UnionTail + Compute1x8DepthTail: + ld1 {v0.s}[0], [x7], #4 + ld1 {v16.4s, v17.4s}, [x1], #32 + fmla v24.4s, v16.4s, v0.s[0] + fmla v25.4s, v17.4s, v0.s[0] + subs w8, w8, #1 + bgt Compute1x8DepthTail + Compute1x8UnionTail: + fadd v24.4s, v24.4s, v26.4s + fadd v25.4s, v25.4s, v27.4s + fadd v28.4s, v28.4s, v30.4s + fadd v29.4s, v29.4s, v31.4s + fadd v24.4s, v24.4s, v28.4s + fadd v25.4s, v25.4s, v29.4s + Act1x8: + cmp x4, #3 + beq Relu61x8 + cmp x4, #1 + beq Relu1x8 + b Return1x8 + Relu61x8: + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmax v24.4s, v24.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + b Return1x8 + Relu1x8: + fmax v24.4s, v24.4s, v1.4s + fmax v25.4s, v25.4s, v1.4s + Return1x8: + ret + + Compute1x4Unit: + mov x7, x0 // reload a-ptr + mov w8, w5 // reset depth + dup v24.2d, xzr + dup v26.2d, xzr + dup v28.2d, xzr + dup v30.2d, xzr + cbz x3, Compute1x4Enter + ld1 {v24.4s}, [x3] + Compute1x4Enter: + subs w8, w8, #4 + blt Compute1x4Tail + Compute1x4: + ld1 {v0.4s}, [x7], #16 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 + fmla v24.4s, v16.4s, v0.s[0] + fmla v26.4s, v18.4s, v0.s[1] + fmla v28.4s, v20.4s, v0.s[2] + fmla v30.4s, v22.4s, v0.s[3] + subs w8, w8, #4 + bge Compute1x4 + Compute1x4Tail: + add w8, w8, #4 + cbz w8, Compute1x4UnionTail + Compute1x4DepthTail: + ld1 {v0.s}[0], [x7], #4 + ld1 {v16.4s}, [x1] + add x1, x1, #32 + fmla v24.4s, v16.4s, v0.s[0] + subs w8, w8, #1 + bgt Compute1x4DepthTail + Compute1x4UnionTail: + fadd v24.4s, v24.4s, v26.4s + fadd v28.4s, v28.4s, v30.4s + fadd v24.4s, v24.4s, v28.4s + Act1x4: + cmp x4, #3 + beq Relu61x4 + cmp x4, #1 + beq Relu1x4 + b Return1x4 + Relu61x4: + fmin v24.4s, v24.4s, v2.4s + fmax v24.4s, v24.4s, v1.4s + b Return1x8 + Relu1x4: + fmax v24.4s, v24.4s, v1.4s + Return1x4: + ret + + End: + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..2dedccd0cd0a2b0c2374dd8a4589c69b2c6c03bb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32.S @@ -0,0 +1,787 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: row +// w7: col +// w17: stride +// w13: c8_nhwc_c4 + +asm_function MatmulFloatNeon64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + + ldr x9, [sp, #152] + ldr x14, [sp, #160] + + mov w19, #32 // sizeof(float) * 8 + mul w15, w5, w19 // block stride of lhs/rhs: sizeof(float) * 8 * depth + mov x19, #4 + ldr x17, [sp, #144] + cbz x14, NoWinoSteps + mul x8, x7, x17 + mov x11, #8 + mul x11, x11, x17 + mul x8, x8, x19 + mul x11, x11, x19 +NoWinoSteps: + mul x17, x17, x19 + +L1: + mov w10, w6 // reload lhs row + mov x12, x0 // reload lhs ptr + mov x19, x2 // reload dst ptr + +L2: + mov x16, x1 // reload rhs ptr + mov w13, w5 // reload depth + dup v8.4s, wzr + dup v9.4s, wzr + dup v10.4s, wzr + dup v11.4s, wzr + dup v12.4s, wzr + dup v13.4s, wzr + dup v14.4s, wzr + dup v15.4s, wzr + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + +LoopStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48 + ld1 {v3.4s, v4.4s}, [x16], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs w13, w13, #1 + beq LoopEnd + +Loop: + ld1 {v0.4s}, [x12], #16 + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + ld1 {v1.4s}, [x12], #16 + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + ld1 {v3.4s}, [x16], #16 + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + ld1 {v4.4s}, [x16], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + ld1 {v2.4s}, [x12], #16 + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs w13, w13, #1 + bgt Loop + +LoopEnd: + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + +Bias: + cbz x3, Activation + ld1 {v0.4s}, [x3], #16 + ld1 {v1.4s}, [x3] + sub x3, x3, #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s + +Activation: + cmp w4, #3 + beq Relu6 + cmp w4, #1 + beq Relu + b Write + +Relu6: + mov w13, #6 + dup v2.4s, w13 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + +Relu: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + +Write: + cbnz x14, WriteWino + cbz x9, WriteC8 + cmp w7, #1 + beq Write1 + cmp w7, #2 + beq Write2 + cmp w7, #3 + beq Write3 + cmp w7, #4 + beq Write4 + cmp w7, #5 + beq Write5 + cmp w7, #6 + beq Write6 + cmp w7, #7 + beq Write7 + b Write8 + +Write1: + str s8, [x19] + cmp w10, #1 + beq WriteEnd + add x19, x19, x17 + str s10, [x19] + cmp w10, #2 + beq WriteEnd + add x19, x19, x17 + str s12, [x19] + cmp w10, #3 + beq WriteEnd + add x19, x19, x17 + str s14, [x19] + cmp w10, #4 + beq WriteEnd + add x19, x19, x17 + str s16, [x19] + cmp w10, #5 + beq WriteEnd + add x19, x19, x17 + str s18, [x19] + cmp w10, #6 + beq WriteEnd + add x19, x19, x17 + str s20, [x19] + cmp w10, #7 + beq WriteEnd + add x19, x19, x17 + str s22, [x19] + cmp w10, #8 + beq WriteEnd + add x19, x19, x17 + str s24, [x19] + cmp w10, #9 + beq WriteEnd + add x19, x19, x17 + str s26, [x19] + cmp w10, #10 + beq WriteEnd + add x19, x19, x17 + str s28, [x19] + cmp w10, #11 + beq WriteEnd + add x19, x19, x17 + str s30, [x19] + add x19, x19, x17 + b WriteEnd +Write2: + dup s9, v8.s[1] + stp s8, s9, [x19] + cmp w10, #1 + beq WriteEnd + add x19, x19, x17 + dup s11, v10.s[1] + stp s10, s11, [x19] + cmp w10, #2 + beq WriteEnd + add x19, x19, x17 + dup s13, v12.s[1] + stp s12, s13, [x19] + cmp w10, #3 + beq WriteEnd + add x19, x19, x17 + dup s15, v14.s[1] + stp s14, s15, [x19] + cmp w10, #4 + beq WriteEnd + add x19, x19, x17 + dup s17, v16.s[1] + stp s16, s17, [x19] + cmp w10, #5 + beq WriteEnd + add x19, x19, x17 + dup s19, v18.s[1] + stp s18, s19, [x19] + cmp w10, #6 + beq WriteEnd + add x19, x19, x17 + dup s21, v20.s[1] + stp s20, s21, [x19] + cmp w10, #7 + beq WriteEnd + add x19, x19, x17 + dup s23, v22.s[1] + stp s22, s23, [x19] + cmp w10, #8 + beq WriteEnd + add x19, x19, x17 + dup s25, v24.s[1] + stp s24, s25, [x19] + cmp w10, #9 + beq WriteEnd + add x19, x19, x17 + dup s27, v26.s[1] + stp s26, s27, [x19] + cmp w10, #10 + beq WriteEnd + add x19, x19, x17 + dup s29, v28.s[1] + stp s28, s29, [x19] + cmp w10, #11 + beq WriteEnd + add x19, x19, x17 + dup s31, v30.s[1] + stp s30, s31, [x19] + add x19, x19, x17 + b WriteEnd +Write3: + add x13, x19, #8 + dup s9, v8.s[1] + stp s8, s9, [x19] + add x19, x19, x17 + st1 {v8.s}[2], [x13], x17 + cmp w10, #1 + beq WriteEnd + dup s11, v10.s[1] + stp s10, s11, [x19] + add x19, x19, x17 + st1 {v10.s}[2], [x13], x17 + cmp w10, #2 + beq WriteEnd + dup s13, v12.s[1] + stp s12, s13, [x19] + add x19, x19, x17 + st1 {v12.s}[2], [x13], x17 + cmp w10, #3 + beq WriteEnd + dup s15, v14.s[1] + stp s14, s15, [x19] + add x19, x19, x17 + st1 {v14.s}[2], [x13], x17 + cmp w10, #4 + beq WriteEnd + dup s17, v16.s[1] + stp s16, s17, [x19] + add x19, x19, x17 + st1 {v16.s}[2], [x13], x17 + cmp w10, #5 + beq WriteEnd + dup s19, v18.s[1] + stp s18, s19, [x19] + add x19, x19, x17 + st1 {v18.s}[2], [x13], x17 + cmp w10, #6 + beq WriteEnd + dup s21, v20.s[1] + stp s20, s21, [x19] + add x19, x19, x17 + st1 {v20.s}[2], [x13], x17 + cmp w10, #7 + beq WriteEnd + dup s23, v22.s[1] + stp s22, s23, [x19] + add x19, x19, x17 + st1 {v22.s}[2], [x13], x17 + cmp w10, #8 + beq WriteEnd + dup s25, v24.s[1] + stp s24, s25, [x19] + add x19, x19, x17 + st1 {v24.s}[2], [x13], x17 + cmp w10, #9 + beq WriteEnd + dup s27, v26.s[1] + stp s26, s27, [x19] + add x19, x19, x17 + st1 {v26.s}[2], [x13], x17 + cmp w10, #10 + beq WriteEnd + dup s29, v28.s[1] + stp s28, s29, [x19] + add x19, x19, x17 + st1 {v28.s}[2], [x13], x17 + cmp w10, #11 + beq WriteEnd + dup s31, v30.s[1] + stp s30, s31, [x19] + add x19, x19, x17 + st1 {v30.s}[2], [x13] + b WriteEnd +Write4: + st1 {v8.4s}, [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s}, [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s}, [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s}, [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v16.4s}, [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v18.4s}, [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v20.4s}, [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v22.4s}, [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s}, [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s}, [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s}, [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v30.4s}, [x19], x17 + b WriteEnd +Write5: + add x13, x19, #16 + st1 {v8.4s}, [x19], x17 + str s9, [x13] + cmp w10, #1 + beq WriteEnd + add x13, x13, x17 + st1 {v10.4s}, [x19], x17 + str s11, [x13] + cmp w10, #2 + beq WriteEnd + add x13, x13, x17 + st1 {v12.4s}, [x19], x17 + str s13, [x13] + cmp w10, #3 + beq WriteEnd + add x13, x13, x17 + st1 {v14.4s}, [x19], x17 + str s15, [x13] + cmp w10, #4 + beq WriteEnd + add x13, x13, x17 + st1 {v16.4s}, [x19], x17 + str s17, [x13] + cmp w10, #5 + beq WriteEnd + add x13, x13, x17 + st1 {v18.4s}, [x19], x17 + str s19, [x13] + cmp w10, #6 + beq WriteEnd + add x13, x13, x17 + st1 {v20.4s}, [x19], x17 + str s21, [x13] + cmp w10, #7 + beq WriteEnd + add x13, x13, x17 + st1 {v22.4s}, [x19], x17 + str s23, [x13] + cmp w10, #8 + beq WriteEnd + add x13, x13, x17 + st1 {v24.4s}, [x19], x17 + str s25, [x13] + cmp w10, #9 + beq WriteEnd + add x13, x13, x17 + st1 {v26.4s}, [x19], x17 + str s27, [x13] + cmp w10, #10 + beq WriteEnd + add x13, x13, x17 + st1 {v28.4s}, [x19], x17 + str s29, [x13] + cmp w10, #11 + beq WriteEnd + add x13, x13, x17 + st1 {v30.4s}, [x19], x17 + str s31, [x13] + b WriteEnd +Write6: + add x13, x19, #16 + st1 {v8.4s}, [x19], x17 + dup s8, v9.s[1] + stp s9, s8, [x13] + cmp w10, #1 + beq WriteEnd + add x13, x13, x17 + st1 {v10.4s}, [x19], x17 + dup s10, v11.s[1] + stp s11, s10, [x13] + cmp w10, #2 + beq WriteEnd + add x13, x13, x17 + st1 {v12.4s}, [x19], x17 + dup s12, v13.s[1] + stp s13, s12, [x13] + cmp w10, #3 + beq WriteEnd + add x13, x13, x17 + st1 {v14.4s}, [x19], x17 + dup s14, v15.s[1] + stp s15, s14, [x13] + cmp w10, #4 + beq WriteEnd + add x13, x13, x17 + st1 {v16.4s}, [x19], x17 + dup s16, v17.s[1] + stp s17, s16, [x13] + cmp w10, #5 + beq WriteEnd + add x13, x13, x17 + st1 {v18.4s}, [x19], x17 + dup s18, v19.s[1] + stp s19, s18, [x13] + cmp w10, #6 + beq WriteEnd + add x13, x13, x17 + st1 {v20.4s}, [x19], x17 + dup s20, v21.s[1] + stp s21, s20, [x13] + cmp w10, #7 + beq WriteEnd + add x13, x13, x17 + st1 {v22.4s}, [x19], x17 + dup s22, v23.s[1] + stp s23, s22, [x13] + cmp w10, #8 + beq WriteEnd + add x13, x13, x17 + st1 {v24.4s}, [x19], x17 + dup s24, v25.s[1] + stp s25, s24, [x13] + cmp w10, #9 + beq WriteEnd + add x13, x13, x17 + st1 {v26.4s}, [x19], x17 + dup s26, v27.s[1] + stp s27, s26, [x13] + cmp w10, #10 + beq WriteEnd + add x13, x13, x17 + st1 {v28.4s}, [x19], x17 + dup s28, v29.s[1] + stp s29, s28, [x13] + cmp w10, #11 + beq WriteEnd + add x13, x13, x17 + st1 {v30.4s}, [x19], x17 + dup s30, v31.s[1] + stp s31, s30, [x13] + b WriteEnd +Write7: + add x13, x19, #16 + add x16, x19, #24 + st1 {v8.4s}, [x19], x17 + dup s8, v9.s[1] + stp s9, s8, [x13] + add x13, x13, x17 + st1 {v9.s}[2], [x16], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s}, [x19], x17 + dup s10, v11.s[1] + stp s11, s10, [x13] + add x13, x13, x17 + st1 {v11.s}[2], [x16], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s}, [x19], x17 + dup s12, v13.s[1] + stp s13, s12, [x13] + add x13, x13, x17 + st1 {v13.s}[2], [x16], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s}, [x19], x17 + dup s14, v15.s[1] + stp s15, s14, [x13] + add x13, x13, x17 + st1 {v15.s}[2], [x16], x17 + cmp w10, #4 + beq WriteEnd + st1 {v16.4s}, [x19], x17 + dup s16, v17.s[1] + stp s17, s16, [x13] + add x13, x13, x17 + st1 {v17.s}[2], [x16], x17 + cmp w10, #5 + beq WriteEnd + st1 {v18.4s}, [x19], x17 + dup s18, v19.s[1] + stp s19, s18, [x13] + add x13, x13, x17 + st1 {v19.s}[2], [x16], x17 + cmp w10, #6 + beq WriteEnd + st1 {v20.4s}, [x19], x17 + dup s20, v21.s[1] + stp s21, s20, [x13] + add x13, x13, x17 + st1 {v21.s}[2], [x16], x17 + cmp w10, #7 + beq WriteEnd + st1 {v22.4s}, [x19], x17 + dup s22, v23.s[1] + stp s23, s22, [x13] + add x13, x13, x17 + st1 {v23.s}[2], [x16], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s}, [x19], x17 + dup s24, v25.s[1] + stp s25, s24, [x13] + add x13, x13, x17 + st1 {v25.s}[2], [x16], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s}, [x19], x17 + dup s26, v27.s[1] + stp s27, s26, [x13] + add x13, x13, x17 + st1 {v27.s}[2], [x16], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s}, [x19], x17 + dup s28, v29.s[1] + stp s29, s28, [x13] + add x13, x13, x17 + st1 {v29.s}[2], [x16], x17 + cmp w10, #11 + beq WriteEnd + st1 {v30.4s}, [x19], x17 + dup s30, v31.s[1] + stp s31, s30, [x13] + add x13, x13, x17 + st1 {v31.s}[2], [x16], x17 + b WriteEnd +WriteC8: + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 + b WriteEnd +WriteWino: + st1 {v8.4s, v9.4s}, [x19], x8 + st1 {v10.4s, v11.4s}, [x19], x8 + st1 {v12.4s, v13.4s}, [x19], x8 + st1 {v14.4s, v15.4s}, [x19], x8 + st1 {v16.4s, v17.4s}, [x19], x8 + st1 {v18.4s, v19.4s}, [x19], x8 + st1 {v20.4s, v21.4s}, [x19], x8 + st1 {v22.4s, v23.4s}, [x19], x8 + st1 {v24.4s, v25.4s}, [x19], x8 + st1 {v26.4s, v27.4s}, [x19], x8 + st1 {v28.4s, v29.4s}, [x19], x8 + st1 {v30.4s, v31.4s}, [x19], x8 + b WriteEnd +Write8: + st1 {v8.4s, v9.4s}, [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v30.4s, v31.4s}, [x19], x17 + +WriteEnd: + subs w10, w10, #12 // lhs row - 12 + bgt L2 + +End2: + subs w7, w7, #8 // rhs col - 8 + add x1, x1, x15 // rhs ptr + stride + cbz x3, NoBiasStep + add x3, x3, #32 // bias ptr + stride +NoBiasStep: + cbnz x14, WinoDstStep + cbz x9, NoDstStep + add x2, x2, #32 // dst ptr + stride + b NoDstStep +WinoDstStep: + add x2, x2, x11 +NoDstStep: + bgt L1 + +End1: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32Opt.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..5fc3d8099007e1e15d1b3b49d1949b64a24ed7d4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32Opt.S @@ -0,0 +1,1669 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFloatNeon64Opt + sub sp, sp, #160 + + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + + mov x21, #48 // sizeof(float) * 12 + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cmp x9, #3 // c4 + beq C4Stride + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #32 + mul x16, x6, x21 // row * 8 * sizeof(float) + b NoC8Steps +C4Stride: + mov x18, #48 // 12 * sizeof(float) + mov x22, #4 + mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row + mul x8, x8, x22 // col stride + // col >= 4 , block stride 192, otherwise 12 * 4 * col + cmp x7, #4 + bge C4StrideCommon + mul x18, x18, x7 // block stride + b LoopRowStart +C4StrideCommon: + mov x18, #192 // block stride + + b LoopRowStart + +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #4 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float) + mov x21, #32 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float) +NoWinoSteps: + mov x21, #4 + mul x8, x8, x21 + +LoopRowStart: + cmp x9, #3 + bne RowStart + mov x20, x2 +RowStart: + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + ble LoopRow8 + +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + cbz x9, NoReloadDst + cmp x9, #3 + beq C4ReloadDst + mov x11, x2 + b NoReloadDst + C4ReloadDst: + mov x11, x20 + NoReloadDst: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf + + LoopDepthStart: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + ld1 {v4.4s}, [x14], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + fmul v25.4s, v4.4s, v2.s[0] + fmul v27.4s, v4.4s, v2.s[1] + fmul v29.4s, v4.4s, v2.s[2] + fmul v31.4s, v4.4s, v2.s[3] + + subs x19, x19, #1 + beq Bias + + LoopDepth: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + ld1 {v4.4s}, [x14], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + subs x19, x19, #1 + bgt LoopDepth + + Bias: + cbz x3, Activation + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b Write + + LoopDepthStartHalf: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + add x14, x14, #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + + subs x19, x19, #1 + beq BiasHalf + + LoopDepthHalf: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + add x14, x14, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf + + BiasHalf: + cbz x3, ActivationHalf + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + fadd v24.4s, v24.4s, v0.4s + fadd v26.4s, v26.4s, v0.4s + fadd v28.4s, v28.4s, v0.4s + fadd v30.4s, v30.4s, v0.4s + + ActivationHalf: + cmp x4, #3 + beq Relu6Half + cmp x4, #1 + beq ReluHalf + b Write + + Relu6Half: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + + ReluHalf: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + b Write + + +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + cbz x9, NoReloadDst8 + cmp x9, #3 + beq C4ReloadDst8 + mov x11, x2 + b NoReloadDst8 + C4ReloadDst8: + mov x11, x20 + NoReloadDst8: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf8 + + LoopDepthStart8: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + ld1 {v4.4s}, [x14], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + beq Bias8 + + LoopDepth8: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + ld1 {v4.4s}, [x14], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepth8 + + Bias8: + cbz x3, Activation8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + + Relu8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b Write + + LoopDepthStartHalf8: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + add x14, x14, #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + beq BiasHalf8 + + LoopDepthHalf8: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + ld1 {v1.4s, v2.4s}, [x10], #32 + add x14, x14, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf8 + + BiasHalf8: + cbz x3, ActivationHalf8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + + ActivationHalf8: + cmp x4, #3 + beq Relu6Half8 + cmp x4, #1 + beq ReluHalf8 + b Write + + Relu6Half8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + + ReluHalf8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + b Write + +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + cbz x9, NoReloadDst4 + cmp x9, #3 + beq C4ReloadDst4 + mov x11, x2 + b NoReloadDst4 + C4ReloadDst4: + mov x11, x20 + NoReloadDst4: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf4 + + LoopDepthStart4: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + add x10, x10, #32 + ld1 {v4.4s}, [x14], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + beq Bias4 + + LoopDepth4: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + add x10, x10, #32 + ld1 {v4.4s}, [x14], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepth4 + + Bias4: + cbz x3, Activation4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + b Write + + LoopDepthStartHalf4: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + add x10, x10, #32 + add x14, x14, #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + beq BiasHalf4 + + LoopDepthHalf4: + prfm pldl1strm, [x14, #632] + ld1 {v3.4s}, [x14], #16 + ld1 {v0.4s}, [x10], #16 + prfm pldl1keep, [x10, #632] + add x10, x10, #32 + add x14, x14, #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf4 + + BiasHalf4: + cbz x3, ActivationHalf4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + + ActivationHalf4: + cmp x4, #3 + beq Relu6Half4 + cmp x4, #1 + beq ReluHalf4 + b Write + + Relu6Half4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + + ReluHalf4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cmp x9, #3 + beq WriteC4 + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s22, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str s24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str s26, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str s28, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str s30, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v8.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v8.2s}, [x11], x8 + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + st1 {v14.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], x8 + st1 {v30.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str s17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str s19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str s21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str s23, [x19] + cmp x6, #8 + beq WriteEnd + add x19, x19, x8 + st1 {v24.4s}, [x11], x8 + str s25, [x19] + cmp x6, #9 + beq WriteEnd + add x19, x19, x8 + st1 {v26.4s}, [x11], x8 + str s27, [x19] + cmp x6, #10 + beq WriteEnd + add x19, x19, x8 + st1 {v28.4s}, [x11], x8 + str s29, [x19] + cmp x6, #11 + beq WriteEnd + add x19, x19, x8 + st1 {v30.4s}, [x11], x8 + str s31, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + st1 {v25.2s}, [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + st1 {v27.2s}, [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + st1 {v29.2s}, [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + st1 {v31.2s}, [x19] + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + st1 {v15.s}[2], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + st1 {v17.s}[2], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + st1 {v19.s}[2], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + st1 {v21.s}[2], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + st1 {v23.s}[2], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + st1 {v25.2s}, [x19], x8 + st1 {v25.s}[2], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + st1 {v27.2s}, [x19], x8 + st1 {v27.s}[2], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + st1 {v29.2s}, [x19], x8 + st1 {v29.s}[2], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + st1 {v31.2s}, [x19] + st1 {v31.s}[2], [x20] + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x19], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x19], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + st1 {v16.4s, v17.4s}, [x11], x15 + st1 {v18.4s, v19.4s}, [x11], x15 + st1 {v20.4s, v21.4s}, [x11], x15 + st1 {v22.4s, v23.4s}, [x11], x15 + st1 {v24.4s, v25.4s}, [x11], x15 + st1 {v26.4s, v27.4s}, [x11], x15 + st1 {v28.4s, v29.4s}, [x11], x15 + st1 {v30.4s, v31.4s}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s, v31.4s}, [x11], x8 + add x11, x11, #32 + b WriteEnd + WriteC4: + cmp x13, #1 + beq C4Write1 + cmp x13, #2 + beq C4Write2 + cmp x13, #3 + beq C4Write3 + cmp x13, #4 + beq C4Write4 + cmp x13, #5 + beq C4Write5 + cmp x13, #6 + beq C4Write6 + cmp x13, #7 + beq C4Write7 + b C4Write8 + C4Write1: + // add x20, x11, x8 + str s8, [x11], #4 + cmp x6, #1 + beq WriteEnd + str s10, [x11], #4 + cmp x6, #2 + beq WriteEnd + str s12, [x11], #4 + cmp x6, #3 + beq WriteEnd + str s14, [x11], #4 + cmp x6, #4 + beq WriteEnd + str s16, [x11], #4 + cmp x6, #5 + beq WriteEnd + str s18, [x11], #4 + cmp x6, #6 + beq WriteEnd + str s20, [x11], #4 + cmp x6, #7 + beq WriteEnd + str s22, [x11], #4 + cmp x6, #8 + beq WriteEnd + str s24, [x11], #4 + cmp x6, #9 + beq WriteEnd + str s26, [x11], #4 + cmp x6, #10 + beq WriteEnd + str s28, [x11], #4 + cmp x6, #11 + beq WriteEnd + str s30, [x11], #4 + b WriteEnd + C4Write2: + // add x20, x11, x8 + st1 {v8.2s}, [x11], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], #8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], #8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], #8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], #8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], #8 + b WriteEnd + C4Write3: + // add x20, x11, x8 + add x19, x11, #8 + st1 {v8.2s}, [x11] + add x11, x11, #12 + st1 {v8.s}[2], [x19] + add x19, x19, #12 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11] + add x11, x11, #12 + st1 {v10.s}[2], [x19] + add x19, x19, #12 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11] + add x11, x11, #12 + st1 {v12.s}[2], [x19] + add x19, x19, #12 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11] + add x11, x11, #12 + st1 {v14.s}[2], [x19] + add x19, x19, #12 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11] + add x11, x11, #12 + st1 {v16.s}[2], [x19] + add x19, x19, #12 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11] + add x11, x11, #12 + st1 {v18.s}[2], [x19] + add x19, x19, #12 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11] + add x11, x11, #12 + st1 {v20.s}[2], [x19] + add x19, x19, #12 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11] + add x11, x11, #12 + st1 {v22.s}[2], [x19] + add x19, x19, #12 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11] + add x11, x11, #12 + st1 {v24.s}[2], [x19] + add x19, x19, #12 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11] + add x11, x11, #12 + st1 {v26.s}[2], [x19] + add x19, x19, #12 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11] + add x11, x11, #12 + st1 {v28.s}[2], [x19] + add x19, x19, #12 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11] + add x11, x11, #12 + st1 {v30.s}[2], [x19] + add x19, x19, #12 + b WriteEnd + + C4Write4: + add x20, x11, x8 + st1 {v8.4s}, [x11], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + b WriteEnd + C4Write5: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + str s9, [x19], #4 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + str s11, [x19], #4 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + str s13, [x19], #4 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + str s15, [x19], #4 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + str s17, [x19], #4 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + str s19, [x19], #4 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + str s21, [x19], #4 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + str s23, [x19], #4 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + str s25, [x19], #4 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + str s27, [x19], #4 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + str s29, [x19], #4 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + str s31, [x19], #4 + b WriteEnd + C4Write6: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], #8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.2s}, [x19], #8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.2s}, [x19], #8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.2s}, [x19], #8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + st1 {v31.2s}, [x19], #8 + b WriteEnd + C4Write7: + add x19, x11, x8 + add x16, x19, #8 + mov x15, #12 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], x15 + st1 {v9.s}[2], [x16], x15 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], x15 + st1 {v11.s}[2], [x16], x15 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], x15 + st1 {v13.s}[2], [x16], x15 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], x15 + st1 {v15.s}[2], [x16], x15 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], x15 + st1 {v17.s}[2], [x16], x15 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], x15 + st1 {v19.s}[2], [x16], x15 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], x15 + st1 {v21.s}[2], [x16], x15 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], x15 + st1 {v23.s}[2], [x16], x15 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.2s}, [x19], x15 + st1 {v25.s}[2], [x16], x15 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.2s}, [x19], x15 + st1 {v27.s}[2], [x16], x15 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.2s}, [x19], x15 + st1 {v29.s}[2], [x16], x15 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11] + st1 {v31.2s}, [x19] + st1 {v31.s}[2], [x16] + b WriteEnd + C4Write8: + add x19, x11, x8 + add x20, x19, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.4s}, [x19], #16 + cmp x6, #1 + beq WriteEnd + + st1 {v10.4s}, [x11], #16 + st1 {v11.4s}, [x19], #16 + cmp x6, #2 + beq WriteEnd + + st1 {v12.4s}, [x11], #16 + st1 {v13.4s}, [x19], #16 + cmp x6, #3 + beq WriteEnd + + st1 {v14.4s}, [x11], #16 + st1 {v15.4s}, [x19], #16 + cmp x6, #4 + beq WriteEnd + + st1 {v16.4s}, [x11], #16 + st1 {v17.4s}, [x19], #16 + cmp x6, #5 + beq WriteEnd + + st1 {v18.4s}, [x11], #16 + st1 {v19.4s}, [x19], #16 + cmp x6, #6 + beq WriteEnd + + st1 {v20.4s}, [x11], #16 + st1 {v21.4s}, [x19], #16 + cmp x6, #7 + beq WriteEnd + + st1 {v22.4s}, [x11], #16 + st1 {v23.4s}, [x19], #16 + cmp x6, #8 + beq WriteEnd + + st1 {v24.4s}, [x11], #16 + st1 {v25.4s}, [x19], #16 + cmp x6, #9 + beq WriteEnd + + st1 {v26.4s}, [x11], #16 + st1 {v27.4s}, [x19], #16 + cmp x6, #10 + beq WriteEnd + + st1 {v28.4s}, [x11], #16 + st1 {v29.4s}, [x19], #16 + cmp x6, #11 + beq WriteEnd + + st1 {v30.4s}, [x11] + st1 {v31.4s}, [x19] + b WriteEnd + + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + ble LoopColEnd + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + cmp x9, #3 + beq C4DstStep + mov x21, #4 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C4DstStep: + add x2, x2, x18 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopRowStart + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32OptRow12.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32OptRow12.S new file mode 100644 index 0000000000000000000000000000000000000000..05465bd166e8b627624cab5861dcb119a03fc853 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32OptRow12.S @@ -0,0 +1,1229 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFloatNeon64OptRow12 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + + mov x21, #48 // sizeof(float) * 12 + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cmp x9, #3 // c4 + beq C4Stride + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #32 + mul x16, x6, x21 // row * 8 * sizeof(float) + b NoC8Steps +C4Stride: + mov x18, #48 // 12 * sizeof(float) + mov x22, #4 + mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row + mul x8, x8, x22 // col stride + // col >= 4 , block stride 192, otherwise 12 * 4 * col + cmp x7, #4 + bge C4StrideCommon + mul x18, x18, x7 // block stride + b LoopRowStart +C4StrideCommon: + mov x18, #192 // block stride + b LoopRowStart + +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #4 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float) + mov x21, #32 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float) +NoWinoSteps: + mov x21, #4 + mul x8, x8, x21 + +LoopRowStart: + cmp x9, #3 + bne LoopRow + mov x20, x2 +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + cbz x9, NoReloadDst + cmp x9, #3 + beq C4ReloadDst + mov x11, x2 + b NoReloadDst + C4ReloadDst: + mov x11, x20 + NoReloadDst: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf + + LoopDepthStart: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + fmul v25.4s, v4.4s, v2.s[0] + fmul v27.4s, v4.4s, v2.s[1] + fmul v29.4s, v4.4s, v2.s[2] + fmul v31.4s, v4.4s, v2.s[3] + + subs x19, x19, #1 + beq Bias + + LoopDepth: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + fmla v25.4s, v4.4s, v2.s[0] + fmla v27.4s, v4.4s, v2.s[1] + fmla v29.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + subs x19, x19, #1 + bgt LoopDepth + + Bias: + cbz x3, Activation + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + fadd v24.4s, v24.4s, v0.4s + fadd v25.4s, v25.4s, v1.4s + fadd v26.4s, v26.4s, v0.4s + fadd v27.4s, v27.4s, v1.4s + fadd v28.4s, v28.4s, v0.4s + fadd v29.4s, v29.4s, v1.4s + fadd v30.4s, v30.4s, v0.4s + fadd v31.4s, v31.4s, v1.4s + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v25.4s, v25.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v27.4s, v27.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v29.4s, v29.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + fmin v31.4s, v31.4s, v2.4s + + Relu: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v25.4s, v25.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v27.4s, v27.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v29.4s, v29.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + fmax v31.4s, v31.4s, v3.4s + b Write + + LoopDepthStartHalf: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v24.4s, v3.4s, v2.s[0] + fmul v26.4s, v3.4s, v2.s[1] + fmul v28.4s, v3.4s, v2.s[2] + fmul v30.4s, v3.4s, v2.s[3] + + subs x19, x19, #1 + beq BiasHalf + + LoopDepthHalf: + ld1 {v0.4s, v1.4s, v2.4s}, [x10], #48 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v24.4s, v3.4s, v2.s[0] + fmla v26.4s, v3.4s, v2.s[1] + fmla v28.4s, v3.4s, v2.s[2] + fmla v30.4s, v3.4s, v2.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf + + BiasHalf: + cbz x3, ActivationHalf + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + fadd v24.4s, v24.4s, v0.4s + fadd v26.4s, v26.4s, v0.4s + fadd v28.4s, v28.4s, v0.4s + fadd v30.4s, v30.4s, v0.4s + + ActivationHalf: + cmp x4, #3 + beq Relu6Half + cmp x4, #1 + beq ReluHalf + b Write + + Relu6Half: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v24.4s, v24.4s, v2.4s + fmin v26.4s, v26.4s, v2.4s + fmin v28.4s, v28.4s, v2.4s + fmin v30.4s, v30.4s, v2.4s + + ReluHalf: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v24.4s, v24.4s, v3.4s + fmax v26.4s, v26.4s, v3.4s + fmax v28.4s, v28.4s, v3.4s + fmax v30.4s, v30.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cmp x9, #3 + beq WriteC4 + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s22, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str s24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str s26, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str s28, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str s30, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v8.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v8.2s}, [x11], x8 + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + st1 {v14.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], x8 + st1 {v30.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str s17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str s19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str s21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str s23, [x19] + cmp x6, #8 + beq WriteEnd + add x19, x19, x8 + st1 {v24.4s}, [x11], x8 + str s25, [x19] + cmp x6, #9 + beq WriteEnd + add x19, x19, x8 + st1 {v26.4s}, [x11], x8 + str s27, [x19] + cmp x6, #10 + beq WriteEnd + add x19, x19, x8 + st1 {v28.4s}, [x11], x8 + str s29, [x19] + cmp x6, #11 + beq WriteEnd + add x19, x19, x8 + st1 {v30.4s}, [x11], x8 + str s31, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + st1 {v25.2s}, [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + st1 {v27.2s}, [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + st1 {v29.2s}, [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + st1 {v31.2s}, [x19] + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + st1 {v15.s}[2], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + st1 {v17.s}[2], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + st1 {v19.s}[2], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + st1 {v21.s}[2], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + st1 {v23.s}[2], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], x8 + st1 {v25.2s}, [x19], x8 + st1 {v25.s}[2], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], x8 + st1 {v27.2s}, [x19], x8 + st1 {v27.s}[2], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], x8 + st1 {v29.2s}, [x19], x8 + st1 {v29.s}[2], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], x8 + st1 {v31.2s}, [x19] + st1 {v31.s}[2], [x20] + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x19], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x19], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + st1 {v16.4s, v17.4s}, [x11], x15 + st1 {v18.4s, v19.4s}, [x11], x15 + st1 {v20.4s, v21.4s}, [x11], x15 + st1 {v22.4s, v23.4s}, [x11], x15 + st1 {v24.4s, v25.4s}, [x11], x15 + st1 {v26.4s, v27.4s}, [x11], x15 + st1 {v28.4s, v29.4s}, [x11], x15 + st1 {v30.4s, v31.4s}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s, v25.4s}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s, v27.4s}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s, v29.4s}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s, v31.4s}, [x11], x8 + add x11, x11, #32 + b WriteEnd + WriteC4: + cmp x13, #1 + beq C4Write1 + cmp x13, #2 + beq C4Write2 + cmp x13, #3 + beq C4Write3 + cmp x13, #4 + beq C4Write4 + cmp x13, #5 + beq C4Write5 + cmp x13, #6 + beq C4Write6 + cmp x13, #7 + beq C4Write7 + b C4Write8 + C4Write1: + str s8, [x11], #4 + cmp x6, #1 + beq WriteEnd + str s10, [x11], #4 + cmp x6, #2 + beq WriteEnd + str s12, [x11], #4 + cmp x6, #3 + beq WriteEnd + str s14, [x11], #4 + cmp x6, #4 + beq WriteEnd + str s16, [x11], #4 + cmp x6, #5 + beq WriteEnd + str s18, [x11], #4 + cmp x6, #6 + beq WriteEnd + str s20, [x11], #4 + cmp x6, #7 + beq WriteEnd + str s22, [x11], #4 + cmp x6, #8 + beq WriteEnd + str s24, [x11], #4 + cmp x6, #9 + beq WriteEnd + str s26, [x11], #4 + cmp x6, #10 + beq WriteEnd + str s28, [x11], #4 + cmp x6, #11 + beq WriteEnd + str s30, [x11], #4 + b WriteEnd + C4Write2: + st1 {v8.2s}, [x11], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], #8 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11], #8 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11], #8 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11], #8 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11], #8 + b WriteEnd + C4Write3: + add x19, x11, #8 + st1 {v8.2s}, [x11] + add x11, x11, #12 + st1 {v8.s}[2], [x19] + add x19, x19, #12 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11] + add x11, x11, #12 + st1 {v10.s}[2], [x19] + add x19, x19, #12 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11] + add x11, x11, #12 + st1 {v12.s}[2], [x19] + add x19, x19, #12 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11] + add x11, x11, #12 + st1 {v14.s}[2], [x19] + add x19, x19, #12 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11] + add x11, x11, #12 + st1 {v16.s}[2], [x19] + add x19, x19, #12 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11] + add x11, x11, #12 + st1 {v18.s}[2], [x19] + add x19, x19, #12 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11] + add x11, x11, #12 + st1 {v20.s}[2], [x19] + add x19, x19, #12 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11] + add x11, x11, #12 + st1 {v22.s}[2], [x19] + add x19, x19, #12 + cmp x6, #8 + beq WriteEnd + st1 {v24.2s}, [x11] + add x11, x11, #12 + st1 {v24.s}[2], [x19] + add x19, x19, #12 + cmp x6, #9 + beq WriteEnd + st1 {v26.2s}, [x11] + add x11, x11, #12 + st1 {v26.s}[2], [x19] + add x19, x19, #12 + cmp x6, #10 + beq WriteEnd + st1 {v28.2s}, [x11] + add x11, x11, #12 + st1 {v28.s}[2], [x19] + add x19, x19, #12 + cmp x6, #11 + beq WriteEnd + st1 {v30.2s}, [x11] + add x11, x11, #12 + st1 {v30.s}[2], [x19] + add x19, x19, #12 + b WriteEnd + C4Write4: + st1 {v8.4s}, [x11], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + b WriteEnd + C4Write5: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + str s9, [x19], #4 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + str s11, [x19], #4 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + str s13, [x19], #4 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + str s15, [x19], #4 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + str s17, [x19], #4 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + str s19, [x19], #4 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + str s21, [x19], #4 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + str s23, [x19], #4 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + str s25, [x19], #4 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + str s27, [x19], #4 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + str s29, [x19], #4 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + str s31, [x19], #4 + b WriteEnd + C4Write6: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], #8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.2s}, [x19], #8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.2s}, [x19], #8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.2s}, [x19], #8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11], #16 + st1 {v31.2s}, [x19], #8 + b WriteEnd + C4Write7: + add x19, x11, x8 + add x16, x19, #8 + mov x15, #12 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], x15 + st1 {v9.s}[2], [x16], x15 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], x15 + st1 {v11.s}[2], [x16], x15 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], x15 + st1 {v13.s}[2], [x16], x15 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], x15 + st1 {v15.s}[2], [x16], x15 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], x15 + st1 {v17.s}[2], [x16], x15 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], x15 + st1 {v19.s}[2], [x16], x15 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], x15 + st1 {v21.s}[2], [x16], x15 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], x15 + st1 {v23.s}[2], [x16], x15 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.2s}, [x19], x15 + st1 {v25.s}[2], [x16], x15 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.2s}, [x19], x15 + st1 {v27.s}[2], [x16], x15 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.2s}, [x19], x15 + st1 {v29.s}[2], [x16], x15 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11] + st1 {v31.2s}, [x19] + st1 {v31.s}[2], [x16] + b WriteEnd + C4Write8: + add x19, x11, x8 + add x20, x19, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.4s}, [x19], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.4s}, [x19], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.4s}, [x19], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.4s}, [x19], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.4s}, [x19], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.4s}, [x19], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.4s}, [x19], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.4s}, [x19], #16 + cmp x6, #8 + beq WriteEnd + st1 {v24.4s}, [x11], #16 + st1 {v25.4s}, [x19], #16 + cmp x6, #9 + beq WriteEnd + st1 {v26.4s}, [x11], #16 + st1 {v27.4s}, [x19], #16 + cmp x6, #10 + beq WriteEnd + st1 {v28.4s}, [x11], #16 + st1 {v29.4s}, [x19], #16 + cmp x6, #11 + beq WriteEnd + st1 {v30.4s}, [x11] + st1 {v31.4s}, [x19] + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + bgt LoopCol + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + cmp x9, #3 + beq C4DstStep + mov x21, #4 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C4DstStep: + add x2, x2, x18 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopRow + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32OptRow4.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32OptRow4.S new file mode 100644 index 0000000000000000000000000000000000000000..b984c4940040d729753373593ab442102773fc18 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32OptRow4.S @@ -0,0 +1,597 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64OptRow4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFloatNeon64OptRow4 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + + mov x21, #48 // sizeof(float) * 12 + + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cmp x9, #3 // c4 + beq C4Stride + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #32 + mul x16, x6, x21 // row * 8 * sizeof(float) + b NoC8Steps +C4Stride: + mov x18, #16 // 4 * sizeof(float) + mov x22, #4 + mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row + mul x8, x8, x22 // col stride + // col >= 4 , block stride 64, otherwise 4 * 4 * col + cmp x7, #4 + bge C4StrideCommon + mul x18, x18, x7 // block stride + b LoopRowStart +C4StrideCommon: + mov x18, #64 // block stride + b LoopRowStart + +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #4 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float) + mov x21, #32 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float) +NoWinoSteps: + mov x21, #4 + mul x8, x8, x21 + +LoopRowStart: + cmp x9, #3 + bne LoopRow4 + mov x20, x2 +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + cbz x9, NoReloadDst4 + cmp x9, #3 + beq C4ReloadDst4 + mov x11, x2 + b NoReloadDst4 + C4ReloadDst4: + mov x11, x20 + NoReloadDst4: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf4 + + LoopDepthStart4: + ld1 {v0.4s}, [x10], #16 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + beq Bias4 + + LoopDepth4: + ld1 {v0.4s}, [x10], #16 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepth4 + + Bias4: + cbz x3, Activation4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + + Relu4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + b Write + + LoopDepthStartHalf4: + ld1 {v0.4s}, [x10], #16 + ld1 {v3.4s}, [x14], #16 + ld1 {v4.4s}, [x14], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + beq BiasHalf4 + + LoopDepthHalf4: + ld1 {v0.4s}, [x10], #16 + ld1 {v3.4s}, [x14], #16 + ld1 {v4.4s}, [x14], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf4 + + BiasHalf4: + cbz x3, ActivationHalf4 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + + ActivationHalf4: + cmp x4, #3 + beq Relu6Half4 + cmp x4, #1 + beq ReluHalf4 + b Write + + Relu6Half4: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + + ReluHalf4: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cmp x9, #3 + beq WriteC4 + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v8.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v8.2s}, [x11], x8 + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + st1 {v14.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + st1 {v15.s}[2], [x20], x8 + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + add x11, x11, #32 + b WriteEnd + WriteC4: + cmp x13, #1 + beq C4Write1 + cmp x13, #2 + beq C4Write2 + cmp x13, #3 + beq C4Write3 + cmp x13, #4 + beq C4Write4 + cmp x13, #5 + beq C4Write5 + cmp x13, #6 + beq C4Write6 + cmp x13, #7 + beq C4Write7 + b C4Write8 + C4Write1: + str s8, [x11], #4 + cmp x6, #1 + beq WriteEnd + str s10, [x11], #4 + cmp x6, #2 + beq WriteEnd + str s12, [x11], #4 + cmp x6, #3 + beq WriteEnd + str s14, [x11], #4 + b WriteEnd + C4Write2: + st1 {v8.2s}, [x11], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], #8 + b WriteEnd + C4Write3: + add x19, x11, #8 + st1 {v8.2s}, [x11] + add x11, x11, #12 + st1 {v8.s}[2], [x19] + add x19, x19, #12 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11] + add x11, x11, #12 + st1 {v10.s}[2], [x19] + add x19, x19, #12 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11] + add x11, x11, #12 + st1 {v12.s}[2], [x19] + add x19, x19, #12 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11] + st1 {v14.s}[2], [x19] + b WriteEnd + C4Write4: + st1 {v8.4s}, [x11], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + b WriteEnd + C4Write5: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + str s9, [x19], #4 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + str s11, [x19], #4 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + str s13, [x19], #4 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + str s15, [x19], #4 + b WriteEnd + C4Write6: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], #8 + b WriteEnd + C4Write7: + add x19, x11, x8 + add x16, x19, #8 + mov x15, #12 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], x15 + st1 {v9.s}[2], [x16], x15 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], x15 + st1 {v11.s}[2], [x16], x15 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], x15 + st1 {v13.s}[2], [x16], x15 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], x15 + st1 {v15.s}[2], [x16], x15 + b WriteEnd + C4Write8: + add x19, x11, x8 + add x20, x19, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.4s}, [x19], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.4s}, [x19], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.4s}, [x19], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.4s}, [x19], #16 + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + bgt LoopCol4 + + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + cmp x9, #3 + beq C4DstStep + mov x21, #4 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C4DstStep: + add x2, x2, x18 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopRow4 + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32OptRow8.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32OptRow8.S new file mode 100644 index 0000000000000000000000000000000000000000..c5b260c09ccfbf5120abc7374e11335d5f934299 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulFp32OptRow8.S @@ -0,0 +1,911 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFloatNeon64OptRow8 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + + mov x21, #48 // sizeof(float) * 12 + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float) * 12 * depth + cmp x9, #3 // c4 + beq C4Stride + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #32 + mul x16, x6, x21 // row * 8 * sizeof(float) + b NoC8Steps +C4Stride: + mov x18, #32 // 8 * sizeof(float) + mov x22, #4 + mul x8, x8, x22 // stride * sizeof(float), in c4 stride == row + mul x8, x8, x22 // col stride + // col >= 4 , block stride 128, otherwise 8 * 4 * col + cmp x7, #4 + bge C4StrideCommon + mul x18, x18, x7 // block stride + b LoopRowStart +C4StrideCommon: + mov x18, #128 // block stride + b LoopRowStart + +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #4 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float) + mov x21, #32 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float) +NoWinoSteps: + mov x21, #4 + mul x8, x8, x21 + +LoopRowStart: + cmp x9, #3 + bne LoopRow8 + mov x20, x2 +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + cbz x9, NoReloadDst8 + cmp x9, #3 + beq C4ReloadDst8 + mov x11, x2 + b NoReloadDst8 + C4ReloadDst8: + mov x11, x20 + NoReloadDst8: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + cmp x13, #4 + ble LoopDepthStartHalf8 + + LoopDepthStart8: + ld1 {v0.4s, v1.4s}, [x10], #32 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v9.4s, v4.4s, v0.s[0] + fmul v11.4s, v4.4s, v0.s[1] + fmul v13.4s, v4.4s, v0.s[2] + fmul v15.4s, v4.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + fmul v17.4s, v4.4s, v1.s[0] + fmul v19.4s, v4.4s, v1.s[1] + fmul v21.4s, v4.4s, v1.s[2] + fmul v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + beq Bias8 + + LoopDepth8: + ld1 {v0.4s, v1.4s}, [x10], #32 + ld1 {v3.4s, v4.4s}, [x14], #32 + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v9.4s, v4.4s, v0.s[0] + fmla v11.4s, v4.4s, v0.s[1] + fmla v13.4s, v4.4s, v0.s[2] + fmla v15.4s, v4.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + fmla v17.4s, v4.4s, v1.s[0] + fmla v19.4s, v4.4s, v1.s[1] + fmla v21.4s, v4.4s, v1.s[2] + fmla v23.4s, v4.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepth8 + + Bias8: + cbz x3, Activation8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v9.4s, v9.4s, v1.4s + fadd v10.4s, v10.4s, v0.4s + fadd v11.4s, v11.4s, v1.4s + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v1.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v1.4s + fadd v16.4s, v16.4s, v0.4s + fadd v17.4s, v17.4s, v1.4s + fadd v18.4s, v18.4s, v0.4s + fadd v19.4s, v19.4s, v1.4s + fadd v20.4s, v20.4s, v0.4s + fadd v21.4s, v21.4s, v1.4s + fadd v22.4s, v22.4s, v0.4s + fadd v23.4s, v23.4s, v1.4s + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v9.4s, v9.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v11.4s, v11.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v13.4s, v13.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v15.4s, v15.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v17.4s, v17.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v19.4s, v19.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v21.4s, v21.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + fmin v23.4s, v23.4s, v2.4s + + Relu8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v9.4s, v9.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v11.4s, v11.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v13.4s, v13.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v15.4s, v15.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v17.4s, v17.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v19.4s, v19.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v21.4s, v21.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + fmax v23.4s, v23.4s, v3.4s + b Write + + LoopDepthStartHalf8: + ld1 {v0.4s, v1.4s}, [x10], #32 + ld1 {v3.4s}, [x14], #16 + ld1 {v4.4s}, [x14], #16 // weight packed 8, only hold place + fmul v8.4s, v3.4s, v0.s[0] + fmul v10.4s, v3.4s, v0.s[1] + fmul v12.4s, v3.4s, v0.s[2] + fmul v14.4s, v3.4s, v0.s[3] + fmul v16.4s, v3.4s, v1.s[0] + fmul v18.4s, v3.4s, v1.s[1] + fmul v20.4s, v3.4s, v1.s[2] + fmul v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + beq BiasHalf8 + + LoopDepthHalf8: + ld1 {v0.4s, v1.4s}, [x10], #32 + ld1 {v3.4s}, [x14], #16 + ld1 {v4.4s}, [x14], #16 // only hold place + fmla v8.4s, v3.4s, v0.s[0] + fmla v10.4s, v3.4s, v0.s[1] + fmla v12.4s, v3.4s, v0.s[2] + fmla v14.4s, v3.4s, v0.s[3] + fmla v16.4s, v3.4s, v1.s[0] + fmla v18.4s, v3.4s, v1.s[1] + fmla v20.4s, v3.4s, v1.s[2] + fmla v22.4s, v3.4s, v1.s[3] + + subs x19, x19, #1 + bgt LoopDepthHalf8 + + BiasHalf8: + cbz x3, ActivationHalf8 + ld1 {v0.4s}, [x12], #16 + ld1 {v1.4s}, [x12], #16 + fadd v8.4s, v8.4s, v0.4s + fadd v10.4s, v10.4s, v0.4s + fadd v12.4s, v12.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v16.4s, v16.4s, v0.4s + fadd v18.4s, v18.4s, v0.4s + fadd v20.4s, v20.4s, v0.4s + fadd v22.4s, v22.4s, v0.4s + + ActivationHalf8: + cmp x4, #3 + beq Relu6Half8 + cmp x4, #1 + beq ReluHalf8 + b Write + + Relu6Half8: + mov w19, #6 + dup v2.4s, w19 + scvtf v2.4s, v2.4s + fmin v8.4s, v8.4s, v2.4s + fmin v10.4s, v10.4s, v2.4s + fmin v12.4s, v12.4s, v2.4s + fmin v14.4s, v14.4s, v2.4s + fmin v16.4s, v16.4s, v2.4s + fmin v18.4s, v18.4s, v2.4s + fmin v20.4s, v20.4s, v2.4s + fmin v22.4s, v22.4s, v2.4s + + ReluHalf8: + dup v3.4s, wzr + fmax v8.4s, v8.4s, v3.4s + fmax v10.4s, v10.4s, v3.4s + fmax v12.4s, v12.4s, v3.4s + fmax v14.4s, v14.4s, v3.4s + fmax v16.4s, v16.4s, v3.4s + fmax v18.4s, v18.4s, v3.4s + fmax v20.4s, v20.4s, v3.4s + fmax v22.4s, v22.4s, v3.4s + + Write: + cmp x9, #2 + beq WriteWino + cmp x9, #3 + beq WriteC4 + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #4 + str s8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str s10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str s12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str s14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str s16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str s18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str s20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str s22, [x11] + add x11, x11, x8 + add x11, x11, #4 + b WriteEnd + Write2: + add x2, x2, #8 + st1 {v8.2s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write3: + add x2, x2, #12 + add x19, x11, #8 + st1 {v8.2s}, [x11], x8 + st1 {v8.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], x8 + st1 {v10.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], x8 + st1 {v12.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], x8 + st1 {v14.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + add x11, x11, #12 + b WriteEnd + Write4: + add x2, x2, #16 + st1 {v8.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write5: + add x2, x2, #20 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + str s9, [x19] + cmp x6, #1 + beq WriteEnd + add x19, x19, x8 + st1 {v10.4s}, [x11], x8 + str s11, [x19] + cmp x6, #2 + beq WriteEnd + add x19, x19, x8 + st1 {v12.4s}, [x11], x8 + str s13, [x19] + cmp x6, #3 + beq WriteEnd + add x19, x19, x8 + st1 {v14.4s}, [x11], x8 + str s15, [x19] + cmp x6, #4 + beq WriteEnd + add x19, x19, x8 + st1 {v16.4s}, [x11], x8 + str s17, [x19] + cmp x6, #5 + beq WriteEnd + add x19, x19, x8 + st1 {v18.4s}, [x11], x8 + str s19, [x19] + cmp x6, #6 + beq WriteEnd + add x19, x19, x8 + st1 {v20.4s}, [x11], x8 + str s21, [x19] + cmp x6, #7 + beq WriteEnd + add x19, x19, x8 + st1 {v22.4s}, [x11], x8 + str s23, [x19] + add x11, x11, #20 + b WriteEnd + Write6: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + add x11, x11, #24 + b WriteEnd + Write7: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.4s}, [x11], x8 + st1 {v9.2s}, [x19], x8 + st1 {v9.s}[2], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], x8 + st1 {v11.2s}, [x19], x8 + st1 {v11.s}[2], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], x8 + st1 {v13.2s}, [x19], x8 + st1 {v13.s}[2], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], x8 + st1 {v15.2s}, [x19], x8 + st1 {v15.s}[2], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], x8 + st1 {v17.2s}, [x19], x8 + st1 {v17.s}[2], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], x8 + st1 {v19.2s}, [x19], x8 + st1 {v19.s}[2], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], x8 + st1 {v21.2s}, [x19], x8 + st1 {v21.s}[2], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], x8 + st1 {v23.2s}, [x19], x8 + st1 {v23.s}[2], [x20], x8 + add x11, x11, #28 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x19], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x19], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x19], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v8.4s, v9.4s}, [x11], x15 + st1 {v10.4s, v11.4s}, [x11], x15 + st1 {v12.4s, v13.4s}, [x11], x15 + st1 {v14.4s, v15.4s}, [x11], x15 + st1 {v16.4s, v17.4s}, [x11], x15 + st1 {v18.4s, v19.4s}, [x11], x15 + st1 {v20.4s, v21.4s}, [x11], x15 + st1 {v22.4s, v23.4s}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #32 + st1 {v8.4s, v9.4s}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s, v11.4s}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s, v13.4s}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s, v15.4s}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s, v17.4s}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s, v19.4s}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s, v21.4s}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s, v23.4s}, [x11], x8 + add x11, x11, #32 + b WriteEnd + WriteC4: + cmp x13, #1 + beq C4Write1 + cmp x13, #2 + beq C4Write2 + cmp x13, #3 + beq C4Write3 + cmp x13, #4 + beq C4Write4 + cmp x13, #5 + beq C4Write5 + cmp x13, #6 + beq C4Write6 + cmp x13, #7 + beq C4Write7 + b C4Write8 + C4Write1: + str s8, [x11], #4 + cmp x6, #1 + beq WriteEnd + str s10, [x11], #4 + cmp x6, #2 + beq WriteEnd + str s12, [x11], #4 + cmp x6, #3 + beq WriteEnd + str s14, [x11], #4 + cmp x6, #4 + beq WriteEnd + str s16, [x11], #4 + cmp x6, #5 + beq WriteEnd + str s18, [x11], #4 + cmp x6, #6 + beq WriteEnd + str s20, [x11], #4 + cmp x6, #7 + beq WriteEnd + str s22, [x11], #4 + b WriteEnd + C4Write2: + st1 {v8.2s}, [x11], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11], #8 + b WriteEnd + C4Write3: + add x19, x11, #8 + st1 {v8.2s}, [x11] + add x11, x11, #12 + st1 {v8.s}[2], [x19] + add x19, x19, #12 + cmp x6, #1 + beq WriteEnd + st1 {v10.2s}, [x11] + add x11, x11, #12 + st1 {v10.s}[2], [x19] + add x19, x19, #12 + cmp x6, #2 + beq WriteEnd + st1 {v12.2s}, [x11] + add x11, x11, #12 + st1 {v12.s}[2], [x19] + add x19, x19, #12 + cmp x6, #3 + beq WriteEnd + st1 {v14.2s}, [x11] + add x11, x11, #12 + st1 {v14.s}[2], [x19] + add x19, x19, #12 + cmp x6, #4 + beq WriteEnd + st1 {v16.2s}, [x11] + add x11, x11, #12 + st1 {v16.s}[2], [x19] + add x19, x19, #12 + cmp x6, #5 + beq WriteEnd + st1 {v18.2s}, [x11] + add x11, x11, #12 + st1 {v18.s}[2], [x19] + add x19, x19, #12 + cmp x6, #6 + beq WriteEnd + st1 {v20.2s}, [x11] + add x11, x11, #12 + st1 {v20.s}[2], [x19] + add x19, x19, #12 + cmp x6, #7 + beq WriteEnd + st1 {v22.2s}, [x11] + st1 {v22.s}[2], [x19] + b WriteEnd + C4Write4: + st1 {v8.4s}, [x11], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + b WriteEnd + C4Write5: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + str s9, [x19], #4 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + str s11, [x19], #4 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + str s13, [x19], #4 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + str s15, [x19], #4 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + str s17, [x19], #4 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + str s19, [x19], #4 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + str s21, [x19], #4 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + str s23, [x19], #4 + b WriteEnd + C4Write6: + add x19, x11, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], #8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], #8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], #8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], #8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], #8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], #8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], #8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], #8 + b WriteEnd + C4Write7: + add x19, x11, x8 + add x16, x19, #8 + mov x15, #12 + st1 {v8.4s}, [x11], #16 + st1 {v9.2s}, [x19], x15 + st1 {v9.s}[2], [x16], x15 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.2s}, [x19], x15 + st1 {v11.s}[2], [x16], x15 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.2s}, [x19], x15 + st1 {v13.s}[2], [x16], x15 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.2s}, [x19], x15 + st1 {v15.s}[2], [x16], x15 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.2s}, [x19], x15 + st1 {v17.s}[2], [x16], x15 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.2s}, [x19], x15 + st1 {v19.s}[2], [x16], x15 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.2s}, [x19], x15 + st1 {v21.s}[2], [x16], x15 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.2s}, [x19], x15 + st1 {v23.s}[2], [x16], x15 + b WriteEnd + C4Write8: + add x19, x11, x8 + add x20, x19, x8 + st1 {v8.4s}, [x11], #16 + st1 {v9.4s}, [x19], #16 + cmp x6, #1 + beq WriteEnd + st1 {v10.4s}, [x11], #16 + st1 {v11.4s}, [x19], #16 + cmp x6, #2 + beq WriteEnd + st1 {v12.4s}, [x11], #16 + st1 {v13.4s}, [x19], #16 + cmp x6, #3 + beq WriteEnd + st1 {v14.4s}, [x11], #16 + st1 {v15.4s}, [x19], #16 + cmp x6, #4 + beq WriteEnd + st1 {v16.4s}, [x11], #16 + st1 {v17.4s}, [x19], #16 + cmp x6, #5 + beq WriteEnd + st1 {v18.4s}, [x11], #16 + st1 {v19.4s}, [x19], #16 + cmp x6, #6 + beq WriteEnd + st1 {v20.4s}, [x11], #16 + st1 {v21.4s}, [x19], #16 + cmp x6, #7 + beq WriteEnd + st1 {v22.4s}, [x11], #16 + st1 {v23.4s}, [x19], #16 + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + bgt LoopCol8 + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + cmp x9, #3 + beq C4DstStep + mov x21, #4 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C4DstStep: + add x2, x2, x18 + b NoDstStep + C8DstStep: + add x2, x2, #384 + mov x11, x2 + NoDstStep: + subs x6, x6, #12 + bgt LoopCol8 + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulInt8.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulInt8.S new file mode 100644 index 0000000000000000000000000000000000000000..731bac4bc4ac302f133a22c8263c066d6a0204da --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulInt8.S @@ -0,0 +1,420 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, +// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, +// int32_t *right_shift, int row, int col, int stride, int filter_peroc) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// w14: row +// w15: col +// w24: stride +// w27: filter_peroc + +asm_function MatmulInt8Neon64 + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + + ldr w8, [sp, #208] + ldr w9, [sp, #216] + ldr w10, [sp, #224] + ldr x11, [sp, #232] + ldr x12, [sp, #240] + ldr x13, [sp, #248] + ldr w14, [sp, #256] + ldr w15, [sp, #264] + ldr w24, [sp, #272] + ldr w27, [sp, #280] + + mov w17, #4 // sizeof(int8)*4 + mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + mov w17, #1 + mov x25, x2 +L1: + cmp w4, #0 // if at the end of col4 + beq End1 + + mov w16, w3 // reset a row4 counter + mov w23, w14 // reset a row counter + mov x17, x0 // reload a ptr + mov x22, x6 // reload a_sums ptr +L2: + cmp w16, #0 + beq End2 + + mov x28, x1 // reload b ptr + mov x19, x7 // reload bias ptr + mov w20, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w20, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x28], #16 + ld1 {v5.16b}, [x28], #16 + ld1 {v6.16b}, [x28], #16 + ld1 {v7.16b}, [x28], #16 + + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs w20, w20, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x19], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + cmp w27, #0 + beq PerTLoad +PerCLoad: + ld1 {v20.4s}, [x6], #16 + ld1 {v21.4s}, [x6], #16 + ld1 {v22.4s}, [x6], #16 + ld1 {v23.4s}, [x6], #16 + + ld1 {v13.4s}, [x12] + ld1 {v12.4s}, [x11] + ld1 {v11.4s}, [x13] + b Apply + +PerTLoad: + ld1 {v14.4s}, [x22], #16 + dup v20.4s, v14.s[0] + dup v21.4s, v14.s[1] + dup v22.4s, v14.s[2] + dup v23.4s, v14.s[3] + + ld1 {v14.s}[0], [x12] + dup v13.4s, v14.s[0] + ld1 {v14.s}[0], [x11] + dup v12.4s, v14.s[0] + ld1 {v14.s}[0], [x13] + dup v11.4s, v14.s[0] + b Apply + +Apply: + // Subtract (Asums*Zb) + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + // Apply left shift + sqshl v16.4s, v16.4s, v13.4s + sqshl v17.4s, v17.4s, v13.4s + sqshl v18.4s, v18.4s, v13.4s + sqshl v19.4s, v19.4s, v13.4s + + // Apply the fixed-point part of the multiplier. + sqrdmulh v16.4s, v16.4s, v12.4s + sqrdmulh v17.4s, v17.4s, v12.4s + sqrdmulh v18.4s, v18.4s, v12.4s + sqrdmulh v19.4s, v19.4s, v12.4s + + // Apply right shift + and v20.16b, v11.16b, v16.16b + sshr v20.4s, v20.4s, #31 + sqadd v16.4s, v16.4s, v20.4s + srshl v16.4s, v16.4s, v11.4s + and v21.16b, v11.16b, v17.16b + sshr v21.4s, v21.4s, #31 + sqadd v17.4s, v17.4s, v21.4s + srshl v17.4s, v17.4s, v11.4s + and v22.16b, v11.16b, v18.16b + sshr v22.4s, v22.4s, #31 + sqadd v18.4s, v18.4s, v22.4s + srshl v18.4s, v18.4s, v11.4s + and v23.16b, v11.16b, v19.16b + sshr v23.4s, v23.4s, #31 + sqadd v19.4s, v19.4s, v23.4s + srshl v19.4s, v19.4s, v11.4s + + // Add the destination zero point + dup v10.4s, w10 + add v16.4s, v16.4s, v10.4s + add v17.4s, v17.4s, v10.4s + add v18.4s, v18.4s, v10.4s + add v19.4s, v19.4s, v10.4s + + // Apply the act_min bound + dup v9.4s, w8 + smax v16.4s, v16.4s, v9.4s + smax v17.4s, v17.4s, v9.4s + smax v18.4s, v18.4s, v9.4s + smax v19.4s, v19.4s, v9.4s + + // Apply the act_min bound + dup v8.4s, w9 + smin v16.4s, v16.4s, v8.4s + smin v17.4s, v17.4s, v8.4s + smin v18.4s, v18.4s, v8.4s + smin v19.4s, v19.4s, v8.4s + + // int32 -> int16 + sqxtn v13.4h, v16.4s + sqxtn2 v13.8h, v17.4s + sqxtn v14.4h, v18.4s + sqxtn2 v14.8h, v19.4s + + // int16 -> int8 + sqxtn v15.8b, v13.8h + sqxtn2 v15.16b, v14.8h + + cmp w23, #4 + blt Write // if rows < 4 + cmp w15, #4 + blt Write // if cols < 4 + + st1 {v15.s}[0], [x2], x24 + st1 {v15.s}[1], [x2], x24 + st1 {v15.s}[2], [x2], x24 + st1 {v15.s}[3], [x2], x24 + b Endwrite + +Write: + cmp w15, #4 + beq WriteCol4 + cmp w15, #3 + beq WriteCol3 + cmp w15, #2 + beq WriteCol2 + cmp w15, #1 + beq WriteCol1 + +WriteCol4: + st1 {v15.s}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v15.s}[1], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v15.s}[2], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v15.s}[3], [x2], x24 + b Endwrite + +WriteCol3: + mov x26, x2 + st1 {v15.b}[0], [x26], #1 + st1 {v15.b}[1], [x26], #1 + st1 {v15.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v15.b}[4], [x26], #1 + st1 {v15.b}[5], [x26], #1 + st1 {v15.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v15.b}[8], [x26], #1 + st1 {v15.b}[9], [x26], #1 + st1 {v15.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v15.b}[12], [x26], #1 + st1 {v15.b}[13], [x26], #1 + st1 {v15.b}[14], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol2: + mov x26, x2 + st1 {v15.b}[0], [x26], #1 + st1 {v15.b}[1], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v15.b}[4], [x26], #1 + st1 {v15.b}[5], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v15.b}[8], [x26], #1 + st1 {v15.b}[9], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v15.b}[12], [x26], #1 + st1 {v15.b}[13], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol1: + st1 {v15.b}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v15.b}[4], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v15.b}[8], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v15.b}[12], [x2], x24 + b Endwrite + +Endwrite: + sub w16, w16, #4 // a row4 counter - 4 + sub w23, w23, #4 // a row counter - 4 + b L2 + +End2: + sub w4, w4, #4 // b col4 counter - 4 + sub w15, w15, #4 // b col counter - 4 + add x1, x1, x21 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + add x25, x25, #4 // output + stride(4 * sizeof(int8)) + mov x2, x25 + + cmp w27, #0 + beq PerTEnd2 + add x12, x12, #16 + add x11, x11, #16 + add x13, x13, #16 +PerTEnd2: + b L1 + +End1: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulInt8Opt.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulInt8Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..b8059324518112c0e11708438016ee739d7c3ea7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulInt8Opt.S @@ -0,0 +1,356 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, +// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, +// int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: row4 +// x4: col4 +// x5: deep16 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// x14: stride +// x15: filter_peroc +// x28: filter_zp + +asm_function MatmulInt8Opt + sub sp, sp, #224 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + stp x27, x28, [sp, #192] + stp x29, x30, [sp, #208] + + ldr w8, [sp, #224] + ldr w9, [sp, #232] + ldr w10, [sp, #240] + ldr x11, [sp, #248] + ldr x12, [sp, #256] + ldr x13, [sp, #264] + ldr x14, [sp, #272] + ldr x15, [sp, #280] + + mov x23, #4 + mul x23, x23, x5 // lhs step + mov x24, #4 + mul x24, x24, x14 // dst step +LoopRow: + mov x16, x1 // reload rhs ptr + mov x17, x4 // reload rhs col + mov x29, x7 // reload bias ptr + mov x27, x2 // reload dst ptr + ldr x28, [sp, #288] // reload filter_zp + + LoopCol: + mov x25, x6 // reload a_sums ptr + mov x19, x27 // reload dst ptr + mov x20, x0 // reload lhs ptr + mov x21, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + LoopDepth: + ld1 {v0.16b, v1.16b}, [x20], #32 + ld1 {v4.16b, v5.16b}, [x16], #32 + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + ld1 {v6.16b, v7.16b}, [x16], #32 + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + ld1 {v2.16b, v3.16b}, [x20], #32 + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs x21, x21, #16 // depth - 16 + bgt LoopDepth + + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + Bias: + cbz x7, NoBias + ld1 {v15.4s}, [x29], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + NoBias: + ld1r {v20.4s}, [x25], #4 + ld1r {v21.4s}, [x25], #4 + ld1r {v22.4s}, [x25], #4 + ld1r {v23.4s}, [x25], #4 + cbz x15, ApplySum + + ld1 {v14.4s}, [x28], #16 + mul v20.4s, v20.4s, v14.4s + mul v21.4s, v21.4s, v14.4s + mul v22.4s, v22.4s, v14.4s + mul v23.4s, v23.4s, v14.4s + + ApplySum: + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + cbnz x15, PerCLoad + + ld1r {v13.4s}, [x12] + ld1r {v12.4s}, [x11] + ld1r {v11.4s}, [x13] + b Quantize + + PerCLoad: + ld1 {v13.4s}, [x12], #16 + ld1 {v12.4s}, [x11], #16 + ld1 {v11.4s}, [x13], #16 + + Quantize: + sqshl v16.4s, v16.4s, v13.4s + sqshl v17.4s, v17.4s, v13.4s + sqshl v18.4s, v18.4s, v13.4s + sqshl v19.4s, v19.4s, v13.4s + + sqrdmulh v16.4s, v16.4s, v12.4s + sqrdmulh v17.4s, v17.4s, v12.4s + sqrdmulh v18.4s, v18.4s, v12.4s + sqrdmulh v19.4s, v19.4s, v12.4s + + and v20.16b, v11.16b, v16.16b + sshr v20.4s, v20.4s, #31 + sqadd v16.4s, v16.4s, v20.4s + srshl v16.4s, v16.4s, v11.4s + and v21.16b, v11.16b, v17.16b + sshr v21.4s, v21.4s, #31 + sqadd v17.4s, v17.4s, v21.4s + srshl v17.4s, v17.4s, v11.4s + and v22.16b, v11.16b, v18.16b + sshr v22.4s, v22.4s, #31 + sqadd v18.4s, v18.4s, v22.4s + srshl v18.4s, v18.4s, v11.4s + and v23.16b, v11.16b, v19.16b + sshr v23.4s, v23.4s, #31 + sqadd v19.4s, v19.4s, v23.4s + srshl v19.4s, v19.4s, v11.4s + + dup v10.4s, w10 + add v16.4s, v16.4s, v10.4s + add v17.4s, v17.4s, v10.4s + add v18.4s, v18.4s, v10.4s + add v19.4s, v19.4s, v10.4s + + dup v9.4s, w8 + smax v16.4s, v16.4s, v9.4s + smax v17.4s, v17.4s, v9.4s + smax v18.4s, v18.4s, v9.4s + smax v19.4s, v19.4s, v9.4s + + dup v8.4s, w9 + smin v16.4s, v16.4s, v8.4s + smin v17.4s, v17.4s, v8.4s + smin v18.4s, v18.4s, v8.4s + smin v19.4s, v19.4s, v8.4s + + sqxtn v13.4h, v16.4s + sqxtn2 v13.8h, v17.4s + sqxtn v14.4h, v18.4s + sqxtn2 v14.8h, v19.4s + + sqxtn v15.8b, v13.8h + sqxtn2 v15.16b, v14.8h + + cmp x17, #1 + beq Write1 + cmp x17, #2 + beq Write2 + cmp x17, #3 + beq Write3 + b Write4 + + Write1: + add x27, x27, #1 + st1 {v15.b}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.b}[4], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.b}[8], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.b}[12], [x19], x14 + b WriteEnd + Write2: + add x27, x27, #2 + st1 {v15.h}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.h}[2], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.h}[4], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.h}[6], [x19], x14 + b WriteEnd + Write3: + add x27, x27, #3 + add x22, x19, #2 + st1 {v15.h}[0], [x19], x14 + st1 {v15.b}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.h}[2], [x19], x14 + st1 {v15.b}[6], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.h}[4], [x19], x14 + st1 {v15.b}[10], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.h}[6], [x19], x14 + st1 {v15.b}[14], [x22], x14 + b WriteEnd + Write4: + add x27, x27, #4 + st1 {v15.s}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.s}[1], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.s}[2], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.s}[3], [x19], x14 + + WriteEnd: + subs x17, x17, #4 + bgt LoopCol + +LoopColEnd: + subs x3, x3, #4 + ble LoopRowEnd + ldr x11, [sp, #248] + ldr x12, [sp, #256] + ldr x13, [sp, #264] + add x6, x6, #16 + add x0, x0, x23 + add x2, x2, x24 + b LoopRow + +LoopRowEnd: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulR4Int8.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulR4Int8.S new file mode 100644 index 0000000000000000000000000000000000000000..1fb0a21d1bff242f46d6a85347047e08e0bd6584 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulR4Int8.S @@ -0,0 +1,193 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, +// const int *input_sum, const int *bias) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias + +asm_function MatMulR4Int8Neon64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + + mov w15, #0 // b col index + mov w16, #0 // a row index + mov w17, #4 // sizeof(int8)*4 + mul w12, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + +L1: + cmp w15, w4 + beq End1 + + mov w16, #0 // reset a row index + mov x17, x0 // reload a ptr + mov x13, x6 // reload a_sums ptr +L2: + cmp w16, w3 + beq End2 + + mov x19, x1 // reload b ptr + mov x10, x7 // reload bias ptr + mov w11, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w11, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x19], #16 + ld1 {v5.16b}, [x19], #16 + ld1 {v6.16b}, [x19], #16 + ld1 {v7.16b}, [x19], #16 + + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs w11, w11, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x10], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + // Subtract (Asums*Zb) + ld1 {v14.4s}, [x13], #16 + dup v20.4s, v14.s[0] + dup v21.4s, v14.s[1] + dup v22.4s, v14.s[2] + dup v23.4s, v14.s[3] + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + add w16, w16, #4 // a row index + 4 + b L2 + +End2: + add w15, w15, #4 // b col index + 4 + add x1, x1, x12 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + b L1 + +End1: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulWinogradFp32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulWinogradFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..23032ab964340e946a59179fe120ce76371ca105 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/MatmulWinogradFp32.S @@ -0,0 +1,183 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// MatrixMultiplyWinograd(float *matix_a, float *matrix_b, float *matrix_c, int m, int k, int n, int in_channel, int c4_channel) + // x0: matrix_a, x1: matrix_b, x2: matrix_c, x3: m, x4: k, x5: n, x6: in_channel, x7: c4_channel +asm_function MatrixMultiplyWinograd + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + st1 {v8.4s}, [sp] + stp x19, x20, [sp, #16] + stp x21, x22, [sp, #32] + mov x8, #4 + mul x10, x5, x8 + mov x17, x3 // m + mul x13, x6, x8 // in_channel * 4 + mul x21, x13, x4 // in_channel * k + + LoopM: + mov x15, x5 // n + mov x14, x1 // mat_b + LoopN: + mov x16, x0 // mat_a_m + sub x22, x5, x15 // ni + sub x19, x17, x3 // mi + mul x22, x22, x17 // ni * m + mov x11, x6 // in_channel + add x22, x22, x19 // (ni * m) + mi + mul x22, x22, x7 // x22 * c4_channel + add x20, x2, x22 // dst + offset + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + b EndLoopC + LoopC16: + mov x12, x14 + mov x9, x4 // new_k + dup v5.4s, wzr + dup v6.4s, wzr + dup v7.4s, wzr + dup v8.4s, wzr + LoopK16: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x16], x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + fmla v6.4s, v1.4s, v4.s[0] + fmla v7.4s, v2.4s, v4.s[0] + fmla v8.4s, v3.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK16 + Write16: + st1 {v5.4s}, [x20], #16 + st1 {v6.4s}, [x20], #16 + st1 {v7.4s}, [x20], #16 + st1 {v8.4s}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #64 // add 64B + subs x11, x11, #16 + beq EndLoopC + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC8: + dup v5.4s, wzr + dup v6.4s, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK8: + ld1 {v0.4s, v1.4s}, [x16], x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + fmla v6.4s, v1.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK8 + Write8: + st1 {v5.4s}, [x20], #16 + st1 {v6.4s}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #32 // add 64B + subs x11, x11, #8 + beq EndLoopC + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC4: + dup v5.4s, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK4: + ld1 {v0.4s}, [x16], x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK4 + Write4: + st1 {v5.4s}, [x20], #16 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #16 // add 16B + subs x11, x11, #4 + beq EndLoopC + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC: + dup v5.4s, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK: + ldr s0, [x16] + add x16, x16, x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK + Write1: + str s5, [x20], #4 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #4 // ptr add 4B + subs x11, x11, #1 + beq EndLoopC + b LoopC + + EndLoopC: + add x14, x14, #4 + subs x15, x15, #1 + beq EndLoopN + b LoopN + EndLoopN: + subs x3, x3, #1 + beq EndLoopM + add x0, x0, x21 + b LoopM + EndLoopM: + ld1 {v8.4s}, [sp], #16 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PostFuncBiasReluC4.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PostFuncBiasReluC4.S new file mode 100644 index 0000000000000000000000000000000000000000..8f51cc89c0f51647b37105997ec86df11d3ae4cb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PostFuncBiasReluC4.S @@ -0,0 +1,316 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, +// size_t plane_size, size_t plane_stride, size_t relu_type); +// x0 dst x1 srx x2 bias +// w3 oc4div w4 oc4mod w5 plane_size +// x6 plane_stride x7 relu_type + +// v0 ~ v7 value +// v16 bias data +// x12 oc_stride +// x14 x15 write loop tmp buf +// v26 relu6 #6; v27 relu #0 +// w10 oc4 loop control +// w13 hw loop control + + +asm_function WinogradPostFuncBiasReluC4 + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + + mov x10, #4 + add x12, x3, x4 + mul x12, x12, x10 + + mov w10, #0 + +Loop_C4: + cmp w10, w3 + beq Loop_C1 + mov x15, #4 + mul x14, x10, x15 + add x15, x0, x14 + add w10, w10, #4 + mov w13, w5 + ld1 {v16.4s}, [x2], #16 + +Loop_8x4: + cmp w13, #8 + blt Loop_4x4 + sub w13, w13, #8 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 + + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v16.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v16.4s + fadd v6.4s, v6.4s, v16.4s + fadd v7.4s, v7.4s, v16.4s + + cmp x7, #3 + beq Relu6_8x4 + cmp x7, #1 + beq Relu_8x4 + b Write_8x4 +Relu6_8x4: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s +Relu_8x4: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s +Write_8x4: + st1 {v0.4s}, [x15], x12 + st1 {v1.4s}, [x15], x12 + st1 {v2.4s}, [x15], x12 + st1 {v3.4s}, [x15], x12 + st1 {v4.4s}, [x15], x12 + st1 {v5.4s}, [x15], x12 + st1 {v6.4s}, [x15], x12 + st1 {v7.4s}, [x15], x12 + b Loop_8x4 + +Loop_4x4: + cmp w13, #4 + blt Loop_1x4 + sub w13, w13, #4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v16.4s + cmp x7, #3 + beq Relu6_4x4 + cmp x7, #1 + beq Relu_4x4 + b Write_4x4 +Relu6_4x4: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s +Relu_4x4: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s +Write_4x4: + st1 {v0.4s}, [x15], x12 + st1 {v1.4s}, [x15], x12 + st1 {v2.4s}, [x15], x12 + st1 {v3.4s}, [x15], x12 + +Loop_1x4: + cmp x7, #3 + beq Relu6_1x4 + cmp x7, #1 + beq Relu_1x4 + b Write_1x4 +Relu6_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x15], x12 + b Relu6_1x4 +Relu_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x15], x12 + b Relu_1x4 +Write_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + st1 {v0.4s}, [x15], x12 + b Write_1x4 + +HW_Add: + add x1, x1, x6 + b Loop_C4 + +Loop_C1: + cmp x4, #0 + beq End + mov w13, w5 + ld1 {v16.4s}, [x2], #16 + mov x15, #4 + mul x14, x10, x15 + add x0, x0, x14 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp x7, #3 + beq Loop_C1_1_Relu6 + cmp x7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x12 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x12 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + str s0, [x0] + add x0, x0, x12 + b Loop_C1_1_Write + +Loop_C1_2: + cmp x7, #3 + beq Loop_C1_2_Relu6 + cmp x7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + b Loop_C1_2_Write + +Loop_C1_3: + add x15, x0, #8 + cmp x7, #3 + beq Loop_C1_3_Relu6 + cmp x7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + st1 {v0.s}[2], [x15], x12 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + st1 {v0.s}[2], [x15], x12 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s}, [x1], #16 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x12 + st1 {v0.s}[2], [x15], x12 + b Loop_C1_3_Write + +End: + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PostFuncBiasReluC8.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PostFuncBiasReluC8.S new file mode 100644 index 0000000000000000000000000000000000000000..1392ab4a5449daa0a339127125afa214f32726ad --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PostFuncBiasReluC8.S @@ -0,0 +1,553 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div,size_t oc8mod +// size_t plane_size, size_t stride, int relu_type); +// x0 dst x1 srx x2 bias +// x3 oc8div x4 oc8mod x5 plane_size +// x6 stride x7 relu_type + +// v0 ~ v15 value +// v16 v17 bias data +// x14 x15 weite loop tmp buf +// x16 relu6 #6; x17 relu #0 +// w10 oc8 loop control +// w13 hw loop control + +asm_function PostFuncBiasReluC8 + sub sp, sp, #128 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + + movi v26.4s, #6 + scvtf v26.4s, v26.4s + dup v27.4s, wzr + mov w10, #0 + +Loop_C8: + cmp w10, w3 + beq Loop_C1 + mov x15, #4 + mul x14, x10, x15 + add x15, x0, x14 + add w10, w10, #8 + mov w13, w5 + ld1 {v16.4s, v17.4s}, [x2], #32 + +Loop_8x8: + cmp w13, #8 + blt Loop_4x8 + sub w13, w13, #8 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64 + + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v17.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v17.4s + fadd v6.4s, v6.4s, v16.4s + fadd v7.4s, v7.4s, v17.4s + fadd v8.4s, v8.4s, v16.4s + fadd v9.4s, v9.4s, v17.4s + fadd v10.4s, v10.4s, v16.4s + fadd v11.4s, v11.4s, v17.4s + fadd v12.4s, v12.4s, v16.4s + fadd v13.4s, v13.4s, v17.4s + fadd v14.4s, v14.4s, v16.4s + fadd v15.4s, v15.4s, v17.4s + + cmp x7, #3 + beq Relu6_8x8 + cmp x7, #1 + beq Relu_8x8 + b Write_8x8 +Relu6_8x8: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s + fmin v8.4s, v8.4s, v26.4s + fmin v9.4s, v9.4s, v26.4s + fmin v10.4s, v10.4s, v26.4s + fmin v11.4s, v11.4s, v26.4s + fmin v12.4s, v12.4s, v26.4s + fmin v13.4s, v13.4s, v26.4s + fmin v14.4s, v14.4s, v26.4s + fmin v15.4s, v15.4s, v26.4s +Relu_8x8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s + fmax v8.4s, v8.4s, v27.4s + fmax v9.4s, v9.4s, v27.4s + fmax v10.4s, v10.4s, v27.4s + fmax v11.4s, v11.4s, v27.4s + fmax v12.4s, v12.4s, v27.4s + fmax v13.4s, v13.4s, v27.4s + fmax v14.4s, v14.4s, v27.4s + fmax v15.4s, v15.4s, v27.4s +Write_8x8: + st1 {v0.4s, v1.4s}, [x15], x6 + st1 {v2.4s, v3.4s}, [x15], x6 + st1 {v4.4s, v5.4s}, [x15], x6 + st1 {v6.4s, v7.4s}, [x15], x6 + st1 {v8.4s, v9.4s}, [x15], x6 + st1 {v10.4s, v11.4s}, [x15], x6 + st1 {v12.4s, v13.4s}, [x15], x6 + st1 {v14.4s, v15.4s}, [x15], x6 + b Loop_8x8 + +Loop_4x8: + cmp w13, #4 + blt Loop_1x8 + sub w13, w13, #4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 + + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v16.4s + fadd v3.4s, v3.4s, v17.4s + fadd v4.4s, v4.4s, v16.4s + fadd v5.4s, v5.4s, v17.4s + fadd v6.4s, v6.4s, v16.4s + fadd v7.4s, v7.4s, v17.4s + + cmp x7, #3 + beq Relu6_4x8 + cmp x7, #1 + beq Relu_4x8 + b Write_4x8 +Relu6_4x8: + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmin v2.4s, v2.4s, v26.4s + fmin v3.4s, v3.4s, v26.4s + fmin v4.4s, v4.4s, v26.4s + fmin v5.4s, v5.4s, v26.4s + fmin v6.4s, v6.4s, v26.4s + fmin v7.4s, v7.4s, v26.4s +Relu_4x8: + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + fmax v2.4s, v2.4s, v27.4s + fmax v3.4s, v3.4s, v27.4s + fmax v4.4s, v4.4s, v27.4s + fmax v5.4s, v5.4s, v27.4s + fmax v6.4s, v6.4s, v27.4s + fmax v7.4s, v7.4s, v27.4s +Write_4x8: + st1 {v0.4s, v1.4s}, [x15], x6 + st1 {v2.4s, v3.4s}, [x15], x6 + st1 {v4.4s, v5.4s}, [x15], x6 + st1 {v6.4s, v7.4s}, [x15], x6 + +Loop_1x8: + cmp x7, #3 + beq Relu6_1x8 + cmp x7, #1 + beq Relu_1x8 + b Write_1x8 +Relu6_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s, v1.4s}, [x15], x6 + b Relu6_1x8 +Relu_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s, v1.4s}, [x15], x6 + b Relu_1x8 +Write_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s, v1.4s}, [x15], x6 + b Write_1x8 + + +Loop_C1: + cmp x4, #0 + beq End + mov w13, w5 + ld1 {v16.4s, v17.4s}, [x2], #32 + mov x15, #4 + mul x14, x10, x15 + add x0, x0, x14 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + cmp x4, #4 + beq Loop_C1_4 + cmp x4, #5 + beq Loop_C1_5 + cmp x4, #6 + beq Loop_C1_6 + cmp x4, #7 + beq Loop_C1_7 + +Loop_C1_1: + cmp x7, #3 + beq Loop_C1_1_Relu6 + cmp x7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x6 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + str s0, [x0] + add x0, x0, x6 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + str s0, [x0] + add x0, x0, x6 + b Loop_C1_1_Write + +Loop_C1_2: + cmp x7, #3 + beq Loop_C1_2_Relu6 + cmp x7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + b Loop_C1_2_Write + + +Loop_C1_3: + add x15, x0, #8 + cmp x7, #3 + beq Loop_C1_3_Relu6 + cmp x7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + st1 {v0.s}[2], [x15], x6 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + st1 {v0.s}[2], [x15], x6 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + dup s1, v0.s[1] + stp s0, s1, [x0] + add x0, x0, x6 + st1 {v0.s}[2], [x15], x6 + b Loop_C1_3_Write + +Loop_C1_4: + cmp x7, #3 + beq Loop_C1_4_Relu6 + cmp x7, #1 + beq Loop_C1_4_Relu + b Loop_C1_4_Write +Loop_C1_4_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmin v0.4s, v0.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x0], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fmax v0.4s, v0.4s, v27.4s + st1 {v0.4s}, [x0], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + st1 {v0.4s}, [x0], x6 + b Loop_C1_4_Write + +Loop_C1_5: + add x15, x0, #16 + cmp x7, #3 + beq Loop_C1_5_Relu6 + cmp x7, #1 + beq Loop_C1_5_Relu + b Loop_C1_5_Write +Loop_C1_5_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + str s1, [x15] + add x15, x15, x6 + b Loop_C1_5_Relu6 +Loop_C1_5_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + str s1, [x15] + add x15, x15, x6 + b Loop_C1_5_Relu +Loop_C1_5_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s}, [x0], x6 + str s1, [x15] + add x15, x15, x6 + b Loop_C1_5_Write + +Loop_C1_6: + add x15, x0, #16 + cmp x7, #3 + beq Loop_C1_6_Relu6 + cmp x7, #1 + beq Loop_C1_6_Relu + b Loop_C1_6_Write +Loop_C1_6_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + b Loop_C1_6_Relu6 +Loop_C1_6_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + b Loop_C1_6_Relu +Loop_C1_6_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + b Loop_C1_6_Write + +Loop_C1_7: + add x15, x0, #16 + add x14, x0, #24 + cmp x7, #3 + beq Loop_C1_7_Relu6 + cmp x7, #1 + beq Loop_C1_7_Relu + b Loop_C1_7_Write +Loop_C1_7_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmin v0.4s, v0.4s, v26.4s + fmin v1.4s, v1.4s, v26.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + st1 {v1.s}[2], [x14], x6 + b Loop_C1_7_Relu6 +Loop_C1_7_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fmax v0.4s, v0.4s, v27.4s + fmax v1.4s, v1.4s, v27.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + st1 {v1.s}[2], [x14], x6 + b Loop_C1_7_Relu +Loop_C1_7_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4s, v1.4s}, [x1], #32 + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + st1 {v0.4s}, [x0], x6 + dup s0, v1.s[1] + stp s1, s0, [x15] + add x15, x15, x6 + st1 {v1.s}[2], [x14], x6 + b Loop_C1_7_Write + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PostFuncInt8C4Neon64.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PostFuncInt8C4Neon64.S new file mode 100644 index 0000000000000000000000000000000000000000..a240b64da72acafbb4bc6dda6e69a0c70c456a4c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PostFuncInt8C4Neon64.S @@ -0,0 +1,259 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, +// size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, +// int32_t zp, int32_t mini, int32_t maxi); +// x0 in +// x1 bias +// x2 out +// x3 oc4div +// x4 oc4res +// x5 plane +// x6 stride +// x7 multiplier +// x8 left_shift +// x9 right_shift +// x10 zp +// x11 mini +// x12 maxi + +// v0 ~ v15 value +// x24 x25 write loop tmp buf + + +// v16 bias data + +// v26 multiplier +// v27 left_shift +// v28 right_shift +// v29 zp +// v30 min +// v31 max + +// w15 oc4 loop control +// w16 hw loop control + +asm_function PostFuncInt8C4Neon64 + sub sp, sp, #16 + stp x24, x25, [sp] + + ldr w8, [sp, #16] + ldr w9, [sp, #24] + ldr w10, [sp, #32] + ldr w11, [sp, #40] + ldr w12, [sp, #48] + ldr w13, [sp, #56] + + dup v26.4s, w7 + dup v27.4s, w8 + dup v28.4s, w9 + dup v29.4s, w10 + dup v30.4s, w11 + dup v31.4s, w12 + + mov x15, #0 + +Loop_C4: + cmp x15, x3 + beq Loop_C1 + mov x25, #4 + mul x24, x15, x25 + add x25, x2, x24 + add w15, w15, #4 + mov w16, w5 + ld1 {v16.4s}, [x1], #16 + +Loop_4x4: + cmp x16, #4 + blt Loop_1x4 + sub x16, x16, #4 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 + + add v0.4s, v0.4s, v16.4s + add v1.4s, v1.4s, v16.4s + add v2.4s, v2.4s, v16.4s + add v3.4s, v3.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqshl v1.4s, v1.4s, v27.4s + sqshl v2.4s, v2.4s, v27.4s + sqshl v3.4s, v3.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + sqrdmulh v1.4s, v1.4s, v26.4s + sqrdmulh v2.4s, v2.4s, v26.4s + sqrdmulh v3.4s, v3.4s, v26.4s + and v4.16b, v28.16b, v0.16b + and v5.16b, v28.16b, v1.16b + and v6.16b, v28.16b, v2.16b + and v7.16b, v28.16b, v3.16b + sshr v4.4s, v4.4s, #31 + sshr v5.4s, v5.4s, #31 + sshr v6.4s, v6.4s, #31 + sshr v7.4s, v7.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + sqadd v1.4s, v1.4s, v5.4s + sqadd v2.4s, v2.4s, v6.4s + sqadd v3.4s, v3.4s, v7.4s + srshl v0.4s, v0.4s, v28.4s + srshl v1.4s, v1.4s, v28.4s + srshl v2.4s, v2.4s, v28.4s + srshl v3.4s, v3.4s, v28.4s + add v0.4s, v0.4s, v29.4s + add v1.4s, v1.4s, v29.4s + add v2.4s, v2.4s, v29.4s + add v3.4s, v3.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smax v1.4s, v1.4s, v30.4s + smax v2.4s, v2.4s, v30.4s + smax v3.4s, v3.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + smin v1.4s, v1.4s, v31.4s + smin v2.4s, v2.4s, v31.4s + smin v3.4s, v3.4s, v31.4s + sqxtn v4.4h, v0.4s + sqxtn v5.4h, v1.4s + sqxtn v6.4h, v2.4s + sqxtn v7.4h, v3.4s + sqxtn v0.8b, v4.8h + sqxtn v1.8b, v5.8h + sqxtn v2.8b, v6.8h + sqxtn v3.8b, v7.8h + + st1 {v0.s}[0], [x25], x6 + st1 {v1.s}[0], [x25], x6 + st1 {v2.s}[0], [x25], x6 + st1 {v3.s}[0], [x25], x6 + b Loop_4x4 + + +Loop_1x4: + cmp x16, #0 + beq Loop_C4 + sub x16, x16, #1 + ld1 {v0.4s}, [x0], #16 + + add v0.4s, v0.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + and v2.16b, v28.16b, v0.16b + sshr v2.4s, v2.4s, #31 + sqadd v0.4s, v0.4s, v2.4s + srshl v0.4s, v0.4s, v28.4s + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + sqxtn v1.4h, v0.4s + sqxtn v0.8b, v1.8h + + st1 {v0.s}[0], [x25], x6 + b Loop_1x4 + +Loop_C1: + cmp x4, #0 + beq End + mov x16, x5 + ld1 {v16.4s}, [x1], #16 + mov x25, #4 + mul x24, x15, x25 + add x25, x2, x24 + add x24, x25, #2 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp x16, #0 + beq End + sub x16, x16, #1 + ld1 {v0.4s}, [x0], #16 + + add v0.4s, v0.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + and v2.16b, v28.16b, v0.16b + sshr v2.4s, v2.4s, #31 + sqadd v0.4s, v0.4s, v2.4s + srshl v0.4s, v0.4s, v28.4s + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + sqxtn v1.4h, v0.4s + sqxtn v0.8b, v1.8h + + st1 {v0.b}[0], [x25], x6 + b Loop_C1_1 + + +Loop_C1_2: + cmp x16, #0 + beq End + sub x16, x16, #1 + ld1 {v0.4s}, [x0], #16 + + add v0.4s, v0.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + and v2.16b, v28.16b, v0.16b + sshr v2.4s, v2.4s, #31 + sqadd v0.4s, v0.4s, v2.4s + srshl v0.4s, v0.4s, v28.4s + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + sqxtn v1.4h, v0.4s + sqxtn v0.8b, v1.8h + + st1 {v0.h}[0], [x25], x6 + b Loop_C1_2 + + +Loop_C1_3: + cmp x16, #0 + beq End + sub x16, x16, #1 + ld1 {v0.4s}, [x0], #16 + + add v0.4s, v0.4s, v16.4s + sqshl v0.4s, v0.4s, v27.4s + sqrdmulh v0.4s, v0.4s, v26.4s + and v2.16b, v28.16b, v0.16b + sshr v2.4s, v2.4s, #31 + sqadd v0.4s, v0.4s, v2.4s + srshl v0.4s, v0.4s, v28.4s + add v0.4s, v0.4s, v29.4s + smax v0.4s, v0.4s, v30.4s + smin v0.4s, v0.4s, v31.4s + sqxtn v1.4h, v0.4s + sqxtn v0.8b, v1.8h + + st1 {v0.h}[0], [x25], x6 + st1 {v0.b}[2], [x24], x6 + b Loop_C1_3 + + +End: + ldp x24, x25, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PreSum4x16Int8Peroc.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PreSum4x16Int8Peroc.S new file mode 100644 index 0000000000000000000000000000000000000000..e53b408e3580090d7dc04a2aa2d15f028c77a350 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PreSum4x16Int8Peroc.S @@ -0,0 +1,140 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div4, +// size_t oc_res4, size_t stride); + +// x0 src +// x1 sum +// x2 zp +// w3 hw4 +// w4 ic16 +// w5 oc_div4 +// w6 oc_res4 +// w7 stride + +asm_function PreSum4x16Int8Peroc + mov w8, #0 + +RowLoop: + cmp w8, w3 + beq End + add w8, w8, #4 + dup v16.4s, wzr + mov w9, #0 + mov x16, x2 + +Sum: + cmp w9, w4 + beq Mul + add w9, w9, #16 + + ld1 {v0.16b}, [x0], #16 + ld1 {v1.16b}, [x0], #16 + ld1 {v2.16b}, [x0], #16 + ld1 {v3.16b}, [x0], #16 + + saddlp v4.8h, v0.16b + saddlp v5.8h, v1.16b + saddlp v6.8h, v2.16b + saddlp v7.8h, v3.16b + saddlp v0.4S, v4.8h + saddlp v1.4S, v5.8h + saddlp v2.4S, v6.8h + saddlp v3.4S, v7.8h + addv s4, v0.4S + addv s5, v1.4S + addv s6, v2.4S + addv s7, v3.4S + mov v0.s[0], v4.s[0] + mov v0.s[1], v5.s[0] + mov v0.s[2], v6.s[0] + mov v0.s[3], v7.s[0] + + add v16.4s, v16.4s, v0.4s + b Sum + +Mul: + mov x12, x1 + add x1, x1, #64 + mov w9, #0 + + dup v1.4s, v16.s[0] + dup v2.4s, v16.s[1] + dup v3.4s, v16.s[2] + dup v4.4s, v16.s[3] + +WriteOc4: + cmp w9, w5 + beq OcRes4 + add w9, w9, #4 + ld1 {v5.4s}, [x16], #16 + + mul v16.4s, v5.4s, v1.4s + mul v17.4s, v5.4s, v2.4s + mul v18.4s, v5.4s, v3.4s + mul v19.4s, v5.4s, v4.4s + st1 {v16.4s}, [x12], #16 + st1 {v17.4s}, [x12], #16 + st1 {v18.4s}, [x12], #16 + st1 {v19.4s}, [x12], #16 + add x12, x12, x7 + b WriteOc4 + +OcRes4: + cmp w6, #0 + beq RowLoop + dup v15.4s, wzr + cmp w6, #1 + beq OcRes4_1 + cmp w6, #2 + beq OcRes4_2 + cmp w6, #3 + beq OcRes4_3 + +OcRes4_1: + ld1 {v15.s}[0], [x16] + b OcRes4End + +OcRes4_2: + ld1 {v15.d}[0], [x16] + b OcRes4End + +OcRes4_3: + ld1 {v15.d}[0], [x16] + add x16, x16, #8 + ld1 {v15.s}[2], [x16] + b OcRes4End + +OcRes4End: + mul v16.4s, v15.4s, v1.4s + mul v17.4s, v15.4s, v2.4s + mul v18.4s, v15.4s, v3.4s + mul v19.4s, v15.4s, v4.4s + st1 {v16.4s}, [x12], #16 + st1 {v17.4s}, [x12], #16 + st1 {v18.4s}, [x12], #16 + st1 {v19.4s}, [x12], #16 + b RowLoop + +End: + ret +#endif \ No newline at end of file diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.h b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PreSum4x16Int8Pert.S similarity index 37% rename from mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.h rename to mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PreSum4x16Int8Pert.S index 7995aa579263f184e8ed8146c58dc38fef19e2c9..1590b007407fae61142bb98d005476eac8e7a1f4 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/PreSum4x16Int8Pert.S @@ -13,41 +13,69 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_CXX_API_ACL_MODEL_MULTI_H -#define MINDSPORE_CCSRC_CXX_API_ACL_MODEL_MULTI_H - -#include "cxx_api/model/acl/acl_model.h" -#include -#include -#include -#include - -namespace mindspore { -namespace compile { -class MsBackend; -class FinalVM; -} // namespace compile - -class AclModelMulti : public AclModel { - public: - AclModelMulti() : AclModel(), is_multi_graph_(std::nullopt) {} - ~AclModelMulti() = default; - - Status Build() override; - Status Predict(const std::vector &inputs, std::vector *outputs) override; - - std::vector GetInputs() override; - std::vector GetOutputs() override; - - private: - void SetInputs(); - void SetOutput(); - - std::optional is_multi_graph_; - std::shared_ptr backend_; - std::shared_ptr vm_; - std::vector inputs_ = {}; - std::vector outputs_ = {}; -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_MULTI_H +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void PreSum4x16Int8Pert(const int8_t *src, int32_t *dst, size_t row4, size_t col16, int32_t filter_zp); + +// x0 src +// x1 dst +// w2 row4 +// w3 co16 +// w4 filter_zp + +asm_function PreSum4x16Int8Pert + dup v17.4s, w4 + mov w5, #0 + +RowLoop: + cmp w5, w2 + beq End + add w5, w5, #4 + dup v16.4s, wzr + mov w6, #0 + +CalLoop: + cmp w6, w3 + beq Write + add w6, w6, #16 + + ld1 {v0.16b}, [x0], #16 + ld1 {v1.16b}, [x0], #16 + ld1 {v2.16b}, [x0], #16 + ld1 {v3.16b}, [x0], #16 + + saddlp v4.8h, v0.16b + saddlp v5.8h, v1.16b + saddlp v6.8h, v2.16b + saddlp v7.8h, v3.16b + + saddlp v0.4S, v4.8h + saddlp v1.4S, v5.8h + saddlp v2.4S, v6.8h + saddlp v3.4S, v7.8h + + addv s4, v0.4S + addv s5, v1.4S + addv s6, v2.4S + addv s7, v3.4S + + mov v0.s[0], v4.s[0] + mov v0.s[1], v5.s[0] + mov v0.s[2], v6.s[0] + mov v0.s[3], v7.s[0] + + add v16.4s, v16.4s, v0.4s + b CalLoop + +Write: + mul v16.4s, v16.4s, v17.4s + st1 {v16.4s}, [x1], #16 + beq RowLoop + +End: + ret +#endif \ No newline at end of file diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/SPMM8x8Fp32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/SPMM8x8Fp32.S new file mode 100644 index 0000000000000000000000000000000000000000..614d83f8fcc363f8b66f254836fd0499678d06a3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/SPMM8x8Fp32.S @@ -0,0 +1,294 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void SPMM8x8Fp32(const float *a, const float *b, const uint32_t *nnz, const size_t *dmap, float *c, +// const float *bias, ActType act_type, size_t out_stride); +// x0: a +// x1: b +// x2: nnz +// x3: dmap +// x4: c +// x5: bias +// w6: act_type +// x7: out_stride + +// wdata tmp w8 +// loop_oc_count w9 +// loop_nnz_count w10 +// dmap tmp w11 +// a_ptr +// 8 x 1 fp32 A v0-v1 +// fp32 B-value v2 +// uint32 B-NNZ x9 +// uint32 B-INDEX x10 +// 4 MIN v3 +// 4 MAX v4 +// 2 vacc v5-v6 +// 8 x 8 fp32 C v16-v31 + +// v16[0] v18[0] v20[0] v22[0] v24[0] v26[0] v28[0] v30[0] +// v16[1] v18[1] v20[1] v22[1] v24[1] v26[1] v28[1] v30[1] +// v16[2] v18[2] v20[2] v22[2] v24[2] v26[2] v28[2] v30[2] +// v16[3] v18[3] v20[3] v22[3] v24[3] v26[3] v28[3] v30[3] +// v17[0] v19[0] v21[0] v23[0] v25[0] v27[0] v29[0] v31[0] +// v17[1] v19[1] v21[1] v23[1] v25[1] v27[1] v29[1] v31[1] +// v17[2] v19[2] v21[2] v23[2] v25[2] v27[2] v29[2] v31[2] +// v17[3] v19[3] v21[3] v23[3] v25[3] v27[3] v29[3] v31[3] + +asm_function SPMM8x8Fp32 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + + // init output with bias + ldr w8, [x5], #4 + dup v16.4s, w8 + dup v17.4s, w8 + ldr w8, [x5], #4 + dup v18.4s, w8 + dup v19.4s, w8 + ldr w8, [x5], #4 + dup v20.4s, w8 + dup v21.4s, w8 + ldr w8, [x5], #4 + dup v22.4s, w8 + dup v23.4s, w8 + ldr w8, [x5], #4 + dup v24.4s, w8 + dup v25.4s, w8 + ldr w8, [x5], #4 + dup v26.4s, w8 + dup v27.4s, w8 + ldr w8, [x5], #4 + dup v28.4s, w8 + dup v29.4s, w8 + ldr w8, [x5] + dup v30.4s, w8 + dup v31.4s, w8 + + // OC 0 + ldr w10, [x2], #4 // load nnz + cmp w10, #0 + beq OC_1 +LOOP_NNZ0: + ldr x11, [x3], #8 // load dmap + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] // load inputs + ldr w8, [x1], #4 // load weight + dup v2.4s, w8 + // matmul + fmla v16.4s, v0.4s, v2.4s + fmla v17.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ0 + +OC_1: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_2 +LOOP_NNZ1: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v18.4s, v0.4s, v2.4s + fmla v19.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ1 + +OC_2: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_3 +LOOP_NNZ2: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v20.4s, v0.4s, v2.4s + fmla v21.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ2 + +OC_3: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_4 +LOOP_NNZ3: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v22.4s, v0.4s, v2.4s + fmla v23.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ3 + +OC_4: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_5 +LOOP_NNZ4: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v24.4s, v0.4s, v2.4s + fmla v25.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ4 + +OC_5: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_6 +LOOP_NNZ5: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v26.4s, v0.4s, v2.4s + fmla v27.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ5 + +OC_6: + ldr w10, [x2], #4 + cmp w10, #0 + beq OC_7 +LOOP_NNZ6: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v28.4s, v0.4s, v2.4s + fmla v29.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ6 + +OC_7: + ldr w10, [x2], #4 + cmp w10, #0 + beq REORDER_OUT +LOOP_NNZ7: + ldr x11, [x3], #8 + add x8, x0, x11 + ld1 {v0.4s, v1.4s}, [x8] + ldr w8, [x1], #4 + dup v2.4s, w8 + // matmul + fmla v30.4s, v0.4s, v2.4s + fmla v31.4s, v1.4s, v2.4s + // loop nnz condition + subs w10, w10, #1 + bgt LOOP_NNZ7 + + // reorder output +// v16[0] v18[0] v20[0] v22[0] v24[0] v26[0] v28[0] v30[0] +// v16[1] v18[1] v20[1] v22[1] v24[1] v26[1] v28[1] v30[1] +// v16[2] v18[2] v20[2] v22[2] v24[2] v26[2] v28[2] v30[2] +// v16[3] v18[3] v20[3] v22[3] v24[3] v26[3] v28[3] v30[3] +// v17[0] v19[0] v21[0] v23[0] v25[0] v27[0] v29[0] v31[0] +// v17[1] v19[1] v21[1] v23[1] v25[1] v27[1] v29[1] v31[1] +// v17[2] v19[2] v21[2] v23[2] v25[2] v27[2] v29[2] v31[2] +// v17[3] v19[3] v21[3] v23[3] v25[3] v27[3] v29[3] v31[3] + +// v0[0] v0[1] v0[2] v0[3] v1[0] v1[1] v1[2] v1[3] +// v2[0] v2[1] v2[2] v2[3] v3[0] v3[1] v3[2] v3[3] +// v4[0] v4[1] v4[2] v4[3] v5[0] v5[1] v5[2] v5[3] +// v6[0] v6[1] v6[2] v6[3] v7[0] v7[1] v7[2] v7[3] +// v8[0] v8[1] v8[2] v8[3] v9[0] v9[1] v9[2] v9[3] +// v10[0] v10[1] v10[2] v10[3] v11[0] v11[1] v11[2] v11[3] +// v12[0] v12[1] v12[2] v12[3] v13[0] v13[1] v13[2] v13[3] +// v14[0] v14[1] v14[2] v14[3] v15[0] v15[1] v15[2] v15[3] + +REORDER_OUT: + zip1 v1.4s, v16.4s, v18.4s + zip2 v3.4s, v16.4s, v18.4s + zip1 v9.4s, v17.4s, v19.4s + zip2 v11.4s, v17.4s, v19.4s + zip1 v5.4s, v20.4s, v22.4s + zip2 v7.4s, v20.4s, v22.4s + zip1 v13.4s, v21.4s, v23.4s + zip2 v15.4s, v21.4s, v23.4s + trn1 v0.2d, v1.2d, v5.2d + trn2 v2.2d, v1.2d, v5.2d + trn1 v4.2d, v3.2d, v7.2d + trn2 v6.2d, v3.2d, v7.2d + trn1 v8.2d, v9.2d, v13.2d + trn2 v10.2d, v9.2d, v13.2d + trn1 v12.2d, v11.2d, v15.2d + trn2 v14.2d, v11.2d, v15.2d + + zip1 v16.4s, v24.4s, v26.4s + zip2 v17.4s, v24.4s, v26.4s + zip1 v20.4s, v25.4s, v27.4s + zip2 v21.4s, v25.4s, v27.4s + zip1 v18.4s, v28.4s, v30.4s + zip2 v19.4s, v28.4s, v30.4s + zip1 v22.4s, v29.4s, v31.4s + zip2 v23.4s, v29.4s, v31.4s + trn1 v1.2d, v16.2d, v18.2d + trn2 v3.2d, v16.2d, v18.2d + trn1 v5.2d, v17.2d, v19.2d + trn2 v7.2d, v17.2d, v19.2d + trn1 v9.2d, v20.2d, v22.2d + trn2 v11.2d, v20.2d, v22.2d + trn1 v13.2d, v21.2d, v23.2d + trn2 v15.2d, v21.2d, v23.2d + +WRITE_OUT: + st1 {v0.4s, v1.4s}, [x4], x7 + st1 {v2.4s, v3.4s}, [x4], x7 + st1 {v4.4s, v5.4s}, [x4], x7 + st1 {v6.4s, v7.4s}, [x4], x7 + st1 {v8.4s, v9.4s}, [x4], x7 + st1 {v10.4s, v11.4s}, [x4], x7 + st1 {v12.4s, v13.4s}, [x4], x7 + st1 {v14.4s, v15.4s}, [x4] + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/TiledC4MatmulFp32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/TiledC4MatmulFp32.S new file mode 100644 index 0000000000000000000000000000000000000000..e0efc7b2e9b1eec364ac4f44f6ce3c88ffc3cd63 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/TiledC4MatmulFp32.S @@ -0,0 +1,279 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +asm_function TiledC4MatmulFp32 +//void TiledC4MatmulFp32(float* dst, const float* src, const float* weight, size_t ic4, size_t cal_num, size_t oc4) +//x0: dst +//x1: src +//x2: weight +//x3: cal_num +//x4: ic4 +//x5: oc4 + +sub sp, sp, #128 +st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] +add x9, sp, #64 +st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + +mov x7, #4 //sizeof(float) +mul x3, x3, x7 +mov x7, #64 +mul x10, x4, x7 + +cmp x5, #2 +blt LoopOcHalf +LoopOc: + mov x8, x1 + subs x9, x4, #1 + + add x6, x2, x10 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + fmul v16.4s, v8.4s, v0.s[0] + fmul v17.4s, v8.4s, v1.s[0] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64 + fmul v18.4s, v8.4s, v2.s[0] + fmul v19.4s, v8.4s, v3.s[0] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + fmul v20.4s, v8.4s, v4.s[0] + fmul v21.4s, v8.4s, v5.s[0] + fmul v22.4s, v8.4s, v6.s[0] + fmul v23.4s, v8.4s, v7.s[0] + fmul v24.4s, v12.4s, v0.s[0] + fmul v25.4s, v12.4s, v1.s[0] + fmul v26.4s, v12.4s, v2.s[0] + fmul v27.4s, v12.4s, v3.s[0] + fmul v28.4s, v12.4s, v4.s[0] + fmul v29.4s, v12.4s, v5.s[0] + fmul v30.4s, v12.4s, v6.s[0] + fmul v31.4s, v12.4s, v7.s[0] + + beq LoopIcEnd + LoopIc: + add x2, x2, #128 + prfm pldl1keep, [x2] + prfm pldl1keep, [x2, x10] + sub x2, x2, #128 + prfm pldl1keep, [x8, #128] + prfm pldl1keep, [x8, #192] + + fmla v16.4s, v9.4s, v0.s[1] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v2.s[1] + fmla v19.4s, v9.4s, v3.s[1] + fmla v20.4s, v9.4s, v4.s[1] + fmla v21.4s, v9.4s, v5.s[1] + fmla v22.4s, v9.4s, v6.s[1] + fmla v23.4s, v9.4s, v7.s[1] + fmla v24.4s, v13.4s, v0.s[1] + fmla v25.4s, v13.4s, v1.s[1] + fmla v26.4s, v13.4s, v2.s[1] + fmla v27.4s, v13.4s, v3.s[1] + fmla v28.4s, v13.4s, v4.s[1] + fmla v29.4s, v13.4s, v5.s[1] + fmla v30.4s, v13.4s, v6.s[1] + fmla v31.4s, v13.4s, v7.s[1] + + fmla v16.4s, v10.4s, v0.s[2] + fmla v17.4s, v10.4s, v1.s[2] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v3.s[2] + fmla v20.4s, v10.4s, v4.s[2] + fmla v21.4s, v10.4s, v5.s[2] + fmla v22.4s, v10.4s, v6.s[2] + fmla v23.4s, v10.4s, v7.s[2] + fmla v24.4s, v14.4s, v0.s[2] + fmla v25.4s, v14.4s, v1.s[2] + fmla v26.4s, v14.4s, v2.s[2] + fmla v27.4s, v14.4s, v3.s[2] + fmla v28.4s, v14.4s, v4.s[2] + fmla v29.4s, v14.4s, v5.s[2] + fmla v30.4s, v14.4s, v6.s[2] + fmla v31.4s, v14.4s, v7.s[2] + + fmla v16.4s, v11.4s, v0.s[3] + fmla v17.4s, v11.4s, v1.s[3] + fmla v18.4s, v11.4s, v2.s[3] + fmla v19.4s, v11.4s, v3.s[3] + fmla v20.4s, v11.4s, v4.s[3] + fmla v21.4s, v11.4s, v5.s[3] + fmla v22.4s, v11.4s, v6.s[3] + fmla v23.4s, v11.4s, v7.s[3] + fmla v24.4s, v15.4s, v0.s[3] + fmla v25.4s, v15.4s, v1.s[3] + fmla v26.4s, v15.4s, v2.s[3] + fmla v27.4s, v15.4s, v3.s[3] + fmla v28.4s, v15.4s, v4.s[3] + fmla v29.4s, v15.4s, v5.s[3] + fmla v30.4s, v15.4s, v6.s[3] + fmla v31.4s, v15.4s, v7.s[3] + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v1.s[0] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64 + fmla v18.4s, v8.4s, v2.s[0] + fmla v19.4s, v8.4s, v3.s[0] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + fmla v20.4s, v8.4s, v4.s[0] + fmla v21.4s, v8.4s, v5.s[0] + fmla v22.4s, v8.4s, v6.s[0] + fmla v23.4s, v8.4s, v7.s[0] + fmla v24.4s, v12.4s, v0.s[0] + fmla v25.4s, v12.4s, v1.s[0] + fmla v26.4s, v12.4s, v2.s[0] + fmla v27.4s, v12.4s, v3.s[0] + fmla v28.4s, v12.4s, v4.s[0] + fmla v29.4s, v12.4s, v5.s[0] + fmla v30.4s, v12.4s, v6.s[0] + fmla v31.4s, v12.4s, v7.s[0] + + subs x9, x9, #1 + bne LoopIc + + LoopIcEnd: + fmla v16.4s, v9.4s, v0.s[1] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v2.s[1] + fmla v19.4s, v9.4s, v3.s[1] + fmla v20.4s, v9.4s, v4.s[1] + fmla v21.4s, v9.4s, v5.s[1] + fmla v22.4s, v9.4s, v6.s[1] + fmla v23.4s, v9.4s, v7.s[1] + fmla v24.4s, v13.4s, v0.s[1] + fmla v25.4s, v13.4s, v1.s[1] + fmla v26.4s, v13.4s, v2.s[1] + fmla v27.4s, v13.4s, v3.s[1] + fmla v28.4s, v13.4s, v4.s[1] + fmla v29.4s, v13.4s, v5.s[1] + fmla v30.4s, v13.4s, v6.s[1] + fmla v31.4s, v13.4s, v7.s[1] + + fmla v16.4s, v10.4s, v0.s[2] + fmla v17.4s, v10.4s, v1.s[2] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v3.s[2] + fmla v20.4s, v10.4s, v4.s[2] + fmla v21.4s, v10.4s, v5.s[2] + fmla v22.4s, v10.4s, v6.s[2] + fmla v23.4s, v10.4s, v7.s[2] + fmla v24.4s, v14.4s, v0.s[2] + fmla v25.4s, v14.4s, v1.s[2] + fmla v26.4s, v14.4s, v2.s[2] + fmla v27.4s, v14.4s, v3.s[2] + fmla v28.4s, v14.4s, v4.s[2] + fmla v29.4s, v14.4s, v5.s[2] + fmla v30.4s, v14.4s, v6.s[2] + fmla v31.4s, v14.4s, v7.s[2] + + add x7, x0, #64 + + fmla v16.4s, v11.4s, v0.s[3] + fmla v17.4s, v11.4s, v1.s[3] + fmla v18.4s, v11.4s, v2.s[3] + fmla v19.4s, v11.4s, v3.s[3] + fmla v20.4s, v11.4s, v4.s[3] + fmla v21.4s, v11.4s, v5.s[3] + fmla v22.4s, v11.4s, v6.s[3] + fmla v23.4s, v11.4s, v7.s[3] + fmla v24.4s, v15.4s, v0.s[3] + fmla v25.4s, v15.4s, v1.s[3] + fmla v26.4s, v15.4s, v2.s[3] + fmla v27.4s, v15.4s, v3.s[3] + fmla v28.4s, v15.4s, v4.s[3] + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x3 + fmla v29.4s, v15.4s, v5.s[3] + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x7], x3 + fmla v30.4s, v15.4s, v6.s[3] + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], x3 + mov x2, x6 + fmla v31.4s, v15.4s, v7.s[3] + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x7] + + subs x5, x5, #2 + beq LoopOcEnd + cmp x5, #2 + bge LoopOc + +LoopOcHalf: + mov x8, x1 + mov x9, x4 + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + + LoopIcHalf: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v1.s[0] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64 + fmla v18.4s, v8.4s, v2.s[0] + fmla v19.4s, v8.4s, v3.s[0] + fmla v20.4s, v8.4s, v4.s[0] + fmla v21.4s, v8.4s, v5.s[0] + fmla v22.4s, v8.4s, v6.s[0] + fmla v23.4s, v8.4s, v7.s[0] + + fmla v16.4s, v9.4s, v0.s[1] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v2.s[1] + fmla v19.4s, v9.4s, v3.s[1] + fmla v20.4s, v9.4s, v4.s[1] + fmla v21.4s, v9.4s, v5.s[1] + fmla v22.4s, v9.4s, v6.s[1] + fmla v23.4s, v9.4s, v7.s[1] + + fmla v16.4s, v10.4s, v0.s[2] + fmla v17.4s, v10.4s, v1.s[2] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v3.s[2] + fmla v20.4s, v10.4s, v4.s[2] + fmla v21.4s, v10.4s, v5.s[2] + fmla v22.4s, v10.4s, v6.s[2] + fmla v23.4s, v10.4s, v7.s[2] + + fmla v16.4s, v11.4s, v0.s[3] + fmla v17.4s, v11.4s, v1.s[3] + fmla v18.4s, v11.4s, v2.s[3] + fmla v19.4s, v11.4s, v3.s[3] + fmla v20.4s, v11.4s, v4.s[3] + fmla v21.4s, v11.4s, v5.s[3] + fmla v22.4s, v11.4s, v6.s[3] + fmla v23.4s, v11.4s, v7.s[3] + + subs x9, x9, #1 + bne LoopIcHalf + + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + +LoopOcEnd: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/WinogradTransLeft.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/WinogradTransLeft.S new file mode 100644 index 0000000000000000000000000000000000000000..243b19de31aa4518009594041efff5570a6c7fb0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/WinogradTransLeft.S @@ -0,0 +1,158 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +asm_function WinogradTransLeft +//void WinogradTransLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6:length + +sub sp, sp, #32 +stp x19, x20, [sp] + +mov x8, #16 // 4 * sizeof(float) +mul x8, x6, x8 +mul x9, x3, x8 +sub x9, x9, x8 +add x7, x9, x8 // step for S +mov x10, #4 +mul x10, x4, x10 // step for B + +LoopH: + mov x13, x0 + mov x15, x3 + LoopW: + mov x14, x13 + mov x17, x1 + dup v30.4s, wzr + mov x11, x6 + InitZero: + st1 {v30.4s}, [x2], #16 + subs x11, x11, #1 + bne InitZero + + sub x2, x2, x8 + mov x12, x5 + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.s}[0], [x17], x10 + ld1 {v0.s}[1], [x17], x10 + ld1 {v0.s}[2], [x17], x10 + ld1 {v0.s}[3], [x17], x10 + mov x11, x6 + mov x20, x17 + add x20, x14, x7 + add x16, x20, x7 + add x19, x16, x7 + + LoopLength4: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x14], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x20], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + ld1 {v21.4s}, [x19], #16 + fmla v17.4s, v21.4s, v0.s[3] + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength4 + + sub x2, x2, x8 + sub x12, x12, #4 + add x14, x19, x9 + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.s}[0], [x17], x10 + ld1 {v0.s}[1], [x17], x10 + ld1 {v0.s}[2], [x17], x10 + mov x11, x6 + mov x20, x17 + add x20, x14, x7 + add x16, x20, x7 + LoopLength3: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x14], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x20], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength3 + + sub x2, x2, x8 + sub x12, x12, #3 + add x14, x16, x9 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LKEnd + LoopK: + ld1r {v31.4s}, [x17], x10 + mov x11, x6 + LoopLength: + ld1 {v0.4s}, [x2] + ld1 {v1.4s}, [x14], #16 + fmla v0.4s, v1.4s, v31.4s + st1 {v0.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength + + subs x12, x12, #1 + sub x2, x2, x8 + add x14, x14, x9 + bne LoopK + + LKEnd: + subs x15, x15, #1 + add x13, x13, x8 + add x2, x2, x8 + bne LoopW + + add x1, x1, #4 //sizeof(float) + subs x4, x4, #1 + bne LoopH + + ldp x19, x20, [sp], #32 + ret + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/WinogradTransRight.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/WinogradTransRight.S new file mode 100644 index 0000000000000000000000000000000000000000..95ee50a53c9de32e8e80b14ab45250313fb18827 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm64/WinogradTransRight.S @@ -0,0 +1,160 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +asm_function WinogradTransRight +//void WinogradTransRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6: length + +sub sp, sp, #16 +stp x19, x20, [sp] + +mov x8, #16 // 4 * sizeof(float) +mul x8, x6, x8 +mul x9, x5, x8 // step for S +mov x10, #4 +mul x10, x4, x10 // step for B + +LoopH: + mov x7, x1 + mov x15, x3 + LoopW: + mov x17, x0 + mov x13, x7 + dup v30.4s, wzr + mov x11, x6 + InitZero: + st1 {v30.4s}, [x2], #16 + subs x11, x11, #1 + bne InitZero + sub x2, x2, x8 + mov x12, x5 + + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.s}[0], [x13], x10 + ld1 {v0.s}[1], [x13], x10 + ld1 {v0.s}[2], [x13], x10 + ld1 {v0.s}[3], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + add x19, x16, x8 + + LoopLength4: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x17], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x14], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + ld1 {v21.4s}, [x19], #16 + fmla v17.4s, v21.4s, v0.s[3] + + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength4 + sub x2, x2, x8 + sub x12, x12, #4 + mov x17, x19 + + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.s}[0], [x13], x10 + ld1 {v0.s}[1], [x13], x10 + ld1 {v0.s}[2], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + + LoopLength3: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x17], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x14], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength3 + sub x2, x2, x8 + sub x12, x12, #3 + mov x17, x19 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LoopKEnd + + LoopK: + ld1r {v31.4s}, [x13], x10 + + mov x11, x6 + LoopLength: + ld1 {v0.4s}, [x2] + ld1 {v1.4s}, [x17], #16 + fmla v0.4s, v1.4s, v31.4s + + st1 {v0.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength + subs x12, x12, #1 + + sub x2, x2, x8 + bne LoopK + LoopKEnd: + subs x15, x15, #1 + add x2, x2, x8 + add x7, x7, #4 //sizeof(float) + bne LoopW + + add x0, x0, x9 + subs x4, x4, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/Float16Tofloat32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/Float16Tofloat32.S new file mode 100644 index 0000000000000000000000000000000000000000..c65523b02a1e4034e13475130ae4a08280106b02 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/Float16Tofloat32.S @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + .text + .align 5 + .global Float16ToFloat32 +#ifndef __APPLE__ + .type Float16ToFloat32, %function +#endif + +// void Float16ToFloat32(const float16_t *input, float *output, int number); +// r0: input, r1: output, r2: number +Float16ToFloat32: + cmp r2, #0 + beq LoopEnd + cmp r2, #16 + bge Loop16 + cmp r2, #8 + bge Loop8 + b Loop + Loop16: + vld1.16 {q0, q1}, [r0]! + vcvt.f32.f16 q3, d0 + vcvt.f32.f16 q4, d1 + vcvt.f32.f16 q5, d2 + vst1.32 {q3, q4}, [r1]! + vcvt.f32.f16 q6, d3 + subs r2, r2, #16 + vst1.32 {q5, q6}, [r1]! + beq LoopEnd + cmp r2, #16 + bge Loop16 + cmp r2, #8 + bge Loop8 + b Loop + Loop8: + vld1.16 {q0}, [r0]! + vcvt.f32.f16 q1, d0 + vcvt.f32.f16 q2, d1 + vst1.32 {q1, q2}, [r1]! + subs r2, r2, #8 + beq LoopEnd + cmp r2, #8 + bge Loop8 + b Loop + Loop: + vldr.16 s0, [r0] + vcvtb.f32.f16 s0, s0 + vstr.32 s0, [r1] + add r0, r0, #2 + add r1, r1, #4 + subs r2, r2, #1 + bgt Loop + LoopEnd: + mov pc, lr +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/Float32ToFloat16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/Float32ToFloat16.S new file mode 100644 index 0000000000000000000000000000000000000000..d28a6f64291e3310198f6fd122e4eb1893d3f2eb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/Float32ToFloat16.S @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + .text + .align 5 + .global Float32ToFloat16 +#ifndef __APPLE__ + .type Float32ToFloat16, %function +#endif + +// void Float32ToFloat16(const float *input, float16_t *output, int number); +// r0: input, r1: output, r2: number +Float32ToFloat16: + cmp r2, #0 + beq LoopEnd + cmp r2, #16 + bge Loop16 + cmp r2, #8 + bge Loop8 + b Loop + Loop16: + vld1.32 {q0, q1}, [r0]! + vcvt.f16.f32 d0, q0 + vcvt.f16.f32 d1, q1 + vld1.32 {q2, q3}, [r0]! + vcvt.f16.f32 d2, q2 + vcvt.f16.f32 d3, q3 + vst1.16 {q0, q1}, [r1]! + subs r2, r2, #16 + beq LoopEnd + cmp r2, #16 + bge Loop16 + cmp r2, #8 + bge Loop8 + b Loop + Loop8: + vld1.32 {q0, q1}, [r0]! + vcvt.f16.f32 d0, q0 + vcvt.f16.f32 d1, q1 + vst1.16 {q0}, [r1]! + subs r2, r2, #8 + beq LoopEnd + cmp r2, #8 + bge Loop8 + b Loop + Loop: + vldr s0, [r0] + vcvtb.f16.f32 s0, s0 + vstr.16 s0, [r1] + add r0, r0, #4 + add r1, r1, #2 + subs r2, r2, #1 + bgt Loop + LoopEnd: + mov pc, lr +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/MatVecMulFp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/MatVecMulFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..f0c3bab858a2188b9e62760ec429b8626309e6af --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/MatVecMulFp16.S @@ -0,0 +1,237 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatVecMulA32NeonFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int col) { +// r0: a +// r1: b +// r2: c +// r3: bias +// r4: act_type +// r5: depth +// r6: col + +asm_function MatVecMulA32NeonFp16 + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r0-r8, r9, r10, r11, lr} + add sp, sp, #52 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + + add r10, r5, r5 // stride = depth * sizeof(float16_t) + mov lr, #4 + mul r11, r10, lr // stride x 4 + + cmp r6, #4 + blt Col1Loop + +Col4Loop: + mov r7, r0 // reload a(vector) ptr + mov r9, r1 // reload b(matrix) ptr + mov r8, r5 // reload depth value + + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q15, q15, q15 + + cmp r8, #8 + bge Col4Depth8 + cmp r8, #4 + bge Col4Depth4 + cmp r8, #1 + bge Col4Depth1 + b Col4End + + Col4Depth8: + vld1.16 {q8}, [r7]! + add lr, r9, r10 + vld1.16 {q0}, [r9]! + vld1.16 {q1}, [lr], r10 + vld1.16 {q2}, [lr], r10 + vld1.16 {q3}, [lr] + + vmla.f16 q9, q8, q0 + vmla.f16 q10, q8, q1 + vmla.f16 q11, q8, q2 + vmla.f16 q12, q8, q3 + sub r8, r8, #8 + cmp r8, #8 + bge Col4Depth8 + cmp r8, #4 + bge Col4Depth4 + b AddC4 + + Col4Depth4: + vld1.16 {d16}, [r7]! + add lr, r9, r10 + vld1.16 {d0}, [r9]! + vld1.16 {d2}, [lr], r10 + vld1.16 {d4}, [lr], r10 + vld1.16 {d6}, [lr] + + vmla.f16 d18, d16, d0 + vmla.f16 d20, d16, d2 + vmla.f16 d22, d16, d4 + vmla.f16 d24, d16, d6 + sub r8, r8, #4 + cmp r8, #4 + bge Col4Depth4 + + AddC4: + vpadd.f16 d0, d18, d19 + vpadd.f16 d1, d20, d21 + vpadd.f16 d2, d22, d23 + vpadd.f16 d4, d24, d25 + vpadd.f16 d30, d0, d1 + vpadd.f16 d31, d2, d4 + vpadd.f16 d30, d30, d31 + cmp r8, #1 + bge Col4Depth1 + b Col4End + + Col4Depth1: + vld1.16 {d0[0]}, [r7]! + add lr, r9, r10 + vld1.16 {d2[0]}, [r9]! + vld1.16 {d2[1]}, [lr], r10 + vld1.16 {d2[2]}, [lr], r10 + vld1.16 {d2[3]}, [lr] + + vmla.f16 d30, d2, d0[0] + subs r8, r8, #1 + bne Col4Depth1 + + Col4End: + cmp r3, #0 + beq Col4Activation + vld1.16 {d26}, [r3]! + vadd.f16 d30, d30, d26 + + Col4Activation: + cmp r4, #3 + beq Col4Relu6 + cmp r4, #1 + beq Col4Relu + b Col4Write + + Col4Relu6: + vmov.i16 q12, #6 + vcvt.f16.s16 q12, q12 + vmin.f16 d30, d30, d24 + + Col4Relu: + veor q13, q13, q13 + vmax.f16 d30, d30, d26 + + Col4Write: + vst1.16 {d30}, [r2]! + subs r6, r6, #4 + beq End + add r1, r1, r11 + cmp r6, #4 + bge Col4Loop + +Col1Loop: + mov r7, r0 // reload a(vector) ptr + mov r9, r1 // reload b(matrix) ptr + mov r8, r5 // reload depth value + veor q10, q10, q10 + veor q15, q15, q15 + + cmp r8, #8 + bge Col1Depth8 + cmp r8, #4 + bge Col1Depth4 + cmp r8, #1 + bge Col1Depth1 + b Col1End + + Col1Depth8: + vld1.16 {q0}, [r7]! + vld1.16 {q1}, [r9]! + vmla.f16 q10, q1, q0 + sub r8, r8, #8 + cmp r8, #8 + bge Col1Depth8 + cmp r8, #4 + bge Col1Depth4 + b AddC1 + + Col1Depth4: + vld1.16 {d0}, [r7]! + vld1.16 {d2}, [r9]! + vmla.f16 d20, d2, d0 + sub r8, r8, #4 + cmp r8, #4 + bge Col1Depth4 + + AddC1: + vpadd.f16 d30, d20, d21 + vpadd.f16 d30, d30, d20 + vpadd.f16 d30, d30, d20 + cmp r8, #1 + bge Col1Depth1 + b Col1End + + Col1Depth1: + vld1.16 {d0[0]}, [r7]! + vld1.16 {d2[0]}, [r9]! + vmla.f16 d30, d2, d0[0] + subs r8, r8, #1 + bne Col1Depth1 + + Col1End: + cmp r3, #0 + beq Col1Activation + vld1.16 {d28[0]}, [r3]! + vadd.f16 d30, d30, d28 + + Col1Activation: + cmp r4, #3 + beq Col1Relu6 + cmp r4, #1 + beq Col1Relu + b Col1Write + + Col1Relu6: + vmov.i16 d26, #6 + vcvt.f16.s16 d26, d26 + vmin.f16 d30, d30, d26 + + Col1Relu: + veor d24, d24, d24 + vmax.f16 d30, d30, d24 + + Col1Write: + vst1.16 {d30[0]}, [r2]! + subs r6, r6, #1 + beq End + add r1, r1, r10 + b Col1Loop + +End: + sub sp, sp, #52 + pop {r0-r8, r9, r10, r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S new file mode 100644 index 0000000000000000000000000000000000000000..328d6a049271abec371033bbe298d022cdc0e759 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S @@ -0,0 +1,617 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + .text + .align 5 + .global MatMul12x8A32Fp16 +#ifndef __APPLE__ + .type MatMul12x8A32Fp16, %function +#endif + +// void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, +// int deep, int row, int col, int stride, bool write_mode); +// r0: a +// r1: b +// r2: dst +// r3: bias +// #4: depth +// #8: row +// #12: col +// #16: stride +// #20: writeNhwc/writeWino + +asm_function MatMul12x8A32Fp16 + // r13(sp) and r15(pc) can not be used!! + // r9 r4 is tmp register + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r3-r11, lr} + vpush {q4-q7} + add sp, sp, #104 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + ldr lr, [sp, #20] + + mov r10, r1 // b + mov r11, r0 // a + mov r12, r2 // dst + + cmp lr, #2 + bne NoWinograd + mul r4, r8, r7 // stride * col + add r4, r4, r4 // r4 * sizeof(float16_t) + mov r9, #16 + mul r9, r8, r9 // stride * 8 * sizeof(float16_t) +NoWinograd: + add r8, r8, r8 // stride * sizeof(float16_t) + +a .req r0 +weight .req r1 +dst .req r2 +bias .req r3 +depth .req r5 +row .req r6 +col .req r7 +stride .req r8 +b_tmp .req r10 +a_tmp .req r11 +dst_tmp .req r12 + +.macro STORE_12x8 p1 + vst1.16 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_12x7 p1, p2, p3 + add r4, dst, #8 + add r9, dst, #12 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + vst1.16 {\p3}, [r9] + add dst, dst, stride +.endm + +.macro STORE_12x6 p1, p2 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + add dst, dst, stride +.endm + +.macro STORE_12x5 p1, p2 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride +.endm + +.macro STORE_12x4 p1 + vst1.16 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_12x3 p1, p2 + add r4, dst, #4 + vst1.32 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride +.endm + +.macro STORE_12x2 p1 + vst1.32 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_12x1 p1 + vst1.16 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_C8 p1, p2 + vst1.16 {\p1}, [dst] + cmp row, \p2 + add dst, dst, stride + beq WriteEnd +.endm + +.macro STORE_C7 p1, p2, p3, p4 + add r4, dst, #8 + add r9, dst, #12 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + vst1.16 {\p3}, [r9] + add dst, dst, stride + cmp row, \p4 + beq WriteEnd +.endm + +.macro STORE_C6 p1, p2, p3 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + add dst, dst, stride + cmp row, \p3 + beq WriteEnd +.endm + +.macro STORE_C5 p1, p2, p3 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride + cmp row, \p3 + beq WriteEnd +.endm + +.macro STORE_C4 p1, p2 + vst1.16 {\p1}, [dst] + cmp row, \p2 + add dst, dst, stride + beq WriteEnd +.endm + +.macro STORE_C3 p1, p2, p3 + add r4, dst, #4 + vst1.32 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride + cmp row, \p3 + beq WriteEnd +.endm + +.macro STORE_C2 p1, p2 + vst1.32 {\p1}, [dst] + add dst, dst, stride + cmp row, \p2 + beq WriteEnd +.endm + +.macro STORE_C1 p1, p2 + vst1.16 {\p1}, [dst] + add dst, dst, stride + cmp row, \p2 + beq WriteEnd +.endm + +LoopRow12: + ldr bias, [sp, #-40] + LoopCol8: + mov dst, dst_tmp + mov a, a_tmp + ldr depth, [sp, #4] + veor q4, q4, q4 + veor q5, q5, q5 + veor q6, q6, q6 + veor q7, q7, q7 + veor q8, q8, q8 + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + LoopDepth: + vld1.16 {q0, d2}, [a]! + vld1.16 {q2}, [weight]! + vmla.f16 q4, q2, d0[0] + vmla.f16 q5, q2, d0[1] + vmla.f16 q6, q2, d0[2] + vmla.f16 q7, q2, d0[3] + vmla.f16 q8, q2, d1[0] + vmla.f16 q9, q2, d1[1] + vmla.f16 q10, q2, d1[2] + vmla.f16 q11, q2, d1[3] + vmla.f16 q12, q2, d2[0] + vmla.f16 q13, q2, d2[1] + vmla.f16 q14, q2, d2[2] + vmla.f16 q15, q2, d2[3] + + subs depth, depth, #1 + bne LoopDepth + + Bias: + cmp bias, #0 + beq Activation + vld1.16 {q0}, [bias]! + vadd.f16 q4, q4, q0 + vadd.f16 q5, q5, q0 + vadd.f16 q6, q6, q0 + vadd.f16 q7, q7, q0 + vadd.f16 q8, q8, q0 + vadd.f16 q9, q9, q0 + vadd.f16 q10, q10, q0 + vadd.f16 q11, q11, q0 + vadd.f16 q12, q12, q0 + vadd.f16 q13, q13, q0 + vadd.f16 q14, q14, q0 + vadd.f16 q15, q15, q0 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i16 q2, #0x4600 + vadd.f16 q4, q4, q2 + vadd.f16 q5, q5, q2 + vadd.f16 q6, q6, q2 + vadd.f16 q7, q7, q2 + vmin.f16 q8, q8, q2 + vmin.f16 q9, q9, q2 + vmin.f16 q10, q10, q2 + vmin.f16 q11, q11, q2 + vmin.f16 q12, q12, q2 + vmin.f16 q13, q13, q2 + vmin.f16 q14, q14, q2 + vmin.f16 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f16 q4, q4, q3 + vmax.f16 q5, q5, q3 + vmax.f16 q6, q6, q3 + vmax.f16 q7, q7, q3 + vmax.f16 q8, q8, q3 + vmax.f16 q9, q9, q3 + vmax.f16 q10, q10, q3 + vmax.f16 q11, q11, q3 + vmax.f16 q12, q12, q3 + vmax.f16 q13, q13, q3 + vmax.f16 q14, q14, q3 + vmax.f16 q15, q15, q3 + + Write: + ldr lr, [sp, #20] + cmp lr, #2 + beq WriteWinograd + cmp row, #12 + bge Write12xCol + b WriteRowxCol + + WriteWinograd: + vst1.16 {q4}, [dst] + add dst, dst, r4 + vst1.16 {q5}, [dst] + add dst, dst, r4 + vst1.16 {q6}, [dst] + add dst, dst, r4 + vst1.16 {q7}, [dst] + add dst, dst, r4 + vst1.16 {q8}, [dst] + add dst, dst, r4 + vst1.16 {q9}, [dst] + add dst, dst, r4 + vst1.16 {q10}, [dst] + add dst, dst, r4 + vst1.16 {q11}, [dst] + add dst, dst, r4 + vst1.16 {q12}, [dst] + add dst, dst, r4 + vst1.16 {q13}, [dst] + add dst, dst, r4 + vst1.16 {q14}, [dst] + add dst, dst, r4 + vst1.16 {q15}, [dst] + add dst_tmp, dst_tmp, r9 + b WriteEnd + Write12xCol: + cmp col, #8 + bge Write12x8 + cmp col, #1 + beq Write12x1 + cmp col, #2 + beq Write12x2 + cmp col, #3 + beq Write12x3 + cmp col, #4 + beq Write12x4 + cmp col, #5 + beq Write12x5 + cmp col, #6 + beq Write12x6 + b Write12x7 + + WriteRowxCol: + cmp col, #8 + bge WriteRowx8 + cmp col, #1 + beq WriteRowx1 + cmp col, #2 + beq WriteRowx2 + cmp col, #3 + beq WriteRowx3 + cmp col, #4 + beq WriteRowx4 + cmp col, #5 + beq WriteRowx5 + cmp col, #6 + beq WriteRowx6 + b WriteRowx7 + + Write12x8: + STORE_12x8 q4 + STORE_12x8 q5 + STORE_12x8 q6 + STORE_12x8 q7 + STORE_12x8 q8 + STORE_12x8 q9 + STORE_12x8 q10 + STORE_12x8 q11 + STORE_12x8 q12 + STORE_12x8 q13 + STORE_12x8 q14 + STORE_12x8 q15 + b WriteEnd + WriteRowx8: + STORE_C8 q4, #1 + STORE_C8 q5, #2 + STORE_C8 q6, #3 + STORE_C8 q7, #4 + STORE_C8 q8, #5 + STORE_C8 q9, #6 + STORE_C8 q10, #7 + STORE_C8 q11, #8 + STORE_C8 q12, #9 + STORE_C8 q13, #10 + STORE_C8 q14, #11 + STORE_C8 q15, #12 + b WriteEnd + + Write12x1: + STORE_12x1 d8[0] + STORE_12x1 d10[0] + STORE_12x1 d12[0] + STORE_12x1 d14[0] + STORE_12x1 d16[0] + STORE_12x1 d18[0] + STORE_12x1 d20[0] + STORE_12x1 d22[0] + STORE_12x1 d24[0] + STORE_12x1 d26[0] + STORE_12x1 d28[0] + STORE_12x1 d30[0] + b WriteEnd + WriteRowx1: + STORE_C1 d8[0], #1 + STORE_C1 d10[0], #2 + STORE_C1 d12[0], #3 + STORE_C1 d14[0], #4 + STORE_C1 d16[0], #5 + STORE_C1 d18[0], #6 + STORE_C1 d20[0], #7 + STORE_C1 d22[0], #8 + STORE_C1 d24[0], #9 + STORE_C1 d26[0], #10 + STORE_C1 d28[0], #11 + STORE_C1 d30[0], #12 + b WriteEnd + + Write12x2: + STORE_12x2 d8[0] + STORE_12x2 d10[0] + STORE_12x2 d12[0] + STORE_12x2 d14[0] + STORE_12x2 d16[0] + STORE_12x2 d18[0] + STORE_12x2 d20[0] + STORE_12x2 d22[0] + STORE_12x2 d24[0] + STORE_12x2 d26[0] + STORE_12x2 d28[0] + STORE_12x2 d30[0] + b WriteEnd + WriteRowx2: + STORE_C2 d8[0], #1 + STORE_C2 d10[0], #2 + STORE_C2 d12[0], #3 + STORE_C2 d14[0], #4 + STORE_C2 d16[0], #5 + STORE_C2 d18[0], #6 + STORE_C2 d20[0], #7 + STORE_C2 d22[0], #8 + STORE_C2 d24[0], #9 + STORE_C2 d26[0], #10 + STORE_C2 d28[0], #11 + STORE_C2 d30[0], #12 + b WriteEnd + + Write12x3: + STORE_12x3 d8[0], d8[2] + STORE_12x3 d10[0], d10[2] + STORE_12x3 d12[0], d12[2] + STORE_12x3 d14[0], d14[2] + STORE_12x3 d16[0], d16[2] + STORE_12x3 d18[0], d18[2] + STORE_12x3 d20[0], d20[2] + STORE_12x3 d22[0], d22[2] + STORE_12x3 d24[0], d24[2] + STORE_12x3 d26[0], d26[2] + STORE_12x3 d28[0], d28[2] + STORE_12x3 d30[0], d30[2] + b WriteEnd + WriteRowx3: + STORE_C3 d8[0], d8[2], #1 + STORE_C3 d10[0], d10[2], #2 + STORE_C3 d12[0], d12[2], #3 + STORE_C3 d14[0], d14[2], #4 + STORE_C3 d16[0], d16[2], #5 + STORE_C3 d18[0], d18[2], #6 + STORE_C3 d20[0], d20[2], #7 + STORE_C3 d22[0], d22[2], #8 + STORE_C3 d24[0], d24[2], #9 + STORE_C3 d26[0], d26[2], #10 + STORE_C3 d28[0], d28[2], #11 + STORE_C3 d30[0], d30[2], #12 + b WriteEnd + + Write12x4: + STORE_12x4 d8 + STORE_12x4 d10 + STORE_12x4 d12 + STORE_12x4 d14 + STORE_12x4 d16 + STORE_12x4 d18 + STORE_12x4 d20 + STORE_12x4 d22 + STORE_12x4 d24 + STORE_12x4 d26 + STORE_12x4 d28 + STORE_12x4 d30 + b WriteEnd + WriteRowx4: + STORE_C4 d8, #1 + STORE_C4 d10, #2 + STORE_C4 d12, #3 + STORE_C4 d14, #4 + STORE_C4 d16, #5 + STORE_C4 d18, #6 + STORE_C4 d20, #7 + STORE_C4 d22, #8 + STORE_C4 d24, #9 + STORE_C4 d26, #10 + STORE_C4 d28, #11 + STORE_C4 d30, #12 + b WriteEnd + + Write12x5: + STORE_12x5 d8, d9[0] + STORE_12x5 d10, d11[0] + STORE_12x5 d12, d13[0] + STORE_12x5 d14, d15[0] + STORE_12x5 d16, d17[0] + STORE_12x5 d18, d19[0] + STORE_12x5 d20, d21[0] + STORE_12x5 d22, d23[0] + STORE_12x5 d24, d25[0] + STORE_12x5 d26, d27[0] + STORE_12x5 d28, d29[0] + STORE_12x5 d30, d31[0] + b WriteEnd + WriteRowx5: + STORE_C5 d8, d9[0], #1 + STORE_C5 d10, d11[0], #2 + STORE_C5 d12, d13[0], #3 + STORE_C5 d14, d15[0], #4 + STORE_C5 d16, d17[0], #5 + STORE_C5 d18, d19[0], #6 + STORE_C5 d20, d21[0], #7 + STORE_C5 d22, d23[0], #8 + STORE_C5 d24, d25[0], #9 + STORE_C5 d26, d27[0], #10 + STORE_C5 d28, d29[0], #11 + STORE_C5 d30, d31[0], #12 + b WriteEnd + + Write12x6: + STORE_12x6 d8, d9[0] + STORE_12x6 d10, d11[0] + STORE_12x6 d12, d13[0] + STORE_12x6 d14, d15[0] + STORE_12x6 d16, d17[0] + STORE_12x6 d18, d19[0] + STORE_12x6 d20, d21[0] + STORE_12x6 d22, d23[0] + STORE_12x6 d24, d25[0] + STORE_12x6 d26, d27[0] + STORE_12x6 d28, d29[0] + STORE_12x6 d30, d31[0] + b WriteEnd + WriteRowx6: + STORE_C6 d8, d9[0], #1 + STORE_C6 d10, d11[0], #2 + STORE_C6 d12, d13[0], #3 + STORE_C6 d14, d15[0], #4 + STORE_C6 d16, d17[0], #5 + STORE_C6 d18, d19[0], #6 + STORE_C6 d20, d21[0], #7 + STORE_C6 d22, d23[0], #8 + STORE_C6 d24, d25[0], #9 + STORE_C6 d26, d27[0], #10 + STORE_C6 d28, d29[0], #11 + STORE_C6 d30, d31[0], #12 + b WriteEnd + + Write12x7: + STORE_12x7 d8, d9[0], d9[2] + STORE_12x7 d10, d11[0], d11[2] + STORE_12x7 d12, d13[0], d13[2] + STORE_12x7 d14, d15[0], d15[2] + STORE_12x7 d16, d17[0], d17[2] + STORE_12x7 d18, d19[0], d19[2] + STORE_12x7 d20, d21[0], d21[2] + STORE_12x7 d22, d23[0], d23[2] + STORE_12x7 d24, d25[0], d25[2] + STORE_12x7 d26, d27[0], d27[2] + STORE_12x7 d28, d29[0], d29[2] + STORE_12x7 d30, d31[0], d31[2] + b WriteEnd + WriteRowx7: + STORE_C7 d8, d9[0], d9[2], #1 + STORE_C7 d10, d11[0], d11[2], #2 + STORE_C7 d12, d13[0], d13[2], #3 + STORE_C7 d14, d15[0], d15[2], #4 + STORE_C7 d16, d17[0], d17[2], #5 + STORE_C7 d18, d19[0], d19[2], #6 + STORE_C7 d20, d21[0], d21[2], #7 + STORE_C7 d22, d23[0], d23[2], #8 + STORE_C7 d24, d25[0], d25[2], #9 + STORE_C7 d26, d27[0], d27[2], #10 + STORE_C7 d28, d29[0], d29[2], #11 + STORE_C7 d30, d31[0], d31[2], #12 + b WriteEnd + + WriteEnd: + cmp col, #8 + ble LoopColEnd + sub col, col, #8 + ldr lr, [sp, #20] + cmp lr, #2 + beq LoopCol8 + add dst_tmp, dst_tmp, #16 + b LoopCol8 + LoopColEnd: + cmp row, #12 + ble LoopRowEnd + sub row, row, #12 + mov a_tmp, a + mov weight, b_tmp + ldr lr, [sp, #20] + cmp lr, #2 + beq WinogradDst + ldr lr, [sp, #12] + sub lr, lr, col + add lr, lr, lr // col *= 2 + sub dst_tmp, dst, lr + b LoopRow + WinogradDst: + add dst_tmp, dst, r9 + LoopRow: + mov dst, dst_tmp + ldr col, [sp, #12] + b LoopRow12 +LoopRowEnd: + sub sp, sp, #104 + vpop {q4-q7} + pop {r3-r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/TiledC4MatmulFp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/TiledC4MatmulFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..6f3b6434210d056a2b2746359ce16c17c6a74962 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/TiledC4MatmulFp16.S @@ -0,0 +1,108 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +asm_function TiledC4MatmulFp16 +// void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t cal_num, size_t ic4, +// size_t oc4); +// r0: dst +// r1: src +// r2: weight +// r3: cal_num +// r4(sp): ic4 +// r5(sp + #4): oc4 +push {r4-r11, lr} +vpush {q4-q7} +add sp, sp, #100 +ldr r4, [sp] +ldr r5, [sp, #4] // oc4 +add r3, r3, r3 +mov r7, r1 + +cmp r5, #1 +blt LoopOCEnd +cmp r4, #1 +blt LoopICEnd +LoopOC: + ldr r4, [sp] + veor q15, q15, q15 + veor q14, q14, q14 + veor q13, q13, q13 + veor q12, q12, q12 + LoopIC: + vld1.16 {q4, q5}, [r2]! // weight + vld1.16 {q2, q3}, [r1]! // 16 number src + vmla.f16 d24, d8, d4[0] + vmla.f16 d24, d9, d4[1] + vmla.f16 d24, d10, d4[2] + vmla.f16 d24, d11, d4[3] + + vmla.f16 d25, d8, d5[0] + vmla.f16 d25, d9, d5[1] + vmla.f16 d25, d10, d5[2] + vmla.f16 d25, d11, d5[3] + + vmla.f16 d26, d8, d6[0] + vmla.f16 d26, d9, d6[1] + vmla.f16 d26, d10, d6[2] + vmla.f16 d26, d11, d6[3] + + vmla.f16 d27, d8, d7[0] + vmla.f16 d27, d9, d7[1] + vmla.f16 d27, d10, d7[2] + vmla.f16 d27, d11, d7[3] + + vld1.16 {q0, q1}, [r1]! // 16 number src + vmla.f16 d28, d8, d0[0] + vmla.f16 d28, d9, d0[1] + vmla.f16 d28, d10, d0[2] + vmla.f16 d28, d11, d0[3] + + vmla.f16 d29, d8, d1[0] + vmla.f16 d29, d9, d1[1] + vmla.f16 d29, d10, d1[2] + vmla.f16 d29, d11, d1[3] + + vmla.f16 d30, d8, d2[0] + vmla.f16 d30, d9, d2[1] + vmla.f16 d30, d10, d2[2] + vmla.f16 d30, d11, d2[3] + + vmla.f16 d31, d8, d3[0] + vmla.f16 d31, d9, d3[1] + vmla.f16 d31, d10, d3[2] + vmla.f16 d31, d11, d3[3] + + subs r4, r4, #1 + bne LoopIC + b LoopICEnd + LoopICEnd: + mov lr, r0 + vst1.16 {q12, q13}, [lr]! + vst1.16 {q14, q15}, [lr]! + add r0, r0, r3 // dst += cal_num + mov r1, r7 + subs r5, r5, #1 + bne LoopOC +LoopOCEnd: + sub sp, sp, #100 + vpop {q4-q7} + pop {r4-r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/WinogradTransLeft.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/WinogradTransLeft.S new file mode 100644 index 0000000000000000000000000000000000000000..5c87be2845624875355dd7703bae879296aa7461 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/WinogradTransLeft.S @@ -0,0 +1,165 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, +// size_t length); +//r0: S +//r1: B +//r2: M +//r3: w +//r4: h +//r5: k +//r6: length +asm_function WinogradTransLeftFp16 + push {r0, r3, r4-r11, lr} + vpush {q4-q7} + add sp, sp, #108 + ldr r4, [sp] + ldr r6, [sp, #8] + + mov r8, #8 // 4 * sizeof(float16_t) + mul r8, r6, r8 // length * 4 * 2 + mul r7, r3, r8 // step for S + add r10, r4, r4 // step for B + +cmp r4, #1 +blt LoopHEnd +cmp r3, #1 +blt LoopHEnd +LoopH: + ldr r3, [sp, #-40] // w + ldr r0, [sp, #-44] + LoopW: + mov r11, r0 // S + mov lr, r1 // B_src + veor q6, q6, q6 + ldr r6, [sp, #8] + InitZero: + vst1.16 {d12}, [r2]! + subs r6, r6, #1 + bne InitZero + sub r2, r2, r8 + + ldr r5, [sp, #4] + cmp r5, #4 + bge LoopK4 + cmp r5, #3 + bge LoopK3 + cmp r5, #1 + bge LoopK1 + b LoopKEnd + + LoopK4: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + vld1.16 {d3[0]}, [lr], r10 + vld1.16 {d5[0]}, [lr], r10 + vld1.16 {d7[0]}, [lr], r10 + + add r12, r11, r7 + add r14, r12, r7 + add r9, r14, r7 + LoopK4L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vld1.16 {d2}, [r12]! + vmla.f16 d12, d0, d1[0] + vld1.16 {d4}, [r14]! + vmla.f16 d12, d2, d3[0] + vld1.16 {d6}, [r9]! + vmla.f16 d12, d4, d5[0] + vmla.f16 d12, d6, d7[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK4L4 + + subs r5, r5, #4 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + sub r9, r9, r8 + add r11, r9, r7 + cmp r5, #4 + bge LoopK4 + cmp r5, #3 + bge LoopK3 + b LoopK1 + + LoopK3: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + vld1.16 {d3[0]}, [lr], r10 + vld1.16 {d5[0]}, [lr], r10 + + add r12, r11, r7 + add r9, r12, r7 + LoopK3L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vld1.16 {d2}, [r12]! + vmla.f16 d12, d0, d1[0] + vld1.16 {d4}, [r9]! + vmla.f16 d12, d2, d3[0] + vmla.f16 d12, d4, d5[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK3L4 + + subs r5, r5, #3 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + sub r9, r9, r8 + add r11, r9, r7 + cmp r5, #3 + bge LoopK3 + b LoopK1 + + LoopK1: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + + LoopK1L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vmla.f16 d12, d0, d1[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK1L4 + + subs r5, r5, #1 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + sub r11, r11, r8 + add r11, r11, r7 + b LoopK1 + LoopKEnd: + add r0, r0, r8 // S += unitstep + subs r3, r3, #1 + bne LoopW + LoopWEnd: + subs r4, r4, #1 + beq LoopHEnd + add r1, r1, #2 // B += 1 + b LoopH +LoopHEnd: + sub sp, sp, #108 + vpop {q4-q7} + pop {r0, r3, r4-r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/WinogradTransRight.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/WinogradTransRight.S new file mode 100644 index 0000000000000000000000000000000000000000..ec04dc88e01b4da10516b96bb36253114c63814c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/arm82_aarch32_fp16/WinogradTransRight.S @@ -0,0 +1,163 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, +// size_t length); +//r0: S +//r1: B +//r2: M +//r3: w +//r4: h +//r5: k +//r6: length +asm_function WinogradTransRightFp16 + push {r1, r3, r4-r11, lr} + vpush {q4-q7} + add sp, sp, #108 + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + + mov r8, #8 // 4 * sizeof(float16_t) + mul r8, r6, r8 // length * 4 * 2 + mul r7, r5, r8 // step for S = k * unitStep * 4 + add r10, r4, r4 // step for B = 2 * h + +cmp r4, #1 +blt LoopHEnd +cmp r3, #1 +blt LoopHEnd +LoopH: + ldr r3, [sp, #-40] // w + ldr r1, [sp, #-44] + LoopW: + mov r11, r0 // S + mov lr, r1 // B_src + veor q6, q6, q6 + ldr r6, [sp, #8] + InitZero: + vst1.16 {d12}, [r2]! + subs r6, r6, #1 + bne InitZero + sub r2, r2, r8 + + ldr r5, [sp, #4] + cmp r5, #4 + bge LoopK4 + cmp r5, #3 + bge LoopK3 + cmp r5, #1 + bge LoopK1 + b LoopKEnd + + LoopK4: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + vld1.16 {d3[0]}, [lr], r10 + vld1.16 {d5[0]}, [lr], r10 + vld1.16 {d7[0]}, [lr], r10 + + add r12, r11, r8 + add r14, r12, r8 + add r9, r14, r8 + LoopK4L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vld1.16 {d2}, [r12]! + vmla.f16 d12, d0, d1[0] + vld1.16 {d4}, [r14]! + vmla.f16 d12, d2, d3[0] + vld1.16 {d6}, [r9]! + vmla.f16 d12, d4, d5[0] + vmla.f16 d12, d6, d7[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK4L4 + + subs r5, r5, #4 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + mov r11, r9 + cmp r5, #4 + bge LoopK4 + cmp r5, #3 + bge LoopK3 + b LoopK1 + + LoopK3: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + vld1.16 {d3[0]}, [lr], r10 + vld1.16 {d5[0]}, [lr], r10 + + add r12, r11, r8 + add r9, r12, r8 + LoopK3L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vld1.16 {d2}, [r12]! + vmla.f16 d12, d0, d1[0] + vld1.16 {d4}, [r9]! + vmla.f16 d12, d2, d3[0] + vmla.f16 d12, d4, d5[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK3L4 + + subs r5, r5, #3 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + mov r11, r9 + cmp r5, #3 + bge LoopK3 + b LoopK1 + + LoopK1: + ldr r6, [sp, #8] + vld1.16 {d1[0]}, [lr], r10 + + LoopK1L4: + vld1.16 {d12}, [r2] + vld1.16 {d0}, [r11]! + vmla.f16 d12, d0, d1[0] + vst1.16 {d12}, [r2]! // dst + subs r6, r6, #1 // length + bne LoopK1L4 + + subs r5, r5, #1 // k + beq LoopKEnd + sub r2, r2, r8 // dst - step + b LoopK1 + + LoopKEnd: + add r1, r1, #2 // B[x] + subs r3, r3, #1 + bne LoopW + LoopWEnd: + add r0, r0, r7 + subs r4, r4, #1 + beq LoopHEnd + b LoopH +LoopHEnd: + sub sp, sp, #108 + vpop {q4-q7} + pop {r1, r3, r4-r11, pc} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/ConvDwFp32Avx3x3.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/ConvDwFp32Avx3x3.S new file mode 100644 index 0000000000000000000000000000000000000000..0f993934ba1fae8697e9e960ab19f673d50c611d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/ConvDwFp32Avx3x3.S @@ -0,0 +1,313 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl/assembly_global.h" +.text +.align 4 + +// void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, size_t output_width, +// size_t input_stride, size_t relum, szie_t relu6) +// in linux x64 platform: +// rdi: output +// rsi: input +// rdx: weights +// rcx: bias +// r8: channels +// r9: output_width +// 8: input_stride +// 16: relu +// 24: relu6 + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output +// rdx: input +// r8: weights +// r9: bias +// 40: channels +// 48: output_width +// 56: input_stride +// 64: relu +// 72: relu6 +asm_function ConvDwFp32Avx3x3 + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rbx + pushq %rbp + pushq %r9 // -56 + pushq %r8 // -64 + pushq %rcx // -72 + pushq %rdx // -80 + pushq %rsi // -88 + pushq %rdi // -96 + addq $96, %rsp + +#ifdef _WIN32 + movq %rcx, %rdi + movq %rdx, %rsi + movq %r8, %rdx + movq %r9, %rcx + movq 40(%rsp), %r8 // channels + movq 48(%rsp), %r9 // output_width + + mov %rdx, -80(%rsp) + mov %rcx, -72(%rsp) + mov %r9, -56(%rsp) + mov %r8, -64(%rsp) + movq 56(%rsp), %rbp // input_stride + movq %rbp, 8(%rsp) + movq 64(%rsp), %rbp // relu + movq %rbp, 16(%rsp) + movq 72(%rsp), %rbp // relu6 + movq %rbp, 24(%rsp) +#endif + + movq $6, %rax + vcvtsi2ss %rax, %xmm15, %xmm15 + vshufps $0, %xmm15, %xmm15, %xmm15 + vinsertf128 $1, %xmm15, %ymm15, %ymm15 + vxorps %ymm14, %ymm14, %ymm14 + + LoopPixel: + movq -80(%rsp), %rdx + movq -72(%rsp), %rcx + movq -64(%rsp), %r8 + movq (%rsi), %r9 + movq 8(%rsi), %r10 + movq 16(%rsi), %r11 + movq 24(%rsi), %r12 + movq 32(%rsi), %r13 + movq 40(%rsi), %r14 + movq 48(%rsi), %r15 + movq 56(%rsi), %rbp + movq 64(%rsi), %rbx + + vmovups (%r9), %ymm0 + addq $32, %r9 + vmovups (%r10), %ymm1 + addq $32, %r10 + vmovups (%r11), %ymm2 + addq $32, %r11 + + vmovups (%rdx), %ymm11 + addq $32, %rdx + vmovups (%rdx), %ymm12 + addq $32, %rdx + vmovups (%rdx), %ymm13 + addq $32, %rdx + + vmovups (%rcx), %ymm10 + addq $32, %rcx + + cmpq $8, %r8 + jbe LeftLoop + LoopC8: + vfmadd231ps %ymm11, %ymm0, %ymm10 + vmovups (%r12), %ymm3 + addq $32, %r12 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm1, %ymm10 + vmovups (%r13), %ymm4 + addq $32, %r13 + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm2, %ymm10 + vmovups (%r14), %ymm5 + addq $32, %r14 + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm3, %ymm10 + vmovups (%r15), %ymm6 + addq $32, %r15 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm4, %ymm10 + vmovups (%rbp), %ymm7 + addq $32, %rbp + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm5, %ymm10 + vmovups (%rbx), %ymm8 + addq $32, %rbx + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm6, %ymm10 + vmovups (%r9), %ymm0 + addq $32, %r9 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm7, %ymm10 + vmovups (%r10), %ymm1 + addq $32, %r10 + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm8, %ymm10 + vmovups (%r11), %ymm2 + addq $32, %r11 + vmovups (%rdx), %ymm13 + addq $32, %rdx + + movq 24(%rsp), %rax + cmpq $0, %rax + jne Relu6 + movq 16(%rsp), %rax + cmpq $0, %rax + jne Relu + jmp Write + Relu6: + vminps %ymm15, %ymm10, %ymm10 + Relu: + vmaxps %ymm14, %ymm10, %ymm10 + Write: + vmovups %ymm10, (%rdi) + addq $32, %rdi + + vmovups (%rcx), %ymm10 + addq $32, %rcx + subq $8, %r8 + cmpq $8, %r8 + ja LoopC8 + + LeftLoop: + vfmadd231ps %ymm11, %ymm0, %ymm10 + vmovups (%r12), %ymm3 + addq $32, %r12 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm1, %ymm10 + vmovups (%r13), %ymm4 + addq $32, %r13 + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm2, %ymm10 + vmovups (%r14), %ymm5 + addq $32, %r14 + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm3, %ymm10 + vmovups (%r15), %ymm6 + addq $32, %r15 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm4, %ymm10 + vmovups (%rbp), %ymm7 + addq $32, %rbp + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm5, %ymm10 + vmovups (%rbx), %ymm8 + addq $32, %rbx + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm6, %ymm10 + vfmadd231ps %ymm12, %ymm7, %ymm10 + vfmadd231ps %ymm13, %ymm8, %ymm10 + + movq 24(%rsp), %rax + cmpq $0, %rax + jne LeftRelu6 + movq 16(%rsp), %rax + cmpq $0, %rax + jne LeftRelu + jmp LeftWrite + LeftRelu6: + vminps %ymm15, %ymm10, %ymm10 + LeftRelu: + vmaxps %ymm14, %ymm10, %ymm10 + LeftWrite: + cmpq $1, %r8 + je Write1 + cmpq $2, %r8 + je Write2 + cmpq $3, %r8 + je Write3 + cmpq $4, %r8 + je Write4 + cmpq $5, %r8 + je Write5 + cmpq $6, %r8 + je Write6 + cmpq $7, %r8 + je Write7 + jmp Write8 + Write1: + vmovss %xmm10, (%rdi) + addq $4, %rdi + jmp NextPixel + Write2: + vmovsd %xmm10, (%rdi) + addq $8, %rdi + jmp NextPixel + Write3: + vmovsd %xmm10, (%rdi) + movhlps %xmm10, %xmm10 + vmovss %xmm10, 8(%rdi) + addq $12, %rdi + jmp NextPixel + Write4: + vmovups %xmm10, (%rdi) + addq $16, %rdi + jmp NextPixel + Write5: + vmovups %xmm10, (%rdi) + vextractf128 $1, %ymm10, %xmm9 + vmovss %xmm9, 16(%rdi) + addq $20, %rdi + jmp NextPixel + Write6: + vmovups %xmm10, (%rdi) + vextractf128 $1, %ymm10, %xmm9 + vmovsd %xmm9, 16(%rdi) + addq $24, %rdi + jmp NextPixel + Write7: + vmovups %xmm10, (%rdi) + vextractf128 $1, %ymm10, %xmm9 + vmovsd %xmm9, 16(%rdi) + movhlps %xmm9, %xmm9 + vmovss %xmm9, 24(%rdi) + addq $28, %rdi + jmp NextPixel + Write8: + vmovups %ymm10, (%rdi) + add $32, %rdi + + NextPixel: + movq 8(%rsp), %rbp + addq %rbp, %rsi + movq -56(%rsp), %rax + subq $1, %rax + movq %rax, -56(%rsp) + cmpq $0, %rax + ja LoopPixel +End: + subq $96, %rsp + popq %rdi + popq %rsi + popq %rdx + popq %rcx + popq %r8 + popq %r9 + popq %rbp + popq %rbx + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/ConvDwFp32BorderAvx.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/ConvDwFp32BorderAvx.S new file mode 100644 index 0000000000000000000000000000000000000000..9a4eab7c31bf9a6b5d10ff2dea06e5dd8aa78413 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/ConvDwFp32BorderAvx.S @@ -0,0 +1,188 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl/assembly_global.h" + +.text +.align 4 + +// void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, +// size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, +// size_t relu6); + +asm_function ConvDwFp32Border + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rbx + pushq %rbp + pushq %r9 + pushq %r8 // -64 + pushq %rcx // -72 + pushq %rdx // -80 + pushq %rsi + pushq %rdi + addq $96, %rsp + + movq %rdi, %rdx +#ifdef _WIN32 + movq %rcx, %rdx +#endif + movq 8(%rdx), %r12 // src + movq 16(%rdx), %r13 // weight + movq 24(%rdx), %rbp // bias + movq 32(%rdx), %r11 // height + movq 40(%rdx), %r10 + movq %r10, -72(%rsp) // width + movq 48(%rdx), %r10 + movq %r10, -80(%rsp) // in_kh_step + movq 56(%rdx), %r10 // in_kw_step + movq 64(%rdx), %rax // kernel_w + movq 72(%rdx), %rcx // relu + movq 80(%rdx), %rbx // reul6 + movq $6, -64(%rsp) + movq (%rdx), %rdx + cmpq $0, %r11 + je End + + xorps %xmm8, %xmm8 + LoopHeight: + movq %r12, %rsi // src_kh, src_kw + movq %r13, %rdi // weight_kh, weight_kw + movq -72(%rsp), %r8 // width + + cmpq $6, %r8 + jae LoopWidth6 + cmpq $4, %r8 + jae LoopWidth4 + cmpq $1, %r8 + jae LoopWidth1 + jmp LoopWidthEnd + + LoopWidth6: + xorps %xmm6, %xmm6 + xorps %xmm7, %xmm7 + imul $3, %r10, %r9 + addq %rsi, %r9 + vmovups (%rsi), %xmm0 // src_kw + vmovups (%rsi, %r10), %xmm1 + vmovups (%rsi, %r10, 2), %xmm2 + vmovups (%r9), %xmm3 + vmovups (%rsi, %r10, 4), %xmm4 + vmovups (%r9, %r10, 2), %xmm5 + + vfmadd231ps (%rdi), %xmm0, %xmm6 + vfmadd231ps 16(%rdi), %xmm1, %xmm7 + vfmadd231ps 32(%rdi), %xmm2, %xmm8 + vfmadd231ps 48(%rdi), %xmm3, %xmm6 + vfmadd231ps 64(%rdi), %xmm4, %xmm7 + vfmadd231ps 80(%rdi), %xmm5, %xmm8 + + addps %xmm6, %xmm7 + imul $6, %r10, %r15 + addq $96, %rdi + addps %xmm7, %xmm8 + addq %r15, %rsi + + subq $6, %r8 + cmpq $6, %r8 + jae LoopWidth6 + cmpq $4, %r8 + jae LoopWidth4 + cmpq $0, %r8 + je LoopWidthEnd + jmp LoopWidth1 + + LoopWidth4: + xorps %xmm6, %xmm6 + xorps %xmm7, %xmm7 + imul $3, %r10, %r9 + addq %rsi, %r9 + vmovups (%rsi), %xmm0 // src_kw + vmovups (%rsi, %r10, 1), %xmm1 + vmovups (%rsi, %r10, 2), %xmm2 + vmovups (%r9), %xmm3 + + vfmadd231ps (%rdi), %xmm0, %xmm6 + vfmadd231ps 16(%rdi), %xmm1, %xmm7 + vfmadd231ps 32(%rdi), %xmm2, %xmm8 + vfmadd231ps 48(%rdi), %xmm3, %xmm6 + + addps %xmm6, %xmm7 + imul $4, %r10, %r15 + addq $64, %rdi + addps %xmm7, %xmm8 + addq %r15, %rsi + + subq $4, %r8 + cmpq $4, %r8 + jae LoopWidth4 + cmpq $0, %r8 + je LoopWidthEnd + jmp LoopWidth1 + + LoopWidth1: + vmovups (%rsi), %xmm0 // input_tmp + addq %r10, %rsi + vfmadd231ps (%rdi), %xmm0, %xmm8 + addq $16, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopWidth1 + jmp LoopWidthEnd + + LoopWidthEnd: + subq $1, %r11 + cmpq $0, %r11 + je LoopHeightEnd + addq -80(%rsp), %r12 // in_kh_step + addq %rax, %r13 // kernel_w_step + jmp LoopHeight + + LoopHeightEnd: + xorps %xmm10, %xmm10 + vbroadcastss -64(%rsp), %xmm9 + + addps (%rbp), %xmm8 + cmpq $1, %rbx + je Relu6 + cmpq $1, %rcx + je Relu + jmp Write + Relu6: + minps %xmm9, %xmm8 + Relu: + maxps %xmm10, %xmm8 + Write: + movups %xmm8, (%rdx) +End: + subq $96, %rsp + popq %rdi + popq %rsi + popq %rdx + popq %rcx + popq %r8 + popq %r9 + popq %rbp + popq %rbx + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/ConvDwFp32RowAvx.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/ConvDwFp32RowAvx.S new file mode 100644 index 0000000000000000000000000000000000000000..29dc965d97fdb14b55b1004f139645cf2561f37d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/ConvDwFp32RowAvx.S @@ -0,0 +1,189 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl/assembly_global.h" + +.text +.align 4 + +// void ConvDwFp32Row(float *output_ptr, const float *input_tmp, const float *weight_ptr, size_t num_pixels, +// size_t output_channel, size_t input_step); +// in linux x64 platform: +// rdi: output_ptr +// rsi: input_ptr +// rdx: weight_ptr +// rcx: num_pixels +// r8: output_channel +// r9: input_step + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output_ptr +// rdx: input_ptr +// r8: weight_ptr +// r9: num_pixels +// 40: output_channel +// 48: input_step + +asm_function ConvDwFp32Row + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rsi + pushq %rdi + addq $48, %rsp + +#ifdef _WIN32 + movq %rcx, %rdi // output_ptr + movq %rdx, %rsi // input_ptr + movq %r8, %rdx // weight_ptr + movq %r9, %rcx // num_pixels + movq 40(%rsp), %r8 // output_channel + movq 48(%rsp), %r9 // input_step +#endif + + movq $4, %r13 + imul %r13, %r9 + movq %rsi, %r13 // input_ptr + movq %rdx, %r14 // weight_ptr + movq %r8, %r15 // output_channel + cmpq $0, %rcx + je End + + LoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $32, %r8 + jae LoopC32 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC32: + vmovups (%rsi), %ymm0 // input_tmp + vmovups 32(%rsi), %ymm1 + vmovups 64(%rsi), %ymm2 + vmovups 96(%rsi), %ymm3 + + vmovups (%rdi), %ymm8 // output_tmp + vmovups 32(%rdi), %ymm9 + vmovups 64(%rdi), %ymm10 + vmovups 96(%rdi), %ymm11 + + addq $128, %rsi + vfmadd231ps (%rdx), %ymm0, %ymm8 + vfmadd231ps 32(%rdx), %ymm1, %ymm9 + vfmadd231ps 64(%rdx), %ymm2, %ymm10 + vfmadd231ps 96(%rdx), %ymm3, %ymm11 + + vmovups %ymm8, (%rdi) // output_ptr + vmovups %ymm9, 32(%rdi) + vmovups %ymm10, 64(%rdi) + vmovups %ymm11, 96(%rdi) + addq $128, %rdi + addq $128, %rdx + + subq $32, %r8 + cmpq $32, %r8 + jae LoopC32 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC16: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + vmovups 32(%rsi), %ymm1 + vmovups 32(%rdi), %ymm9 + addq $64, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + vfmadd231ps 32(%rdx), %ymm1, %ymm9 + + vmovups %ymm8, (%rdi) // output_ptr + addq $64, %rdx + vmovups %ymm9, 32(%rdi) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + addq $32, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rdi), %xmm8 // output_ptr + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp LoopPixel +End: + subq $48, %rsp + popq %rdi + popq %rsi + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/ConvDwFp32RowOptAVX.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/ConvDwFp32RowOptAVX.S new file mode 100644 index 0000000000000000000000000000000000000000..f7c2d5e4062042ac11e0c4f2b80e9dda97fc7693 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/ConvDwFp32RowOptAVX.S @@ -0,0 +1,382 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl/assembly_global.h" + +.text +.align 4 + +// void ConvDwAVXFp32Row(float *output_ptr, const float *input_tmp, const float *weight_ptr, size_t num_pixels, +// size_t output_channel, size_t input_step); +// in linux x64 platform: +// rdi: output_ptr +// rsi: input_ptr +// rdx: weight_ptr +// rcx: num_pixels +// r8: output_channel +// r9: input_step + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output_ptr +// rdx: input_ptr +// r8: weight_ptr +// r9: num_pixels +// 40: output_channel +// 48: input_step + +asm_function ConvDwAVXFp32Row + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rsi + pushq %rdi + addq $48, %rsp + +#ifdef _WIN32 + movq %rcx, %rdi // output_ptr + movq %rdx, %rsi // input_ptr + movq %r8, %rdx // weight_ptr + movq %r9, %rcx // num_pixels + movq 40(%rsp), %r8 // output_channel + movq 48(%rsp), %r9 // input_step + movq 56(%rsp), %r11 // first_calc_flag + movq 64(%rsp), %r10 // bias +#else + movq 8(%rsp), %r11 // first_calc_flag + movq 16(%rsp), %r10 // bias +#endif + + + movq $4, %r13 + imul %r13, %r9 + movq %r8, %r12 + imul %r13, %r12 + movq %rsi, %r13 // input_ptr + movq %rdx, %r14 // weight_ptr + movq %r8, %r15 // output_channel + + cmpq $1, %r11 + je OutputInitByBias + jmp OutputInitBySelf + +OutputInitByBias: + cmpq $3, %rcx + jae BiasLoopPixelNum4 + cmpq $0, %rcx + ja BiasLoopPixel + je End + + BiasLoopPixelNum4: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + movq 16(%rsp), %r10 // bias_tmp + + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopC8Num4: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rsi, %r9), %ymm1 + vmovups (%rsi, %r9, 2), %ymm2 + // vmovups (%rsi, %r9, 3), %ymm3 + + vmovups (%r10), %ymm8 // output_tmp + vmovups (%r10), %ymm9 // output_tmp + vmovups (%r10), %ymm10 // output_tmp + // vmovups (%r10), %ymm11 // output_tmp + addq $32, %rsi + addq $32, %r10 + + vmovups (%rdx), %ymm15 // weight_tmp + vfmadd231ps %ymm15, %ymm0, %ymm8 + vfmadd231ps %ymm15, %ymm1, %ymm9 + vfmadd231ps %ymm15, %ymm2, %ymm10 + // vfmadd231ps %ymm15, %ymm3, %ymm11 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + vmovups %ymm9, (%rdi, %r12) + vmovups %ymm10, (%rdi, %r12, 2) + // vmovups %ymm11, (%rdi, %r12, 3) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopCNum4: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rsi, %r9), %xmm1 + vmovss (%rsi, %r9, 2), %xmm2 + // vmovss (%rsi, %r9, 3), %xmm3 + + vmovss (%r10), %xmm8 // output_ptr + vmovss (%r10), %xmm9 // output_tmp + vmovss (%r10), %xmm10 // output_tmp + // vmovss (%r10), %xmm11 // output_tmp + addq $4, %r10 + + vmovss (%rdx), %xmm15 // weight_tmp + vfmadd231ss %xmm15, %xmm0, %xmm8 + vfmadd231ss %xmm15, %xmm1, %xmm9 + vfmadd231ss %xmm15, %xmm2, %xmm10 + // vfmadd231ss %xmm15, %xmm3, %xmm11 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + vmovss %xmm9, (%rdi, %r12) + vmovss %xmm10, (%rdi, %r12, 2) + // vmovss %xmm11, (%rdi, %r12, 3) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopCEndNum4: + subq $3, %rcx // num_pixel -= 3 + addq %r12, %rdi + addq %r12, %rdi + + addq %r9, %r13 + addq %r9, %r13 + addq %r9, %r13 + cmpq $3, %rcx + jae BiasLoopPixelNum4 + cmpq $0, %rcx + ja BiasLoopPixel + je End + + BiasLoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + movq 16(%rsp), %r10 // bias_tmp + + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%r10), %ymm8 // output_tmp + addq $32, %rsi + addq $32, %r10 + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%r10), %xmm8 // output_ptr + addq $4, %r10 + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp BiasLoopPixel + +OutputInitBySelf: + cmpq $3, %rcx + jae LoopPixelNum4 + cmpq $0, %rcx + ja LoopPixel + je End + + LoopPixelNum4: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopC8Num4: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rsi, %r9), %ymm1 + vmovups (%rsi, %r9, 2), %ymm2 + // vmovups (%rsi, %r9, 3), %ymm3 + + vmovups (%rdi), %ymm8 // output_tmp + vmovups (%rdi, %r12), %ymm9 // output_tmp + vmovups (%rdi, %r12, 2), %ymm10 // output_tmp + // vmovups (%rdi, %r12, 3), %ymm11 // output_tmp + addq $32, %rsi + + vmovups (%rdx), %ymm15 // weight_tmp + vfmadd231ps %ymm15, %ymm0, %ymm8 + vfmadd231ps %ymm15, %ymm1, %ymm9 + vfmadd231ps %ymm15, %ymm2, %ymm10 + // vfmadd231ps %ymm15, %ymm3, %ymm11 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + vmovups %ymm9, (%rdi, %r12) + vmovups %ymm10, (%rdi, %r12, 2) + // vmovups %ymm11, (%rdi, %r12, 3) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopCNum4: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rsi, %r9), %xmm1 + vmovss (%rsi, %r9, 2), %xmm2 + // vmovss (%rsi, %r9, 3), %xmm3 + + vmovss (%rdi), %xmm8 // output_ptr + vmovss (%rdi, %r12), %xmm9 // output_tmp + vmovss (%rdi, %r12, 2), %xmm10 // output_tmp + // vmovss (%rdi, %r12, 3), %xmm11 // output_tmp + + vmovss (%rdx), %xmm15 // weight_tmp + vfmadd231ss %xmm15, %xmm0, %xmm8 + vfmadd231ss %xmm15, %xmm1, %xmm9 + vfmadd231ss %xmm15, %xmm2, %xmm10 + // vfmadd231ss %xmm15, %xmm3, %xmm11 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + vmovss %xmm9, (%rdi, %r12) + vmovss %xmm10, (%rdi, %r12, 2) + // vmovss %xmm11, (%rdi, %r12, 3) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopCEndNum4: + subq $3, %rcx // num_pixel -= 3 + addq %r12, %rdi + addq %r12, %rdi + + addq %r9, %r13 + addq %r9, %r13 + addq %r9, %r13 + cmpq $3, %rcx + jae LoopPixelNum4 + cmpq $0, %rcx + ja LoopPixel + je End + + LoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + addq $32, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rdi), %xmm8 // output_ptr + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp LoopPixel +End: + subq $48, %rsp + popq %rdi + popq %rsi + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/MatmulAvx.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/MatmulAvx.S new file mode 100644 index 0000000000000000000000000000000000000000..aef2c3addeb19e710e705ff48da091d5a622d3bf --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx/MatmulAvx.S @@ -0,0 +1,993 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl/assembly_global.h" + +.text +.align 4 + +// void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth +// int row, int col, size_t stride, size_t writeMode) +// parameters pass in Linux x86 platform: +// rdi: a +// rsi: b +// rdx: c +// rcx: bias +// r8: act_type +// r9: depth +// 8: row +// 16: col +// 24: stride +// 32: writeNhwc/writeWino + +// parameters pass in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: a +// rdx: b +// r8: c +// r9: bias +// 40: act_type +// 48: depth +// 56: row +// 64: col +// 72: stride +// 80: writeMode + +asm_function MatmulFloatAvxOpt + // rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rbx + pushq %rbp + pushq %r9 // -56 + pushq %r8 // -64 + pushq %rcx // -72 + pushq %rdx // -80 + pushq %rsi // -88 + pushq %rdi // -96 + pushq %rsi // -104 rsi + pushq %rdi // -112 rdi + addq $112, %rsp +#ifdef _WIN32 + movq %rcx, %rdi + movq %rdx, %rsi + movq %r8, %rdx + movq %r9, %rcx + movq 40(%rsp), %r8 // act_type + movq 48(%rsp), %r9 // depth + movq %r9, -56(%rsp) // r9 + movq %rcx, -72(%rsp) // rcx + movq %rdx, -80(%rsp) // rdx + movq %rsi, -88(%rsp) // rsi + movq %rdi, -96(%rsp) // rdi + + movq 56(%rsp), %rbp // row + movq %rbp, 8(%rsp) + movq 64(%rsp), %rbp // col + movq %rbp, 16(%rsp) + movq 72(%rsp), %rbp // stride + movq %rbp, 24(%rsp) + movq 80(%rsp), %rbp // weiteMode + movq %rbp, 32(%rsp) +#endif + movq 8(%rsp), %rbp + movq 16(%rsp), %rbx + movq 24(%rsp), %r10 + movq 32(%rsp), %r14 + + movq $24, %r11 + imul %r9, %r11 + cmpq $0, %r14 + jne NoC8Steps + movq $48, %r13 + imul %rbp, %r13 +NoC8Steps: + cmpq $2, %r14 + jne NoWinoSteps + movq $4, %r12 + imul %r10, %r12 + imul %rbx, %r12 + movq $32, %r13 + imul %r10, %r13 +NoWinoSteps: + movq $4, %rax + imul %rax, %r10 + +LoopRow: + movq -88(%rsp), %rsi + movq 16(%rsp), %rbx + movq -72(%rsp), %rcx + + LoopCol: + cmpq $0, %r14 + je NoReloadDst + movq -80(%rsp), %rdx + NoReloadDst: + movq -96(%rsp), %rdi + movq -56(%rsp), %r9 + + vmovups (%rsi), %ymm0 + vmovups 32(%rsi), %ymm1 + vbroadcastss (%rdi), %ymm10 + vbroadcastss 4(%rdi), %ymm11 + vbroadcastss 8(%rdi), %ymm12 + vbroadcastss 12(%rdi), %ymm13 + vbroadcastss 16(%rdi), %ymm2 + vbroadcastss 20(%rdi), %ymm3 + addq $64, %rsi + vmulps %ymm0, %ymm10, %ymm4 + vmulps %ymm1, %ymm10, %ymm5 + vmulps %ymm0, %ymm11, %ymm6 + vmulps %ymm1, %ymm11, %ymm7 + vmulps %ymm0, %ymm12, %ymm8 + vmulps %ymm1, %ymm12, %ymm9 + vmulps %ymm0, %ymm13, %ymm10 + vmulps %ymm1, %ymm13, %ymm11 + add $24, %rdi + vmulps %ymm0, %ymm2, %ymm12 + vmulps %ymm1, %ymm2, %ymm13 + vmulps %ymm0, %ymm3, %ymm14 + vmulps %ymm1, %ymm3, %ymm15 + + subq $1, %r9 + cmpq $0, %r9 + je Bias + + LoopDepth: + vmovups (%rsi), %ymm0 + vmovups 32(%rsi), %ymm1 + vbroadcastss (%rdi), %ymm2 + vbroadcastss 4(%rdi), %ymm3 + vfmadd231ps %ymm0, %ymm2, %ymm4 + addq $64, %rsi + vfmadd231ps %ymm1, %ymm2, %ymm5 + vbroadcastss 8(%rdi), %ymm2 + vfmadd231ps %ymm0, %ymm3, %ymm6 + vfmadd231ps %ymm1, %ymm3, %ymm7 + vbroadcastss 12(%rdi), %ymm3 + vfmadd231ps %ymm0, %ymm2, %ymm8 + prefetcht0 384(%rsi) + vfmadd231ps %ymm1, %ymm2, %ymm9 + vbroadcastss 16(%rdi), %ymm2 + vfmadd231ps %ymm0, %ymm3, %ymm10 + vfmadd231ps %ymm1, %ymm3, %ymm11 + vbroadcastss 20(%rdi), %ymm3 + vfmadd231ps %ymm0, %ymm2, %ymm12 + vfmadd231ps %ymm1, %ymm2, %ymm13 + addq $24, %rdi + vfmadd231ps %ymm0, %ymm3, %ymm14 + vfmadd231ps %ymm1, %ymm3, %ymm15 + + subq $1, %r9 + cmpq $0, %r9 + ja LoopDepth + + Bias: + cmpq $0, %rcx + je Activation + vmovups (%rcx), %ymm0 + vmovups 32(%rcx), %ymm1 + add $64, %rcx + vaddps %ymm0, %ymm4, %ymm4 + vaddps %ymm1, %ymm5, %ymm5 + vaddps %ymm0, %ymm6, %ymm6 + vaddps %ymm1, %ymm7, %ymm7 + vaddps %ymm0, %ymm8, %ymm8 + vaddps %ymm1, %ymm9, %ymm9 + vaddps %ymm0, %ymm10, %ymm10 + vaddps %ymm1, %ymm11, %ymm11 + vaddps %ymm0, %ymm12, %ymm12 + vaddps %ymm1, %ymm13, %ymm13 + vaddps %ymm0, %ymm14, %ymm14 + vaddps %ymm1, %ymm15, %ymm15 + + Activation: + cmpq $3, %r8 + je Relu6 + cmpq $1, %r8 + je Relu + jmp Write + + Relu6: + movq $6, %rax + vcvtsi2ss %rax, %xmm0, %xmm0 + vshufps $0, %xmm0, %xmm0, %xmm0 + vinsertf128 $1, %xmm0, %ymm0, %ymm0 + vminps %ymm0, %ymm4, %ymm4 + vminps %ymm0, %ymm5, %ymm5 + vminps %ymm0, %ymm6, %ymm6 + vminps %ymm0, %ymm7, %ymm7 + vminps %ymm0, %ymm8, %ymm8 + vminps %ymm0, %ymm9, %ymm9 + vminps %ymm0, %ymm10, %ymm10 + vminps %ymm0, %ymm11, %ymm11 + vminps %ymm0, %ymm12, %ymm12 + vminps %ymm0, %ymm13, %ymm13 + vminps %ymm0, %ymm14, %ymm14 + vminps %ymm0, %ymm15, %ymm15 + + Relu: + vxorps %ymm1, %ymm1, %ymm1 + vmaxps %ymm1, %ymm4, %ymm4 + vmaxps %ymm1, %ymm5, %ymm5 + vmaxps %ymm1, %ymm6, %ymm6 + vmaxps %ymm1, %ymm7, %ymm7 + vmaxps %ymm1, %ymm8, %ymm8 + vmaxps %ymm1, %ymm9, %ymm9 + vmaxps %ymm1, %ymm10, %ymm10 + vmaxps %ymm1, %ymm11, %ymm11 + vmaxps %ymm1, %ymm12, %ymm12 + vmaxps %ymm1, %ymm13, %ymm13 + vmaxps %ymm1, %ymm14, %ymm14 + vmaxps %ymm1, %ymm15, %ymm15 + + Write: + cmpq $2, %r14 + je WriteWino + cmpq $0, %r14 + je WriteC8 + cmpq $1, %rbx + je Write1 + cmpq $2, %rbx + je Write2 + cmpq $3, %rbx + je Write3 + cmpq $4, %rbx + je Write4 + cmpq $5, %rbx + je Write5 + cmpq $6, %rbx + je Write6 + cmpq $7, %rbx + je Write7 + cmpq $8, %rbx + je Write8 + cmpq $9, %rbx + je Write9 + cmpq $10, %rbx + je Write10 + cmpq $11, %rbx + je Write11 + cmpq $12, %rbx + je Write12 + cmpq $13, %rbx + je Write13 + cmpq $14, %rbx + je Write14 + cmpq $15, %rbx + je Write15 + jmp Write16 + + Write1: + movq %rdx, %rax + addq $4, %rax + movq %rax, -80(%rsp) + vmovss %xmm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovss %xmm14, (%rdx) + addq %r10, %rdx + addq $4, %rdx + jmp WriteEnd + Write2: + movq %rdx, %rax + addq $8, %rax + movq %rax, -80(%rsp) + vmovsd %xmm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm14, (%rdx) + addq %r10, %rdx + addq $8, %rdx + jmp WriteEnd + Write3: + movq %rdx, %rax + addq $12, %rax + movq %rax, -80(%rsp) + vmovsd %xmm4, (%rdx) + movhlps %xmm4, %xmm4 + vmovss %xmm4, 8(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm6, (%rdx) + movhlps %xmm6, %xmm6 + vmovss %xmm6, 8(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm8, (%rdx) + movhlps %xmm8, %xmm8 + vmovss %xmm8, 8(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm10, (%rdx) + movhlps %xmm10, %xmm10 + vmovss %xmm10, 8(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm12, (%rdx) + movhlps %xmm12, %xmm12 + vmovss %xmm12, 8(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovsd %xmm14, (%rdx) + movhlps %xmm14, %xmm14 + vmovss %xmm14, 8(%rdx) + addq %r10, %rdx + addq $12, %rdx + jmp WriteEnd + Write4: + movq %rdx, %rax + addq $16, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + addq %r10, %rdx + addq $16, %rdx + jmp WriteEnd + Write5: + movq %rdx, %rax + addq $20, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + vextractf128 $1, %ymm4, %xmm4 + vmovss %xmm4, 16(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + vextractf128 $1, %ymm6, %xmm6 + vmovss %xmm6, 16(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + vextractf128 $1, %ymm8, %xmm8 + vmovss %xmm8, 16(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + vextractf128 $1, %ymm10, %xmm10 + vmovss %xmm10, 16(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + vextractf128 $1, %ymm12, %xmm12 + vmovss %xmm12, 16(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + vextractf128 $1, %ymm14, %xmm14 + vmovss %xmm14, 16(%rdx) + addq %r10, %rdx + addq $20, %rdx + jmp WriteEnd + Write6: + movq %rdx, %rax + addq $24, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + vextractf128 $1, %ymm4, %xmm4 + vmovsd %xmm4, 16(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + vextractf128 $1, %ymm6, %xmm6 + vmovsd %xmm6, 16(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + vextractf128 $1, %ymm8, %xmm8 + vmovsd %xmm8, 16(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + vextractf128 $1, %ymm10, %xmm10 + vmovsd %xmm10, 16(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + vextractf128 $1, %ymm12, %xmm12 + vmovsd %xmm12, 16(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + vextractf128 $1, %ymm14, %xmm14 + vmovsd %xmm14, 16(%rdx) + addq %r10, %rdx + addq $24, %rdx + jmp WriteEnd + Write7: + movq %rdx, %rax + addq $28, %rax + movq %rax, -80(%rsp) + vmovups %xmm4, (%rdx) + vextractf128 $1, %ymm4, %xmm4 + vmovsd %xmm4, 16(%rdx) + movhlps %xmm4, %xmm4 + vmovss %xmm4, 24(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm6, (%rdx) + vextractf128 $1, %ymm6, %xmm6 + vmovsd %xmm6, 16(%rdx) + movhlps %xmm6, %xmm6 + vmovss %xmm6, 24(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm8, (%rdx) + vextractf128 $1, %ymm8, %xmm8 + vmovsd %xmm8, 16(%rdx) + movhlps %xmm8, %xmm8 + vmovss %xmm8, 24(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm10, (%rdx) + vextractf128 $1, %ymm10, %xmm10 + vmovsd %xmm10, 16(%rdx) + movhlps %xmm10, %xmm10 + vmovss %xmm10, 24(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm12, (%rdx) + vextractf128 $1, %ymm12, %xmm12 + vmovsd %xmm12, 16(%rdx) + movhlps %xmm12, %xmm12 + vmovss %xmm12, 24(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %xmm14, (%rdx) + vextractf128 $1, %ymm14, %xmm14 + vmovsd %xmm14, 16(%rdx) + movhlps %xmm14, %xmm14 + vmovss %xmm14, 24(%rdx) + addq %r10, %rdx + addq $28, %rdx + jmp WriteEnd + Write8: + movq %rdx, %rax + addq $32, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + addq %r10, %rdx + addq $32, %rdx + jmp WriteEnd + Write9: + movq %rdx, %rax + addq $36, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovss %xmm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovss %xmm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovss %xmm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovss %xmm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovss %xmm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovss %xmm15, 32(%rdx) + addq %r10, %rdx + addq $36, %rdx + jmp WriteEnd + Write10: + movq %rdx, %rax + addq $40, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovsd %xmm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovsd %xmm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovsd %xmm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovsd %xmm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovsd %xmm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovsd %xmm15, 32(%rdx) + addq %r10, %rdx + addq $40, %rdx + jmp WriteEnd + Write11: + movq %rdx, %rax + addq $44, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovsd %xmm5, 32(%rdx) + movhlps %xmm5, %xmm5 + vmovss %xmm5, 40(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovsd %xmm7, 32(%rdx) + movhlps %xmm7, %xmm7 + vmovss %xmm7, 40(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovsd %xmm9, 32(%rdx) + movhlps %xmm9, %xmm9 + vmovss %xmm9, 40(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovsd %xmm11, 32(%rdx) + movhlps %xmm11, %xmm11 + vmovss %xmm11, 40(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovsd %xmm13, 32(%rdx) + movhlps %xmm13, %xmm13 + vmovss %xmm13, 40(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovsd %xmm15, 32(%rdx) + movhlps %xmm15, %xmm15 + vmovss %xmm15, 40(%rdx) + addq %r10, %rdx + addq $44, %rdx + jmp WriteEnd + Write12: + movq %rdx, %rax + addq $48, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + addq %r10, %rdx + addq $48, %rdx + jmp WriteEnd + Write13: + movq %rdx, %rax + addq $52, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + vextractf128 $1, %ymm5, %xmm5 + vmovss %xmm5, 48(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + vextractf128 $1, %ymm7, %xmm7 + vmovss %xmm7, 48(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + vextractf128 $1, %ymm9, %xmm9 + vmovss %xmm9, 48(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + vextractf128 $1, %ymm11, %xmm11 + vmovss %xmm11, 48(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + vextractf128 $1, %ymm13, %xmm13 + vmovss %xmm13, 48(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + vextractf128 $1, %ymm15, %xmm15 + vmovss %xmm15, 48(%rdx) + addq %r10, %rdx + addq $52, %rdx + jmp WriteEnd + Write14: + movq %rdx, %rax + addq $56, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + vextractf128 $1, %ymm5, %xmm5 + vmovsd %xmm5, 48(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + vextractf128 $1, %ymm7, %xmm7 + vmovsd %xmm7, 48(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + vextractf128 $1, %ymm9, %xmm9 + vmovsd %xmm9, 48(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + vextractf128 $1, %ymm11, %xmm11 + vmovsd %xmm11, 48(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + vextractf128 $1, %ymm13, %xmm13 + vmovsd %xmm13, 48(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + vextractf128 $1, %ymm15, %xmm15 + vmovsd %xmm15, 48(%rdx) + addq %r10, %rdx + addq $56, %rdx + jmp WriteEnd + Write15: + movq %rdx, %rax + addq $60, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %xmm5, 32(%rdx) + vextractf128 $1, %ymm5, %xmm5 + vmovsd %xmm5, 48(%rdx) + movhlps %xmm5, %xmm5 + vmovss %xmm5, 56(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %xmm7, 32(%rdx) + vextractf128 $1, %ymm7, %xmm7 + vmovsd %xmm7, 48(%rdx) + movhlps %xmm7, %xmm7 + vmovss %xmm7, 56(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %xmm9, 32(%rdx) + vextractf128 $1, %ymm9, %xmm9 + vmovsd %xmm9, 48(%rdx) + movhlps %xmm9, %xmm9 + vmovss %xmm9, 56(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %xmm11, 32(%rdx) + vextractf128 $1, %ymm11, %xmm11 + vmovsd %xmm11, 48(%rdx) + movhlps %xmm11, %xmm11 + vmovss %xmm11, 56(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %xmm13, 32(%rdx) + vextractf128 $1, %ymm13, %xmm13 + vmovsd %xmm13, 48(%rdx) + movhlps %xmm13, %xmm13 + vmovss %xmm13, 56(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %xmm15, 32(%rdx) + vextractf128 $1, %ymm15, %xmm15 + vmovsd %xmm15, 48(%rdx) + movhlps %xmm15, %xmm15 + vmovss %xmm15, 56(%rdx) + addq %r10, %rdx + addq $60, %rdx + jmp WriteEnd + WriteC8: + movq %rdx, %rax + addq %r11, %rdx + movq %rdx, %r15 + addq %r11, %rdx + movq %rdx, -80(%rsp) + vmovups %ymm4, (%rax) + vmovups %ymm6, 32(%rax) + vmovups %ymm8, 64(%rax) + vmovups %ymm10, 96(%rax) + vmovups %ymm12, 128(%rax) + vmovups %ymm14, 160(%rax) + vmovups %ymm5, (%r15) + vmovups %ymm7, 32(%r15) + vmovups %ymm9, 64(%r15) + vmovups %ymm11, 96(%r15) + vmovups %ymm13, 128(%r15) + vmovups %ymm15, 160(%r15) + jmp WriteEnd + WriteWino: + movq %rdx, %rax + addq %r13, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + addq %r12, %rdx + vmovups %ymm6, (%rdx) + addq %r12, %rdx + vmovups %ymm8, (%rdx) + addq %r12, %rdx + vmovups %ymm10, (%rdx) + addq %r12, %rdx + vmovups %ymm12, (%rdx) + addq %r12, %rdx + vmovups %ymm14, (%rdx) + cmpq $8, %rbx + je WriteEnd + movq %rax, %rdx + addq %r13, %rax + movq %rax, -80(%rsp) + vmovups %ymm5, (%rdx) + addq %r12, %rdx + vmovups %ymm7, (%rdx) + addq %r12, %rdx + vmovups %ymm9, (%rdx) + addq %r12, %rdx + vmovups %ymm11, (%rdx) + addq %r12, %rdx + vmovups %ymm13, (%rdx) + addq %r12, %rdx + vmovups %ymm15, (%rdx) + jmp WriteEnd + Write16: + movq %rdx, %rax + addq $64, %rax + movq %rax, -80(%rsp) + vmovups %ymm4, (%rdx) + vmovups %ymm5, 32(%rdx) + cmpq $1, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm6, (%rdx) + vmovups %ymm7, 32(%rdx) + cmpq $2, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm8, (%rdx) + vmovups %ymm9, 32(%rdx) + cmpq $3, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm10, (%rdx) + vmovups %ymm11, 32(%rdx) + cmpq $4, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm12, (%rdx) + vmovups %ymm13, 32(%rdx) + cmpq $5, %rbp + je WriteEnd + addq %r10, %rdx + vmovups %ymm14, (%rdx) + vmovups %ymm15, 32(%rdx) + addq %r10, %rdx + addq $64, %rdx + + WriteEnd: + cmpq $16, %rbx + jbe LoopColEnd + subq $16, %rbx + jmp LoopCol + + LoopColEnd: + movq -96(%rsp), %rdi + addq %r11, %rdi + movq %rdi, -96(%rsp) + cmpq $0, %r14 + je C8DstStep + cmpq $2, %r14 + je WinoDstStep + movq $4, %rax + movq 16(%rsp), %rbx + imul %rbx, %rax + subq %rax, %rdx + movq %rdx, -80(%rsp) + jmp NoDstStep + C8DstStep: + movq -80(%rsp), %rax + addq $384, %rax + movq %rax, -80(%rsp) + jmp NoDstStep + WinoDstStep: + addq %r13, %rdx + movq %rdx, -80(%rsp) + NoDstStep: + cmpq $6, %rbp + jbe LoopRowEnd + subq $6, %rbp + jmp LoopRow + +LoopRowEnd: + subq $112, %rsp + popq %rdi + popq %rsi + popq %rdx + popq %rdx + popq %rdx + popq %rcx + popq %r8 + popq %r9 + popq %rbp + popq %rbx + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx512/ConvDwFp32RowAVX512.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx512/ConvDwFp32RowAVX512.S new file mode 100644 index 0000000000000000000000000000000000000000..2048bce7d8b6fef0357999ee8545e35093421708 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/avx512/ConvDwFp32RowAVX512.S @@ -0,0 +1,499 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX512 +#include "nnacl/assembly_global.h" + +.text +.align 4 + +// void ConvDwAVX512Fp32Row(float *output_ptr, const float *input_tmp, const float *weight_ptr, size_t num_pixels, +// size_t output_channel, size_t input_step); +// in linux x64 platform: +// rdi: output_ptr +// rsi: input_ptr +// rdx: weight_ptr +// rcx: num_pixels +// r8: output_channel +// r9: input_step + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output_ptr +// rdx: input_ptr +// r8: weight_ptr +// r9: num_pixels +// 40: output_channel +// 48: input_step + +asm_function ConvDwAVX512Fp32Row + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rsi + pushq %rdi + addq $48, %rsp + +#ifdef _WIN32 + movq %rcx, %rdi // output_ptr + movq %rdx, %rsi // input_ptr + movq %r8, %rdx // weight_ptr + movq %r9, %rcx // num_pixels + movq 40(%rsp), %r8 // output_channel + movq 48(%rsp), %r9 // input_step + movq 56(%rsp), %r11 // first_calc_flag + movq 64(%rsp), %r10 // bias +#else + movq 8(%rsp), %r11 // first_calc_flag + movq 16(%rsp), %r10 // bias +#endif + + movq $4, %r13 + imul %r13, %r9 + movq %r8, %r12 + imul %r13, %r12 + movq %rsi, %r13 // input_ptr + movq %rdx, %r14 // weight_ptr + movq %r8, %r15 // output_channel + + cmpq $1, %r11 + je OutputInitByBias + jmp OutputInitBySelf + +OutputInitByBias: + cmpq $3, %rcx + jae BiasLoopPixelNum4 + cmpq $0, %rcx + ja BiasLoopPixel + je End + + BiasLoopPixelNum4: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + movq 16(%rsp), %r10 // bias_tmp + + cmpq $16, %r8 + jae BiasLoopC16Num4 + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopC16Num4: + vmovups (%rsi), %zmm0 // input_tmp + vmovups (%rsi, %r9), %zmm1 + vmovups (%rsi, %r9, 2), %zmm2 + // vmovups (%rsi, %r9, 3), %zmm3 + + vmovups (%r10), %zmm8 // output_tmp + vmovups (%r10), %zmm9 // output_tmp + vmovups (%r10), %zmm10 // output_tmp + // vmovups (%r10), %zmm11 // output_tmp + addq $64, %rsi + addq $64, %r10 + + vmovups (%rdx), %zmm15 // weight_tmp + vfmadd231ps %zmm15, %zmm0, %zmm8 + vfmadd231ps %zmm15, %zmm1, %zmm9 + vfmadd231ps %zmm15, %zmm2, %zmm10 + // vfmadd231ps %zmm15, %zmm3, %zmm11 + + addq $64, %rdx + vmovups %zmm8, (%rdi) + vmovups %zmm9, (%rdi, %r12) + vmovups %zmm10, (%rdi, %r12, 2) + // vmovups %zmm11, (%rdi, %r12, 3) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae BiasLoopC16Num4 + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopC8Num4: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rsi, %r9), %ymm1 + vmovups (%rsi, %r9, 2), %ymm2 + // vmovups (%rsi, %r9, 3), %ymm3 + + vmovups (%r10), %ymm8 // output_tmp + vmovups (%r10), %ymm9 // output_tmp + vmovups (%r10), %ymm10 // output_tmp + // vmovups (%r10), %ymm11 // output_tmp + addq $32, %rsi + addq $32, %r10 + + vmovups (%rdx), %ymm15 // weight_tmp + vfmadd231ps %ymm15, %ymm0, %ymm8 + vfmadd231ps %ymm15, %ymm1, %ymm9 + vfmadd231ps %ymm15, %ymm2, %ymm10 + // vfmadd231ps %ymm15, %ymm3, %ymm11 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + vmovups %ymm9, (%rdi, %r12) + vmovups %ymm10, (%rdi, %r12, 2) + // vmovups %ymm11, (%rdi, %r12, 3) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae BiasLoopC8Num4 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopCNum4: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rsi, %r9), %xmm1 + vmovss (%rsi, %r9, 2), %xmm2 + // vmovss (%rsi, %r9, 3), %xmm3 + + vmovss (%r10), %xmm8 // output_ptr + vmovss (%r10), %xmm9 // output_tmp + vmovss (%r10), %xmm10 // output_tmp + // vmovss (%r10), %xmm11 // output_tmp + addq $4, %r10 + + vmovss (%rdx), %xmm15 // weight_tmp + vfmadd231ss %xmm15, %xmm0, %xmm8 + vfmadd231ss %xmm15, %xmm1, %xmm9 + vfmadd231ss %xmm15, %xmm2, %xmm10 + // vfmadd231ss %xmm15, %xmm3, %xmm11 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + vmovss %xmm9, (%rdi, %r12) + vmovss %xmm10, (%rdi, %r12, 2) + // vmovss %xmm11, (%rdi, %r12, 3) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja BiasLoopCNum4 + jmp BiasLoopCEndNum4 + + BiasLoopCEndNum4: + subq $3, %rcx // num_pixel -= 3 + addq %r12, %rdi + addq %r12, %rdi + + addq %r9, %r13 + addq %r9, %r13 + addq %r9, %r13 + cmpq $3, %rcx + jae BiasLoopPixelNum4 + cmpq $0, %rcx + ja BiasLoopPixel + je End + + BiasLoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + movq 16(%rsp), %r10 // bias_tmp + + cmpq $16, %r8 + jae BiasLoopC16 + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC16: + vmovups (%rsi), %zmm0 // input_tmp + vmovups (%r10), %zmm8 // output_tmp + addq $64, %rsi + addq $64, %r10 + + vfmadd231ps (%rdx), %zmm0, %zmm8 + + addq $64, %rdx + vmovups %zmm8, (%rdi) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae BiasLoopC16 + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%r10), %ymm8 // output_tmp + addq $32, %rsi + addq $32, %r10 + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae BiasLoopC8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%r10), %xmm8 // output_ptr + addq $4, %r10 + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja BiasLoopC + jmp BiasLoopCEnd + + BiasLoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp BiasLoopPixel + +OutputInitBySelf: + cmpq $3, %rcx + jae LoopPixelNum4 + cmpq $0, %rcx + ja LoopPixel + je End + + LoopPixelNum4: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $16, %r8 + jae LoopC16Num4 + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopC16Num4: + vmovups (%rsi), %zmm0 // input_tmp + vmovups (%rsi, %r9), %zmm1 + vmovups (%rsi, %r9, 2), %zmm2 + // vmovups (%rsi, %r9, 3), %zmm3 + + vmovups (%rdi), %zmm8 // output_tmp + vmovups (%rdi, %r12), %zmm9 // output_tmp + vmovups (%rdi, %r12, 2), %zmm10 // output_tmp + // vmovups (%rdi, %r12, 3), %zmm11 // output_tmp + addq $64, %rsi + + vmovups (%rdx), %zmm15 // weight_tmp + vfmadd231ps %zmm15, %zmm0, %zmm8 + vfmadd231ps %zmm15, %zmm1, %zmm9 + vfmadd231ps %zmm15, %zmm2, %zmm10 + // vfmadd231ps %zmm15, %zmm3, %zmm11 + + addq $64, %rdx + vmovups %zmm8, (%rdi) + vmovups %zmm9, (%rdi, %r12) + vmovups %zmm10, (%rdi, %r12, 2) + // vmovups %zmm11, (%rdi, %r12, 3) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae LoopC16Num4 + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopC8Num4: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rsi, %r9), %ymm1 + vmovups (%rsi, %r9, 2), %ymm2 + // vmovups (%rsi, %r9, 3), %ymm3 + + vmovups (%rdi), %ymm8 // output_tmp + vmovups (%rdi, %r12), %ymm9 // output_tmp + vmovups (%rdi, %r12, 2), %ymm10 // output_tmp + // vmovups (%rdi, %r12, 3), %ymm11 // output_tmp + addq $32, %rsi + + vmovups (%rdx), %ymm15 // weight_tmp + vfmadd231ps %ymm15, %ymm0, %ymm8 + vfmadd231ps %ymm15, %ymm1, %ymm9 + vfmadd231ps %ymm15, %ymm2, %ymm10 + // vfmadd231ps %ymm15, %ymm3, %ymm11 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + vmovups %ymm9, (%rdi, %r12) + vmovups %ymm10, (%rdi, %r12, 2) + // vmovups %ymm11, (%rdi, %r12, 3) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8Num4 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopCNum4: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rsi, %r9), %xmm1 + vmovss (%rsi, %r9, 2), %xmm2 + // vmovss (%rsi, %r9, 3), %xmm3 + + vmovss (%rdi), %xmm8 // output_ptr + vmovss (%rdi, %r12), %xmm9 // output_tmp + vmovss (%rdi, %r12, 2), %xmm10 // output_tmp + // vmovss (%rdi, %r12, 3), %xmm11 // output_tmp + + vmovss (%rdx), %xmm15 // weight_tmp + vfmadd231ss %xmm15, %xmm0, %xmm8 + vfmadd231ss %xmm15, %xmm1, %xmm9 + vfmadd231ss %xmm15, %xmm2, %xmm10 + // vfmadd231ss %xmm15, %xmm3, %xmm11 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + vmovss %xmm9, (%rdi, %r12) + vmovss %xmm10, (%rdi, %r12, 2) + // vmovss %xmm11, (%rdi, %r12, 3) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopCNum4 + jmp LoopCEndNum4 + + LoopCEndNum4: + subq $3, %rcx // num_pixel -= 3 + addq %r12, %rdi + addq %r12, %rdi + + addq %r9, %r13 + addq %r9, %r13 + addq %r9, %r13 + cmpq $3, %rcx + jae LoopPixelNum4 + cmpq $0, %rcx + ja LoopPixel + je End + + LoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC16: + vmovups (%rsi), %zmm0 // input_tmp + vmovups (%rdi), %zmm8 // output_tmp + addq $64, %rsi + + vfmadd231ps (%rdx), %zmm0, %zmm8 + + addq $64, %rdx + vmovups %zmm8, (%rdi) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + addq $32, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rdi), %xmm8 // output_ptr + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp LoopPixel +End: + subq $48, %rsp + popq %rdi + popq %rsi + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/CalculateMinMaxFp16Count8.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/CalculateMinMaxFp16Count8.S new file mode 100644 index 0000000000000000000000000000000000000000..5799e5c82cf08296e20b05974660ecbb908fb986 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/CalculateMinMaxFp16Count8.S @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void CalculateMinMaxCount8Fp16(const float16_t *data, int count_8, float16_t *real_min, float16_t *real_max); +// x0: data +// w1: count_8 +// x2: real_min +// x3: real_max + +asm_function CalculateMinMaxCount8Fp16 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + + mov x4, x0 // reload data + mov w5, w1 // reload count + ld1 {v31.8h}, [x4] + ld1 {v30.8h}, [x4], #16 + subs w5, w5, #8 + ble Write + + Loop: + ld1 {v0.8h}, [x4], #16 + fmin v31.8h, v31.8h, v0.8h + fmax v30.8h, v30.8h, v0.8h + subs w5, w5, #8 + bgt Loop + + Write: + fminv h6, v31.8h + fmaxv h7, v30.8h + + str h6, [x2] + str h7, [x3] + + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/ConvDwFp16Border.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/ConvDwFp16Border.S new file mode 100644 index 0000000000000000000000000000000000000000..3dff798806e98f5d6f87a68e4381c2011978463c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/ConvDwFp16Border.S @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, +// size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, +// size_t relu6) + +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: in_kh_step, x7: in_kw_step, +// x8: kernel_w, x9: relu, x10: relu6 +asm_function ConvDwFp16Border + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + + ld1 {v0.8h}, [x3] // bias + movi v1.8h, #0x46, lsl #8 // relu 6 + dup v2.4s, wzr // relu + + mov x13, x1 + mov x14, x2 + LoopH: + mov x15, x13 + mov x16, x14 + mov x17, x5 + LoopW: + ld1 {v3.8h}, [x15], x7 + ld1 {v4.8h}, [x16], #16 + fmla v0.8h, v3.8h, v4.8h + subs x17, x17, #1 + bne LoopW + subs x4, x4, #1 + add x13, x13, x6 + add x14, x14, x8 + bne LoopH + cbnz x10, Relu6 + cbnz x9, Relu + b Write + Relu6: + fmin v0.8h, v0.8h, v1.8h + Relu: + fmax v0.8h, v0.8h, v2.8h + Write: + st1 {v0.8h}, [x0] + + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/ConvDwFp16Center.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/ConvDwFp16Center.S new file mode 100644 index 0000000000000000000000000000000000000000..f70981480ec43d24df1e6548f50db4e1d309279f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/ConvDwFp16Center.S @@ -0,0 +1,312 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, +// x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step +// x14: relu, x15: relu6 +asm_function ConvDwFp16Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #192 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] + add x9, sp, #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x25, x26, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] + ldr x10, [sp, #208] + ldr x11, [sp, #216] + ldr x12, [sp, #224] + ldr x13, [sp, #232] + ldr x14, [sp, #240] + ldr x15, [sp, #248] + + ld1 {v24.8h}, [x3] + movi v26.8h, #0x46, lsl #8 + dup v27.4s, wzr + + LoopH: + mov x23, x1 + mov x24, x5 + mov x3, x0 + cmp x24, #8 + blt LoopW + cmp x24, #16 + blt LoopW8 + + LoopW16: + mov x19, #16 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + mov v8.16b, v24.16b + mov v9.16b, v24.16b + mov v10.16b, v24.16b + mov v11.16b, v24.16b + mov v12.16b, v24.16b + mov v13.16b, v24.16b + mov v14.16b, v24.16b + mov v15.16b, v24.16b + LoopKh16: + mov x25, x7 + mov x21, x16 + LoopKw16: + mov x22, x21 + ld1 {v25.8h}, [x17], #16 + ld1 {v16.8h}, [x22], x11 + ld1 {v17.8h}, [x22], x11 + fmla v0.8h, v16.8h, v25.8h + fmla v1.8h, v17.8h, v25.8h + ld1 {v18.8h}, [x22], x11 + ld1 {v19.8h}, [x22], x11 + fmla v2.8h, v18.8h, v25.8h + fmla v3.8h, v19.8h, v25.8h + ld1 {v20.8h}, [x22], x11 + ld1 {v21.8h}, [x22], x11 + fmla v4.8h, v20.8h, v25.8h + fmla v5.8h, v21.8h, v25.8h + ld1 {v22.8h}, [x22], x11 + ld1 {v23.8h}, [x22], x11 + fmla v6.8h, v22.8h, v25.8h + fmla v7.8h, v23.8h, v25.8h + ld1 {v16.8h}, [x22], x11 + ld1 {v17.8h}, [x22], x11 + fmla v8.8h, v16.8h, v25.8h + fmla v9.8h, v17.8h, v25.8h + ld1 {v18.8h}, [x22], x11 + ld1 {v19.8h}, [x22], x11 + fmla v10.8h, v18.8h, v25.8h + fmla v11.8h, v19.8h, v25.8h + ld1 {v20.8h}, [x22], x11 + ld1 {v21.8h}, [x22], x11 + fmla v12.8h, v20.8h, v25.8h + fmla v13.8h, v21.8h, v25.8h + ld1 {v22.8h}, [x22], x11 + ld1 {v23.8h}, [x22], x11 + fmla v14.8h, v22.8h, v25.8h + fmla v15.8h, v23.8h, v25.8h + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw16 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh16 + cbnz x15, Relu616 + cbnz x14, Relu16 + b Write16 + Relu616: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h + fmin v4.8h, v4.8h, v26.8h + fmin v5.8h, v5.8h, v26.8h + fmin v6.8h, v6.8h, v26.8h + fmin v7.8h, v7.8h, v26.8h + fmin v8.8h, v8.8h, v26.8h + fmin v9.8h, v9.8h, v26.8h + fmin v10.8h, v10.8h, v26.8h + fmin v11.8h, v11.8h, v26.8h + fmin v12.8h, v12.8h, v26.8h + fmin v13.8h, v13.8h, v26.8h + fmin v14.8h, v14.8h, v26.8h + fmin v15.8h, v15.8h, v26.8h + Relu16: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h + fmax v4.8h, v4.8h, v27.8h + fmax v5.8h, v5.8h, v27.8h + fmax v6.8h, v6.8h, v27.8h + fmax v7.8h, v7.8h, v27.8h + fmax v8.8h, v8.8h, v27.8h + fmax v9.8h, v9.8h, v27.8h + fmax v10.8h, v10.8h, v27.8h + fmax v11.8h, v11.8h, v27.8h + fmax v12.8h, v12.8h, v27.8h + fmax v13.8h, v13.8h, v27.8h + fmax v14.8h, v14.8h, v27.8h + fmax v15.8h, v15.8h, v27.8h + Write16: + st1 {v0.8h}, [x3], x9 + st1 {v1.8h}, [x3], x9 + st1 {v2.8h}, [x3], x9 + st1 {v3.8h}, [x3], x9 + st1 {v4.8h}, [x3], x9 + st1 {v5.8h}, [x3], x9 + st1 {v6.8h}, [x3], x9 + st1 {v7.8h}, [x3], x9 + st1 {v8.8h}, [x3], x9 + st1 {v9.8h}, [x3], x9 + st1 {v10.8h}, [x3], x9 + st1 {v11.8h}, [x3], x9 + st1 {v12.8h}, [x3], x9 + st1 {v13.8h}, [x3], x9 + st1 {v14.8h}, [x3], x9 + st1 {v15.8h}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #16 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + blt LoopW + cmp x24, #16 + bge LoopW16 + LoopW8: + mov x19, #8 + mul x19, x19, x11 + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + mov v1.16b, v24.16b + mov v2.16b, v24.16b + mov v3.16b, v24.16b + mov v4.16b, v24.16b + mov v5.16b, v24.16b + mov v6.16b, v24.16b + mov v7.16b, v24.16b + LoopKh8: + mov x25, x7 + mov x21, x16 + LoopKw8: + mov x22, x21 + ld1 {v25.8h}, [x17], #16 + ld1 {v16.8h}, [x22], x11 + ld1 {v17.8h}, [x22], x11 + fmla v0.8h, v16.8h, v25.8h + fmla v1.8h, v17.8h, v25.8h + ld1 {v18.8h}, [x22], x11 + ld1 {v19.8h}, [x22], x11 + fmla v2.8h, v18.8h, v25.8h + fmla v3.8h, v19.8h, v25.8h + ld1 {v20.8h}, [x22], x11 + ld1 {v21.8h}, [x22], x11 + fmla v4.8h, v20.8h, v25.8h + fmla v5.8h, v21.8h, v25.8h + ld1 {v22.8h}, [x22], x11 + ld1 {v23.8h}, [x22], x11 + fmla v6.8h, v22.8h, v25.8h + fmla v7.8h, v23.8h, v25.8h + subs x25, x25, #1 + add x21, x21, x13 + bne LoopKw8 + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh8 + cbnz x15, Relu68 + cbnz x14, Relu8 + b Write8 + Relu68: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h + fmin v4.8h, v4.8h, v26.8h + fmin v5.8h, v5.8h, v26.8h + fmin v6.8h, v6.8h, v26.8h + fmin v7.8h, v7.8h, v26.8h + Relu8: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h + fmax v4.8h, v4.8h, v27.8h + fmax v5.8h, v5.8h, v27.8h + fmax v6.8h, v6.8h, v27.8h + fmax v7.8h, v7.8h, v27.8h + Write8: + st1 {v0.8h}, [x3], x9 + st1 {v1.8h}, [x3], x9 + st1 {v2.8h}, [x3], x9 + st1 {v3.8h}, [x3], x9 + st1 {v4.8h}, [x3], x9 + st1 {v5.8h}, [x3], x9 + st1 {v6.8h}, [x3], x9 + st1 {v7.8h}, [x3], x9 + add x23, x23, x19 + sub x24, x24, #8 + cmp x24, #0 + ble LoopWEnd + cmp x24, #8 + bge LoopW8 + LoopW: + mov x16, x23 + mov x17, x2 + mov x20, x6 + mov v0.16b, v24.16b + LoopKh: + mov x25, x7 + mov x22, x16 + LoopKw: + ld1 {v16.8h}, [x22], x13 + ld1 {v25.8h}, [x17], #16 + fmla v0.8h, v16.8h, v25.8h + subs x25, x25, #1 + bne LoopKw + add x16, x16, x12 + subs x20, x20, #1 + bne LoopKh + cbnz x15, Relu6 + cbnz x14, Relu + b Write + Relu6: + fmin v0.8h, v0.8h, v26.8h + Relu: + fmax v0.8h, v0.8h, v27.8h + Write: + st1 {v0.8h}, [x3], x9 + add x23, x23, x11 + subs x24, x24, #1 + bne LoopW + LoopWEnd: + add x0, x0, x8 + add x1, x1, x10 + subs x4, x4, #1 + bne LoopH + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/ConvDwFp16Row.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/ConvDwFp16Row.S new file mode 100644 index 0000000000000000000000000000000000000000..e8bc38ee2c9e9524725afe0ff10bfc00d647e9c2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/ConvDwFp16Row.S @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void ConvDwFp16Row(float16_t* output_ptr, const float16_t* input_ptr,const float16_t* filter_ptr, +// size_t num_pixels, size_t input_channel, size_t input_step) +// x0: output_ptr, x1: input_ptr, x2: filter_ptr, x3: num_pixels, +// x4: input_channel, x5: input_step +// +asm_function ConvDwFp16Row + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters +cmp x3, #0 +beq End + +mov x9, x0 +mov x12, #2 // sizeof(float16_t) +mul x5, x5, x12 + +LoopOutPixel: +mov x6, x1 +mov x7, x2 +mov x8, x4 + +LoopInputDepth32In: + cmp x8, #32 + blt Loop8 + sub x8, x8, #32 + + ld1 {v0.8h, v1.8h}, [x6], #32 + ld1 {v2.8h, v3.8h}, [x7], #32 + ld1 {v16.8h, v17.8h}, [x0], #32 + + cmp x8, #32 + blt LoopInputDepth32Out + LoopInputDepth32: + fmla v16.8h, v0.8h, v2.8h + fmla v17.8h, v1.8h, v3.8h + + st1 {v16.8h, v17.8h}, [x9], #32 + + ld1 {v4.8h, v5.8h}, [x6], #32 + ld1 {v6.8h, v7.8h}, [x7], #32 + ld1 {v18.8h, v19.8h}, [x0], #32 + + fmla v18.8h, v4.8h, v6.8h + fmla v19.8h, v5.8h, v7.8h + + st1 {v18.8h, v19.8h}, [x9], #32 + + ld1 {v0.8h, v1.8h}, [x6], #32 + ld1 {v2.8h, v3.8h}, [x7], #32 + ld1 {v16.8h, v17.8h}, [x0], #32 + + sub x8, x8, #32 + cmp x8, #32 + bge LoopInputDepth32 + + LoopInputDepth32Out: + fmla v16.8h, v0.8h, v2.8h + fmla v17.8h, v1.8h, v3.8h + st1 {v16.8h, v17.8h}, [x9], #32 + + ld1 {v4.8h, v5.8h}, [x6], #32 + ld1 {v6.8h, v7.8h}, [x7], #32 + ld1 {v18.8h, v19.8h}, [x0], #32 + + fmla v18.8h, v4.8h, v6.8h + fmla v19.8h, v5.8h, v7.8h + + st1 {v18.8h, v19.8h}, [x9], #32 + + Loop8: + cmp x8, #8 + blt L0 + + LoopInputDepth8: + ld1 {v0.8h}, [x6], #16 + ld1 {v2.8h}, [x7], #16 + ld1 {v16.8h}, [x0], #16 + fmla v16.8h, v0.8h, v2.8h + st1 {v16.8h}, [x9], #16 + sub x8, x8, #8 + cmp x8, #8 + bge LoopInputDepth8 + + L0: + cmp x8, #0 + beq Loop8LineEnd + + LoopInputDepth0: + ldr h0, [x6], #2 + ldr h1, [x7], #2 + ldr h2, [x0], #2 + fmul h0, h0, h1 + fadd h2, h2, h0 + str h2, [x9], #2 + subs x8, x8, #1 + bne LoopInputDepth0 + + Loop8LineEnd: + +subs x3, x3, #1 +add x1, x1, x5 +bne LoopOutPixel + +End: +ret + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/DeconvDwFp16Border.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/DeconvDwFp16Border.S new file mode 100644 index 0000000000000000000000000000000000000000..a79410abd34bb712a130128ba4d0937a3feb537d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/DeconvDwFp16Border.S @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp16Border(float *dst, const float *src, const float *weight, size_t height, size_t width, +// size_t in_kh_step, size_t in_kw_step, size_t kernel_w) + +// x0: dst, x1: src, x2: weight, x3: height, x4: width, x5: in_kh_step, x6: in_kw_step, x7: kernel_w +asm_function DeconvDwFp16Border + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + ld1 {v1.8h}, [x1] + + mov x13, x0 + mov x14, x2 + LoopH: + mov x15, x13 + mov x16, x14 + mov x17, x4 + LoopW: + ld1 {v0.8h}, [x15] + ld1 {v2.8h}, [x16], #16 + fmla v0.8h, v1.8h, v2.8h + st1 {v0.8h}, [x15], x6 + subs x17, x17, #1 + bne LoopW + subs x3, x3, #1 + add x13, x13, x5 + add x14, x14, x7 + bne LoopH + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/DeconvDwFp16Center.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/DeconvDwFp16Center.S new file mode 100644 index 0000000000000000000000000000000000000000..bb37a913a86f151fa9848a9abce5f0ec8fca9fd8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/DeconvDwFp16Center.S @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width, +// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, +// size_t in_kh_step, size_t in_kw_step); +// x0: dst, x1: src, x2: weight, x3: height, x4: weight, x5: kernel_h, x6: kernel_w, x7: out_h_step +// x8: block_channel, x9: in_sh_step, x10: in_sw_step, x11: in_kh_step, x12: in_kw_step +asm_function DeconvDwFp16Center + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #32 + stp x19, x20, [sp] + stp x21, x22, [sp, #16] + + ldr x8, [sp, #32] + ldr x9, [sp, #40] + ldr x10, [sp, #48] + ldr x11, [sp, #56] + ldr x12, [sp, #64] + + LoopH: + mov x15, x0 + mov x16, x1 + mov x17, x4 + LoopW: + mov x22, x15 + mov x19, x2 + mov x20, x5 + ld1 {v1.8h}, [x16], x8 + LoopKh: + mov x21, x22 + mov x13, x6 + LoopKw: + ld1 {v0.8h}, [x21] + ld1 {v2.8h}, [x19], #16 + fmla v0.8h, v1.8h, v2.8h + st1 {v0.8h}, [x21], x12 + subs x13, x13, #1 + bne LoopKw + add x22, x22, x11 + subs x20, x20, #1 + bne LoopKh + add x15, x15, x10 + subs x17, x17, #1 + bne LoopW + add x0, x0, x9 + add x1, x1, x7 + subs x3, x3, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/DynamicGatherArm64ForFp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/DynamicGatherArm64ForFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..71eb2fe6e715be31a7103521ea4772b4a292ff55 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/DynamicGatherArm64ForFp16.S @@ -0,0 +1,54 @@ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" +.text +.align 5 + +// void DynamicGatherArm64ForFp16(const int8_t *src, float16_t *output, int count_16, int zp, float16_t scale); +// x0: src(left matrix ptr) +// x1: output(right matrix ptr) +// w2: count_16 +// w3: zp +// w4: scale + +asm_function DynamicGatherArm64ForFp16 + mov x5, x0 // reload src + mov x6, x1 // reload out + mov w7, w2 // reload count_16 + dup v1.4s, w3 // zp + dup v2.4s, v0.s[0] // scale + + LoopCount: + ld1 {v0.16b}, [x5], #16 + + sxtl v3.8h, v0.8b + sxtl2 v4.8h, v0.16b + + sxtl v16.4s, v3.4h + sxtl2 v17.4s, v3.8h + sxtl v18.4s, v4.4h + sxtl2 v19.4s, v4.8h + + sub v16.4s, v16.4s, v1.4s + scvtf v16.4s,v16.4s + fmul v16.4s, v16.4s, v2.4s + sub v17.4s, v17.4s, v1.4s + scvtf v17.4s,v17.4s + fmul v17.4s, v17.4s, v2.4s + sub v18.4s, v18.4s, v1.4s + scvtf v18.4s,v18.4s + fmul v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v1.4s + scvtf v19.4s,v19.4s + fmul v19.4s, v19.4s, v2.4s + + fcvtn v16.4h, v16.4s + fcvtn v17.4h, v17.4s + fcvtn v18.4h, v18.4s + fcvtn v19.4h, v19.4s + + st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x6], #32 + subs w7, w7, #16 + bgt LoopCount +ret + +#endif \ No newline at end of file diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/Float16ToFloat32.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/Float16ToFloat32.S new file mode 100644 index 0000000000000000000000000000000000000000..a0cb05cba0348589c956671fe8bf1ab2cc30d0d7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/Float16ToFloat32.S @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void Float16ToFloat32(const float16_t *input, float *output, int number); +// x0: input, x1: output, x2: number +asm_function Float16ToFloat32 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + cmp x2, #0 + beq LoopEnd + cmp x2, #64 + blt Loop + Loop64: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 + fcvtl v16.4s, v0.4h + fcvtl2 v17.4s, v0.8h + fcvtl v18.4s, v1.4h + fcvtl2 v19.4s, v1.8h + fcvtl v20.4s, v2.4h + fcvtl2 v21.4s, v2.8h + fcvtl v22.4s, v3.4h + fcvtl2 v23.4s, v3.8h + fcvtl v24.4s, v4.4h + fcvtl2 v25.4s, v4.8h + fcvtl v26.4s, v5.4h + fcvtl2 v27.4s, v5.8h + fcvtl v28.4s, v6.4h + fcvtl2 v29.4s, v6.8h + fcvtl v30.4s, v7.4h + fcvtl2 v31.4s, v7.8h + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x1], #64 + subs x2, x2, #64 + ble LoopEnd + cmp x2, #64 + bge Loop64 + Loop: + ldr h0, [x0], #2 + fcvt s0, h0 + str s0, [x1], #4 + subs x2, x2, #1 + bgt Loop + LoopEnd: + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/Float32ToFloat16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/Float32ToFloat16.S new file mode 100644 index 0000000000000000000000000000000000000000..4066d33725a9a0b436500421f294c60b4e1c97a5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/Float32ToFloat16.S @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void Float32ToFloat16(const float *input, float16_t output, int number); +// x0: input, x1: output, x2: number +asm_function Float32ToFloat16 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + cmp x2, #0 + beq LoopEnd + cmp x2, #64 + blt Loop + Loop64: + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], #64 + fcvtn v0.4h, v16.4s + fcvtn2 v0.8h, v17.4s + fcvtn v1.4h, v18.4s + fcvtn2 v1.8h, v19.4s + fcvtn v2.4h, v20.4s + fcvtn2 v2.8h, v21.4s + fcvtn v3.4h, v22.4s + fcvtn2 v3.8h, v23.4s + fcvtn v4.4h, v24.4s + fcvtn2 v4.8h, v25.4s + fcvtn v5.4h, v26.4s + fcvtn2 v5.8h, v27.4s + fcvtn v6.4h, v28.4s + fcvtn2 v6.8h, v29.4s + fcvtn v7.4h, v30.4s + fcvtn2 v7.8h, v31.4s + st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 + st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x1], #64 + subs x2, x2, #64 + ble LoopEnd + cmp x2, #64 + bge Loop64 + Loop: + ldr s0, [x0], #4 + fcvt h0, s0 + str h0, [x1], #2 + subs x2, x2, #1 + bgt Loop + LoopEnd: + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatVecMulFp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatVecMulFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..e09435154552bd944b4bcd57fe45027ebaec3ad4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatVecMulFp16.S @@ -0,0 +1,191 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, int depth, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: col + +asm_function MatVecMulFp16Neon64 + sub sp, sp, #128 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + + mov w14, #2 // sizeof(float16) + mul w8, w14, w5 // rhs depthx1 block stride + mov w14, #4 + mul w13, w8, w14 // rhs depthx4 block stride + +Loop: + mov x15, x0 // reload a ptr + mov x7, x1 // reload b ptr + mov w9, w5 // reload depth + cmp w6, #4 + blt Loop1x1 + +Loop1x4: + dup v5.8h, wzr + dup v6.8h, wzr + dup v7.8h, wzr + dup v8.8h, wzr + dup v9.8h, wzr + dup v10.8h, wzr + dup v11.8h, wzr + dup v12.8h, wzr + dup v13.8h, wzr + + add x10, x7, x8 + add x11, x10, x8 + add x12, x11, x8 + +Depth8_1x4: + cmp w9, #8 + blt Depth1_1x4 + + ld1 {v0.8h}, [x15], #16 + ld1 {v1.8h}, [x7], #16 + ld1 {v2.8h}, [x10], #16 + ld1 {v3.8h}, [x11], #16 + ld1 {v4.8h}, [x12], #16 + + fmla v5.8h, v1.8h, v0.8h + fmla v6.8h, v2.8h, v0.8h + fmla v7.8h, v3.8h, v0.8h + fmla v8.8h, v4.8h, v0.8h + sub w9, w9, #8 + cbz w9, End1x4 + b Depth8_1x4 + +Depth1_1x4: + ld1 {v0.h}[0], [x15], #2 + ld1 {v1.h}[0], [x7], #2 + ld1 {v1.h}[1], [x10], #2 + ld1 {v1.h}[2], [x11], #2 + ld1 {v1.h}[3], [x12], #2 + + fmla v9.8h, v1.8h, v0.h[0] + sub w9, w9, #1 + cbz w9, End1x4 + b Depth1_1x4 + +End1x4: + faddp v10.8h, v5.8h, v6.8h + faddp v11.8h, v7.8h, v8.8h + faddp v12.8h, v10.8h, v11.8h + faddp v13.8h, v12.8h, v12.8h + fadd v13.8h, v13.8h, v9.8h + + cbz x3, Act1x4 + ld1 {v14.4h}, [x3], #8 + fadd v13.8h, v13.8h, v14.8h + +Act1x4: + cmp w4, #3 + beq Relu6_1x4 + cmp w4, #1 + beq Relu1x4 + b Write1x4 + +Relu6_1x4: + movi v14.8h, #0x46, lsl #8 + fmin v13.8h, v13.8h, v14.8h + +Relu1x4: + dup v14.8h, wzr + fmax v13.8h, v13.8h, v14.8h + +Write1x4: + st1 {v13.4h}, [x2], #8 + sub w6, w6, #4 + cbz w6, End + add x1, x1, x13 + b Loop + +Loop1x1: + dup v2.8h, wzr + dup v3.8h, wzr + dup v4.8h, wzr + dup v5.8h, wzr + dup v6.8h, wzr + +Depth8_1x1: + cmp w9, #8 + blt Depth1_1x1 + + ld1 {v0.8h}, [x15], #16 + ld1 {v1.8h}, [x7], #16 + + fmla v2.8h, v1.8h, v0.8h + sub w9, w9, #8 + cbz w9, End1x1 + b Depth8_1x1 + +Depth1_1x1: + ld1 {v0.h}[0], [x15], #2 + ld1 {v1.h}[0], [x7], #2 + + fmla v3.8h, v1.8h, v0.h[0] + sub w9, w9, #1 + cbz w9, End1x1 + b Depth1_1x1 + +End1x1: + faddp v4.8h, v2.8h, v2.8h + faddp v5.8h, v4.8h, v4.8h + faddp v6.8h, v5.8h, v5.8h + fadd v6.8h, v6.8h, v3.8h + + cbz x3, Act1x1 + ld1 {v7.h}[0], [x3], #2 + fadd v6.8h, v6.8h, v7.8h + +Act1x1: + cmp w4, #3 + beq Relu6_1x1 + cmp w4, #1 + beq Relu1x1 + b Write1x1 + +Relu6_1x1: + movi v7.8h, #0x46, lsl #8 + fmin v6.8h, v6.8h, v7.8h + +Relu1x1: + dup v7.8h, wzr + fmax v6.8h, v6.8h, v7.8h + +Write1x1: + st1 {v6.h}[0], [x2], #2 + sub w6, w6, #1 + cbz w6, End + add x1, x1, x8 + b Loop + +End: + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ret +#endif \ No newline at end of file diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/Matmul12X16Fp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/Matmul12X16Fp16.S new file mode 100644 index 0000000000000000000000000000000000000000..010bd43c3927c9a81ac458c6a0b91a967b85c8c7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/Matmul12X16Fp16.S @@ -0,0 +1,1703 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type : ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu +// x5: depth : Ic +// x6: row : remain_row +// x7: col +// x8: stride : output_stride x8 = x8 * 2 +// x9: writeMode : OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 + +// x17 : input_stride + +asm_function MatMul12x16Fp16Opt + sub sp, sp, #160 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] + +.macro CLEAR_OUTPUT_V8_V9 + dup v8.4s, wzr + dup v9.4s, wzr +.endm + +.macro CLEAR_OUTPUT_V8_V11 + dup v8.4s, wzr + dup v9.4s, wzr + dup v10.4s, wzr + dup v11.4s, wzr +.endm + +.macro CLEAR_OUTPUT_V8_V15 + CLEAR_OUTPUT_V8_V11 + dup v12.4s, wzr + dup v13.4s, wzr + dup v14.4s, wzr + dup v15.4s, wzr +.endm + +.macro CLEAR_OUTPUT_V8_V23 + CLEAR_OUTPUT_V8_V15 + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr +.endm + +.macro CLEAR_OUTPUT_V8_V31 + CLEAR_OUTPUT_V8_V23 + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +.endm + + mov x21, #24 + mul x17, x5, x21 // input_stride : 12 * Ic * sizeof(float16_t) + mov x21, #2 + mul x8, x8, x21 // output_stride + +LoopRowStart: + cmp x6, #1 + ble LoopRow1 + cmp x6, #2 + ble LoopRow2 + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + ble LoopRow8 + +LoopRow12: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol12: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V31 + cmp x19, #2 + blt LoopDepth12One + + LoopDepth12: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + fmla v10.8h, v4.8h, v0.h[1] + fmla v11.8h, v5.8h, v0.h[1] + fmla v12.8h, v4.8h, v0.h[2] + fmla v13.8h, v5.8h, v0.h[2] + fmla v14.8h, v4.8h, v0.h[3] + fmla v15.8h, v5.8h, v0.h[3] + fmla v16.8h, v4.8h, v0.h[4] + fmla v17.8h, v5.8h, v0.h[4] + fmla v18.8h, v4.8h, v0.h[5] + fmla v19.8h, v5.8h, v0.h[5] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + fmla v20.8h, v4.8h, v0.h[6] + fmla v21.8h, v5.8h, v0.h[6] + fmla v22.8h, v4.8h, v0.h[7] + fmla v23.8h, v5.8h, v0.h[7] + fmla v24.8h, v4.8h, v1.h[0] + fmla v25.8h, v5.8h, v1.h[0] + fmla v26.8h, v4.8h, v1.h[1] + fmla v27.8h, v5.8h, v1.h[1] + fmla v28.8h, v4.8h, v1.h[2] + fmla v29.8h, v5.8h, v1.h[2] + fmla v30.8h, v4.8h, v1.h[3] + fmla v31.8h, v5.8h, v1.h[3] + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + fmla v10.8h, v6.8h, v2.h[1] + fmla v11.8h, v7.8h, v2.h[1] + fmla v12.8h, v6.8h, v2.h[2] + fmla v13.8h, v7.8h, v2.h[2] + fmla v14.8h, v6.8h, v2.h[3] + fmla v15.8h, v7.8h, v2.h[3] + fmla v16.8h, v6.8h, v2.h[4] + fmla v17.8h, v7.8h, v2.h[4] + fmla v18.8h, v6.8h, v2.h[5] + fmla v19.8h, v7.8h, v2.h[5] + fmla v20.8h, v6.8h, v2.h[6] + fmla v21.8h, v7.8h, v2.h[6] + fmla v22.8h, v6.8h, v2.h[7] + fmla v23.8h, v7.8h, v2.h[7] + fmla v24.8h, v6.8h, v3.h[0] + fmla v25.8h, v7.8h, v3.h[0] + fmla v26.8h, v6.8h, v3.h[1] + fmla v27.8h, v7.8h, v3.h[1] + fmla v28.8h, v6.8h, v3.h[2] + fmla v29.8h, v7.8h, v3.h[2] + fmla v30.8h, v6.8h, v3.h[3] + fmla v31.8h, v7.8h, v3.h[3] + subs x19, x19, #2 + beq Bias12 + cmp x19, #2 + bge LoopDepth12 + + LoopDepth12One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v11.8h, v4.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v13.8h, v4.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v15.8h, v4.8h, v0.h[3] + fmla v16.8h, v3.8h, v1.h[0] + fmla v17.8h, v4.8h, v1.h[0] + fmla v18.8h, v3.8h, v1.h[1] + fmla v19.8h, v4.8h, v1.h[1] + fmla v20.8h, v3.8h, v1.h[2] + fmla v21.8h, v4.8h, v1.h[2] + fmla v22.8h, v3.8h, v1.h[3] + fmla v23.8h, v4.8h, v1.h[3] + fmla v24.8h, v3.8h, v2.h[0] + fmla v25.8h, v4.8h, v2.h[0] + fmla v26.8h, v3.8h, v2.h[1] + fmla v27.8h, v4.8h, v2.h[1] + fmla v28.8h, v3.8h, v2.h[2] + fmla v29.8h, v4.8h, v2.h[2] + fmla v30.8h, v3.8h, v2.h[3] + fmla v31.8h, v4.8h, v2.h[3] + subs x19, x19, #1 + bgt LoopDepth12One + + Bias12: + cbz x3, Activation12 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + fadd v10.8h, v10.8h, v0.8h + fadd v11.8h, v11.8h, v1.8h + fadd v12.8h, v12.8h, v0.8h + fadd v13.8h, v13.8h, v1.8h + fadd v14.8h, v14.8h, v0.8h + fadd v15.8h, v15.8h, v1.8h + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v1.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v1.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v1.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v1.8h + fadd v24.8h, v24.8h, v0.8h + fadd v25.8h, v25.8h, v1.8h + fadd v26.8h, v26.8h, v0.8h + fadd v27.8h, v27.8h, v1.8h + fadd v28.8h, v28.8h, v0.8h + fadd v29.8h, v29.8h, v1.8h + fadd v30.8h, v30.8h, v0.8h + fadd v31.8h, v31.8h, v1.8h + + Activation12: + cmp x4, #3 + beq Relu612 + cmp x4, #1 + beq Relu12 + b Write + + Relu612: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + fmin v10.8h, v10.8h, v2.8h + fmin v11.8h, v11.8h, v2.8h + fmin v12.8h, v12.8h, v2.8h + fmin v13.8h, v13.8h, v2.8h + fmin v14.8h, v14.8h, v2.8h + fmin v15.8h, v15.8h, v2.8h + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + fmin v20.8h, v20.8h, v2.8h + fmin v21.8h, v21.8h, v2.8h + fmin v22.8h, v22.8h, v2.8h + fmin v23.8h, v23.8h, v2.8h + fmin v24.8h, v24.8h, v2.8h + fmin v25.8h, v25.8h, v2.8h + fmin v26.8h, v26.8h, v2.8h + fmin v27.8h, v27.8h, v2.8h + fmin v28.8h, v28.8h, v2.8h + fmin v29.8h, v29.8h, v2.8h + fmin v30.8h, v30.8h, v2.8h + fmin v31.8h, v31.8h, v2.8h + + Relu12: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + fmax v10.8h, v10.8h, v2.8h + fmax v11.8h, v11.8h, v2.8h + fmax v12.8h, v12.8h, v2.8h + fmax v13.8h, v13.8h, v2.8h + fmax v14.8h, v14.8h, v2.8h + fmax v15.8h, v15.8h, v2.8h + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + fmax v20.8h, v20.8h, v2.8h + fmax v21.8h, v21.8h, v2.8h + fmax v22.8h, v22.8h, v2.8h + fmax v23.8h, v23.8h, v2.8h + fmax v24.8h, v24.8h, v2.8h + fmax v25.8h, v25.8h, v2.8h + fmax v26.8h, v26.8h, v2.8h + fmax v27.8h, v27.8h, v2.8h + fmax v28.8h, v28.8h, v2.8h + fmax v29.8h, v29.8h, v2.8h + fmax v30.8h, v30.8h, v2.8h + fmax v31.8h, v31.8h, v2.8h + b Write + +LoopRow8: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol8: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V23 + cmp x19, #2 + blt LoopDepth8One + + LoopDepth8: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + fmla v10.8h, v4.8h, v0.h[1] + fmla v11.8h, v5.8h, v0.h[1] + fmla v12.8h, v4.8h, v0.h[2] + fmla v13.8h, v5.8h, v0.h[2] + fmla v14.8h, v4.8h, v0.h[3] + fmla v15.8h, v5.8h, v0.h[3] + fmla v16.8h, v4.8h, v0.h[4] + fmla v17.8h, v5.8h, v0.h[4] + fmla v18.8h, v4.8h, v0.h[5] + fmla v19.8h, v5.8h, v0.h[5] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + fmla v20.8h, v4.8h, v0.h[6] + fmla v21.8h, v5.8h, v0.h[6] + fmla v22.8h, v4.8h, v0.h[7] + fmla v23.8h, v5.8h, v0.h[7] + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + fmla v10.8h, v6.8h, v2.h[1] + fmla v11.8h, v7.8h, v2.h[1] + fmla v12.8h, v6.8h, v2.h[2] + fmla v13.8h, v7.8h, v2.h[2] + fmla v14.8h, v6.8h, v2.h[3] + fmla v15.8h, v7.8h, v2.h[3] + fmla v16.8h, v6.8h, v2.h[4] + fmla v17.8h, v7.8h, v2.h[4] + fmla v18.8h, v6.8h, v2.h[5] + fmla v19.8h, v7.8h, v2.h[5] + fmla v20.8h, v6.8h, v2.h[6] + fmla v21.8h, v7.8h, v2.h[6] + fmla v22.8h, v6.8h, v2.h[7] + fmla v23.8h, v7.8h, v2.h[7] + subs x19, x19, #2 + beq Bias8 + cmp x19, #2 + bge LoopDepth8 + + LoopDepth8One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v11.8h, v4.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v13.8h, v4.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v15.8h, v4.8h, v0.h[3] + fmla v16.8h, v3.8h, v1.h[0] + fmla v17.8h, v4.8h, v1.h[0] + fmla v18.8h, v3.8h, v1.h[1] + fmla v19.8h, v4.8h, v1.h[1] + fmla v20.8h, v3.8h, v1.h[2] + fmla v21.8h, v4.8h, v1.h[2] + fmla v22.8h, v3.8h, v1.h[3] + fmla v23.8h, v4.8h, v1.h[3] + subs x19, x19, #1 + bgt LoopDepth8One + + Bias8: + cbz x3, Activation8 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + fadd v10.8h, v10.8h, v0.8h + fadd v11.8h, v11.8h, v1.8h + fadd v12.8h, v12.8h, v0.8h + fadd v13.8h, v13.8h, v1.8h + fadd v14.8h, v14.8h, v0.8h + fadd v15.8h, v15.8h, v1.8h + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v1.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v1.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v1.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v1.8h + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + fmin v10.8h, v10.8h, v2.8h + fmin v11.8h, v11.8h, v2.8h + fmin v12.8h, v12.8h, v2.8h + fmin v13.8h, v13.8h, v2.8h + fmin v14.8h, v14.8h, v2.8h + fmin v15.8h, v15.8h, v2.8h + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + fmin v20.8h, v20.8h, v2.8h + fmin v21.8h, v21.8h, v2.8h + fmin v22.8h, v22.8h, v2.8h + fmin v23.8h, v23.8h, v2.8h + + Relu8: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + fmax v10.8h, v10.8h, v2.8h + fmax v11.8h, v11.8h, v2.8h + fmax v12.8h, v12.8h, v2.8h + fmax v13.8h, v13.8h, v2.8h + fmax v14.8h, v14.8h, v2.8h + fmax v15.8h, v15.8h, v2.8h + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + fmax v20.8h, v20.8h, v2.8h + fmax v21.8h, v21.8h, v2.8h + fmax v22.8h, v22.8h, v2.8h + fmax v23.8h, v23.8h, v2.8h + b Write + +LoopRow4: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol4: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V15 + cmp x19, #2 + blt LoopDepth4One + + LoopDepth4: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + fmla v10.8h, v4.8h, v0.h[1] + fmla v11.8h, v5.8h, v0.h[1] + fmla v12.8h, v4.8h, v0.h[2] + fmla v13.8h, v5.8h, v0.h[2] + fmla v14.8h, v4.8h, v0.h[3] + fmla v15.8h, v5.8h, v0.h[3] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + fmla v10.8h, v6.8h, v2.h[1] + fmla v11.8h, v7.8h, v2.h[1] + fmla v12.8h, v6.8h, v2.h[2] + fmla v13.8h, v7.8h, v2.h[2] + fmla v14.8h, v6.8h, v2.h[3] + fmla v15.8h, v7.8h, v2.h[3] + subs x19, x19, #2 + beq Bias4 + cmp x19, #2 + bge LoopDepth4 + + LoopDepth4One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v11.8h, v4.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v13.8h, v4.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v15.8h, v4.8h, v0.h[3] + subs x19, x19, #1 + bgt LoopDepth4One + + Bias4: + cbz x3, Activation4 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + fadd v10.8h, v10.8h, v0.8h + fadd v11.8h, v11.8h, v1.8h + fadd v12.8h, v12.8h, v0.8h + fadd v13.8h, v13.8h, v1.8h + fadd v14.8h, v14.8h, v0.8h + fadd v15.8h, v15.8h, v1.8h + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + fmin v10.8h, v10.8h, v2.8h + fmin v11.8h, v11.8h, v2.8h + fmin v12.8h, v12.8h, v2.8h + fmin v13.8h, v13.8h, v2.8h + fmin v14.8h, v14.8h, v2.8h + fmin v15.8h, v15.8h, v2.8h + + Relu4: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + fmax v10.8h, v10.8h, v2.8h + fmax v11.8h, v11.8h, v2.8h + fmax v12.8h, v12.8h, v2.8h + fmax v13.8h, v13.8h, v2.8h + fmax v14.8h, v14.8h, v2.8h + fmax v15.8h, v15.8h, v2.8h + b Write + +LoopRow2: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol2: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V11 + cmp x19, #2 + blt LoopDepth2One + + LoopDepth2: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + fmla v10.8h, v4.8h, v0.h[1] + fmla v11.8h, v5.8h, v0.h[1] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + fmla v10.8h, v6.8h, v2.h[1] + fmla v11.8h, v7.8h, v2.h[1] + subs x19, x19, #2 + beq Bias2 + cmp x19, #2 + bge LoopDepth2 + + LoopDepth2One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v11.8h, v4.8h, v0.h[1] + subs x19, x19, #1 + bgt LoopDepth2One + + Bias2: + cbz x3, Activation2 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + fadd v10.8h, v10.8h, v0.8h + fadd v11.8h, v11.8h, v1.8h + + Activation2: + cmp x4, #3 + beq Relu62 + cmp x4, #1 + beq Relu2 + b Write + + Relu62: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + fmin v10.8h, v10.8h, v2.8h + fmin v11.8h, v11.8h, v2.8h + + Relu2: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + fmax v10.8h, v10.8h, v2.8h + fmax v11.8h, v11.8h, v2.8h + b Write + +LoopRow1: + mov x14, x1 // cur_weight + mov x13, x7 // reload_col + mov x12, x3 // reload_bias + + LoopCol1: + mov x11, x2 // cur_output + mov x10, x0 // cur_input + mov x19, x5 // reload_depth + CLEAR_OUTPUT_V8_V9 + cmp x19, #2 + blt LoopDepth1One + + LoopDepth1: + ld1 {v0.8h}, [x10], #16 // cur_input + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x14], #32 // cur_weight + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v5.8h, v0.h[0] + ld1 {v2.8h}, [x10], #16 // cur_input + ld1 {v3.4h}, [x10], #8 + ld1 {v6.8h, v7.8h}, [x14], #32 // cur_weight + + fmla v8.8h, v6.8h, v2.h[0] + fmla v9.8h, v7.8h, v2.h[0] + subs x19, x19, #2 + beq Bias1 + cmp x19, #2 + bge LoopDepth1 + + LoopDepth1One: + ld1 {v0.4h, v1.4h, v2.4h}, [x10], #24 // cur_input + ld1 {v3.8h, v4.8h}, [x14], #32 // cur_weight + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + subs x19, x19, #1 + bgt LoopDepth1One + + Bias1: + cbz x3, Activation1 + ld1 {v0.8h, v1.8h}, [x12], #32 + fadd v8.8h, v8.8h, v0.8h + fadd v9.8h, v9.8h, v1.8h + + Activation1: + cmp x4, #3 + beq Relu61 + cmp x4, #1 + beq Relu1 + b Write + + Relu61: + movi v2.8h, #0x46, lsl #8 + fmin v8.8h, v8.8h, v2.8h + fmin v9.8h, v9.8h, v2.8h + + Relu1: + dup v2.8h, wzr + fmax v8.8h, v8.8h, v2.8h + fmax v9.8h, v9.8h, v2.8h + b Write + + Write: + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + cmp x13, #8 + beq Write8 + cmp x13, #9 + beq Write9 + cmp x13, #10 + beq Write10 + cmp x13, #11 + beq Write11 + cmp x13, #12 + beq Write12 + cmp x13, #13 + beq Write13 + cmp x13, #14 + beq Write14 + cmp x13, #15 + beq Write15 + b Write16 + + Write1: + add x2, x2, #2 + str h8, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str h10, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str h12, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str h14, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str h16, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str h18, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str h20, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str h22, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str h24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str h26, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str h28, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str h30, [x11] + add x11, x11, x8 + add x11, x11, #2 + b WriteEnd + + Write2: + add x2, x2, #4 + add x19, x11, #2 + st1 {v8.h}[0], [x11], x8 + st1 {v8.h}[1], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.h}[0], [x11], x8 + st1 {v10.h}[1], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.h}[0], [x11], x8 + st1 {v12.h}[1], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.h}[0], [x11], x8 + st1 {v14.h}[1], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.h}[0], [x11], x8 + st1 {v16.h}[1], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.h}[0], [x11], x8 + st1 {v18.h}[1], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.h}[0], [x11], x8 + st1 {v20.h}[1], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.h}[0], [x11], x8 + st1 {v22.h}[1], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.h}[0], [x11], x8 + st1 {v24.h}[1], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.h}[0], [x11], x8 + st1 {v26.h}[1], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.h}[0], [x11], x8 + st1 {v28.h}[1], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.h}[0], [x11], x8 + st1 {v30.h}[1], [x19], x8 + add x11, x11, #4 + b WriteEnd + + Write3: + add x2, x2, #6 + add x19, x11, #4 + add x20, x11, #2 + st1 {v8.h}[0], [x11], x8 + st1 {v8.h}[1], [x20], x8 + st1 {v8.h}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.h}[0], [x11], x8 + st1 {v10.h}[1], [x20], x8 + st1 {v10.h}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.h}[0], [x11], x8 + st1 {v12.h}[1], [x20], x8 + st1 {v12.h}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.h}[0], [x11], x8 + st1 {v14.h}[1], [x20], x8 + st1 {v14.h}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.h}[0], [x11], x8 + st1 {v16.h}[1], [x20], x8 + st1 {v16.h}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.h}[0], [x11], x8 + st1 {v18.h}[1], [x20], x8 + st1 {v18.h}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.h}[0], [x11], x8 + st1 {v20.h}[1], [x20], x8 + st1 {v20.h}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.h}[0], [x11], x8 + st1 {v22.h}[1], [x20], x8 + st1 {v22.h}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.h}[0], [x11], x8 + st1 {v24.h}[1], [x20], x8 + st1 {v24.h}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.h}[0], [x11], x8 + st1 {v26.h}[1], [x20], x8 + st1 {v26.h}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.h}[0], [x11], x8 + st1 {v28.h}[1], [x20], x8 + st1 {v28.h}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.h}[0], [x11], x8 + st1 {v30.h}[1], [x20], x8 + st1 {v30.h}[2], [x19], x8 + add x11, x11, #6 + b WriteEnd + + Write4: + add x2, x2, #8 + st1 {v8.4h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write5: + add x2, x2, #10 + add x19, x11, #8 + st1 {v8.4h}, [x11], x8 + st1 {v8.h}[4], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4h}, [x11], x8 + st1 {v10.h}[4], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4h}, [x11], x8 + st1 {v12.h}[4], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4h}, [x11], x8 + st1 {v14.h}[4], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + add x11, x11, #10 + b WriteEnd + Write6: + add x2, x2, #12 + add x19, x11, #8 + add x20, x11, #10 + st1 {v8.4h}, [x11], x8 + st1 {v8.h}[4], [x19], x8 + st1 {v8.h}[5], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4h}, [x11], x8 + st1 {v10.h}[4], [x19], x8 + st1 {v10.h}[5], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4h}, [x11], x8 + st1 {v12.h}[4], [x19], x8 + st1 {v12.h}[5], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4h}, [x11], x8 + st1 {v14.h}[4], [x19], x8 + st1 {v14.h}[5], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + st1 {v16.h}[5], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + st1 {v18.h}[5], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + st1 {v20.h}[5], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + st1 {v22.h}[5], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + st1 {v24.h}[5], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + st1 {v26.h}[5], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + st1 {v28.h}[5], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + st1 {v30.h}[5], [x20], x8 + add x11, x11, #12 + b WriteEnd + Write7: + add x2, x2, #14 + add x19, x11, #8 + add x20, x11, #10 + add x10, x11, #12 + st1 {v8.4h}, [x11], x8 + st1 {v8.h}[4], [x19], x8 + st1 {v8.h}[5], [x20], x8 + st1 {v8.h}[6], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.4h}, [x11], x8 + st1 {v10.h}[4], [x19], x8 + st1 {v10.h}[5], [x20], x8 + st1 {v10.h}[6], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.4h}, [x11], x8 + st1 {v12.h}[4], [x19], x8 + st1 {v12.h}[5], [x20], x8 + st1 {v12.h}[6], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.4h}, [x11], x8 + st1 {v14.h}[4], [x19], x8 + st1 {v14.h}[5], [x20], x8 + st1 {v14.h}[6], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + st1 {v16.h}[5], [x20], x8 + st1 {v16.h}[6], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + st1 {v18.h}[5], [x20], x8 + st1 {v18.h}[6], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + st1 {v20.h}[5], [x20], x8 + st1 {v20.h}[6], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + st1 {v22.h}[5], [x20], x8 + st1 {v22.h}[6], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + st1 {v24.h}[5], [x20], x8 + st1 {v24.h}[6], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + st1 {v26.h}[5], [x20], x8 + st1 {v26.h}[6], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + st1 {v28.h}[5], [x20], x8 + st1 {v28.h}[6], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + st1 {v30.h}[5], [x20], x8 + st1 {v30.h}[6], [x10], x8 + add x11, x11, #14 + b WriteEnd + Write8: + add x2, x2, #16 + st1 {v8.8h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + add x11, x11, #16 + b WriteEnd + Write9: + add x2, x2, #18 + add x19, x11, #16 + st1 {v8.8h}, [x11], x8 + st1 {v9.h}[0], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.h}[0], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.h}[0], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.h}[0], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.h}[0], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.h}[0], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.h}[0], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.h}[0], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.h}[0], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.h}[0], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.h}[0], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.h}[0], [x19], x8 + add x11, x11, #18 + b WriteEnd + Write10: + add x2, x2, #20 + add x19, x11, #16 + add x20, x11, #18 + st1 {v8.8h}, [x11], x8 + st1 {v9.h}[0], [x19], x8 + st1 {v9.h}[1], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.h}[0], [x19], x8 + st1 {v11.h}[1], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.h}[0], [x19], x8 + st1 {v13.h}[1], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.h}[0], [x19], x8 + st1 {v15.h}[1], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.h}[0], [x19], x8 + st1 {v17.h}[1], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.h}[0], [x19], x8 + st1 {v19.h}[1], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.h}[0], [x19], x8 + st1 {v21.h}[1], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.h}[0], [x19], x8 + st1 {v23.h}[1], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.h}[0], [x19], x8 + st1 {v25.h}[1], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.h}[0], [x19], x8 + st1 {v27.h}[1], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.h}[0], [x19], x8 + st1 {v29.h}[1], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.h}[0], [x19], x8 + st1 {v31.h}[1], [x20], x8 + add x11, x11, #20 + b WriteEnd + Write11: + add x2, x2, #22 + add x19, x11, #16 + add x20, x11, #18 + add x10, x11, #20 + st1 {v8.8h}, [x11], x8 + st1 {v9.h}[0], [x19], x8 + st1 {v9.h}[1], [x20], x8 + st1 {v9.h}[2], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.h}[0], [x19], x8 + st1 {v11.h}[1], [x20], x8 + st1 {v11.h}[2], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.h}[0], [x19], x8 + st1 {v13.h}[1], [x20], x8 + st1 {v13.h}[2], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.h}[0], [x19], x8 + st1 {v15.h}[1], [x20], x8 + st1 {v15.h}[2], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.h}[0], [x19], x8 + st1 {v17.h}[1], [x20], x8 + st1 {v17.h}[2], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.h}[0], [x19], x8 + st1 {v19.h}[1], [x20], x8 + st1 {v19.h}[2], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.h}[0], [x19], x8 + st1 {v21.h}[1], [x20], x8 + st1 {v21.h}[2], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.h}[0], [x19], x8 + st1 {v23.h}[1], [x20], x8 + st1 {v23.h}[2], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.h}[0], [x19], x8 + st1 {v25.h}[1], [x20], x8 + st1 {v25.h}[2], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.h}[0], [x19], x8 + st1 {v27.h}[1], [x20], x8 + st1 {v27.h}[2], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.h}[0], [x19], x8 + st1 {v29.h}[1], [x20], x8 + st1 {v29.h}[2], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.h}[0], [x19], x8 + st1 {v31.h}[1], [x20], x8 + st1 {v31.h}[2], [x10], x8 + add x11, x11, #22 + b WriteEnd + Write12: + add x2, x2, #24 + add x19, x11, #16 + st1 {v8.8h}, [x11], x8 + st1 {v9.4h}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.4h}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.4h}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.4h}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.4h}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.4h}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.4h}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.4h}, [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.4h}, [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.4h}, [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.4h}, [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.4h}, [x19], x8 + add x11, x11, #24 + b WriteEnd + Write13: + add x2, x2, #26 + add x19, x11, #16 + add x20, x11, #24 + st1 {v8.8h}, [x11], x8 + st1 {v9.4h}, [x19], x8 + st1 {v9.h}[4], [x20], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.4h}, [x19], x8 + st1 {v11.h}[4], [x20], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.4h}, [x19], x8 + st1 {v13.h}[4], [x20], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.4h}, [x19], x8 + st1 {v15.h}[4], [x20], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.4h}, [x19], x8 + st1 {v17.h}[4], [x20], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.4h}, [x19], x8 + st1 {v19.h}[4], [x20], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.4h}, [x19], x8 + st1 {v21.h}[4], [x20], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.4h}, [x19], x8 + st1 {v23.h}[4], [x20], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.4h}, [x19], x8 + st1 {v25.h}[4], [x20], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.4h}, [x19], x8 + st1 {v27.h}[4], [x20], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.4h}, [x19], x8 + st1 {v29.h}[4], [x20], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.4h}, [x19], x8 + st1 {v31.h}[4], [x20], x8 + add x11, x11, #26 + b WriteEnd + Write14: + add x2, x2, #28 + add x19, x11, #16 + add x20, x11, #24 + add x10, x11, #26 + st1 {v8.8h}, [x11], x8 + st1 {v9.4h}, [x19], x8 + st1 {v9.h}[4], [x20], x8 + st1 {v9.h}[5], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.4h}, [x19], x8 + st1 {v11.h}[4], [x20], x8 + st1 {v11.h}[5], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.4h}, [x19], x8 + st1 {v13.h}[4], [x20], x8 + st1 {v13.h}[5], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.4h}, [x19], x8 + st1 {v15.h}[4], [x20], x8 + st1 {v15.h}[5], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.4h}, [x19], x8 + st1 {v17.h}[4], [x20], x8 + st1 {v17.h}[5], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.4h}, [x19], x8 + st1 {v19.h}[4], [x20], x8 + st1 {v19.h}[5], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.4h}, [x19], x8 + st1 {v21.h}[4], [x20], x8 + st1 {v21.h}[5], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.4h}, [x19], x8 + st1 {v23.h}[4], [x20], x8 + st1 {v23.h}[5], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.4h}, [x19], x8 + st1 {v25.h}[4], [x20], x8 + st1 {v25.h}[5], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.4h}, [x19], x8 + st1 {v27.h}[4], [x20], x8 + st1 {v27.h}[5], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.4h}, [x19], x8 + st1 {v29.h}[4], [x20], x8 + st1 {v29.h}[5], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.4h}, [x19], x8 + st1 {v31.h}[4], [x20], x8 + st1 {v31.h}[5], [x10], x8 + add x11, x11, #28 + b WriteEnd + Write15: + add x2, x2, #30 + add x19, x11, #16 + add x20, x11, #24 + add x10, x11, #26 + add x15, x11, #28 + st1 {v8.8h}, [x11], x8 + st1 {v9.4h}, [x19], x8 + st1 {v9.h}[4], [x20], x8 + st1 {v9.h}[5], [x10], x8 + st1 {v9.h}[6], [x15], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.4h}, [x19], x8 + st1 {v11.h}[4], [x20], x8 + st1 {v11.h}[5], [x10], x8 + st1 {v11.h}[6], [x15], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.4h}, [x19], x8 + st1 {v13.h}[4], [x20], x8 + st1 {v13.h}[5], [x10], x8 + st1 {v13.h}[6], [x15], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.4h}, [x19], x8 + st1 {v15.h}[4], [x20], x8 + st1 {v15.h}[5], [x10], x8 + st1 {v15.h}[6], [x15], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.4h}, [x19], x8 + st1 {v17.h}[4], [x20], x8 + st1 {v17.h}[5], [x10], x8 + st1 {v17.h}[6], [x15], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.4h}, [x19], x8 + st1 {v19.h}[4], [x20], x8 + st1 {v19.h}[5], [x10], x8 + st1 {v19.h}[6], [x15], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.4h}, [x19], x8 + st1 {v21.h}[4], [x20], x8 + st1 {v21.h}[5], [x10], x8 + st1 {v21.h}[6], [x15], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.4h}, [x19], x8 + st1 {v23.h}[4], [x20], x8 + st1 {v23.h}[5], [x10], x8 + st1 {v23.h}[6], [x15], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.4h}, [x19], x8 + st1 {v25.h}[4], [x20], x8 + st1 {v25.h}[5], [x10], x8 + st1 {v25.h}[6], [x15], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.4h}, [x19], x8 + st1 {v27.h}[4], [x20], x8 + st1 {v27.h}[5], [x10], x8 + st1 {v27.h}[6], [x15], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.4h}, [x19], x8 + st1 {v29.h}[4], [x20], x8 + st1 {v29.h}[5], [x10], x8 + st1 {v29.h}[6], [x15], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.4h}, [x19], x8 + st1 {v31.h}[4], [x20], x8 + st1 {v31.h}[5], [x10], x8 + st1 {v31.h}[6], [x15], x8 + add x11, x11, #30 + b WriteEnd + Write16: + add x2, x2, #32 + add x19, x11, #16 + st1 {v8.8h}, [x11], x8 + st1 {v9.8h}, [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v10.8h}, [x11], x8 + st1 {v11.8h}, [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v12.8h}, [x11], x8 + st1 {v13.8h}, [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v14.8h}, [x11], x8 + st1 {v15.8h}, [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v16.8h}, [x11], x8 + st1 {v17.8h}, [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + st1 {v19.8h}, [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + st1 {v21.8h}, [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + st1 {v23.8h}, [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + st1 {v25.8h}, [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + st1 {v27.8h}, [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + st1 {v29.8h}, [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + st1 {v31.8h}, [x19], x8 + add x11, x11, #32 + b WriteEnd + + WriteEnd: + subs x13, x13, #16 // col - 16 + ble LoopColEnd + cmp x6, #1 + ble LoopCol1 + cmp x6, #2 + ble LoopCol2 + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol12 + +LoopColEnd: + add x0, x0, x17 + mov x21, #2 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + subs x6, x6, #12 + bgt LoopRowStart + + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulBaseFp16Neon.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulBaseFp16Neon.S new file mode 100644 index 0000000000000000000000000000000000000000..31f1adbd024eeaeacdafdb4852326ace5b92c95d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulBaseFp16Neon.S @@ -0,0 +1,960 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulBaseFp16Neon + sub sp, sp, #160 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + + ldr x8, [sp, #160] + ldr x9, [sp, #168] // act + add x8, x8, x8 // stride * sizeof(float16_t) + + add x16, x7, x7 // col * sizeof(float16_t) + add x17, x5, x5 // depth * zieof(float16_t) + mov x11, x2 + dup v12.8h, wzr + movi v13.8h, #0x46, lsl #8 +LoopRowStart: + cmp x6, #16 + bge LoopRow16 + cmp x6, #8 + bge LoopRow8 + b LoopRow4 + +LoopRow16: + mov x15, #16 + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol16: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + ld1 {v16.8h}, [x12], #16 + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + mov v20.16b, v16.16b + mov v21.16b, v16.16b + mov v22.16b, v16.16b + mov v23.16b, v16.16b + mov v24.16b, v16.16b + mov v25.16b, v16.16b + mov v26.16b, v16.16b + mov v27.16b, v16.16b + mov v28.16b, v16.16b + mov v29.16b, v16.16b + mov v30.16b, v16.16b + mov v31.16b, v16.16b + + cmp x19, #4 + blt LoopDepth16One + + LoopDepth16: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + fmla v24.8h, v9.8h, v3.h[0] + fmla v25.8h, v9.8h, v3.h[1] + fmla v26.8h, v9.8h, v3.h[2] + fmla v27.8h, v9.8h, v3.h[3] + fmla v28.8h, v9.8h, v3.h[4] + fmla v29.8h, v9.8h, v3.h[5] + fmla v30.8h, v9.8h, v3.h[6] + fmla v31.8h, v9.8h, v3.h[7] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + fmla v24.8h, v10.8h, v5.h[0] + fmla v25.8h, v10.8h, v5.h[1] + fmla v26.8h, v10.8h, v5.h[2] + fmla v27.8h, v10.8h, v5.h[3] + fmla v28.8h, v10.8h, v5.h[4] + fmla v29.8h, v10.8h, v5.h[5] + fmla v30.8h, v10.8h, v5.h[6] + fmla v31.8h, v10.8h, v5.h[7] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + fmla v24.8h, v11.8h, v7.h[0] + fmla v25.8h, v11.8h, v7.h[1] + fmla v26.8h, v11.8h, v7.h[2] + fmla v27.8h, v11.8h, v7.h[3] + fmla v28.8h, v11.8h, v7.h[4] + fmla v29.8h, v11.8h, v7.h[5] + fmla v30.8h, v11.8h, v7.h[6] + fmla v31.8h, v11.8h, v7.h[7] + + subs x19, x19, #4 + beq Activation16 + cmp x19, #4 + bge LoopDepth16 + + LoopDepth16One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + fmla v20.8h, v2.8h, v0.h[4] + fmla v21.8h, v2.8h, v0.h[5] + fmla v22.8h, v2.8h, v0.h[6] + fmla v23.8h, v2.8h, v0.h[7] + fmla v24.8h, v2.8h, v1.h[0] + fmla v25.8h, v2.8h, v1.h[1] + fmla v26.8h, v2.8h, v1.h[2] + fmla v27.8h, v2.8h, v1.h[3] + fmla v28.8h, v2.8h, v1.h[4] + fmla v29.8h, v2.8h, v1.h[5] + fmla v30.8h, v2.8h, v1.h[6] + fmla v31.8h, v2.8h, v1.h[7] + subs x19, x19, #1 + bgt LoopDepth16One + + Activation16: + cmp x4, #3 + beq Relu616 + cmp x4, #1 + beq Relu16 + b Write16 + Relu616: + fmin v16.8h, v16.8h, v13.8h + fmin v17.8h, v17.8h, v13.8h + fmin v18.8h, v18.8h, v13.8h + fmin v19.8h, v19.8h, v13.8h + fmin v20.8h, v20.8h, v13.8h + fmin v21.8h, v21.8h, v13.8h + fmin v22.8h, v22.8h, v13.8h + fmin v23.8h, v23.8h, v13.8h + fmin v24.8h, v24.8h, v13.8h + fmin v25.8h, v25.8h, v13.8h + fmin v26.8h, v26.8h, v13.8h + fmin v27.8h, v27.8h, v13.8h + fmin v28.8h, v28.8h, v13.8h + fmin v29.8h, v29.8h, v13.8h + fmin v30.8h, v30.8h, v13.8h + fmin v31.8h, v31.8h, v13.8h + Relu16: + fmax v16.8h, v16.8h, v12.8h + fmax v17.8h, v17.8h, v12.8h + fmax v18.8h, v18.8h, v12.8h + fmax v19.8h, v19.8h, v12.8h + fmax v20.8h, v20.8h, v12.8h + fmax v21.8h, v21.8h, v12.8h + fmax v22.8h, v22.8h, v12.8h + fmax v23.8h, v23.8h, v12.8h + fmax v24.8h, v24.8h, v12.8h + fmax v25.8h, v25.8h, v12.8h + fmax v26.8h, v26.8h, v12.8h + fmax v27.8h, v27.8h, v12.8h + fmax v28.8h, v28.8h, v12.8h + fmax v29.8h, v29.8h, v12.8h + fmax v30.8h, v30.8h, v12.8h + fmax v31.8h, v31.8h, v12.8h + Write16: + cmp x13, #8 + bge Write16x8 + b Write + Write16x8: + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + st1 {v17.8h}, [x11], x8 + st1 {v18.8h}, [x11], x8 + st1 {v19.8h}, [x11], x8 + st1 {v20.8h}, [x11], x8 + st1 {v21.8h}, [x11], x8 + st1 {v22.8h}, [x11], x8 + st1 {v23.8h}, [x11], x8 + st1 {v24.8h}, [x11], x8 + st1 {v25.8h}, [x11], x8 + st1 {v26.8h}, [x11], x8 + st1 {v27.8h}, [x11], x8 + st1 {v28.8h}, [x11], x8 + st1 {v29.8h}, [x11], x8 + st1 {v30.8h}, [x11], x8 + st1 {v31.8h}, [x11], x8 + b WriteEnd + +LoopRow8: + mov x15, #8 + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + ld1 {v16.8h}, [x12], #16 + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + mov v20.16b, v16.16b + mov v21.16b, v16.16b + mov v22.16b, v16.16b + mov v23.16b, v16.16b + + cmp x19, #4 + blt LoopDepth8One + + LoopDepth8: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v16.8h, v9.8h, v1.h[0] + fmla v17.8h, v9.8h, v1.h[1] + fmla v18.8h, v9.8h, v1.h[2] + fmla v19.8h, v9.8h, v1.h[3] + fmla v20.8h, v9.8h, v1.h[4] + fmla v21.8h, v9.8h, v1.h[5] + fmla v22.8h, v9.8h, v1.h[6] + fmla v23.8h, v9.8h, v1.h[7] + fmla v16.8h, v10.8h, v2.h[0] + fmla v17.8h, v10.8h, v2.h[1] + fmla v18.8h, v10.8h, v2.h[2] + fmla v19.8h, v10.8h, v2.h[3] + fmla v20.8h, v10.8h, v2.h[4] + fmla v21.8h, v10.8h, v2.h[5] + fmla v22.8h, v10.8h, v2.h[6] + fmla v23.8h, v10.8h, v2.h[7] + fmla v16.8h, v11.8h, v3.h[0] + fmla v17.8h, v11.8h, v3.h[1] + fmla v18.8h, v11.8h, v3.h[2] + fmla v19.8h, v11.8h, v3.h[3] + fmla v20.8h, v11.8h, v3.h[4] + fmla v21.8h, v11.8h, v3.h[5] + fmla v22.8h, v11.8h, v3.h[6] + fmla v23.8h, v11.8h, v3.h[7] + subs x19, x19, #4 + beq Activation8 + cmp x19, #4 + bge LoopDepth8 + LoopDepth8One: + ld1 {v0.8h}, [x10], #16 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + fmla v20.8h, v2.8h, v0.h[4] + fmla v21.8h, v2.8h, v0.h[5] + fmla v22.8h, v2.8h, v0.h[6] + fmla v23.8h, v2.8h, v0.h[7] + subs x19, x19, #1 + bgt LoopDepth8One + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write8_Row + Relu68: + fmin v16.8h, v16.8h, v13.8h + fmin v17.8h, v17.8h, v13.8h + fmin v18.8h, v18.8h, v13.8h + fmin v19.8h, v19.8h, v13.8h + fmin v20.8h, v20.8h, v13.8h + fmin v21.8h, v21.8h, v13.8h + fmin v22.8h, v22.8h, v13.8h + fmin v23.8h, v23.8h, v13.8h + Relu8: + fmax v16.8h, v16.8h, v12.8h + fmax v17.8h, v17.8h, v12.8h + fmax v18.8h, v18.8h, v12.8h + fmax v19.8h, v19.8h, v12.8h + fmax v20.8h, v20.8h, v12.8h + fmax v21.8h, v21.8h, v12.8h + fmax v22.8h, v22.8h, v12.8h + fmax v23.8h, v23.8h, v12.8h + Write8_Row: + cmp x13, #8 // row + bge Write8x8 + b Write + Write8x8: + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + st1 {v17.8h}, [x11], x8 + st1 {v18.8h}, [x11], x8 + st1 {v19.8h}, [x11], x8 + st1 {v20.8h}, [x11], x8 + st1 {v21.8h}, [x11], x8 + st1 {v22.8h}, [x11], x8 + st1 {v23.8h}, [x11], x8 + b WriteEnd + +LoopRow4: + mov x15, #4 + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + mov x11, x2 + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + ld1 {v16.8h}, [x12], #16 + mov v17.16b, v16.16b + mov v18.16b, v16.16b + mov v19.16b, v16.16b + cmp x19, #4 + blt LoopDepth4One + LoopDepth4: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v16.8h, v9.8h, v0.h[4] + fmla v17.8h, v9.8h, v0.h[5] + fmla v18.8h, v9.8h, v0.h[6] + fmla v19.8h, v9.8h, v0.h[7] + fmla v16.8h, v10.8h, v1.h[0] + fmla v17.8h, v10.8h, v1.h[1] + fmla v18.8h, v10.8h, v1.h[2] + fmla v19.8h, v10.8h, v1.h[3] + fmla v16.8h, v11.8h, v1.h[4] + fmla v17.8h, v11.8h, v1.h[5] + fmla v18.8h, v11.8h, v1.h[6] + fmla v19.8h, v11.8h, v1.h[7] + subs x19, x19, #4 + beq Activation4 + cmp x19, #4 + bge LoopDepth4 + LoopDepth4One: + ld1 {v0.4h}, [x10], #8 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + subs x19, x19, #1 + bgt LoopDepth4One + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write4_Row + Relu64: + fmin v16.8h, v16.8h, v13.8h + fmin v17.8h, v17.8h, v13.8h + fmin v18.8h, v18.8h, v13.8h + fmin v19.8h, v19.8h, v13.8h + Relu4: + fmax v16.8h, v16.8h, v12.8h + fmax v17.8h, v17.8h, v12.8h + fmax v18.8h, v18.8h, v12.8h + fmax v19.8h, v19.8h, v12.8h + Write4_Row: + cmp x6, #4 + bge Write4x8 + b Write + Write4x8: + cmp x13, #8 + blt Write + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + st1 {v17.8h}, [x11], x8 + st1 {v18.8h}, [x11], x8 + st1 {v19.8h}, [x11], x8 + b WriteEnd + + Write: + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #2 + st1 {v16.h}[0], [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.h}[0], [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.h}[0], [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.h}[0], [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.h}[0], [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.h}[0], [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.h}[0], [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.h}[0], [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.h}[0], [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.h}[0], [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.h}[0], [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.h}[0], [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.h}[0], [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.h}[0], [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.h}[0], [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.h}[0], [x11], x8 + b WriteEnd + Write2: + add x2, x2, #4 + st1 {v16.s}[0], [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.s}[0], [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.s}[0], [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.s}[0], [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.s}[0], [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.s}[0], [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.s}[0], [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.s}[0], [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.s}[0], [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.s}[0], [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.s}[0], [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.s}[0], [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.s}[0], [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.s}[0], [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.s}[0], [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.s}[0], [x11], x8 + b WriteEnd + Write3: + add x2, x2, #6 + add x19, x11, #4 + st1 {v16.s}[0], [x11], x8 + st1 {v16.h}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.s}[0], [x11], x8 + st1 {v17.h}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.s}[0], [x11], x8 + st1 {v18.h}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.s}[0], [x11], x8 + st1 {v19.h}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.s}[0], [x11], x8 + st1 {v20.h}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.s}[0], [x11], x8 + st1 {v21.h}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.s}[0], [x11], x8 + st1 {v22.h}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.s}[0], [x11], x8 + st1 {v23.h}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.s}[0], [x11], x8 + st1 {v24.h}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.s}[0], [x11], x8 + st1 {v25.h}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.s}[0], [x11], x8 + st1 {v26.h}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.s}[0], [x11], x8 + st1 {v27.h}[2], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.s}[0], [x11], x8 + st1 {v28.h}[2], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.s}[0], [x11], x8 + st1 {v29.h}[2], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.s}[0], [x11], x8 + st1 {v30.h}[2], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.s}[0], [x11], x8 + st1 {v31.h}[2], [x19] + b WriteEnd + Write4: + add x2, x2, #8 + st1 {v16.4h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + b WriteEnd + Write5: + add x2, x2, #10 + add x19, x11, #8 + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.h}[4], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.h}[4], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.h}[4], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.h}[4], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.h}[4], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.h}[4], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.h}[4], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.h}[4], [x19] + b WriteEnd + Write6: + add x2, x2, #12 + add x19, x11, #8 + st1 {v16.4h}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.s}[2], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.s}[2], [x19] + b WriteEnd + Write7: + add x2, x2, #14 + add x19, x11, #8 + add x10, x11, #12 + st1 {v16.4h}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + st1 {v16.h}[6], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + st1 {v17.h}[6], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + st1 {v18.h}[6], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + st1 {v19.h}[6], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + st1 {v20.h}[6], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + st1 {v21.h}[6], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + st1 {v22.h}[6], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + st1 {v23.h}[6], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + st1 {v24.h}[6], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + st1 {v25.h}[6], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + st1 {v26.h}[6], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + st1 {v27.h}[6], [x10], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + st1 {v28.h}[6], [x10], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + st1 {v29.h}[6], [x10], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.s}[2], [x19], x8 + st1 {v30.h}[6], [x10], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.s}[2], [x19] + st1 {v31.h}[6], [x10] + b WriteEnd + Write8: + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.8h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.8h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.8h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.8h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.8h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.8h}, [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.8h}, [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.8h}, [x11], x8 + + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + ble LoopColEnd + cmp x6, #16 + bge LoopCol16 + cmp x6, #8 + bge LoopCol8 + b LoopCol4 + +LoopColEnd: + sub x2, x2, x16 // dst - col * 2 + mul x21, x8, x15 // row_block * col * 2 + add x2, x2, x21 + subs x6, x6, x15 + mul x15, x15, x17 + add x0, x0, x15 + bgt LoopRowStart + + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulFp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..1d6b69a60171e8efc1a14a803247b0e9a3266794 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulFp16.S @@ -0,0 +1,892 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, int stride, bool write_nhwc) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: row +// w7: col +// w17: stride +// w13: writeC8 + +asm_function MatmulFp16Neon64 + sub sp, sp, #144 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + stp x19, x20, [sp, #128] + + mov w18, #16 // sizeof(float16) * 8 + mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float16) * 8 * depth + mov x11, x3 // bias flag + mov x19, #2 + ldr x17, [sp, #144] + mul x17, x17, x19 + +L1: + mov w10, w6 // reload lhs row + mov x12, x0 // reload lhs ptr + mov x19, x2 // reload dst ptr + +L2: + mov x16, x1 // reload rhs ptr + mov w13, w5 // reload depth + mov x14, x3 // reload bias ptr + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + cmp w13, #8 + blt CommLoopMul + +OptLoopMul8: + ld1 {v0.8h, v1.8h}, [x12], #32 + ld1 {v8.8h, v9.8h}, [x16], #32 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + ld1 {v2.8h, v3.8h}, [x12], #32 + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + ld1 {v10.8h, v11.8h}, [x16], #32 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 + fmla v24.8h, v9.8h, v3.h[0] + fmla v25.8h, v9.8h, v3.h[1] + fmla v26.8h, v9.8h, v3.h[2] + fmla v27.8h, v9.8h, v3.h[3] + fmla v28.8h, v9.8h, v3.h[4] + fmla v29.8h, v9.8h, v3.h[5] + fmla v30.8h, v9.8h, v3.h[6] + fmla v31.8h, v9.8h, v3.h[7] + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x16], #64 + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 + fmla v24.8h, v10.8h, v5.h[0] + fmla v25.8h, v10.8h, v5.h[1] + fmla v26.8h, v10.8h, v5.h[2] + fmla v27.8h, v10.8h, v5.h[3] + fmla v28.8h, v10.8h, v5.h[4] + fmla v29.8h, v10.8h, v5.h[5] + fmla v30.8h, v10.8h, v5.h[6] + fmla v31.8h, v10.8h, v5.h[7] + ld1 {v4.8h, v5.8h}, [x12], #32 + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + fmla v24.8h, v11.8h, v7.h[0] + fmla v25.8h, v11.8h, v7.h[1] + fmla v26.8h, v11.8h, v7.h[2] + fmla v27.8h, v11.8h, v7.h[3] + fmla v28.8h, v11.8h, v7.h[4] + fmla v29.8h, v11.8h, v7.h[5] + fmla v30.8h, v11.8h, v7.h[6] + fmla v31.8h, v11.8h, v7.h[7] + ld1 {v6.8h, v7.8h}, [x12], #32 + fmla v16.8h, v12.8h, v0.h[0] + fmla v17.8h, v12.8h, v0.h[1] + fmla v18.8h, v12.8h, v0.h[2] + fmla v19.8h, v12.8h, v0.h[3] + fmla v20.8h, v12.8h, v0.h[4] + fmla v21.8h, v12.8h, v0.h[5] + fmla v22.8h, v12.8h, v0.h[6] + fmla v23.8h, v12.8h, v0.h[7] + fmla v24.8h, v12.8h, v1.h[0] + fmla v25.8h, v12.8h, v1.h[1] + fmla v26.8h, v12.8h, v1.h[2] + fmla v27.8h, v12.8h, v1.h[3] + fmla v28.8h, v12.8h, v1.h[4] + fmla v29.8h, v12.8h, v1.h[5] + fmla v30.8h, v12.8h, v1.h[6] + fmla v31.8h, v12.8h, v1.h[7] + fmla v16.8h, v13.8h, v2.h[0] + fmla v17.8h, v13.8h, v2.h[1] + fmla v18.8h, v13.8h, v2.h[2] + fmla v19.8h, v13.8h, v2.h[3] + fmla v20.8h, v13.8h, v2.h[4] + fmla v21.8h, v13.8h, v2.h[5] + fmla v22.8h, v13.8h, v2.h[6] + fmla v23.8h, v13.8h, v2.h[7] + fmla v24.8h, v13.8h, v3.h[0] + fmla v25.8h, v13.8h, v3.h[1] + fmla v26.8h, v13.8h, v3.h[2] + fmla v27.8h, v13.8h, v3.h[3] + fmla v28.8h, v13.8h, v3.h[4] + fmla v29.8h, v13.8h, v3.h[5] + fmla v30.8h, v13.8h, v3.h[6] + fmla v31.8h, v13.8h, v3.h[7] + fmla v16.8h, v14.8h, v4.h[0] + fmla v17.8h, v14.8h, v4.h[1] + fmla v18.8h, v14.8h, v4.h[2] + fmla v19.8h, v14.8h, v4.h[3] + fmla v20.8h, v14.8h, v4.h[4] + fmla v21.8h, v14.8h, v4.h[5] + fmla v22.8h, v14.8h, v4.h[6] + fmla v23.8h, v14.8h, v4.h[7] + fmla v24.8h, v14.8h, v5.h[0] + fmla v25.8h, v14.8h, v5.h[1] + fmla v26.8h, v14.8h, v5.h[2] + fmla v27.8h, v14.8h, v5.h[3] + fmla v28.8h, v14.8h, v5.h[4] + fmla v29.8h, v14.8h, v5.h[5] + fmla v30.8h, v14.8h, v5.h[6] + fmla v31.8h, v14.8h, v5.h[7] + fmla v16.8h, v15.8h, v6.h[0] + fmla v17.8h, v15.8h, v6.h[1] + fmla v18.8h, v15.8h, v6.h[2] + fmla v19.8h, v15.8h, v6.h[3] + fmla v20.8h, v15.8h, v6.h[4] + fmla v21.8h, v15.8h, v6.h[5] + fmla v22.8h, v15.8h, v6.h[6] + fmla v23.8h, v15.8h, v6.h[7] + fmla v24.8h, v15.8h, v7.h[0] + fmla v25.8h, v15.8h, v7.h[1] + fmla v26.8h, v15.8h, v7.h[2] + fmla v27.8h, v15.8h, v7.h[3] + fmla v28.8h, v15.8h, v7.h[4] + fmla v29.8h, v15.8h, v7.h[5] + fmla v30.8h, v15.8h, v7.h[6] + fmla v31.8h, v15.8h, v7.h[7] + + sub w13, w13, #8 + cmp w13, #0 + ble Bias + cmp w13, #8 + bge OptLoopMul8 + +CommLoopMul: + ld1 {v0.8h, v1.8h}, [x12], #32 + ld1 {v8.8h}, [x16], #16 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + + subs w13, w13, #1 + bgt CommLoopMul + +Bias: + cbz x11, Activation + ld1 {v0.8h}, [x14], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v0.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v0.8h + fadd v24.8h, v24.8h, v0.8h + fadd v25.8h, v25.8h, v0.8h + fadd v26.8h, v26.8h, v0.8h + fadd v27.8h, v27.8h, v0.8h + fadd v28.8h, v28.8h, v0.8h + fadd v29.8h, v29.8h, v0.8h + fadd v30.8h, v30.8h, v0.8h + fadd v31.8h, v31.8h, v0.8h + +Activation: + cmp w4, #3 + beq Relu6 + cmp w4, #1 + beq Relu + b Write + +Relu6: + movi v15.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v15.8h + fmin v17.8h, v17.8h, v15.8h + fmin v18.8h, v18.8h, v15.8h + fmin v19.8h, v19.8h, v15.8h + fmin v20.8h, v20.8h, v15.8h + fmin v21.8h, v21.8h, v15.8h + fmin v22.8h, v22.8h, v15.8h + fmin v23.8h, v23.8h, v15.8h + fmin v24.8h, v24.8h, v15.8h + fmin v25.8h, v25.8h, v15.8h + fmin v26.8h, v26.8h, v15.8h + fmin v27.8h, v27.8h, v15.8h + fmin v28.8h, v28.8h, v15.8h + fmin v29.8h, v29.8h, v15.8h + fmin v30.8h, v30.8h, v15.8h + fmin v31.8h, v31.8h, v15.8h + +Relu: + dup v14.4s, wzr + fmax v16.8h, v16.8h, v14.8h + fmax v17.8h, v17.8h, v14.8h + fmax v18.8h, v18.8h, v14.8h + fmax v19.8h, v19.8h, v14.8h + fmax v20.8h, v20.8h, v14.8h + fmax v21.8h, v21.8h, v14.8h + fmax v22.8h, v22.8h, v14.8h + fmax v23.8h, v23.8h, v14.8h + fmax v24.8h, v24.8h, v14.8h + fmax v25.8h, v25.8h, v14.8h + fmax v26.8h, v26.8h, v14.8h + fmax v27.8h, v27.8h, v14.8h + fmax v28.8h, v28.8h, v14.8h + fmax v29.8h, v29.8h, v14.8h + fmax v30.8h, v30.8h, v14.8h + fmax v31.8h, v31.8h, v14.8h + +Write: + ldrb w13, [sp, #152] + cbz w13, WriteC8 + cmp w7, #1 + beq Write1 + cmp w7, #2 + beq Write2 + cmp w7, #3 + beq Write3 + cmp w7, #4 + beq Write4 + cmp w7, #5 + beq Write5 + cmp w7, #6 + beq Write6 + cmp w7, #7 + beq Write7 + b Write8 + +Write1: + st1 {v16.h}[0], [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.h}[0], [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.h}[0], [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.h}[0], [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.h}[0], [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.h}[0], [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.h}[0], [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.h}[0], [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.h}[0], [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.h}[0], [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.h}[0], [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.h}[0], [x19], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.h}[0], [x19], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.h}[0], [x19], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.h}[0], [x19], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.h}[0], [x19], x17 + b WriteEnd +Write2: + add x13, x19, #2 + st1 {v16.h}[0], [x19], x17 + st1 {v16.h}[1], [x13], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.h}[0], [x19], x17 + st1 {v17.h}[1], [x13], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.h}[0], [x19], x17 + st1 {v18.h}[1], [x13], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.h}[0], [x19], x17 + st1 {v19.h}[1], [x13], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.h}[0], [x19], x17 + st1 {v20.h}[1], [x13], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.h}[0], [x19], x17 + st1 {v21.h}[1], [x13], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.h}[0], [x19], x17 + st1 {v22.h}[1], [x13], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.h}[0], [x19], x17 + st1 {v23.h}[1], [x13], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.h}[0], [x19], x17 + st1 {v24.h}[1], [x13], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.h}[0], [x19], x17 + st1 {v25.h}[1], [x13], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.h}[0], [x19], x17 + st1 {v26.h}[1], [x13], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.h}[0], [x19], x17 + st1 {v27.h}[1], [x13], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.h}[0], [x19], x17 + st1 {v28.h}[1], [x13], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.h}[0], [x19], x17 + st1 {v29.h}[1], [x13], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.h}[0], [x19], x17 + st1 {v30.h}[1], [x13], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.h}[0], [x19], x17 + st1 {v31.h}[1], [x13], x17 + b WriteEnd +Write3: + add x13, x19, #2 + add x14, x19, #4 + st1 {v16.h}[0], [x19], x17 + st1 {v16.h}[1], [x13], x17 + st1 {v16.h}[2], [x14], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.h}[0], [x19], x17 + st1 {v17.h}[1], [x13], x17 + st1 {v17.h}[2], [x14], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.h}[0], [x19], x17 + st1 {v18.h}[1], [x13], x17 + st1 {v18.h}[2], [x14], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.h}[0], [x19], x17 + st1 {v19.h}[1], [x13], x17 + st1 {v19.h}[2], [x14], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.h}[0], [x19], x17 + st1 {v20.h}[1], [x13], x17 + st1 {v20.h}[2], [x14], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.h}[0], [x19], x17 + st1 {v21.h}[1], [x13], x17 + st1 {v21.h}[2], [x14], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.h}[0], [x19], x17 + st1 {v22.h}[1], [x13], x17 + st1 {v22.h}[2], [x14], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.h}[0], [x19], x17 + st1 {v23.h}[1], [x13], x17 + st1 {v23.h}[2], [x14], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.h}[0], [x19], x17 + st1 {v24.h}[1], [x13], x17 + st1 {v24.h}[2], [x14], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.h}[0], [x19], x17 + st1 {v25.h}[1], [x13], x17 + st1 {v25.h}[2], [x14], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.h}[0], [x19], x17 + st1 {v26.h}[1], [x13], x17 + st1 {v26.h}[2], [x14], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.h}[0], [x19], x17 + st1 {v27.h}[1], [x13], x17 + st1 {v27.h}[2], [x14], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.h}[0], [x19], x17 + st1 {v28.h}[1], [x13], x17 + st1 {v28.h}[2], [x14], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.h}[0], [x19], x17 + st1 {v29.h}[1], [x13], x17 + st1 {v29.h}[2], [x14], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.h}[0], [x19], x17 + st1 {v30.h}[1], [x13], x17 + st1 {v30.h}[2], [x14], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.h}[0], [x19], x17 + st1 {v31.h}[1], [x13], x17 + st1 {v31.h}[2], [x14], x17 + b WriteEnd +Write4: + st1 {v16.4h}, [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.4h}, [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x19], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x19], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x19], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x19], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x19], x17 + b WriteEnd +Write5: + add x13, x19, #8 + st1 {v16.4h}, [x19], x17 + st1 {v16.h}[4], [x13], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.4h}, [x19], x17 + st1 {v17.h}[4], [x13], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x19], x17 + st1 {v18.h}[4], [x13], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x19], x17 + st1 {v19.h}[4], [x13], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x19], x17 + st1 {v20.h}[4], [x13], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x19], x17 + st1 {v21.h}[4], [x13], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x19], x17 + st1 {v22.h}[4], [x13], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x19], x17 + st1 {v23.h}[4], [x13], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x19], x17 + st1 {v24.h}[4], [x13], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x19], x17 + st1 {v25.h}[4], [x13], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x19], x17 + st1 {v26.h}[4], [x13], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x19], x17 + st1 {v27.h}[4], [x13], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x19], x17 + st1 {v28.h}[4], [x13], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x19], x17 + st1 {v29.h}[4], [x13], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x19], x17 + st1 {v30.h}[4], [x13], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x19], x17 + st1 {v31.h}[4], [x13], x17 + b WriteEnd +Write6: + add x13, x19, #8 + add x14, x19, #10 + st1 {v16.4h}, [x19], x17 + st1 {v16.h}[4], [x13], x17 + st1 {v16.h}[5], [x14], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.4h}, [x19], x17 + st1 {v17.h}[4], [x13], x17 + st1 {v17.h}[5], [x14], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x19], x17 + st1 {v18.h}[4], [x13], x17 + st1 {v18.h}[5], [x14], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x19], x17 + st1 {v19.h}[4], [x13], x17 + st1 {v19.h}[5], [x14], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x19], x17 + st1 {v20.h}[4], [x13], x17 + st1 {v20.h}[5], [x14], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x19], x17 + st1 {v21.h}[4], [x13], x17 + st1 {v21.h}[5], [x14], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x19], x17 + st1 {v22.h}[4], [x13], x17 + st1 {v22.h}[5], [x14], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x19], x17 + st1 {v23.h}[4], [x13], x17 + st1 {v23.h}[5], [x14], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x19], x17 + st1 {v24.h}[4], [x13], x17 + st1 {v24.h}[5], [x14], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x19], x17 + st1 {v25.h}[4], [x13], x17 + st1 {v25.h}[5], [x14], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x19], x17 + st1 {v26.h}[4], [x13], x17 + st1 {v26.h}[5], [x14], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x19], x17 + st1 {v27.h}[4], [x13], x17 + st1 {v27.h}[5], [x14], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x19], x17 + st1 {v28.h}[4], [x13], x17 + st1 {v28.h}[5], [x14], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x19], x17 + st1 {v29.h}[4], [x13], x17 + st1 {v29.h}[5], [x14], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x19], x17 + st1 {v30.h}[4], [x13], x17 + st1 {v30.h}[5], [x14], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x19], x17 + st1 {v31.h}[4], [x13], x17 + st1 {v31.h}[5], [x14], x17 + b WriteEnd +Write7: + add x13, x19, #8 + add x14, x19, #10 + add x16, x19, #12 + st1 {v16.4h}, [x19], x17 + st1 {v16.h}[4], [x13], x17 + st1 {v16.h}[5], [x14], x17 + st1 {v16.h}[6], [x16], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.4h}, [x19], x17 + st1 {v17.h}[4], [x13], x17 + st1 {v17.h}[5], [x14], x17 + st1 {v17.h}[6], [x16], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.4h}, [x19], x17 + st1 {v18.h}[4], [x13], x17 + st1 {v18.h}[5], [x14], x17 + st1 {v18.h}[6], [x16], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.4h}, [x19], x17 + st1 {v19.h}[4], [x13], x17 + st1 {v19.h}[5], [x14], x17 + st1 {v19.h}[6], [x16], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.4h}, [x19], x17 + st1 {v20.h}[4], [x13], x17 + st1 {v20.h}[5], [x14], x17 + st1 {v20.h}[6], [x16], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.4h}, [x19], x17 + st1 {v21.h}[4], [x13], x17 + st1 {v21.h}[5], [x14], x17 + st1 {v21.h}[6], [x16], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.4h}, [x19], x17 + st1 {v22.h}[4], [x13], x17 + st1 {v22.h}[5], [x14], x17 + st1 {v22.h}[6], [x16], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.4h}, [x19], x17 + st1 {v23.h}[4], [x13], x17 + st1 {v23.h}[5], [x14], x17 + st1 {v23.h}[6], [x16], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.4h}, [x19], x17 + st1 {v24.h}[4], [x13], x17 + st1 {v24.h}[5], [x14], x17 + st1 {v24.h}[6], [x16], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.4h}, [x19], x17 + st1 {v25.h}[4], [x13], x17 + st1 {v25.h}[5], [x14], x17 + st1 {v25.h}[6], [x16], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.4h}, [x19], x17 + st1 {v26.h}[4], [x13], x17 + st1 {v26.h}[5], [x14], x17 + st1 {v26.h}[6], [x16], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.4h}, [x19], x17 + st1 {v27.h}[4], [x13], x17 + st1 {v27.h}[5], [x14], x17 + st1 {v27.h}[6], [x16], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.4h}, [x19], x17 + st1 {v28.h}[4], [x13], x17 + st1 {v28.h}[5], [x14], x17 + st1 {v28.h}[6], [x16], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.4h}, [x19], x17 + st1 {v29.h}[4], [x13], x17 + st1 {v29.h}[5], [x14], x17 + st1 {v29.h}[6], [x16], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.4h}, [x19], x17 + st1 {v30.h}[4], [x13], x17 + st1 {v30.h}[5], [x14], x17 + st1 {v30.h}[6], [x16], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.4h}, [x19], x17 + st1 {v31.h}[4], [x13], x17 + st1 {v31.h}[5], [x14], x17 + st1 {v31.h}[6], [x16], x17 + b WriteEnd +WriteC8: + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64 + b WriteEnd +Write8: + st1 {v16.8h}, [x19], x17 + cmp w10, #1 + beq WriteEnd + st1 {v17.8h}, [x19], x17 + cmp w10, #2 + beq WriteEnd + st1 {v18.8h}, [x19], x17 + cmp w10, #3 + beq WriteEnd + st1 {v19.8h}, [x19], x17 + cmp w10, #4 + beq WriteEnd + st1 {v20.8h}, [x19], x17 + cmp w10, #5 + beq WriteEnd + st1 {v21.8h}, [x19], x17 + cmp w10, #6 + beq WriteEnd + st1 {v22.8h}, [x19], x17 + cmp w10, #7 + beq WriteEnd + st1 {v23.8h}, [x19], x17 + cmp w10, #8 + beq WriteEnd + st1 {v24.8h}, [x19], x17 + cmp w10, #9 + beq WriteEnd + st1 {v25.8h}, [x19], x17 + cmp w10, #10 + beq WriteEnd + st1 {v26.8h}, [x19], x17 + cmp w10, #11 + beq WriteEnd + st1 {v27.8h}, [x19], x17 + cmp w10, #12 + beq WriteEnd + st1 {v28.8h}, [x19], x17 + cmp w10, #13 + beq WriteEnd + st1 {v29.8h}, [x19], x17 + cmp w10, #14 + beq WriteEnd + st1 {v30.8h}, [x19], x17 + cmp w10, #15 + beq WriteEnd + st1 {v31.8h}, [x19], x17 + +WriteEnd: + subs w10, w10, #16 // lhs row - 8 + bgt L2 + +End2: + subs w7, w7, #8 // rhs col - 8 + add x1, x1, x15 // rhs ptr + stride + add x3, x3, #16 // bias ptr + stride + ldrb w13, [sp, #152] + cbz w13, NoDstStep + add x2, x2, #16 // dst ptr + stride +NoDstStep: + bgt L1 + +End1: + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulFp16Opt.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulFp16Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..21348f80a2aa6927598d41a18afbacc76eeb7bc7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulFp16Opt.S @@ -0,0 +1,1185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int row, int col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFp16Neon64Opt + sub sp, sp, #96 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + stp x19, x20, [sp, #64] + stp x21, x22, [sp, #80] + + ldr x8, [sp, #96] + ldr x9, [sp, #104] + + mov x21, #32 // sizeof(float16_t) * 16 + mul x17, x5, x21 // block stride of lhs/rhs: sizeof(float16_t) * 16 * depth + cbnz x9, NoC8Steps + mov x11, x2 + mov x21, #16 + mul x16, x6, x21 // row * 8 * sizeof(float16_t) +NoC8Steps: + cmp x9, #2 + bne NoWinoSteps + mov x21, #2 + mul x15, x7, x8 + mul x15, x15, x21 // kernel_size * col *sizeof(float16_t) + mov x21, #16 + mul x16, x8, x21 // kernel_size * 8 * sizeof(float16_t) +NoWinoSteps: + mov x21, #2 + mul x8, x8, x21 + +LoopRowStart: + cmp x6, #1 + ble LoopRow + cmp x6, #2 + ble LoopRow2 + cmp x6, #4 + ble LoopRow4 + cmp x6, #8 + ble LoopRow8 + +LoopRow16: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol16: + cbz x9, NoReloadDst16 + mov x11, x2 + NoReloadDst16: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + cmp x19, #4 + blt LoopDepth16One + + LoopDepth16: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + fmla v24.8h, v8.8h, v1.h[0] + fmla v25.8h, v8.8h, v1.h[1] + fmla v26.8h, v8.8h, v1.h[2] + fmla v27.8h, v8.8h, v1.h[3] + fmla v28.8h, v8.8h, v1.h[4] + fmla v29.8h, v8.8h, v1.h[5] + fmla v30.8h, v8.8h, v1.h[6] + fmla v31.8h, v8.8h, v1.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + fmla v24.8h, v9.8h, v3.h[0] + fmla v25.8h, v9.8h, v3.h[1] + fmla v26.8h, v9.8h, v3.h[2] + fmla v27.8h, v9.8h, v3.h[3] + fmla v28.8h, v9.8h, v3.h[4] + fmla v29.8h, v9.8h, v3.h[5] + fmla v30.8h, v9.8h, v3.h[6] + fmla v31.8h, v9.8h, v3.h[7] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + fmla v24.8h, v10.8h, v5.h[0] + fmla v25.8h, v10.8h, v5.h[1] + fmla v26.8h, v10.8h, v5.h[2] + fmla v27.8h, v10.8h, v5.h[3] + fmla v28.8h, v10.8h, v5.h[4] + fmla v29.8h, v10.8h, v5.h[5] + fmla v30.8h, v10.8h, v5.h[6] + fmla v31.8h, v10.8h, v5.h[7] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + fmla v24.8h, v11.8h, v7.h[0] + fmla v25.8h, v11.8h, v7.h[1] + fmla v26.8h, v11.8h, v7.h[2] + fmla v27.8h, v11.8h, v7.h[3] + fmla v28.8h, v11.8h, v7.h[4] + fmla v29.8h, v11.8h, v7.h[5] + fmla v30.8h, v11.8h, v7.h[6] + fmla v31.8h, v11.8h, v7.h[7] + + subs x19, x19, #4 + beq Bias16 + cmp x19, #4 + bge LoopDepth16 + + LoopDepth16One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + fmla v20.8h, v2.8h, v0.h[4] + fmla v21.8h, v2.8h, v0.h[5] + fmla v22.8h, v2.8h, v0.h[6] + fmla v23.8h, v2.8h, v0.h[7] + fmla v24.8h, v2.8h, v1.h[0] + fmla v25.8h, v2.8h, v1.h[1] + fmla v26.8h, v2.8h, v1.h[2] + fmla v27.8h, v2.8h, v1.h[3] + fmla v28.8h, v2.8h, v1.h[4] + fmla v29.8h, v2.8h, v1.h[5] + fmla v30.8h, v2.8h, v1.h[6] + fmla v31.8h, v2.8h, v1.h[7] + + subs x19, x19, #1 + bgt LoopDepth16One + + Bias16: + cbz x3, Activation16 + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v0.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v0.8h + fadd v24.8h, v24.8h, v0.8h + fadd v25.8h, v25.8h, v0.8h + fadd v26.8h, v26.8h, v0.8h + fadd v27.8h, v27.8h, v0.8h + fadd v28.8h, v28.8h, v0.8h + fadd v29.8h, v29.8h, v0.8h + fadd v30.8h, v30.8h, v0.8h + fadd v31.8h, v31.8h, v0.8h + + Activation16: + cmp x4, #3 + beq Relu616 + cmp x4, #1 + beq Relu16 + b Write + + Relu616: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + fmin v20.8h, v20.8h, v2.8h + fmin v21.8h, v21.8h, v2.8h + fmin v22.8h, v22.8h, v2.8h + fmin v23.8h, v23.8h, v2.8h + fmin v24.8h, v24.8h, v2.8h + fmin v25.8h, v25.8h, v2.8h + fmin v26.8h, v26.8h, v2.8h + fmin v27.8h, v27.8h, v2.8h + fmin v28.8h, v28.8h, v2.8h + fmin v29.8h, v29.8h, v2.8h + fmin v30.8h, v30.8h, v2.8h + fmin v31.8h, v31.8h, v2.8h + + Relu16: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + fmax v20.8h, v20.8h, v2.8h + fmax v21.8h, v21.8h, v2.8h + fmax v22.8h, v22.8h, v2.8h + fmax v23.8h, v23.8h, v2.8h + fmax v24.8h, v24.8h, v2.8h + fmax v25.8h, v25.8h, v2.8h + fmax v26.8h, v26.8h, v2.8h + fmax v27.8h, v27.8h, v2.8h + fmax v28.8h, v28.8h, v2.8h + fmax v29.8h, v29.8h, v2.8h + fmax v30.8h, v30.8h, v2.8h + fmax v31.8h, v31.8h, v2.8h + b Write + +LoopRow8: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol8: + cbz x9, NoReloadDst8 + mov x11, x2 + NoReloadDst8: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + + cmp x19, #4 + blt LoopDepth8One + + LoopDepth8: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + fmla v20.8h, v8.8h, v0.h[4] + fmla v21.8h, v8.8h, v0.h[5] + fmla v22.8h, v8.8h, v0.h[6] + fmla v23.8h, v8.8h, v0.h[7] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v20.8h, v9.8h, v2.h[4] + fmla v21.8h, v9.8h, v2.h[5] + fmla v22.8h, v9.8h, v2.h[6] + fmla v23.8h, v9.8h, v2.h[7] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v20.8h, v10.8h, v4.h[4] + fmla v21.8h, v10.8h, v4.h[5] + fmla v22.8h, v10.8h, v4.h[6] + fmla v23.8h, v10.8h, v4.h[7] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + fmla v20.8h, v11.8h, v6.h[4] + fmla v21.8h, v11.8h, v6.h[5] + fmla v22.8h, v11.8h, v6.h[6] + fmla v23.8h, v11.8h, v6.h[7] + + subs x19, x19, #4 + beq Bias8 + cmp x19, #4 + bge LoopDepth8 + + LoopDepth8One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + fmla v20.8h, v2.8h, v0.h[4] + fmla v21.8h, v2.8h, v0.h[5] + fmla v22.8h, v2.8h, v0.h[6] + fmla v23.8h, v2.8h, v0.h[7] + + subs x19, x19, #1 + bgt LoopDepth8One + + Bias8: + cbz x3, Activation8 + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v0.8h + fadd v22.8h, v22.8h, v0.8h + fadd v23.8h, v23.8h, v0.8h + + Activation8: + cmp x4, #3 + beq Relu68 + cmp x4, #1 + beq Relu8 + b Write + + Relu68: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + fmin v20.8h, v20.8h, v2.8h + fmin v21.8h, v21.8h, v2.8h + fmin v22.8h, v22.8h, v2.8h + fmin v23.8h, v23.8h, v2.8h + + Relu8: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + fmax v20.8h, v20.8h, v2.8h + fmax v21.8h, v21.8h, v2.8h + fmax v22.8h, v22.8h, v2.8h + fmax v23.8h, v23.8h, v2.8h + b Write + +LoopRow4: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol4: + cbz x9, NoReloadDst4 + mov x11, x2 + NoReloadDst4: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + + cmp x19, #4 + blt LoopDepth4One + + LoopDepth4: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v18.8h, v9.8h, v2.h[2] + fmla v19.8h, v9.8h, v2.h[3] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v18.8h, v10.8h, v4.h[2] + fmla v19.8h, v10.8h, v4.h[3] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + fmla v18.8h, v11.8h, v6.h[2] + fmla v19.8h, v11.8h, v6.h[3] + + subs x19, x19, #4 + beq Bias4 + cmp x19, #4 + bge LoopDepth4 + + LoopDepth4One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + fmla v18.8h, v2.8h, v0.h[2] + fmla v19.8h, v2.8h, v0.h[3] + + subs x19, x19, #1 + bgt LoopDepth4One + + Bias4: + cbz x3, Activation4 + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + fadd v18.8h, v18.8h, v0.8h + fadd v19.8h, v19.8h, v0.8h + + Activation4: + cmp x4, #3 + beq Relu64 + cmp x4, #1 + beq Relu4 + b Write + + Relu64: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + fmin v18.8h, v18.8h, v2.8h + fmin v19.8h, v19.8h, v2.8h + + Relu4: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + fmax v18.8h, v18.8h, v2.8h + fmax v19.8h, v19.8h, v2.8h + b Write + +LoopRow2: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol2: + cbz x9, NoReloadDst2 + mov x11, x2 + NoReloadDst2: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + + cmp x19, #4 + blt LoopDepth2One + + LoopDepth2: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v17.8h, v9.8h, v2.h[1] + fmla v16.8h, v10.8h, v4.h[0] + fmla v17.8h, v10.8h, v4.h[1] + fmla v16.8h, v11.8h, v6.h[0] + fmla v17.8h, v11.8h, v6.h[1] + + subs x19, x19, #4 + beq Bias2 + cmp x19, #4 + bge LoopDepth2 + + LoopDepth2One: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + fmla v17.8h, v2.8h, v0.h[1] + + subs x19, x19, #1 + bgt LoopDepth2One + + Bias2: + cbz x3, Activation2 + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + fadd v17.8h, v17.8h, v0.8h + + Activation2: + cmp x4, #3 + beq Relu62 + cmp x4, #1 + beq Relu2 + b Write + + Relu62: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + fmin v17.8h, v17.8h, v2.8h + + Relu2: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + fmax v17.8h, v17.8h, v2.8h + b Write + +LoopRow: + mov x14, x1 // reload rhs ptr + mov x13, x7 // reload rhs col + mov x12, x3 // reload bias + + LoopCol: + cbz x9, NoReloadDst + mov x11, x2 + NoReloadDst: + mov x10, x0 // reload lhs ptr + mov x19, x5 // reload depth + + dup v16.4s, wzr + + cmp x19, #4 + blt LoopDepthOne + + LoopDepth: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x14], #64 + fmla v16.8h, v8.8h, v0.h[0] + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + fmla v16.8h, v9.8h, v2.h[0] + fmla v16.8h, v10.8h, v4.h[0] + fmla v16.8h, v11.8h, v6.h[0] + + subs x19, x19, #4 + beq Bias + cmp x19, #4 + bge LoopDepth + + LoopDepthOne: + ld1 {v0.8h, v1.8h}, [x10], #32 + ld1 {v2.8h}, [x14], #16 + fmla v16.8h, v2.8h, v0.h[0] + + subs x19, x19, #1 + bgt LoopDepthOne + + Bias: + cbz x3, Activation + ld1 {v0.8h}, [x12], #16 + fadd v16.8h, v16.8h, v0.8h + + Activation: + cmp x4, #3 + beq Relu6 + cmp x4, #1 + beq Relu + b Write + + Relu6: + movi v2.8h, #0x46, lsl #8 + fmin v16.8h, v16.8h, v2.8h + + Relu: + dup v2.8h, wzr + fmax v16.8h, v16.8h, v2.8h + + Write: + cmp x9, #2 + beq WriteWino + cbz x9, WriteC8 + cmp x13, #1 + beq Write1 + cmp x13, #2 + beq Write2 + cmp x13, #3 + beq Write3 + cmp x13, #4 + beq Write4 + cmp x13, #5 + beq Write5 + cmp x13, #6 + beq Write6 + cmp x13, #7 + beq Write7 + b Write8 + + Write1: + add x2, x2, #2 + str h16, [x11] + cmp x6, #1 + beq WriteEnd + add x11, x11, x8 + str h17, [x11] + cmp x6, #2 + beq WriteEnd + add x11, x11, x8 + str h18, [x11] + cmp x6, #3 + beq WriteEnd + add x11, x11, x8 + str h19, [x11] + cmp x6, #4 + beq WriteEnd + add x11, x11, x8 + str h20, [x11] + cmp x6, #5 + beq WriteEnd + add x11, x11, x8 + str h21, [x11] + cmp x6, #6 + beq WriteEnd + add x11, x11, x8 + str h22, [x11] + cmp x6, #7 + beq WriteEnd + add x11, x11, x8 + str h23, [x11] + cmp x6, #8 + beq WriteEnd + add x11, x11, x8 + str h24, [x11] + cmp x6, #9 + beq WriteEnd + add x11, x11, x8 + str h25, [x11] + cmp x6, #10 + beq WriteEnd + add x11, x11, x8 + str h26, [x11] + cmp x6, #11 + beq WriteEnd + add x11, x11, x8 + str h27, [x11] + cmp x6, #12 + beq WriteEnd + add x11, x11, x8 + str h28, [x11] + cmp x6, #13 + beq WriteEnd + add x11, x11, x8 + str h29, [x11] + cmp x6, #14 + beq WriteEnd + add x11, x11, x8 + str h30, [x11] + cmp x6, #15 + beq WriteEnd + add x11, x11, x8 + str h31, [x11] + add x11, x11, x8 + add x11, x11, #2 + b WriteEnd + Write2: + add x2, x2, #4 + st1 {v16.s}[0], [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.s}[0], [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.s}[0], [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.s}[0], [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.s}[0], [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.s}[0], [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.s}[0], [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.s}[0], [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.s}[0], [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.s}[0], [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.s}[0], [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.s}[0], [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.s}[0], [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.s}[0], [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.s}[0], [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.s}[0], [x11], x8 + add x11, x11, #4 + b WriteEnd + Write3: + add x2, x2, #6 + add x19, x11, #4 + st1 {v16.s}[0], [x11], x8 + st1 {v16.h}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.s}[0], [x11], x8 + st1 {v17.h}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.s}[0], [x11], x8 + st1 {v18.h}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.s}[0], [x11], x8 + st1 {v19.h}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.s}[0], [x11], x8 + st1 {v20.h}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.s}[0], [x11], x8 + st1 {v21.h}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.s}[0], [x11], x8 + st1 {v22.h}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.s}[0], [x11], x8 + st1 {v23.h}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.s}[0], [x11], x8 + st1 {v24.h}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.s}[0], [x11], x8 + st1 {v25.h}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.s}[0], [x11], x8 + st1 {v26.h}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.s}[0], [x11], x8 + st1 {v27.h}[2], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.s}[0], [x11], x8 + st1 {v28.h}[2], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.s}[0], [x11], x8 + st1 {v29.h}[2], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.s}[0], [x11], x8 + st1 {v30.h}[2], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.s}[0], [x11], x8 + st1 {v31.h}[2], [x19] + add x11, x11, #6 + b WriteEnd + Write4: + add x2, x2, #8 + st1 {v16.4h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + add x11, x11, #8 + b WriteEnd + Write5: + add x2, x2, #10 + add x19, x11, #8 + st1 {v16.4h}, [x11], x8 + st1 {v16.h}[4], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.h}[4], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.h}[4], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.h}[4], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.h}[4], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.h}[4], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.h}[4], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.h}[4], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.h}[4], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.h}[4], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.h}[4], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.h}[4], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.h}[4], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.h}[4], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.h}[4], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.h}[4], [x19] + add x11, x11, #10 + b WriteEnd + Write6: + add x2, x2, #12 + add x19, x11, #8 + st1 {v16.4h}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.s}[2], [x19], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.s}[2], [x19] + add x11, x11, #12 + b WriteEnd + Write7: + add x2, x2, #14 + add x19, x11, #8 + add x10, x11, #12 + st1 {v16.4h}, [x11], x8 + st1 {v16.s}[2], [x19], x8 + st1 {v16.h}[6], [x10], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.4h}, [x11], x8 + st1 {v17.s}[2], [x19], x8 + st1 {v17.h}[6], [x10], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.4h}, [x11], x8 + st1 {v18.s}[2], [x19], x8 + st1 {v18.h}[6], [x10], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.4h}, [x11], x8 + st1 {v19.s}[2], [x19], x8 + st1 {v19.h}[6], [x10], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.4h}, [x11], x8 + st1 {v20.s}[2], [x19], x8 + st1 {v20.h}[6], [x10], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.4h}, [x11], x8 + st1 {v21.s}[2], [x19], x8 + st1 {v21.h}[6], [x10], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.4h}, [x11], x8 + st1 {v22.s}[2], [x19], x8 + st1 {v22.h}[6], [x10], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.4h}, [x11], x8 + st1 {v23.s}[2], [x19], x8 + st1 {v23.h}[6], [x10], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.4h}, [x11], x8 + st1 {v24.s}[2], [x19], x8 + st1 {v24.h}[6], [x10], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.4h}, [x11], x8 + st1 {v25.s}[2], [x19], x8 + st1 {v25.h}[6], [x10], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.4h}, [x11], x8 + st1 {v26.s}[2], [x19], x8 + st1 {v26.h}[6], [x10], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.4h}, [x11], x8 + st1 {v27.s}[2], [x19], x8 + st1 {v27.h}[6], [x10], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.4h}, [x11], x8 + st1 {v28.s}[2], [x19], x8 + st1 {v28.h}[6], [x10], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.4h}, [x11], x8 + st1 {v29.s}[2], [x19], x8 + st1 {v29.h}[6], [x10], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.4h}, [x11], x8 + st1 {v30.s}[2], [x19], x8 + st1 {v30.h}[6], [x10], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.4h}, [x11], x8 + st1 {v31.s}[2], [x19] + st1 {v31.h}[6], [x10] + add x11, x11, #14 + b WriteEnd + WriteC8: + mov x19, x11 + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x19], #64 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x19], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x19], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x19], #64 + add x11, x11, x16 + b WriteEnd + WriteWino: + add x2, x11, x16 + st1 {v16.8h}, [x11], x15 + st1 {v17.8h}, [x11], x15 + st1 {v18.8h}, [x11], x15 + st1 {v19.8h}, [x11], x15 + st1 {v20.8h}, [x11], x15 + st1 {v21.8h}, [x11], x15 + st1 {v22.8h}, [x11], x15 + st1 {v23.8h}, [x11], x15 + st1 {v24.8h}, [x11], x15 + st1 {v25.8h}, [x11], x15 + st1 {v26.8h}, [x11], x15 + st1 {v27.8h}, [x11], x15 + st1 {v28.8h}, [x11], x15 + st1 {v29.8h}, [x11], x15 + st1 {v30.8h}, [x11], x15 + st1 {v31.8h}, [x11], x15 + b WriteEnd + Write8: + add x2, x2, #16 + st1 {v16.8h}, [x11], x8 + cmp x6, #1 + beq WriteEnd + st1 {v17.8h}, [x11], x8 + cmp x6, #2 + beq WriteEnd + st1 {v18.8h}, [x11], x8 + cmp x6, #3 + beq WriteEnd + st1 {v19.8h}, [x11], x8 + cmp x6, #4 + beq WriteEnd + st1 {v20.8h}, [x11], x8 + cmp x6, #5 + beq WriteEnd + st1 {v21.8h}, [x11], x8 + cmp x6, #6 + beq WriteEnd + st1 {v22.8h}, [x11], x8 + cmp x6, #7 + beq WriteEnd + st1 {v23.8h}, [x11], x8 + cmp x6, #8 + beq WriteEnd + st1 {v24.8h}, [x11], x8 + cmp x6, #9 + beq WriteEnd + st1 {v25.8h}, [x11], x8 + cmp x6, #10 + beq WriteEnd + st1 {v26.8h}, [x11], x8 + cmp x6, #11 + beq WriteEnd + st1 {v27.8h}, [x11], x8 + cmp x6, #12 + beq WriteEnd + st1 {v28.8h}, [x11], x8 + cmp x6, #13 + beq WriteEnd + st1 {v29.8h}, [x11], x8 + cmp x6, #14 + beq WriteEnd + st1 {v30.8h}, [x11], x8 + cmp x6, #15 + beq WriteEnd + st1 {v31.8h}, [x11], x8 + add x11, x11, #16 + + WriteEnd: + subs x13, x13, #8 // rhs col - 8 + ble LoopColEnd + cmp x6, #1 + ble LoopCol + cmp x6, #2 + ble LoopCol2 + cmp x6, #4 + ble LoopCol4 + cmp x6, #8 + ble LoopCol8 + b LoopCol16 + +LoopColEnd: + add x0, x0, x17 + cbz x9, C8DstStep + mov x21, #2 + mul x21, x21, x7 + sub x11, x11, x21 + mov x2, x11 + b NoDstStep + C8DstStep: + add x2, x2, #256 + mov x11, x2 + NoDstStep: + subs x6, x6, #16 + bgt LoopRowStart + + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulFp16OptV2.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulFp16OptV2.S new file mode 100644 index 0000000000000000000000000000000000000000..1c340a228b6a89bef71d96116b3adde1f5ddc4e1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulFp16OptV2.S @@ -0,0 +1,2966 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void MatmulFp16OptV2(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// size_t depth, size_t row, size_t col, size_t stride, size_t writeMode) +// x0: a +// x1: b +// x2: c +// x3: bias +// x4: act_type +// x5: depth +// x6: row +// x7: col +// x8: stride +// x9: writeMode + +asm_function MatmulFp16OptV2 + sub sp, sp, #192 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + stp x19, x20, [sp, #128] + stp x21, x22, [sp, #144] + stp x23, x24, [sp, #160] + stp x29, x30, [sp, #176] + + ldr x8, [sp, #192] + ldr x9, [sp, #200] // writeMode + lsl x8, x8, #1 // stride * sizeof(float16_t) + + lsl x15, x7, #1 // col * sizeof(float16_t) + lsl x16, x5, #1 // depth * sizeof(float16_t) + mov x11, x2 + movi v7.8h, #0x46, lsl #8 + subs x6, x6, #12 + blt LoopRow8 +LoopRow12: + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol12x8 + LoopCol12x16: + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias12x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + dup v24.2d, xzr + dup v25.2d, xzr + dup v26.2d, xzr + dup v27.2d, xzr + dup v28.2d, xzr + dup v29.2d, xzr + dup v30.2d, xzr + dup v31.2d, xzr + b Compute12x16Enter + InitFromBias12x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + ld1 {v12.8h, v13.8h}, [x12] + ld1 {v14.8h, v15.8h}, [x12] + ld1 {v16.8h, v17.8h}, [x12] + ld1 {v18.8h, v19.8h}, [x12] + ld1 {v20.8h, v21.8h}, [x12] + ld1 {v22.8h, v23.8h}, [x12] + ld1 {v24.8h, v25.8h}, [x12] + ld1 {v26.8h, v27.8h}, [x12] + ld1 {v28.8h, v29.8h}, [x12] + ld1 {v30.8h, v31.8h}, [x12] + add x12, x12, #32 + Compute12x16Enter: + bl Compute12x16Unit + Activation12x16: + cmp x4, #3 + beq Relu612x16 + cmp x4, #1 + beq Relu12x16 + b Write12x16 + + Relu612x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + fmin v24.8h, v24.8h, v7.8h + fmin v25.8h, v25.8h, v7.8h + fmin v26.8h, v26.8h, v7.8h + fmin v27.8h, v27.8h, v7.8h + fmin v28.8h, v28.8h, v7.8h + fmin v29.8h, v29.8h, v7.8h + fmin v30.8h, v30.8h, v7.8h + fmin v31.8h, v31.8h, v7.8h + + Relu12x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + fmax v24.8h, v24.8h, v6.8h + fmax v25.8h, v25.8h, v6.8h + fmax v26.8h, v26.8h, v6.8h + fmax v27.8h, v27.8h, v6.8h + fmax v28.8h, v28.8h, v6.8h + fmax v29.8h, v29.8h, v6.8h + fmax v30.8h, v30.8h, v6.8h + fmax v31.8h, v31.8h, v6.8h + Write12x16: + mov x22, x21 + add x23, x21, x8, lsl #2 + add x24, x21, x8, lsl #3 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22], x8 + st1 {v12.8h, v13.8h}, [x22], x8 + st1 {v14.8h, v15.8h}, [x22] + st1 {v16.8h, v17.8h}, [x23], x8 + st1 {v18.8h, v19.8h}, [x23], x8 + st1 {v20.8h, v21.8h}, [x23], x8 + st1 {v22.8h, v23.8h}, [x23] + st1 {v24.8h, v25.8h}, [x24], x8 + st1 {v26.8h, v27.8h}, [x24], x8 + st1 {v28.8h, v29.8h}, [x24], x8 + st1 {v30.8h, v31.8h}, [x24] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol12x16 + + LoopCol12x8: + adds x13, x13, #16 + cbz x13, LoopRow12End + subs x13, x13, #8 + blt LoopCol12x4 + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias12x8 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + dup v24.2d, xzr + dup v26.2d, xzr + dup v28.2d, xzr + dup v30.2d, xzr + b Compute12x8Enter + InitFromBias12x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + ld1 {v12.8h}, [x12] + ld1 {v14.8h}, [x12] + ld1 {v16.8h}, [x12] + ld1 {v18.8h}, [x12] + ld1 {v20.8h}, [x12] + ld1 {v22.8h}, [x12] + ld1 {v24.8h}, [x12] + ld1 {v26.8h}, [x12] + ld1 {v28.8h}, [x12] + ld1 {v30.8h}, [x12] + add x12, x12, #16 + Compute12x8Enter: + bl Compute12x8Unit + Activation12x8: + cmp x4, #3 + beq Relu612x8 + cmp x4, #1 + beq Relu12x8 + b Write12x8 + + Relu612x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v24.8h, v24.8h, v7.8h + fmin v26.8h, v26.8h, v7.8h + fmin v28.8h, v28.8h, v7.8h + fmin v30.8h, v30.8h, v7.8h + + Relu12x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v24.8h, v24.8h, v6.8h + fmax v26.8h, v26.8h, v6.8h + fmax v28.8h, v28.8h, v6.8h + fmax v30.8h, v30.8h, v6.8h + Write12x8: + mov x22, x21 + add x23, x21, x8, lsl #2 + add x24, x21, x8, lsl #3 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22], x8 + st1 {v12.8h}, [x22], x8 + st1 {v14.8h}, [x22] + st1 {v16.8h}, [x23], x8 + st1 {v18.8h}, [x23], x8 + st1 {v20.8h}, [x23], x8 + st1 {v22.8h}, [x23] + st1 {v24.8h}, [x24], x8 + st1 {v26.8h}, [x24], x8 + st1 {v28.8h}, [x24], x8 + st1 {v30.8h}, [x24] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol12x4: + adds x13, x13, #8 + cbz x13, LoopRow12End + LoopCol12x4Core: + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + cbnz x12, InitFromBias12x4 + dup v8.2s, wzr + dup v10.2s, wzr + dup v12.2s, wzr + dup v14.2s, wzr + dup v16.2s, wzr + dup v18.2s, wzr + dup v20.2s, wzr + dup v22.2s, wzr + dup v24.2s, wzr + dup v26.2s, wzr + dup v28.2s, wzr + dup v30.2s, wzr + b Compute12x4Enter + InitFromBias12x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + ld1 {v12.4h}, [x12] + ld1 {v14.4h}, [x12] + ld1 {v16.4h}, [x12] + ld1 {v18.4h}, [x12] + ld1 {v20.4h}, [x12] + ld1 {v22.4h}, [x12] + ld1 {v24.4h}, [x12] + ld1 {v26.4h}, [x12] + ld1 {v28.4h}, [x12] + ld1 {v30.4h}, [x12] + add x12, x12, #8 + Compute12x4Enter: + bl Compute12x4Unit + Activation12x4: + cmp x4, #3 + beq Relu612x4 + cmp x4, #1 + beq Relu12x4 + b Write12x4 + + Relu612x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + fmin v12.4h, v12.4h, v7.4h + fmin v14.4h, v14.4h, v7.4h + fmin v16.4h, v16.4h, v7.4h + fmin v18.4h, v18.4h, v7.4h + fmin v20.4h, v20.4h, v7.4h + fmin v22.4h, v22.4h, v7.4h + fmin v24.4h, v24.4h, v7.4h + fmin v26.4h, v26.4h, v7.4h + fmin v28.4h, v28.4h, v7.4h + fmin v30.4h, v30.4h, v7.4h + + Relu12x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + fmax v12.4h, v12.4h, v6.4h + fmax v14.4h, v14.4h, v6.4h + fmax v16.4h, v16.4h, v6.4h + fmax v18.4h, v18.4h, v6.4h + fmax v20.4h, v20.4h, v6.4h + fmax v22.4h, v22.4h, v6.4h + fmax v24.4h, v24.4h, v6.4h + fmax v26.4h, v26.4h, v6.4h + fmax v28.4h, v28.4h, v6.4h + fmax v30.4h, v30.4h, v6.4h + Write12x4: + mov x22, x21 + add x23, x21, x8, lsl #2 + add x24, x21, x8, lsl #3 + cmp x13, #1 + beq Write12x1 + cmp x13, #2 + beq Write12x2 + cmp x13, #3 + beq Write12x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22], x8 + st1 {v12.4h}, [x22], x8 + st1 {v14.4h}, [x22] + st1 {v16.4h}, [x23], x8 + st1 {v18.4h}, [x23], x8 + st1 {v20.4h}, [x23], x8 + st1 {v22.4h}, [x23] + st1 {v24.4h}, [x24], x8 + st1 {v26.4h}, [x24], x8 + st1 {v28.4h}, [x24], x8 + st1 {v30.4h}, [x24] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol12x4Core + b LoopRow12End + Write12x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22], x8 + st1 {v12.h}[0], [x22], x8 + st1 {v14.h}[0], [x22] + st1 {v16.h}[0], [x23], x8 + st1 {v18.h}[0], [x23], x8 + st1 {v20.h}[0], [x23], x8 + st1 {v22.h}[0], [x23] + st1 {v24.h}[0], [x24], x8 + st1 {v26.h}[0], [x24], x8 + st1 {v28.h}[0], [x24], x8 + st1 {v30.h}[0], [x24] + b LoopRow12End + Write12x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v14.s}[0], [x22] + st1 {v16.s}[0], [x23], x8 + st1 {v18.s}[0], [x23], x8 + st1 {v20.s}[0], [x23], x8 + st1 {v22.s}[0], [x23] + st1 {v24.s}[0], [x24], x8 + st1 {v26.s}[0], [x24], x8 + st1 {v28.s}[0], [x24], x8 + st1 {v30.s}[0], [x24] + b LoopRow12End + Write12x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v12.h}[2], [x23], x8 + st1 {v14.s}[0], [x22], x8 + st1 {v14.h}[2], [x23], x8 + st1 {v16.s}[0], [x22], x8 + st1 {v16.h}[2], [x23], x8 + st1 {v18.s}[0], [x22], x8 + st1 {v18.h}[2], [x23], x8 + st1 {v20.s}[0], [x22], x8 + st1 {v20.h}[2], [x23], x8 + st1 {v22.s}[0], [x22], x8 + st1 {v22.h}[2], [x23], x8 + st1 {v24.s}[0], [x22], x8 + st1 {v24.h}[2], [x23], x8 + st1 {v26.s}[0], [x22], x8 + st1 {v26.h}[2], [x23], x8 + st1 {v28.s}[0], [x22], x8 + st1 {v28.h}[2], [x23], x8 + st1 {v30.s}[0], [x22] + st1 {v30.h}[2], [x23] + LoopRow12End: + add x0, x0, x16, lsl #3 + add x0, x0, x16, lsl #2 + add x2, x2, x8, lsl #3 + add x2, x2, x8, lsl #2 + subs x6, x6, #12 + bge LoopRow12 + +LoopRow8: + adds x6, x6,#12 + cbz x6, End + subs x6, x6, #8 + blt LoopRow4 + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol8x8 + LoopCol8x16: + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias8x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + dup v16.2d, xzr + dup v17.2d, xzr + dup v18.2d, xzr + dup v19.2d, xzr + dup v20.2d, xzr + dup v21.2d, xzr + dup v22.2d, xzr + dup v23.2d, xzr + b Compute8x16Enter + InitFromBias8x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + ld1 {v12.8h, v13.8h}, [x12] + ld1 {v14.8h, v15.8h}, [x12] + ld1 {v16.8h, v17.8h}, [x12] + ld1 {v18.8h, v19.8h}, [x12] + ld1 {v20.8h, v21.8h}, [x12] + ld1 {v22.8h, v23.8h}, [x12] + add x12, x12, #32 + Compute8x16Enter: + bl Compute8x16Unit + Activation8x16: + cmp x4, #3 + beq Relu68x16 + cmp x4, #1 + beq Relu8x16 + b Write8x16 + + Relu68x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + + Relu8x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + Write8x16: + mov x22, x21 + add x23, x21, x8, lsl #2 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22], x8 + st1 {v12.8h, v13.8h}, [x22], x8 + st1 {v14.8h, v15.8h}, [x22] + st1 {v16.8h, v17.8h}, [x23], x8 + st1 {v18.8h, v19.8h}, [x23], x8 + st1 {v20.8h, v21.8h}, [x23], x8 + st1 {v22.8h, v23.8h}, [x23] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol8x16 + + LoopCol8x8: + adds x13, x13, #16 + cbz x13, LoopRow8End + subs x13, x13, #8 + blt LoopCol8x4 + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias8x8 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + dup v16.2d, xzr + dup v18.2d, xzr + dup v20.2d, xzr + dup v22.2d, xzr + b Compute8x8Enter + InitFromBias8x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + ld1 {v12.8h}, [x12] + ld1 {v14.8h}, [x12] + ld1 {v16.8h}, [x12] + ld1 {v18.8h}, [x12] + ld1 {v20.8h}, [x12] + ld1 {v22.8h}, [x12] + add x12, x12, #16 + Compute8x8Enter: + bl Compute8x8Unit + Activation8x8: + cmp x4, #3 + beq Relu68x8 + cmp x4, #1 + beq Relu8x8 + b Write8x8 + + Relu68x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + + Relu8x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + Write8x8: + mov x22, x21 + add x23, x21, x8, lsl #2 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22], x8 + st1 {v12.8h}, [x22], x8 + st1 {v14.8h}, [x22] + st1 {v16.8h}, [x23], x8 + st1 {v18.8h}, [x23], x8 + st1 {v20.8h}, [x23], x8 + st1 {v22.8h}, [x23] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol8x4: + adds x13, x13, #8 + cbz x13, LoopRow8End + LoopCol8x4Core: + mov x10, x0 // update matrixA + ld1 {v0.8h}, [x10], #16 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + cbnz x12, InitFromBias8x4 + dup v8.2s, wzr + dup v10.2s, wzr + dup v12.2s, wzr + dup v14.2s, wzr + dup v16.2s, wzr + dup v18.2s, wzr + dup v20.2s, wzr + dup v22.2s, wzr + b Compute8x4Enter + InitFromBias8x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + ld1 {v12.4h}, [x12] + ld1 {v14.4h}, [x12] + ld1 {v16.4h}, [x12] + ld1 {v18.4h}, [x12] + ld1 {v20.4h}, [x12] + ld1 {v22.4h}, [x12] + add x12, x12, #8 + Compute8x4Enter: + bl Compute8x4Unit + Activation8x4: + cmp x4, #3 + beq Relu68x4 + cmp x4, #1 + beq Relu8x4 + b Write8x4 + + Relu68x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + fmin v12.4h, v12.4h, v7.4h + fmin v14.4h, v14.4h, v7.4h + fmin v16.4h, v16.4h, v7.4h + fmin v18.4h, v18.4h, v7.4h + fmin v20.4h, v20.4h, v7.4h + fmin v22.4h, v22.4h, v7.4h + + Relu8x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + fmax v12.4h, v12.4h, v6.4h + fmax v14.4h, v14.4h, v6.4h + fmax v16.4h, v16.4h, v6.4h + fmax v18.4h, v18.4h, v6.4h + fmax v20.4h, v20.4h, v6.4h + fmax v22.4h, v22.4h, v6.4h + Write8x4: + mov x22, x21 + add x23, x21, x8, lsl #2 + cmp x13, #1 + beq Write8x1 + cmp x13, #2 + beq Write8x2 + cmp x13, #3 + beq Write8x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22], x8 + st1 {v12.4h}, [x22], x8 + st1 {v14.4h}, [x22] + st1 {v16.4h}, [x23], x8 + st1 {v18.4h}, [x23], x8 + st1 {v20.4h}, [x23], x8 + st1 {v22.4h}, [x23] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol8x4Core + b LoopRow8End + Write8x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22], x8 + st1 {v12.h}[0], [x22], x8 + st1 {v14.h}[0], [x22] + st1 {v16.h}[0], [x23], x8 + st1 {v18.h}[0], [x23], x8 + st1 {v20.h}[0], [x23], x8 + st1 {v22.h}[0], [x23] + b LoopRow8End + Write8x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v14.s}[0], [x22] + st1 {v16.s}[0], [x23], x8 + st1 {v18.s}[0], [x23], x8 + st1 {v20.s}[0], [x23], x8 + st1 {v22.s}[0], [x23] + b LoopRow8End + Write8x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v12.h}[2], [x23], x8 + st1 {v14.s}[0], [x22], x8 + st1 {v14.h}[2], [x23], x8 + st1 {v16.s}[0], [x22], x8 + st1 {v16.h}[2], [x23], x8 + st1 {v18.s}[0], [x22], x8 + st1 {v18.h}[2], [x23], x8 + st1 {v20.s}[0], [x22], x8 + st1 {v20.h}[2], [x23], x8 + st1 {v22.s}[0], [x22], x8 + st1 {v22.h}[2], [x23], x8 + LoopRow8End: + add x0, x0, x16, lsl #3 + add x2, x2, x8, lsl #3 + subs x6, x6, #8 + +LoopRow4: + adds x6, x6, #8 + cbz x6, End + subs x6, x6, #4 + blt LoopRowTail + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol4x8 + LoopCol4x16: + mov x10, x0 // update matrixA + ld1 {v0.4h}, [x10], #8 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias4x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + dup v14.2d, xzr + dup v15.2d, xzr + b Compute4x16Enter + InitFromBias4x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + ld1 {v12.8h, v13.8h}, [x12] + ld1 {v14.8h, v15.8h}, [x12] + add x12, x12, #32 + Compute4x16Enter: + bl Compute4x16Unit + Activation4x16: + cmp x4, #3 + beq Relu64x16 + cmp x4, #1 + beq Relu4x16 + b Write4x16 + + Relu64x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + + Relu4x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + Write4x16: + mov x22, x21 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22], x8 + st1 {v12.8h, v13.8h}, [x22], x8 + st1 {v14.8h, v15.8h}, [x22] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol4x16 + + LoopCol4x8: + adds x13, x13, #16 + cbz x13, LoopRow4End + subs x13, x13, #8 + blt LoopCol4x4 + mov x10, x0 // update matrixA + ld1 {v0.4h}, [x10], #8 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + cbnz x12, InitFromBias4x8 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + dup v14.2d, xzr + b Compute4x8Enter + InitFromBias4x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + ld1 {v12.8h}, [x12] + ld1 {v14.8h}, [x12] + add x12, x12, #16 + Compute4x8Enter: + bl Compute4x8Unit + Activation4x8: + cmp x4, #3 + beq Relu64x8 + cmp x4, #1 + beq Relu4x8 + b Write4x8 + + Relu64x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + + Relu4x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + Write4x8: + mov x22, x21 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22], x8 + st1 {v12.8h}, [x22], x8 + st1 {v14.8h}, [x22] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol4x4: + adds x13, x13, #8 + cbz x13, LoopRow4End + LoopCol4x4Core: + mov x10, x0 // update matrixA + ld1 {v0.4h}, [x10], #8 + mov x14, x5 // reload depth + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + cbnz x12, InitFromBias4x4 + dup v8.2s, wzr + dup v10.2s, wzr + dup v12.2s, wzr + dup v14.2s, wzr + b Compute4x4Enter + InitFromBias4x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + ld1 {v12.4h}, [x12] + ld1 {v14.4h}, [x12] + add x12, x12, #8 + Compute4x4Enter: + bl Compute4x4Unit + Activation4x4: + cmp x4, #3 + beq Relu64x4 + cmp x4, #1 + beq Relu4x4 + b Write4x4 + + Relu64x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + fmin v12.4h, v12.4h, v7.4h + fmin v14.4h, v14.4h, v7.4h + + Relu4x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + fmax v12.4h, v12.4h, v6.4h + fmax v14.4h, v14.4h, v6.4h + Write4x4: + mov x22, x21 + cmp x13, #1 + beq Write4x1 + cmp x13, #2 + beq Write4x2 + cmp x13, #3 + beq Write4x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22], x8 + st1 {v12.4h}, [x22], x8 + st1 {v14.4h}, [x22] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol4x4Core + b LoopRow4End + Write4x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22], x8 + st1 {v12.h}[0], [x22], x8 + st1 {v14.h}[0], [x22] + b LoopRow4End + Write4x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v14.s}[0], [x22] + b LoopRow4End + Write4x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v12.h}[2], [x23], x8 + st1 {v14.s}[0], [x22], x8 + st1 {v14.h}[2], [x23], x8 + LoopRow4End: + add x0, x0, x16, lsl #2 + add x2, x2, x8, lsl #2 + subs x6, x6, #4 + +LoopRowTail: + adds x6, x6, #4 + cbz x6, End + cmp x6, #1 + beq LoopRow1 + cmp x6, #2 + beq LoopRow2 + // LoopRow3 + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol3x8 + LoopCol3x16: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias3x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + dup v12.2d, xzr + dup v13.2d, xzr + b Compute3x16Enter + InitFromBias3x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + ld1 {v12.8h, v13.8h}, [x12] + add x12, x12, #32 + Compute3x16Enter: + bl Compute3x16Unit + Activation3x16: + cmp x4, #3 + beq Relu63x16 + cmp x4, #1 + beq Relu3x16 + b Write3x16 + + Relu63x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + + Relu3x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + Write3x16: + mov x22, x21 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22], x8 + st1 {v12.8h, v13.8h}, [x22] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol3x16 + + LoopCol3x8: + adds x13, x13, #16 + cbz x13, End + subs x13, x13, #8 + blt LoopCol3x4 + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias3x8 + dup v8.2d, xzr + dup v10.2d, xzr + dup v12.2d, xzr + b Compute3x8Enter + InitFromBias3x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + ld1 {v12.8h}, [x12] + add x12, x12, #16 + Compute3x8Enter: + bl Compute3x8Unit + Activation3x8: + cmp x4, #3 + beq Relu63x8 + cmp x4, #1 + beq Relu3x8 + b Write3x8 + + Relu63x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + + Relu3x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + Write3x8: + mov x22, x21 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22], x8 + st1 {v12.8h}, [x22] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol3x4: + adds x13, x13, #8 + cbz x13, End + LoopCol3x4Core: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias3x4 + dup v8.2s, wzr + dup v10.2s, wzr + dup v12.2s, wzr + b Compute3x4Enter + InitFromBias3x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + ld1 {v12.4h}, [x12] + add x12, x12, #8 + Compute3x4Enter: + bl Compute3x4Unit + Activation3x4: + cmp x4, #3 + beq Relu63x4 + cmp x4, #1 + beq Relu3x4 + b Write3x4 + + Relu63x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + fmin v12.4h, v12.4h, v7.4h + + Relu3x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + fmax v12.4h, v12.4h, v6.4h + Write3x4: + mov x22, x21 + cmp x13, #1 + beq Write3x1 + cmp x13, #2 + beq Write3x2 + cmp x13, #3 + beq Write3x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22], x8 + st1 {v12.4h}, [x22] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol3x4Core + b End + Write3x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22], x8 + st1 {v12.h}[0], [x22] + b End + Write3x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v12.s}[0], [x22] + b End + Write3x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + st1 {v12.s}[0], [x22], x8 + st1 {v12.h}[2], [x23], x8 + b End + +LoopRow2: + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol2x8 + LoopCol2x16: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias2x16 + dup v8.2d, xzr + dup v9.2d, xzr + dup v10.2d, xzr + dup v11.2d, xzr + b Compute2x16Enter + InitFromBias2x16: + ld1 {v8.8h, v9.8h}, [x12] + ld1 {v10.8h, v11.8h}, [x12] + add x12, x12, #32 + Compute2x16Enter: + bl Compute2x16Unit + Activation2x16: + cmp x4, #3 + beq Relu62x16 + cmp x4, #1 + beq Relu2x16 + b Write2x16 + + Relu62x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + + Relu2x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + Write2x16: + mov x22, x21 + st1 {v8.8h, v9.8h}, [x22], x8 + st1 {v10.8h, v11.8h}, [x22] + add x21, x21, #32 + subs x13, x13, #16 + bge LoopCol2x16 + + LoopCol2x8: + adds x13, x13, #16 + cbz x13, End + subs x13, x13, #8 + blt LoopCol2x4 + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias2x8 + dup v8.2d, xzr + dup v10.2d, xzr + b Compute2x8Enter + InitFromBias2x8: + ld1 {v8.8h}, [x12] + ld1 {v10.8h}, [x12] + add x12, x12, #16 + Compute2x8Enter: + bl Compute2x8Unit + Activation2x8: + cmp x4, #3 + beq Relu62x8 + cmp x4, #1 + beq Relu2x8 + b Write2x8 + + Relu62x8: + fmin v8.8h, v8.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + + Relu2x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + Write2x8: + mov x22, x21 + st1 {v8.8h}, [x22], x8 + st1 {v10.8h}, [x22] + add x21, x21, #16 + subs x13, x13, #8 + + LoopCol2x4: + adds x13, x13, #8 + cbz x13, End + LoopCol2x4Core: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias2x4 + dup v8.2s, wzr + dup v10.2s, wzr + b Compute2x4Enter + InitFromBias2x4: + ld1 {v8.4h}, [x12] + ld1 {v10.4h}, [x12] + add x12, x12, #8 + Compute2x4Enter: + bl Compute2x4Unit + Activation2x4: + cmp x4, #3 + beq Relu62x4 + cmp x4, #1 + beq Relu2x4 + b Write2x4 + + Relu62x4: + fmin v8.4h, v8.4h, v7.4h + fmin v10.4h, v10.4h, v7.4h + Relu2x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + fmax v10.4h, v10.4h, v6.4h + Write2x4: + mov x22, x21 + cmp x13, #1 + beq Write2x1 + cmp x13, #2 + beq Write2x2 + cmp x13, #3 + beq Write2x3 + st1 {v8.4h}, [x22], x8 + st1 {v10.4h}, [x22] + add x21, x21, #8 + subs x13, x13, #4 + bgt LoopCol2x4Core + b End + Write2x1: + st1 {v8.h}[0], [x22], x8 + st1 {v10.h}[0], [x22] + b End + Write2x2: + st1 {v8.s}[0], [x22], x8 + st1 {v10.s}[0], [x22] + b End + Write2x3: + add x23, x22, #4 + st1 {v8.s}[0], [x22], x8 + st1 {v8.h}[2], [x23], x8 + st1 {v10.s}[0], [x22], x8 + st1 {v10.h}[2], [x23], x8 + b End + +LoopRow1: + mov x11, x1 // reload matrixB + mov x12, x3 // reload bias + mov x13, x7 // reload col + mov x21, x2 // relocate output + subs x13, x13, #16 + blt LoopCol1x8 + LoopCol1x16: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias1x16 + dup v8.2d, xzr + dup v9.2d, xzr + b Compute1x16Enter + InitFromBias1x16: + ld1 {v8.8h, v9.8h}, [x12], #32 + Compute1x16Enter: + bl Compute1x16Unit + Activation1x16: + cmp x4, #3 + beq Relu61x16 + cmp x4, #1 + beq Relu1x16 + b Write1x16 + + Relu61x16: + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + + Relu1x16: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + Write1x16: + st1 {v8.8h, v9.8h}, [x21], #32 + subs x13, x13, #16 + bge LoopCol1x16 + + LoopCol1x8: + adds x13, x13, #16 + cbz x13, End + subs x13, x13, #8 + blt LoopCol1x4 + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias1x8 + dup v8.2d, xzr + b Compute1x8Enter + InitFromBias1x8: + ld1 {v8.8h}, [x12], #16 + Compute1x8Enter: + bl Compute1x8Unit + Activation1x8: + cmp x4, #3 + beq Relu61x8 + cmp x4, #1 + beq Relu1x8 + b Write1x8 + + Relu61x8: + fmin v8.8h, v8.8h, v7.8h + + Relu1x8: + dup v6.8h, wzr + fmax v8.8h, v8.8h, v6.8h + Write1x8: + st1 {v8.8h}, [x21], #16 + subs x13, x13, #8 + + LoopCol1x4: + adds x13, x13, #8 + cbz x13, End + LoopCol1x4Core: + mov x10, x0 // update matrixA + mov x14, x5 // reload depth + cbnz x12, InitFromBias1x4 + dup v8.2s, wzr + b Compute1x4Enter + InitFromBias1x4: + ld1 {v8.4h}, [x12], #8 + Compute1x4Enter: + bl Compute1x4Unit + Activation1x4: + cmp x4, #3 + beq Relu61x4 + cmp x4, #1 + beq Relu1x4 + b Write1x4 + + Relu61x4: + fmin v8.4h, v8.4h, v7.4h + Relu1x4: + dup v6.4h, wzr + fmax v8.4h, v8.4h, v6.4h + Write1x4: + cmp x13, #1 + beq Write1x1 + cmp x13, #2 + beq Write1x2 + cmp x13, #3 + beq Write1x3 + st1 {v8.4h}, [x21], #8 + subs x13, x13, #4 + bgt LoopCol1x4Core + b End + Write1x1: + st1 {v8.h}[0], [x21] + b End + Write1x2: + st1 {v8.s}[0], [x21] + b End + Write1x3: + add x22, x21, #4 + st1 {v8.s}[0], [x21] + st1 {v8.h}[2], [x22] + b End + +Compute12x16Unit: + subs x14, x14, #2 + ble Compute12x16End + Compute12x16: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h, v2.8h}, [x10], #32 + ld1 {v4.8h, v5.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v6.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + fmla v25.8h, v4.8h, v1.h[0] + fmla v27.8h, v4.8h, v1.h[1] + fmla v29.8h, v4.8h, v1.h[2] + fmla v31.8h, v4.8h, v1.h[3] + + fmla v8.8h, v5.8h, v1.h[4] + fmla v10.8h, v5.8h, v1.h[5] + fmla v12.8h, v5.8h, v1.h[6] + fmla v14.8h, v5.8h, v1.h[7] + fmla v16.8h, v5.8h, v2.h[0] + fmla v18.8h, v5.8h, v2.h[1] + fmla v20.8h, v5.8h, v2.h[2] + fmla v22.8h, v5.8h, v2.h[3] + fmla v24.8h, v5.8h, v2.h[4] + fmla v26.8h, v5.8h, v2.h[5] + fmla v28.8h, v5.8h, v2.h[6] + fmla v30.8h, v5.8h, v2.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v6.8h, v1.h[4] + fmla v11.8h, v6.8h, v1.h[5] + fmla v13.8h, v6.8h, v1.h[6] + fmla v15.8h, v6.8h, v1.h[7] + prfm pldl1keep, [x10, #632] + ld1 {v0.8h}, [x10], #16 + fmla v17.8h, v6.8h, v2.h[0] + fmla v19.8h, v6.8h, v2.h[1] + fmla v21.8h, v6.8h, v2.h[2] + fmla v23.8h, v6.8h, v2.h[3] + fmla v25.8h, v6.8h, v2.h[4] + fmla v27.8h, v6.8h, v2.h[5] + fmla v29.8h, v6.8h, v2.h[6] + fmla v31.8h, v6.8h, v2.h[7] + + subs x14, x14, #2 + bgt Compute12x16 + Compute12x16End: + cbnz x14, Compute12x16End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + ld1 {v2.8h}, [x10], #16 + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + fmla v25.8h, v4.8h, v1.h[0] + fmla v27.8h, v4.8h, v1.h[1] + fmla v29.8h, v4.8h, v1.h[2] + fmla v31.8h, v4.8h, v1.h[3] + mov v0.16b, v2.16b + Compute12x16End1: + ld1 {v1.4h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + fmla v25.8h, v4.8h, v1.h[0] + fmla v27.8h, v4.8h, v1.h[1] + fmla v29.8h, v4.8h, v1.h[2] + fmla v31.8h, v4.8h, v1.h[3] + ret + +Compute12x8Unit: + subs x14, x14, #2 + ble Compute12x8End + Compute12x8: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h, v2.8h}, [x10], #32 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v4.8h, v1.h[4] + fmla v10.8h, v4.8h, v1.h[5] + fmla v12.8h, v4.8h, v1.h[6] + fmla v14.8h, v4.8h, v1.h[7] + ld1 {v0.8h}, [x10], #16 + fmla v16.8h, v4.8h, v2.h[0] + fmla v18.8h, v4.8h, v2.h[1] + fmla v20.8h, v4.8h, v2.h[2] + fmla v22.8h, v4.8h, v2.h[3] + fmla v24.8h, v4.8h, v2.h[4] + fmla v26.8h, v4.8h, v2.h[5] + fmla v28.8h, v4.8h, v2.h[6] + fmla v30.8h, v4.8h, v2.h[7] + + subs x14, x14, #2 + bgt Compute12x8 + Compute12x8End: + cbnz x14, Compute12x8End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + ld1 {v0.8h}, [x10], #16 + mov v3.16b, v4.16b + Compute12x8End1: + ld1 {v1.4h}, [x10] + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v24.8h, v3.8h, v1.h[0] + fmla v26.8h, v3.8h, v1.h[1] + fmla v28.8h, v3.8h, v1.h[2] + fmla v30.8h, v3.8h, v1.h[3] + ret + +Compute12x4Unit: + subs x14, x14, #2 + ble Compute12x4End + Compute12x4: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h, v2.8h}, [x10], #32 + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + fmla v24.4h, v3.4h, v1.h[0] + fmla v26.4h, v3.4h, v1.h[1] + fmla v28.4h, v3.4h, v1.h[2] + fmla v30.4h, v3.4h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v4.4h, v1.h[4] + fmla v10.4h, v4.4h, v1.h[5] + fmla v12.4h, v4.4h, v1.h[6] + fmla v14.4h, v4.4h, v1.h[7] + ld1 {v0.8h}, [x10], #16 + fmla v16.4h, v4.4h, v2.h[0] + fmla v18.4h, v4.4h, v2.h[1] + fmla v20.4h, v4.4h, v2.h[2] + fmla v22.4h, v4.4h, v2.h[3] + fmla v24.4h, v4.4h, v2.h[4] + fmla v26.4h, v4.4h, v2.h[5] + fmla v28.4h, v4.4h, v2.h[6] + fmla v30.4h, v4.4h, v2.h[7] + + subs x14, x14, #2 + bgt Compute12x4 + Compute12x4End: + cbnz x14, Compute12x4End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + fmla v24.4h, v3.4h, v1.h[0] + fmla v26.4h, v3.4h, v1.h[1] + fmla v28.4h, v3.4h, v1.h[2] + fmla v30.4h, v3.4h, v1.h[3] + ld1 {v0.8h}, [x10], #16 + mov v3.8b, v4.8b + Compute12x4End1: + ld1 {v1.4h}, [x10] + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + fmla v24.4h, v3.4h, v1.h[0] + fmla v26.4h, v3.4h, v1.h[1] + fmla v28.4h, v3.4h, v1.h[2] + fmla v30.4h, v3.4h, v1.h[3] + ret + +Compute8x16Unit: + subs x14, x14, #2 + ble Compute8x16End + Compute8x16: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10], #16 + ld1 {v4.8h, v5.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v6.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + + fmla v8.8h, v5.8h, v1.h[0] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v1.h[2] + fmla v14.8h, v5.8h, v1.h[3] + fmla v16.8h, v5.8h, v1.h[4] + fmla v18.8h, v5.8h, v1.h[5] + fmla v20.8h, v5.8h, v1.h[6] + fmla v22.8h, v5.8h, v1.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v6.8h, v1.h[0] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v1.h[2] + fmla v15.8h, v6.8h, v1.h[3] + prfm pldl1keep, [x10, #632] + ld1 {v0.8h}, [x10], #16 + fmla v17.8h, v6.8h, v1.h[4] + fmla v19.8h, v6.8h, v1.h[5] + fmla v21.8h, v6.8h, v1.h[6] + fmla v23.8h, v6.8h, v1.h[7] + + subs x14, x14, #2 + bgt Compute8x16 + Compute8x16End: + cbnz x14, Compute8x16End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + mov v0.16b, v1.16b + Compute8x16End1: + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + fmla v17.8h, v4.8h, v0.h[4] + fmla v19.8h, v4.8h, v0.h[5] + fmla v21.8h, v4.8h, v0.h[6] + fmla v23.8h, v4.8h, v0.h[7] + ret + +Compute8x8Unit: + subs x14, x14, #2 + ble Compute8x8End + Compute8x8: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10], #16 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v4.8h, v1.h[0] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v1.h[2] + fmla v14.8h, v4.8h, v1.h[3] + ld1 {v0.8h}, [x10], #16 + fmla v16.8h, v4.8h, v1.h[4] + fmla v18.8h, v4.8h, v1.h[5] + fmla v20.8h, v4.8h, v1.h[6] + fmla v22.8h, v4.8h, v1.h[7] + + subs x14, x14, #2 + bgt Compute8x8 + Compute8x8End: + cbnz x14, Compute8x8End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + mov v0.16b, v1.16b + mov v3.16b, v4.16b + Compute8x8End1: + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v16.8h, v3.8h, v0.h[4] + fmla v18.8h, v3.8h, v0.h[5] + fmla v20.8h, v3.8h, v0.h[6] + fmla v22.8h, v3.8h, v0.h[7] + ret + +Compute8x4Unit: + subs x14, x14, #2 + ble Compute8x4End + Compute8x4: + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10], #16 + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v4.4h, v1.h[0] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v1.h[2] + fmla v14.4h, v4.4h, v1.h[3] + ld1 {v0.8h}, [x10], #16 + fmla v16.4h, v4.4h, v1.h[4] + fmla v18.4h, v4.4h, v1.h[5] + fmla v20.4h, v4.4h, v1.h[6] + fmla v22.4h, v4.4h, v1.h[7] + + subs x14, x14, #2 + bgt Compute8x4 + Compute8x4End: + cbnz x14, Compute8x4End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.8h}, [x10] + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + mov v0.16b, v1.16b + mov v3.8b, v4.8b + Compute8x4End1: + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + fmla v16.4h, v3.4h, v0.h[4] + fmla v18.4h, v3.4h, v0.h[5] + fmla v20.4h, v3.4h, v0.h[6] + fmla v22.4h, v3.4h, v0.h[7] + ret + +Compute4x16Unit: + subs x14, x14, #2 + ble Compute4x16End + Compute4x16: + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h, v5.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v6.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + + fmla v8.8h, v5.8h, v1.h[0] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v1.h[2] + fmla v14.8h, v5.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v6.8h, v1.h[0] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v1.h[2] + fmla v15.8h, v6.8h, v1.h[3] + ld1 {v0.4h}, [x10], #8 + + subs x14, x14, #2 + bgt Compute4x16 + Compute4x16End: + cbnz x14, Compute4x16End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + mov v0.8b, v1.8b + Compute4x16End1: + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v0.h[1] + fmla v13.8h, v4.8h, v0.h[2] + fmla v15.8h, v4.8h, v0.h[3] + ret + +Compute4x8Unit: + subs x14, x14, #2 + ble Compute4x8End + Compute4x8: + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v4.8h, v1.h[0] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v1.h[2] + fmla v14.8h, v4.8h, v1.h[3] + ld1 {v0.4h}, [x10], #8 + + subs x14, x14, #2 + bgt Compute4x8 + Compute4x8End: + cbnz x14, Compute4x8End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10] + ld1 {v4.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + mov v0.8b, v1.8b + mov v3.16b, v4.16b + Compute4x8End1: + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v0.h[1] + fmla v12.8h, v3.8h, v0.h[2] + fmla v14.8h, v3.8h, v0.h[3] + ret + +Compute4x4Unit: + subs x14, x14, #2 + ble Compute4x4End + Compute4x4: + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10], #8 + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v4.4h, v1.h[0] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v1.h[2] + fmla v14.4h, v4.4h, v1.h[3] + ld1 {v0.4h}, [x10], #8 + + subs x14, x14, #2 + bgt Compute4x4 + Compute4x4End: + cbnz x14, Compute4x4End1 + prfm pldl1keep, [x10, #632] + ld1 {v1.4h}, [x10] + ld1 {v4.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + mov v0.8b, v1.8b + mov v3.8b, v4.8b + Compute4x4End1: + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v0.h[1] + fmla v12.4h, v3.4h, v0.h[2] + fmla v14.4h, v3.4h, v0.h[3] + ret + +Compute3x16Unit: + add x19, x10, x16 + add x20, x10, x16, lsl #1 + subs x14, x14, #8 + blt Compute3x16End4 + Compute3x16: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + ld1 {v2.8h}, [x20], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v2.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v2.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + fmla v12.8h, v3.8h, v2.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v13.8h, v4.8h, v2.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v10.8h, v5.8h, v1.h[3] + fmla v12.8h, v5.8h, v2.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[3] + fmla v11.8h, v6.8h, v1.h[3] + fmla v13.8h, v6.8h, v2.h[3] + + fmla v8.8h, v3.8h, v0.h[4] + fmla v10.8h, v3.8h, v1.h[4] + fmla v12.8h, v3.8h, v2.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[4] + fmla v11.8h, v4.8h, v1.h[4] + fmla v13.8h, v4.8h, v2.h[4] + fmla v8.8h, v5.8h, v0.h[5] + fmla v10.8h, v5.8h, v1.h[5] + fmla v12.8h, v5.8h, v2.h[5] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[5] + fmla v11.8h, v6.8h, v1.h[5] + fmla v13.8h, v6.8h, v2.h[5] + fmla v8.8h, v3.8h, v0.h[6] + fmla v10.8h, v3.8h, v1.h[6] + fmla v12.8h, v3.8h, v2.h[6] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[6] + fmla v11.8h, v4.8h, v1.h[6] + fmla v13.8h, v4.8h, v2.h[6] + fmla v8.8h, v5.8h, v0.h[7] + fmla v10.8h, v5.8h, v1.h[7] + fmla v12.8h, v5.8h, v2.h[7] + fmla v9.8h, v6.8h, v0.h[7] + fmla v11.8h, v6.8h, v1.h[7] + fmla v13.8h, v6.8h, v2.h[7] + + subs x14, x14, #8 + bge Compute3x16 + Compute3x16End4: + adds x14, x14, #8 + cbz x14, Compute3x16Return + subs x14, x14, #4 + blt Compute3x16EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + ld1 {v2.4h}, [x20], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v2.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v2.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + fmla v12.8h, v3.8h, v2.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v13.8h, v4.8h, v2.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v10.8h, v5.8h, v1.h[3] + fmla v12.8h, v5.8h, v2.h[3] + fmla v9.8h, v6.8h, v0.h[3] + fmla v11.8h, v6.8h, v1.h[3] + fmla v13.8h, v6.8h, v2.h[3] + subs x14, x14, #4 + Compute3x16EndTail: + adds x14, x14, #4 + cbz x14, Compute3x16Return + cmp x14, #1 + beq Compute3x16EndTail1 + cmp x14, #2 + beq Compute3x16EndTail2 + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld1 {v2.s}[0], [x20], #4 + ld1 {v2.h}[2], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v2.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v2.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + fmla v12.8h, v3.8h, v2.h[2] + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v13.8h, v4.8h, v2.h[2] + b Compute3x16Return + Compute3x16EndTail2: + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld1 {v2.s}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + fmla v12.8h, v5.8h, v2.h[1] + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v13.8h, v6.8h, v2.h[1] + b Compute3x16Return + Compute3x16EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + ld1 {v2.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v2.h[0] + Compute3x16Return: + ret + +Compute3x8Unit: + add x19, x10, x16 + add x20, x10, x16, lsl #1 + subs x14, x14, #8 + blt Compute3x8End4 + Compute3x8: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + ld1 {v2.8h}, [x20], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v2.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + fmla v12.8h, v5.8h, v2.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v6.8h, v0.h[3] + fmla v10.8h, v6.8h, v1.h[3] + fmla v12.8h, v6.8h, v2.h[3] + fmla v8.8h, v3.8h, v0.h[4] + fmla v10.8h, v3.8h, v1.h[4] + fmla v12.8h, v3.8h, v2.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[5] + fmla v10.8h, v4.8h, v1.h[5] + fmla v12.8h, v4.8h, v2.h[5] + fmla v8.8h, v5.8h, v0.h[6] + fmla v10.8h, v5.8h, v1.h[6] + fmla v12.8h, v5.8h, v2.h[6] + fmla v8.8h, v6.8h, v0.h[7] + fmla v10.8h, v6.8h, v1.h[7] + fmla v12.8h, v6.8h, v2.h[7] + + subs x14, x14, #8 + bge Compute3x8 + Compute3x8End4: + adds x14, x14, #8 + cbz x14, Compute3x8Return + subs x14, x14, #4 + blt Compute3x8EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + ld1 {v2.4h}, [x20], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v2.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + fmla v12.8h, v5.8h, v2.h[2] + fmla v8.8h, v6.8h, v0.h[3] + fmla v10.8h, v6.8h, v1.h[3] + fmla v12.8h, v6.8h, v2.h[3] + subs x14, x14, #4 + Compute3x8EndTail: + adds x14, x14, #4 + cbz x14, Compute3x8Return + cmp x14, #1 + beq Compute3x8EndTail1 + cmp x14, #2 + beq Compute3x8EndTail2 + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld1 {v2.s}[0], [x20], #4 + ld1 {v2.h}[2], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + ld1 {v5.8h}, [x11], #16 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v12.8h, v4.8h, v2.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + fmla v12.8h, v5.8h, v2.h[2] + b Compute3x8Return + Compute3x8EndTail2: + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld2 {v2.h, v3.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v5.8h, v0.h[0] + fmla v10.8h, v5.8h, v1.h[0] + fmla v12.8h, v5.8h, v2.h[0] + fmla v8.8h, v6.8h, v0.h[1] + fmla v10.8h, v6.8h, v1.h[1] + fmla v12.8h, v6.8h, v3.h[0] + b Compute3x8Return + Compute3x8EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + ld1 {v2.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v12.8h, v3.8h, v2.h[0] + Compute3x8Return: + ret + +Compute3x4Unit: + add x19, x10, x16 + add x20, x10, x16, lsl #1 + subs x14, x14, #8 + blt Compute3x4End4 + Compute3x4: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + ld1 {v2.8h}, [x20], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v12.4h, v3.4h, v2.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v2.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + fmla v12.4h, v5.4h, v2.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v6.4h, v0.h[3] + fmla v10.4h, v6.4h, v1.h[3] + fmla v12.4h, v6.4h, v2.h[3] + fmla v8.4h, v3.4h, v0.h[4] + fmla v10.4h, v3.4h, v1.h[4] + fmla v12.4h, v3.4h, v2.h[4] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[5] + fmla v10.4h, v4.4h, v1.h[5] + fmla v12.4h, v4.4h, v2.h[5] + fmla v8.4h, v5.4h, v0.h[6] + fmla v10.4h, v5.4h, v1.h[6] + fmla v12.4h, v5.4h, v2.h[6] + fmla v8.4h, v6.4h, v0.h[7] + fmla v10.4h, v6.4h, v1.h[7] + fmla v12.4h, v6.4h, v2.h[7] + + subs x14, x14, #8 + bge Compute3x4 + Compute3x4End4: + adds x14, x14, #8 + cbz x14, Compute3x4Return + subs x14, x14, #4 + blt Compute3x4EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + ld1 {v2.4h}, [x20], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v12.4h, v3.4h, v2.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v2.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + fmla v12.4h, v5.4h, v2.h[2] + fmla v8.4h, v6.4h, v0.h[3] + fmla v10.4h, v6.4h, v1.h[3] + fmla v12.4h, v6.4h, v2.h[3] + subs x14, x14, #4 + Compute3x4EndTail: + adds x14, x14, #4 + cbz x14, Compute3x4Return + cmp x14, #1 + beq Compute3x4EndTail1 + cmp x14, #2 + beq Compute3x4EndTail2 + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld1 {v2.s}[0], [x20], #4 + ld1 {v2.h}[2], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v12.4h, v3.4h, v2.h[0] + ld1 {v5.4h}, [x11], #8 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v12.4h, v4.4h, v2.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + fmla v12.4h, v5.4h, v2.h[2] + b Compute3x4Return + Compute3x4EndTail2: + ld1 {v0.4h}, [x10] + ld1 {v1.4h}, [x19] + ld2 {v2.h, v3.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v5.4h, v0.h[0] + fmla v10.4h, v5.4h, v1.h[0] + fmla v12.4h, v5.4h, v2.h[0] + fmla v8.4h, v6.4h, v0.h[1] + fmla v10.4h, v6.4h, v1.h[1] + fmla v12.4h, v6.4h, v3.h[0] + b Compute3x4Return + Compute3x4EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + ld1 {v2.h}[0], [x20] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v12.4h, v3.4h, v2.h[0] + Compute3x4Return: + ret + +Compute2x16Unit: + add x19, x10, x16 + subs x14, x14, #8 + blt Compute2x16End4 + Compute2x16: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v10.8h, v5.8h, v1.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[3] + fmla v11.8h, v6.8h, v1.h[3] + + fmla v8.8h, v3.8h, v0.h[4] + fmla v10.8h, v3.8h, v1.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[4] + fmla v11.8h, v4.8h, v1.h[4] + fmla v8.8h, v5.8h, v0.h[5] + fmla v10.8h, v5.8h, v1.h[5] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[5] + fmla v11.8h, v6.8h, v1.h[5] + fmla v8.8h, v3.8h, v0.h[6] + fmla v10.8h, v3.8h, v1.h[6] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[6] + fmla v11.8h, v4.8h, v1.h[6] + fmla v8.8h, v5.8h, v0.h[7] + fmla v10.8h, v5.8h, v1.h[7] + fmla v9.8h, v6.8h, v0.h[7] + fmla v11.8h, v6.8h, v1.h[7] + + subs x14, x14, #8 + bge Compute2x16 + Compute2x16End4: + adds x14, x14, #8 + cbz x14, Compute2x16Return + subs x14, x14, #4 + blt Compute2x16EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v10.8h, v5.8h, v1.h[3] + fmla v9.8h, v6.8h, v0.h[3] + fmla v11.8h, v6.8h, v1.h[3] + subs x14, x14, #4 + Compute2x16EndTail: + adds x14, x14, #4 + cbz x14, Compute2x16Return + cmp x14, #1 + beq Compute2x16EndTail1 + cmp x14, #2 + beq Compute2x16EndTail2 + ld1 {v0.4h}, [x10] + ld1 {v1.s}[0], [x19], #4 + ld1 {v1.h}[2], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v1.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v1.h[1] + fmla v8.8h, v3.8h, v0.h[2] + fmla v10.8h, v3.8h, v1.h[2] + fmla v9.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v1.h[2] + b Compute2x16Return + Compute2x16EndTail2: + ld1 {v0.4h}, [x10] + ld2 {v1.h, v2.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v2.h[0] + fmla v9.8h, v6.8h, v0.h[1] + fmla v11.8h, v6.8h, v2.h[0] + b Compute2x16Return + Compute2x16EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v9.8h, v4.8h, v0.h[0] + fmla v11.8h, v4.8h, v1.h[0] + Compute2x16Return: + ret + +Compute2x8Unit: + add x19, x10, x16 + subs x14, x14, #8 + blt Compute2x8End4 + Compute2x8: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v6.8h, v0.h[3] + fmla v10.8h, v6.8h, v1.h[3] + fmla v8.8h, v3.8h, v0.h[4] + fmla v10.8h, v3.8h, v1.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[5] + fmla v10.8h, v4.8h, v1.h[5] + fmla v8.8h, v5.8h, v0.h[6] + fmla v10.8h, v5.8h, v1.h[6] + fmla v8.8h, v6.8h, v0.h[7] + fmla v10.8h, v6.8h, v1.h[7] + + subs x14, x14, #8 + bge Compute2x8 + Compute2x8End4: + adds x14, x14, #8 + cbz x14, Compute2x8Return + subs x14, x14, #4 + blt Compute2x8EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v1.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v10.8h, v5.8h, v1.h[2] + fmla v8.8h, v6.8h, v0.h[3] + fmla v10.8h, v6.8h, v1.h[3] + subs x14, x14, #4 + Compute2x8EndTail: + adds x14, x14, #4 + cbz x14, Compute2x8Return + cmp x14, #1 + beq Compute2x8EndTail1 + cmp x14, #2 + beq Compute2x8EndTail2 + ld1 {v0.4h}, [x10] + ld3 {v1.h, v2.h, v3.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v4.8h, v5.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[0] + fmla v10.8h, v4.8h, v1.h[0] + ld1 {v6.8h}, [x11], #16 + fmla v8.8h, v5.8h, v0.h[1] + fmla v10.8h, v5.8h, v2.h[0] + fmla v8.8h, v6.8h, v0.h[2] + fmla v10.8h, v6.8h, v3.h[0] + b Compute2x8Return + Compute2x8EndTail2: + ld1 {v0.4h}, [x10] + ld2 {v1.h, v2.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + fmla v8.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v2.h[0] + b Compute2x8Return + Compute2x8EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + fmla v10.8h, v3.8h, v1.h[0] + Compute2x8Return: + ret + +Compute2x4Unit: + add x19, x10, x16 + subs x14, x14, #8 + blt Compute2x4End4 + Compute2x4: + ld1 {v0.8h}, [x10], #16 + ld1 {v1.8h}, [x19], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v6.4h, v0.h[3] + fmla v10.4h, v6.4h, v1.h[3] + fmla v8.4h, v3.4h, v0.h[4] + fmla v10.4h, v3.4h, v1.h[4] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[5] + fmla v10.4h, v4.4h, v1.h[5] + fmla v8.4h, v5.4h, v0.h[6] + fmla v10.4h, v5.4h, v1.h[6] + fmla v8.4h, v6.4h, v0.h[7] + fmla v10.4h, v6.4h, v1.h[7] + + subs x14, x14, #8 + bge Compute2x4 + Compute2x4End4: + adds x14, x14, #8 + cbz x14, Compute2x4Return + subs x14, x14, #4 + blt Compute2x4EndTail + ld1 {v0.4h}, [x10], #8 + ld1 {v1.4h}, [x19], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v1.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v10.4h, v5.4h, v1.h[2] + fmla v8.4h, v6.4h, v0.h[3] + fmla v10.4h, v6.4h, v1.h[3] + subs x14, x14, #4 + Compute2x4EndTail: + adds x14, x14, #4 + cbz x14, Compute2x4Return + cmp x14, #1 + beq Compute2x4EndTail1 + cmp x14, #2 + beq Compute2x4EndTail2 + ld1 {v0.4h}, [x10] + ld3 {v1.h, v2.h, v3.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v4.4h, v5.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[0] + fmla v10.4h, v4.4h, v1.h[0] + ld1 {v6.4h}, [x11], #8 + fmla v8.4h, v5.4h, v0.h[1] + fmla v10.4h, v5.4h, v2.h[0] + fmla v8.4h, v6.4h, v0.h[2] + fmla v10.4h, v6.4h, v3.h[0] + b Compute2x4Return + Compute2x4EndTail2: + ld1 {v0.4h}, [x10] + ld2 {v1.h, v2.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + fmla v8.4h, v4.4h, v0.h[1] + fmla v10.4h, v4.4h, v2.h[0] + b Compute2x4Return + Compute2x4EndTail1: + ld1 {v0.h}[0], [x10] + ld1 {v1.h}[0], [x19] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + fmla v10.4h, v3.4h, v1.h[0] + Compute2x4Return: + ret + +Compute1x16Unit: + subs x14, x14, #8 + blt Compute1x16End4 + Compute1x16: + ld1 {v0.8h}, [x10], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v8.8h, v5.8h, v0.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v8.8h, v3.8h, v0.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v8.8h, v5.8h, v0.h[3] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[3] + + fmla v8.8h, v3.8h, v0.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[4] + fmla v8.8h, v5.8h, v0.h[5] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[5] + fmla v8.8h, v3.8h, v0.h[6] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[6] + fmla v8.8h, v5.8h, v0.h[7] + fmla v9.8h, v6.8h, v0.h[7] + + subs x14, x14, #8 + bge Compute1x16 + Compute1x16End4: + adds x14, x14, #8 + cbz x14, Compute1x16Return + subs x14, x14, #4 + blt Compute1x16EndTail + ld1 {v0.4h}, [x10], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v8.8h, v5.8h, v0.h[1] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v0.h[1] + fmla v8.8h, v3.8h, v0.h[2] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[2] + fmla v8.8h, v5.8h, v0.h[3] + fmla v9.8h, v6.8h, v0.h[3] + subs x14, x14, #4 + Compute1x16EndTail: + adds x14, x14, #4 + cbz x14, Compute1x16Return + cmp x14, #1 + beq Compute1x16EndTail1 + cmp x14, #2 + beq Compute1x16EndTail2 + ld3 {v0.h, v1.h, v2.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v8.8h, v5.8h, v1.h[0] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v9.8h, v6.8h, v1.h[0] + fmla v8.8h, v3.8h, v2.h[0] + fmla v9.8h, v4.8h, v2.h[0] + b Compute1x16Return + Compute1x16EndTail2: + ld2 {v0.h, v1.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v9.8h, v4.8h, v0.h[0] + fmla v8.8h, v5.8h, v1.h[0] + fmla v9.8h, v6.8h, v1.h[0] + b Compute1x16Return + Compute1x16EndTail1: + ld1 {v0.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[0] + Compute1x16Return: + ret + +Compute1x8Unit: + subs x14, x14, #8 + blt Compute1x8End4 + Compute1x8: + ld1 {v0.8h}, [x10], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v8.8h, v5.8h, v0.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v6.8h, v0.h[3] + fmla v8.8h, v3.8h, v0.h[4] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[5] + fmla v8.8h, v5.8h, v0.h[6] + fmla v8.8h, v6.8h, v0.h[7] + + subs x14, x14, #8 + bge Compute1x8 + Compute1x8End4: + adds x14, x14, #8 + cbz x14, Compute1x8Return + subs x14, x14, #4 + blt Compute1x8EndTail + ld1 {v0.4h}, [x10], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h, v6.8h}, [x11], #32 + fmla v8.8h, v4.8h, v0.h[1] + fmla v8.8h, v5.8h, v0.h[2] + fmla v8.8h, v6.8h, v0.h[3] + subs x14, x14, #4 + Compute1x8EndTail: + adds x14, x14, #4 + cbz x14, Compute1x8Return + cmp x14, #1 + beq Compute1x8EndTail1 + cmp x14, #2 + beq Compute1x8EndTail2 + ld3 {v0.h, v1.h, v2.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + ld1 {v5.8h}, [x11], #16 + fmla v8.8h, v4.8h, v1.h[0] + fmla v8.8h, v5.8h, v2.h[0] + b Compute1x8Return + Compute1x8EndTail2: + ld2 {v0.h, v1.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h, v4.8h}, [x11], #32 + fmla v8.8h, v3.8h, v0.h[0] + fmla v8.8h, v4.8h, v1.h[0] + b Compute1x8Return + Compute1x8EndTail1: + ld1 {v0.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.8h}, [x11], #16 + fmla v8.8h, v3.8h, v0.h[0] + Compute1x8Return: + ret + +Compute1x4Unit: + subs x14, x14, #8 + blt Compute1x4End4 + Compute1x4: + ld1 {v0.8h}, [x10], #16 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v8.4h, v5.4h, v0.h[2] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v6.4h, v0.h[3] + fmla v8.4h, v3.4h, v0.h[4] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[5] + fmla v8.4h, v5.4h, v0.h[6] + fmla v8.4h, v6.4h, v0.h[7] + + subs x14, x14, #8 + bge Compute1x4 + Compute1x4End4: + adds x14, x14, #8 + cbz x14, Compute1x4Return + subs x14, x14, #4 + blt Compute1x4EndTail + ld1 {v0.4h}, [x10], #8 + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + ld1 {v5.4h, v6.4h}, [x11], #16 + fmla v8.4h, v4.4h, v0.h[1] + fmla v8.4h, v5.4h, v0.h[2] + fmla v8.4h, v6.4h, v0.h[3] + subs x14, x14, #4 + Compute1x4EndTail: + adds x14, x14, #4 + cbz x14, Compute1x4Return + cmp x14, #1 + beq Compute1x4EndTail1 + cmp x14, #2 + beq Compute1x4EndTail2 + ld3 {v0.h, v1.h, v2.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + ld1 {v5.4h}, [x11], #8 + fmla v8.4h, v4.4h, v1.h[0] + fmla v8.4h, v5.4h, v2.h[0] + b Compute1x4Return + Compute1x4EndTail2: + ld2 {v0.h, v1.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h, v4.4h}, [x11], #16 + fmla v8.4h, v3.4h, v0.h[0] + fmla v8.4h, v4.4h, v1.h[0] + b Compute1x4Return + Compute1x4EndTail1: + ld1 {v0.h}[0], [x10] + prfm pldl1strm, [x11, #632] + ld1 {v3.4h}, [x11], #8 + fmla v8.4h, v3.4h, v0.h[0] + Compute1x4Return: + ret + +End: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulWinogradFp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulWinogradFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..ca0542da187040ce17bcdb9139dee89e793d8a01 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/MatmulWinogradFp16.S @@ -0,0 +1,217 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// MatrixMultiplyWinogradFp16(float16_t *matix_a, float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n, int in_channel) + // x0: matrix_a, x1: matrix_b, x2: matrix_c, x3: m, x4: k, x5: n, x6: in_channel +asm_function MatrixMultiplyWinogradFp16 + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + st1 {v8.8h}, [sp] + stp x19, x20, [sp, #16] + stp x21, x22, [sp, #32] + + mov x8, #2 + mul x10, x5, x8 // n * 2 + mov x17, x3 // m + mul x13, x6, x8 // in_channel * 2 + mul x21, x13, x4 // in_channel * k * 2 + + LoopM: + mov x15, x5 // n + mov x14, x1 // mat_b + LoopN: + mov x16, x0 // mat_a_m + sub x22, x5, x15 // ni + sub x19, x17, x3 // mi + mul x22, x22, x17 // ni * m + mov x11, x6 // in_channel + add x22, x22, x19 // (ni * m) + mi + mul x22, x22, x13 // x22 * channel_in * 2 + add x20, x2, x22 // dst + offset + cmp x11, #32 + bge LoopC32 + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + b EndLoopC + LoopC32: + mov x12, x14 + mov x9, x4 // new_k + dup v5.8h, wzr + dup v6.8h, wzr + dup v7.8h, wzr + dup v8.8h, wzr + LoopK32: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.8h, v0.8h, v4.h[0] + fmla v6.8h, v1.8h, v4.h[0] + fmla v7.8h, v2.8h, v4.h[0] + fmla v8.8h, v3.8h, v4.h[0] + subs x9, x9, #1 + bne LoopK32 + Write32: + st1 {v5.8h}, [x20], #16 + st1 {v6.8h}, [x20], #16 + st1 {v7.8h}, [x20], #16 + st1 {v8.8h}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #64 // add 64B + subs x11, x11, #32 + beq EndLoopC + cmp x11, #32 + bge LoopC32 + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC16: + dup v5.8h, wzr + dup v6.8h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK16: + ld1 {v0.8h, v1.8h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.8h, v0.8h, v4.h[0] + fmla v6.8h, v1.8h, v4.h[0] + subs x9, x9, #1 + bne LoopK16 + Write16: + st1 {v5.8h}, [x20], #16 + st1 {v6.8h}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #32 // add 32B + subs x11, x11, #16 + beq EndLoopC + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC8: + dup v5.8h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK8: + ld1 {v0.8h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.8h, v0.8h, v4.h[0] + subs x9, x9, #1 + bne LoopK8 + Write8: + st1 {v5.8h}, [x20], #16 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #16 // add 16B + subs x11, x11, #8 + beq EndLoopC + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC4: + dup v5.4h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK4: + ld1 {v0.4h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.4h, v0.4h, v4.h[0] + subs x9, x9, #1 + bne LoopK4 + Write4: + st1 {v5.4h}, [x20], #8 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #8 // add 8B + subs x11, x11, #4 + beq EndLoopC + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC: + dup v5.8h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK: + ldr h0, [x16] + add x16, x16, x13 + ldr h4, [x12] + add x12, x12, x10 + fmul h0, h0, h4 + fadd h5, h5, h0 + subs x9, x9, #1 + bne LoopK + Write: + str h5, [x20], #2 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #2 // ptr add 2B + subs x11, x11, #1 + beq EndLoopC + b LoopC + + EndLoopC: + add x14, x14, #2 + subs x15, x15, #1 + beq EndLoopN + b LoopN + EndLoopN: + subs x3, x3, #1 + beq EndLoopM + add x0, x0, x21 + b LoopM + + EndLoopM: + ld1 {v8.8h}, [sp], #16 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/PostFuncBiasReluC4Fp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/PostFuncBiasReluC4Fp16.S new file mode 100644 index 0000000000000000000000000000000000000000..3a72b877ab4984c95e8bf312d2b4b22375633e71 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/PostFuncBiasReluC4Fp16.S @@ -0,0 +1,293 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC4Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc4div, size_t oc4mod, +// size_t plane_size, size_t plane_stride, size_t relu_type); +// x0 dst x1 srx x2 bias +// w3 oc4div w4 oc4mod w5 plane_size +// x6 plane_stride x7 relu_type + +asm_function PostFuncBiasReluC4Fp16 + + movi v26.4h, #6 + scvtf v26.4h, v26.4h + dup v27.4h, wzr + + mov x10, #2 + add x12, x3, x4 + mul x12, x12, x10 + + mov w10, #0 + +Loop_C4: + cmp w10, w3 + beq Loop_C1 + mov x15, #2 + mul x14, x10, x15 + add x15, x0, x14 + add w10, w10, #4 + mov w13, w5 + ld1 {v16.4h}, [x2], #8 + +Loop_8x4: + cmp w13, #8 + blt Loop_4x4 + sub w13, w13, #8 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x1], #32 + ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x1], #32 + + fadd v0.4h, v0.4h, v16.4h + fadd v1.4h, v1.4h, v16.4h + fadd v2.4h, v2.4h, v16.4h + fadd v3.4h, v3.4h, v16.4h + fadd v4.4h, v4.4h, v16.4h + fadd v5.4h, v5.4h, v16.4h + fadd v6.4h, v6.4h, v16.4h + fadd v7.4h, v7.4h, v16.4h + + cmp x7, #3 + beq Relu6_8x4 + cmp x7, #1 + beq Relu_8x4 + b Write_8x4 +Relu6_8x4: + fmin v0.4h, v0.4h, v26.4h + fmin v1.4h, v1.4h, v26.4h + fmin v2.4h, v2.4h, v26.4h + fmin v3.4h, v3.4h, v26.4h + fmin v4.4h, v4.4h, v26.4h + fmin v5.4h, v5.4h, v26.4h + fmin v6.4h, v6.4h, v26.4h + fmin v7.4h, v7.4h, v26.4h +Relu_8x4: + fmax v0.4h, v0.4h, v27.4h + fmax v1.4h, v1.4h, v27.4h + fmax v2.4h, v2.4h, v27.4h + fmax v3.4h, v3.4h, v27.4h + fmax v4.4h, v4.4h, v27.4h + fmax v5.4h, v5.4h, v27.4h + fmax v6.4h, v6.4h, v27.4h + fmax v7.4h, v7.4h, v27.4h +Write_8x4: + st1 {v0.4h}, [x15], x12 + st1 {v1.4h}, [x15], x12 + st1 {v2.4h}, [x15], x12 + st1 {v3.4h}, [x15], x12 + st1 {v4.4h}, [x15], x12 + st1 {v5.4h}, [x15], x12 + st1 {v6.4h}, [x15], x12 + st1 {v7.4h}, [x15], x12 + b Loop_8x4 + +Loop_4x4: + cmp w13, #4 + blt Loop_1x4 + sub w13, w13, #4 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x1], #32 + fadd v0.4h, v0.4h, v16.4h + fadd v1.4h, v1.4h, v16.4h + fadd v2.4h, v2.4h, v16.4h + fadd v3.4h, v3.4h, v16.4h + cmp x7, #3 + beq Relu6_4x4 + cmp x7, #1 + beq Relu_4x4 + b Write_4x4 +Relu6_4x4: + fmin v0.4h, v0.4h, v26.4h + fmin v1.4h, v1.4h, v26.4h + fmin v2.4h, v2.4h, v26.4h + fmin v3.4h, v3.4h, v26.4h +Relu_4x4: + fmax v0.4h, v0.4h, v27.4h + fmax v1.4h, v1.4h, v27.4h + fmax v2.4h, v2.4h, v27.4h + fmax v3.4h, v3.4h, v27.4h +Write_4x4: + st1 {v0.4h}, [x15], x12 + st1 {v1.4h}, [x15], x12 + st1 {v2.4h}, [x15], x12 + st1 {v3.4h}, [x15], x12 + +Loop_1x4: + cmp x7, #3 + beq Relu6_1x4 + cmp x7, #1 + beq Relu_1x4 + b Write_1x4 +Relu6_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmin v0.4h, v0.4h, v26.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.4h}, [x15], x12 + b Relu6_1x4 +Relu_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.4h}, [x15], x12 + b Relu_1x4 +Write_1x4: + cmp w13, #0 + beq HW_Add + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + st1 {v0.4h}, [x15], x12 + b Write_1x4 + +HW_Add: + add x1, x1, x6 + b Loop_C4 + +Loop_C1: + cmp w4, #0 + beq End + mov w13, w5 + ld1 {v16.4h}, [x2], #8 + mov x15, #2 + mul x14, x10, x15 + add x0, x0, x14 + + cmp w4, #1 + beq Loop_C1_1 + cmp w4, #2 + beq Loop_C1_2 + cmp w4, #3 + beq Loop_C1_3 + +Loop_C1_1: + cmp x7, #3 + beq Loop_C1_1_Relu6 + cmp x7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmin v0.4h, v0.4h, v26.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.h}[0], [x0], x12 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.h}[0], [x0], x12 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + st1 {v0.h}[0], [x0], x12 + b Loop_C1_1_Write + +Loop_C1_2: + cmp x7, #3 + beq Loop_C1_2_Relu6 + cmp x7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmin v0.4h, v0.4h, v26.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.s}[0], [x0], x12 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.s}[0], [x0], x12 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + st1 {v0.s}[0], [x0], x12 + b Loop_C1_2_Write + +Loop_C1_3: + add x15, x0, #4 + cmp x7, #3 + beq Loop_C1_3_Relu6 + cmp x7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmin v0.4h, v0.4h, v26.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.s}[0], [x0], x12 + st1 {v0.h}[2], [x15], x12 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + fmax v0.4h, v0.4h, v27.4h + st1 {v0.s}[0], [x0], x12 + st1 {v0.h}[2], [x15], x12 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.4h}, [x1], #8 + fadd v0.4h, v0.4h, v16.4h + st1 {v0.s}[0], [x0], x12 + st1 {v0.h}[2], [x15], x12 + b Loop_C1_3_Write + +End: + ret + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/PostFuncBiasReluC8Fp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/PostFuncBiasReluC8Fp16.S new file mode 100644 index 0000000000000000000000000000000000000000..367b42705c1e1016162a63e14befedd451f0ea1a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/PostFuncBiasReluC8Fp16.S @@ -0,0 +1,469 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +//void PostFuncBiasReluC8Fp16(float *dst, const float *src, const float *bias, size_t oc8div,size_t oc8mod +// size_t plane_size, size_t stride, int relu_type); +// x0 dst x1 srx x2 bias +// x3 oc8div x4 oc8mod x5 plane_size +// x6 stride x7 relu_type + +// v0 ~ v7 value +// v16 bias data +// x22 x23 x24 x25 write loop tmp buf +// x26 relu6 #6; x27 relu #0 +// w10 oc8 loop control +// w13 hw loop control + +asm_function PostFuncBiasReluC8Fp16 + movi v26.8h, #0x46, lsl #8 + dup v27.8h, wzr + mov w10, #0 + +Loop_C8: + cmp w10, w3 + beq Loop_C1 + mov x25, #2 + mul x24, x10, x25 + add x25, x0, x24 + add w10, w10, #8 + mov w13, w5 + ld1 {v16.8h}, [x2], #16 + +Loop8x8: + cmp w13, #8 + blt Loop_4x8 + sub w13, w13, #8 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x1], #64 + + fadd v0.8h, v0.8h, v16.8h + fadd v1.8h, v1.8h, v16.8h + fadd v2.8h, v2.8h, v16.8h + fadd v3.8h, v3.8h, v16.8h + fadd v4.8h, v4.8h, v16.8h + fadd v5.8h, v5.8h, v16.8h + fadd v6.8h, v6.8h, v16.8h + fadd v7.8h, v7.8h, v16.8h + + cmp w7, #2 + beq Relu6_8x8 + cmp w7, #1 + beq Relu_8x8 + b Write_8x8 +Relu6_8x8: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h + fmin v4.8h, v4.8h, v26.8h + fmin v5.8h, v5.8h, v26.8h + fmin v6.8h, v6.8h, v26.8h + fmin v7.8h, v7.8h, v26.8h +Relu_8x8: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h + fmax v4.8h, v4.8h, v27.8h + fmax v5.8h, v5.8h, v27.8h + fmax v6.8h, v6.8h, v27.8h + fmax v7.8h, v7.8h, v27.8h +Write_8x8: + st1 {v0.8h}, [x25], x6 + st1 {v1.8h}, [x25], x6 + st1 {v2.8h}, [x25], x6 + st1 {v3.8h}, [x25], x6 + st1 {v4.8h}, [x25], x6 + st1 {v5.8h}, [x25], x6 + st1 {v6.8h}, [x25], x6 + st1 {v7.8h}, [x25], x6 + b Loop8x8 + +Loop_4x8: + cmp w13, #4 + blt Loop_1x8 + sub w13, w13, #4 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 + + fadd v0.8h, v0.8h, v16.8h + fadd v1.8h, v1.8h, v16.8h + fadd v2.8h, v2.8h, v16.8h + fadd v3.8h, v3.8h, v16.8h + + cmp w7, #2 + beq Relu6_4x8 + cmp w7, #1 + beq Relu_4x8 + b Write_4x8 +Relu6_4x8: + fmin v0.8h, v0.8h, v26.8h + fmin v1.8h, v1.8h, v26.8h + fmin v2.8h, v2.8h, v26.8h + fmin v3.8h, v3.8h, v26.8h +Relu_4x8: + fmax v0.8h, v0.8h, v27.8h + fmax v1.8h, v1.8h, v27.8h + fmax v2.8h, v2.8h, v27.8h + fmax v3.8h, v3.8h, v27.8h +Write_4x8: + st1 {v0.8h}, [x25], x6 + st1 {v1.8h}, [x25], x6 + st1 {v2.8h}, [x25], x6 + st1 {v3.8h}, [x25], x6 + b Loop_4x8 + +Loop_1x8: + cmp w7, #2 + beq Relu6_1x8 + cmp w7, #1 + beq Relu_1x8 + b Write_1x8 +Relu6_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.8h}, [x25], x6 + b Relu6_1x8 +Relu_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.8h}, [x25], x6 + b Relu_1x8 +Write_1x8: + cmp w13, #0 + beq Loop_C8 + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.8h}, [x25], x6 + b Write_1x8 + + +Loop_C1: + cmp x4, #0 + beq End + mov w13, w5 + ld1 {v16.8h}, [x2], #16 + mov x25, #2 + mul x24, x10, x25 + add x22, x0, x24 + + cmp x4, #1 + beq Loop_C1_1 + cmp x4, #2 + beq Loop_C1_2 + cmp x4, #3 + beq Loop_C1_3 + cmp x4, #4 + beq Loop_C1_4 + cmp x4, #5 + beq Loop_C1_5 + cmp x4, #6 + beq Loop_C1_6 + cmp x4, #7 + beq Loop_C1_7 + +Loop_C1_1: + cmp w7, #2 + beq Loop_C1_1_Relu6 + cmp w7, #1 + beq Loop_C1_1_Relu + b Loop_C1_1_Write +Loop_C1_1_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + b Loop_C1_1_Relu6 +Loop_C1_1_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + b Loop_C1_1_Relu +Loop_C1_1_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.h}[0], [x22], x6 + b Loop_C1_1_Write + +Loop_C1_2: + add x24, x0, #2 + cmp w7, #2 + beq Loop_C1_2_Relu6 + cmp w7, #1 + beq Loop_C1_2_Relu + b Loop_C1_2_Write +Loop_C1_2_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + b Loop_C1_2_Relu6 +Loop_C1_2_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + b Loop_C1_2_Relu +Loop_C1_2_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + b Loop_C1_2_Write + + +Loop_C1_3: + add x24, x22, #2 + add x25, x22, #4 + cmp w7, #2 + beq Loop_C1_3_Relu6 + cmp w7, #1 + beq Loop_C1_3_Relu + b Loop_C1_3_Write +Loop_C1_3_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + st1 {v0.h}[2], [x25], x6 + b Loop_C1_3_Relu6 +Loop_C1_3_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + st1 {v0.h}[2], [x25], x6 + b Loop_C1_3_Relu +Loop_C1_3_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.h}[0], [x22], x6 + st1 {v0.h}[1], [x24], x6 + st1 {v0.h}[2], [x25], x6 + b Loop_C1_3_Write + +Loop_C1_4: + cmp w7, #2 + beq Loop_C1_4_Relu6 + cmp w7, #1 + beq Loop_C1_4_Relu + b Loop_C1_4_Write +Loop_C1_4_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + b Loop_C1_4_Relu6 +Loop_C1_4_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x22], x6 + b Loop_C1_4_Write + +Loop_C1_5: + add x25, x22, #8 + cmp w7, #2 + beq Loop_C1_5_Relu6 + cmp w7, #1 + beq Loop_C1_5_Relu + b Loop_C1_5_Write +Loop_C1_5_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x25], x6 + b Loop_C1_5_Relu6 +Loop_C1_5_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x25], x6 + b Loop_C1_5_Relu +Loop_C1_5_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x25], x6 + b Loop_C1_5_Write + +Loop_C1_6: + add x23, x22, #8 + add x24, x22, #10 + cmp w7, #2 + beq Loop_C1_6_Relu6 + cmp w7, #1 + beq Loop_C1_6_Relu + b Loop_C1_6_Write +Loop_C1_6_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + b Loop_C1_6_Relu6 +Loop_C1_6_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + b Loop_C1_6_Relu +Loop_C1_6_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + b Loop_C1_6_Write + +Loop_C1_7: + add x23, x22, #8 + add x24, x22, #10 + add x25, x22, #12 + cmp w7, #2 + beq Loop_C1_7_Relu6 + cmp w7, #1 + beq Loop_C1_7_Relu + b Loop_C1_7_Write +Loop_C1_7_Relu6: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmin v0.8h, v0.8h, v26.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + st1 {v0.h}[6], [x25], x6 + b Loop_C1_7_Relu6 +Loop_C1_7_Relu: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + fmax v0.8h, v0.8h, v27.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + st1 {v0.h}[6], [x25], x6 + b Loop_C1_7_Relu +Loop_C1_7_Write: + cmp w13, #0 + beq End + sub w13, w13, #1 + ld1 {v0.8h}, [x1], #16 + fadd v0.8h, v0.8h, v16.8h + st1 {v0.4h}, [x22], x6 + st1 {v0.h}[4], [x23], x6 + st1 {v0.h}[5], [x24], x6 + st1 {v0.h}[6], [x25], x6 + b Loop_C1_7_Write + +End: + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/TiledC4MatmulFp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/TiledC4MatmulFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..5b616ae75b2b5032e712048bed138771cda2190c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/TiledC4MatmulFp16.S @@ -0,0 +1,273 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +asm_function TiledC4MatmulFp16 + +sub sp, sp, #128 +st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp] +add x9, sp, #64 +st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9] + +mov x7, #2 //sizeof(float) +mul x3, x3, x7 +mov x7, #32 +mul x10, x4, x7 + +cmp x5, #2 +blt LoopOcHalf +LoopOc: + mov x8, x1 + subs x9, x4, #1 + + add x6, x2, x10 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x8], #32 + ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x2], #32 + fmul v16.4h, v8.4h, v0.h[0] + fmul v17.4h, v8.4h, v1.h[0] + ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x8], #32 + fmul v18.4h, v8.4h, v2.h[0] + fmul v19.4h, v8.4h, v3.h[0] + ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [x6], #32 + fmul v20.4h, v8.4h, v4.h[0] + fmul v21.4h, v8.4h, v5.h[0] + fmul v22.4h, v8.4h, v6.h[0] + fmul v23.4h, v8.4h, v7.h[0] + fmul v24.4h, v12.4h, v0.h[0] + fmul v25.4h, v12.4h, v1.h[0] + fmul v26.4h, v12.4h, v2.h[0] + fmul v27.4h, v12.4h, v3.h[0] + fmul v28.4h, v12.4h, v4.h[0] + fmul v29.4h, v12.4h, v5.h[0] + fmul v30.4h, v12.4h, v6.h[0] + fmul v31.4h, v12.4h, v7.h[0] + + beq LoopIcEnd + LoopIc: + add x2, x2, #64 + prfm pldl1keep, [x2] + prfm pldl1keep, [x2, x10] + sub x2, x2, #64 + prfm pldl1keep, [x8, #64] + prfm pldl1keep, [x8, #96] + + fmla v16.4h, v9.4h, v0.h[1] + fmla v17.4h, v9.4h, v1.h[1] + fmla v18.4h, v9.4h, v2.h[1] + fmla v19.4h, v9.4h, v3.h[1] + fmla v20.4h, v9.4h, v4.h[1] + fmla v21.4h, v9.4h, v5.h[1] + fmla v22.4h, v9.4h, v6.h[1] + fmla v23.4h, v9.4h, v7.h[1] + fmla v24.4h, v13.4h, v0.h[1] + fmla v25.4h, v13.4h, v1.h[1] + fmla v26.4h, v13.4h, v2.h[1] + fmla v27.4h, v13.4h, v3.h[1] + fmla v28.4h, v13.4h, v4.h[1] + fmla v29.4h, v13.4h, v5.h[1] + fmla v30.4h, v13.4h, v6.h[1] + fmla v31.4h, v13.4h, v7.h[1] + + fmla v16.4h, v10.4h, v0.h[2] + fmla v17.4h, v10.4h, v1.h[2] + fmla v18.4h, v10.4h, v2.h[2] + fmla v19.4h, v10.4h, v3.h[2] + fmla v20.4h, v10.4h, v4.h[2] + fmla v21.4h, v10.4h, v5.h[2] + fmla v22.4h, v10.4h, v6.h[2] + fmla v23.4h, v10.4h, v7.h[2] + fmla v24.4h, v14.4h, v0.h[2] + fmla v25.4h, v14.4h, v1.h[2] + fmla v26.4h, v14.4h, v2.h[2] + fmla v27.4h, v14.4h, v3.h[2] + fmla v28.4h, v14.4h, v4.h[2] + fmla v29.4h, v14.4h, v5.h[2] + fmla v30.4h, v14.4h, v6.h[2] + fmla v31.4h, v14.4h, v7.h[2] + + fmla v16.4h, v11.4h, v0.h[3] + fmla v17.4h, v11.4h, v1.h[3] + fmla v18.4h, v11.4h, v2.h[3] + fmla v19.4h, v11.4h, v3.h[3] + fmla v20.4h, v11.4h, v4.h[3] + fmla v21.4h, v11.4h, v5.h[3] + fmla v22.4h, v11.4h, v6.h[3] + fmla v23.4h, v11.4h, v7.h[3] + fmla v24.4h, v15.4h, v0.h[3] + fmla v25.4h, v15.4h, v1.h[3] + fmla v26.4h, v15.4h, v2.h[3] + fmla v27.4h, v15.4h, v3.h[3] + fmla v28.4h, v15.4h, v4.h[3] + fmla v29.4h, v15.4h, v5.h[3] + fmla v30.4h, v15.4h, v6.h[3] + fmla v31.4h, v15.4h, v7.h[3] + + ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x2], #32 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x8], #32 + fmla v16.4h, v8.4h, v0.h[0] + fmla v17.4h, v8.4h, v1.h[0] + ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x8], #32 + fmla v18.4h, v8.4h, v2.h[0] + fmla v19.4h, v8.4h, v3.h[0] + ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [x6], #32 + fmla v20.4h, v8.4h, v4.h[0] + fmla v21.4h, v8.4h, v5.h[0] + fmla v22.4h, v8.4h, v6.h[0] + fmla v23.4h, v8.4h, v7.h[0] + fmla v24.4h, v12.4h, v0.h[0] + fmla v25.4h, v12.4h, v1.h[0] + fmla v26.4h, v12.4h, v2.h[0] + fmla v27.4h, v12.4h, v3.h[0] + fmla v28.4h, v12.4h, v4.h[0] + fmla v29.4h, v12.4h, v5.h[0] + fmla v30.4h, v12.4h, v6.h[0] + fmla v31.4h, v12.4h, v7.h[0] + + subs x9, x9, #1 + bne LoopIc + + LoopIcEnd: + fmla v16.4h, v9.4h, v0.h[1] + fmla v17.4h, v9.4h, v1.h[1] + fmla v18.4h, v9.4h, v2.h[1] + fmla v19.4h, v9.4h, v3.h[1] + fmla v20.4h, v9.4h, v4.h[1] + fmla v21.4h, v9.4h, v5.h[1] + fmla v22.4h, v9.4h, v6.h[1] + fmla v23.4h, v9.4h, v7.h[1] + fmla v24.4h, v13.4h, v0.h[1] + fmla v25.4h, v13.4h, v1.h[1] + fmla v26.4h, v13.4h, v2.h[1] + fmla v27.4h, v13.4h, v3.h[1] + fmla v28.4h, v13.4h, v4.h[1] + fmla v29.4h, v13.4h, v5.h[1] + fmla v30.4h, v13.4h, v6.h[1] + fmla v31.4h, v13.4h, v7.h[1] + + fmla v16.4h, v10.4h, v0.h[2] + fmla v17.4h, v10.4h, v1.h[2] + fmla v18.4h, v10.4h, v2.h[2] + fmla v19.4h, v10.4h, v3.h[2] + fmla v20.4h, v10.4h, v4.h[2] + fmla v21.4h, v10.4h, v5.h[2] + fmla v22.4h, v10.4h, v6.h[2] + fmla v23.4h, v10.4h, v7.h[2] + fmla v24.4h, v14.4h, v0.h[2] + fmla v25.4h, v14.4h, v1.h[2] + fmla v26.4h, v14.4h, v2.h[2] + fmla v27.4h, v14.4h, v3.h[2] + fmla v28.4h, v14.4h, v4.h[2] + fmla v29.4h, v14.4h, v5.h[2] + fmla v30.4h, v14.4h, v6.h[2] + fmla v31.4h, v14.4h, v7.h[2] + + add x7, x0, #32 + + fmla v16.4h, v11.4h, v0.h[3] + fmla v17.4h, v11.4h, v1.h[3] + fmla v18.4h, v11.4h, v2.h[3] + fmla v19.4h, v11.4h, v3.h[3] + fmla v20.4h, v11.4h, v4.h[3] + fmla v21.4h, v11.4h, v5.h[3] + fmla v22.4h, v11.4h, v6.h[3] + fmla v23.4h, v11.4h, v7.h[3] + fmla v24.4h, v15.4h, v0.h[3] + fmla v25.4h, v15.4h, v1.h[3] + fmla v26.4h, v15.4h, v2.h[3] + fmla v27.4h, v15.4h, v3.h[3] + fmla v28.4h, v15.4h, v4.h[3] + st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], x3 + fmla v29.4h, v15.4h, v5.h[3] + st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x7], x3 + fmla v30.4h, v15.4h, v6.h[3] + st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [x0], x3 + mov x2, x6 + fmla v31.4h, v15.4h, v7.h[3] + st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [x7] + + subs x5, x5, #2 + beq LoopOcEnd + cmp x5, #2 + bge LoopOc + +LoopOcHalf: + mov x8, x1 + mov x9, x4 + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + + LoopIcHalf: + ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x2], #32 + ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x8], #32 + fmla v16.4h, v8.4h, v0.h[0] + fmla v17.4h, v8.4h, v1.h[0] + ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x8], #32 + fmla v18.4h, v8.4h, v2.h[0] + fmla v19.4h, v8.4h, v3.h[0] + fmla v20.4h, v8.4h, v4.h[0] + fmla v21.4h, v8.4h, v5.h[0] + fmla v22.4h, v8.4h, v6.h[0] + fmla v23.4h, v8.4h, v7.h[0] + + fmla v16.4h, v9.4h, v0.h[1] + fmla v17.4h, v9.4h, v1.h[1] + fmla v18.4h, v9.4h, v2.h[1] + fmla v19.4h, v9.4h, v3.h[1] + fmla v20.4h, v9.4h, v4.h[1] + fmla v21.4h, v9.4h, v5.h[1] + fmla v22.4h, v9.4h, v6.h[1] + fmla v23.4h, v9.4h, v7.h[1] + + fmla v16.4h, v10.4h, v0.h[2] + fmla v17.4h, v10.4h, v1.h[2] + fmla v18.4h, v10.4h, v2.h[2] + fmla v19.4h, v10.4h, v3.h[2] + fmla v20.4h, v10.4h, v4.h[2] + fmla v21.4h, v10.4h, v5.h[2] + fmla v22.4h, v10.4h, v6.h[2] + fmla v23.4h, v10.4h, v7.h[2] + + fmla v16.4h, v11.4h, v0.h[3] + fmla v17.4h, v11.4h, v1.h[3] + fmla v18.4h, v11.4h, v2.h[3] + fmla v19.4h, v11.4h, v3.h[3] + fmla v20.4h, v11.4h, v4.h[3] + fmla v21.4h, v11.4h, v5.h[3] + fmla v22.4h, v11.4h, v6.h[3] + fmla v23.4h, v11.4h, v7.h[3] + + subs x9, x9, #1 + bne LoopIcHalf + + st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], #32 + st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x0], #32 + +LoopOcEnd: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/VecMatmulFp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/VecMatmulFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..0df891d3141abf9f1c455234b269d359e8cff934 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/VecMatmulFp16.S @@ -0,0 +1,181 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +// void VecMatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, +// int depth, int col) +// x0: a +// x1: b +// x2: c +// x3: bias +// w4: act_type +// w5: depth +// w6: col + +asm_function VecMatmulFp16Neon64_2 + sub sp, sp, #128 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp] + add x9, sp, #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x9] + +LoopCol: + mov x15, x0 // reload a ptr + ld1 {v0.8h}, [x3], #16 // acc0 + ld1 {v1.8h}, [x3], #16 // acc1 + mov w9, #0 // tmp depth + +Loop2x8Inner: + sub w18, w5, w9 + cmp w18, #8 + blt DepthRemain + + ld1 {v2.8h}, [x15], #16 // a + ld1 {v3.8h, v4.8h, v5.8h, v6.8h}, [x1], #64 + ld1 {v7.8h, v8.8h, v9.8h, v10.8h}, [x1], #64 + ld1 {v11.8h, v12.8h, v13.8h, v14.8h}, [x1], #64 + ld1 {v15.8h, v16.8h, v17.8h, v18.8h}, [x1], #64 + + fmla v0.8h, v3.8h, v2.h[0] + fmla v0.8h, v5.8h, v2.h[1] + fmla v0.8h, v7.8h, v2.h[2] + fmla v0.8h, v9.8h, v2.h[3] + fmla v0.8h, v11.8h, v2.h[4] + fmla v0.8h, v13.8h, v2.h[5] + fmla v0.8h, v15.8h, v2.h[6] + fmla v0.8h, v17.8h, v2.h[7] + fmla v1.8h, v4.8h, v2.h[0] + fmla v1.8h, v6.8h, v2.h[1] + fmla v1.8h, v8.8h, v2.h[2] + fmla v1.8h, v10.8h, v2.h[3] + fmla v1.8h, v12.8h, v2.h[4] + fmla v1.8h, v14.8h, v2.h[5] + fmla v1.8h, v16.8h, v2.h[6] + fmla v1.8h, v18.8h, v2.h[7] + + add w9, w9, #8 + b Loop2x8Inner + +DepthRemain: // last depth [0, 8) + cmp w18, #0 + ble Act + ld1 {v2.h}[0], [x15], #2 + ld1 {v3.8h}, [x1], #16 + ld1 {v4.8h}, [x1], #16 + fmla v0.8h, v3.8h, v2.h[0] + fmla v1.8h, v4.8h, v2.h[0] + sub w18, w18, #1 + b DepthRemain + +Act: + cmp w4, #3 + beq Relu6 + cmp w4, #1 + beq Relu + b Write + +Relu6: + movi v19.8h, #0x46, lsl #8 + fmin v0.8h, v0.8h, v19.8h + fmin v1.8h, v1.8h, v19.8h + +Relu: + dup v20.8h, wzr + fmax v0.8h, v0.8h, v20.8h + fmax v1.8h, v1.8h, v20.8h + +Write: + cmp w6, #8 + blt WriteMod8 + st1 {v0.8h}, [x2], #16 + sub w6, w6, #8 + mov v0.16b, v1.16b + cmp w6, #8 + blt WriteMod8 + st1 {v1.8h}, [x2], #16 + sub w6, w6, #8 + cbz w6, End + b LoopCol + +WriteMod8: + cmp w6, #0 + ble End + cmp w6, #1 + beq Write1 + cmp w6, #2 + beq Write2 + cmp w6, #3 + beq Write3 + cmp w6, #4 + beq Write4 + cmp w6, #5 + beq Write5 + cmp w6, #6 + beq Write6 + cmp w6, #7 + beq Write7 + +Write1: + st1 {v0.h}[0], [x2], #2 + b End +Write2: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + b End +Write3: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + b End +Write4: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + st1 {v0.h}[3], [x2], #2 + b End +Write5: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + st1 {v0.h}[3], [x2], #2 + st1 {v0.h}[4], [x2], #2 + b End +Write6: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + st1 {v0.h}[3], [x2], #2 + st1 {v0.h}[4], [x2], #2 + st1 {v0.h}[5], [x2], #2 + b End +Write7: + st1 {v0.h}[0], [x2], #2 + st1 {v0.h}[1], [x2], #2 + st1 {v0.h}[2], [x2], #2 + st1 {v0.h}[3], [x2], #2 + st1 {v0.h}[4], [x2], #2 + st1 {v0.h}[5], [x2], #2 + st1 {v0.h}[6], [x2], #2 + b End + +End: + ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/WinogradTransLeftFp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/WinogradTransLeftFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..c9b4104e3810d9a0d13ba61e4a3813b01f0f99a5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/WinogradTransLeftFp16.S @@ -0,0 +1,150 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +asm_function WinogradTransLeftFp16 + +sub sp, sp, #16 +stp x19, x20, [sp] + +mov x8, #8 // 4 * sizeof(float16) +mul x8, x6, x8 +mul x9, x3, x8 +sub x9, x9, x8 +add x7, x9, x8 // step for S +mov x10, #2 +mul x10, x4, x10 // step for B + +LoopH: + mov x13, x0 + mov x15, x3 + LoopW: + mov x14, x13 + mov x17, x1 + dup v30.4h, wzr + mov x11, x6 + InitZero: + st1 {v30.4h}, [x2], #8 + subs x11, x11, #1 + bne InitZero + + sub x2, x2, x8 + mov x12, x5 + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.h}[0], [x17], x10 + ld1 {v0.h}[1], [x17], x10 + ld1 {v0.h}[2], [x17], x10 + ld1 {v0.h}[3], [x17], x10 + mov x11, x6 + mov x20, x17 + add x20, x14, x7 + add x16, x20, x7 + add x19, x16, x7 + + LoopLength4: + ld1 {v16.4h}, [x2] + ld1 {v20.4h}, [x14], #8 + fmla v16.4h, v20.4h, v0.h[0] + ld1 {v21.4h}, [x20], #8 + fmul v17.4h, v21.4h, v0.h[1] + ld1 {v20.4h}, [x16], #8 + fmla v16.4h, v20.4h, v0.h[2] + ld1 {v21.4h}, [x19], #8 + fmla v17.4h, v21.4h, v0.h[3] + fadd v17.4h, v16.4h, v17.4h + st1 {v17.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength4 + + sub x2, x2, x8 + sub x12, x12, #4 + add x14, x19, x9 + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.h}[0], [x17], x10 + ld1 {v0.h}[1], [x17], x10 + ld1 {v0.h}[2], [x17], x10 + mov x11, x6 + mov x20, x17 + add x20, x14, x7 + add x16, x20, x7 + LoopLength3: + ld1 {v16.4h}, [x2] + ld1 {v20.4h}, [x14], #8 + fmla v16.4h, v20.4h, v0.h[0] + ld1 {v21.4h}, [x20], #8 + fmul v17.4h, v21.4h, v0.h[1] + ld1 {v20.4h}, [x16], #8 + fmla v16.4h, v20.4h, v0.h[2] + fadd v17.4h, v16.4h, v17.4h + st1 {v17.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength3 + + sub x2, x2, x8 + sub x12, x12, #3 + add x14, x16, x9 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LKEnd + LoopK: + ld1r {v31.4h}, [x17], x10 + mov x11, x6 + LoopLength: + ld1 {v0.4h}, [x2] + ld1 {v1.4h}, [x14], #8 + fmla v0.4h, v1.4h, v31.4h + st1 {v0.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength + + subs x12, x12, #1 + sub x2, x2, x8 + add x14, x14, x9 + bne LoopK + + LKEnd: + subs x15, x15, #1 + add x13, x13, x8 + add x2, x2, x8 + bne LoopW + + add x1, x1, #2 //sizeof(float) + subs x4, x4, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + ret + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/WinogradTransRightFp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/WinogradTransRightFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..46c3cd84e96691fea27666cd348a5578c0a4b0ae --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/fp16/WinogradTransRightFp16.S @@ -0,0 +1,154 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" + +.text +.align 5 + +asm_function WinogradTransRightFp16 + +sub sp, sp, #16 +stp x19, x20, [sp] + +mov x8, #8 // 4 * sizeof(float16) +mul x8, x6, x8 +mul x9, x5, x8 // step for S +mov x10, #2 +mul x10, x4, x10 // step for B + +LoopH: + mov x7, x1 + mov x15, x3 + LoopW: + mov x17, x0 + mov x13, x7 + dup v30.4h, wzr + mov x11, x6 + InitZero: + st1 {v30.4h}, [x2], #8 + subs x11, x11, #1 + bne InitZero + sub x2, x2, x8 + mov x12, x5 + + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.h}[0], [x13], x10 + ld1 {v0.h}[1], [x13], x10 + ld1 {v0.h}[2], [x13], x10 + ld1 {v0.h}[3], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + add x19, x16, x8 + + LoopLength4: + ld1 {v16.4h}, [x2] + ld1 {v20.4h}, [x17], #8 + fmla v16.4h, v20.4h, v0.h[0] + ld1 {v21.4h}, [x14], #8 + fmul v17.4h, v21.4h, v0.h[1] + ld1 {v20.4h}, [x16], #8 + fmla v16.4h, v20.4h, v0.h[2] + ld1 {v21.4h}, [x19], #8 + fmla v17.4h, v21.4h, v0.h[3] + + fadd v17.4h, v16.4h, v17.4h + st1 {v17.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength4 + sub x2, x2, x8 + sub x12, x12, #4 + mov x17, x19 + + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.h}[0], [x13], x10 + ld1 {v0.h}[1], [x13], x10 + ld1 {v0.h}[2], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + + LoopLength3: + ld1 {v16.4h}, [x2] + ld1 {v20.4h}, [x17], #8 + fmla v16.4h, v20.4h, v0.h[0] + ld1 {v21.4h}, [x14], #8 + fmul v17.4h, v21.4h, v0.h[1] + ld1 {v20.4h}, [x16], #8 + fmla v16.4h, v20.4h, v0.h[2] + + fadd v17.4h, v16.4h, v17.4h + st1 {v17.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength3 + sub x2, x2, x8 + sub x12, x12, #3 + mov x17, x19 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LoopKEnd + + LoopK: + ld1r {v31.4h}, [x13], x10 + + mov x11, x6 + LoopLength: + ld1 {v0.4h}, [x2] + ld1 {v1.4h}, [x17], #8 + fmla v0.4h, v1.4h, v31.4h + + st1 {v0.4h}, [x2], #8 + subs x11, x11, #1 + bne LoopLength + subs x12, x12, #1 + + sub x2, x2, x8 + bne LoopK + LoopKEnd: + subs x15, x15, #1 + add x2, x2, x8 + add x7, x7, #2 + bne LoopW + + add x0, x0, x9 + subs x4, x4, #1 + bne LoopH + + ldp x19, x20, [sp], #16 + + ret + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S new file mode 100644 index 0000000000000000000000000000000000000000..4319a9279dfdf0d5f03a50aa03e5a23afc2bd948 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWI.S @@ -0,0 +1,764 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" +.text +.align 5 + +// void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, float *multi_scales, +// float *bias, size_t row, size_t col, size_t stride, const int *a_sums, +// const int *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode); +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: deep +// x4: multi_scales +// x5: bias +// x6: row +// x7: col +// x8: stride +// x9: a_sums +// x10: b_sums +// x19/w19: a_zp +// x19/w20: b_zp_sum +// x21: act_type -> 0: none, 1:Relu, 3:Relu6 + +asm_function DynamicMatmulSdot4x4x16AIWI + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x19, [sp, #24] + ldr x20, [sp, #32] + ldr x21, [sp, #40] + ldr x22, [sp, #48] + + dup v16.4s, wzr // dup:Duplicate general-purpose register to vector. + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + mov x11, x1 // reload rhs ptr + mov x17, x0 // reload lhs ptr + mov x16, x3 // reload depth + + cmp x7, #4 + ble LoopDepthQuarter + cmp x7, #8 + ble LoopDepthHalf + +LoopDepth: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x11], #64 + + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v18.4s, v3.16b, v0.4b[0] + sdot v19.4s, v4.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v22.4s, v3.16b, v0.4b[1] + sdot v23.4s, v4.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v26.4s, v3.16b, v0.4b[2] + sdot v27.4s, v4.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + sdot v30.4s, v3.16b, v0.4b[3] + sdot v31.4s, v4.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepth + b AddInputSum + +LoopDepthHalf: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthHalf + b AddInputSum + +LoopDepthQuarter: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthQuarter + b AddInputSum + +AddInputSum: + cmp w20, #0 + beq AddInputSumEnd + ld1 {v5.4s}, [x9], #16 + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + dup v8.4s, v5.s[2] + dup v9.4s, v5.s[3] + + sub v16.4s, v16.4s, v6.4s + sub v17.4s, v17.4s, v6.4s + sub v18.4s, v18.4s, v6.4s + sub v19.4s, v19.4s, v6.4s + sub v20.4s, v20.4s, v7.4s + sub v21.4s, v21.4s, v7.4s + sub v22.4s, v22.4s, v7.4s + sub v23.4s, v23.4s, v7.4s + sub v24.4s, v24.4s, v8.4s + sub v25.4s, v25.4s, v8.4s + sub v26.4s, v26.4s, v8.4s + sub v27.4s, v27.4s, v8.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v9.4s + sub v30.4s, v30.4s, v9.4s + sub v31.4s, v31.4s, v9.4s +AddInputSumEnd: + +AddWeightSum: + ld1 {v9.4s}, [x10], #16 + ld1 {v10.4s}, [x10], #16 + ld1 {v11.4s}, [x10], #16 + ld1 {v12.4s}, [x10], #16 + dup v13.4s, w19 + mul v9.4s, v9.4s, v13.4s + mul v10.4s, v10.4s, v13.4s + mul v11.4s, v11.4s, v13.4s + mul v12.4s, v12.4s, v13.4s + sub v16.4s, v16.4s, v9.4s + sub v17.4s, v17.4s, v10.4s + sub v18.4s, v18.4s, v11.4s + sub v19.4s, v19.4s, v12.4s + sub v20.4s, v20.4s, v9.4s + sub v21.4s, v21.4s, v10.4s + sub v22.4s, v22.4s, v11.4s + sub v23.4s, v23.4s, v12.4s + sub v24.4s, v24.4s, v9.4s + sub v25.4s, v25.4s, v10.4s + sub v26.4s, v26.4s, v11.4s + sub v27.4s, v27.4s, v12.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v10.4s + sub v30.4s, v30.4s, v11.4s + sub v31.4s, v31.4s, v12.4s + +AddZpSum: + mul w15, w19, w20 + cmp w15, #0 + beq AddZpSumEnd + dup v14.4s, w15 + add v16.4s, v16.4s, v14.4s + add v17.4s, v17.4s, v14.4s + add v18.4s, v18.4s, v14.4s + add v19.4s, v19.4s, v14.4s + add v20.4s, v20.4s, v14.4s + add v21.4s, v21.4s, v14.4s + add v22.4s, v22.4s, v14.4s + add v23.4s, v23.4s, v14.4s + add v24.4s, v24.4s, v14.4s + add v25.4s, v25.4s, v14.4s + add v26.4s, v26.4s, v14.4s + add v27.4s, v27.4s, v14.4s + add v28.4s, v28.4s, v14.4s + add v29.4s, v29.4s, v14.4s + add v30.4s, v30.4s, v14.4s + add v31.4s, v31.4s, v14.4s +AddZpSumEnd: + +Convert2Float: + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + scvtf v23.4s, v23.4s + scvtf v24.4s, v24.4s + scvtf v25.4s, v25.4s + scvtf v26.4s, v26.4s + scvtf v27.4s, v27.4s + scvtf v28.4s, v28.4s + scvtf v29.4s, v29.4s + scvtf v30.4s, v30.4s + scvtf v31.4s, v31.4s + +MultiplyScale: + // multi_scale * input_matrix + cbz x22, TensorXTensor + cmp x22, #1 + beq TensorXChannel + cmp x22, #2 + beq ChannelXTensor + ChannelXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x4], #64 + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x4], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x4] + + fmul v20.4s,v20.4s,v4.4s + fmul v21.4s,v21.4s,v5.4s + fmul v22.4s,v22.4s,v6.4s + fmul v23.4s,v23.4s,v7.4s + + fmul v24.4s,v24.4s,v8.4s + fmul v25.4s,v25.4s,v9.4s + fmul v26.4s,v26.4s,v10.4s + fmul v27.4s,v27.4s,v11.4s + + fmul v28.4s,v28.4s,v12.4s + fmul v29.4s,v29.4s,v13.4s + fmul v30.4s,v30.4s,v14.4s + fmul v31.4s,v31.4s,v15.4s + b AddBias + + TensorXTensor: + ld1 {v0.s}[0], [x4] + + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[0] + fmul v21.4s,v21.4s,v0.s[0] + fmul v22.4s,v22.4s,v0.s[0] + fmul v23.4s,v23.4s,v0.s[0] + + fmul v24.4s,v24.4s,v0.s[0] + fmul v25.4s,v25.4s,v0.s[0] + fmul v26.4s,v26.4s,v0.s[0] + fmul v27.4s,v27.4s,v0.s[0] + + fmul v28.4s,v28.4s,v0.s[0] + fmul v29.4s,v29.4s,v0.s[0] + fmul v30.4s,v30.4s,v0.s[0] + fmul v31.4s,v31.4s,v0.s[0] + b AddBias + + TensorXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4] + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + fmul v20.4s,v20.4s,v0.4s + fmul v21.4s,v21.4s,v1.4s + fmul v22.4s,v22.4s,v2.4s + fmul v23.4s,v23.4s,v3.4s + + fmul v24.4s,v24.4s,v0.4s + fmul v25.4s,v25.4s,v1.4s + fmul v26.4s,v26.4s,v2.4s + fmul v27.4s,v27.4s,v3.4s + + fmul v28.4s,v28.4s,v0.4s + fmul v29.4s,v29.4s,v1.4s + fmul v30.4s,v30.4s,v2.4s + fmul v31.4s,v31.4s,v3.4s + b AddBias + + ChannelXTensor: + ld1 {v0.4s}, [x4] + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[1] + fmul v21.4s,v21.4s,v0.s[1] + fmul v22.4s,v22.4s,v0.s[1] + fmul v23.4s,v23.4s,v0.s[1] + + fmul v24.4s,v24.4s,v0.s[2] + fmul v25.4s,v25.4s,v0.s[2] + fmul v26.4s,v26.4s,v0.s[2] + fmul v27.4s,v27.4s,v0.s[2] + + fmul v28.4s,v28.4s,v0.s[3] + fmul v29.4s,v29.4s,v0.s[3] + fmul v30.4s,v30.4s,v0.s[3] + fmul v31.4s,v31.4s,v0.s[3] +AddBias: + // +bias + cbz x5, StoreData + ld1 {v1.4s, v2.4s, v3.4s, v4.4s}, [x5] + + fadd v16.4s,v16.4s,v1.4s + fadd v17.4s,v17.4s,v2.4s + fadd v18.4s,v18.4s,v3.4s + fadd v19.4s,v19.4s,v4.4s + + fadd v20.4s,v20.4s,v1.4s + fadd v21.4s,v21.4s,v2.4s + fadd v22.4s,v22.4s,v3.4s + fadd v23.4s,v23.4s,v4.4s + + fadd v24.4s,v24.4s,v1.4s + fadd v25.4s,v25.4s,v2.4s + fadd v26.4s,v26.4s,v3.4s + fadd v27.4s,v27.4s,v4.4s + + fadd v28.4s,v28.4s,v1.4s + fadd v29.4s,v29.4s,v2.4s + fadd v30.4s,v30.4s,v3.4s + fadd v31.4s,v31.4s,v4.4s + +Activate: + cmp x21, #1 + beq Relu + cmp x21, #3 + beq Relu6 + b StoreData + +Relu: + dup v1.4s, wzr + + smax v16.4s,v16.4s,v1.4s + smax v17.4s,v17.4s,v1.4s + smax v18.4s,v18.4s,v1.4s + smax v19.4s,v19.4s,v1.4s + + smax v20.4s,v20.4s,v1.4s + smax v21.4s,v21.4s,v1.4s + smax v22.4s,v22.4s,v1.4s + smax v23.4s,v23.4s,v1.4s + + smax v24.4s,v24.4s,v1.4s + smax v25.4s,v25.4s,v1.4s + smax v26.4s,v26.4s,v1.4s + smax v27.4s,v27.4s,v1.4s + + smax v28.4s,v28.4s,v1.4s + smax v29.4s,v29.4s,v1.4s + smax v30.4s,v30.4s,v1.4s + smax v31.4s,v31.4s,v1.4s + + b StoreData + +Relu6: + dup v1.4s, wzr + movi v2.4s, #6 + scvtf v2.4s, v2.4s + + // max (out, 0) + smax v16.4s,v16.4s,v1.4s + smax v17.4s,v17.4s,v1.4s + smax v18.4s,v18.4s,v1.4s + smax v19.4s,v19.4s,v1.4s + + smax v20.4s,v20.4s,v1.4s + smax v21.4s,v21.4s,v1.4s + smax v22.4s,v22.4s,v1.4s + smax v23.4s,v23.4s,v1.4s + + smax v24.4s,v24.4s,v1.4s + smax v25.4s,v25.4s,v1.4s + smax v26.4s,v26.4s,v1.4s + smax v27.4s,v27.4s,v1.4s + + smax v28.4s,v28.4s,v1.4s + smax v29.4s,v29.4s,v1.4s + smax v30.4s,v30.4s,v1.4s + smax v31.4s,v31.4s,v1.4s + + // min (out, 6) + + smin v16.4s,v16.4s,v2.4s + smin v17.4s,v17.4s,v2.4s + smin v18.4s,v18.4s,v2.4s + smin v19.4s,v19.4s,v2.4s + + smin v20.4s,v20.4s,v2.4s + smin v21.4s,v21.4s,v2.4s + smin v22.4s,v22.4s,v2.4s + smin v23.4s,v23.4s,v2.4s + + smin v24.4s,v24.4s,v2.4s + smin v25.4s,v25.4s,v2.4s + smin v26.4s,v26.4s,v2.4s + smin v27.4s,v27.4s,v2.4s + + smin v28.4s,v28.4s,v2.4s + smin v29.4s,v29.4s,v2.4s + smin v30.4s,v30.4s,v2.4s + smin v31.4s,v31.4s,v2.4s + + b StoreData + +StoreData: + cmp x7, #16 + beq Write16 + + mov x15, x2 // reload out ptr + add x14, x15, x8 + add x13, x14, x8 + add x12, x13, x8 + + cmp x7, #15 + beq Write15 + cmp x7, #14 + beq Write14 + cmp x7, #13 + beq Write13 + cmp x7, #12 + beq Write12 + cmp x7, #11 + beq Write11 + cmp x7, #10 + beq Write10 + cmp x7, #9 + beq Write9 + cmp x7, #8 + beq Write8 + cmp x7, #7 + beq Write7 + cmp x7, #6 + beq Write6 + cmp x7, #5 + beq Write5 + cmp x7, #4 + beq Write4 + cmp x7, #3 + beq Write3 + cmp x7, #2 + beq Write2 + cmp x7, #1 + beq Write1 + b StoreDataEnd + +Write16: + cmp x6, #4 + beq Write16Row4 + cmp x6, #3 + beq Write16Row3 + cmp x6, #2 + beq Write16Row2 + cmp x6, #1 + beq Write16Row1 + + Write16Row4: + st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8 + st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2], x8 + st1 {v24.4s,v25.4s,v26.4s,v27.4s}, [x2], x8 + st1 {v28.4s,v29.4s,v30.4s,v31.4s}, [x2] + b StoreDataEnd + Write16Row3: + st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8 + st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2], x8 + st1 {v24.4s,v25.4s,v26.4s,v27.4s}, [x2] + b StoreDataEnd + Write16Row2: + st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2], x8 + st1 {v20.4s,v21.4s,v22.4s,v23.4s}, [x2] + b StoreDataEnd + Write16Row1: + st1 {v16.4s,v17.4s,v18.4s,v19.4s}, [x2] + b StoreDataEnd + +Write15: + st1 {v16.4s,v17.4s,v18.4s}, [x15], #48 + st1 {v19.1d}, [x15], #8 + st1 {v19.s}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s,v22.4s}, [x14], #48 + st1 {v23.1d}, [x14], #8 + st1 {v23.s}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s,v26.4s}, [x13], #48 + st1 {v27.1d}, [x13], #8 + st1 {v27.s}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s,v30.4s}, [x12], #48 + st1 {v31.1d}, [x12], #8 + st1 {v31.s}[2], [x12] + b StoreDataEnd + +Write14: + st1 {v16.4s,v17.4s,v18.4s}, [x15], #48 + st1 {v19.1d}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s,v22.4s}, [x14], #48 + st1 {v23.1d}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s,v26.4s}, [x13], #48 + st1 {v27.1d}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s,v30.4s}, [x12], #48 + st1 {v31.1d}, [x12] + b StoreDataEnd + +Write13: + st1 {v16.4s,v17.4s,v18.4s}, [x15], #48 + st1 {v19.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s,v22.4s}, [x14], #48 + st1 {v23.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s,v26.4s}, [x13], #48 + st1 {v27.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s,v30.4s}, [x12], #48 + st1 {v31.s}[0], [x12] + b StoreDataEnd + +Write12: + st1 {v16.4s,v17.4s,v18.4s}, [x15], #48 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s,v22.4s}, [x14], #48 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s,v26.4s}, [x13], #48 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s,v30.4s}, [x12], #48 + b StoreDataEnd + +Write11: + st1 {v16.4s,v17.4s}, [x15], #32 + st1 {v18.1d}, [x15], #8 + st1 {v18.s}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s}, [x14], #32 + st1 {v22.1d}, [x14], #8 + st1 {v22.s}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s}, [x13], #32 + st1 {v26.1d}, [x13], #8 + st1 {v26.s}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s}, [x12], #32 + st1 {v30.1d}, [x12], #8 + st1 {v30.s}[2], [x12] + b StoreDataEnd + +Write10: + st1 {v16.4s,v17.4s}, [x15], #32 + st1 {v18.1d}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s}, [x14], #32 + st1 {v22.1d}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s}, [x13], #32 + st1 {v26.1d}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s}, [x12], #32 + st1 {v30.1d}, [x12] + b StoreDataEnd + +Write9: + st1 {v16.4s,v17.4s}, [x15], #32 + st1 {v18.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s}, [x14], #32 + st1 {v22.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s}, [x13], #32 + st1 {v26.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s}, [x12], #32 + st1 {v30.s}[0], [x12] + b StoreDataEnd + +Write8: + st1 {v16.4s,v17.4s}, [x15], #32 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s,v21.4s}, [x14], #32 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s,v25.4s}, [x13], #32 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s,v29.4s}, [x12], #32 + b StoreDataEnd + +Write7: + st1 {v16.4s}, [x15], #16 + st1 {v17.1d}, [x15], #8 + st1 {v17.s}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s}, [x14], #16 + st1 {v21.1d}, [x14], #8 + st1 {v21.s}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s}, [x13], #16 + st1 {v25.1d}, [x13], #8 + st1 {v25.s}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s}, [x12], #16 + st1 {v29.1d}, [x12], #8 + st1 {v29.s}[2], [x12] + b StoreDataEnd + +Write6: + st1 {v16.4s}, [x15], #16 + st1 {v17.1d}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s}, [x14], #16 + st1 {v21.1d}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s}, [x13], #16 + st1 {v25.1d}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s}, [x12], #16 + st1 {v29.1d}, [x12] + b StoreDataEnd + +Write5: + st1 {v16.4s}, [x15], #16 + st1 {v17.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s}, [x14], #16 + st1 {v21.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s}, [x13], #16 + st1 {v25.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s}, [x12], #16 + st1 {v29.s}[0], [x12] + b StoreDataEnd + +Write4: + st1 {v16.4s}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4s}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4s}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4s}, [x12] + b StoreDataEnd + +Write3: + st1 {v16.1d}, [x15], #8 + st1 {v16.s}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.1d}, [x14], #8 + st1 {v20.s}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.1d}, [x13], #8 + st1 {v24.s}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.1d}, [x12], #8 + st1 {v28.s}[2], [x12] + b StoreDataEnd + +Write2: + st1 {v16.1d}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.1d}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.1d}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.1d}, [x12] + b StoreDataEnd + +Write1: + st1 {v16.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.s}[0], [x12] + b StoreDataEnd +StoreDataEnd: + sub sp, sp, #160 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S new file mode 100644 index 0000000000000000000000000000000000000000..a5650b3d20ac3c9174b85d01b38e163b3066618f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/DynamicMatmulSdot4x4x16AIWIForFp16.S @@ -0,0 +1,788 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" +.text +.align 5 + +// void DynamicMatmulSdot4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, +// float16_t *multi_scales, float16_t *bias, size_t row, size_t col, size_t stride, +// const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, +// int64_t act_type, int64_t mode); +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: deep +// x4: multi_scales +// x5: bias +// x6: row +// x7: col +// x8: stride +// x9: a_sums +// x10: b_sums +// x19/w19: a_zp +// x19/w20: b_zp_sum +// x21: act_type -> 0: none, 1:Relu, 3:Relu6 + +asm_function DynamicMatmulSdot4x4x16AIWIForFp16 + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr x8, [sp] + ldr x9, [sp, #8] + ldr x10, [sp, #16] + ldr x19, [sp, #24] + ldr x20, [sp, #32] + ldr x21, [sp, #40] + ldr x22, [sp, #48] + + dup v16.4s, wzr // dup:Duplicate general-purpose register to vector. + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + mov x11, x1 // reload rhs ptr + mov x17, x0 // reload lhs ptr + mov x16, x3 // reload depth + + cmp x7, #4 + ble LoopDepthQuarter + cmp x7, #8 + ble LoopDepthHalf + +LoopDepth: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x11], #64 + + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v18.4s, v3.16b, v0.4b[0] + sdot v19.4s, v4.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v22.4s, v3.16b, v0.4b[1] + sdot v23.4s, v4.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v26.4s, v3.16b, v0.4b[2] + sdot v27.4s, v4.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + sdot v30.4s, v3.16b, v0.4b[3] + sdot v31.4s, v4.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepth + b AddInputSum + +LoopDepthHalf: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b, v2.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthHalf + b AddInputSum + +LoopDepthQuarter: + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x11] + add x11, x11, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + + subs x16, x16, #4 + bgt LoopDepthQuarter + b AddInputSum + +AddInputSum: + cmp w20, #0 + beq AddInputSumEnd + ld1 {v5.4s}, [x9], #16 + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + dup v8.4s, v5.s[2] + dup v9.4s, v5.s[3] + + sub v16.4s, v16.4s, v6.4s + sub v17.4s, v17.4s, v6.4s + sub v18.4s, v18.4s, v6.4s + sub v19.4s, v19.4s, v6.4s + sub v20.4s, v20.4s, v7.4s + sub v21.4s, v21.4s, v7.4s + sub v22.4s, v22.4s, v7.4s + sub v23.4s, v23.4s, v7.4s + sub v24.4s, v24.4s, v8.4s + sub v25.4s, v25.4s, v8.4s + sub v26.4s, v26.4s, v8.4s + sub v27.4s, v27.4s, v8.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v9.4s + sub v30.4s, v30.4s, v9.4s + sub v31.4s, v31.4s, v9.4s +AddInputSumEnd: + +AddWeightSum: + ld1 {v9.4s}, [x10], #16 + ld1 {v10.4s}, [x10], #16 + ld1 {v11.4s}, [x10], #16 + ld1 {v12.4s}, [x10], #16 + dup v13.4s, w19 + mul v9.4s, v9.4s, v13.4s + mul v10.4s, v10.4s, v13.4s + mul v11.4s, v11.4s, v13.4s + mul v12.4s, v12.4s, v13.4s + sub v16.4s, v16.4s, v9.4s + sub v17.4s, v17.4s, v10.4s + sub v18.4s, v18.4s, v11.4s + sub v19.4s, v19.4s, v12.4s + sub v20.4s, v20.4s, v9.4s + sub v21.4s, v21.4s, v10.4s + sub v22.4s, v22.4s, v11.4s + sub v23.4s, v23.4s, v12.4s + sub v24.4s, v24.4s, v9.4s + sub v25.4s, v25.4s, v10.4s + sub v26.4s, v26.4s, v11.4s + sub v27.4s, v27.4s, v12.4s + sub v28.4s, v28.4s, v9.4s + sub v29.4s, v29.4s, v10.4s + sub v30.4s, v30.4s, v11.4s + sub v31.4s, v31.4s, v12.4s + +AddZpSum: + mul w15, w19, w20 + cmp w15, #0 + beq AddZpSumEnd + dup v14.4s, w15 + add v16.4s, v16.4s, v14.4s + add v17.4s, v17.4s, v14.4s + add v18.4s, v18.4s, v14.4s + add v19.4s, v19.4s, v14.4s + add v20.4s, v20.4s, v14.4s + add v21.4s, v21.4s, v14.4s + add v22.4s, v22.4s, v14.4s + add v23.4s, v23.4s, v14.4s + add v24.4s, v24.4s, v14.4s + add v25.4s, v25.4s, v14.4s + add v26.4s, v26.4s, v14.4s + add v27.4s, v27.4s, v14.4s + add v28.4s, v28.4s, v14.4s + add v29.4s, v29.4s, v14.4s + add v30.4s, v30.4s, v14.4s + add v31.4s, v31.4s, v14.4s +AddZpSumEnd: + +Convert2Float: + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + scvtf v23.4s, v23.4s + scvtf v24.4s, v24.4s + scvtf v25.4s, v25.4s + scvtf v26.4s, v26.4s + scvtf v27.4s, v27.4s + scvtf v28.4s, v28.4s + scvtf v29.4s, v29.4s + scvtf v30.4s, v30.4s + scvtf v31.4s, v31.4s + +MultiplyScale: + // multi_scale * input_matrix + cbz x22, TensorXTensor + cmp x22, #1 + beq TensorXChannel + cmp x22, #2 + beq ChannelXTensor + ChannelXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x4], #64 + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x4], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x4] + + fmul v20.4s,v20.4s,v4.4s + fmul v21.4s,v21.4s,v5.4s + fmul v22.4s,v22.4s,v6.4s + fmul v23.4s,v23.4s,v7.4s + + fmul v24.4s,v24.4s,v8.4s + fmul v25.4s,v25.4s,v9.4s + fmul v26.4s,v26.4s,v10.4s + fmul v27.4s,v27.4s,v11.4s + + fmul v28.4s,v28.4s,v12.4s + fmul v29.4s,v29.4s,v13.4s + fmul v30.4s,v30.4s,v14.4s + fmul v31.4s,v31.4s,v15.4s + b ConvertHalfPrecision + + TensorXTensor: + ld1 {v0.s}[0], [x4] + + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[0] + fmul v21.4s,v21.4s,v0.s[0] + fmul v22.4s,v22.4s,v0.s[0] + fmul v23.4s,v23.4s,v0.s[0] + + fmul v24.4s,v24.4s,v0.s[0] + fmul v25.4s,v25.4s,v0.s[0] + fmul v26.4s,v26.4s,v0.s[0] + fmul v27.4s,v27.4s,v0.s[0] + + fmul v28.4s,v28.4s,v0.s[0] + fmul v29.4s,v29.4s,v0.s[0] + fmul v30.4s,v30.4s,v0.s[0] + fmul v31.4s,v31.4s,v0.s[0] + b ConvertHalfPrecision + + TensorXChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x4] + + fmul v16.4s,v16.4s,v0.4s + fmul v17.4s,v17.4s,v1.4s + fmul v18.4s,v18.4s,v2.4s + fmul v19.4s,v19.4s,v3.4s + + fmul v20.4s,v20.4s,v0.4s + fmul v21.4s,v21.4s,v1.4s + fmul v22.4s,v22.4s,v2.4s + fmul v23.4s,v23.4s,v3.4s + + fmul v24.4s,v24.4s,v0.4s + fmul v25.4s,v25.4s,v1.4s + fmul v26.4s,v26.4s,v2.4s + fmul v27.4s,v27.4s,v3.4s + + fmul v28.4s,v28.4s,v0.4s + fmul v29.4s,v29.4s,v1.4s + fmul v30.4s,v30.4s,v2.4s + fmul v31.4s,v31.4s,v3.4s + b ConvertHalfPrecision + + ChannelXTensor: + ld1 {v0.4s}, [x4] + fmul v16.4s,v16.4s,v0.s[0] + fmul v17.4s,v17.4s,v0.s[0] + fmul v18.4s,v18.4s,v0.s[0] + fmul v19.4s,v19.4s,v0.s[0] + + fmul v20.4s,v20.4s,v0.s[1] + fmul v21.4s,v21.4s,v0.s[1] + fmul v22.4s,v22.4s,v0.s[1] + fmul v23.4s,v23.4s,v0.s[1] + + fmul v24.4s,v24.4s,v0.s[2] + fmul v25.4s,v25.4s,v0.s[2] + fmul v26.4s,v26.4s,v0.s[2] + fmul v27.4s,v27.4s,v0.s[2] + + fmul v28.4s,v28.4s,v0.s[3] + fmul v29.4s,v29.4s,v0.s[3] + fmul v30.4s,v30.4s,v0.s[3] + fmul v31.4s,v31.4s,v0.s[3] + +ConvertHalfPrecision: +// from single-precision convert to half-precision + fcvtn v16.4h,v16.4s + fcvtn v17.4h,v17.4s + fcvtn v18.4h,v18.4s + fcvtn v19.4h,v19.4s + + fcvtn v20.4h,v20.4s + fcvtn v21.4h,v21.4s + fcvtn v22.4h,v22.4s + fcvtn v23.4h,v23.4s + + fcvtn v24.4h,v24.4s + fcvtn v25.4h,v25.4s + fcvtn v26.4h,v26.4s + fcvtn v27.4h,v27.4s + + fcvtn v28.4h,v28.4s + fcvtn v29.4h,v29.4s + fcvtn v30.4h,v30.4s + fcvtn v31.4h,v31.4s + +AddBias: + // +bias + cbz x5, StoreData + ld1 {v1.4h, v2.4h, v3.4h, v4.4h}, [x5] + + fadd v16.4h,v16.4h,v1.4h + fadd v17.4h,v17.4h,v2.4h + fadd v18.4h,v18.4h,v3.4h + fadd v19.4h,v19.4h,v4.4h + + fadd v20.4h,v20.4h,v1.4h + fadd v21.4h,v21.4h,v2.4h + fadd v22.4h,v22.4h,v3.4h + fadd v23.4h,v23.4h,v4.4h + + fadd v24.4h,v24.4h,v1.4h + fadd v25.4h,v25.4h,v2.4h + fadd v26.4h,v26.4h,v3.4h + fadd v27.4h,v27.4h,v4.4h + + fadd v28.4h,v28.4h,v1.4h + fadd v29.4h,v29.4h,v2.4h + fadd v30.4h,v30.4h,v3.4h + fadd v31.4h,v31.4h,v4.4h + +Activate: + cmp x21, #1 + beq Relu + cmp x21, #3 + beq Relu6 + b StoreData + +Relu: + dup v1.4h, wzr + + smax v16.4h,v16.4h,v1.4h + smax v17.4h,v17.4h,v1.4h + smax v18.4h,v18.4h,v1.4h + smax v19.4h,v19.4h,v1.4h + + smax v20.4h,v20.4h,v1.4h + smax v21.4h,v21.4h,v1.4h + smax v22.4h,v22.4h,v1.4h + smax v23.4h,v23.4h,v1.4h + + smax v24.4h,v24.4h,v1.4h + smax v25.4h,v25.4h,v1.4h + smax v26.4h,v26.4h,v1.4h + smax v27.4h,v27.4h,v1.4h + + smax v28.4h,v28.4h,v1.4h + smax v29.4h,v29.4h,v1.4h + smax v30.4h,v30.4h,v1.4h + smax v31.4h,v31.4h,v1.4h + + b StoreData + +Relu6: + dup v1.4h, wzr + movi v2.4h, #6 + scvtf v2.4h, v2.4h + + // max (out, 0) + smax v16.4h,v16.4h,v1.4h + smax v17.4h,v17.4h,v1.4h + smax v18.4h,v18.4h,v1.4h + smax v19.4h,v19.4h,v1.4h + + smax v20.4h,v20.4h,v1.4h + smax v21.4h,v21.4h,v1.4h + smax v22.4h,v22.4h,v1.4h + smax v23.4h,v23.4h,v1.4h + + smax v24.4h,v24.4h,v1.4h + smax v25.4h,v25.4h,v1.4h + smax v26.4h,v26.4h,v1.4h + smax v27.4h,v27.4h,v1.4h + + smax v28.4h,v28.4h,v1.4h + smax v29.4h,v29.4h,v1.4h + smax v30.4h,v30.4h,v1.4h + smax v31.4h,v31.4h,v1.4h + + // min (out, 6) + + smin v16.4h,v16.4h,v2.4h + smin v17.4h,v17.4h,v2.4h + smin v18.4h,v18.4h,v2.4h + smin v19.4h,v19.4h,v2.4h + + smin v20.4h,v20.4h,v2.4h + smin v21.4h,v21.4h,v2.4h + smin v22.4h,v22.4h,v2.4h + smin v23.4h,v23.4h,v2.4h + + smin v24.4h,v24.4h,v2.4h + smin v25.4h,v25.4h,v2.4h + smin v26.4h,v26.4h,v2.4h + smin v27.4h,v27.4h,v2.4h + + smin v28.4h,v28.4h,v2.4h + smin v29.4h,v29.4h,v2.4h + smin v30.4h,v30.4h,v2.4h + smin v31.4h,v31.4h,v2.4h + + b StoreData + +StoreData: + cmp x7, #16 + beq Write16 + + mov x15, x2 // reload out ptr + add x14, x15, x8 + add x13, x14, x8 + add x12, x13, x8 + + cmp x7, #15 + beq Write15 + cmp x7, #14 + beq Write14 + cmp x7, #13 + beq Write13 + cmp x7, #12 + beq Write12 + cmp x7, #11 + beq Write11 + cmp x7, #10 + beq Write10 + cmp x7, #9 + beq Write9 + cmp x7, #8 + beq Write8 + cmp x7, #7 + beq Write7 + cmp x7, #6 + beq Write6 + cmp x7, #5 + beq Write5 + cmp x7, #4 + beq Write4 + cmp x7, #3 + beq Write3 + cmp x7, #2 + beq Write2 + cmp x7, #1 + beq Write1 + b StoreDataEnd + +Write16: + cmp x6, #4 + beq Write16Row4 + cmp x6, #3 + beq Write16Row3 + cmp x6, #2 + beq Write16Row2 + cmp x6, #1 + beq Write16Row1 + + Write16Row4: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8 + st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2], x8 + st1 {v24.4h,v25.4h,v26.4h,v27.4h}, [x2], x8 + st1 {v28.4h,v29.4h,v30.4h,v31.4h}, [x2] + b StoreDataEnd + Write16Row3: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8 + st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2], x8 + st1 {v24.4h,v25.4h,v26.4h,v27.4h}, [x2] + b StoreDataEnd + Write16Row2: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2], x8 + st1 {v20.4h,v21.4h,v22.4h,v23.4h}, [x2] + b StoreDataEnd + Write16Row1: + st1 {v16.4h,v17.4h,v18.4h,v19.4h}, [x2] + b StoreDataEnd + +Write15: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + st1 {v19.s}[0], [x15], #4 + st1 {v19.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + st1 {v23.s}[0], [x14], #4 + st1 {v23.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + st1 {v27.s}[0], [x13], #4 + st1 {v27.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + st1 {v31.s}[0], [x12], #4 + st1 {v31.h}[2], [x12] + b StoreDataEnd + +Write14: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + st1 {v19.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + st1 {v23.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + st1 {v27.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + st1 {v31.s}[0], [x12] + b StoreDataEnd + +Write13: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + st1 {v19.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + st1 {v23.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + st1 {v27.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + st1 {v31.h}[0], [x12] + b StoreDataEnd + +Write12: + st1 {v16.4h,v17.4h,v18.4h}, [x15], #24 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h,v22.4h}, [x14], #24 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h,v26.4h}, [x13], #24 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h,v30.4h}, [x12], #24 + b StoreDataEnd + +Write11: + st1 {v16.4h,v17.4h}, [x15], #16 + st1 {v18.s}[0], [x15], #4 + st1 {v18.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + st1 {v22.s}[0], [x14], #4 + st1 {v22.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + st1 {v26.s}[0], [x13], #4 + st1 {v26.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + st1 {v30.s}[0], [x12], #4 + st1 {v30.h}[2], [x12] + b StoreDataEnd + +Write10: + st1 {v16.4h,v17.4h}, [x15], #16 + st1 {v18.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + st1 {v22.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + st1 {v26.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + st1 {v30.s}[0], [x12] + b StoreDataEnd + +Write9: + st1 {v16.4h,v17.4h}, [x15], #16 + st1 {v18.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + st1 {v22.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + st1 {v26.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + st1 {v30.h}[0], [x12] + b StoreDataEnd + +Write8: + st1 {v16.4h,v17.4h}, [x15], #16 + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h,v21.4h}, [x14], #16 + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h,v25.4h}, [x13], #16 + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h,v29.4h}, [x12], #16 + b StoreDataEnd + +Write7: + st1 {v16.4h}, [x15], #8 + st1 {v17.s}[0], [x15], #4 + st1 {v17.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14], #8 + st1 {v21.s}[0], [x14], #4 + st1 {v21.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13], #8 + st1 {v25.s}[0], [x13], #4 + st1 {v25.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12], #8 + st1 {v29.s}[0], [x12], #4 + st1 {v29.h}[2], [x12] + b StoreDataEnd + +Write6: + st1 {v16.4h}, [x15], #8 + st1 {v17.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14], #8 + st1 {v21.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13], #8 + st1 {v25.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12], #8 + st1 {v29.s}[0], [x12] + b StoreDataEnd + +Write5: + st1 {v16.4h}, [x15], #8 + st1 {v17.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14], #8 + st1 {v21.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13], #8 + st1 {v25.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12], #8 + st1 {v29.h}[0], [x12] + b StoreDataEnd + +Write4: + st1 {v16.4h}, [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.4h}, [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.4h}, [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.4h}, [x12] + b StoreDataEnd + +Write3: + st1 {v16.s}[0], [x15], #4 + st1 {v16.h}[2], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.s}[0], [x14], #4 + st1 {v20.h}[2], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.s}[0], [x13], #4 + st1 {v24.h}[2], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.s}[0], [x12], #4 + st1 {v28.h}[2], [x12] + b StoreDataEnd + +Write2: + st1 {v16.s}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.s}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.s}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.s}[0], [x12] + b StoreDataEnd + +Write1: + st1 {v16.h}[0], [x15] + cmp x6, #1 + beq StoreDataEnd + st1 {v20.h}[0], [x14] + cmp x6, #2 + beq StoreDataEnd + st1 {v24.h}[0], [x13] + cmp x6, #3 + beq StoreDataEnd + st1 {v28.h}[0], [x12] + b StoreDataEnd +StoreDataEnd: + sub sp, sp, #160 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/MatmulDpInt8.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/MatmulDpInt8.S new file mode 100644 index 0000000000000000000000000000000000000000..3bd9b84be624f0ad9cce53e55dfd816cc065251f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/MatmulDpInt8.S @@ -0,0 +1,864 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" +.text +.align 5 + +//void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, int row8, int col8, int deep4, +// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, +// const int *multiplier, const int *left_shift, const int *right_shift, int row, +// int col, int stride, int peroc); + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row8 +// w4: col8 +// w5: deep4 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// w14: row +// w15: col +// w24: stride +// w27: filter_peroc + +asm_function MatmulInt8DpNeon64 + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + stp x25, x26, [sp], #16 + stp x27, x28, [sp], #16 + + ldr w8, [sp] + ldr w9, [sp, #8] + ldr w10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + ldr w14, [sp, #48] + ldr w15, [sp, #56] + ldr w24, [sp, #64] + ldr w27, [sp, #72] + + mov w17, #8 // sizeof(int8)*8 + mul w21, w5, w17 // the stride of a/b: sizeof(int8)*8*deep4 + mov x25, x2 +L1: + cmp w4, #0 // if at the end of col8 + beq End1 + + mov w16, w3 // reset a row8 counter + mov w23, w14 // reset a row counter + mov x17, x0 // reload a ptr + mov x22, x6 // reload a_sums ptr +L2: + cmp w16, #0 + beq End2 + + mov x28, x1 // reload b ptr + mov x19, x7 // reload bias ptr + mov w20, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w20, #16 + blt LoopD4 + +LoopD16: + ld1 {v0.16b, v1.16b}, [x17], #32 + ld1 {v2.16b, v3.16b}, [x28], #32 + + sdot v16.4s, v2.16b, v0.4b[0] + sdot v18.4s, v2.16b, v0.4b[1] + sdot v20.4s, v2.16b, v0.4b[2] + sdot v22.4s, v2.16b, v0.4b[3] + + ld1 {v4.16b, v5.16b}, [x17], #32 + sdot v24.4s, v2.16b, v1.4b[0] + sdot v26.4s, v2.16b, v1.4b[1] + sdot v28.4s, v2.16b, v1.4b[2] + sdot v30.4s, v2.16b, v1.4b[3] + + ld1 {v6.16b, v7.16b}, [x28], #32 + sdot v17.4s, v3.16b, v0.4b[0] + sdot v19.4s, v3.16b, v0.4b[1] + sdot v21.4s, v3.16b, v0.4b[2] + sdot v23.4s, v3.16b, v0.4b[3] + + sdot v25.4s, v3.16b, v1.4b[0] + sdot v27.4s, v3.16b, v1.4b[1] + sdot v29.4s, v3.16b, v1.4b[2] + sdot v31.4s, v3.16b, v1.4b[3] + + ld1 {v8.16b, v9.16b}, [x17], #32 + sdot v16.4s, v6.16b, v4.4b[0] + sdot v18.4s, v6.16b, v4.4b[1] + sdot v20.4s, v6.16b, v4.4b[2] + sdot v22.4s, v6.16b, v4.4b[3] + + sdot v24.4s, v6.16b, v5.4b[0] + sdot v26.4s, v6.16b, v5.4b[1] + sdot v28.4s, v6.16b, v5.4b[2] + sdot v30.4s, v6.16b, v5.4b[3] + + ld1 {v10.16b, v11.16b}, [x28], #32 + sdot v17.4s, v7.16b, v4.4b[0] + sdot v19.4s, v7.16b, v4.4b[1] + sdot v21.4s, v7.16b, v4.4b[2] + sdot v23.4s, v7.16b, v4.4b[3] + + sdot v25.4s, v7.16b, v5.4b[0] + sdot v27.4s, v7.16b, v5.4b[1] + sdot v29.4s, v7.16b, v5.4b[2] + sdot v31.4s, v7.16b, v5.4b[3] + + ld1 {v12.16b, v13.16b}, [x17], #32 + sdot v16.4s, v10.16b, v8.4b[0] + sdot v18.4s, v10.16b, v8.4b[1] + sdot v20.4s, v10.16b, v8.4b[2] + sdot v22.4s, v10.16b, v8.4b[3] + + sdot v24.4s, v10.16b, v9.4b[0] + sdot v26.4s, v10.16b, v9.4b[1] + sdot v28.4s, v10.16b, v9.4b[2] + sdot v30.4s, v10.16b, v9.4b[3] + + ld1 {v14.16b, v15.16b}, [x28], #32 + sdot v17.4s, v11.16b, v8.4b[0] + sdot v19.4s, v11.16b, v8.4b[1] + sdot v21.4s, v11.16b, v8.4b[2] + sdot v23.4s, v11.16b, v8.4b[3] + + sdot v25.4s, v11.16b, v9.4b[0] + sdot v27.4s, v11.16b, v9.4b[1] + sdot v29.4s, v11.16b, v9.4b[2] + sdot v31.4s, v11.16b, v9.4b[3] + + sdot v16.4s, v14.16b, v12.4b[0] + sdot v18.4s, v14.16b, v12.4b[1] + sdot v20.4s, v14.16b, v12.4b[2] + sdot v22.4s, v14.16b, v12.4b[3] + + sdot v24.4s, v14.16b, v13.4b[0] + sdot v26.4s, v14.16b, v13.4b[1] + sdot v28.4s, v14.16b, v13.4b[2] + sdot v30.4s, v14.16b, v13.4b[3] + + sdot v17.4s, v15.16b, v12.4b[0] + sdot v19.4s, v15.16b, v12.4b[1] + sdot v21.4s, v15.16b, v12.4b[2] + sdot v23.4s, v15.16b, v12.4b[3] + + sdot v25.4s, v15.16b, v13.4b[0] + sdot v27.4s, v15.16b, v13.4b[1] + sdot v29.4s, v15.16b, v13.4b[2] + sdot v31.4s, v15.16b, v13.4b[3] + + subs w20, w20, #16 // depth - 16 + b L3 + +LoopD4: + cmp w20, #0 + beq End3 + + ld1 {v0.16b, v1.16b}, [x17], #32 + ld1 {v2.16b, v3.16b}, [x28], #32 + + sdot v16.4s, v2.16b, v0.4b[0] + sdot v18.4s, v2.16b, v0.4b[1] + sdot v20.4s, v2.16b, v0.4b[2] + sdot v22.4s, v2.16b, v0.4b[3] + sdot v24.4s, v2.16b, v1.4b[0] + sdot v26.4s, v2.16b, v1.4b[1] + sdot v28.4s, v2.16b, v1.4b[2] + sdot v30.4s, v2.16b, v1.4b[3] + sdot v17.4s, v3.16b, v0.4b[0] + sdot v19.4s, v3.16b, v0.4b[1] + sdot v21.4s, v3.16b, v0.4b[2] + sdot v23.4s, v3.16b, v0.4b[3] + sdot v25.4s, v3.16b, v1.4b[0] + sdot v27.4s, v3.16b, v1.4b[1] + sdot v29.4s, v3.16b, v1.4b[2] + sdot v31.4s, v3.16b, v1.4b[3] + + subs w20, w20, #4 // depth - 4 + b LoopD4 + +End3: + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x19], #16 + ld1 {v14.4s}, [x19], #16 + add v16.4s, v16.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v20.4s, v20.4s, v15.4s + add v22.4s, v22.4s, v15.4s + add v24.4s, v24.4s, v15.4s + add v26.4s, v26.4s, v15.4s + add v28.4s, v28.4s, v15.4s + add v30.4s, v30.4s, v15.4s + add v17.4s, v17.4s, v14.4s + add v19.4s, v19.4s, v14.4s + add v21.4s, v21.4s, v14.4s + add v23.4s, v23.4s, v14.4s + add v25.4s, v25.4s, v14.4s + add v27.4s, v27.4s, v14.4s + add v29.4s, v29.4s, v14.4s + add v31.4s, v31.4s, v14.4s + + cmp w27, #0 + beq PerTSumLoad +PerCSumLoad: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + b ApplySum +PerTSumLoad: + ld1 {v14.4s}, [x22], #16 + ld1 {v15.4s}, [x22], #16 + dup v0.4s, v14.s[0] + dup v1.4s, v14.s[0] + dup v2.4s, v14.s[1] + dup v3.4s, v14.s[1] + dup v4.4s, v14.s[2] + dup v5.4s, v14.s[2] + dup v6.4s, v14.s[3] + dup v7.4s, v14.s[3] + dup v8.4s, v15.s[0] + dup v9.4s, v15.s[0] + dup v10.4s, v15.s[1] + dup v11.4s, v15.s[1] + dup v12.4s, v15.s[2] + dup v13.4s, v15.s[2] + dup v14.4s, v15.s[3] + dup v15.4s, v14.s[0] +ApplySum: + // Subtract (Asums*Zb) + sub v16.4s, v16.4s, v0.4s + sub v17.4s, v17.4s, v1.4s + sub v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v3.4s + sub v20.4s, v20.4s, v4.4s + sub v21.4s, v21.4s, v5.4s + sub v22.4s, v22.4s, v6.4s + sub v23.4s, v23.4s, v7.4s + sub v24.4s, v24.4s, v8.4s + sub v25.4s, v25.4s, v9.4s + sub v26.4s, v26.4s, v10.4s + sub v27.4s, v27.4s, v11.4s + sub v28.4s, v28.4s, v12.4s + sub v29.4s, v29.4s, v13.4s + sub v30.4s, v30.4s, v14.4s + sub v31.4s, v31.4s, v15.4s + + cmp w27, #0 + beq PerTRoundLoad +PerCRoundLoad: + ld1 {v8.4s, v9.4s}, [x12] + ld1 {v10.4s, v11.4s}, [x11] + ld1 {v12.4s, v13.4s}, [x13] + b ApplyRound +PerTRoundLoad: + ld1 {v14.s}[0], [x12] + dup v8.4s, v14.s[0] + dup v9.4s, v14.s[0] + ld1 {v14.s}[0], [x11] + dup v10.4s, v14.s[0] + dup v11.4s, v14.s[0] + ld1 {v14.s}[0], [x13] + dup v12.4s, v14.s[0] + dup v13.4s, v14.s[0] +ApplyRound: + // Apply left shift + sqshl v16.4s, v16.4s, v8.4s + sqshl v17.4s, v17.4s, v9.4s + sqshl v18.4s, v18.4s, v8.4s + sqshl v19.4s, v19.4s, v9.4s + sqshl v20.4s, v20.4s, v8.4s + sqshl v21.4s, v21.4s, v9.4s + sqshl v22.4s, v22.4s, v8.4s + sqshl v23.4s, v23.4s, v9.4s + sqshl v24.4s, v24.4s, v8.4s + sqshl v25.4s, v25.4s, v9.4s + sqshl v26.4s, v26.4s, v8.4s + sqshl v27.4s, v27.4s, v9.4s + sqshl v28.4s, v28.4s, v8.4s + sqshl v29.4s, v29.4s, v9.4s + sqshl v30.4s, v30.4s, v8.4s + sqshl v31.4s, v31.4s, v9.4s + + // Apply the fixed-point part of the multiplier. + sqrdmulh v16.4s, v16.4s, v10.4s + sqrdmulh v17.4s, v17.4s, v11.4s + sqrdmulh v18.4s, v18.4s, v10.4s + sqrdmulh v19.4s, v19.4s, v11.4s + sqrdmulh v20.4s, v20.4s, v10.4s + sqrdmulh v21.4s, v21.4s, v11.4s + sqrdmulh v22.4s, v22.4s, v10.4s + sqrdmulh v23.4s, v23.4s, v11.4s + sqrdmulh v24.4s, v24.4s, v10.4s + sqrdmulh v25.4s, v25.4s, v11.4s + sqrdmulh v26.4s, v26.4s, v10.4s + sqrdmulh v27.4s, v27.4s, v11.4s + sqrdmulh v28.4s, v28.4s, v10.4s + sqrdmulh v29.4s, v29.4s, v11.4s + sqrdmulh v30.4s, v30.4s, v10.4s + sqrdmulh v31.4s, v31.4s, v11.4s + + // Apply right shift + and v0.16b, v12.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v12.4s + and v1.16b, v13.16b, v17.16b + sshr v1.4s, v1.4s, #31 + sqadd v17.4s, v17.4s, v1.4s + srshl v17.4s, v17.4s, v13.4s + and v2.16b, v12.16b, v18.16b + sshr v2.4s, v2.4s, #31 + sqadd v18.4s, v18.4s, v2.4s + srshl v18.4s, v18.4s, v12.4s + and v3.16b, v13.16b, v19.16b + sshr v3.4s, v3.4s, #31 + sqadd v19.4s, v19.4s, v3.4s + srshl v19.4s, v19.4s, v13.4s + and v0.16b, v12.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v12.4s + and v1.16b, v13.16b, v21.16b + sshr v1.4s, v1.4s, #31 + sqadd v21.4s, v21.4s, v1.4s + srshl v21.4s, v21.4s, v13.4s + and v2.16b, v12.16b, v22.16b + sshr v2.4s, v2.4s, #31 + sqadd v22.4s, v22.4s, v2.4s + srshl v22.4s, v22.4s, v12.4s + and v3.16b, v13.16b, v23.16b + sshr v3.4s, v3.4s, #31 + sqadd v23.4s, v23.4s, v3.4s + srshl v23.4s, v23.4s, v13.4s + and v0.16b, v12.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v12.4s + and v1.16b, v13.16b, v25.16b + sshr v1.4s, v1.4s, #31 + sqadd v25.4s, v25.4s, v1.4s + srshl v25.4s, v25.4s, v13.4s + and v2.16b, v12.16b, v26.16b + sshr v2.4s, v2.4s, #31 + sqadd v26.4s, v26.4s, v2.4s + srshl v26.4s, v26.4s, v12.4s + and v3.16b, v13.16b, v27.16b + sshr v3.4s, v3.4s, #31 + sqadd v27.4s, v27.4s, v3.4s + srshl v27.4s, v27.4s, v13.4s + and v0.16b, v12.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v12.4s + and v1.16b, v13.16b, v29.16b + sshr v1.4s, v1.4s, #31 + sqadd v29.4s, v29.4s, v1.4s + srshl v29.4s, v29.4s, v13.4s + and v2.16b, v12.16b, v30.16b + sshr v2.4s, v2.4s, #31 + sqadd v30.4s, v30.4s, v2.4s + srshl v30.4s, v30.4s, v12.4s + and v3.16b, v13.16b, v31.16b + sshr v3.4s, v3.4s, #31 + sqadd v31.4s, v31.4s, v3.4s + srshl v31.4s, v31.4s, v13.4s + + // Add the destination zero point + dup v8.4s, w10 + add v16.4s, v16.4s, v8.4s + add v17.4s, v17.4s, v8.4s + add v18.4s, v18.4s, v8.4s + add v19.4s, v19.4s, v8.4s + add v20.4s, v20.4s, v8.4s + add v21.4s, v21.4s, v8.4s + add v22.4s, v22.4s, v8.4s + add v23.4s, v23.4s, v8.4s + add v24.4s, v24.4s, v8.4s + add v25.4s, v25.4s, v8.4s + add v26.4s, v26.4s, v8.4s + add v27.4s, v27.4s, v8.4s + add v28.4s, v28.4s, v8.4s + add v29.4s, v29.4s, v8.4s + add v30.4s, v30.4s, v8.4s + add v31.4s, v31.4s, v8.4s + + // Apply the act_min bound + dup v7.4s, w8 + smax v16.4s, v16.4s, v7.4s + smax v17.4s, v17.4s, v7.4s + smax v18.4s, v18.4s, v7.4s + smax v19.4s, v19.4s, v7.4s + smax v20.4s, v20.4s, v7.4s + smax v21.4s, v21.4s, v7.4s + smax v22.4s, v22.4s, v7.4s + smax v23.4s, v23.4s, v7.4s + smax v24.4s, v24.4s, v7.4s + smax v25.4s, v25.4s, v7.4s + smax v26.4s, v26.4s, v7.4s + smax v27.4s, v27.4s, v7.4s + smax v28.4s, v28.4s, v7.4s + smax v29.4s, v29.4s, v7.4s + smax v30.4s, v30.4s, v7.4s + smax v31.4s, v31.4s, v7.4s + + // Apply the act_max bound + dup v6.4s, w9 + smin v16.4s, v16.4s, v6.4s + smin v17.4s, v17.4s, v6.4s + smin v18.4s, v18.4s, v6.4s + smin v19.4s, v19.4s, v6.4s + smin v20.4s, v20.4s, v6.4s + smin v21.4s, v21.4s, v6.4s + smin v22.4s, v22.4s, v6.4s + smin v23.4s, v23.4s, v6.4s + smin v24.4s, v24.4s, v6.4s + smin v25.4s, v25.4s, v6.4s + smin v26.4s, v26.4s, v6.4s + smin v27.4s, v27.4s, v6.4s + smin v28.4s, v28.4s, v6.4s + smin v29.4s, v29.4s, v6.4s + smin v30.4s, v30.4s, v6.4s + smin v31.4s, v31.4s, v6.4s + + // int32 -> int16 + sqxtn v0.4h, v16.4s + sqxtn2 v0.8h, v17.4s + sqxtn v1.4h, v18.4s + sqxtn2 v1.8h, v19.4s + sqxtn v2.4h, v20.4s + sqxtn2 v2.8h, v21.4s + sqxtn v3.4h, v22.4s + sqxtn2 v3.8h, v23.4s + sqxtn v4.4h, v24.4s + sqxtn2 v4.8h, v25.4s + sqxtn v5.4h, v26.4s + sqxtn2 v5.8h, v27.4s + sqxtn v6.4h, v28.4s + sqxtn2 v6.8h, v29.4s + sqxtn v7.4h, v30.4s + sqxtn2 v7.8h, v31.4s + + // int16 -> int8 + sqxtn v8.8b, v0.8h + sqxtn2 v8.16b, v1.8h + sqxtn v9.8b, v2.8h + sqxtn2 v9.16b, v3.8h + sqxtn v10.8b, v4.8h + sqxtn2 v10.16b, v5.8h + sqxtn v11.8b, v6.8h + sqxtn2 v11.16b, v7.8h + + cmp w23, #8 + blt Write // if rows < 8 + cmp w15, #8 + blt Write // if cols < 8 + + st1 {v8.d}[0], [x2], x24 + st1 {v8.d}[1], [x2], x24 + st1 {v9.d}[0], [x2], x24 + st1 {v9.d}[1], [x2], x24 + st1 {v10.d}[0], [x2], x24 + st1 {v10.d}[1], [x2], x24 + st1 {v11.d}[0], [x2], x24 + st1 {v11.d}[1], [x2], x24 + b Endwrite + +Write: + cmp w15, #8 + bge WriteCol8 + cmp w15, #7 + beq WriteCol7 + cmp w15, #6 + beq WriteCol6 + cmp w15, #5 + beq WriteCol5 + cmp w15, #4 + beq WriteCol4 + cmp w15, #3 + beq WriteCol3 + cmp w15, #2 + beq WriteCol2 + cmp w15, #1 + beq WriteCol1 + +WriteCol8: + st1 {v8.d}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v8.d}[1], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v9.d}[0], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v9.d}[1], [x2], x24 + cmp w23, #4 + beq Endwrite + st1 {v10.d}[0], [x2], x24 + cmp w23, #5 + beq Endwrite + st1 {v10.d}[1], [x2], x24 + cmp w23, #6 + beq Endwrite + st1 {v11.d}[0], [x2], x24 + cmp w23, #7 + beq Endwrite + st1 {v11.d}[1], [x2], x24 + b Endwrite + +WriteCol7: + mov x26, x2 + st1 {v8.s}[0], [x26], #4 + st1 {v8.h}[2], [x26], #2 + st1 {v8.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v8.s}[2], [x26], #4 + st1 {v8.h}[6], [x26], #2 + st1 {v8.b}[14], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v9.s}[0], [x26], #4 + st1 {v9.h}[2], [x26], #2 + st1 {v9.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v9.s}[2], [x26], #4 + st1 {v9.h}[6], [x26], #2 + st1 {v9.b}[14], [x26], #1 + add x2, x2, x24 + cmp w23, #4 + beq Endwrite + mov x26, x2 + st1 {v10.s}[0], [x26], #4 + st1 {v10.h}[2], [x26], #2 + st1 {v10.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #5 + beq Endwrite + mov x26, x2 + st1 {v10.s}[2], [x26], #4 + st1 {v10.h}[6], [x26], #2 + st1 {v10.b}[14], [x26], #1 + add x2, x2, x24 + cmp w23, #6 + beq Endwrite + mov x26, x2 + st1 {v11.s}[0], [x26], #4 + st1 {v11.h}[2], [x26], #2 + st1 {v11.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #7 + beq Endwrite + mov x26, x2 + st1 {v11.s}[2], [x26], #4 + st1 {v11.h}[6], [x26], #2 + st1 {v11.b}[14], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol6: + mov x26, x2 + st1 {v8.s}[0], [x26], #4 + st1 {v8.h}[2], [x26], #2 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v8.s}[2], [x26], #4 + st1 {v8.h}[6], [x26], #2 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v9.s}[0], [x26], #4 + st1 {v9.h}[2], [x26], #2 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v9.s}[2], [x26], #4 + st1 {v9.h}[6], [x26], #2 + add x2, x2, x24 + cmp w23, #4 + beq Endwrite + mov x26, x2 + st1 {v10.s}[0], [x26], #4 + st1 {v10.h}[2], [x26], #2 + add x2, x2, x24 + cmp w23, #5 + beq Endwrite + mov x26, x2 + st1 {v10.s}[2], [x26], #4 + st1 {v10.h}[6], [x26], #2 + add x2, x2, x24 + cmp w23, #6 + beq Endwrite + mov x26, x2 + st1 {v11.s}[0], [x26], #4 + st1 {v11.h}[2], [x26], #2 + add x2, x2, x24 + cmp w23, #7 + beq Endwrite + mov x26, x2 + st1 {v11.s}[2], [x26], #4 + st1 {v11.h}[6], [x26], #2 + add x2, x2, x24 + b Endwrite + +WriteCol5: + mov x26, x2 + st1 {v8.s}[0], [x26], #4 + st1 {v8.b}[4], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v8.s}[2], [x26], #4 + st1 {v8.b}[12], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v9.s}[0], [x26], #4 + st1 {v9.b}[4], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v9.s}[2], [x26], #4 + st1 {v9.b}[12], [x26], #1 + add x2, x2, x24 + cmp w23, #4 + beq Endwrite + mov x26, x2 + st1 {v10.s}[0], [x26], #4 + st1 {v10.b}[4], [x26], #1 + add x2, x2, x24 + cmp w23, #5 + beq Endwrite + mov x26, x2 + st1 {v10.s}[2], [x26], #4 + st1 {v10.b}[12], [x26], #1 + add x2, x2, x24 + cmp w23, #6 + beq Endwrite + mov x26, x2 + st1 {v11.s}[0], [x26], #4 + st1 {v11.b}[4], [x26], #1 + add x2, x2, x24 + cmp w23, #7 + beq Endwrite + mov x26, x2 + st1 {v11.s}[2], [x26], #4 + st1 {v11.b}[12], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol4: + st1 {v8.s}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v8.s}[2], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v9.s}[0], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v9.s}[2], [x2], x24 + cmp w23, #4 + beq Endwrite + st1 {v10.s}[0], [x2], x24 + cmp w23, #5 + beq Endwrite + st1 {v10.s}[2], [x2], x24 + cmp w23, #6 + beq Endwrite + st1 {v11.s}[0], [x2], x24 + cmp w23, #7 + beq Endwrite + st1 {v11.s}[2], [x2], x24 + b Endwrite + +WriteCol3: + mov x26, x2 + st1 {v8.h}[0], [x26], #2 + st1 {v8.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v8.h}[4], [x26], #2 + st1 {v8.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v9.h}[0], [x26], #2 + st1 {v9.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v9.h}[4], [x26], #2 + st1 {v9.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #4 + beq Endwrite + mov x26, x2 + st1 {v10.h}[0], [x26], #2 + st1 {v10.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #5 + beq Endwrite + mov x26, x2 + st1 {v10.h}[4], [x26], #2 + st1 {v10.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #6 + beq Endwrite + mov x26, x2 + st1 {v11.h}[0], [x26], #2 + st1 {v11.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #7 + beq Endwrite + mov x26, x2 + st1 {v11.h}[4], [x26], #2 + st1 {v11.b}[10], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol2: + st1 {v8.h}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v8.h}[4], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v9.h}[0], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v9.h}[4], [x2], x24 + cmp w23, #4 + beq Endwrite + st1 {v10.h}[0], [x2], x24 + cmp w23, #5 + beq Endwrite + st1 {v10.h}[4], [x2], x24 + cmp w23, #6 + beq Endwrite + st1 {v11.h}[0], [x2], x24 + cmp w23, #7 + beq Endwrite + st1 {v11.h}[4], [x2], x24 + b Endwrite + +WriteCol1: + st1 {v8.b}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v8.b}[8], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v9.b}[0], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v9.b}[8], [x2], x24 + cmp w23, #4 + beq Endwrite + st1 {v10.b}[0], [x2], x24 + cmp w23, #5 + beq Endwrite + st1 {v10.b}[8], [x2], x24 + cmp w23, #6 + beq Endwrite + st1 {v11.b}[0], [x2], x24 + cmp w23, #7 + beq Endwrite + st1 {v11.b}[8], [x2], x24 + b Endwrite + +Endwrite: + sub w16, w16, #8 // a row8 counter - 8 + sub w23, w23, #8 // a row counter - 8 + b L2 + +End2: + sub w4, w4, #8 // b col8 counter - 8 + sub w15, w15, #8 // b col counter - 8 + add x1, x1, x21 // b ptr + stride + add x7, x7, #32 // bias ptr + stride + add x25, x25, #8 // output + stride(8 * sizeof(int8)) + mov x2, x25 + + cmp w27, #0 + beq PerTEnd2 + add x12, x12, #32 + add x11, x11, #32 + add x13, x13, #32 +PerTEnd2: + b L1 + +End1: + sub sp, sp, #208 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/MatmulDpInt8Opt.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/MatmulDpInt8Opt.S new file mode 100644 index 0000000000000000000000000000000000000000..ef0e11d6a7029e1898c056953cf201f785fe6978 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/MatmulDpInt8Opt.S @@ -0,0 +1,1098 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" +.text +.align 5 + +//void MatmulInt8DpOpt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep4, const int *a_sums, +// const int *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier, +// const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc, +// const int32_t *filter_zp) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: row +// x4: col +// x5: deep4 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// x14: stride +// x15: filter_peroc +// x28: filter_zp + +asm_function MatmulInt8DpOpt + sub sp, sp, #224 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + stp x25, x26, [sp], #16 + stp x27, x28, [sp], #16 + stp x29, x30, [sp], #16 + + ldr w8, [sp] + ldr w9, [sp, #8] + ldr w10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + ldr x14, [sp, #48] + ldr x15, [sp, #56] + + mov x23, #4 + mul x23, x23, x5 // lhs step + mov x24, #4 + mul x24, x24, x14 // dst step + +LoopRow: + mov x16, x1 // reload rhs ptr + mov x17, x4 // reload rhs col + mov x29, x7 // reload bias ptr + mov x25, x6 // reload input_sum ptr + mov x27, x2 // reload dst ptr + ldr x28, [sp, #64] // reload filter_zp + + LoopCol: + mov x19, x27 // reload dst ptr + mov x20, x0 // reload lhs ptr + mov x21, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + cmp x17, #4 + ble LoopDepthQuarter + cmp x17, #8 + ble LoopDepthHalf + + LoopDepth: + ld1 {v0.16b}, [x20], #16 + ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x16], #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v18.4s, v3.16b, v0.4b[0] + sdot v19.4s, v4.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v22.4s, v3.16b, v0.4b[1] + sdot v23.4s, v4.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v26.4s, v3.16b, v0.4b[2] + sdot v27.4s, v4.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + sdot v30.4s, v3.16b, v0.4b[3] + sdot v31.4s, v4.16b, v0.4b[3] + + subs x21, x21, #4 + bgt LoopDepth + + Bias: + cbz x7, NoReadBias + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x29], #64 + add v16.4s, v16.4s, v0.4s + add v17.4s, v17.4s, v1.4s + add v18.4s, v18.4s, v2.4s + add v19.4s, v19.4s, v3.4s + add v20.4s, v20.4s, v0.4s + add v21.4s, v21.4s, v1.4s + add v22.4s, v22.4s, v2.4s + add v23.4s, v23.4s, v3.4s + add v24.4s, v24.4s, v0.4s + add v25.4s, v25.4s, v1.4s + add v26.4s, v26.4s, v2.4s + add v27.4s, v27.4s, v3.4s + add v28.4s, v28.4s, v0.4s + add v29.4s, v29.4s, v1.4s + add v30.4s, v30.4s, v2.4s + add v31.4s, v31.4s, v3.4s + + NoReadBias: + ld1r {v12.4s}, [x25], #4 + ld1r {v13.4s}, [x25], #4 + ld1r {v14.4s}, [x25], #4 + ld1r {v15.4s}, [x25], #4 + cbnz x15, PerChannelSum + + PerTensorSum: + sub v16.4s, v16.4s, v12.4s + sub v17.4s, v17.4s, v12.4s + sub v18.4s, v18.4s, v12.4s + sub v19.4s, v19.4s, v12.4s + sub v20.4s, v20.4s, v13.4s + sub v21.4s, v21.4s, v13.4s + sub v22.4s, v22.4s, v13.4s + sub v23.4s, v23.4s, v13.4s + sub v24.4s, v24.4s, v14.4s + sub v25.4s, v25.4s, v14.4s + sub v26.4s, v26.4s, v14.4s + sub v27.4s, v27.4s, v14.4s + sub v28.4s, v28.4s, v15.4s + sub v29.4s, v29.4s, v15.4s + sub v30.4s, v30.4s, v15.4s + sub v31.4s, v31.4s, v15.4s + + b PerTensor + + PerChannelSum: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x28], #64 + mul v0.4s, v8.4s, v12.4s + mul v1.4s, v9.4s, v12.4s + mul v2.4s, v10.4s, v12.4s + mul v3.4s, v11.4s, v12.4s + mul v4.4s, v8.4s, v13.4s + mul v5.4s, v9.4s, v13.4s + mul v6.4s, v10.4s, v13.4s + mul v7.4s, v11.4s, v13.4s + sub v16.4s, v16.4s, v0.4s + sub v17.4s, v17.4s, v1.4s + sub v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v3.4s + sub v20.4s, v20.4s, v4.4s + sub v21.4s, v21.4s, v5.4s + sub v22.4s, v22.4s, v6.4s + sub v23.4s, v23.4s, v7.4s + mul v0.4s, v8.4s, v14.4s + mul v1.4s, v9.4s, v14.4s + mul v2.4s, v10.4s, v14.4s + mul v3.4s, v11.4s, v14.4s + mul v4.4s, v8.4s, v15.4s + mul v5.4s, v9.4s, v15.4s + mul v6.4s, v10.4s, v15.4s + mul v7.4s, v11.4s, v15.4s + sub v24.4s, v24.4s, v0.4s + sub v25.4s, v25.4s, v1.4s + sub v26.4s, v26.4s, v2.4s + sub v27.4s, v27.4s, v3.4s + sub v28.4s, v28.4s, v4.4s + sub v29.4s, v29.4s, v5.4s + sub v30.4s, v30.4s, v6.4s + sub v31.4s, v31.4s, v7.4s + + PerTensor: + cbnz x15, PerChannel + ld1r {v0.4s}, [x12] + mov v1.16b, v0.16b + mov v2.16b, v0.16b + mov v3.16b, v0.16b + ld1r {v4.4s}, [x11] + mov v5.16b, v4.16b + mov v6.16b, v4.16b + mov v7.16b, v4.16b + ld1r {v8.4s}, [x13] + mov v9.16b, v8.16b + mov v10.16b, v8.16b + mov v11.16b, v8.16b + + b Quantization + + PerChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x11], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x13], #64 + + Quantization: + sqshl v16.4s, v16.4s, v0.4s + sqshl v17.4s, v17.4s, v1.4s + sqshl v18.4s, v18.4s, v2.4s + sqshl v19.4s, v19.4s, v3.4s + sqshl v20.4s, v20.4s, v0.4s + sqshl v21.4s, v21.4s, v1.4s + sqshl v22.4s, v22.4s, v2.4s + sqshl v23.4s, v23.4s, v3.4s + sqshl v24.4s, v24.4s, v0.4s + sqshl v25.4s, v25.4s, v1.4s + sqshl v26.4s, v26.4s, v2.4s + sqshl v27.4s, v27.4s, v3.4s + sqshl v28.4s, v28.4s, v0.4s + sqshl v29.4s, v29.4s, v1.4s + sqshl v30.4s, v30.4s, v2.4s + sqshl v31.4s, v31.4s, v3.4s + + sqrdmulh v16.4s, v16.4s, v4.4s + sqrdmulh v17.4s, v17.4s, v5.4s + sqrdmulh v18.4s, v18.4s, v6.4s + sqrdmulh v19.4s, v19.4s, v7.4s + sqrdmulh v20.4s, v20.4s, v4.4s + sqrdmulh v21.4s, v21.4s, v5.4s + sqrdmulh v22.4s, v22.4s, v6.4s + sqrdmulh v23.4s, v23.4s, v7.4s + sqrdmulh v24.4s, v24.4s, v4.4s + sqrdmulh v25.4s, v25.4s, v5.4s + sqrdmulh v26.4s, v26.4s, v6.4s + sqrdmulh v27.4s, v27.4s, v7.4s + sqrdmulh v28.4s, v28.4s, v4.4s + sqrdmulh v29.4s, v29.4s, v5.4s + sqrdmulh v30.4s, v30.4s, v6.4s + sqrdmulh v31.4s, v31.4s, v7.4s + + and v0.16b, v8.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v8.4s + and v1.16b, v9.16b, v17.16b + sshr v1.4s, v1.4s, #31 + sqadd v17.4s, v17.4s, v1.4s + srshl v17.4s, v17.4s, v9.4s + and v2.16b, v10.16b, v18.16b + sshr v2.4s, v2.4s, #31 + sqadd v18.4s, v18.4s, v2.4s + srshl v18.4s, v18.4s, v10.4s + and v3.16b, v11.16b, v19.16b + sshr v3.4s, v3.4s, #31 + sqadd v19.4s, v19.4s, v3.4s + srshl v19.4s, v19.4s, v11.4s + + and v0.16b, v8.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v8.4s + and v1.16b, v9.16b, v21.16b + sshr v1.4s, v1.4s, #31 + sqadd v21.4s, v21.4s, v1.4s + srshl v21.4s, v21.4s, v9.4s + and v2.16b, v10.16b, v22.16b + sshr v2.4s, v2.4s, #31 + sqadd v22.4s, v22.4s, v2.4s + srshl v22.4s, v22.4s, v10.4s + and v3.16b, v11.16b, v23.16b + sshr v3.4s, v3.4s, #31 + sqadd v23.4s, v23.4s, v3.4s + srshl v23.4s, v23.4s, v11.4s + + and v0.16b, v8.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v8.4s + and v1.16b, v9.16b, v25.16b + sshr v1.4s, v1.4s, #31 + sqadd v25.4s, v25.4s, v1.4s + srshl v25.4s, v25.4s, v9.4s + and v2.16b, v10.16b, v26.16b + sshr v2.4s, v2.4s, #31 + sqadd v26.4s, v26.4s, v2.4s + srshl v26.4s, v26.4s, v10.4s + and v3.16b, v11.16b, v27.16b + sshr v3.4s, v3.4s, #31 + sqadd v27.4s, v27.4s, v3.4s + srshl v27.4s, v27.4s, v11.4s + + and v0.16b, v8.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v8.4s + and v1.16b, v9.16b, v29.16b + sshr v1.4s, v1.4s, #31 + sqadd v29.4s, v29.4s, v1.4s + srshl v29.4s, v29.4s, v9.4s + and v2.16b, v10.16b, v30.16b + sshr v2.4s, v2.4s, #31 + sqadd v30.4s, v30.4s, v2.4s + srshl v30.4s, v30.4s, v10.4s + and v3.16b, v11.16b, v31.16b + sshr v3.4s, v3.4s, #31 + sqadd v31.4s, v31.4s, v3.4s + srshl v31.4s, v31.4s, v11.4s + + // zp + dup v6.4s, w10 + add v16.4s, v16.4s, v6.4s + add v17.4s, v17.4s, v6.4s + add v18.4s, v18.4s, v6.4s + add v19.4s, v19.4s, v6.4s + add v20.4s, v20.4s, v6.4s + add v21.4s, v21.4s, v6.4s + add v22.4s, v22.4s, v6.4s + add v23.4s, v23.4s, v6.4s + add v24.4s, v24.4s, v6.4s + add v25.4s, v25.4s, v6.4s + add v26.4s, v26.4s, v6.4s + add v27.4s, v27.4s, v6.4s + add v28.4s, v28.4s, v6.4s + add v29.4s, v29.4s, v6.4s + add v30.4s, v30.4s, v6.4s + add v31.4s, v31.4s, v6.4s + + // min + dup v0.4s, w8 + smax v16.4s, v16.4s, v0.4s + smax v17.4s, v17.4s, v0.4s + smax v18.4s, v18.4s, v0.4s + smax v19.4s, v19.4s, v0.4s + smax v20.4s, v20.4s, v0.4s + smax v21.4s, v21.4s, v0.4s + smax v22.4s, v22.4s, v0.4s + smax v23.4s, v23.4s, v0.4s + smax v24.4s, v24.4s, v0.4s + smax v25.4s, v25.4s, v0.4s + smax v26.4s, v26.4s, v0.4s + smax v27.4s, v27.4s, v0.4s + smax v28.4s, v28.4s, v0.4s + smax v29.4s, v29.4s, v0.4s + smax v30.4s, v30.4s, v0.4s + smax v31.4s, v31.4s, v0.4s + + // max + dup v1.4s, w9 + smin v16.4s, v16.4s, v1.4s + smin v17.4s, v17.4s, v1.4s + smin v18.4s, v18.4s, v1.4s + smin v19.4s, v19.4s, v1.4s + smin v20.4s, v20.4s, v1.4s + smin v21.4s, v21.4s, v1.4s + smin v22.4s, v22.4s, v1.4s + smin v23.4s, v23.4s, v1.4s + smin v24.4s, v24.4s, v1.4s + smin v25.4s, v25.4s, v1.4s + smin v26.4s, v26.4s, v1.4s + smin v27.4s, v27.4s, v1.4s + smin v28.4s, v28.4s, v1.4s + smin v29.4s, v29.4s, v1.4s + smin v30.4s, v30.4s, v1.4s + smin v31.4s, v31.4s, v1.4s + + sqxtn v16.4h, v16.4s + sqxtn2 v16.8h, v17.4s + sqxtn v0.8b, v16.8h + sqxtn v18.4h, v18.4s + sqxtn2 v18.8h, v19.4s + sqxtn2 v0.16b, v18.8h + + sqxtn v20.4h, v20.4s + sqxtn2 v20.8h, v21.4s + sqxtn v1.8b, v20.8h + sqxtn v22.4h, v22.4s + sqxtn2 v22.8h, v23.4s + sqxtn2 v1.16b, v22.8h + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v2.8b, v24.8h + sqxtn v26.4h, v26.4s + sqxtn2 v26.8h, v27.4s + sqxtn2 v2.16b, v26.8h + + sqxtn v28.4h, v28.4s + sqxtn2 v28.8h, v29.4s + sqxtn v3.8b, v28.8h + sqxtn v30.4h, v30.4s + sqxtn2 v30.8h, v31.4s + sqxtn2 v3.16b, v30.8h + + b WriteStart + + LoopDepthHalf: + ld1 {v0.16b}, [x20], #16 + ld1 {v1.16b, v2.16b}, [x16] + add x16, x16, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + + subs x21, x21, #4 + bgt LoopDepthHalf + + BiasHalf: + cbz x7, NoReadBiasHalf + ld1 {v0.4s, v1.4s}, [x29] + add x29, x29, #64 + add v16.4s, v16.4s, v0.4s + add v17.4s, v17.4s, v1.4s + add v20.4s, v20.4s, v0.4s + add v21.4s, v21.4s, v1.4s + add v24.4s, v24.4s, v0.4s + add v25.4s, v25.4s, v1.4s + add v28.4s, v28.4s, v0.4s + add v29.4s, v29.4s, v1.4s + + NoReadBiasHalf: + ld1r {v12.4s}, [x25], #4 + ld1r {v13.4s}, [x25], #4 + ld1r {v14.4s}, [x25], #4 + ld1r {v15.4s}, [x25], #4 + cbnz x15, PerChannelSumHalf + + PerTensorSumHalf: + sub v16.4s, v16.4s, v12.4s + sub v17.4s, v17.4s, v12.4s + sub v20.4s, v20.4s, v13.4s + sub v21.4s, v21.4s, v13.4s + sub v24.4s, v24.4s, v14.4s + sub v25.4s, v25.4s, v14.4s + sub v28.4s, v28.4s, v15.4s + sub v29.4s, v29.4s, v15.4s + + b PerTensorHalf + + PerChannelSumHalf: + ld1 {v8.4s, v9.4s}, [x28] + add x28, x28, #64 + mul v0.4s, v8.4s, v12.4s + mul v1.4s, v9.4s, v12.4s + mul v4.4s, v8.4s, v13.4s + mul v5.4s, v9.4s, v13.4s + sub v16.4s, v16.4s, v0.4s + sub v17.4s, v17.4s, v1.4s + sub v20.4s, v20.4s, v4.4s + sub v21.4s, v21.4s, v5.4s + mul v2.4s, v8.4s, v14.4s + mul v3.4s, v9.4s, v14.4s + mul v6.4s, v8.4s, v15.4s + mul v7.4s, v9.4s, v15.4s + sub v24.4s, v24.4s, v2.4s + sub v25.4s, v25.4s, v3.4s + sub v28.4s, v28.4s, v6.4s + sub v29.4s, v29.4s, v7.4s + + PerTensorHalf: + cbnz x15, PerChannelHalf + ld1r {v0.4s}, [x12] + mov v1.16b, v0.16b + ld1r {v4.4s}, [x11] + mov v5.16b, v4.16b + ld1r {v8.4s}, [x13] + mov v9.16b, v8.16b + + b QuantizationHalf + + PerChannelHalf: + ld1 {v0.4s, v1.4s}, [x12] + add x12, x12, #64 + ld1 {v4.4s, v5.4s}, [x11] + add x11, x11, #64 + ld1 {v8.4s, v9.4s}, [x13] + add x13, x13, #64 + + QuantizationHalf: + sqshl v16.4s, v16.4s, v0.4s + sqshl v17.4s, v17.4s, v1.4s + sqshl v20.4s, v20.4s, v0.4s + sqshl v21.4s, v21.4s, v1.4s + sqshl v24.4s, v24.4s, v0.4s + sqshl v25.4s, v25.4s, v1.4s + sqshl v28.4s, v28.4s, v0.4s + sqshl v29.4s, v29.4s, v1.4s + + sqrdmulh v16.4s, v16.4s, v4.4s + sqrdmulh v17.4s, v17.4s, v5.4s + sqrdmulh v20.4s, v20.4s, v4.4s + sqrdmulh v21.4s, v21.4s, v5.4s + sqrdmulh v24.4s, v24.4s, v4.4s + sqrdmulh v25.4s, v25.4s, v5.4s + sqrdmulh v28.4s, v28.4s, v4.4s + sqrdmulh v29.4s, v29.4s, v5.4s + + and v0.16b, v8.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v8.4s + and v1.16b, v9.16b, v17.16b + sshr v1.4s, v1.4s, #31 + sqadd v17.4s, v17.4s, v1.4s + srshl v17.4s, v17.4s, v9.4s + + and v0.16b, v8.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v8.4s + and v1.16b, v9.16b, v21.16b + sshr v1.4s, v1.4s, #31 + sqadd v21.4s, v21.4s, v1.4s + srshl v21.4s, v21.4s, v9.4s + + and v0.16b, v8.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v8.4s + and v1.16b, v9.16b, v25.16b + sshr v1.4s, v1.4s, #31 + sqadd v25.4s, v25.4s, v1.4s + srshl v25.4s, v25.4s, v9.4s + + and v0.16b, v8.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v8.4s + and v1.16b, v9.16b, v29.16b + sshr v1.4s, v1.4s, #31 + sqadd v29.4s, v29.4s, v1.4s + srshl v29.4s, v29.4s, v9.4s + + // zp + dup v6.4s, w10 + add v16.4s, v16.4s, v6.4s + add v17.4s, v17.4s, v6.4s + add v20.4s, v20.4s, v6.4s + add v21.4s, v21.4s, v6.4s + add v24.4s, v24.4s, v6.4s + add v25.4s, v25.4s, v6.4s + add v28.4s, v28.4s, v6.4s + add v29.4s, v29.4s, v6.4s + + // min + dup v0.4s, w8 + smax v16.4s, v16.4s, v0.4s + smax v17.4s, v17.4s, v0.4s + smax v20.4s, v20.4s, v0.4s + smax v21.4s, v21.4s, v0.4s + smax v24.4s, v24.4s, v0.4s + smax v25.4s, v25.4s, v0.4s + smax v28.4s, v28.4s, v0.4s + smax v29.4s, v29.4s, v0.4s + + // max + dup v1.4s, w9 + smin v16.4s, v16.4s, v1.4s + smin v17.4s, v17.4s, v1.4s + smin v20.4s, v20.4s, v1.4s + smin v21.4s, v21.4s, v1.4s + smin v24.4s, v24.4s, v1.4s + smin v25.4s, v25.4s, v1.4s + smin v28.4s, v28.4s, v1.4s + smin v29.4s, v29.4s, v1.4s + + sqxtn v16.4h, v16.4s + sqxtn2 v16.8h, v17.4s + sqxtn v0.8b, v16.8h + + sqxtn v20.4h, v20.4s + sqxtn2 v20.8h, v21.4s + sqxtn v1.8b, v20.8h + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v2.8b, v24.8h + + sqxtn v28.4h, v28.4s + sqxtn2 v28.8h, v29.4s + sqxtn v3.8b, v28.8h + + b WriteStart + + LoopDepthQuarter: + ld1 {v0.16b}, [x20], #16 + ld1 {v1.16b}, [x16] + add x16, x16, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + + subs x21, x21, #4 + bgt LoopDepthQuarter + + BiasQuarter: + cbz x7, NoReadBiasQuarter + ld1 {v0.4s}, [x29] + add x29, x29, #64 + add v16.4s, v16.4s, v0.4s + add v20.4s, v20.4s, v0.4s + add v24.4s, v24.4s, v0.4s + add v28.4s, v28.4s, v0.4s + + NoReadBiasQuarter: + ld1r {v12.4s}, [x25], #4 + ld1r {v13.4s}, [x25], #4 + ld1r {v14.4s}, [x25], #4 + ld1r {v15.4s}, [x25], #4 + cbnz x15, PerChannelSumQuarter + + PerTensorSumQuarter: + sub v16.4s, v16.4s, v12.4s + sub v20.4s, v20.4s, v13.4s + sub v24.4s, v24.4s, v14.4s + sub v28.4s, v28.4s, v15.4s + + b PerTensorQuarter + + PerChannelSumQuarter: + ld1 {v8.4s}, [x28] + add x28, x28, #64 + mul v0.4s, v8.4s, v12.4s + mul v4.4s, v8.4s, v13.4s + sub v16.4s, v16.4s, v0.4s + sub v20.4s, v20.4s, v4.4s + mul v2.4s, v8.4s, v14.4s + mul v6.4s, v8.4s, v15.4s + sub v24.4s, v24.4s, v2.4s + sub v28.4s, v28.4s, v6.4s + + PerTensorQuarter: + cbnz x15, PerChannelQuarter + ld1r {v0.4s}, [x12] + ld1r {v4.4s}, [x11] + ld1r {v8.4s}, [x13] + + b QuantizationHalf + + PerChannelQuarter: + ld1 {v0.4s}, [x12] + add x12, x12, #64 + ld1 {v4.4s}, [x11] + add x11, x11, #64 + ld1 {v8.4s}, [x13] + add x13, x13, #64 + + QuantizationQuarter: + sqshl v16.4s, v16.4s, v0.4s + sqshl v20.4s, v20.4s, v0.4s + sqshl v24.4s, v24.4s, v0.4s + sqshl v28.4s, v28.4s, v0.4s + + sqrdmulh v16.4s, v16.4s, v4.4s + sqrdmulh v20.4s, v20.4s, v4.4s + sqrdmulh v24.4s, v24.4s, v4.4s + sqrdmulh v28.4s, v28.4s, v4.4s + + and v0.16b, v8.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v8.4s + + and v0.16b, v8.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v8.4s + + and v0.16b, v8.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v8.4s + + and v0.16b, v8.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v8.4s + + // zp + dup v6.4s, w10 + add v16.4s, v16.4s, v6.4s + add v20.4s, v20.4s, v6.4s + add v24.4s, v24.4s, v6.4s + add v28.4s, v28.4s, v6.4s + + // min + dup v0.4s, w8 + smax v16.4s, v16.4s, v0.4s + smax v20.4s, v20.4s, v0.4s + smax v24.4s, v24.4s, v0.4s + smax v28.4s, v28.4s, v0.4s + + // max + dup v1.4s, w9 + smin v16.4s, v16.4s, v1.4s + smin v20.4s, v20.4s, v1.4s + smin v24.4s, v24.4s, v1.4s + smin v28.4s, v28.4s, v1.4s + + sqxtn v16.4h, v16.4s + sqxtn v0.8b, v16.8h + + sqxtn v20.4h, v20.4s + sqxtn v1.8b, v20.8h + + sqxtn v24.4h, v24.4s + sqxtn v2.8b, v24.8h + + sqxtn v28.4h, v28.4s + sqxtn v3.8b, v28.8h + + b WriteStart + + WriteStart: + cmp x17, #1 + beq Write1 + cmp x17, #2 + beq Write2 + cmp x17, #3 + beq Write3 + cmp x17, #4 + beq Write4 + cmp x17, #5 + beq Write5 + cmp x17, #6 + beq Write6 + cmp x17, #7 + beq Write7 + cmp x17, #8 + beq Write8 + cmp x17, #9 + beq Write9 + cmp x17, #10 + beq Write10 + cmp x17, #11 + beq Write11 + cmp x17, #12 + beq Write12 + cmp x17, #13 + beq Write13 + cmp x17, #14 + beq Write14 + cmp x17, #15 + beq Write15 + b Write16 + + Write1: + add x27, x27, #1 + st1 {v0.b}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.b}[0], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.b}[0], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.b}[0], [x19], x14 + b WriteEnd + Write2: + add x27, x27, #2 + st1 {v0.h}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.h}[0], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.h}[0], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.h}[0], [x19], x14 + b WriteEnd + Write3: + add x27, x27, #3 + add x22, x19, #2 + st1 {v0.h}[0], [x19], x14 + st1 {v0.b}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.h}[0], [x19], x14 + st1 {v1.b}[2], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.h}[0], [x19], x14 + st1 {v2.b}[2], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.h}[0], [x19], x14 + st1 {v3.b}[2], [x22], x14 + b WriteEnd + Write4: + add x27, x27, #4 + st1 {v0.s}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + b WriteEnd + Write5: + add x27, x27, #5 + add x22, x19, #4 + st1 {v0.s}[0], [x19], x14 + st1 {v0.b}[4], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + st1 {v1.b}[4], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + st1 {v2.b}[4], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + st1 {v3.b}[4], [x22], x14 + b WriteEnd + Write6: + add x27, x27, #6 + add x22, x19, #4 + st1 {v0.s}[0], [x19], x14 + st1 {v0.h}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + st1 {v1.h}[2], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + st1 {v2.h}[2], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + st1 {v3.h}[2], [x22], x14 + b WriteEnd + Write7: + add x27, x27, #7 + add x22, x19, #4 + add x26, x19, #6 + st1 {v0.s}[0], [x19], x14 + st1 {v0.h}[2], [x22], x14 + st1 {v0.b}[6], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + st1 {v1.h}[2], [x22], x14 + st1 {v1.b}[6], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + st1 {v2.h}[2], [x22], x14 + st1 {v2.b}[6], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + st1 {v3.h}[2], [x22], x14 + st1 {v3.b}[6], [x26], x14 + b WriteEnd + Write8: + add x27, x27, #8 + st1 {v0.8b}, [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + b WriteEnd + Write9: + add x27, x27, #9 + add x22, x19, #8 + st1 {v0.8b}, [x19], x14 + st1 {v0.b}[8], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.b}[8], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.b}[8], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.b}[8], [x22], x14 + b WriteEnd + Write10: + add x27, x27, #10 + add x22, x19, #8 + st1 {v0.8b}, [x19], x14 + st1 {v0.h}[4], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.h}[4], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.h}[4], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.h}[4], [x22], x14 + b WriteEnd + Write11: + add x27, x27, #11 + add x22, x19, #8 + add x26, x19, #10 + st1 {v0.8b}, [x19], x14 + st1 {v0.h}[4], [x22], x14 + st1 {v0.b}[10], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.h}[4], [x22], x14 + st1 {v1.b}[10], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.h}[4], [x22], x14 + st1 {v2.b}[10], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.h}[4], [x22], x14 + st1 {v3.b}[10], [x26], x14 + b WriteEnd + Write12: + add x27, x27, #12 + add x22, x19, #8 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + b WriteEnd + Write13: + add x27, x27, #13 + add x22, x19, #8 + add x26, x19, #12 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + st1 {v0.b}[12], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + st1 {v1.b}[12], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + st1 {v2.b}[12], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + st1 {v3.b}[12], [x26], x14 + b WriteEnd + Write14: + add x27, x27, #14 + add x22, x19, #8 + add x26, x19, #12 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + st1 {v0.h}[6], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + st1 {v1.h}[6], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + st1 {v2.h}[6], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + st1 {v3.h}[6], [x26], x14 + b WriteEnd + Write15: + add x27, x27, #15 + add x22, x19, #8 + add x26, x19, #12 + add x21, x19, #14 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + st1 {v0.h}[6], [x26], x14 + st1 {v0.b}[14], [x21], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + st1 {v1.h}[6], [x26], x14 + st1 {v1.b}[14], [x21], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + st1 {v2.h}[6], [x26], x14 + st1 {v2.b}[14], [x21], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + st1 {v3.h}[6], [x26], x14 + st1 {v3.b}[14], [x21], x14 + b WriteEnd + Write16: + add x27, x27, #16 + st1 {v0.16b}, [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.16b}, [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.16b}, [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.16b}, [x19], x14 + + WriteEnd: + subs x17, x17, #16 + ble LoopColEnd + mov x25, x6 + b LoopCol + +LoopColEnd: + subs x3, x3, #4 + ble LoopRowEnd + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + add x6, x6, #16 + add x0, x0, x23 + add x2, x2, x24 + b LoopRow + +LoopRowEnd: + sub sp, sp, #224 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ldp x29, x30, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/MatmulOptR4Int8.S b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/MatmulOptR4Int8.S new file mode 100644 index 0000000000000000000000000000000000000000..e746d7368bacba757ac5f9f59c49a04f83710d82 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly/opt/MatmulOptR4Int8.S @@ -0,0 +1,155 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_ARM64 +#include "nnacl/assembly_global.h" +.text +.align 5 + +//void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, +// const int *input_sum, const int *bias) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias + +asm_function MatMulOptR4Int8Neon64 + sub sp, sp, #144 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + + mov w15, #0 // b col index + mov w16, #0 // a row index + mov w17, #4 // sizeof(int8)*4 + mul w12, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + +L1: + cmp w15, w4 + beq End1 + + mov w16, #0 // reset a row index + mov x17, x0 // reload a ptr + mov x13, x6 // reload a_sums ptr +L2: + cmp w16, w3 + beq End2 + + mov x19, x1 // reload b ptr + mov x10, x7 // reload bias ptr + mov w11, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w11, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x19], #16 + ld1 {v5.16b}, [x19], #16 + ld1 {v6.16b}, [x19], #16 + ld1 {v7.16b}, [x19], #16 + + sdot v16.4s, v4.16b, v0.16b + sdot v17.4s, v5.16b, v0.16b + sdot v18.4s, v6.16b, v0.16b + sdot v19.4s, v7.16b, v0.16b + sdot v20.4s, v4.16b, v1.16b + sdot v21.4s, v5.16b, v1.16b + sdot v22.4s, v6.16b, v1.16b + sdot v23.4s, v7.16b, v1.16b + sdot v24.4s, v4.16b, v2.16b + sdot v25.4s, v5.16b, v2.16b + sdot v26.4s, v6.16b, v2.16b + sdot v27.4s, v7.16b, v2.16b + sdot v28.4s, v4.16b, v3.16b + sdot v29.4s, v5.16b, v3.16b + sdot v30.4s, v6.16b, v3.16b + sdot v31.4s, v7.16b, v3.16b + subs w11, w11, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x10], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + // Subtract (Asums*Zb) + ld1 {v14.4s}, [x13], #16 + dup v20.4s, v14.s[0] + dup v21.4s, v14.s[1] + dup v22.4s, v14.s[2] + dup v23.4s, v14.s[3] + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 + add w16, w16, #4 // a row index + 4 + b L2 + +End2: + add w15, w15, #4 // b col index + 4 + add x1, x1, x12 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + b L1 + +End1: + sub sp, sp, #144 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ret +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/assembly_global.h b/mindspore-lite/ops/kernel/cpu/nnacl/assembly_global.h new file mode 100644 index 0000000000000000000000000000000000000000..d1f5ca8bd6024efc1beaf2c70fc9a6f24ff24a1c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/assembly_global.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_ASSEMBLY_GLOBAL_H +#define NNACL_ASSEMBLY_GLOBAL_H + +// clang-format off +.macro asm_function fname +#ifdef __APPLE__ +.globl _\fname +_\fname: +#else +.global \fname +#ifdef __ELF__ +.hidden \fname +.type \fname, %function +#endif +\fname: +#endif +.endm + +// clang-format off +.macro asm_default_function fname +#ifdef __APPLE__ +.globl _\fname +_\fname: +#else +.global \fname +#ifdef __ELF__ +.type \fname, %function +#endif +\fname: +#endif +.endm + +// clang-format on + +#endif // NNACL_ASSEMBLY_GLOBAL_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/attention_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/attention_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..4b645a33785602a01d2b07137f10f71765f0f02a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/attention_parameter.h @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_ATTENTION_PARAMETER_H_ +#define NNACL_ATTENTION_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct AttentionParameter { + OpParameter op_parameter_; + int head_num_; + int head_size_; + bool cross_; +} AttentionParameter; + +typedef struct RelativePositionAttentionParameter { + // Primitive parameter + OpParameter op_parameter_; + // multi-head-attention args + int num_heads_; // number of heads of multi-head-attention + int k_seq_; // length of sequence of key of attention + int v_seq_; // length of sequence of value of attention + bool use_bias_; // if matmul in attention has bias + // relative-position-attention args + int p_seq_; // length of sequence of position of attention + // args for compute + int batch_; // batch of query/key/value/position + int d_model_; // d_model of multi-head-attention + int q_seq_; // length of sequence of query of attention + int row_tile_; // row tile for matrix pack + int col_tile_; // col tile for matrix pack + int bias_tile_; // tile for bias pack +} RelativePositionAttentionParameter; + +#endif // NNACL_ATTENTION_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/arithmetic_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/arithmetic_base.c new file mode 100644 index 0000000000000000000000000000000000000000..b6e12ca3fb66260b6f32ab48e4c0af021accb015 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/arithmetic_base.c @@ -0,0 +1,48 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/arithmetic_base.h" +#include "nnacl/kernel/arithmetic.h" + +void CalcMultiplesAndStrides(ArithmeticParameter *param) { + for (size_t i = 0; i < param->ndim_; i++) { + if (param->in_shape0_[i] != 0) { + param->multiples0_[i] = param->out_shape_[i] / param->in_shape0_[i]; + } + if (param->in_shape1_[i] != 0) { + param->multiples1_[i] = param->out_shape_[i] / param->in_shape1_[i]; + } + } + // cal strides + ComputeStrides(param->in_shape0_, param->in_strides0_, param->ndim_); + ComputeStrides(param->in_shape1_, param->in_strides1_, param->ndim_); + ComputeStrides(param->out_shape_, param->out_strides_, param->ndim_); +} + +void CalcStructMultiplesAndStrides(ArithmeticStruct *arithmetic) { + for (size_t i = 0; i < arithmetic->ndim_; i++) { + if (arithmetic->in_shape0_[i] != 0) { + arithmetic->multiples0_[i] = arithmetic->out_shape_[i] / arithmetic->in_shape0_[i]; + } + if (arithmetic->in_shape1_[i] != 0) { + arithmetic->multiples1_[i] = arithmetic->out_shape_[i] / arithmetic->in_shape1_[i]; + } + } + // cal strides + ComputeStrides(arithmetic->in_shape0_, arithmetic->in_strides0_, arithmetic->ndim_); + ComputeStrides(arithmetic->in_shape1_, arithmetic->in_strides1_, arithmetic->ndim_); + ComputeStrides(arithmetic->out_shape_, arithmetic->out_strides_, arithmetic->ndim_); +} diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_converter_utils/shared_memory.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/arithmetic_base.h similarity index 53% rename from mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_converter_utils/shared_memory.h rename to mindspore-lite/ops/kernel/cpu/nnacl/base/arithmetic_base.h index f4ec14a776014c5598c9d0d9fe2b4ba1b7f6e40d..67c90812a5c3d14dcbd3317fa48408c29c96d96b 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_converter_utils/shared_memory.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/arithmetic_base.h @@ -14,25 +14,23 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H -#define MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H -#include -#include "include/api/status.h" +#ifndef NNACL_BASE_ARITHMETIC_BASE_H_ +#define NNACL_BASE_ARITHMETIC_BASE_H_ -namespace mindspore { -class SharedMemory { - public: - Status Create(uint64_t memory_size); - Status Attach(); - void Detach(); - void Destroy() const; +#include "nnacl/arithmetic_parameter.h" +#include "nnacl/nnacl_utils.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/kernel/arithmetic.h" - private: - friend class MultiProcess; - uint8_t *GetSharedMemoryAddr() { return shmat_addr_; } +#ifdef __cplusplus +extern "C" { +#endif - int shm_id_ = -1; - uint8_t *shmat_addr_ = nullptr; -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H +void CalcMultiplesAndStrides(ArithmeticParameter *param); +void CalcStructMultiplesAndStrides(ArithmeticStruct *arithmetic); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_BASE_ARITHMETIC_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/batch_to_space_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/batch_to_space_base.c new file mode 100644 index 0000000000000000000000000000000000000000..d8900df0b44df1e17daf0aea8852e1f409e2f563 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/batch_to_space_base.c @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/batch_to_space_base.h" + +void BatchToSpaceNoCropForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + int data_size) { + int block_h = block[0]; + int block_w = block[1]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int stride_h = block_w * out_n; + int output_offset = 0; + int copy_size = in_c * data_size; + int in_stride_h = in_w * in_c; + int in_stride_n = in_stride_h * in_h; + for (int n = 0; n < out_n; ++n) { + for (int h = 0; h < in_h; ++h) { + int h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + for (int w = 0; w < in_w; ++w) { + int w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + int in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + memcpy((int8_t *)output + output_offset, (int8_t *)input + in_offset * data_size, copy_size); + output_offset += copy_size; + } + } + } + } + } +} + +void BatchToSpaceForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + const int *crops, int data_size) { + int block_h = block[0]; + int block_w = block[1]; + if (block_h == 0 || block_w == 0) { + return; + } + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int h_start = crops[0] / block_h; + int h_valid_begin = crops[0]; + int h_end = MSMIN((in_h * block_h - crops[1]) / block_h + 1, in_h); + int h_valid_end = in_h * block_h - crops[1] - 1; + int w_start = crops[2] / block_w; + int w_valid_begin = crops[2]; + int w_end = MSMIN((in_w * block_w - crops[3]) / block_w + 1, in_w); + int w_valid_end = in_w * block_w - crops[3] - 1; + + int stride_h = block_w * out_n; + int output_offset = 0; + int copy_size = in_c * data_size; + int in_stride_h = in_w * in_c; + int in_stride_n = in_stride_h * in_h; + for (int n = 0; n < out_n; ++n) { + for (int h = h_start; h < h_end; ++h) { + int h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + int h_index = h * block_h + bh; + if (h_index < h_valid_begin || h_index > h_valid_end) { + continue; + } + for (int w = w_start; w < w_end; ++w) { + int w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + int w_index = w * block_w + bw; + if (w_index < w_valid_begin || w_index > w_valid_end) { + continue; + } + int in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + memcpy((int8_t *)output + output_offset, (int8_t *)input + in_offset * data_size, copy_size); + output_offset += copy_size; + } + } + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/batch_to_space_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/batch_to_space_base.h new file mode 100644 index 0000000000000000000000000000000000000000..bf4553ce1fe7e4c867f88b25e22e3898b9029e0a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/batch_to_space_base.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BATCH_TO_SPACE_BASE_H_ +#define NNACL_BATCH_TO_SPACE_BASE_H_ + +#include +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void BatchToSpaceNoCropForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + int data_size); +void BatchToSpaceForNHWC(const void *input, void *output, const int *in_shape, int out_n, const int *block, + const int *crops, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BATCH_TO_SPACE_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/broadcast_to.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/broadcast_to.c new file mode 100644 index 0000000000000000000000000000000000000000..048bcb78436d3ccba98ce668c140c2d76dd09c5c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/broadcast_to.c @@ -0,0 +1,106 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/broadcast_to.h" +#include +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" + +size_t accumulate(const int *shape, int start, int end) { + size_t product = 1; + for (int i = start; i <= end; ++i) { + product *= (size_t)shape[i]; + } + return product; +} + +void pad_input_shape(int *input_shape, int input_shape_len, int output_shape_len) { + if (input_shape_len < output_shape_len) { + const int shape_gap = output_shape_len - input_shape_len; + for (int i = input_shape_len - 1; i >= 0; --i) { + input_shape[i + shape_gap] = input_shape[i]; + } + for (int i = 0; i < shape_gap; ++i) { + input_shape[i] = 1; + } + } +} + +#define BROADCAST_TO_SIZE_IMPL(data_size) \ + int BroadcastToSize##data_size(const void *input, BroadcastShapeInfo *shape_info, void *output) { \ + if (input == NULL || output == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + if (shape_info->output_shape_size_ > MAX_SHAPE_SIZE) { \ + return NNACL_ERR; \ + } \ + int *input_shape = shape_info->input_shape_; \ + const int *output_shape = shape_info->output_shape_; \ + const int dim_max = shape_info->output_shape_size_ - 1; \ + const size_t temp_length = accumulate(output_shape, 0, dim_max); \ + const size_t data_len = data_size / BYTE_SIZE; \ + if (temp_length * data_len == 0) { \ + return NNACL_ERR; \ + } \ + int8_t *data_temp = (int8_t *)malloc(temp_length * data_len); \ + if (data_temp == NULL) { \ + return NNACL_ERR; \ + } \ + pad_input_shape(input_shape, shape_info->input_shape_size_, dim_max + 1); \ + shape_info->input_shape_size_ = dim_max + 1; \ + \ + size_t before_dim_elements_num = accumulate(input_shape, 0, dim_max - 1); \ + size_t after_dim_elements_num = (size_t)(input_shape[dim_max]); \ + size_t dim_broadcast_rate = (size_t)(output_shape[dim_max] / input_shape[dim_max]); \ + for (size_t i = 0; i < before_dim_elements_num; ++i) { \ + const int8_t *in_ptr = (const int8_t *)input + i * after_dim_elements_num * data_len; \ + for (size_t j = 0; j < dim_broadcast_rate; ++j) { \ + int8_t *out_ptr = (int8_t *)output + (i * dim_broadcast_rate + j) * after_dim_elements_num * data_len; \ + memcpy(out_ptr, in_ptr, after_dim_elements_num *data_len); \ + } \ + } \ + \ + int dim_index = dim_max - 1; \ + while (dim_index >= 0) { \ + if (input_shape[dim_index] == 0) { \ + free(data_temp); \ + return NNACL_ERR; \ + } \ + dim_broadcast_rate = (size_t)(output_shape[dim_index] / input_shape[dim_index]); \ + if (dim_broadcast_rate > 1) { \ + before_dim_elements_num = accumulate(input_shape, 0, dim_index - 1); \ + after_dim_elements_num = accumulate(output_shape, dim_index + 1, dim_max); \ + for (size_t i = 0; i < before_dim_elements_num; ++i) { \ + int8_t *in_ptr = (int8_t *)output + i * after_dim_elements_num * data_len; \ + for (size_t j = 0; j < dim_broadcast_rate; ++j) { \ + int8_t *out_ptr = data_temp + (i * dim_broadcast_rate + j) * after_dim_elements_num * data_len; \ + memcpy(out_ptr, in_ptr, after_dim_elements_num *data_len); \ + } \ + } \ + size_t elements_total = before_dim_elements_num * dim_broadcast_rate * after_dim_elements_num; \ + memcpy(output, data_temp, elements_total *data_len); \ + } \ + --dim_index; \ + } \ + free(data_temp); \ + return NNACL_OK; \ + } + +BROADCAST_TO_SIZE_IMPL(8) +BROADCAST_TO_SIZE_IMPL(16) +BROADCAST_TO_SIZE_IMPL(32) +BROADCAST_TO_SIZE_IMPL(64) +BROADCAST_TO_SIZE_IMPL(128) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/broadcast_to.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/broadcast_to.h new file mode 100644 index 0000000000000000000000000000000000000000..7cb2353bf15a41ed22b718dcab88f7a6f4d6ccc6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/broadcast_to.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_BROADCAST_TO_H_ +#define NNACL_FP32_BROADCAST_TO_H_ + +#include "nnacl/broadcast_to_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define BYTE_SIZE 8 +int BroadcastToSize8(const void *input, BroadcastShapeInfo *shape_info, void *output); +int BroadcastToSize16(const void *input, BroadcastShapeInfo *shape_info, void *output); +int BroadcastToSize32(const void *input, BroadcastShapeInfo *shape_info, void *output); +int BroadcastToSize64(const void *input, BroadcastShapeInfo *shape_info, void *output); +int BroadcastToSize128(const void *input, BroadcastShapeInfo *shape_info, void *output); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_BROADCAST_TO_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/cast_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/cast_base.c new file mode 100644 index 0000000000000000000000000000000000000000..b9327cc1d4ae8d8a7749c8d283bf5b0656535536 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/cast_base.c @@ -0,0 +1,199 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/cast_base.h" +#include "nnacl/cast_base_simd.h" + +typedef union float32_bits { + unsigned int u; + float f; +} float32_bits; + +uint16_t Float32ToFloat16_(float f) { + float32_bits hbit; + hbit.f = f; + uint16_t hbits = 0; + // Extract the sign bit + uint16_t sign = (hbit.u >> FP16_BIT_SIZE) & 0x8000; // Get the sign (1 bit) ox8000 + // Extract the exponent + uint32_t exponent = (hbit.u >> FP32_SIGNIFICAND) & 0xFF; // Extract the exponent (8 bits) 0xFF + // Handle special cases (NaN, Inf, 0) + if (exponent == 0xFF) { // NaN or Infinity 0xFF + hbits |= sign | 0x7FFF; // Set to max float16 value (Infinity) + return hbits; + } else if (exponent == 0) { // Zero or denormalized number + // In float16, we treat zero the same way + hbits |= sign; // Preserve sign for zero + return hbits; + } + // Adjust the exponent to fit float16 + exponent -= FP32_EXPONENT_BIAS; // Remove float32 bias + exponent += FP16_EXPONENT_BIAS; // Add float16 bias + // Check for overflow + if (exponent >= 0x1F) { // 0X1F + hbits |= sign | 0x7FFF; // Set to max float16 value (Infinity) 0x7FFF + return hbits; + } + if (exponent == 0) { + // Handle underflow (too small to represent) + return sign; // Return zero with the correct sign + } + // Shift the mantissa: + // Extract the mantissa (23 bits), shift right by 13 (10-exp) + uint32_t mantissa = (hbit.u & 0x7FFFFF) >> FP16_SHIFT; // 0x7FFFFF + // Combine sign, exponent, and mantissa into hbits + hbits |= + sign | ((uint16_t)exponent << FP16_SIGNIFICAND) | (mantissa & 0x3FF); // combine sign exponent and mantissa 0x3FF + return hbits; +} + +void Int32ToFloat32(const int32_t *input, float *output, int number) { + int index = 0; + + SIMD_RUN_NO_SCALAR(Int32ToFloat32, index, input, output, number); + + for (; index < number; ++index) { + output[index] = (float)input[index]; + } +} + +void Float32ToInt32(const float *input, int32_t *output, int number) { + int index = 0; + + SIMD_RUN_X86_NO_SCALAR(Float32ToInt32, index, input, output, number); + + for (; index < number; ++index) { + output[index] = (int32_t)input[index]; + } +} + +void BoolToFloat32(const bool *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +void Uint8ToFloat32(const uint8_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +void Int32ToFloat32(const int32_t *input, float *output, int number); + +void Int64ToFloat32(const int64_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)input[i]; + } +} + +#ifdef ENABLE_FP16 +void Int64ToFp16(const int64_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void Int32ToFp16(const int32_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void BoolToFp16(const bool *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void Uint8ToFp16(const uint8_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +void Float32ToFp16(const float *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)(input[i]); + } +} + +void Fp16ToFloat32(const float16_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float)(input[i]); + } +} +#else +void Fp16ToFloat32(const uint16_t *input, float *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = ShortToFloat32(input[i]); + } +} + +void Float32ToFp16(const float *input, uint16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = Float32ToFloat16_(input[i]); + } +} +#endif + +void Float32ToInt32(const float *input, int32_t *output, int number); + +void Float32ToInt64(const float *input, int64_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int64_t)input[i]; + } +} + +void Int32ToInt64(const int32_t *input, int64_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int64_t)input[i]; + } +} + +void Int64ToInt32(const int64_t *input, int32_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int32_t)input[i]; + } +} + +void Float32ToInt16(const float *input, int16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int16_t)input[i]; + } +} + +void BoolToInt32(const bool *input, int32_t *output, int number) { + for (int i = 0; i < number; ++i) { + if (input[i]) { + output[i] = 1; + } else { + output[i] = 0; + } + } +} + +void Float32ToBool(const float *input, bool *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (bool)input[i]; + } +} + +void Float32ToUint8(const float *input, uint8_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (uint8_t)input[i]; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/cast_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/cast_base.h new file mode 100644 index 0000000000000000000000000000000000000000..6ebaf1a1b3419f5148f4e01c51740a275e633d6a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/cast_base.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_CAST_BASE_H_ +#define NNACL_BASE_CAST_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/nnacl_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void BoolToFloat32(const bool *input, float *output, int number); + +void Uint8ToFloat32(const uint8_t *input, float *output, int number); + +void Int32ToFloat32(const int32_t *input, float *output, int number); + +void Int64ToFloat32(const int64_t *input, float *output, int number); + +#ifdef ENABLE_FP16 +void Int64ToFp16(const int64_t *input, float16_t *output, int number); + +void Int32ToFp16(const int32_t *input, float16_t *output, int number); + +void BoolToFp16(const bool *input, float16_t *output, int number); + +void Uint8ToFp16(const uint8_t *input, float16_t *output, int number); + +void Float32ToFp16(const float *input, float16_t *output, int number); + +void Fp16ToFloat32(const float16_t *input, float *output, int number); +#else +void Fp16ToFloat32(const uint16_t *input, float *output, int number); + +void Float32ToFp16(const float *input, uint16_t *output, int number); +#endif + +uint16_t Float32ToFloat16_(float f); + +void Float32ToInt32(const float *input, int32_t *output, int number); + +void Float32ToInt64(const float *input, int64_t *output, int number); + +void Int32ToInt64(const int32_t *input, int64_t *output, int number); + +void Int64ToInt32(const int64_t *input, int32_t *output, int number); + +void Float32ToInt16(const float *input, int16_t *output, int number); + +void BoolToInt32(const bool *input, int32_t *output, int number); + +void Float32ToBool(const float *input, bool *output, int number); + +void Float32ToUint8(const float *input, uint8_t *output, int number); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_CAST_BASE_H_ diff --git a/mindspore-lite/src/extendrt/kernel/cuda/unique.cc b/mindspore-lite/ops/kernel/cpu/nnacl/base/cast_base_simd.h.in similarity index 37% rename from mindspore-lite/src/extendrt/kernel/cuda/unique.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/base/cast_base_simd.h.in index 742c7490b09996b336cabc3e4af56e09ab50476b..427bd3f1aab2a83be0e8fc20726ccc571c7c805a 100644 --- a/mindspore-lite/src/extendrt/kernel/cuda/unique.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/cast_base_simd.h.in @@ -13,35 +13,37 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef NNACL_BASE_CAST_BASE_@SIMD_INSTRUCTION@_H_ +#define NNACL_BASE_CAST_BASE_@SIMD_INSTRUCTION@_H_ -#include "src/extendrt/kernel/cuda/unique.h" -#include -#include +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" -namespace mindspore::kernel { -int UniqueCudaKernel::Prepare() { - CudaKernel::Prepare(); - if (unique_helper_ == nullptr) { - unique_helper_ = std::make_shared>(type_name_); - helper_ = unique_helper_; +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int Int32ToFloat32@SIMD_INSTRUCTION@(int index, const int32_t *input, float *output, int number) { + for (int block_max_size = number - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 value = SIMD_LD_EPI32(input + index); + SIMD_ST_F32(output + index, SIMD_EPI32_TO_F32(value)); } - return RET_OK; + return index; } -int UniqueCudaKernel::PostProcess() { - auto ret = CudaKernel::PostProcess(); - CHECK_NOT_EQUAL_RETURN(ret, RET_OK); - // set output tensor shape - std::vector out_shape = out_tensors_[0]->shape(); - out_shape[out_shape.size() - 1] = unique_helper_->GetOutSize(); - out_tensors_[0]->set_shape(out_shape); - return RET_OK; +#ifndef MS_SIMD_NEON +static inline int Float32ToInt32@SIMD_INSTRUCTION@(int index, const float *input, int32_t *output, int number) { + for (int block_max_size = number - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 value = SIMD_LD_F32(input + index); + SIMD_ST_EPI32(output + index, SIMD_F32_TO_EPI32(value)); + } + return index; } +#endif -int UniqueCudaKernel::Run() { - int ret = unique_helper_->Process(input_device_ptrs_, output_device_ptrs_, work_device_ptrs_, stream_); - CHECK_NOT_EQUAL_RETURN(ret, RET_OK); - return RET_OK; +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus } -// REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Unique, CudaKernelCreator) -} // namespace mindspore::kernel +#endif +#endif diff --git a/mindspore-lite/src/common/draw/graphviz_graph_builder.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/concat_base.c similarity index 30% rename from mindspore-lite/src/common/draw/graphviz_graph_builder.h rename to mindspore-lite/ops/kernel/cpu/nnacl/base/concat_base.c index 68f3a6a522fea357466803e0281ca52c58e15474..636afd8a3f7414ebfe254ad689b73ddcb9ac3f2b 100644 --- a/mindspore-lite/src/common/draw/graphviz_graph_builder.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/concat_base.c @@ -1,5 +1,5 @@ /** - * Copyright 2023 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,40 +14,41 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_COMMON_DRAW_GRAPHVIZ_GRAPH_BUILDER_H_ -#define MINDSPORE_LITE_SRC_COMMON_DRAW_GRAPHVIZ_GRAPH_BUILDER_H_ +#include "nnacl/base/concat_base.h" -#include -#include -#include -#include -#include -#include "src/common/log_adapter.h" -#include "src/common/draw/graphviz_graph.h" -#include "src/common/draw/adapter_graph.h" -#include "src/tensor.h" -#include "include/errorcode.h" +void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output, + int task_id, int thread_num, int data_size) { + int before_axis_size = 1; + for (int i = 0; i < axis; ++i) { + before_axis_size *= inputs_output_shape[0][i]; + } -namespace mindspore::lite { -class GVGraphBuilder { - public: - std::shared_ptr Build(const std::shared_ptr &graph); - - void AppendGraphInputNode(const lite::Tensor &tensor); - void AppendWeightNode(const lite::Tensor &tensor, const std::string &id, const std::string &label); - int AppendComputeNode(const AdapterNode &node); - int AppendGraphOutputNode(const std::vector &out_tensors); - - protected: - static GVNode *CreateComputeNode(const AdapterNode &node); - int LinkNodes(const AdapterNode &node, const GVNode &gv_node); - void AppendOutTensorMap(const lite::Tensor *tensor, lite::GVNode *node, size_t out_index); - std::pair GetBelongingGVNode(const lite::Tensor *tensor) const; - - std::shared_ptr gv_graph_{nullptr}; - std::string name_; - std::unordered_map> gv_node_out_tensor_map_; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_SRC_COMMON_DRAW_GRAPHVIZ_GRAPH_BUILDER_H_ + int after_axis_size = data_size; + for (size_t i = (size_t)(axis) + 1; i < shape_size; ++i) { + after_axis_size *= inputs_output_shape[0][i]; + } + int axis_offset = 0; + uint8_t *dst_base = (output); + int output_stride = after_axis_size * inputs_output_shape[input_num][axis]; + for (int i = 0; i < input_num; ++i) { + const uint8_t *src_base = (input[i]); + if (inputs_output_shape[i] == NULL) { + continue; + } + int input_stride = after_axis_size * inputs_output_shape[i][axis]; + NNACL_CHECK_ZERO_RETURN(thread_num); + int offset = UP_DIV(input_stride, thread_num); + int count = input_stride - offset * task_id; + if (count <= 0) { + axis_offset += inputs_output_shape[i][axis]; + continue; + } + count = MSMIN(offset, count); + for (int j = 0; j < before_axis_size; j++) { + const uint8_t *src = src_base + j * input_stride + task_id * offset; + uint8_t *dst = dst_base + j * output_stride + axis_offset * after_axis_size + task_id * offset; + memcpy(dst, src, count); + } + axis_offset += inputs_output_shape[i][axis]; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/concat_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/concat_base.h new file mode 100644 index 0000000000000000000000000000000000000000..232862db3008a37dd6b8708dc7d8e33bf03024fe --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/concat_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_CONCAT_BASE_H_ +#define NNACL_FP32_CONCAT_BASE_H_ + +#include +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output, + int task_id, int thread_num, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONCAT_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/conv1x1_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/conv1x1_base.c new file mode 100644 index 0000000000000000000000000000000000000000..7898e73509717a362e413ecc7b2969e0251db3b6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/conv1x1_base.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/conv1x1_base.h" + +void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size) { + /* support nhwc */ + char *src = (char *)src_ptr; + char *dst = (char *)dst_ptr; + for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) { + int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_; + if (src_h < 0 || src_h >= conv_param->input_h_) { + continue; + } + const char *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_ * data_size; + char *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_ * data_size; + for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) { + int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_; + if (src_w < 0 || src_w >= conv_param->input_w_) { + continue; + } + memcpy(dst_h_ptr + dst_w * conv_param->input_channel_ * data_size, + src_h_ptr + src_w * conv_param->input_channel_ * data_size, conv_param->input_channel_ * data_size); + } + } + return; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/conv1x1_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/conv1x1_base.h new file mode 100644 index 0000000000000000000000000000000000000000..bf9ffc02f3a21e67d271feb2586f0516a3f8cb3c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/conv1x1_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_CONV1X1_BASE_H_ +#define NNACL_BASE_CONV1X1_BASE_H_ + +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_CONV1X1_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/conv_common_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/conv_common_base.c new file mode 100644 index 0000000000000000000000000000000000000000..c326423efda30ddf112812e694e2de9a9f2960a3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/conv_common_base.c @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/base/conv_common_base.h" +#include "nnacl/errorcode.h" + +#define MIN_UNIT 2 +#define MAX_UNIT 8 + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num) { + return conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->pad_u_ == 1 && conv_param->pad_d_ == 1 && conv_param->pad_l_ == 1 && conv_param->pad_r_ == 1 && + conv_param->input_channel_ == conv_param->output_channel_ && conv_param->output_w_ >= 4 && + conv_param->output_h_ >= thread_num * 4; // better had more than 4 rows for each thread +} +#endif + +bool CheckWinogradInputOutputUnit(int input_unit, int output_unit) { + if (input_unit != 4 && input_unit != 6 && input_unit != 8) { + return false; + } + if ((output_unit >= input_unit) || (output_unit < 2)) { + return false; + } + return true; +} + +// Reference to the paper "Fast Algorithms for Convolutional Neural Networks" +// Utilize cost model to compute performance gain. +// If the gain is greater than got from Im2col, winograd algorithm will be chosen. +int SelectOutputUnit(const ConvParameter *conv_param) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_c = conv_param->input_channel_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_c = conv_param->output_channel_; + if (conv_param->op_parameter_.thread_num_ == 0) { + return NNACL_PARAM_INVALID; + } + int unit2 = UP_DIV(out_w * out_h, C12NUM * conv_param->op_parameter_.thread_num_); + int max_out_unit = (int)(sqrtf((float)unit2)); + max_out_unit = max_out_unit < MAX_UNIT ? max_out_unit : MAX_UNIT; + max_out_unit = max_out_unit > MIN_UNIT ? max_out_unit : MIN_UNIT; + + int unit = 0; + float max_rate = 0.0f; + float common_cost = (float)out_h * out_w * in_c * out_c * kernel_h * kernel_w; + + for (int i = MIN_UNIT; i <= max_out_unit; ++i) { + int input_unit = i + kernel_w - 1; + if (!CheckWinogradInputOutputUnit(input_unit, i)) { + continue; + } + float penalty = ((float)input_unit * input_unit) / ((float)kernel_h * kernel_w) * 0.12f; + float wino_cost = ((2 + out_c) * (float)input_unit * input_unit * in_c + ((float)input_unit + i) * i * out_c) * + UP_DIV(out_w, i) * UP_DIV(out_h, i); + float reduce_rate = common_cost / wino_cost - penalty; + if (reduce_rate > max_rate) { + max_rate = reduce_rate; + unit = i; + } + } + if (max_rate < 1.0f) { + return 1; + } + // If output_unit is 1, then it is conventional convolution + return unit; +} + +bool CheckIfUseWinograd(int *output_unit, const ConvParameter *conv_param) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + return false; + } + if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1 && conv_param->input_channel_ != 1) { + *output_unit = SelectOutputUnit(conv_param); + if (*output_unit > 1) { + return true; + } + } + return false; +} + +bool CheckAvxUseSW1x1Conv(const ConvParameter *conv_param) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + if (conv_param->pad_d_ == 0 && conv_param->pad_l_ == 0 && conv_param->pad_r_ == 0 && conv_param->pad_u_ == 0 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { + return true; + } + } + return false; +} + +bool CheckAvxUseSWConv(const ConvParameter *conv_param, int thread_nr_) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->input_w_, conv_param->input_h_, false); + if (conv_param->pad_d_ == 0 && conv_param->pad_l_ == 0 && conv_param->pad_r_ == 0 && conv_param->pad_u_ == 0 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1 && conv_param->input_channel_ % C8NUM == 0 && + (conv_param->input_w_ * conv_param->input_h_ >= thread_nr_)) { + return true; + } + } else { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ > C128NUM) { // conv1d kernel + return false; + } else if (conv_param->input_channel_ / conv_param->op_parameter_.thread_num_ <= C16NUM && + conv_param->input_h_ >= thread_nr_ && + (conv_param->kernel_h_ < C7NUM || conv_param->input_h_ / conv_param->kernel_h_ >= C4NUM) && + (conv_param->kernel_w_ < C7NUM || conv_param->input_w_ / conv_param->kernel_w_ >= C4NUM)) { + return true; + } + } + return false; +} diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph.cc b/mindspore-lite/ops/kernel/cpu/nnacl/base/conv_common_base.h similarity index 46% rename from mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/base/conv_common_base.h index d67d24d73a3f75fad966e4f649a1187b83ac76a6..73565aba63f00130dac2b60160b29432e856d1cd 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/conv_common_base.h @@ -13,27 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "include/api/graph.h" -#include "cxx_api/graph/graph_data.h" -#include "src/common/log_adapter.h" -namespace mindspore { -Graph::Graph() : graph_data_(nullptr) {} +#ifndef NNACL_BASE_CONV_DEPTHWISE_BASE_H_ +#define NNACL_BASE_CONV_DEPTHWISE_BASE_H_ -Graph::Graph(const std::shared_ptr &graph_data) : graph_data_(graph_data) {} +#include "nnacl/conv_parameter.h" -Graph::Graph(std::shared_ptr &&graph_data) : graph_data_(graph_data) {} +bool CheckAvxUseSW1x1Conv(const ConvParameter *conv_param); +bool CheckAvxUseSWConv(const ConvParameter *conv_param, int thread_nr_); -Graph::~Graph() {} +#ifdef __cplusplus +extern "C" { +#endif -Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {} +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num); +#endif -bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; } +bool CheckWinogradInputOutputUnit(int input_unit, int output_unit); -bool Graph::operator!=(std::nullptr_t) const { return graph_data_ != nullptr; } +bool CheckIfUseWinograd(int *output_unit, const ConvParameter *conv_param); -ModelType Graph::ModelType() const { - MS_EXCEPTION_IF_NULL(graph_data_); - return graph_data_->ModelType(); +#ifdef __cplusplus } -} // namespace mindspore +#endif + +#endif // NNACL_BASE_CONV_DEPTHWISE_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/crop_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/crop_base.c new file mode 100644 index 0000000000000000000000000000000000000000..6d120d08464f4071473117fa5e4f65404f3ca680 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/crop_base.c @@ -0,0 +1,40 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/crop_base.h" +#include "nnacl/errorcode.h" + +int CropPadOffset(int input_dim, CropParameter *crop_para, int64_t *in_offset) { + int64_t axis = crop_para->axis_; + int offsets_size = crop_para->offset_size_; + if (offsets_size > 1) { + NNACL_CHECK_TRUE_RET(axis + offsets_size == input_dim, NNACL_ERR); + } + for (int i = 0; i < input_dim; i++) { + int crop_offset = 0; + if (i >= axis) { + if (offsets_size == 1) { + crop_offset = crop_para->offset_[0]; + } else if (offsets_size > 1) { + if (i - axis < CROP_OFFSET_MAX_SIZE) { + crop_offset = crop_para->offset_[i - axis]; + } + } + } + in_offset[i] = crop_offset; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_registration_factory.cc b/mindspore-lite/ops/kernel/cpu/nnacl/base/crop_base.h similarity index 62% rename from mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_registration_factory.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/base/crop_base.h index 944c378b93851879bb7eeb8aa6c9ba47df76225b..9ef1c652d33b75cb32dfba4e2addf1770b4c1fb1 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_registration_factory.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/crop_base.h @@ -13,13 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "extendrt/delegate/ascend_native/ascend_native_registration_factory.h" -#include "extendrt/delegate/ascend_native/ascend_native_kernel_registry.h" -namespace mindspore::kernel { -template <> -AscendNativeRegistrationFactory &AscendNativeRegistrationFactory::Get() { - static AscendNativeRegistrationFactory obj; - return obj; +#ifndef NNACL_BASE_CROP_BASE_H_ +#define NNACL_BASE_CROP_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/crop_parameter.h" + +#define CROP_OFFSET_MAX_SIZE 4 + +#ifdef __cplusplus +extern "C" { +#endif + +int CropPadOffset(int input_dim, CropParameter *crop_para, int64_t *in_offset); + +#ifdef __cplusplus } -} // namespace mindspore::kernel +#endif + +#endif // NNACL_BASE_CROP_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/depth_to_space_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/depth_to_space_base.c new file mode 100644 index 0000000000000000000000000000000000000000..57bcd79e00b7d20c10cdedcd9034efc141eab3ec --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/depth_to_space_base.c @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/depth_to_space_base.h" +#include "nnacl/errorcode.h" + +void DepthToSpaceForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceArgs *param) { + int32_t block_size = param->block_size_; + int32_t in_shape_dim2 = in_shape[2]; + int32_t in_shape_dim1 = in_shape[1]; + size_t copy_size = (size_t)block_size * param->out_stride_dim2_ * param->data_type_size_; + for (int i = 0; i < in_shape[0]; ++i) { + int64_t in_offset_n = i * param->in_stride_dim0_; + int64_t out_offset_n = i * param->out_stride_dim0_; + for (int j = 0; j < in_shape_dim1; ++j) { + int64_t in_offset_h = in_offset_n + j * param->in_stride_dim1_; + int64_t out_offset_h = out_offset_n + j * block_size * param->out_stride_dim1_; + for (int k = 0; k < in_shape_dim2; ++k) { + int64_t in_offset_w = in_offset_h + k * param->in_stride_dim2_; + int64_t out_offset_w = out_offset_h + k * block_size * param->out_stride_dim2_; + for (int l = 0; l < block_size; ++l) { + int64_t out_offset = (out_offset_w + l * param->out_stride_dim1_) * param->data_type_size_; + int64_t in_offset = (in_offset_w + l * block_size * param->out_stride_dim2_) * param->data_type_size_; + memcpy((int8_t *)output + out_offset, (int8_t *)input + in_offset, copy_size); + } + } + } + } +} + +void DepthToSpaceCRDForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceArgs *param) { + int32_t block_size = param->block_size_; + int32_t in_shape_dim3 = in_shape[3]; + int32_t in_shape_dim2 = in_shape[2]; + int32_t in_shape_dim1 = in_shape[1]; + size_t copy_size = param->data_type_size_; + for (int i = 0; i < in_shape[0]; ++i) { + int64_t in_offset_n = i * param->in_stride_dim0_; + int64_t out_offset_n = i * param->out_stride_dim0_; + for (int j = 0; j < in_shape_dim1; ++j) { + int64_t in_offset_h = in_offset_n + j * param->in_stride_dim1_; + int64_t out_offset_h = out_offset_n + j * block_size * param->out_stride_dim1_; + for (int k = 0; k < in_shape_dim2; ++k) { + int64_t in_offset_w = in_offset_h + k * param->in_stride_dim2_; + int64_t out_offset_w = out_offset_h + k * block_size * param->out_stride_dim2_; + for (int l = 0; l < in_shape_dim3; ++l) { + int64_t offset = l % (block_size * block_size); + int64_t out_offset_c = + out_offset_w + + offset / block_size * block_size * in_shape_dim2 * in_shape_dim3 / (block_size * block_size) + + offset % block_size * in_shape_dim3 / (block_size * block_size); + int64_t out_offset = (out_offset_c + l / (block_size * block_size)) * param->data_type_size_; + int64_t in_offset = (in_offset_w + l) * param->data_type_size_; + memcpy((int8_t *)output + out_offset, (int8_t *)input + in_offset, copy_size); + } + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/depth_to_space_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/depth_to_space_base.h new file mode 100644 index 0000000000000000000000000000000000000000..df1cd25955a73e612f8c5f5574d65ae74a36533b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/depth_to_space_base.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_DEPTH_TO_SPACE_H_ +#define NNACL_DEPTH_TO_SPACE_H_ + +#include +#include "nnacl/kernel/depth_to_space.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DepthToSpaceForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceArgs *param); +void DepthToSpaceCRDForNHWC(const void *input, void *output, const int *in_shape, const DepthToSpaceArgs *param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_DEPTH_TO_SPACE_H_ diff --git a/mindspore-lite/src/extendrt/graph_executor/factory.cc b/mindspore-lite/ops/kernel/cpu/nnacl/base/fill_base.c similarity index 40% rename from mindspore-lite/src/extendrt/graph_executor/factory.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/base/fill_base.c index dac874aa0cbd2f81a889d4a21126fe035d25af1c..10a6aefde4c38c01263030070cc0b7dfcfd45b9e 100644 --- a/mindspore-lite/src/extendrt/graph_executor/factory.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/fill_base.c @@ -1,5 +1,5 @@ /** - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,27 +13,47 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "extendrt/graph_executor/factory.h" -#include -#include - -namespace mindspore { -GraphExecutorRegistry &GraphExecutorRegistry::GetInstance() { - static GraphExecutorRegistry instance; - return instance; + +#include "nnacl/base/fill_base.h" +#include "nnacl/fill_base_simd.h" + +int FillFp32(float *output, int size, float data) { + if (output == NULL) { + return NNACL_NULL_PTR; + } + + int index = 0; + + SIMD_RUN_NO_SCALAR(FillFp32, index, output, size, data); + + for (; index < size; ++index) { + output[index] = data; + } + return NNACL_OK; } -void GraphExecutorRegistry::RegExecutor(const mindspore::GraphExecutorType &type, const GraphExecutorRegFunc &creator) { - graph_executor_map_[type] = creator; +int FillInt32(int *output, int size, int data) { + if (output == NULL) { + return NNACL_NULL_PTR; + } + + int index = 0; + + SIMD_RUN_NO_SCALAR(FillInt32, index, output, size, data); + + for (; index < size; ++index) { + output[index] = data; + } + return NNACL_OK; } -std::shared_ptr GraphExecutorRegistry::GetExecutor( - const mindspore::GraphExecutorType &type, const std::string &name, - std::shared_ptr execution_plan) { - auto it = graph_executor_map_.find(type); - if (it == graph_executor_map_.end()) { - return nullptr; +int FillBool(bool *output, int size, bool data) { + if (output == NULL) { + return NNACL_NULL_PTR; + } + + for (int index = 0; index < size; ++index) { + output[index] = data; } - return it->second(name, execution_plan); + return NNACL_OK; } -} // namespace mindspore diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/fill_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/fill_base.h new file mode 100644 index 0000000000000000000000000000000000000000..99374004f7141f552928ad2da30d08e57c973afd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/fill_base.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FILL_BASE_H_ +#define NNACL_FILL_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/fill_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int FillFp32(float *output, int size, float data); +int FillInt32(int *output, int size, int data); +int FillBool(bool *output, int size, bool data); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FILL_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/fill_base_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/base/fill_base_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..4e14fd4546ec0eb5bbc329b49c3ec51af32aa777 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/fill_base_simd.h.in @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_FILL_BASE_@SIMD_INSTRUCTION@_H_ +#define NNACL_BASE_FILL_BASE_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int FillFp32@SIMD_INSTRUCTION@(int index, float *output, int size, float data) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_MOV_F32(data)); + } + return index; +} + +static inline int FillInt32@SIMD_INSTRUCTION@(int index, int *output, int size, int data) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_MOV_EPI32(data)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/format_transpose.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/format_transpose.c new file mode 100644 index 0000000000000000000000000000000000000000..97d3c282658583e8315037b7f5bfd38b727298e1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/format_transpose.c @@ -0,0 +1,81 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/format_transpose.h" +#include "nnacl/errorcode.h" +#include "nnacl/fp32/pack_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/pack_fp16.h" +#endif + +int TransposeFp32Data(void *src_data, void *dst_data, const FormatC src_format, const FormatC dst_format, + const int batch, const int channel, const int plane) { + if (src_format == Format_NHWC && dst_format == Format_NCHW) { + PackNHWCToNCHWFp32(src_data, dst_data, batch, plane, channel, 0, 1); + } else if (src_format == Format_NCHW && dst_format == Format_NHWC) { + PackNCHWToNHWCFp32(src_data, dst_data, batch, plane, channel, 0, 1); + } else if (src_format == Format_NCHW && dst_format == Format_NC4HW4) { + PackNCHWToNC4HW4Fp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC4HW4 && dst_format == Format_NCHW) { + PackNC4HW4ToNCHWFp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NHWC && dst_format == Format_NC4HW4) { + PackNHWCToNC4HW4Fp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC4HW4 && dst_format == Format_NHWC) { + PackNC4HW4ToNHWCFp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NHWC && dst_format == Format_NC8HW8) { + PackNHWCToNC8HW8Fp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC8HW8 && dst_format == Format_NHWC) { + PackNC8HW8ToNHWCFp32(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC8HW8 && dst_format == Format_NCHW) { + PackNC8HW8ToNCHWFp32(src_data, dst_data, batch, plane, channel); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} + +#ifdef ENABLE_FP16 +int TransposeFp16Data(void *src_data, void *dst_data, const FormatC src_format, const FormatC dst_format, int batch, + int channel, int plane) { + if (src_format == Format_NCHW && dst_format == Format_NC8HW8) { + PackNCHWFp16ToNC8HW8Fp16(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NHWC && dst_format == Format_NC8HW8) { + return NNACL_ERR; + } else if (src_format == Format_NC8HW8 && dst_format == Format_NCHW) { + PackNC8HW8ToNCHWFp16(src_data, dst_data, batch, plane, channel); + } else if (src_format == Format_NC8HW8 && dst_format == Format_NHWC) { + PackNC8HW8ToNHWCFp16((float16_t *)src_data, (float16_t *)dst_data, batch, plane, channel); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} +#endif + +int TransData(void *src_data, void *dst_data, const FormatC src_format, const FormatC dst_format, TypeIdC data_type, + const int batch, const int channel, const int plane) { + switch (data_type) { + case kNumberTypeFloat: + case kNumberTypeFloat32: + return TransposeFp32Data(src_data, dst_data, src_format, dst_format, batch, channel, plane); +#ifdef ENABLE_FP16 + case kNumberTypeFloat16: + return TransposeFp16Data(src_data, dst_data, src_format, dst_format, batch, channel, plane); +#endif + default: + return NNACL_ERR; + } +} diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/normalize.cuh b/mindspore-lite/ops/kernel/cpu/nnacl/base/format_transpose.h similarity index 60% rename from mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/normalize.cuh rename to mindspore-lite/ops/kernel/cpu/nnacl/base/format_transpose.h index 03eada9f3b4af15bc24e13ae4785b3f260d3c874..d2e974f2e569e45c2cc66f820a8da2a2a1341ca0 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/normalize.cuh +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/format_transpose.h @@ -13,12 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef NNACL_FORMAT_TRANSPOSE_H_ +#define NNACL_FORMAT_TRANSPOSE_H_ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_NORMALIZE_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_NORMALIZE_H_ +#include "nnacl/op_base.h" -template -void Normalize(const T *input, const T *gamma, const T *beta, T *output, size_t dim_at_axis, float epsilion, - int element_cnt, cudaStream_t stream); +#ifdef __cplusplus +extern "C" { +#endif +int TransData(void *src_data, void *dst_data, const FormatC src_format, const FormatC dst_format, TypeIdC data_type, + const int batch, const int channel, const int plane); +#ifdef __cplusplus +} +#endif -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_NORMALIZE_H_ +#endif // NNACL_FILL_BASE_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/rename_fullname_with_scope.cc b/mindspore-lite/ops/kernel/cpu/nnacl/base/gather_base.c similarity index 35% rename from mindspore-lite/tools/graph_kernel/converter/rename_fullname_with_scope.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/base/gather_base.c index 8988f481feafa45e582d0a75b5584512e8a082d0..e717d966e3967d04455edf744d80ea1b6080f410 100644 --- a/mindspore-lite/tools/graph_kernel/converter/rename_fullname_with_scope.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/gather_base.c @@ -1,5 +1,5 @@ /** - * Copyright 2023 Huawei Technologies Co., Ltd + * Copyright 2020-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,33 +13,32 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "tools/graph_kernel/converter/rename_fullname_with_scope.h" +#include +#include "nnacl/base/gather_base.h" -namespace mindspore::graphkernel { -bool RenameFullnameWithScope::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - bool changed = false; - std::unordered_map names; - auto nodes = TopoSort(func_graph->output()); - for (auto &node : nodes) { - if (node->isa()) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto name_scope = cnode->fullname_with_scope(); - if (name_scope.empty()) { - continue; - } - if (names.find(name_scope) == names.end()) { - names[name_scope] = 1; +int Gather(const void *input, int64_t outer_size, int64_t byte_inner_size, int64_t limit, const int *indices, + int64_t index_num, void *output, int64_t byte_out_stride, int *error_index) { + if (input == NULL || output == NULL || indices == NULL || error_index == NULL) { + return NNACL_NULL_PTR; + } + const int8_t *int8_in = (int8_t *)input; + int8_t *int8_out = (int8_t *)output; + int64_t in_stride = byte_inner_size * limit; + for (int64_t m = 0; m < outer_size; ++m) { + int8_t *int8_out_m = int8_out; + for (int64_t i = 0; i < index_num; ++i) { + int index = indices[i]; + index = index < 0 ? index + limit : index; + if (index < 0 || index >= limit) { + *error_index = index; + return NNACL_GATHER_INDICES_VALUE_INVALID; } else { - // node with same name - names[name_scope]++; - auto new_name_scope = name_scope + "-" + std::to_string(names[name_scope]); - cnode->set_fullname_with_scope(new_name_scope); - changed = true; + memcpy(int8_out_m, int8_in + index * byte_inner_size, byte_inner_size); } + int8_out_m += byte_inner_size; } + int8_in += in_stride; + int8_out += byte_out_stride; } - return changed; + return NNACL_OK; } -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/gather_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/gather_base.h new file mode 100644 index 0000000000000000000000000000000000000000..0d4521b5c51df664227938d15858dcb6ef7b3634 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/gather_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GATHER_BASE_H_ +#define NNACL_GATHER_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Gather(const void *input, int64_t outer_size, int64_t byte_inner_size, int64_t limit, const int *indices, + int64_t index_num, void *output, int64_t byte_out_stride, int *error_index); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_GATHER_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/gather_d_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/gather_d_base.c new file mode 100644 index 0000000000000000000000000000000000000000..da2454ec22c281759bdc497b8fac48ba0b0cac27 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/gather_d_base.c @@ -0,0 +1,163 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/base/gather_d_base.h" + +int CheckIndexValue_int32_t(int32_t *index, const int max_index, const size_t *index_shape, + const size_t index_shape_size) { + // check index + size_t index_size = 1; + for (size_t i = 0; i < index_shape_size; ++i) { + index_size *= index_shape[i]; + } + + for (size_t i = 0; i < index_size; ++i) { + if (index[i] >= max_index || index[i] < -max_index) { + return NNACL_ERR; + } + if (index[i] < 0) { + index[i] = max_index + index[i]; + } + } + return NNACL_OK; +} + +int CheckIndexValue_int64_t(int64_t *index, const int max_index, const size_t *index_shape, + const size_t index_shape_size) { + // check index + size_t index_size = 1; + for (size_t i = 0; i < index_shape_size; ++i) { + index_size *= index_shape[i]; + } + for (size_t i = 0; i < index_size; ++i) { + if (index[i] >= max_index || index[i] < -max_index) { + return NNACL_ERR; + } + if (index[i] < 0) { + index[i] = max_index + index[i]; + } + } + return NNACL_OK; +} + +int InitCalVec(size_t *in_strides, size_t *out_strides, size_t *pos, const size_t *input_shape, + const size_t input_shape_size, const size_t *output_shape, const size_t output_shape_size) { + // in_strides + NNACL_CHECK_NULL_RETURN_ERR(in_strides); + for (size_t i = 0; i < input_shape_size; ++i) { + in_strides[i] = 1; + } + for (int i = (int)input_shape_size - 2; i >= 0; --i) { + in_strides[i] = input_shape[i + 1] * in_strides[i + 1]; + } + + // out_strides + NNACL_CHECK_NULL_RETURN_ERR(out_strides); + for (size_t i = 0; i < output_shape_size; ++i) { + out_strides[i] = 1; + } + for (int i = (int)output_shape_size - 2; i >= 0; --i) { + out_strides[i] = output_shape[i + 1] * out_strides[i + 1]; + } + + NNACL_CHECK_NULL_RETURN_ERR(pos); + for (size_t i = 0; i < output_shape_size; ++i) { + pos[i] = 0; + } + return NNACL_OK; +} + +#define COPY_TASK_IMPL(type0, type1) \ + int CopyTask_Input_##type0##_Index_##type1( \ + type0 *output, const type0 *input, const type1 *index, size_t cur_dim, size_t *pos, const size_t dim, \ + const size_t *output_shape, const size_t output_shape_size, const size_t *in_strides, const size_t *out_strides) { \ + if (pos == NULL || out_strides == NULL || in_strides == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + for (size_t i = 0; i < output_shape[cur_dim]; ++i) { \ + pos[cur_dim] = i; \ + if (cur_dim == output_shape_size - 1) { \ + size_t input_offset = 0; \ + size_t out_offset = 0; \ + for (size_t j = 0; j < output_shape_size; ++j) { \ + out_offset += pos[j] * out_strides[j]; \ + } \ + size_t cur_index = pos[dim]; \ + pos[dim] = index[out_offset]; \ + for (size_t j = 0; j < output_shape_size; ++j) { \ + input_offset += pos[j] * in_strides[j]; \ + } \ + ((type0 *)output)[out_offset] = ((const type0 *)input)[input_offset]; \ + pos[dim] = cur_index; \ + } else { \ + CopyTask_Input_##type0##_Index_##type1(output, input, index, cur_dim + 1, pos, dim, output_shape, \ + output_shape_size, in_strides, out_strides); \ + } \ + } \ + return NNACL_OK; \ + } + +COPY_TASK_IMPL(bool, int32_t) +COPY_TASK_IMPL(bool, int64_t) +COPY_TASK_IMPL(int16_t, int32_t) +COPY_TASK_IMPL(int16_t, int64_t) +COPY_TASK_IMPL(int32_t, int32_t) +COPY_TASK_IMPL(int32_t, int64_t) +COPY_TASK_IMPL(int64_t, int32_t) +COPY_TASK_IMPL(int64_t, int64_t) +COPY_TASK_IMPL(float, int32_t) +COPY_TASK_IMPL(float, int64_t) +#ifdef ENABLE_FP16 +COPY_TASK_IMPL(float16_t, int32_t) +COPY_TASK_IMPL(float16_t, int64_t) +#endif + +#define GATHER_D_IMPL(type0, type1) \ + GATHER_D_IMPL_DECLARATION(type0, type1) { \ + if (output == NULL || input == NULL || index == NULL || input_shape == NULL || output_shape == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + int max_index = input_shape[dim]; \ + int ret = CheckIndexValue_##type1(index, max_index, output_shape, output_shape_size); \ + if (ret != NNACL_OK) { \ + return ret; \ + } \ + size_t in_strides[MAX_SHAPE_SIZE]; \ + size_t out_strides[MAX_SHAPE_SIZE]; \ + size_t pos[MAX_SHAPE_SIZE]; \ + ret = InitCalVec(in_strides, out_strides, pos, input_shape, input_shape_size, output_shape, output_shape_size); \ + if (ret != NNACL_OK) { \ + return ret; \ + } \ + ret = CopyTask_Input_##type0##_Index_##type1(output, input, index, 0, pos, dim, output_shape, output_shape_size, \ + in_strides, out_strides); \ + return ret; \ + } + +GATHER_D_IMPL(bool, int32_t) +GATHER_D_IMPL(bool, int64_t) +GATHER_D_IMPL(int16_t, int32_t) +GATHER_D_IMPL(int16_t, int64_t) +GATHER_D_IMPL(int32_t, int32_t) +GATHER_D_IMPL(int32_t, int64_t) +GATHER_D_IMPL(int64_t, int32_t) +GATHER_D_IMPL(int64_t, int64_t) +GATHER_D_IMPL(float, int32_t) +GATHER_D_IMPL(float, int64_t) +#ifdef ENABLE_FP16 +GATHER_D_IMPL(float16_t, int32_t) +GATHER_D_IMPL(float16_t, int64_t) +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/gather_d_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/gather_d_base.h new file mode 100644 index 0000000000000000000000000000000000000000..cb9be5b4a3ba0e9501b5dadf90e6e695d6df1cac --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/gather_d_base.h @@ -0,0 +1,55 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GATHER_D_BASE_H_ +#define NNACL_GATHER_D_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define GATHER_D(type0, type1, output, input, index, input_shape, input_shape_size, output_shape, output_shape_size, \ + dim) \ + GatherD_Input_##type0##_Index_##type1(output, input, index, input_shape, input_shape_size, output_shape, \ + output_shape_size, dim) + +#define GATHER_D_IMPL_DECLARATION(type0, type1) \ + int GatherD_Input_##type0##_Index_##type1( \ + type0 *output, const type0 *input, type1 *index, const size_t *input_shape, const size_t input_shape_size, \ + const size_t *output_shape, const size_t output_shape_size, const size_t dim) + +GATHER_D_IMPL_DECLARATION(bool, int32_t); +GATHER_D_IMPL_DECLARATION(bool, int64_t); +GATHER_D_IMPL_DECLARATION(int16_t, int32_t); +GATHER_D_IMPL_DECLARATION(int16_t, int64_t); +GATHER_D_IMPL_DECLARATION(int32_t, int32_t); +GATHER_D_IMPL_DECLARATION(int32_t, int64_t); +GATHER_D_IMPL_DECLARATION(int64_t, int32_t); +GATHER_D_IMPL_DECLARATION(int64_t, int64_t); +GATHER_D_IMPL_DECLARATION(float, int32_t); +GATHER_D_IMPL_DECLARATION(float, int64_t); +#ifdef ENABLE_FP16 +GATHER_D_IMPL_DECLARATION(float16_t, int32_t); +GATHER_D_IMPL_DECLARATION(float16_t, int64_t); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_GATHER_D_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/minimal_filtering_generator.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/minimal_filtering_generator.c new file mode 100644 index 0000000000000000000000000000000000000000..fc01f1eaaa6498f5de58da2afd6c2d653b6e4dde --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/minimal_filtering_generator.c @@ -0,0 +1,342 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/base/minimal_filtering_generator.h" +#include +#include +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/errorcode.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +void Polynomial(const float *interval, float *m, int degree) { + for (int i = 0; i < degree; ++i) { + float mul = 1; + for (int j = 0; j < degree; ++j) { + if (i == j) { + continue; + } + mul *= (interval[i] - interval[j]); + } + m[i] = mul; + } +} + +void DiagonalPlusMatrix(const float *matrix, float *diagonal_matrix, int degree) { + int data_num = (degree + 1) * (degree + 1); + memset(diagonal_matrix, 0, data_num * sizeof(float)); + for (int i = 0; i < degree; ++i) { + for (int j = 0; j < degree; ++j) { + if (j == i) { + diagonal_matrix[i * (degree + 1) + j] = matrix[i]; + } + } + } + diagonal_matrix[data_num - 1] = 1; +} + +void ResidueMatrix(const float *interval, float *b, int row, int col) { + // row : input unit, col : output_unit + // result : matrix b + int len = row * col; + memset(b, 0, len * sizeof(float)); + for (int i = 0; i < row - 1; ++i) { + for (int j = 0; j < col; ++j) { + b[i * col + j] = pow(interval[i], j); + } + } + b[len - 1] = 1; +} + +int LT(const float *poly_array, float *matrix_lt, int n) { + if (n > MAX_LEN) { + return NNACL_ERR; + } + float coefficient_array[MAX_LEN]; // n + float poly[MAX_LEN]; // n + + Polynomial(poly_array, poly, n); + for (int i = 0; i < n; ++i) { + // get coefficient + int index = 1; + memset(coefficient_array, 0, n * sizeof(float)); + coefficient_array[0] = 1; + for (int j = 0; j < n; ++j) { + if (j == i) continue; + float poly_coe = poly_array[j] == 0 ? 0 : -poly_array[j]; + coefficient_array[index] = 1; + for (int k = index - 1; k > 0; --k) { + coefficient_array[k] = coefficient_array[k] * poly_coe + coefficient_array[k - 1]; + } + coefficient_array[0] *= poly_coe; + index++; + } + + // lx[i, 0].nth(j) / f[i] + int setp = i * n; + for (int l = 0; l < n; ++l) { + matrix_lt[setp + l] = coefficient_array[l] / poly[i]; + } + } // matrix L row loop + return NNACL_OK; +} + +void T(const float *poly_array, float *matrix_t, int n) { + memset(matrix_t, 0, n * (n + 1) * sizeof(float)); + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n + 1; ++j) { + if (j == i) matrix_t[i * (n + 1) + j] = 1; + if (j == n) { + if (poly_array[i] == 0) { + matrix_t[i * (n + 1) + j] = 0; + } else { + matrix_t[i * (n + 1) + j] = -pow(poly_array[i], n); + } + } + } + } +} + +int B(const float *poly_array, float *matrix_b, int in_unit) { + memset(matrix_b, 0, in_unit * in_unit * sizeof(float)); + int n = in_unit - 1; + if ((n * n) > MAX_LEN || (n * in_unit) > MAX_LEN) { + return NNACL_ERR; + } + float matrix_l[MAX_LEN]; // n * n + float matrix_lt[MAX_LEN]; // n * n + float matrix_t[MAX_LEN]; // n * in_unit + + T(poly_array, matrix_t, n); + if (LT(poly_array, matrix_lt, n) != NNACL_OK) { + return NNACL_ERR; + } + MatrixTranspose(matrix_lt, matrix_l, n, n); + MatrixMultiply(matrix_l, matrix_t, matrix_b, n, n, in_unit); + matrix_b[in_unit * in_unit - 1] = 1; + return NNACL_OK; +} + +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel) { + int cnt = 0; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + for (int y = 0; y < in_channel; ++y) { + float tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + matrix_c[cnt++] = tmp; + } + cnt += c4_channel / 4 - in_channel; + } + } +} +#endif + +void GenerateIntervalArray(float *array, float interval, int degree) { + array[0] = 0; + for (int i = 1; i < degree; ++i) { + int coefficient = pow(-1, i - 1); + array[i] = array[i - 1] + interval * i * coefficient; + } +} + +void MatrixTranspose(const float *matrix, float *trans_matrix, int row, int col) { + for (int i = 0; i < col; ++i) { + for (int j = 0; j < row; ++j) { + trans_matrix[i * row + j] = matrix[j * col + i]; + } + } +} + +void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n) { + int count = 0; + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float res = 0; + for (int i = 0; i < k; i++) { + res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n); + } + *(matrix_c + count) = res; + count++; + } + } +} + +int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g, + float *matrix_gt, float coefficient, int out_unit, int filter_size) { + int in_unit = out_unit + filter_size - 1; + int degree = in_unit - 1; + if (degree > MAX_LEN || (in_unit * in_unit) > MAX_LEN || (in_unit * filter_size) > MAX_LEN) { + return NNACL_ERR; + } + float polynomial_m[MAX_LEN]; // degree + float diagonal_matrix[MAX_LEN]; // input_unit * input_unit + float inverse_diagonal_matrix[MAX_LEN]; // input_unit * input_unit + + // get diagonal matrix + float interval[MAX_LEN]; // degree + GenerateIntervalArray(interval, coefficient, degree); + Polynomial(interval, polynomial_m, degree); + DiagonalPlusMatrix(polynomial_m, diagonal_matrix, degree); + if (diagonal_matrix[0] < 0) { + for (int i = 0; i < in_unit; ++i) { + if (diagonal_matrix[i] != 0) diagonal_matrix[i] *= -1; + } + } + + // inverse diagonal matrix + for (int j = 0; j < in_unit * in_unit; ++j) { + if (diagonal_matrix[j] != 0) { + inverse_diagonal_matrix[j] = 1.0 / diagonal_matrix[j]; + } else { + inverse_diagonal_matrix[j] = 0; + } + } + + // get matrix A && AT + ResidueMatrix(interval, matrix_a, in_unit, out_unit); + MatrixTranspose(matrix_a, matrix_at, in_unit, out_unit); + + // get matrix B + int ret = B(interval, matrix_bt, in_unit); + if (ret != NNACL_OK) { + return ret; + } + MatrixTranspose(matrix_bt, matrix_b, in_unit, in_unit); + MatrixMultiply(diagonal_matrix, matrix_b, matrix_bt, in_unit, in_unit, in_unit); + MatrixTranspose(matrix_bt, matrix_b, in_unit, in_unit); + + // get matrix G && GT + float tmp_g[MAX_LEN]; // in_unit * filter_size + ResidueMatrix(interval, matrix_g, in_unit, filter_size); + MatrixTranspose(matrix_g, tmp_g, in_unit, filter_size); + MatrixMultiply(tmp_g, inverse_diagonal_matrix, matrix_gt, filter_size, in_unit, in_unit); + MatrixTranspose(matrix_gt, matrix_g, filter_size, in_unit); + return NNACL_OK; +} + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c, + const float *bias, int m, int k, int n) { + int count = 0; + MS_FLOAT32X4 bias_ptr = MS_MOVQ_F32(0); + if (bias != NULL) { + bias_ptr = MS_LDQ_F32(bias); + } + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + MS_FLOAT32X4 res = MS_MOVQ_F32(0); + for (int i = 0; i < k; i++) { + res = MS_MLAQ_F32(res, matrix_a[h_offset + i], matrix_b[w + i * n]); + } + matrix_c[count] = MS_ADDQ_F32(res, bias_ptr); + count++; + } + } +} +#endif + +int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt, + int oc_block, int input_unit, int kernel_unit, int channel, int batch, bool pack) { + if (oc_block == 0) { + return NNACL_PARAM_INVALID; + } + // original weight format : ohwi + int oc_block_num = UP_DIV(batch, oc_block); + int block_stride = channel * oc_block; + int block_num_stride = block_stride * oc_block_num; + + // trans_filter = G*g*GT (g represents weight_data) + // separate into two steps ===> tmp = (g * GT)T ===> trans = (tmp * GT)T use same function:MatrixMultiplyWinograd + float *tmp_data = (float *)(malloc(channel * input_unit * kernel_unit * sizeof(float))); + if (tmp_data == NULL) { + return NNACL_ERR; + } + float *trans_out_data = (float *)(malloc(channel * input_unit * input_unit * sizeof(float))); + if (trans_out_data == NULL) { + free(tmp_data); + return NNACL_ERR; + } + +#ifndef ENABLE_ARM + float *tmp_data1 = (float *)(malloc(channel * input_unit * kernel_unit * sizeof(float))); + if (tmp_data1 == NULL) { + free(tmp_data); + free(trans_out_data); + return NNACL_ERR; + } + float *trans_out_data1 = (float *)(malloc(channel * input_unit * input_unit * sizeof(float))); + if (trans_out_data1 == NULL) { + free(tmp_data); + free(tmp_data1); + free(trans_out_data); + return NNACL_ERR; + } +#endif + + int input_oz_offset = kernel_unit * kernel_unit * channel; + for (int i = 0; i < batch; i++) { + int out_c_block = i / oc_block; + int out_c_res = i % oc_block; + int output_oz_offset = out_c_block * block_stride + out_c_res; + +#ifndef ENABLE_ARM + // tmp_data = g * GT + MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit, kernel_unit, input_unit, + channel, channel * 4); + // tmp_data1 = (tmp_data)T + PackHWCToWHC(tmp_data, tmp_data1, kernel_unit, input_unit, channel); + // trans_out_data1 = tmp * GT + MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit, kernel_unit, input_unit, channel, + channel * 4); + // trans_out_data = (trans_out_data1)T + PackHWCToWHC(trans_out_data1, trans_out_data, input_unit, input_unit, channel); +#else + // tmp = (g * GT)T + MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit, kernel_unit, input_unit, + channel, channel * 4); + // trans = (tmp * GT)T + MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit, kernel_unit, input_unit, channel, + channel * 4); +#endif + if (pack) { + int in_offset = 0; + for (int j = 0; j < input_unit; ++j) { + for (int k = 0; k < input_unit; ++k) { + for (int c = 0; c < channel; ++c) { + *(winograd_data + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c]; + } + in_offset += channel; + output_oz_offset += block_num_stride; + } + } + } else { + memcpy(winograd_data + i * channel * input_unit * input_unit, trans_out_data, + channel * input_unit * input_unit * sizeof(float)); + } + } +#ifndef ENABLE_ARM + free(tmp_data1); + free(trans_out_data1); +#endif + free(tmp_data); + free(trans_out_data); + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/minimal_filtering_generator.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/minimal_filtering_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..44f5bd005e31f551964da46f2b72eec4ccee5b06 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/minimal_filtering_generator.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_MINIMAL_FILTERING_GENERATOR_H_ +#define NNACL_MINIMAL_FILTERING_GENERATOR_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include + +#ifdef __cplusplus +extern "C" { +#endif +void Polynomial(const float *interval, float *m, int degree); + +void DiagonalPlusMatrix(const float *matrix, float *diagonal_matrix, int degree); + +void ResidueMatrix(const float *interval, float *b, int row, int col); + +int LT(const float *poly_array, float *matrix_lt, int n); + +void T(const float *poly_array, float *matrix_t, int n); + +int B(const float *poly_array, float *matrix_b, int in_unit); + +void GenerateIntervalArray(float *array, float interval, int degree); + +void MatrixTranspose(const float *matrix, float *trans_matrix, int row, int col); + +void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_c, int m, int k, int n); + +int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g, + float *matrix_gt, float coefficient, int out_unit, int filter_size); +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel); + +int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt, + int oc_block, int input_unit_, int kernel_unit_, int channel, int batch, bool pack); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_MINIMAL_FILTERING_GENERATOR_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/scatter_nd_binary.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/scatter_nd_binary.c new file mode 100644 index 0000000000000000000000000000000000000000..eaba1d23cd0eac7fc8f40d627e6e933c3060b2af --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/scatter_nd_binary.c @@ -0,0 +1,111 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/scatter_nd_binary.h" +#include +#include +#include "nnacl/errorcode.h" +#include "nnacl/scatter_nd_binary_simd.h" + +int ScatterNDAdd(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, + int task_id) { + if (update == NULL || output == NULL || output_unit_offsets == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->num_unit); + if (type == 0) { + float *update_fp32 = (float *)update; + float *output_fp32 = (float *)output; + for (int i = begin; i < end; i++) { + const float *update_data = update_fp32 + i * param->unit_size; + float *output_data = output_fp32 + output_unit_offsets[i]; + int j = 0; + + SIMD_RUN_NO_SCALAR(ScatterNDAddFp32, j, update_data, param->unit_size, output_data); + for (; j < param->unit_size; j++) { + output_data[j] += update_data[j]; + } + } + } else { + int *update_int32 = (int *)update; + int *output_int32 = (int *)output; + for (int i = begin; i < end; i++) { + const int *update_data = update_int32 + i * param->unit_size; + int *output_data = output_int32 + output_unit_offsets[i]; + int j = 0; + + SIMD_RUN_NO_SCALAR(ScatterNDAddInt32, j, update_data, param->unit_size, output_data); + for (; j < param->unit_size; j++) { + output_data[j] += update_data[j]; + } + } + } + return NNACL_OK; +} + +int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets, const ScatterNDParameter *param, + int task_id) { + if (param->op_parameter.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->num_unit); + + int data_type_len = param->data_type_len; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(param->unit_size, data_type_len, NNACL_ERR); + + for (int i = begin; i < end; i++) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_unit_offsets[i], data_type_len, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(i, param->unit_size * data_type_len, NNACL_ERR); + (void)memcpy((int8_t *)output + output_unit_offsets[i] * data_type_len, + (int8_t *)update + i * param->unit_size * data_type_len, param->unit_size * data_type_len); + } + return NNACL_OK; +} + +int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, + int task_id) { + if (update == NULL || output == NULL || output_unit_offsets == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->num_unit); + if (type == 0) { + float *update_fp32 = (float *)update; + float *output_fp32 = (float *)output; + for (int i = begin; i < end; i++) { + const float *update_data = update_fp32 + i * param->unit_size; + float *output_data = output_fp32 + output_unit_offsets[i]; + int j = 0; + + SIMD_RUN_NO_SCALAR(ScatterNDMaxFp32, j, update_data, param->unit_size, output_data); + for (; j < param->unit_size; j++) { + output_data[j] = fmaxf(update_data[j], output_data[j]); + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/scatter_nd_binary.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/scatter_nd_binary.h new file mode 100644 index 0000000000000000000000000000000000000000..36657cd9ad88d522fe4e2d4f252fd642e7a6b559 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/scatter_nd_binary.h @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_SCATTER_ND_BINARY_H_ +#define NNACL_BASE_SCATTER_ND_BINARY_H_ + +#include "nnacl/scatter_nd_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets, const ScatterNDParameter *param, + int task_id); + +int ScatterNDAdd(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, + int task_id); + +int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, + int task_id); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_BASE_SCATTER_ND_BINARY_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/scatter_nd_binary_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/base/scatter_nd_binary_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..fcd73c69dbff177eed990cd2328f2b2472cf7afc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/scatter_nd_binary_simd.h.in @@ -0,0 +1,59 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_SCATTER_ND_BINARY_@SIMD_INSTRUCTION@_H_ +#define NNACL_BASE_SCATTER_ND_BINARY_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + + static inline int ScatterNDAddFp32@SIMD_INSTRUCTION@(int index, const float *update, int size, float *output) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); + } + return index; +} + +static inline int ScatterNDAddInt32@SIMD_INSTRUCTION@(int index, const int *update, int size, int *output) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); + } + return index; +} + +static inline int ScatterNDMaxFp32@SIMD_INSTRUCTION@(int index, const float *update, int size, float *output) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); + } + return index; +} + +static inline int ScatterNDMaxInt32@SIMD_INSTRUCTION@(int index, const int *update, int size, int *output) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif // NNACL_BASE_SCATTER_ND_BINARY_@SIMD_INSTRUCTION@_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/sequence_unstack_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/sequence_unstack_base.h new file mode 100644 index 0000000000000000000000000000000000000000..26823037dec22d21ca81c09749f50f2fbec25b76 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/sequence_unstack_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_SEQUENCE_UNSTACK_H_ +#define MINDSPORE_NNACL_SEQUENCE_UNSTACK_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/sequence_unstack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SequenceUnstack(const void *input, void **output, const SequenceUnstackParameter *para, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_SEQUENCE_UNSTACK_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/slice_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/slice_base.c new file mode 100644 index 0000000000000000000000000000000000000000..acbb10baf6109e8d1c1d5ae8fa4729193d755987 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/slice_base.c @@ -0,0 +1,173 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/base/slice_base.h" +#include + +void InitSliceStruct(SliceStruct *slice, TensorC *in_tensor, TensorC *begin_tensor, TensorC *size_tensor) { + slice->param_length_ = in_tensor->shape_size_; + + int32_t *begin = (int32_t *)begin_tensor->data_; + int32_t *size = (int32_t *)size_tensor->data_; + + for (int i = 0; i < slice->param_length_; ++i) { + slice->shape_[i] = in_tensor->shape_[i]; + slice->begin_[i] = begin[i]; + slice->size_[i] = size[i] < 0 ? slice->shape_[i] - slice->begin_[i] : size[i]; + slice->end_[i] = slice->begin_[i] + slice->size_[i]; + } + return; +} + +void PadSliceParameterTo8D(SliceStruct *param) { + int32_t begin[DIMENSION_8D]; + int32_t end[DIMENSION_8D]; + int32_t slice_size[DIMENSION_8D]; + int32_t data_shape[DIMENSION_8D]; + for (int32_t i = 0; i < param->param_length_; ++i) { + begin[i] = param->begin_[i]; + end[i] = param->end_[i]; + slice_size[i] = param->size_[i] < 0 ? param->shape_[i] - begin[i] : param->size_[i]; + data_shape[i] = param->shape_[i]; + } + int32_t real_index = param->param_length_ - 1; + for (int32_t i = DIMENSION_8D - 1; i >= 0; --i) { + if (real_index >= 0) { + param->begin_[i] = begin[real_index]; + param->end_[i] = end[real_index]; + param->size_[i] = slice_size[real_index]; + param->shape_[i] = data_shape[real_index--]; + } else { + param->begin_[i] = 0; + param->end_[i] = 1; + param->size_[i] = 1; + param->shape_[i] = 1; + } + } + param->param_length_ = DIMENSION_8D; +} + +void DoSlice(const void *input, void *output, const SliceStruct *param, int thread_id, int thread_num, int data_size) { + int8_t *int8_in = (int8_t *)input; + int8_t *int8_out = (int8_t *)output; + + int out_stride[8]; + out_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + out_stride[i] = out_stride[i + 1] * param->size_[i + 1]; + } + int count_per_thread = UP_DIV(param->size_[5], thread_num); + int thread_begin = thread_id * count_per_thread; + int thread_end = MSMIN(param->size_[5], thread_begin + count_per_thread); + int copy_size = param->size_[7] * data_size; + int in_stride[8]; + in_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + in_stride[i] = param->shape_[i + 1] * in_stride[i + 1]; + } + + for (int ii = 0; ii < param->size_[0]; ++ii) { + int out_offset0 = ii * out_stride[0]; + int in_offset0 = (ii + param->begin_[0]) * in_stride[0] + param->begin_[7]; + for (int jj = 0; jj < param->size_[1]; ++jj) { + int out_offset1 = jj * out_stride[1] + out_offset0; + int in_offset1 = (jj + param->begin_[1]) * in_stride[1] + in_offset0; + for (int kk = 0; kk < param->size_[2]; ++kk) { + int out_offset2 = kk * out_stride[2] + out_offset1; + int in_offset2 = (kk + param->begin_[2]) * in_stride[2] + in_offset1; + for (int ll = 0; ll < param->size_[3]; ++ll) { + int out_offset3 = ll * out_stride[3] + out_offset2; + int in_offset3 = (ll + param->begin_[3]) * in_stride[3] + in_offset2; + for (int i = 0; i < param->size_[4]; ++i) { + int out_offset4 = i * out_stride[4] + out_offset3; + int in_offset4 = (i + param->begin_[4]) * in_stride[4] + in_offset3; + for (int j = thread_begin; j < thread_end; ++j) { + int out_offset5 = j * out_stride[5] + out_offset4; + int in_offset5 = (j + param->begin_[5]) * in_stride[5] + in_offset4; + for (int k = 0; k < param->size_[6]; ++k) { + int out_offset6 = k * out_stride[6] + out_offset5; + int in_offset6 = (k + param->begin_[6]) * in_stride[6] + in_offset5; + memcpy(int8_out + out_offset6 * data_size, int8_in + in_offset6 * data_size, copy_size); + } + } + } + } + } + } + } +} + +static bool WhetherCopyByAxis(const int32_t *begin, const int32_t *end, const int32_t *shape, int dim) { + for (int i = dim + 1; i < DIMENSION_8D; ++i) { + if (begin[i] != 0 || end[i] != shape[i]) return false; + } + return true; +} + +void DoSliceNoParallel(const void *input, void *output, const SliceStruct *param, int data_size) { + int8_t *int8_in = (int8_t *)input; + int8_t *int8_out = (int8_t *)output; + + int copy_size = param->size_[7] * data_size; + int in_stride[8]; + in_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + in_stride[i] = param->shape_[i + 1] * in_stride[i + 1]; + } + bool axis_copy_flag[DIMENSION_8D] = {false}; + for (int i = 0; i < DIMENSION_8D; ++i) { + axis_copy_flag[i] = WhetherCopyByAxis(param->begin_, param->end_, param->shape_, i); + } + int out_offset = 0; + for (int32_t dim0 = param->begin_[0]; dim0 < param->end_[0]; ++dim0) { + int in_offset0 = dim0 * in_stride[0] + param->begin_[7]; +#define FAST_COPY_IF_NEED(rank) \ + if (axis_copy_flag[rank]) { \ + int left_block_num = param->end_[rank] - dim##rank; \ + memcpy(int8_out + out_offset * data_size, int8_in + in_offset##rank * data_size, \ + in_stride[rank] * left_block_num * data_size); \ + out_offset += in_stride[rank] * left_block_num; \ + dim##rank += left_block_num; \ + continue; \ + } + FAST_COPY_IF_NEED(0); + for (int dim1 = param->begin_[1]; dim1 < param->end_[1]; ++dim1) { + int in_offset1 = dim1 * in_stride[1] + in_offset0; + FAST_COPY_IF_NEED(1); + for (int32_t dim2 = param->begin_[2]; dim2 < param->end_[2]; ++dim2) { + int in_offset2 = in_offset1 + dim2 * in_stride[2]; + FAST_COPY_IF_NEED(2); + for (int32_t dim3 = param->begin_[3]; dim3 < param->end_[3]; ++dim3) { + int in_offset3 = in_offset2 + dim3 * in_stride[3]; + FAST_COPY_IF_NEED(3); + for (int32_t dim4 = param->begin_[4]; dim4 < param->end_[4]; ++dim4) { + int in_offset4 = in_offset3 + dim4 * in_stride[4]; + FAST_COPY_IF_NEED(4); + for (int32_t dim5 = param->begin_[5]; dim5 < param->end_[5]; ++dim5) { + int in_offset5 = in_offset4 + dim5 * in_stride[5]; + FAST_COPY_IF_NEED(5); +#undef FAST_COPY_IF_NEED + for (int32_t dim6 = param->begin_[6]; dim6 < param->end_[6]; ++dim6) { + int in_offset6 = in_offset5 + dim6 * in_stride[6]; + memcpy(int8_out + out_offset * data_size, int8_in + in_offset6 * data_size, copy_size); + out_offset += param->size_[7]; + } + } + } + } + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/slice_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/slice_base.h new file mode 100644 index 0000000000000000000000000000000000000000..f93d8665aa3be7ba20d5265210f86166813c6a5a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/slice_base.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_SLICE_BASE_H_ +#define NNACL_BASE_SLICE_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/slice_parameter.h" +#include "nnacl/kernel/slice.h" + +#ifdef __cplusplus +extern "C" { +#endif +void InitSliceStruct(SliceStruct *slice, TensorC *in_tensor, TensorC *begin_tensor, TensorC *size_tensor); +void PadSliceParameterTo8D(SliceStruct *param); + +void DoSlice(const void *input, void *output, const SliceStruct *param, int thread_id, int thread_num, int data_size); +void DoSliceNoParallel(const void *input, void *output, const SliceStruct *param, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_BASE_SLICE_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/space_to_depth_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/space_to_depth_base.c new file mode 100644 index 0000000000000000000000000000000000000000..8317c9873a2031505b15022e62e016b6cd3b70a1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/space_to_depth_base.c @@ -0,0 +1,54 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/space_to_depth_base.h" +#include "nnacl/common_func.h" +#include "nnacl/errorcode.h" + +int SpaceToDepthForNHWC(const void *input, void *output, const int *in_shape, const int *out_shape, int shape_size, + SpaceToDepthParameter *param, int task_id) { + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + int output_h = out_shape[kNHWC_H]; + int unit_per_thread = UP_DIV(output_h, param->op_parameter_.thread_num_); + int h_start = unit_per_thread * task_id; + int h_end = MSMIN(h_start + unit_per_thread, output_h); + + int block_size = param->block_size_; + int in_strides[C4NUM]; + int out_strides[C4NUM]; + ComputeStrides(in_shape, in_strides, shape_size); + ComputeStrides(out_shape, out_strides, shape_size); + for (int i = 0; i < out_shape[0]; ++i) { + int64_t in_offset_n = i * in_strides[0]; + int64_t out_offset_n = i * out_strides[0]; + for (int j = h_start; j < h_end; ++j) { + int64_t in_offset_h = in_offset_n + j * block_size * in_strides[1]; + int64_t out_offset_h = out_offset_n + j * out_strides[1]; + for (int k = 0; k < out_shape[2]; ++k) { + int64_t in_offset_w = in_offset_h + k * block_size * in_strides[2]; + int64_t out_offset_w = out_offset_h + k * out_strides[2]; + for (int l = 0; l < block_size; ++l) { + memcpy((int8_t *)output + (out_offset_w + l * block_size * in_strides[DIMENSION_2D]) * param->date_type_len, + (const int8_t *)input + (in_offset_w + l * in_strides[DIMENSION_1D]) * param->date_type_len, + block_size * in_strides[DIMENSION_2D] * param->date_type_len); + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/space_to_depth_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/space_to_depth_base.h new file mode 100644 index 0000000000000000000000000000000000000000..4a1207c3f9f1023e0f2fd0a7efbf7b832ae1c291 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/space_to_depth_base.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_SPACE_TO_DEPTH_BASE_H_ +#define NNACL_BASE_SPACE_TO_DEPTH_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/space_to_depth_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SpaceToDepthForNHWC(const void *input, void *output, const int *in_shape, const int *out_shape, int shape_size, + SpaceToDepthParameter *param, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_SPACE_TO_DEPTH_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/split_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/split_base.c new file mode 100644 index 0000000000000000000000000000000000000000..16a5b384acf9f15e3ad493474bacbd1d19aab0de --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/split_base.c @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/split_base.h" +#include "nnacl/split_parameter.h" +#include +#include "nnacl/errorcode.h" + +int DoSplit(const void *in_data, void **out_data, const int *input_shape, int offset, int num_unit, + const SplitParameter *split_param, int data_size) { + const int8_t *int8_in = (int8_t *)in_data; + + const int num_split = split_param->num_split_; + const int *split_sizes = split_param->split_sizes_; + const int *strides = split_param->strides_; + const int split_dim = split_param->split_dim_; + int in_stride = strides[split_dim]; + + int in_stride_bytes = in_stride * data_size; + + int split_which; + int split_times; + int stride_per_split = in_stride * input_shape[split_dim]; + + split_which = offset % num_split; + split_times = offset / num_split; + const int8_t *src = int8_in + split_times * stride_per_split * data_size; + + for (int i = 0; i < split_which; i++) { + src += split_sizes[i] * in_stride * data_size; + } + + for (int i = offset; i < offset + num_unit; i++) { + split_which = i % num_split; + split_times = i / num_split; + int split_size = split_sizes[split_which]; + int8_t *int8_out = (int8_t *)out_data[split_which]; + int8_t *dst = int8_out + split_times * in_stride * split_size * data_size; + (void)memcpy(dst, src, split_size * in_stride_bytes); + src += split_size * in_stride * data_size; + } + + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/split_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/split_base.h new file mode 100644 index 0000000000000000000000000000000000000000..bf87e9f67f15c8aea471fa1722733e7ff3990ee9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/split_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_SPLIT_BASE_H_ +#define NNACL_BASE_SPLIT_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoSplit(const void *in_data, void **out_data, const int *input_shape, int offset, int num_unit, + const SplitParameter *split_param, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_SPLIT_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/split_with_over_lap_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/split_with_over_lap_base.c new file mode 100644 index 0000000000000000000000000000000000000000..77d3844be9ca75b79499f0f285431f76c8cf3f56 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/split_with_over_lap_base.c @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/split_with_over_lap_base.h" +#include +#include "nnacl/errorcode.h" + +int DoSplitWithOverlapParallel(const char *in_data, char **out_data, int slice_idx, + const SplitWithOverlapParameter *param, const int *start_indices, + const int *end_indices) { + int start_index = start_indices[slice_idx]; + int end_index = end_indices[slice_idx]; + + int input_stride = param->split_dim_size_ * param->inner_stride_ * param->element_bytes_; + int out_stride = (end_index - start_index) * param->inner_stride_ * param->element_bytes_; + + const char *src_ptr = in_data + start_index * param->inner_stride_ * param->element_bytes_; + char *dst_ptr = out_data[slice_idx]; + + for (int i = 0; i < param->outer_total_dim_; i++) { + (void)memcpy(dst_ptr + i * out_stride, src_ptr, out_stride); + src_ptr += input_stride; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/split_with_over_lap_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/split_with_over_lap_base.h new file mode 100644 index 0000000000000000000000000000000000000000..e425dcf4f3accf3f68756db8a451cb2461a5ae8c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/split_with_over_lap_base.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ +#define NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoSplitWithOverlapParallel(const char *in_data, char **out_data, int slice_idx, + const SplitWithOverlapParameter *param, const int *start_indices, + const int *end_indices); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_NNACL_SPLIT_WITH_OVER_LAP_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/stack_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/stack_base.c new file mode 100644 index 0000000000000000000000000000000000000000..69eb4a42e8c3e5378253bbda69c2aef596f54431 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/stack_base.c @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/base/stack_base.h" + +void Stack(void **inputs, void *output, size_t input_num, size_t copy_size, int outer_start, int outer_end) { + size_t out_offset = 0; + for (size_t i = outer_start; i < outer_end; ++i) { + for (size_t j = 0; j < input_num; ++j) { + memcpy((char *)output + out_offset, (char *)inputs[j] + i * copy_size, copy_size); + out_offset += copy_size; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/stack_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/stack_base.h new file mode 100644 index 0000000000000000000000000000000000000000..dbd9d22a2565a6384078abc01ed407c3509131a3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/stack_base.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BASE_STACK_BASE_H_ +#define NNACL_BASE_STACK_BASE_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/stack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Stack(void **inputs, void *output, size_t input_num, size_t copy_size, int outer_start, int outer_end); +#ifdef __cplusplus +} +#endif +#endif // NNACL_BASE_STACK_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/tile_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/tile_base.c new file mode 100644 index 0000000000000000000000000000000000000000..f74896a9155ba5beda548854d05dc547788aae66 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/tile_base.c @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/tile_base.h" +#include +#include "nnacl/errorcode.h" + +void DoCopyData(const uint8_t *input_data, uint8_t *output_data, size_t size, size_t data_size, size_t multiple) { + uint8_t *out_data = output_data; + for (size_t i = 0; i < multiple; ++i) { + (void)memcpy(out_data, input_data, size * sizeof(uint8_t) * data_size); + out_data += size * data_size; + } +} + +int DoTileOneDimension(uint8_t *input_data, uint8_t *output_data, size_t dim, const TileStruct *tile) { + int src_dim_size = tile->in_shape_[dim]; + if (dim == tile->in_dim_ - 1) { + DoCopyData(input_data, output_data, src_dim_size, tile->data_size_, tile->multiples_[dim]); + return NNACL_OK; + } + for (int i = 0; i < src_dim_size; ++i) { + for (int j = 0; j < tile->multiples_[dim]; ++j) { + int in_pos = tile->in_strides_[dim] * i; + int out_pos = tile->out_strides_[dim] * (i + j * src_dim_size); + DoTileOneDimension(input_data + in_pos * tile->data_size_, output_data + out_pos * tile->data_size_, dim + 1, + tile); + } + } + return NNACL_OK; +} + +void Tile(void *input_data, void *output_data, const TileStruct *tile) { + DoTileOneDimension((uint8_t *)input_data, (uint8_t *)output_data, 0, tile); +} + +void TileSimple(void *input_data, void *output_data, size_t begin, size_t end, const TileStruct *tile) { + uint8_t *out_data = output_data; + uint8_t *in_data = input_data; + size_t dst_one_row_size = tile->fast_stride_ * tile->fast_multiple_ * tile->data_size_; + for (size_t i = begin; i < end; ++i) { + uint8_t *src = in_data + i * tile->fast_stride_ * tile->data_size_; + uint8_t *dst = out_data + i * tile->fast_stride_ * tile->fast_multiple_ * tile->data_size_; + size_t offset = tile->fast_stride_ * tile->data_size_; + (void)memcpy(dst, src, offset); + // copy size double each time + while (2 * offset <= dst_one_row_size) { + (void)memcpy(dst + offset, dst, offset); + offset *= 2; + } + if (2 * offset > dst_one_row_size) { + (void)memcpy(dst + offset, dst, dst_one_row_size - offset); + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/tile_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/tile_base.h new file mode 100644 index 0000000000000000000000000000000000000000..db2bbe1238175396d017e70cc45d5340f3a0a274 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/tile_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_TILE_BASE_H_ +#define NNACL_BASE_TILE_BASE_H_ + +#include "nnacl/kernel/tile.h" +#include "nnacl/tile_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Tile(void *input_data, void *output_data, const TileStruct *tile); +void TileSimple(void *input_data, void *output_data, size_t begin, size_t end, const TileStruct *tile); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_TILE_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/transpose_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/transpose_base.c new file mode 100644 index 0000000000000000000000000000000000000000..a0c63001bcc1410900101b53b31275cf75526605 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/transpose_base.c @@ -0,0 +1,274 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/transpose_base.h" +#include "nnacl/errorcode.h" + +#define TRANSPOSE_TWO_DIMS(TYPE, NAME) \ + void TransposeDim2##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * output1; \ + int stride0_i = i * 1 * stride0; \ + for (int j = 0; j < output1; ++j) { \ + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; \ + } \ + } \ + } + +#define TRANSPOSE_THREE_DIMS(TYPE, NAME) \ + void TransposeDim3##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int stride2 = strides[perm[2]]; \ + const int out_stride0 = out_strides[0]; \ + const int out_stride1 = out_strides[1]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + const int output2 = output_shape[2]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * out_stride0; \ + int stride0_i = i * stride0; \ + for (int j = 0; j < output1; ++j) { \ + int out_stride1_j = j * out_stride1; \ + int stride1_j = j * stride1; \ + for (int k = 0; k < output2; ++k) { \ + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; \ + } \ + } \ + } \ + } + +#define TRANSPOSE_FOUR_DIMS(TYPE, NAME) \ + void TransposeDim4##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int stride2 = strides[perm[2]]; \ + const int stride3 = strides[perm[3]]; \ + const int out_stride0 = out_strides[0]; \ + const int out_stride1 = out_strides[1]; \ + const int out_stride2 = out_strides[2]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + const int output2 = output_shape[2]; \ + const int output3 = output_shape[3]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * out_stride0; \ + int stride0_i = i * stride0; \ + for (int j = 0; j < output1; ++j) { \ + int out_stride1_j = j * out_stride1; \ + int stride1_j = j * stride1; \ + for (int k = 0; k < output2; ++k) { \ + int out_stride2_k = k * out_stride2; \ + int stride2_k = k * stride2; \ + for (int m = 0; m < output3; ++m) { \ + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = \ + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; \ + } \ + } \ + } \ + } \ + } + +#define TRANSPOSE_FIVE_DIMS(TYPE, NAME) \ + void TransposeDim5##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int stride2 = strides[perm[2]]; \ + const int stride3 = strides[perm[3]]; \ + const int stride4 = strides[perm[4]]; \ + const int out_stride0 = out_strides[0]; \ + const int out_stride1 = out_strides[1]; \ + const int out_stride2 = out_strides[2]; \ + const int out_stride3 = out_strides[3]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + const int output2 = output_shape[2]; \ + const int output3 = output_shape[3]; \ + const int output4 = output_shape[4]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * out_stride0; \ + int stride0_i = i * stride0; \ + for (int j = 0; j < output1; ++j) { \ + int out_stride1_j = j * out_stride1; \ + int stride1_j = j * stride1; \ + for (int k = 0; k < output2; ++k) { \ + int out_stride2_k = k * out_stride2; \ + int stride2_k = k * stride2; \ + for (int m = 0; m < output3; ++m) { \ + int out_stride3_m = m * out_stride3; \ + int stride3_m = m * stride3; \ + for (int n = 0; n < output4; ++n) { \ + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = \ + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; \ + } \ + } \ + } \ + } \ + } \ + } + +#define TRANSPOSE_SIX_DIMS(TYPE, NAME) \ + void TransposeDim6##NAME(const TYPE *in_data, TYPE *out_data, const int *strides, const int *out_strides, \ + const int *perm, const int *output_shape) { \ + const int stride0 = strides[perm[0]]; \ + const int stride1 = strides[perm[1]]; \ + const int stride2 = strides[perm[2]]; \ + const int stride3 = strides[perm[3]]; \ + const int stride4 = strides[perm[4]]; \ + const int stride5 = strides[perm[5]]; \ + const int out_stride0 = out_strides[0]; \ + const int out_stride1 = out_strides[1]; \ + const int out_stride2 = out_strides[2]; \ + const int out_stride3 = out_strides[3]; \ + const int out_stride4 = out_strides[4]; \ + const int output0 = output_shape[0]; \ + const int output1 = output_shape[1]; \ + const int output2 = output_shape[2]; \ + const int output3 = output_shape[3]; \ + const int output4 = output_shape[4]; \ + const int output5 = output_shape[5]; \ + for (int i = 0; i < output0; ++i) { \ + int out_stride0_i = i * out_stride0; \ + int stride0_i = i * stride0; \ + for (int j = 0; j < output1; ++j) { \ + int out_stride1_j = j * out_stride1; \ + int stride1_j = j * stride1; \ + for (int k = 0; k < output2; ++k) { \ + int out_stride2_k = k * out_stride2; \ + int stride2_k = k * stride2; \ + for (int m = 0; m < output3; ++m) { \ + int out_stride3_m = m * out_stride3; \ + int stride3_m = m * stride3; \ + for (int n = 0; n < output4; ++n) { \ + int out_stride4_n = n * out_stride4; \ + int stride4_n = n * stride4; \ + for (int g = 0; g < output5; ++g) { \ + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + g] = \ + in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + g * stride5]; \ + } \ + } \ + } \ + } \ + } \ + } \ + } + +#define TRANSPOSE_DIMS(TYPE, NAME) \ + void TransposeDims##NAME(const TYPE *in_data, TYPE *out_data, const int *output_shape, \ + const TransposeParameter *transpose_param, int task_id, int thread_num) { \ + NNACL_CHECK_NULL_RETURN_VOID(in_data); \ + NNACL_CHECK_NULL_RETURN_VOID(out_data); \ + NNACL_CHECK_NULL_RETURN_VOID(output_shape); \ + NNACL_CHECK_NULL_RETURN_VOID(transpose_param); \ + NNACL_CHECK_ZERO_RETURN(thread_num); \ + const int *perm = transpose_param->perm_; \ + const int *strides = transpose_param->strides_; \ + const int *out_strides = transpose_param->out_strides_; \ + int num_axes = transpose_param->num_axes_; \ + size_t data_size = (*out_strides) * output_shape[0]; \ + size_t offset_size = UP_DIV(data_size, thread_num); \ + size_t task_offset = offset_size * task_id; \ + int count = data_size - task_offset; \ + if (count <= 0) { \ + return; \ + } \ + count = MSMIN(offset_size, count); \ + for (int idx = task_offset; idx < task_offset + count; ++idx) { \ + int pos = idx; \ + int output_idx = 0; \ + int input_idx = 0; \ + for (int i = 0; i < num_axes; ++i) { \ + NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); \ + int position = pos / *(out_strides + i); \ + int out_stride = i < num_axes - 1 ? out_strides[i] : 1; \ + output_idx += (position * out_stride); \ + input_idx += (position * strides[perm[i]]); \ + pos -= position * (*(out_strides + i)); \ + } \ + out_data[output_idx] = in_data[input_idx]; \ + } \ + } + +#define DOTRANSPOSE(TYPE, NAME) \ + int DoTranspose##NAME(const TYPE *in_data, TYPE *out_data, const int *output_shape, \ + const TransposeParameter *transpose_param) { \ + NNACL_CHECK_NULL_RETURN_ERR(in_data); \ + NNACL_CHECK_NULL_RETURN_ERR(out_data); \ + NNACL_CHECK_NULL_RETURN_ERR(output_shape); \ + NNACL_CHECK_NULL_RETURN_ERR(transpose_param); \ + const int *perm = transpose_param->perm_; \ + const int *strides = transpose_param->strides_; \ + const int *out_strides = transpose_param->out_strides_; \ + int data_size = transpose_param->data_num_ * sizeof(TYPE); \ + int num_axes = transpose_param->num_axes_; \ + bool needTranspose = false; \ + for (int i = 1; i < num_axes; ++i) { \ + if (perm[i] - perm[i - 1] != 1) { \ + needTranspose = true; \ + break; \ + } \ + } \ + if (!needTranspose) { \ + (void)memcpy(out_data, in_data, data_size); \ + return NNACL_OK; \ + } \ + for (int i = 0; i < num_axes; ++i) { \ + if (perm[i] < 0) { \ + return NNACL_PARAM_INVALID; \ + } \ + } \ + if (num_axes == 2) { \ + TransposeDim2##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else if (num_axes == 3) { \ + TransposeDim3##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else if (num_axes == 4) { \ + TransposeDim4##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else if (num_axes == 5) { \ + TransposeDim5##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else if (num_axes == 6) { \ + TransposeDim6##NAME(in_data, out_data, strides, out_strides, perm, output_shape); \ + } else { \ + return NNACL_ERR; \ + } \ + return NNACL_OK; \ + } + +#define TRANSPOSE_TEMPLATE(TYPE, NAME) \ + TRANSPOSE_TWO_DIMS(TYPE, NAME) \ + TRANSPOSE_THREE_DIMS(TYPE, NAME) \ + TRANSPOSE_FOUR_DIMS(TYPE, NAME) \ + TRANSPOSE_FIVE_DIMS(TYPE, NAME) \ + TRANSPOSE_SIX_DIMS(TYPE, NAME) \ + TRANSPOSE_DIMS(TYPE, NAME) \ + DOTRANSPOSE(TYPE, NAME) + +TRANSPOSE_TEMPLATE(uint8_t, UInt8) +TRANSPOSE_TEMPLATE(uint16_t, UInt16) +TRANSPOSE_TEMPLATE(uint32_t, UInt32) +TRANSPOSE_TEMPLATE(uint64_t, UInt64) +TRANSPOSE_TEMPLATE(int16_t, Int16) +TRANSPOSE_TEMPLATE(int32_t, Int32) +TRANSPOSE_TEMPLATE(int64_t, Int64) +TRANSPOSE_TEMPLATE(double, Float64) +TRANSPOSE_TEMPLATE(bool, Bool) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/transpose_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/transpose_base.h new file mode 100644 index 0000000000000000000000000000000000000000..35ac0813798f416b63e3a64e19343b4cdb5ea127 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/transpose_base.h @@ -0,0 +1,69 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_TRANSPOSE_BASE_H_ +#define NNACL_BASE_TRANSPOSE_BASE_H_ + +#include "nnacl/transpose_parameter.h" +#include + +#ifdef __cplusplus +extern "C" { +#endif + +int DoTransposeUInt8(const uint8_t *in_data, uint8_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeUInt16(const uint16_t *in_data, uint16_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeUInt32(const uint32_t *in_data, uint32_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeUInt64(const uint64_t *in_data, uint64_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeInt16(const int16_t *in_data, int16_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeInt32(const int32_t *in_data, int32_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeInt64(const int64_t *in_data, int64_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeFloat64(const double *in_data, double *out_data, const int *output_shape, + const TransposeParameter *transpose_param); +int DoTransposeBool(const bool *in_data, bool *out_data, const int *output_shape, + const TransposeParameter *transpose_param); + +void TransposeDimsUInt8(const uint8_t *in_data, uint8_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsUInt16(const uint16_t *in_data, uint16_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsUInt32(const uint32_t *in_data, uint32_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsUInt64(const uint64_t *in_data, uint64_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsInt16(const int16_t *in_data, int16_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsInt32(const int32_t *in_data, int32_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsInt64(const int64_t *in_data, int64_t *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsFloat64(const double *in_data, double *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +void TransposeDimsBool(const bool *in_data, bool *out_data, const int *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_TRANSPOSE_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/unsorted_segment_sum_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/unsorted_segment_sum_base.c new file mode 100644 index 0000000000000000000000000000000000000000..6d013fd28d2214e73bbd85a67d2d68622a79890c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/unsorted_segment_sum_base.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/base/unsorted_segment_sum_base.h" +#include "nnacl/errorcode.h" + +#define UNSORTEDSEGMENTSUM(type, type1) \ + int UnsortedSegmentSum_##type##_##type1(const type *input, int unit_num, int input_dim1, const type1 *indices, \ + type *output, int output_dim0, int output_dim1) { \ + NNACL_CHECK_NULL_RETURN_ERR(input); \ + NNACL_CHECK_NULL_RETURN_ERR(indices); \ + NNACL_CHECK_NULL_RETURN_ERR(output); \ + if (input_dim1 == 0) { \ + return NNACL_ERR; \ + } \ + for (int i = 0; i < unit_num; ++i) { \ + int j = i / input_dim1; \ + int k = i % input_dim1; \ + \ + type1 index = indices[j]; \ + if (index < 0 || index >= output_dim0) { \ + continue; \ + } \ + type1 output_index = index * output_dim1 + k; \ + output[output_index] += input[i]; \ + } \ + return NNACL_OK; \ + } + +UNSORTEDSEGMENTSUM(int, int) +UNSORTEDSEGMENTSUM(float, int) +UNSORTEDSEGMENTSUM(int, int64_t) +UNSORTEDSEGMENTSUM(float, int64_t) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/unsorted_segment_sum_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/unsorted_segment_sum_base.h new file mode 100644 index 0000000000000000000000000000000000000000..a1272a8e5f72e275d157e13f78490649c5780122 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/unsorted_segment_sum_base.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_UNSORTED_SEGMENT_SUM_BASE_H_ +#define NNACL_BASE_UNSORTED_SEGMENT_SUM_BASE_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define UnsortedSegmentSum(type, type1, input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) \ + UnsortedSegmentSum_##type##_##type1(input, unit_num, input_dim1, indices, output, output_dim0, output_dim1) +int UnsortedSegmentSum_int_int(const int *input, int unit_num, int input_dim1, const int *indices, int *output, + int output_dim0, int output_dim1); +int UnsortedSegmentSum_float_int(const float *input, int unit_num, int input_dim1, const int *indices, float *output, + int output_dim0, int output_dim1); +int UnsortedSegmentSum_int_int64_t(const int *input, int unit_num, int input_dim1, const int64_t *indices, int *output, + int output_dim0, int output_dim1); +int UnsortedSegmentSum_float_int64_t(const float *input, int unit_num, int input_dim1, const int64_t *indices, + float *output, int output_dim0, int output_dim1); +#ifdef __cplusplus +} +#endif +#endif // NNACL_BASE_UNSORTED_SEGMENT_SUM_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/unstack_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/base/unstack_base.c new file mode 100644 index 0000000000000000000000000000000000000000..bd2c2c3cf04d21ee562ecaee7839d649ed6f14f5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/unstack_base.c @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/base/unstack_base.h" + +void Unstack(const void *input, void **output, const UnstackParameter *para, int data_size) { + NNACL_CHECK_NULL_RETURN_VOID(input); + NNACL_CHECK_NULL_RETURN_VOID(output); + NNACL_CHECK_NULL_RETURN_VOID(para); + const int8_t *in_addr = (int8_t *)input; + for (int j = 0; j < para->num_; j++) { + int8_t *out_addr = (int8_t *)output[j]; + int out_offset = 0; + for (int i = 0; i < para->pre_dims_; i++) { + int in_offset = i * para->axis_dim_ * para->after_dims_ + j * para->after_dims_; + (void)memcpy(out_addr + out_offset * data_size, in_addr + in_offset * data_size, para->after_dims_ * data_size); + out_offset += para->after_dims_; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/base/unstack_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/base/unstack_base.h new file mode 100644 index 0000000000000000000000000000000000000000..919fb6b2355b33a440d876a89d5f06779c45eea4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/base/unstack_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BASE_UNSTACK_BASE_H_ +#define NNACL_BASE_UNSTACK_BASE_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/unstack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Unstack(const void *input, void **output, const UnstackParameter *para, int data_size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_BASE_UNSTACK_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/batch_to_space_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/batch_to_space_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..d347bf6437d461432f0db700c592a5f78a951117 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/batch_to_space_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BATCH_TO_SPACE_PARAMETER_H_ +#define NNACL_BATCH_TO_SPACE_PARAMETER_H_ + +#include +#include "nnacl/op_base.h" + +#define BATCH_TO_SPACE_BLOCK_SHAPE_SIZE 2 + +typedef struct BatchToSpaceParameter { + OpParameter op_parameter_; + int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE]; + int32_t crops_[COMM_SHAPE_SIZE]; +} BatchToSpaceParameter; + +#endif // NNACL_BATCH_TO_SPACE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/batchnorm_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/batchnorm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..14422d89d967704f7f8d6e3e11c366536a0359e6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/batchnorm_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_BATCHNORM_PARAMETER_H_ +#define NNACL_BATCHNORM_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct BatchNormParameter { + OpParameter op_parameter_; + float epsilon_; + bool is_training_; + float momentum_; +} BatchNormParameter; + +#endif // NNACL_BATCHNORM_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/broadcast_to_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/broadcast_to_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..d8892e665717d0147a9d94a013841b1e682a7109 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/broadcast_to_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BROADCAST_TO_PARAMETER_H_ +#define NNACL_BROADCAST_TO_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct BroadcastToParameter { + OpParameter op_parameter_; + int shape_[MAX_SHAPE_SIZE]; + size_t shape_size_; +} BroadcastToParameter; + +typedef struct BroadcastShapeInfo { + int input_shape_[MAX_SHAPE_SIZE]; + int input_shape_size_; + int output_shape_[MAX_SHAPE_SIZE]; + int output_shape_size_; +} BroadcastShapeInfo; + +#endif // NNACL_BROADCAST_TO_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/call_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/call_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..340c957b96dae105573902fa8029b0a46f416985 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/call_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CALL_PARAMETER_H_ +#define NNACL_CALL_PARAMETER_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +typedef struct CallParameter { + OpParameter op_parameter_; + bool is_tail_call; +} CallParameter; + +#endif // NNACL_CALL_PARAMETER_H_ diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/utils.h b/mindspore-lite/ops/kernel/cpu/nnacl/clip_parameter.h similarity index 66% rename from mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/utils.h rename to mindspore-lite/ops/kernel/cpu/nnacl/clip_parameter.h index 76ea7b2055886803ab6e0c98eec78dec10f51dad..c94a494982d1bb9c974cbc3139a8d37434656607 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/utils.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/clip_parameter.h @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_CXX_API_UTILS_H_ -#define MINDSPORE_CCSRC_CXX_API_UTILS_H_ -#include -#include -#include "include/api/visible.h" -namespace mindspore { -MS_API bool CreateGroupsByCkptFile(const std::string &file); -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_UTILS_H_ +#ifndef NNACL_CLIP_PARAMETER_H_ +#define NNACL_CLIP_PARAMETER_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +typedef struct ClipParameter { + OpParameter op_parameter_; + float min_val_; + float max_val_; +} ClipParameter; + +#endif // NNACL_CLIP_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/common_func.c b/mindspore-lite/ops/kernel/cpu/nnacl/common_func.c new file mode 100644 index 0000000000000000000000000000000000000000..c29d9a6380878a31454a6b20705cde9627771d49 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/common_func.c @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/common_func.h" + +int Offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3) { + return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3] + dim3; +} + +int64_t OffsetComm(const int *shape, const int dim0, const int dim1, const int dim2) { + return ((dim0 * shape[1] + dim1) * shape[2] + dim2) * shape[3]; +} + +int Offset4d(const int *shape, const int *dims) { return Offset(shape, dims[0], dims[1], dims[2], dims[3]); } + +int64_t Offset6d(const int *shape, const int *dims) { + return ((OffsetComm(shape, dims[0], dims[1], dims[2]) + dims[3]) * shape[4] + dims[4]) * shape[5]; +} + +int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); } + +int8_t MaxInt8(int8_t a, int8_t b) { return a ^ ((a ^ b) & -(a < b)); } diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/common_func.h b/mindspore-lite/ops/kernel/cpu/nnacl/common_func.h new file mode 100644 index 0000000000000000000000000000000000000000..c80d151fa33d1a03a172db53e2dafc0a1853baa6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/common_func.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_COMMON_FUNC_H_ +#define MINDSPORE_NNACL_COMMON_FUNC_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/nnacl_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int8_t MinInt8(int8_t a, int8_t b); +int8_t MaxInt8(int8_t a, int8_t b); +int Offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3); +int64_t OffsetComm(const int *shape, const int dim0, const int dim1, const int dim2); +int Offset4d(const int *shape, const int *dims); +int64_t Offset6d(const int *shape, const int *dims); + +static inline bool isAddOverflow(int32_t x, int32_t y) { + int32_t sum = x + y; + return (x > 0 && y > 0 && sum < 0) || (x < 0 && y < 0 && sum > 0); +} + +static inline bool isMulOverflow(int32_t x, int32_t y) { + int32_t p = x * y; + return (x != 0) && (p / x != y); +} + +static inline int GetStride(int *strides, const int *shape, int length) { + if (length <= 0) { + return 1; + } + int stride = 1; + for (int i = length - 1; i >= 0; --i) { + strides[i] = stride; + stride *= shape[i]; + } + return stride; +} +#ifdef __cplusplus +} +#endif + +#endif /* MINDSPORE_NNACL_COMMON_FUNC_H_ */ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/concat_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/concat_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..cb61518758c922c706df647c2002e742ee15d089 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/concat_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_CONCAT_PARAMETER_H_ +#define MINDSPORE_NNACL_CONCAT_PARAMETER_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +typedef struct ConcatParameter { + OpParameter op_parameter_; + ConcatQuantArg quant_arg_; + int axis_; +} ConcatParameter; + +#endif // MINDSPORE_NNACL_CONCAT_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/constant_of_shape_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/constant_of_shape_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..d75edb6f6839ed43ef566efbb70f30f5364d0611 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/constant_of_shape_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_CONSTANT_OF_SHAPE_PARAMETER_H_ +#define NNACL_CONSTANT_OF_SHAPE_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct ConstantOfShapeParameter { + OpParameter op_parameter_; + union value_ { + float f32_value_; + int32_t int32_value_; + bool bool_value_; + } value_; + int data_type_; + int element_size_; +} ConstantOfShapeParameter; + +#endif // NNACL_CONSTANT_OF_SHAPE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/conv3d_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/conv3d_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..2f438dc5f39e37a32e65a8924380374915fdd02e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/conv3d_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_CONV3D_PARAMETER_H_ +#define NNACL_CONV3D_PARAMETER_H_ + +#include +#include "nnacl/op_base.h" + +typedef struct Conv3DParameter { + OpParameter op_parameter_; +} Conv3DParameter; + +#endif // NNACL_CONV3D_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/conv_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/conv_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..5bc638d5b888b0286d10b4d55d17653d9500cd78 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/conv_parameter.h @@ -0,0 +1,169 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CONV_PARAMETER_H_ +#define NNACL_CONV_PARAMETER_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +typedef struct ConvParameter { + OpParameter op_parameter_; + ConvQuantArg conv_quant_arg_; + + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int dilation_h_; + int dilation_w_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; + int group_; + int tile_num_; /* # */ + int input_batch_; /* # */ + int input_h_; /* # */ + int input_w_; /* # */ + int input_channel_; + int output_batch_; /* # */ + int output_h_; /* # */ + int output_w_; /* # */ + int output_channel_; + int thread_num_; /* # */ + int input_unit_; /* # */ + int output_unit_; /* # */ + PadType pad_mode_; + ActType act_type_; + int channel_multiplie_; /* # */ + int output_padding_w_; /* # */ + int output_padding_h_; /* # */ + int out_format_; + + bool dynamic_shape_; +} ConvParameter; + +typedef struct ConvComputeParam { + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + int dilation_h_; + int dilation_w_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; + + int in_n_; + int in_h_; + int in_w_; + int in_c_; + int out_n_; + int out_h_; + int out_w_; + int out_c_; + + int in_hw_; + int out_hw_; + int kernel_hw_; + int tile_num_; +} ConvComputeParam; + +typedef struct SlidingWindowParam { + int left_; + int right_; + int top_; + int bottom_; + int c_block_; + int block_channel_; + int ic_align_; + int out_step_; + int out_h_step_; + int out_c_step_; + int out_w_step_; + int out_block_step_; + int in_step_; + int in_h_step_; + int in_sh_step_; // stride H + int in_sw_step_; // stride W + int in_kh_step_; // kernel H + int in_kw_step_; // kernel W + int kernel_step_; +} SlidingWindowParam; + +typedef struct ConvDwCalcParam { + void *num_pixels_; + void *out_w_start_; + void *out_w_end_; + int first_calc_kw_; +} ConvDwCalcParam; + +#define OUPUT_UNIT 2 +#define DECONV_WINOGRAD_DEFAULT_UNIT 3 /* # */ +#define DECONV_WINOGRAD_DEFAULT_TILE 8 /* # */ +#define DECONV_WINOGRAD_BUFFER_COUNT 8 /* # */ +typedef struct DeConvWg { /* # */ + void *b_buffer_; + void *AT_; + void *BT_; + + int kh_; + int kw_; + + int k_; + int i_; + int o_; +} DeConvWg; + +typedef struct DeConvWgABuffer { /* # */ + bool buf_init_; + void *middle_buffer_; + void *dest_buffer_; +} DeConvWgABuffer; + +typedef struct DeConvComputeUnit { /* # */ + void *weight_; + void *tmp_buffer_; + int w_start_; + int h_start_; + int w_size_; + int h_size_; + bool use_winograd_; + DeConvWg winograd_; +} DeConvComputeUnit; + +typedef struct DeConvParam { /* # */ + DeConvComputeUnit *compute_units_; + int compute_size_; + DeConvWgABuffer a_buffer_[DECONV_WINOGRAD_BUFFER_COUNT]; + int input_plane_; + int output_plane_; + int kernel_plane_; + int ic_div_; + int oc_div_; + int ic_up_; + int oc_up_; + int thread_num_; + int in_tile_count_; + int in_tile_h_count_; + int in_tile_w_count_; + int out_tile_h_; + int out_tile_w_; +} DeConvParam; + +#endif // NNACL_CONV_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/crop_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/crop_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..732af21ac0eef807a9bef0700b2c3d05c167284f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/crop_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CROP_PARAMETER_H_ +#define NNACL_CROP_PARAMETER_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +typedef struct CropParameter { + OpParameter op_parameter_; + int64_t axis_; + int offset_size_; + int64_t offset_[COMM_SHAPE_SIZE]; +} CropParameter; + +#endif // NNACL_CROP_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/cumsum_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/cumsum_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..838906d1895dff29e970f0fad6fae5717e24b3fb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/cumsum_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_CUMSUM_PARAMETER_H_ +#define NNACL_CUMSUM_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct CumSumParameter { + OpParameter op_parameter_; + bool reverse_; + bool exclusive_; + int axis_; +} CumsumParameter; + +#endif // NNACL_CUMSUM_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/custom_gru_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/custom_gru_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..4ccacb5bf65aeaf0c94f57e917b25045bc054757 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/custom_gru_parameter.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_CUSTOM_GRU_PARAMETER_H_ +#define NNACL_CUSTOM_GRU_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct CustomGruParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int num_step; + int batch_size; + int input_size; + int hidden_size; +} CustomGruParameter; + +#endif // NNACL_CUSTOM_GRU_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/custom_is_inf_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/custom_is_inf_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..e1eae394a397d68e039c223a8f7546245abb7246 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/custom_is_inf_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ +#define MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct CustomIsInfParameter { + // Primitive parameter + OpParameter op_parameter_; +} CustomIsInfParameter; + +#endif // MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/custom_masked_fill_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/custom_masked_fill_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..047d3d3f87c8e8073f7a51046d588784901a0da5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/custom_masked_fill_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ +#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct CustomMaskedFillParameter { + // Primitive parameter + OpParameter op_parameter_; +} CustomMaskedFillParameter; + +#endif // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/custom_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/custom_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..935f541fbfd8e676f49cf933da5073f8d5b9d8d5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/custom_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_CUSTOM_PARAMETER_H_ +#define NNACL_CUSTOM_PARAMETER_H_ +#include "nnacl/op_base.h" + +#define MAX_STR_LEN 64 +#define MAX_ATTR_NUM 8 + +typedef struct CustomParameter { + OpParameter op_parameter_; + char type[MAX_STR_LEN]; + char attr_name[MAX_ATTR_NUM][MAX_STR_LEN]; + char *attr_data[MAX_ATTR_NUM]; + int attr_num; +} CustomParameter; +#endif // NNACL_CUSTOM_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/depth_to_space_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/depth_to_space_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..832fe85924ff2d8f7888c0e20069d66a3e416d42 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/depth_to_space_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_DEPTH_TO_SPACE_PARAMETER_H_ +#define NNACL_DEPTH_TO_SPACE_PARAMETER_H_ +#include "nnacl/op_base.h" + +typedef struct DepthToSpaceParameter { + OpParameter op_parameter_; + int32_t block_size_; + int32_t mode_; +} DepthToSpaceParameter; + +#endif // NNACL_DEPTH_TO_SPACE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/detection_post_process_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/detection_post_process_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..9bc5428a4d2a5b537fb24114d0ee6bbf0c15793f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/detection_post_process_parameter.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_DETECTION_POST_PROCESS_PARAMETER_H_ +#define NNACL_DETECTION_POST_PROCESS_PARAMETER_H_ +#include "nnacl/op_base.h" + +typedef struct DetectionPostProcessParameter { + OpParameter op_parameter_; + float h_scale_; + float w_scale_; + float x_scale_; + float y_scale_; + float nms_iou_threshold_; + float nms_score_threshold_; + int64_t max_detections_; + int64_t detections_per_class_; + int64_t max_classes_per_detection_; + int64_t num_classes_; + bool use_regular_nms_; + bool out_quantized_; + + float *anchors_; + + void *decoded_boxes_; + void *nms_candidate_; + void *indexes_; + void *scores_; + void *all_class_indexes_; + void *all_class_scores_; + void *single_class_indexes_; + void *selected_; +} DetectionPostProcessParameter; + +#endif // NNACL_DETECTION_POST_PROCESS_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/dynamic_quant_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/dynamic_quant_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..1fc166cb81cf340d79c5f1324a1441fbfaa66807 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/dynamic_quant_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_DYNAMIC_QUANT_PARAMETER_H_ +#define NNACL_DYNAMIC_QUANT_PARAMETER_H_ +#include "nnacl/op_base.h" + +typedef struct DynamicQuantParameter { + OpParameter op_parameter_; + bool symmetric_; + int dst_type_; + int axis_num_; + int prefer_axes_[MAX_SHAPE_SIZE]; +} DynamicQuantParameter; + +#endif // NNACL_DYNAMIC_QUANT_PARAMETER_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/update_kernel_info.cc b/mindspore-lite/ops/kernel/cpu/nnacl/errorcode.c similarity index 32% rename from mindspore-lite/tools/graph_kernel/converter/update_kernel_info.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/errorcode.c index 2967d8e46e5b033c0654deab643ebb21316820e8..ecc114613643cddb5b887814a79bc394866d6d71 100644 --- a/mindspore-lite/tools/graph_kernel/converter/update_kernel_info.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/errorcode.c @@ -13,30 +13,34 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "tools/graph_kernel/converter/update_kernel_info.h" -#include "tools/graph_kernel/common/utils.h" -namespace mindspore::graphkernel { -bool UpdateKernelInfo::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - bool changed = false; - const auto ¶ms = func_graph->parameters(); - for (const auto ¶m : params) { - if (param == nullptr) { - continue; - } - auto build_info = GetKernelInfo(param); - if (build_info == nullptr) { - continue; - } - auto out_types = build_info->GetAllOutputDeviceTypes(); - auto out_formats = build_info->GetAllOutputFormats(); - if (out_types.size() != out_formats.size()) { - MS_LOG(INFO) << "Clear kernel info for node: " << param->fullname_with_scope(); - param->set_kernel_info(nullptr); - changed = true; - } +#include "nnacl/errorcode.h" +#include + +void InitNNACLKernelErrorCode(char **nnacl_kernel_error_msg) { + nnacl_kernel_error_msg[NNACL_CROP_AND_RESIZE_BOX_IDX_INVALID] = + "In CropAndResize, the value of box idx should match: [0, batch)."; + nnacl_kernel_error_msg[NNACL_WHERE_INPUT_NUM_INVALID] = "Invalid input number. Where op input number support 1 or 3."; + nnacl_kernel_error_msg[NNACL_WHERE_CONDITION_DATA_TYPE_ERROR] = + "Invalid input data type. Where op input data type support int32 fp32 and bool."; + nnacl_kernel_error_msg[NNACL_WHERE_CONDITION_NUM_INVALID] = + "The length of three inputs are not equal to 1 or length of output, which is unacceptable."; + nnacl_kernel_error_msg[NNACL_WHERE_INVALID_OUT_NUM] = "The element number invalid."; + nnacl_kernel_error_msg[NNACL_WHERE_NUM_MAX_INVALID] = "Inputs' length are zero"; + nnacl_kernel_error_msg[NNACL_ERR] = "NNACL common error."; +} + +char *NNACLErrorMsg(int error_code) { + static char nnacl_kernel_error_msg[NNACL_COMMON_END][MAX_MSG_LEN]; + static bool inited = false; + if (!inited) { + inited = true; + InitNNACLKernelErrorCode((char **)nnacl_kernel_error_msg); } - return changed; + + if (error_code > NNACL_OK && error_code < NNACL_COMMON_END) { + return nnacl_kernel_error_msg[error_code]; + } + + return "NNACL execute error!"; } -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/errorcode.h b/mindspore-lite/ops/kernel/cpu/nnacl/errorcode.h new file mode 100644 index 0000000000000000000000000000000000000000..a7c6190b2534bfb212c4cec997b009b8eb5c7039 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/errorcode.h @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_ERRORCODE_H_ +#define NNACL_ERRORCODE_H_ + +#include + +#define MAX_MSG_LEN 256 + +typedef enum ErrorCodeCommonEnum { + NNACL_OK = 0, + NNACL_ERR = 1, + NNACL_NULL_PTR, + NNACL_PARAM_INVALID, + NNACL_INFER_INVALID, + NNACL_INPUT_TENSOR_ERROR, + NNACL_OUTPUT_TENSOR_ERROR, + NNACL_INPUT_OUTPUT_DATA_TYPE_UNMATCH, + NNACL_FORMAT_ERROR, + NNACL_BUFFER_OVERFLOW, + NNACL_TENSOR_SIZE_INVALID, + NNACL_UNSUPPORTED_DATA_TYPE, + NNACL_UNSUPPORTED_FORMAT, + NNACL_MALLOC_BUFFER_FAILED, + NNACL_MALLOC_SIZE_INVALID, + NNACL_DISABLE_FP16, + NNACL_ADDN_SHAPE_UNMATCH, + NNACL_ACTIVATION_TYPE_INVALID, + NNACL_ARITHMETIC_DATA_TYPE_UNMATCH, + NNACL_ARITHMETIC_SHAPE_INVALID, + NNACL_ARITHMETIC_SELF_DATA_TYPE_UNSUPPORT, + NNACL_ARG_MIN_MAX_AXIS_INVALID, + NNACL_BIAS_ADD_SHAPE_NOT_MATCH, + NNACL_BIAS_ADD_SHAPE_OVERFLOW, + NNACL_BATCH_TO_SPACE_BLOCK_SHAPE_INVALID, + NNACL_BATCH_TO_SPACE_CROP_INVALID, + NNACL_BATCH_NORM_CHANNEL_SHAPE_INVALID, + NNACL_CLIP_DATA_TYPE_INVALID, + NNACL_CLIP_MINMAX_VALUE_INVALID, + NNACL_CONCAT_AXIS_INVALID, + NNACL_CONCAT_F16_INVALID_DATA_TYPE, + NNACL_CONCAT_F16_OUTPUT_DATA_INVALID, + NNACL_CONCAT_SHAPE_INVALID, + NNACL_CONVOLUTION_INPUT_CHANNEL_UNMATCH, + NNACL_CONVOLUTION_INPUT_HW_OVERFLOW, + NNACL_CONVOLUTION_KERNEL_HW_OVERFLOW, + NNACL_CONVOLUTION_OUTPUT_HW_OVERFLOW, + NNACL_CONVOLUTION_WEIGHT_DATATYPE_INVALID, + NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID, + NNACL_CONVOLUTION_AVX512_UNSUPPORT_FORMAT, + NNACL_CONVOLUTION_WEIGHT_DATA_INVALID, + NNACL_CONVOLUTION_BIAS_DATATYPE_INVALID, + NNACL_CROP_AND_RESIZE_BOX_IDX_INVALID, + NNACL_DECONV_RESIZE_OC_INVALID, + NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID, + NNACL_DECONVOLUTION_DEPTHWISE_STRIDE_INVALID, + NNACL_DECONVOLUTION_DEPTHWISE_INVALID_WEIGHT_SHAPE, + NNACL_DECONVOLUTION_DEPTHWISE_INVALID_WEIGHT_REPACK, + NNACL_DECONVOLUTION_DEPTHWISE_CHANNEL_INVALID, + NNACL_DEPTH_TO_SPACE_INVALID_MODE, + NNACL_ELTWISE_INVALID_MOD, + NNACL_FILL_DATA_TYPE_INVALID, + NNACL_FUSED_BATCH_NORM_NO_CHANGE, + NNACL_FUSED_BATCH_DATA_TYPE_INVALID, + NNACL_FUSED_BATCH_NORM_TO_SCALE_FAILED, + NNACL_FUSED_BATCH_TRAIN_DATA_INVALID, + NNACL_FUSED_BATCH_TRAIN_PARAM_DATA_INVALID, + NNACL_GATHER_INDICES_DATA_TYPE_INVALID, + NNACL_GATHER_INDICES_VALUE_INVALID, + NNACL_GATHER_AXIS_INVALID, + NNACL_GATHER_INPUT_TENSOR_INVALID, + NNACL_GATHER_OUTPUT_TENSOR_INVALID, + NNACL_GATHER_D_AXIS_INVALID, + NNACL_GATHER_ND_COUNT_INVALID, + NNACL_GATHER_ND_INDICES_RANK_INVALID, + NNACL_GATHER_ND_INDICES_SHAPE_INVALID, + NNACL_GROUP_CONVOLUTION_GROUP_INVALID, + NNACL_GATHER_ND_INDICES_DATA_TYPE_INVALID, + NNACL_GROUP_CONVOLUTION_SHAPE_INVALID, + NNACL_GROUP_NORM_NUM_GROUPS_INVALID, + NNACL_GROUP_NORM_SHAPE_SIZE_INVALID, + NNACL_GROUP_NORM_FORMAT_INVALID, + NNACL_SOFTMAX_AXIS_INVALID, + NNACL_MATMUL_ACT_TYPE_INVALID, + NNACL_MATMUL_BIAS_INVALID, + NNACL_NON_ZERO_SHAPE_INVALID, + NNACL_NON_MAX_SUPPRESSION_TENSOR_SIZE_INVALID, + NNACL_NON_MAX_SUPPRESSION_PARAM_INVALID, + NNACL_NON_MAX_SUPPRESSION_BOX_DIMS_INVALID, + NNACL_NON_MAX_SUPPRESSION_BOX_DIMS_SCORE_UNMATCH, + NNACL_NON_MAX_SUPPRESSION_DIMENSION_SPATIAL_UNMATCH, + NNACL_NON_MAX_SUPPRESSION_UNSUPPORT_DEFINE_DATA, + NNACL_NON_MAX_SUPPRESSION_OUTPUT_SIZE_UNMATCH, + NNACL_ONE_HOT_AXIS_INVALID, + NNACL_ONE_HOT_OUTER_SIZE_INVALID, + NNACL_ONE_HOT_INNER_SIZE_INVALID, + NNACL_ONE_HOR_DEPTH_TENSOR_DATA_TYPE_INVALID, + NNACL_ONE_HOR_ON_VALUE_TENSOR_DATA_TYPE_INVALID, + NNACL_ONE_HOR_OFF_VALUE_TENSOR_DATA_TYPE_INVALID, + NNACL_ONE_HOR_ON_OFF_VALUE_TENSOR_DATA_TYPE_INVALID, + NNACL_PAD_SHAPE_INVALID, + NNACL_PAD_PADDING_VALID_INVALID, + NNACL_PAD_MIRROR_PAD_SIZE_INVALID, + NNACL_POW_INVALID_DATA_TYPE, + NNACL_PRELU_SLOPE_NUM_INVALID, + NNACL_PRIOR_BOX_VALUE_INVALID, + NNACL_PRIOR_BOX_RATIO_INVALID, + NNACL_LOCAL_RESPONSE_NORM_SHAPE_INVALID, + NNACL_LOCAL_RESPONSE_NORM_DEPTH_RADIUS_INVALID, + NNACL_LAYER_NORM_OUTPUT_NUM_INVALID, + NNACL_REDUCE_AXIS_SIZE_ERROR, + NNACL_REDUCE_AXES_TENSOR_ERROR, + NNACL_REDUCE_UNSUPPORTED_DATA_TYPE, + NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID, + NNACL_REDUCE_COEFF_DATA_TYPE_INVALID, + NNACL_REVERSE_AXIS_INVALID, + NNACL_REVERSE_AXIS_VALUE_INVALID, + NNACL_REVERSE_DATA_SIZE_INVALID, + NNACL_REVERSE_NUM_AXIS_INVALID, + NNACL_SCALE_AXIS_AND_SHAPE_UNMATCH, + NNACL_SCALE_UNSUPPORT_ACT_TYPE, + NNACL_SCALE_SCALE_SHAPE_UNMATCH, + NNACL_SCALE_INPUT_NUM_INVALID, + NNACL_STACK_TENSOR_SHAPE_INVALID, + NNACL_STRIDED_SLICE_INVALID_SHAPE_SIZE, + NNACL_STRIDED_SLICE_INVALID_DATA_SIZE, + NNACL_STRIDED_SLICE_UNSUPPORTED_DATA_TYPE, + NNACL_STRIDED_SLICE_INVALID_PARALLEL_MOD, + NNACL_STRIDED_SLICE_UNSUPPORTED_MAX_8D, + NNACL_SPLICE_SHAPE_INVALID, + NNACL_TILE_INPUT_SHAPE_INVALID, + NNACL_TILE_SECOND_INPUT_NUM_INVALID, + NNACL_TILE_SECOND_INPUT_VALUE_INVALID, + NNACL_TILE_SECOND_INPUT_DATA_TYPE_INVALID, + NNACL_TILE_RESIZE_IN_RUNTIME_FAILED, + NNACL_TRIU_TRIL_INPUT_SHAPE_ERROR, + NNACL_TRIU_K_TENSOR_DATA_TYPE_INVALID, + NNACL_TRIU_INPUT_DIMS_INVALID, + NNACL_TRANSPOSE_INSHAPE_OUT_OF_RANGE, + NNACL_TRANSPOSE_INPUT_TENSOR_NUM_INVALID, + NNACL_TRANSPOSE_INPUT_TENSOR_VALUD_INVALID, + NNACL_TRANSPOSE_PERM_DIMS_INVALID, + NNACL_TRANSPOSE_PERM_TENSOR_INVALID, + NNACL_TRANSPOSE_PERM_TENSOR_VALUE_INVALID, + NNACL_TRANSPOSE_PERM_DELETE_DIMENSION_FAILED, + NNACL_WHERE_INPUT_NUM_INVALID, + NNACL_WHERE_CONDITION_DATA_TYPE_ERROR, + NNACL_WHERE_CONDITION_NUM_INVALID, + NNACL_WHERE_INVALID_OUT_NUM, + NNACL_WHERE_NUM_MAX_INVALID, + NNACL_WHERE_BROAD_CAST_FAILED, + NNACL_COMMON_END +} ErrorCodeCommonEnum; + +typedef enum ErrorCodeFp32OpEnum { + NNACL_ERRCODE_OP_FP32_START = 10000, + NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC, + NNACL_ERRCODE_REVERSE_MALLOC, + NNACL_ERRCODE_SQRT_NEGATIVE, + NNACL_ERRCODE_RSQRT_NEGATIVE, + NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO, + NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO, + NNACL_ERRCODE_DIVISOR_ZERO, + NNACL_ERRCODE_INDEX_OUT_OF_RANGE, + NNACL_ERRCODE_WINOGRAD_GENERATOR_ERROR, + NNACL_ERRCODE_OP_FP32_END = 19999 +} ErrorCodeFp32OpEnum; + +typedef enum ErrorCodeFp16OpEnum { + NNACL_ERRCODE_OP_FP16_START = 20000, + NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR, + NNACL_ERRCODE_OP_FP16_END = 29999 +} ErrorCodeFp16OpEnum; + +typedef enum ErrorCodeUint8OpEnum { + NNACL_ERRCODE_OP_UINT8_START = 30000, + NNACL_ERRCODE_OP_UINT8_END = 39999 +} ErrorCodeUint8OpEnum; + +typedef enum ErrorCodeInt8OpEnum { + NNACL_ERRCODE_OP_INT8_START = 40000, + NNACL_ERRCODE_ADD_OVERFLOW, + NNACL_ERRCODE_MUL_OVERFLOW, + NNACL_ERRCODE_OP_INT8_END = 49999 +} ErrorCodeInt8OpEnums; + +#ifdef __cplusplus +extern "C" { +#endif +char *NNACLErrorMsg(int error_code); +#ifdef __cplusplus +} +#endif +#endif // NNACL_ERRORCODE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/exp_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/exp_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..7211e36f3a7e89692523029c74c20d7c31094076 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/exp_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_EXP_PARAMETER_H_ +#define NNACL_EXP_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct ExpParameter { + OpParameter op_parameter_; + float base_; + float scale_; + float shift_; +} ExpParameter; + +#endif // NNACL_EXP_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..2d15183ff041256775764e598c2252d237c7bce1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x16_kernel_nhwc_fp32.c @@ -0,0 +1,533 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_10x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_9])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6), + [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..50818ca7c8d66f4b1331852f4ad9b0edfa7633e1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_10x32_kernel_nhwc_fp32.c @@ -0,0 +1,781 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_10x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 0(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 4(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 8(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 12(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 16(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 20(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 24(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 28(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 32(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 36(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 40(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 44(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 48(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 52(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 56(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 60(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 0(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6), + [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..494a7bdf794557e717e59a23479ae28fdbb9aee5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x16_kernel_nhwc_fp32.c @@ -0,0 +1,573 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_11x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm10\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_9])\n" + "vmovups %%zmm10, 0(%[dst_9], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6), + [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..9069307b0d605b2e699dfcb413f3c7bfc23e4179 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_11x32_kernel_nhwc_fp32.c @@ -0,0 +1,844 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_11x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 32(%[src_6]), %%zmm29\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_9]), %%zmm26\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 36(%[src_6]), %%zmm29\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_9]), %%zmm26\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 40(%[src_6]), %%zmm29\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_9]), %%zmm26\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 44(%[src_6]), %%zmm29\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_9]), %%zmm26\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 48(%[src_6]), %%zmm29\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_9]), %%zmm26\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 52(%[src_6]), %%zmm29\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_9]), %%zmm26\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 56(%[src_6]), %%zmm29\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_9]), %%zmm26\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 60(%[src_6]), %%zmm29\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_9]), %%zmm26\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9])\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6), + [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..85540ef1c62c8b5b65b112c9916403e82420dd49 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x16_kernel_nhwc_fp32.c @@ -0,0 +1,614 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm10\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 0(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 32(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 36(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 40(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 44(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 48(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 52(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 56(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 60(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_9])\n" + "vmovups %%zmm10, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm11, 0(%[dst_9], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6), + [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..874fc7265c6928623dd22b03055b6b59ba2db80f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c @@ -0,0 +1,908 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm22\n" + "vmovups 64(%[dst_9], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 0(%[bias]), %%zmm22\n" + "vmovups 64(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 32(%[src_6]), %%zmm29\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_9]), %%zmm26\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 36(%[src_6]), %%zmm29\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_9]), %%zmm26\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 40(%[src_6]), %%zmm29\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_9]), %%zmm26\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 44(%[src_6]), %%zmm29\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_9]), %%zmm26\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 48(%[src_6]), %%zmm29\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_9]), %%zmm26\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 52(%[src_6]), %%zmm29\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_9]), %%zmm26\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 56(%[src_6]), %%zmm29\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_9]), %%zmm26\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 60(%[src_6]), %%zmm29\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_9]), %%zmm26\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9])\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm22, 0(%[dst_9], %[dst_stride], 2)\n" + "vmovups %%zmm23, 64(%[dst_9], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6), + [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..6b9f0d83a73177d4faaf53fa7f23f1e63a85ce14 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x16_kernel_nhwc_fp32.c @@ -0,0 +1,158 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..84a504c3218cfb91dc469e848a5ba13084317b46 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c @@ -0,0 +1,198 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..0c7b8cf90012251be8b82d9b82b334175e478e6c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x48_kernel_nhwc_fp32.c @@ -0,0 +1,238 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..12a340b48c651070e0fc9af258d4cb9e87d90e77 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c @@ -0,0 +1,278 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..58c63c931ba61d282be5940670a235fb943d5a8e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c @@ -0,0 +1,318 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..bd970a3bafb3986d9b1c1892c71bcb3593dc63f9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c @@ -0,0 +1,358 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..24e6359d6d3883cfea22e6aeb1ea8ab23a5e43d3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x16_kernel_nhwc_fp32.c @@ -0,0 +1,198 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d5058e5610d6e5176e6786666bd549439532fea3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c @@ -0,0 +1,261 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..70eba21db17a127d565c8878235ed197fa074a1b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x48_kernel_nhwc_fp32.c @@ -0,0 +1,324 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..08a81839d6656e925e8d9f659e3cb052b2c289be --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c @@ -0,0 +1,387 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..6c2ba6ec19f4e82cfb4ed67cbfe32cec0fc22ea3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c @@ -0,0 +1,450 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5b5cab40e05dfe33c392c099cc0cdaaca06eaef4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c @@ -0,0 +1,513 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..3b99d5b577e14df1295d1ddb3f4dd975e05b7b45 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x16_kernel_nhwc_fp32.c @@ -0,0 +1,238 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..3c890617b44ea802e3d45f430fcd4b5abeac5de6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c @@ -0,0 +1,324 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..28701f4b3dbe6bca15604891395664d4828a9753 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x48_kernel_nhwc_fp32.c @@ -0,0 +1,410 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..49168491553eb815d8b10b84a865fd0dec2e7391 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c @@ -0,0 +1,496 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d614491b645ce503ba2fc4001cee097d722a2949 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c @@ -0,0 +1,583 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..ad809aae05f7ef2b9235f0509a942fdb6c6163f5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c @@ -0,0 +1,669 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 256(%[bias]), %%zmm16\n" + "vmovups 320(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..12c14d616f17f05ff3f216c4148acff3d849165d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x16_kernel_nhwc_fp32.c @@ -0,0 +1,284 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..838a6d42dbab0cd37fb1d6ba9f5ed2fdc7d3f89c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c @@ -0,0 +1,393 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..311963de9e0faf335cfaa323121f15053977dfe5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x48_kernel_nhwc_fp32.c @@ -0,0 +1,502 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..324f93a080b629e9f88865e507b61233e193e749 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c @@ -0,0 +1,612 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c8b6277d02c96c7a5bd3479aedf469c68a369f73 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c @@ -0,0 +1,721 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 192(%[bias]), %%zmm18\n" + "vmovups 256(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..479f97612ccb840d2f68a9309a0e6ddbd57d7b5c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c @@ -0,0 +1,831 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_3]), %%zmm18\n" + "vmovups 64(%[dst_3]), %%zmm19\n" + "vmovups 128(%[dst_3]), %%zmm20\n" + "vmovups 192(%[dst_3]), %%zmm21\n" + "vmovups 256(%[dst_3]), %%zmm22\n" + "vmovups 320(%[dst_3]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 256(%[bias]), %%zmm16\n" + "vmovups 320(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "vmovups 192(%[bias]), %%zmm21\n" + "vmovups 256(%[bias]), %%zmm22\n" + "vmovups 320(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_3])\n" + "vmovups %%zmm19, 64(%[dst_3])\n" + "vmovups %%zmm20, 128(%[dst_3])\n" + "vmovups %%zmm21, 192(%[dst_3])\n" + "vmovups %%zmm22, 256(%[dst_3])\n" + "vmovups %%zmm23, 320(%[dst_3])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..3d4cab9d62e052a18271eec477fad35504fa6db6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x16_kernel_nhwc_fp32.c @@ -0,0 +1,324 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..2a90d4d0869abca9aed20ba47e37847f3a1f7c2a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c @@ -0,0 +1,456 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..86f5c743695ad15578ef761ccf14b37dac9ed882 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x48_kernel_nhwc_fp32.c @@ -0,0 +1,589 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..cb14ec7a9413e1a1704aa5e84b988e86c894c7dc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c @@ -0,0 +1,721 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 128(%[bias]), %%zmm18\n" + "vmovups 192(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3])\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..6ad12587badf38676208870136be07ce38492aa1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c @@ -0,0 +1,854 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm23\n" + "vmovups 256(%[dst_3], %[dst_stride], 1), %%zmm24\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 192(%[bias]), %%zmm18\n" + "vmovups 256(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 128(%[bias]), %%zmm22\n" + "vmovups 192(%[bias]), %%zmm23\n" + "vmovups 256(%[bias]), %%zmm24\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + "vxorps %%zmm24, %%zmm24, %%zmm24\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "vmaxps %%zmm24, %%zmm31, %%zmm24\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + "vminps %%zmm24, %%zmm30, %%zmm24\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3])\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8a1e87a7396cf284388c932facb92c06e922aa50 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x16_kernel_nhwc_fp32.c @@ -0,0 +1,364 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..24144132b7f10d8bda9c282e027b6712817f5629 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c @@ -0,0 +1,519 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b2e808ead844081d9940b09c8662d516e742e67e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x48_kernel_nhwc_fp32.c @@ -0,0 +1,675 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..eadabcf145bf9c8d9b947d0bb3ac7ed44b4bd137 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c @@ -0,0 +1,831 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 128(%[bias]), %%zmm18\n" + "vmovups 192(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 128(%[bias]), %%zmm22\n" + "vmovups 192(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3])\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4f18ba52dce49fefa4afdf729d21f7e8aa4277b7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x16_kernel_nhwc_fp32.c @@ -0,0 +1,408 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..90acd6dcc374301241ba22990f197cd8fae710e6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c @@ -0,0 +1,587 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4c104d418d0facc34bf95f806aa3642e83c8d979 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x48_kernel_nhwc_fp32.c @@ -0,0 +1,765 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_6]), %%zmm18\n" + "vmovups 64(%[dst_6]), %%zmm19\n" + "vmovups 128(%[dst_6]), %%zmm20\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 0(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 4(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 8(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 12(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 16(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 20(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 24(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 28(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 32(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 36(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 40(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 44(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 48(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 52(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 56(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 60(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 0(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_6])\n" + "vmovups %%zmm19, 64(%[dst_6])\n" + "vmovups %%zmm20, 128(%[dst_6])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1ef3d11e97418df4667e1a071d1df20ff12d50d2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x16_kernel_nhwc_fp32.c @@ -0,0 +1,448 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c1e942813a69e13ab376046ba358a8081af70e30 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c @@ -0,0 +1,650 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..47b37a774cb0ab515fa7fab46e0079485c15a83c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x48_kernel_nhwc_fp32.c @@ -0,0 +1,852 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_6]), %%zmm18\n" + "vmovups 64(%[dst_6]), %%zmm19\n" + "vmovups 128(%[dst_6]), %%zmm20\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm21\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm22\n" + "vmovups 128(%[dst_6], %[dst_stride], 1), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "vmovups 0(%[bias]), %%zmm21\n" + "vmovups 64(%[bias]), %%zmm22\n" + "vmovups 128(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_6]), %%zmm27\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_6]), %%zmm27\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_6]), %%zmm27\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_6]), %%zmm27\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_6]), %%zmm27\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_6]), %%zmm27\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_6]), %%zmm27\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_6]), %%zmm27\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_6]), %%zmm27\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_6]), %%zmm27\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_6]), %%zmm27\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_6]), %%zmm27\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_6]), %%zmm27\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_6]), %%zmm27\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_6]), %%zmm27\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_6]), %%zmm27\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_6]), %%zmm27\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3])\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_6])\n" + "vmovups %%zmm19, 64(%[dst_6])\n" + "vmovups %%zmm20, 128(%[dst_6])\n" + "vmovups %%zmm21, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm22, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm23, 128(%[dst_6], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e3f1eca80b6894ca051d4bbc1b113b192496b64a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x16_kernel_nhwc_fp32.c @@ -0,0 +1,488 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_9x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm3, 0(%[dst_3])\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_6])\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..ff3504a9aea67094603935c1f7983dc425c5fbea --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_avx512/nnacl_gemm_avx512_9x32_kernel_nhwc_fp32.c @@ -0,0 +1,713 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_9x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1603526877f885c801b24e6d0fd409340e508ff8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + __m256 dst9; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + __m256 src90 = _mm256_set1_ps(*(src + 72)); + dst9 = _mm256_fmadd_ps(dst9, src90, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + __m256 src91 = _mm256_set1_ps(*(src + 73)); + dst9 = _mm256_fmadd_ps(dst9, src91, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + __m256 src92 = _mm256_set1_ps(*(src + 74)); + dst9 = _mm256_fmadd_ps(dst9, src92, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + __m256 src93 = _mm256_set1_ps(*(src + 75)); + dst9 = _mm256_fmadd_ps(dst9, src93, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + __m256 src94 = _mm256_set1_ps(*(src + 76)); + dst9 = _mm256_fmadd_ps(dst9, src94, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + __m256 src95 = _mm256_set1_ps(*(src + 77)); + dst9 = _mm256_fmadd_ps(dst9, src95, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + __m256 src96 = _mm256_set1_ps(*(src + 78)); + dst9 = _mm256_fmadd_ps(dst9, src96, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + __m256 src97 = _mm256_set1_ps(*(src + 79)); + dst9 = _mm256_fmadd_ps(dst9, src97, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); + _mm256_store_ps(dst + 0 * src_stride + 72, dst9); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..5cd2341cf12baaa9fb0806df0c279f5fe004a60c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,303 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "vmovups 288(%[dst]), %%ymm9\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "vmovaps 0(%[bias]), %%ymm9\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 288(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 289(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 290(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 291(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 292(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 293(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 294(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 295(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + "vmovups %%ymm9, 288(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..dcc963de7ef6fd9c1206b062c5300b8e43a24ccd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,321 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + __m256 dst9; + __m256 dst10; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72); + dst10 = _mm256_load_ps(dst + 0 * dst_stride + 80); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 0); + dst10 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + __m256 src90 = _mm256_set1_ps(*(src + 72)); + dst9 = _mm256_fmadd_ps(dst9, src90, weight00); + __m256 src100 = _mm256_set1_ps(*(src + 80)); + dst10 = _mm256_fmadd_ps(dst10, src100, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + __m256 src91 = _mm256_set1_ps(*(src + 73)); + dst9 = _mm256_fmadd_ps(dst9, src91, weight01); + __m256 src101 = _mm256_set1_ps(*(src + 81)); + dst10 = _mm256_fmadd_ps(dst10, src101, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + __m256 src92 = _mm256_set1_ps(*(src + 74)); + dst9 = _mm256_fmadd_ps(dst9, src92, weight02); + __m256 src102 = _mm256_set1_ps(*(src + 82)); + dst10 = _mm256_fmadd_ps(dst10, src102, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + __m256 src93 = _mm256_set1_ps(*(src + 75)); + dst9 = _mm256_fmadd_ps(dst9, src93, weight03); + __m256 src103 = _mm256_set1_ps(*(src + 83)); + dst10 = _mm256_fmadd_ps(dst10, src103, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + __m256 src94 = _mm256_set1_ps(*(src + 76)); + dst9 = _mm256_fmadd_ps(dst9, src94, weight04); + __m256 src104 = _mm256_set1_ps(*(src + 84)); + dst10 = _mm256_fmadd_ps(dst10, src104, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + __m256 src95 = _mm256_set1_ps(*(src + 77)); + dst9 = _mm256_fmadd_ps(dst9, src95, weight05); + __m256 src105 = _mm256_set1_ps(*(src + 85)); + dst10 = _mm256_fmadd_ps(dst10, src105, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + __m256 src96 = _mm256_set1_ps(*(src + 78)); + dst9 = _mm256_fmadd_ps(dst9, src96, weight06); + __m256 src106 = _mm256_set1_ps(*(src + 86)); + dst10 = _mm256_fmadd_ps(dst10, src106, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + __m256 src97 = _mm256_set1_ps(*(src + 79)); + dst9 = _mm256_fmadd_ps(dst9, src97, weight07); + __m256 src107 = _mm256_set1_ps(*(src + 87)); + dst10 = _mm256_fmadd_ps(dst10, src107, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); + _mm256_store_ps(dst + 0 * src_stride + 72, dst9); + _mm256_store_ps(dst + 0 * src_stride + 80, dst10); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..31c3c4fc2b710d6e45fdfa1b1e6a75cabdac98b6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,325 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "vmovups 288(%[dst]), %%ymm9\n" + "vmovups 320(%[dst]), %%ymm10\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "vmovaps 0(%[bias]), %%ymm9\n" + "vmovaps 0(%[bias]), %%ymm10\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 288(%[src]), %%ymm14\n" + "vbroadcastss 320(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 289(%[src]), %%ymm14\n" + "vbroadcastss 321(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 290(%[src]), %%ymm14\n" + "vbroadcastss 322(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 291(%[src]), %%ymm14\n" + "vbroadcastss 323(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 292(%[src]), %%ymm14\n" + "vbroadcastss 324(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 293(%[src]), %%ymm14\n" + "vbroadcastss 325(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 294(%[src]), %%ymm14\n" + "vbroadcastss 326(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 295(%[src]), %%ymm14\n" + "vbroadcastss 327(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + "vmovups %%ymm9, 288(%[dst])\n" + "vmovups %%ymm10, 320(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8e39ab64baf7591f29c7dae4e14fd0d721a9b27f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,345 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + __m256 dst9; + __m256 dst10; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72); + dst10 = _mm256_load_ps(dst + 0 * dst_stride + 80); + dst11 = _mm256_load_ps(dst + 0 * dst_stride + 88); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 0); + dst10 = _mm256_load_ps(bias + 0); + dst11 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + __m256 src90 = _mm256_set1_ps(*(src + 72)); + dst9 = _mm256_fmadd_ps(dst9, src90, weight00); + __m256 src100 = _mm256_set1_ps(*(src + 80)); + dst10 = _mm256_fmadd_ps(dst10, src100, weight00); + __m256 src110 = _mm256_set1_ps(*(src + 88)); + dst11 = _mm256_fmadd_ps(dst11, src110, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + __m256 src91 = _mm256_set1_ps(*(src + 73)); + dst9 = _mm256_fmadd_ps(dst9, src91, weight01); + __m256 src101 = _mm256_set1_ps(*(src + 81)); + dst10 = _mm256_fmadd_ps(dst10, src101, weight01); + __m256 src111 = _mm256_set1_ps(*(src + 89)); + dst11 = _mm256_fmadd_ps(dst11, src111, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + __m256 src92 = _mm256_set1_ps(*(src + 74)); + dst9 = _mm256_fmadd_ps(dst9, src92, weight02); + __m256 src102 = _mm256_set1_ps(*(src + 82)); + dst10 = _mm256_fmadd_ps(dst10, src102, weight02); + __m256 src112 = _mm256_set1_ps(*(src + 90)); + dst11 = _mm256_fmadd_ps(dst11, src112, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + __m256 src93 = _mm256_set1_ps(*(src + 75)); + dst9 = _mm256_fmadd_ps(dst9, src93, weight03); + __m256 src103 = _mm256_set1_ps(*(src + 83)); + dst10 = _mm256_fmadd_ps(dst10, src103, weight03); + __m256 src113 = _mm256_set1_ps(*(src + 91)); + dst11 = _mm256_fmadd_ps(dst11, src113, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + __m256 src94 = _mm256_set1_ps(*(src + 76)); + dst9 = _mm256_fmadd_ps(dst9, src94, weight04); + __m256 src104 = _mm256_set1_ps(*(src + 84)); + dst10 = _mm256_fmadd_ps(dst10, src104, weight04); + __m256 src114 = _mm256_set1_ps(*(src + 92)); + dst11 = _mm256_fmadd_ps(dst11, src114, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + __m256 src95 = _mm256_set1_ps(*(src + 77)); + dst9 = _mm256_fmadd_ps(dst9, src95, weight05); + __m256 src105 = _mm256_set1_ps(*(src + 85)); + dst10 = _mm256_fmadd_ps(dst10, src105, weight05); + __m256 src115 = _mm256_set1_ps(*(src + 93)); + dst11 = _mm256_fmadd_ps(dst11, src115, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + __m256 src96 = _mm256_set1_ps(*(src + 78)); + dst9 = _mm256_fmadd_ps(dst9, src96, weight06); + __m256 src106 = _mm256_set1_ps(*(src + 86)); + dst10 = _mm256_fmadd_ps(dst10, src106, weight06); + __m256 src116 = _mm256_set1_ps(*(src + 94)); + dst11 = _mm256_fmadd_ps(dst11, src116, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + __m256 src97 = _mm256_set1_ps(*(src + 79)); + dst9 = _mm256_fmadd_ps(dst9, src97, weight07); + __m256 src107 = _mm256_set1_ps(*(src + 87)); + dst10 = _mm256_fmadd_ps(dst10, src107, weight07); + __m256 src117 = _mm256_set1_ps(*(src + 95)); + dst11 = _mm256_fmadd_ps(dst11, src117, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); + _mm256_store_ps(dst + 0 * src_stride + 72, dst9); + _mm256_store_ps(dst + 0 * src_stride + 80, dst10); + _mm256_store_ps(dst + 0 * src_stride + 88, dst11); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..585957249bbb1e2d98f2d31ce6c07388c44fcf32 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,347 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "vmovups 288(%[dst]), %%ymm9\n" + "vmovups 320(%[dst]), %%ymm10\n" + "vmovups 352(%[dst]), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "vmovaps 0(%[bias]), %%ymm9\n" + "vmovaps 0(%[bias]), %%ymm10\n" + "vmovaps 0(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 288(%[src]), %%ymm14\n" + "vbroadcastss 320(%[src]), %%ymm13\n" + "vbroadcastss 352(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 289(%[src]), %%ymm14\n" + "vbroadcastss 321(%[src]), %%ymm13\n" + "vbroadcastss 353(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 290(%[src]), %%ymm14\n" + "vbroadcastss 322(%[src]), %%ymm13\n" + "vbroadcastss 354(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 291(%[src]), %%ymm14\n" + "vbroadcastss 323(%[src]), %%ymm13\n" + "vbroadcastss 355(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 292(%[src]), %%ymm14\n" + "vbroadcastss 324(%[src]), %%ymm13\n" + "vbroadcastss 356(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 293(%[src]), %%ymm14\n" + "vbroadcastss 325(%[src]), %%ymm13\n" + "vbroadcastss 357(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 294(%[src]), %%ymm14\n" + "vbroadcastss 326(%[src]), %%ymm13\n" + "vbroadcastss 358(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 295(%[src]), %%ymm14\n" + "vbroadcastss 327(%[src]), %%ymm13\n" + "vbroadcastss 359(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + "vmovups %%ymm9, 288(%[dst])\n" + "vmovups %%ymm10, 320(%[dst])\n" + "vmovups %%ymm11, 352(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..87a830fbc57095e1e6d09a66107a7dd37c994cc1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src00, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src01, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src02, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src03, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src04, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src05, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src06, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src07, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 1 * src_stride + 0, dst1); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..6458f849f26b07a29711b9e124b172c7958c3564 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 32(%[bias]), %%ymm1\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..548f143c0a08bee2d2940411152d5866808f5fa3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 2 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src00, weight10); + dst2 = _mm256_fmadd_ps(dst2, src00, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src01, weight11); + dst2 = _mm256_fmadd_ps(dst2, src01, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src02, weight12); + dst2 = _mm256_fmadd_ps(dst2, src02, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src03, weight13); + dst2 = _mm256_fmadd_ps(dst2, src03, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src04, weight14); + dst2 = _mm256_fmadd_ps(dst2, src04, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src05, weight15); + dst2 = _mm256_fmadd_ps(dst2, src05, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src06, weight16); + dst2 = _mm256_fmadd_ps(dst2, src06, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src07, weight17); + dst2 = _mm256_fmadd_ps(dst2, src07, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 1 * src_stride + 0, dst1); + _mm256_store_ps(dst + 2 * src_stride + 0, dst2); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..465fc9aeea1aace49fa2b6e2ddf09029a5cf2ae5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm2\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 32(%[bias]), %%ymm1\n" + "vmovaps 64(%[bias]), %%ymm2\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 2)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c23a6ff6fa4dc13a398b51abdd34446df6a2214b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c @@ -0,0 +1,153 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 3 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 16); + dst3 = _mm256_load_ps(bias + 24); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 src00 = _mm256_set1_ps(*(src + 0)); + __m256 weight00 = _mm256_load_ps(weight + 0); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 weight10 = _mm256_load_ps(weight + 8); + dst1 = _mm256_fmadd_ps(dst1, src00, weight10); + __m256 weight20 = _mm256_load_ps(weight + 16); + dst2 = _mm256_fmadd_ps(dst2, src00, weight20); + __m256 weight30 = _mm256_load_ps(weight + 24); + dst3 = _mm256_fmadd_ps(dst3, src00, weight30); + // bock1 + __m256 src01 = _mm256_set1_ps(*(src + 1)); + __m256 weight01 = _mm256_load_ps(weight + 32); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 weight11 = _mm256_load_ps(weight + 40); + dst1 = _mm256_fmadd_ps(dst1, src01, weight11); + __m256 weight21 = _mm256_load_ps(weight + 48); + dst2 = _mm256_fmadd_ps(dst2, src01, weight21); + __m256 weight31 = _mm256_load_ps(weight + 56); + dst3 = _mm256_fmadd_ps(dst3, src01, weight31); + // bock2 + __m256 src02 = _mm256_set1_ps(*(src + 2)); + __m256 weight02 = _mm256_load_ps(weight + 64); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 weight12 = _mm256_load_ps(weight + 72); + dst1 = _mm256_fmadd_ps(dst1, src02, weight12); + __m256 weight22 = _mm256_load_ps(weight + 80); + dst2 = _mm256_fmadd_ps(dst2, src02, weight22); + __m256 weight32 = _mm256_load_ps(weight + 88); + dst3 = _mm256_fmadd_ps(dst3, src02, weight32); + // bock3 + __m256 src03 = _mm256_set1_ps(*(src + 3)); + __m256 weight03 = _mm256_load_ps(weight + 96); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 weight13 = _mm256_load_ps(weight + 104); + dst1 = _mm256_fmadd_ps(dst1, src03, weight13); + __m256 weight23 = _mm256_load_ps(weight + 112); + dst2 = _mm256_fmadd_ps(dst2, src03, weight23); + __m256 weight33 = _mm256_load_ps(weight + 120); + dst3 = _mm256_fmadd_ps(dst3, src03, weight33); + // bock4 + __m256 src04 = _mm256_set1_ps(*(src + 4)); + __m256 weight04 = _mm256_load_ps(weight + 128); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 weight14 = _mm256_load_ps(weight + 136); + dst1 = _mm256_fmadd_ps(dst1, src04, weight14); + __m256 weight24 = _mm256_load_ps(weight + 144); + dst2 = _mm256_fmadd_ps(dst2, src04, weight24); + __m256 weight34 = _mm256_load_ps(weight + 152); + dst3 = _mm256_fmadd_ps(dst3, src04, weight34); + // bock5 + __m256 src05 = _mm256_set1_ps(*(src + 5)); + __m256 weight05 = _mm256_load_ps(weight + 160); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 weight15 = _mm256_load_ps(weight + 168); + dst1 = _mm256_fmadd_ps(dst1, src05, weight15); + __m256 weight25 = _mm256_load_ps(weight + 176); + dst2 = _mm256_fmadd_ps(dst2, src05, weight25); + __m256 weight35 = _mm256_load_ps(weight + 184); + dst3 = _mm256_fmadd_ps(dst3, src05, weight35); + // bock6 + __m256 src06 = _mm256_set1_ps(*(src + 6)); + __m256 weight06 = _mm256_load_ps(weight + 192); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 weight16 = _mm256_load_ps(weight + 200); + dst1 = _mm256_fmadd_ps(dst1, src06, weight16); + __m256 weight26 = _mm256_load_ps(weight + 208); + dst2 = _mm256_fmadd_ps(dst2, src06, weight26); + __m256 weight36 = _mm256_load_ps(weight + 216); + dst3 = _mm256_fmadd_ps(dst3, src06, weight36); + // bock7 + __m256 src07 = _mm256_set1_ps(*(src + 7)); + __m256 weight07 = _mm256_load_ps(weight + 224); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 weight17 = _mm256_load_ps(weight + 232); + dst1 = _mm256_fmadd_ps(dst1, src07, weight17); + __m256 weight27 = _mm256_load_ps(weight + 240); + dst2 = _mm256_fmadd_ps(dst2, src07, weight27); + __m256 weight37 = _mm256_load_ps(weight + 248); + dst3 = _mm256_fmadd_ps(dst3, src07, weight37); + src = src + src_stride; + weight += 1024; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 1 * src_stride + 0, dst1); + _mm256_store_ps(dst + 2 * src_stride + 0, dst2); + _mm256_store_ps(dst + 3 * src_stride + 0, dst3); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..5efc8dcfa81f2c2e55b87932402ea045cd02d965 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,174 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm2\n" + "vmovups 0(%[dst_4]), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 32(%[bias]), %%ymm1\n" + "vmovaps 64(%[bias]), %%ymm2\n" + "vmovaps 96(%[bias]), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_4 ] "r"(dst_4) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + asm volatile( + "0:\n" + // block 0 + "vbroadcastss 0(%[src]), %%ymm15\n" + "vmovaps 0(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 64(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 1 + "vbroadcastss 1(%[src]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 192(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 2 + "vbroadcastss 2(%[src]), %%ymm15\n" + "vmovaps 256(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 3 + "vbroadcastss 3(%[src]), %%ymm15\n" + "vmovaps 384(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 448(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 4 + "vbroadcastss 4(%[src]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 544(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 576(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 5 + "vbroadcastss 5(%[src]), %%ymm15\n" + "vmovaps 640(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 672(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 736(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 6 + "vbroadcastss 6(%[src]), %%ymm15\n" + "vmovaps 768(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 800(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 832(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 864(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 7 + "vbroadcastss 7(%[src]), %%ymm15\n" + "vmovaps 896(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 928(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 960(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 992(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 1024, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm3, 0(%[dst_4])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_4 ] "r"(dst_4) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a795d14be0e68391eb1af2a952028f56d88603b6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..032876f77ae096716565121fe007e59384f41650 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..55787a6905a7b9bc3bf09178c6ef15183033fd6e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,145 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst2; + __m256 dst1; + __m256 dst3; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst2 = _mm256_fmadd_ps(dst2, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst3 = _mm256_fmadd_ps(dst3, src10, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst2 = _mm256_fmadd_ps(dst2, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst3 = _mm256_fmadd_ps(dst3, src11, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst2 = _mm256_fmadd_ps(dst2, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst3 = _mm256_fmadd_ps(dst3, src12, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst2 = _mm256_fmadd_ps(dst2, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst3 = _mm256_fmadd_ps(dst3, src13, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst2 = _mm256_fmadd_ps(dst2, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst3 = _mm256_fmadd_ps(dst3, src14, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst2 = _mm256_fmadd_ps(dst2, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst3 = _mm256_fmadd_ps(dst3, src15, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst2 = _mm256_fmadd_ps(dst2, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst3 = _mm256_fmadd_ps(dst3, src16, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst2 = _mm256_fmadd_ps(dst2, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst3 = _mm256_fmadd_ps(dst3, src17, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 1 * src_stride + 0, dst2); + _mm256_store_ps(dst + 1 * src_stride + 8, dst3); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..c55abaab7490c591faa93cff329866efd0f1d6b0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,163 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 32(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1de0370cb7c38f3ac6242fe18a7c4f5b9563057b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst2; + __m256 dst4; + __m256 dst1; + __m256 dst3; + __m256 dst5; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 2 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 16); + dst1 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst5 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst2 = _mm256_fmadd_ps(dst2, src00, weight10); + dst4 = _mm256_fmadd_ps(dst4, src00, weight20); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst3 = _mm256_fmadd_ps(dst3, src10, weight10); + dst5 = _mm256_fmadd_ps(dst5, src10, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst2 = _mm256_fmadd_ps(dst2, src01, weight11); + dst4 = _mm256_fmadd_ps(dst4, src01, weight21); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst3 = _mm256_fmadd_ps(dst3, src11, weight11); + dst5 = _mm256_fmadd_ps(dst5, src11, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst2 = _mm256_fmadd_ps(dst2, src02, weight12); + dst4 = _mm256_fmadd_ps(dst4, src02, weight22); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst3 = _mm256_fmadd_ps(dst3, src12, weight12); + dst5 = _mm256_fmadd_ps(dst5, src12, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst2 = _mm256_fmadd_ps(dst2, src03, weight13); + dst4 = _mm256_fmadd_ps(dst4, src03, weight23); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst3 = _mm256_fmadd_ps(dst3, src13, weight13); + dst5 = _mm256_fmadd_ps(dst5, src13, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst2 = _mm256_fmadd_ps(dst2, src04, weight14); + dst4 = _mm256_fmadd_ps(dst4, src04, weight24); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst3 = _mm256_fmadd_ps(dst3, src14, weight14); + dst5 = _mm256_fmadd_ps(dst5, src14, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst2 = _mm256_fmadd_ps(dst2, src05, weight15); + dst4 = _mm256_fmadd_ps(dst4, src05, weight25); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst3 = _mm256_fmadd_ps(dst3, src15, weight15); + dst5 = _mm256_fmadd_ps(dst5, src15, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst2 = _mm256_fmadd_ps(dst2, src06, weight16); + dst4 = _mm256_fmadd_ps(dst4, src06, weight26); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst3 = _mm256_fmadd_ps(dst3, src16, weight16); + dst5 = _mm256_fmadd_ps(dst5, src16, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst2 = _mm256_fmadd_ps(dst2, src07, weight17); + dst4 = _mm256_fmadd_ps(dst4, src07, weight27); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst3 = _mm256_fmadd_ps(dst3, src17, weight17); + dst5 = _mm256_fmadd_ps(dst5, src17, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 1 * src_stride + 0, dst2); + _mm256_store_ps(dst + 1 * src_stride + 8, dst3); + _mm256_store_ps(dst + 2 * src_stride + 0, dst4); + _mm256_store_ps(dst + 2 * src_stride + 8, dst5); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..c8e3309bd9bc3ab983e669baf0cda3634d40b559 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 32(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 64(%[bias]), %%ymm4\n" + "vmovaps 64(%[bias]), %%ymm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 2)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..0f9dbfa6ed9f4000171d9e7bde56654be4a49f9c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c @@ -0,0 +1,225 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst2; + __m256 dst4; + __m256 dst6; + __m256 dst1; + __m256 dst3; + __m256 dst5; + __m256 dst7; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 3 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 3 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 16); + dst6 = _mm256_load_ps(bias + 24); + dst1 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst5 = _mm256_load_ps(bias + 16); + dst7 = _mm256_load_ps(bias + 24); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 src00 = _mm256_set1_ps(*(src + 0)); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + __m256 weight00 = _mm256_load_ps(weight + 0); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 weight10 = _mm256_load_ps(weight + 8); + dst2 = _mm256_fmadd_ps(dst2, src00, weight10); + dst3 = _mm256_fmadd_ps(dst3, src10, weight10); + __m256 weight20 = _mm256_load_ps(weight + 16); + dst4 = _mm256_fmadd_ps(dst4, src00, weight20); + dst5 = _mm256_fmadd_ps(dst5, src10, weight20); + __m256 weight30 = _mm256_load_ps(weight + 24); + dst6 = _mm256_fmadd_ps(dst6, src00, weight30); + dst7 = _mm256_fmadd_ps(dst7, src10, weight30); + // bock1 + __m256 src01 = _mm256_set1_ps(*(src + 1)); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + __m256 weight01 = _mm256_load_ps(weight + 32); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 weight11 = _mm256_load_ps(weight + 40); + dst2 = _mm256_fmadd_ps(dst2, src01, weight11); + dst3 = _mm256_fmadd_ps(dst3, src11, weight11); + __m256 weight21 = _mm256_load_ps(weight + 48); + dst4 = _mm256_fmadd_ps(dst4, src01, weight21); + dst5 = _mm256_fmadd_ps(dst5, src11, weight21); + __m256 weight31 = _mm256_load_ps(weight + 56); + dst6 = _mm256_fmadd_ps(dst6, src01, weight31); + dst7 = _mm256_fmadd_ps(dst7, src11, weight31); + // bock2 + __m256 src02 = _mm256_set1_ps(*(src + 2)); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + __m256 weight02 = _mm256_load_ps(weight + 64); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 weight12 = _mm256_load_ps(weight + 72); + dst2 = _mm256_fmadd_ps(dst2, src02, weight12); + dst3 = _mm256_fmadd_ps(dst3, src12, weight12); + __m256 weight22 = _mm256_load_ps(weight + 80); + dst4 = _mm256_fmadd_ps(dst4, src02, weight22); + dst5 = _mm256_fmadd_ps(dst5, src12, weight22); + __m256 weight32 = _mm256_load_ps(weight + 88); + dst6 = _mm256_fmadd_ps(dst6, src02, weight32); + dst7 = _mm256_fmadd_ps(dst7, src12, weight32); + // bock3 + __m256 src03 = _mm256_set1_ps(*(src + 3)); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + __m256 weight03 = _mm256_load_ps(weight + 96); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 weight13 = _mm256_load_ps(weight + 104); + dst2 = _mm256_fmadd_ps(dst2, src03, weight13); + dst3 = _mm256_fmadd_ps(dst3, src13, weight13); + __m256 weight23 = _mm256_load_ps(weight + 112); + dst4 = _mm256_fmadd_ps(dst4, src03, weight23); + dst5 = _mm256_fmadd_ps(dst5, src13, weight23); + __m256 weight33 = _mm256_load_ps(weight + 120); + dst6 = _mm256_fmadd_ps(dst6, src03, weight33); + dst7 = _mm256_fmadd_ps(dst7, src13, weight33); + // bock4 + __m256 src04 = _mm256_set1_ps(*(src + 4)); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + __m256 weight04 = _mm256_load_ps(weight + 128); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 weight14 = _mm256_load_ps(weight + 136); + dst2 = _mm256_fmadd_ps(dst2, src04, weight14); + dst3 = _mm256_fmadd_ps(dst3, src14, weight14); + __m256 weight24 = _mm256_load_ps(weight + 144); + dst4 = _mm256_fmadd_ps(dst4, src04, weight24); + dst5 = _mm256_fmadd_ps(dst5, src14, weight24); + __m256 weight34 = _mm256_load_ps(weight + 152); + dst6 = _mm256_fmadd_ps(dst6, src04, weight34); + dst7 = _mm256_fmadd_ps(dst7, src14, weight34); + // bock5 + __m256 src05 = _mm256_set1_ps(*(src + 5)); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + __m256 weight05 = _mm256_load_ps(weight + 160); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 weight15 = _mm256_load_ps(weight + 168); + dst2 = _mm256_fmadd_ps(dst2, src05, weight15); + dst3 = _mm256_fmadd_ps(dst3, src15, weight15); + __m256 weight25 = _mm256_load_ps(weight + 176); + dst4 = _mm256_fmadd_ps(dst4, src05, weight25); + dst5 = _mm256_fmadd_ps(dst5, src15, weight25); + __m256 weight35 = _mm256_load_ps(weight + 184); + dst6 = _mm256_fmadd_ps(dst6, src05, weight35); + dst7 = _mm256_fmadd_ps(dst7, src15, weight35); + // bock6 + __m256 src06 = _mm256_set1_ps(*(src + 6)); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + __m256 weight06 = _mm256_load_ps(weight + 192); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 weight16 = _mm256_load_ps(weight + 200); + dst2 = _mm256_fmadd_ps(dst2, src06, weight16); + dst3 = _mm256_fmadd_ps(dst3, src16, weight16); + __m256 weight26 = _mm256_load_ps(weight + 208); + dst4 = _mm256_fmadd_ps(dst4, src06, weight26); + dst5 = _mm256_fmadd_ps(dst5, src16, weight26); + __m256 weight36 = _mm256_load_ps(weight + 216); + dst6 = _mm256_fmadd_ps(dst6, src06, weight36); + dst7 = _mm256_fmadd_ps(dst7, src16, weight36); + // bock7 + __m256 src07 = _mm256_set1_ps(*(src + 7)); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + __m256 weight07 = _mm256_load_ps(weight + 224); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 weight17 = _mm256_load_ps(weight + 232); + dst2 = _mm256_fmadd_ps(dst2, src07, weight17); + dst3 = _mm256_fmadd_ps(dst3, src17, weight17); + __m256 weight27 = _mm256_load_ps(weight + 240); + dst4 = _mm256_fmadd_ps(dst4, src07, weight27); + dst5 = _mm256_fmadd_ps(dst5, src17, weight27); + __m256 weight37 = _mm256_load_ps(weight + 248); + dst6 = _mm256_fmadd_ps(dst6, src07, weight37); + dst7 = _mm256_fmadd_ps(dst7, src17, weight37); + src = src + src_stride; + weight += 1024; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 1 * src_stride + 0, dst2); + _mm256_store_ps(dst + 1 * src_stride + 8, dst3); + _mm256_store_ps(dst + 2 * src_stride + 0, dst4); + _mm256_store_ps(dst + 2 * src_stride + 8, dst5); + _mm256_store_ps(dst + 3 * src_stride + 0, dst6); + _mm256_store_ps(dst + 3 * src_stride + 8, dst7); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..e8586a7c381f4007207b412a4c5bd371ed4d0075 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,238 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm5\n" + "vmovups 0(%[dst_4]), %%ymm6\n" + "vmovups 32(%[dst_4]), %%ymm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 32(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 64(%[bias]), %%ymm4\n" + "vmovaps 64(%[bias]), %%ymm5\n" + "vmovaps 96(%[bias]), %%ymm6\n" + "vmovaps 96(%[bias]), %%ymm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_4 ] "r"(dst_4) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + asm volatile( + "0:\n" + // block 0 + "vbroadcastss 0(%[src]), %%ymm15\n" + "vbroadcastss 32(%[src]), %%ymm14\n" + "vmovaps 0(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 32(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 96(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 1 + "vbroadcastss 1(%[src]), %%ymm15\n" + "vbroadcastss 33(%[src]), %%ymm14\n" + "vmovaps 128(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 192(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 224(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 2 + "vbroadcastss 2(%[src]), %%ymm15\n" + "vbroadcastss 34(%[src]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 288(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 320(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 3 + "vbroadcastss 3(%[src]), %%ymm15\n" + "vbroadcastss 35(%[src]), %%ymm14\n" + "vmovaps 384(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 416(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 480(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 4 + "vbroadcastss 4(%[src]), %%ymm15\n" + "vbroadcastss 36(%[src]), %%ymm14\n" + "vmovaps 512(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 576(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 608(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 5 + "vbroadcastss 5(%[src]), %%ymm15\n" + "vbroadcastss 37(%[src]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 672(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 704(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 6 + "vbroadcastss 6(%[src]), %%ymm15\n" + "vbroadcastss 38(%[src]), %%ymm14\n" + "vmovaps 768(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 800(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 832(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 864(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 7 + "vbroadcastss 7(%[src]), %%ymm15\n" + "vbroadcastss 39(%[src]), %%ymm14\n" + "vmovaps 896(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 928(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 960(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 992(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 1024, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm6, 0(%[dst_4])\n" + "vmovups %%ymm7, 32(%[dst_4])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_4 ] "r"(dst_4) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..6fcf7958a743436d8f0f4fe870a33acc148b94f6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..279370a00857d4b54a4c602ea56c2edab28381d3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8cec925905bd13f881d6363751c3fba0944f3a78 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst3; + __m256 dst1; + __m256 dst4; + __m256 dst2; + __m256 dst5; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst3 = _mm256_fmadd_ps(dst3, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst4 = _mm256_fmadd_ps(dst4, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst5 = _mm256_fmadd_ps(dst5, src20, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst3 = _mm256_fmadd_ps(dst3, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst4 = _mm256_fmadd_ps(dst4, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst5 = _mm256_fmadd_ps(dst5, src21, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst3 = _mm256_fmadd_ps(dst3, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst4 = _mm256_fmadd_ps(dst4, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst5 = _mm256_fmadd_ps(dst5, src22, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst3 = _mm256_fmadd_ps(dst3, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst4 = _mm256_fmadd_ps(dst4, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst5 = _mm256_fmadd_ps(dst5, src23, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst3 = _mm256_fmadd_ps(dst3, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst4 = _mm256_fmadd_ps(dst4, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst5 = _mm256_fmadd_ps(dst5, src24, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst3 = _mm256_fmadd_ps(dst3, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst4 = _mm256_fmadd_ps(dst4, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst5 = _mm256_fmadd_ps(dst5, src25, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst3 = _mm256_fmadd_ps(dst3, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst4 = _mm256_fmadd_ps(dst4, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst5 = _mm256_fmadd_ps(dst5, src26, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst3 = _mm256_fmadd_ps(dst3, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst4 = _mm256_fmadd_ps(dst4, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst5 = _mm256_fmadd_ps(dst5, src27, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 1 * src_stride + 0, dst3); + _mm256_store_ps(dst + 1 * src_stride + 8, dst4); + _mm256_store_ps(dst + 1 * src_stride + 16, dst5); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..e2ddb336e6014e8064d7ea2ff249cbc29508ffe4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..db5d05d6f0c73f5615e22fc10bc6a2d8abb298b8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,241 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst3; + __m256 dst6; + __m256 dst1; + __m256 dst4; + __m256 dst7; + __m256 dst2; + __m256 dst5; + __m256 dst8; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst6 = _mm256_load_ps(bias + 16); + dst1 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst7 = _mm256_load_ps(bias + 16); + dst2 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst8 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst3 = _mm256_fmadd_ps(dst3, src00, weight10); + dst6 = _mm256_fmadd_ps(dst6, src00, weight20); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst4 = _mm256_fmadd_ps(dst4, src10, weight10); + dst7 = _mm256_fmadd_ps(dst7, src10, weight20); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst5 = _mm256_fmadd_ps(dst5, src20, weight10); + dst8 = _mm256_fmadd_ps(dst8, src20, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst3 = _mm256_fmadd_ps(dst3, src01, weight11); + dst6 = _mm256_fmadd_ps(dst6, src01, weight21); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst4 = _mm256_fmadd_ps(dst4, src11, weight11); + dst7 = _mm256_fmadd_ps(dst7, src11, weight21); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst5 = _mm256_fmadd_ps(dst5, src21, weight11); + dst8 = _mm256_fmadd_ps(dst8, src21, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst3 = _mm256_fmadd_ps(dst3, src02, weight12); + dst6 = _mm256_fmadd_ps(dst6, src02, weight22); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst4 = _mm256_fmadd_ps(dst4, src12, weight12); + dst7 = _mm256_fmadd_ps(dst7, src12, weight22); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst5 = _mm256_fmadd_ps(dst5, src22, weight12); + dst8 = _mm256_fmadd_ps(dst8, src22, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst3 = _mm256_fmadd_ps(dst3, src03, weight13); + dst6 = _mm256_fmadd_ps(dst6, src03, weight23); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst4 = _mm256_fmadd_ps(dst4, src13, weight13); + dst7 = _mm256_fmadd_ps(dst7, src13, weight23); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst5 = _mm256_fmadd_ps(dst5, src23, weight13); + dst8 = _mm256_fmadd_ps(dst8, src23, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst3 = _mm256_fmadd_ps(dst3, src04, weight14); + dst6 = _mm256_fmadd_ps(dst6, src04, weight24); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst4 = _mm256_fmadd_ps(dst4, src14, weight14); + dst7 = _mm256_fmadd_ps(dst7, src14, weight24); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst5 = _mm256_fmadd_ps(dst5, src24, weight14); + dst8 = _mm256_fmadd_ps(dst8, src24, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst3 = _mm256_fmadd_ps(dst3, src05, weight15); + dst6 = _mm256_fmadd_ps(dst6, src05, weight25); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst4 = _mm256_fmadd_ps(dst4, src15, weight15); + dst7 = _mm256_fmadd_ps(dst7, src15, weight25); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst5 = _mm256_fmadd_ps(dst5, src25, weight15); + dst8 = _mm256_fmadd_ps(dst8, src25, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst3 = _mm256_fmadd_ps(dst3, src06, weight16); + dst6 = _mm256_fmadd_ps(dst6, src06, weight26); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst4 = _mm256_fmadd_ps(dst4, src16, weight16); + dst7 = _mm256_fmadd_ps(dst7, src16, weight26); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst5 = _mm256_fmadd_ps(dst5, src26, weight16); + dst8 = _mm256_fmadd_ps(dst8, src26, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst3 = _mm256_fmadd_ps(dst3, src07, weight17); + dst6 = _mm256_fmadd_ps(dst6, src07, weight27); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst4 = _mm256_fmadd_ps(dst4, src17, weight17); + dst7 = _mm256_fmadd_ps(dst7, src17, weight27); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst5 = _mm256_fmadd_ps(dst5, src27, weight17); + dst8 = _mm256_fmadd_ps(dst8, src27, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 1 * src_stride + 0, dst3); + _mm256_store_ps(dst + 1 * src_stride + 8, dst4); + _mm256_store_ps(dst + 1 * src_stride + 16, dst5); + _mm256_store_ps(dst + 2 * src_stride + 0, dst6); + _mm256_store_ps(dst + 2 * src_stride + 8, dst7); + _mm256_store_ps(dst + 2 * src_stride + 16, dst8); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..487634f1384dceb6c64425100066792c95bd2049 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,249 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm6\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm7\n" + "vmovups 64(%[dst], %[dst_stride], 2), %%ymm8\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 64(%[bias]), %%ymm6\n" + "vmovaps 64(%[bias]), %%ymm7\n" + "vmovaps 64(%[bias]), %%ymm8\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm7, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm8, 64(%[dst], %[dst_stride], 2)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..eaf7595f1bcaa5f966143ff0c413f32adf847cbf --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst3; + __m256 dst6; + __m256 dst9; + __m256 dst1; + __m256 dst4; + __m256 dst7; + __m256 dst10; + __m256 dst2; + __m256 dst5; + __m256 dst8; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst9 = _mm256_load_ps(dst + 3 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst10 = _mm256_load_ps(dst + 3 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16); + dst11 = _mm256_load_ps(dst + 3 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst6 = _mm256_load_ps(bias + 16); + dst9 = _mm256_load_ps(bias + 24); + dst1 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst7 = _mm256_load_ps(bias + 16); + dst10 = _mm256_load_ps(bias + 24); + dst2 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst8 = _mm256_load_ps(bias + 16); + dst11 = _mm256_load_ps(bias + 24); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 src00 = _mm256_set1_ps(*(src + 0)); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + __m256 weight00 = _mm256_load_ps(weight + 0); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 weight10 = _mm256_load_ps(weight + 8); + dst3 = _mm256_fmadd_ps(dst3, src00, weight10); + dst4 = _mm256_fmadd_ps(dst4, src10, weight10); + dst5 = _mm256_fmadd_ps(dst5, src20, weight10); + __m256 weight20 = _mm256_load_ps(weight + 16); + dst6 = _mm256_fmadd_ps(dst6, src00, weight20); + dst7 = _mm256_fmadd_ps(dst7, src10, weight20); + dst8 = _mm256_fmadd_ps(dst8, src20, weight20); + __m256 weight30 = _mm256_load_ps(weight + 24); + dst9 = _mm256_fmadd_ps(dst9, src00, weight30); + dst10 = _mm256_fmadd_ps(dst10, src10, weight30); + dst11 = _mm256_fmadd_ps(dst11, src20, weight30); + // bock1 + __m256 src01 = _mm256_set1_ps(*(src + 1)); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + __m256 weight01 = _mm256_load_ps(weight + 32); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 weight11 = _mm256_load_ps(weight + 40); + dst3 = _mm256_fmadd_ps(dst3, src01, weight11); + dst4 = _mm256_fmadd_ps(dst4, src11, weight11); + dst5 = _mm256_fmadd_ps(dst5, src21, weight11); + __m256 weight21 = _mm256_load_ps(weight + 48); + dst6 = _mm256_fmadd_ps(dst6, src01, weight21); + dst7 = _mm256_fmadd_ps(dst7, src11, weight21); + dst8 = _mm256_fmadd_ps(dst8, src21, weight21); + __m256 weight31 = _mm256_load_ps(weight + 56); + dst9 = _mm256_fmadd_ps(dst9, src01, weight31); + dst10 = _mm256_fmadd_ps(dst10, src11, weight31); + dst11 = _mm256_fmadd_ps(dst11, src21, weight31); + // bock2 + __m256 src02 = _mm256_set1_ps(*(src + 2)); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + __m256 weight02 = _mm256_load_ps(weight + 64); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 weight12 = _mm256_load_ps(weight + 72); + dst3 = _mm256_fmadd_ps(dst3, src02, weight12); + dst4 = _mm256_fmadd_ps(dst4, src12, weight12); + dst5 = _mm256_fmadd_ps(dst5, src22, weight12); + __m256 weight22 = _mm256_load_ps(weight + 80); + dst6 = _mm256_fmadd_ps(dst6, src02, weight22); + dst7 = _mm256_fmadd_ps(dst7, src12, weight22); + dst8 = _mm256_fmadd_ps(dst8, src22, weight22); + __m256 weight32 = _mm256_load_ps(weight + 88); + dst9 = _mm256_fmadd_ps(dst9, src02, weight32); + dst10 = _mm256_fmadd_ps(dst10, src12, weight32); + dst11 = _mm256_fmadd_ps(dst11, src22, weight32); + // bock3 + __m256 src03 = _mm256_set1_ps(*(src + 3)); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + __m256 weight03 = _mm256_load_ps(weight + 96); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 weight13 = _mm256_load_ps(weight + 104); + dst3 = _mm256_fmadd_ps(dst3, src03, weight13); + dst4 = _mm256_fmadd_ps(dst4, src13, weight13); + dst5 = _mm256_fmadd_ps(dst5, src23, weight13); + __m256 weight23 = _mm256_load_ps(weight + 112); + dst6 = _mm256_fmadd_ps(dst6, src03, weight23); + dst7 = _mm256_fmadd_ps(dst7, src13, weight23); + dst8 = _mm256_fmadd_ps(dst8, src23, weight23); + __m256 weight33 = _mm256_load_ps(weight + 120); + dst9 = _mm256_fmadd_ps(dst9, src03, weight33); + dst10 = _mm256_fmadd_ps(dst10, src13, weight33); + dst11 = _mm256_fmadd_ps(dst11, src23, weight33); + // bock4 + __m256 src04 = _mm256_set1_ps(*(src + 4)); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + __m256 weight04 = _mm256_load_ps(weight + 128); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 weight14 = _mm256_load_ps(weight + 136); + dst3 = _mm256_fmadd_ps(dst3, src04, weight14); + dst4 = _mm256_fmadd_ps(dst4, src14, weight14); + dst5 = _mm256_fmadd_ps(dst5, src24, weight14); + __m256 weight24 = _mm256_load_ps(weight + 144); + dst6 = _mm256_fmadd_ps(dst6, src04, weight24); + dst7 = _mm256_fmadd_ps(dst7, src14, weight24); + dst8 = _mm256_fmadd_ps(dst8, src24, weight24); + __m256 weight34 = _mm256_load_ps(weight + 152); + dst9 = _mm256_fmadd_ps(dst9, src04, weight34); + dst10 = _mm256_fmadd_ps(dst10, src14, weight34); + dst11 = _mm256_fmadd_ps(dst11, src24, weight34); + // bock5 + __m256 src05 = _mm256_set1_ps(*(src + 5)); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + __m256 weight05 = _mm256_load_ps(weight + 160); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 weight15 = _mm256_load_ps(weight + 168); + dst3 = _mm256_fmadd_ps(dst3, src05, weight15); + dst4 = _mm256_fmadd_ps(dst4, src15, weight15); + dst5 = _mm256_fmadd_ps(dst5, src25, weight15); + __m256 weight25 = _mm256_load_ps(weight + 176); + dst6 = _mm256_fmadd_ps(dst6, src05, weight25); + dst7 = _mm256_fmadd_ps(dst7, src15, weight25); + dst8 = _mm256_fmadd_ps(dst8, src25, weight25); + __m256 weight35 = _mm256_load_ps(weight + 184); + dst9 = _mm256_fmadd_ps(dst9, src05, weight35); + dst10 = _mm256_fmadd_ps(dst10, src15, weight35); + dst11 = _mm256_fmadd_ps(dst11, src25, weight35); + // bock6 + __m256 src06 = _mm256_set1_ps(*(src + 6)); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + __m256 weight06 = _mm256_load_ps(weight + 192); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 weight16 = _mm256_load_ps(weight + 200); + dst3 = _mm256_fmadd_ps(dst3, src06, weight16); + dst4 = _mm256_fmadd_ps(dst4, src16, weight16); + dst5 = _mm256_fmadd_ps(dst5, src26, weight16); + __m256 weight26 = _mm256_load_ps(weight + 208); + dst6 = _mm256_fmadd_ps(dst6, src06, weight26); + dst7 = _mm256_fmadd_ps(dst7, src16, weight26); + dst8 = _mm256_fmadd_ps(dst8, src26, weight26); + __m256 weight36 = _mm256_load_ps(weight + 216); + dst9 = _mm256_fmadd_ps(dst9, src06, weight36); + dst10 = _mm256_fmadd_ps(dst10, src16, weight36); + dst11 = _mm256_fmadd_ps(dst11, src26, weight36); + // bock7 + __m256 src07 = _mm256_set1_ps(*(src + 7)); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + __m256 weight07 = _mm256_load_ps(weight + 224); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 weight17 = _mm256_load_ps(weight + 232); + dst3 = _mm256_fmadd_ps(dst3, src07, weight17); + dst4 = _mm256_fmadd_ps(dst4, src17, weight17); + dst5 = _mm256_fmadd_ps(dst5, src27, weight17); + __m256 weight27 = _mm256_load_ps(weight + 240); + dst6 = _mm256_fmadd_ps(dst6, src07, weight27); + dst7 = _mm256_fmadd_ps(dst7, src17, weight27); + dst8 = _mm256_fmadd_ps(dst8, src27, weight27); + __m256 weight37 = _mm256_load_ps(weight + 248); + dst9 = _mm256_fmadd_ps(dst9, src07, weight37); + dst10 = _mm256_fmadd_ps(dst10, src17, weight37); + dst11 = _mm256_fmadd_ps(dst11, src27, weight37); + src = src + src_stride; + weight += 1024; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 1 * src_stride + 0, dst3); + _mm256_store_ps(dst + 1 * src_stride + 8, dst4); + _mm256_store_ps(dst + 1 * src_stride + 16, dst5); + _mm256_store_ps(dst + 2 * src_stride + 0, dst6); + _mm256_store_ps(dst + 2 * src_stride + 8, dst7); + _mm256_store_ps(dst + 2 * src_stride + 16, dst8); + _mm256_store_ps(dst + 3 * src_stride + 0, dst9); + _mm256_store_ps(dst + 3 * src_stride + 8, dst10); + _mm256_store_ps(dst + 3 * src_stride + 16, dst11); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..b0f09911f29d03785bb607644912d7313286688a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,302 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm6\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm7\n" + "vmovups 64(%[dst], %[dst_stride], 2), %%ymm8\n" + "vmovups 0(%[dst_4]), %%ymm9\n" + "vmovups 32(%[dst_4]), %%ymm10\n" + "vmovups 64(%[dst_4]), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 64(%[bias]), %%ymm6\n" + "vmovaps 64(%[bias]), %%ymm7\n" + "vmovaps 64(%[bias]), %%ymm8\n" + "vmovaps 96(%[bias]), %%ymm9\n" + "vmovaps 96(%[bias]), %%ymm10\n" + "vmovaps 96(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_4 ] "r"(dst_4) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vbroadcastss 0(%[src]), %%ymm15\n" + "vbroadcastss 32(%[src]), %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vmovaps 0(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 32(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 64(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 96(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 1 + "vbroadcastss 1(%[src]), %%ymm15\n" + "vbroadcastss 33(%[src]), %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vmovaps 128(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 160(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 192(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 224(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 2 + "vbroadcastss 2(%[src]), %%ymm15\n" + "vbroadcastss 34(%[src]), %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vmovaps 256(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 288(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 320(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 352(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 3 + "vbroadcastss 3(%[src]), %%ymm15\n" + "vbroadcastss 35(%[src]), %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vmovaps 384(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 416(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 448(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 480(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 4 + "vbroadcastss 4(%[src]), %%ymm15\n" + "vbroadcastss 36(%[src]), %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vmovaps 512(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 544(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 576(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 608(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 5 + "vbroadcastss 5(%[src]), %%ymm15\n" + "vbroadcastss 37(%[src]), %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vmovaps 640(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 672(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 704(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 736(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 6 + "vbroadcastss 6(%[src]), %%ymm15\n" + "vbroadcastss 38(%[src]), %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vmovaps 768(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 800(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 832(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 864(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 7 + "vbroadcastss 7(%[src]), %%ymm15\n" + "vbroadcastss 39(%[src]), %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vmovaps 896(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 928(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 960(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 992(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 1024, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm7, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm8, 64(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm9, 0(%[dst_4])\n" + "vmovups %%ymm10, 32(%[dst_4])\n" + "vmovups %%ymm11, 64(%[dst_4])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_4 ] "r"(dst_4) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..aff74f6f09f2c2aca45d1fc8bf96d784188f1ecd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..8c51f8064422d69b3bb124afeb9897b7e85f2bf5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4b34d19bea615bc94cd3c293fa31c5fc999c394b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,225 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst4; + __m256 dst1; + __m256 dst5; + __m256 dst2; + __m256 dst6; + __m256 dst3; + __m256 dst7; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 24); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst3 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst4 = _mm256_fmadd_ps(dst4, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst5 = _mm256_fmadd_ps(dst5, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst6 = _mm256_fmadd_ps(dst6, src20, weight10); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst7 = _mm256_fmadd_ps(dst7, src30, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst4 = _mm256_fmadd_ps(dst4, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst5 = _mm256_fmadd_ps(dst5, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst6 = _mm256_fmadd_ps(dst6, src21, weight11); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst7 = _mm256_fmadd_ps(dst7, src31, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst4 = _mm256_fmadd_ps(dst4, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst5 = _mm256_fmadd_ps(dst5, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst6 = _mm256_fmadd_ps(dst6, src22, weight12); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst7 = _mm256_fmadd_ps(dst7, src32, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst4 = _mm256_fmadd_ps(dst4, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst5 = _mm256_fmadd_ps(dst5, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst6 = _mm256_fmadd_ps(dst6, src23, weight13); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst7 = _mm256_fmadd_ps(dst7, src33, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst4 = _mm256_fmadd_ps(dst4, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst5 = _mm256_fmadd_ps(dst5, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst6 = _mm256_fmadd_ps(dst6, src24, weight14); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst7 = _mm256_fmadd_ps(dst7, src34, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst4 = _mm256_fmadd_ps(dst4, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst5 = _mm256_fmadd_ps(dst5, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst6 = _mm256_fmadd_ps(dst6, src25, weight15); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst7 = _mm256_fmadd_ps(dst7, src35, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst4 = _mm256_fmadd_ps(dst4, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst5 = _mm256_fmadd_ps(dst5, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst6 = _mm256_fmadd_ps(dst6, src26, weight16); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst7 = _mm256_fmadd_ps(dst7, src36, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst4 = _mm256_fmadd_ps(dst4, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst5 = _mm256_fmadd_ps(dst5, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst6 = _mm256_fmadd_ps(dst6, src27, weight17); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst7 = _mm256_fmadd_ps(dst7, src37, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 1 * src_stride + 0, dst4); + _mm256_store_ps(dst + 1 * src_stride + 8, dst5); + _mm256_store_ps(dst + 1 * src_stride + 16, dst6); + _mm256_store_ps(dst + 1 * src_stride + 24, dst7); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..061f1d8cd49ea33e3ce6af092a758994cf815c66 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,235 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 96(%[dst], %[dst_stride], 1)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8d987c6153a1a4c550e4846dbb7de79b65dea5f7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst4; + __m256 dst8; + __m256 dst1; + __m256 dst5; + __m256 dst9; + __m256 dst2; + __m256 dst6; + __m256 dst10; + __m256 dst3; + __m256 dst7; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst8 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst9 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst10 = _mm256_load_ps(dst + 2 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 24); + dst11 = _mm256_load_ps(dst + 2 * dst_stride + 24); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst8 = _mm256_load_ps(bias + 16); + dst1 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst9 = _mm256_load_ps(bias + 16); + dst2 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst10 = _mm256_load_ps(bias + 16); + dst3 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + dst11 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst4 = _mm256_fmadd_ps(dst4, src00, weight10); + dst8 = _mm256_fmadd_ps(dst8, src00, weight20); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst5 = _mm256_fmadd_ps(dst5, src10, weight10); + dst9 = _mm256_fmadd_ps(dst9, src10, weight20); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst6 = _mm256_fmadd_ps(dst6, src20, weight10); + dst10 = _mm256_fmadd_ps(dst10, src20, weight20); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst7 = _mm256_fmadd_ps(dst7, src30, weight10); + dst11 = _mm256_fmadd_ps(dst11, src30, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst4 = _mm256_fmadd_ps(dst4, src01, weight11); + dst8 = _mm256_fmadd_ps(dst8, src01, weight21); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst5 = _mm256_fmadd_ps(dst5, src11, weight11); + dst9 = _mm256_fmadd_ps(dst9, src11, weight21); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst6 = _mm256_fmadd_ps(dst6, src21, weight11); + dst10 = _mm256_fmadd_ps(dst10, src21, weight21); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst7 = _mm256_fmadd_ps(dst7, src31, weight11); + dst11 = _mm256_fmadd_ps(dst11, src31, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst4 = _mm256_fmadd_ps(dst4, src02, weight12); + dst8 = _mm256_fmadd_ps(dst8, src02, weight22); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst5 = _mm256_fmadd_ps(dst5, src12, weight12); + dst9 = _mm256_fmadd_ps(dst9, src12, weight22); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst6 = _mm256_fmadd_ps(dst6, src22, weight12); + dst10 = _mm256_fmadd_ps(dst10, src22, weight22); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst7 = _mm256_fmadd_ps(dst7, src32, weight12); + dst11 = _mm256_fmadd_ps(dst11, src32, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst4 = _mm256_fmadd_ps(dst4, src03, weight13); + dst8 = _mm256_fmadd_ps(dst8, src03, weight23); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst5 = _mm256_fmadd_ps(dst5, src13, weight13); + dst9 = _mm256_fmadd_ps(dst9, src13, weight23); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst6 = _mm256_fmadd_ps(dst6, src23, weight13); + dst10 = _mm256_fmadd_ps(dst10, src23, weight23); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst7 = _mm256_fmadd_ps(dst7, src33, weight13); + dst11 = _mm256_fmadd_ps(dst11, src33, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst4 = _mm256_fmadd_ps(dst4, src04, weight14); + dst8 = _mm256_fmadd_ps(dst8, src04, weight24); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst5 = _mm256_fmadd_ps(dst5, src14, weight14); + dst9 = _mm256_fmadd_ps(dst9, src14, weight24); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst6 = _mm256_fmadd_ps(dst6, src24, weight14); + dst10 = _mm256_fmadd_ps(dst10, src24, weight24); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst7 = _mm256_fmadd_ps(dst7, src34, weight14); + dst11 = _mm256_fmadd_ps(dst11, src34, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst4 = _mm256_fmadd_ps(dst4, src05, weight15); + dst8 = _mm256_fmadd_ps(dst8, src05, weight25); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst5 = _mm256_fmadd_ps(dst5, src15, weight15); + dst9 = _mm256_fmadd_ps(dst9, src15, weight25); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst6 = _mm256_fmadd_ps(dst6, src25, weight15); + dst10 = _mm256_fmadd_ps(dst10, src25, weight25); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst7 = _mm256_fmadd_ps(dst7, src35, weight15); + dst11 = _mm256_fmadd_ps(dst11, src35, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst4 = _mm256_fmadd_ps(dst4, src06, weight16); + dst8 = _mm256_fmadd_ps(dst8, src06, weight26); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst5 = _mm256_fmadd_ps(dst5, src16, weight16); + dst9 = _mm256_fmadd_ps(dst9, src16, weight26); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst6 = _mm256_fmadd_ps(dst6, src26, weight16); + dst10 = _mm256_fmadd_ps(dst10, src26, weight26); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst7 = _mm256_fmadd_ps(dst7, src36, weight16); + dst11 = _mm256_fmadd_ps(dst11, src36, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst4 = _mm256_fmadd_ps(dst4, src07, weight17); + dst8 = _mm256_fmadd_ps(dst8, src07, weight27); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst5 = _mm256_fmadd_ps(dst5, src17, weight17); + dst9 = _mm256_fmadd_ps(dst9, src17, weight27); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst6 = _mm256_fmadd_ps(dst6, src27, weight17); + dst10 = _mm256_fmadd_ps(dst10, src27, weight27); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst7 = _mm256_fmadd_ps(dst7, src37, weight17); + dst11 = _mm256_fmadd_ps(dst11, src37, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 1 * src_stride + 0, dst4); + _mm256_store_ps(dst + 1 * src_stride + 8, dst5); + _mm256_store_ps(dst + 1 * src_stride + 16, dst6); + _mm256_store_ps(dst + 1 * src_stride + 24, dst7); + _mm256_store_ps(dst + 2 * src_stride + 0, dst8); + _mm256_store_ps(dst + 2 * src_stride + 8, dst9); + _mm256_store_ps(dst + 2 * src_stride + 16, dst10); + _mm256_store_ps(dst + 2 * src_stride + 24, dst11); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..d949566e0756508e8ef5598d6b02bcecb12fc384 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,299 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm7\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm8\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm9\n" + "vmovups 64(%[dst], %[dst_stride], 2), %%ymm10\n" + "vmovups 96(%[dst], %[dst_stride], 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "vmovaps 64(%[bias]), %%ymm8\n" + "vmovaps 64(%[bias]), %%ymm9\n" + "vmovaps 64(%[bias]), %%ymm10\n" + "vmovaps 64(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 96(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm8, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm9, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm10, 64(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm11, 96(%[dst], %[dst_stride], 2)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c972e8a846b6d220110fc099d21ad24def80b9fa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,153 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..05b800dfc1eb39b9ed42b7696e97d4b35ecad9c4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,171 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5426fb627ad7ca63de3818c9696658955776ce8f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,265 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst5; + __m256 dst1; + __m256 dst6; + __m256 dst2; + __m256 dst7; + __m256 dst3; + __m256 dst8; + __m256 dst4; + __m256 dst9; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst8 = _mm256_load_ps(dst + 1 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst9 = _mm256_load_ps(dst + 1 * dst_stride + 32); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + dst3 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst5 = _mm256_fmadd_ps(dst5, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst6 = _mm256_fmadd_ps(dst6, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst7 = _mm256_fmadd_ps(dst7, src20, weight10); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst8 = _mm256_fmadd_ps(dst8, src30, weight10); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + dst9 = _mm256_fmadd_ps(dst9, src40, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst5 = _mm256_fmadd_ps(dst5, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst6 = _mm256_fmadd_ps(dst6, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst7 = _mm256_fmadd_ps(dst7, src21, weight11); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst8 = _mm256_fmadd_ps(dst8, src31, weight11); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + dst9 = _mm256_fmadd_ps(dst9, src41, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst5 = _mm256_fmadd_ps(dst5, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst6 = _mm256_fmadd_ps(dst6, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst7 = _mm256_fmadd_ps(dst7, src22, weight12); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst8 = _mm256_fmadd_ps(dst8, src32, weight12); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + dst9 = _mm256_fmadd_ps(dst9, src42, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst5 = _mm256_fmadd_ps(dst5, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst6 = _mm256_fmadd_ps(dst6, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst7 = _mm256_fmadd_ps(dst7, src23, weight13); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst8 = _mm256_fmadd_ps(dst8, src33, weight13); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + dst9 = _mm256_fmadd_ps(dst9, src43, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst5 = _mm256_fmadd_ps(dst5, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst6 = _mm256_fmadd_ps(dst6, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst7 = _mm256_fmadd_ps(dst7, src24, weight14); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst8 = _mm256_fmadd_ps(dst8, src34, weight14); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + dst9 = _mm256_fmadd_ps(dst9, src44, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst5 = _mm256_fmadd_ps(dst5, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst6 = _mm256_fmadd_ps(dst6, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst7 = _mm256_fmadd_ps(dst7, src25, weight15); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst8 = _mm256_fmadd_ps(dst8, src35, weight15); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + dst9 = _mm256_fmadd_ps(dst9, src45, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst5 = _mm256_fmadd_ps(dst5, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst6 = _mm256_fmadd_ps(dst6, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst7 = _mm256_fmadd_ps(dst7, src26, weight16); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst8 = _mm256_fmadd_ps(dst8, src36, weight16); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + dst9 = _mm256_fmadd_ps(dst9, src46, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst5 = _mm256_fmadd_ps(dst5, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst6 = _mm256_fmadd_ps(dst6, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst7 = _mm256_fmadd_ps(dst7, src27, weight17); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst8 = _mm256_fmadd_ps(dst8, src37, weight17); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + dst9 = _mm256_fmadd_ps(dst9, src47, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 1 * src_stride + 0, dst5); + _mm256_store_ps(dst + 1 * src_stride + 8, dst6); + _mm256_store_ps(dst + 1 * src_stride + 16, dst7); + _mm256_store_ps(dst + 1 * src_stride + 24, dst8); + _mm256_store_ps(dst + 1 * src_stride + 32, dst9); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..56f6aed3c69839963f62104564732c0e4bc67996 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,271 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm7\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm8\n" + "vmovups 128(%[dst], %[dst_stride], 1), %%ymm9\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "vmovaps 32(%[bias]), %%ymm8\n" + "vmovaps 32(%[bias]), %%ymm9\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm8, 96(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm9, 128(%[dst], %[dst_stride], 1)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c3a6e1a6197c5cfa6ac970655fac724c1936e2a2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,177 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..069f81ee1a8fd43f64676f83eadd18f698cbebd2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,193 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5e074d19f9a637e8800e350f2d5ffd4153bc76eb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,305 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst6; + __m256 dst1; + __m256 dst7; + __m256 dst2; + __m256 dst8; + __m256 dst3; + __m256 dst9; + __m256 dst4; + __m256 dst10; + __m256 dst5; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst8 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst9 = _mm256_load_ps(dst + 1 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst10 = _mm256_load_ps(dst + 1 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst11 = _mm256_load_ps(dst + 1 * dst_stride + 40); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 8); + dst3 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 0); + dst10 = _mm256_load_ps(bias + 8); + dst5 = _mm256_load_ps(bias + 0); + dst11 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst6 = _mm256_fmadd_ps(dst6, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst7 = _mm256_fmadd_ps(dst7, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst8 = _mm256_fmadd_ps(dst8, src20, weight10); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst9 = _mm256_fmadd_ps(dst9, src30, weight10); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + dst10 = _mm256_fmadd_ps(dst10, src40, weight10); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + dst11 = _mm256_fmadd_ps(dst11, src50, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst6 = _mm256_fmadd_ps(dst6, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst7 = _mm256_fmadd_ps(dst7, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst8 = _mm256_fmadd_ps(dst8, src21, weight11); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst9 = _mm256_fmadd_ps(dst9, src31, weight11); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + dst10 = _mm256_fmadd_ps(dst10, src41, weight11); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + dst11 = _mm256_fmadd_ps(dst11, src51, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst6 = _mm256_fmadd_ps(dst6, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst7 = _mm256_fmadd_ps(dst7, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst8 = _mm256_fmadd_ps(dst8, src22, weight12); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst9 = _mm256_fmadd_ps(dst9, src32, weight12); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + dst10 = _mm256_fmadd_ps(dst10, src42, weight12); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + dst11 = _mm256_fmadd_ps(dst11, src52, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst6 = _mm256_fmadd_ps(dst6, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst7 = _mm256_fmadd_ps(dst7, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst8 = _mm256_fmadd_ps(dst8, src23, weight13); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst9 = _mm256_fmadd_ps(dst9, src33, weight13); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + dst10 = _mm256_fmadd_ps(dst10, src43, weight13); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + dst11 = _mm256_fmadd_ps(dst11, src53, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst6 = _mm256_fmadd_ps(dst6, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst7 = _mm256_fmadd_ps(dst7, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst8 = _mm256_fmadd_ps(dst8, src24, weight14); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst9 = _mm256_fmadd_ps(dst9, src34, weight14); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + dst10 = _mm256_fmadd_ps(dst10, src44, weight14); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + dst11 = _mm256_fmadd_ps(dst11, src54, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst6 = _mm256_fmadd_ps(dst6, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst7 = _mm256_fmadd_ps(dst7, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst8 = _mm256_fmadd_ps(dst8, src25, weight15); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst9 = _mm256_fmadd_ps(dst9, src35, weight15); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + dst10 = _mm256_fmadd_ps(dst10, src45, weight15); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + dst11 = _mm256_fmadd_ps(dst11, src55, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst6 = _mm256_fmadd_ps(dst6, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst7 = _mm256_fmadd_ps(dst7, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst8 = _mm256_fmadd_ps(dst8, src26, weight16); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst9 = _mm256_fmadd_ps(dst9, src36, weight16); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + dst10 = _mm256_fmadd_ps(dst10, src46, weight16); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + dst11 = _mm256_fmadd_ps(dst11, src56, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst6 = _mm256_fmadd_ps(dst6, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst7 = _mm256_fmadd_ps(dst7, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst8 = _mm256_fmadd_ps(dst8, src27, weight17); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst9 = _mm256_fmadd_ps(dst9, src37, weight17); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + dst10 = _mm256_fmadd_ps(dst10, src47, weight17); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + dst11 = _mm256_fmadd_ps(dst11, src57, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 1 * src_stride + 0, dst6); + _mm256_store_ps(dst + 1 * src_stride + 8, dst7); + _mm256_store_ps(dst + 1 * src_stride + 16, dst8); + _mm256_store_ps(dst + 1 * src_stride + 24, dst9); + _mm256_store_ps(dst + 1 * src_stride + 32, dst10); + _mm256_store_ps(dst + 1 * src_stride + 40, dst11); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..bb47361dd31c828e447c4b022ccab4696214f598 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,307 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm7\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm8\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm9\n" + "vmovups 128(%[dst], %[dst_stride], 1), %%ymm10\n" + "vmovups 160(%[dst], %[dst_stride], 1), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "vmovaps 32(%[bias]), %%ymm8\n" + "vmovaps 32(%[bias]), %%ymm9\n" + "vmovaps 32(%[bias]), %%ymm10\n" + "vmovaps 32(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm8, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm9, 96(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm10, 128(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm11, 160(%[dst], %[dst_stride], 1)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7efdcd8d0e0f5376860d29693b87a4396764ece7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,201 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..7332ad3917a41c8942522e428519ac4467de1562 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,215 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e99e39da7ebab9fa149b4ca4a33dfe50acf1f4d1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,225 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..2d57354543c2aed39f541d6514bfee191bccdce4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,237 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e11e2af8412c26e329a8d2630b41c35d71b00581 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,249 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..97ad1c3bd04c2a04d10cc49b2eb1ea572cab4177 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,259 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..3dc30cff4e321a8d6f2d5981ad462a86fba2797b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,273 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 0000000000000000000000000000000000000000..599aa518755147ac6c05fc40e64d65e749f849f5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,281 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..283c6d764964051621de82778e4180a11c4e803a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,536 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_9]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b63e47e945810be294091179fdf49806aa1cbcb2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,784 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 0(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 4(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 8(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 12(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 16(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 20(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 24(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 28(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 32(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 36(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 40(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 44(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 48(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 52(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 56(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 60(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vbroadcastss 0(%[src_9]), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm20, %%zmm19 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5038ef8de6bff54f52cb93fd8b54a5523cb08427 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,577 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm10\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", + "%zmm10"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_9]) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_9], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d6fbf5e94a3ddfca8f9eab8f03152f82d4a61686 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,847 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 32(%[src_6]), %%zmm29\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_9]), %%zmm26\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 36(%[src_6]), %%zmm29\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_9]), %%zmm26\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 40(%[src_6]), %%zmm29\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_9]), %%zmm26\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 44(%[src_6]), %%zmm29\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_9]), %%zmm26\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 48(%[src_6]), %%zmm29\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_9]), %%zmm26\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 52(%[src_6]), %%zmm29\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_9]), %%zmm26\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 56(%[src_6]), %%zmm29\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_9]), %%zmm26\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 60(%[src_6]), %%zmm29\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_9]), %%zmm26\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9]) %{{%%k1}}\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5dae233986ec69df82ad49bfeed9e9738b36a937 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,617 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_9]), %%zmm9\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm10\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 0(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 4(%[src_9]), %%zmm21\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 8(%[src_9]), %%zmm21\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 12(%[src_9]), %%zmm21\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 16(%[src_9]), %%zmm21\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 20(%[src_9]), %%zmm21\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 24(%[src_9]), %%zmm21\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 28(%[src_9]), %%zmm21\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 32(%[src_9]), %%zmm21\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 32(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 36(%[src_9]), %%zmm21\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 36(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 40(%[src_9]), %%zmm21\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 40(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 44(%[src_9]), %%zmm21\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 44(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 48(%[src_9]), %%zmm21\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 48(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 52(%[src_9]), %%zmm21\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 52(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 56(%[src_9]), %%zmm21\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 56(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 60(%[src_9]), %%zmm21\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 60(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vbroadcastss 0(%[src_9]), %%zmm21\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm20\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm20, %%zmm10 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm19, %%zmm11 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10 %{{%%k1}}\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10 %{{%%k1}}\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_9]) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_9], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm11, 0(%[dst_9], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..18717026d0880062552215c2039e51c58ba6826d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,911 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm22\n" + "vmovups 64(%[dst_9], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 0(%[bias]), %%zmm22\n" + "vmovups 64(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 32(%[src_6]), %%zmm29\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_9]), %%zmm26\n" + "vbroadcastss 32(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 36(%[src_6]), %%zmm29\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_9]), %%zmm26\n" + "vbroadcastss 36(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 40(%[src_6]), %%zmm29\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_9]), %%zmm26\n" + "vbroadcastss 40(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 44(%[src_6]), %%zmm29\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_9]), %%zmm26\n" + "vbroadcastss 44(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 48(%[src_6]), %%zmm29\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_9]), %%zmm26\n" + "vbroadcastss 48(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 52(%[src_6]), %%zmm29\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_9]), %%zmm26\n" + "vbroadcastss 52(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 56(%[src_6]), %%zmm29\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_9]), %%zmm26\n" + "vbroadcastss 56(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 60(%[src_6]), %%zmm29\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_9]), %%zmm26\n" + "vbroadcastss 60(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "add $64, %[src_9]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm21 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm23 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "add $4, %[src_9]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21 %{{%%k1}}\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21 %{{%%k1}}\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9]) %{{%%k1}}\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm22, 0(%[dst_9], %[dst_stride], 2)\n" + "vmovups %%zmm23, 64(%[dst_9], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), + [ src_6 ] "r"(src_6), [ src_9 ] "r"(src_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..80a0d095301c5ee33729ace27cd49713b0bcd82b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,161 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..90049db3463d5a1e874120abc36ff91ba7bd472f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,201 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f07c59a6f03c76c284b29ae0c9b84b43ba5d74c8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,241 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..de6113370a3a22a17d4e30233dc10c1888044159 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,281 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..fbe109434c5c6ca7afbe3108357690ef055c22e1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,321 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f3784928e001fe246729366e19a8372b9a368f75 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32.c @@ -0,0 +1,361 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x96_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..5203ba4658f3d1c8db531b88f55d47e76eba586f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,201 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c5b57a151273c1b257371eda984b1c30fe6abec0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,264 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f6402381c19b172ab342c0cf38b02b54b9f94316 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,327 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..eca8ff53f310fe081e225deda8f8f209fd9a691e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,390 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f7e2d532af1755b028da76a0d564015e7e6107bf --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,453 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..49944f84bc5f180772de7cbd36fa41c90625b782 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32.c @@ -0,0 +1,517 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x96_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f585e92b31e1d179fd1f355cad001c360fb98f1c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,241 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..2bca2ffe59b9c193ddae633ecc7d6249247dddd5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,327 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..50d0b57ddc5b0fec749dd7be03a3c7d91340b836 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,413 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..45f27ca12b16ed639473122a30141c4735439c7e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,500 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d00f8ef929ca2123d71a212ae2e86f6eee686a52 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,586 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f97f061dfd6e05473e67da067d2d598d64ae2b75 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32.c @@ -0,0 +1,672 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x96_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 256(%[bias]), %%zmm16\n" + "vmovups 320(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17 %{{%%k1}}\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..3eb27840e9585ec4acb700975734c97582b08868 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,286 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..47c717dffce32a84437687055ac416610e898702 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,395 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d46b1fef4f3f1b19563f4f6531dd3ed407c6c4dd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,505 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7964557645fd4d32908f65709ec2c99e1962f241 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,614 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..471b3b678f1d563d630f475c7ef8101775aad479 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,723 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 192(%[bias]), %%zmm18\n" + "vmovups 256(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..dea26f157c817c9830be968121c6477f19520991 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32.c @@ -0,0 +1,833 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x96_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_3]), %%zmm18\n" + "vmovups 64(%[dst_3]), %%zmm19\n" + "vmovups 128(%[dst_3]), %%zmm20\n" + "vmovups 192(%[dst_3]), %%zmm21\n" + "vmovups 256(%[dst_3]), %%zmm22\n" + "vmovups 320(%[dst_3]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 320(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 192(%[bias]), %%zmm9\n" + "vmovups 256(%[bias]), %%zmm10\n" + "vmovups 320(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 256(%[bias]), %%zmm16\n" + "vmovups 320(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "vmovups 192(%[bias]), %%zmm21\n" + "vmovups 256(%[bias]), %%zmm22\n" + "vmovups 320(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 8 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vmovups 3328(%[weight]), %%zmm27\n" + "vmovups 3392(%[weight]), %%zmm26\n" + "vbroadcastss 32(%[src_0]), %%zmm25\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 9 + "vmovups 3456(%[weight]), %%zmm31\n" + "vmovups 3520(%[weight]), %%zmm30\n" + "vmovups 3584(%[weight]), %%zmm29\n" + "vmovups 3648(%[weight]), %%zmm28\n" + "vmovups 3712(%[weight]), %%zmm27\n" + "vmovups 3776(%[weight]), %%zmm26\n" + "vbroadcastss 36(%[src_0]), %%zmm25\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 10 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vmovups 4160(%[weight]), %%zmm26\n" + "vbroadcastss 40(%[src_0]), %%zmm25\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 11 + "vmovups 4224(%[weight]), %%zmm31\n" + "vmovups 4288(%[weight]), %%zmm30\n" + "vmovups 4352(%[weight]), %%zmm29\n" + "vmovups 4416(%[weight]), %%zmm28\n" + "vmovups 4480(%[weight]), %%zmm27\n" + "vmovups 4544(%[weight]), %%zmm26\n" + "vbroadcastss 44(%[src_0]), %%zmm25\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 12 + "vmovups 4608(%[weight]), %%zmm31\n" + "vmovups 4672(%[weight]), %%zmm30\n" + "vmovups 4736(%[weight]), %%zmm29\n" + "vmovups 4800(%[weight]), %%zmm28\n" + "vmovups 4864(%[weight]), %%zmm27\n" + "vmovups 4928(%[weight]), %%zmm26\n" + "vbroadcastss 48(%[src_0]), %%zmm25\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 13 + "vmovups 4992(%[weight]), %%zmm31\n" + "vmovups 5056(%[weight]), %%zmm30\n" + "vmovups 5120(%[weight]), %%zmm29\n" + "vmovups 5184(%[weight]), %%zmm28\n" + "vmovups 5248(%[weight]), %%zmm27\n" + "vmovups 5312(%[weight]), %%zmm26\n" + "vbroadcastss 52(%[src_0]), %%zmm25\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 14 + "vmovups 5376(%[weight]), %%zmm31\n" + "vmovups 5440(%[weight]), %%zmm30\n" + "vmovups 5504(%[weight]), %%zmm29\n" + "vmovups 5568(%[weight]), %%zmm28\n" + "vmovups 5632(%[weight]), %%zmm27\n" + "vmovups 5696(%[weight]), %%zmm26\n" + "vbroadcastss 56(%[src_0]), %%zmm25\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + // block 15 + "vmovups 5760(%[weight]), %%zmm31\n" + "vmovups 5824(%[weight]), %%zmm30\n" + "vmovups 5888(%[weight]), %%zmm29\n" + "vmovups 5952(%[weight]), %%zmm28\n" + "vmovups 6016(%[weight]), %%zmm27\n" + "vmovups 6080(%[weight]), %%zmm26\n" + "vbroadcastss 60(%[src_0]), %%zmm25\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + "add $6144, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23 %{{%%k1}}\n" + "add $384, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_3])\n" + "vmovups %%zmm19, 64(%[dst_3])\n" + "vmovups %%zmm20, 128(%[dst_3])\n" + "vmovups %%zmm21, 192(%[dst_3])\n" + "vmovups %%zmm22, 256(%[dst_3])\n" + "vmovups %%zmm23, 320(%[dst_3]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..058fa02fd1d46631e81d64be843270d4399e0d0e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,326 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c8d2da93b6e029df53cbc0d4f31a383c4d5ec4b0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,458 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b86b606886c6190dc30d5b8d02b43ccada60e9df --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,591 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..53272ed557a5c365586d1162044eb7b378c944eb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,723 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 128(%[bias]), %%zmm18\n" + "vmovups 192(%[bias]), %%zmm19\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7265a964348c455c51a7d9b903fc872118a71026 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32.c @@ -0,0 +1,856 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x80_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm23\n" + "vmovups 256(%[dst_3], %[dst_stride], 1), %%zmm24\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 256(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 64(%[bias]), %%zmm6\n" + "vmovups 128(%[bias]), %%zmm7\n" + "vmovups 192(%[bias]), %%zmm8\n" + "vmovups 256(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 128(%[bias]), %%zmm12\n" + "vmovups 192(%[bias]), %%zmm13\n" + "vmovups 256(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 192(%[bias]), %%zmm18\n" + "vmovups 256(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 128(%[bias]), %%zmm22\n" + "vmovups 192(%[bias]), %%zmm23\n" + "vmovups 256(%[bias]), %%zmm24\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + "vxorps %%zmm24, %%zmm24, %%zmm24\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 8 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vmovups 2816(%[weight]), %%zmm27\n" + "vbroadcastss 32(%[src_0]), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 9 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vmovups 3072(%[weight]), %%zmm28\n" + "vmovups 3136(%[weight]), %%zmm27\n" + "vbroadcastss 36(%[src_0]), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 10 + "vmovups 3200(%[weight]), %%zmm31\n" + "vmovups 3264(%[weight]), %%zmm30\n" + "vmovups 3328(%[weight]), %%zmm29\n" + "vmovups 3392(%[weight]), %%zmm28\n" + "vmovups 3456(%[weight]), %%zmm27\n" + "vbroadcastss 40(%[src_0]), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 11 + "vmovups 3520(%[weight]), %%zmm31\n" + "vmovups 3584(%[weight]), %%zmm30\n" + "vmovups 3648(%[weight]), %%zmm29\n" + "vmovups 3712(%[weight]), %%zmm28\n" + "vmovups 3776(%[weight]), %%zmm27\n" + "vbroadcastss 44(%[src_0]), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 12 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vmovups 4096(%[weight]), %%zmm27\n" + "vbroadcastss 48(%[src_0]), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 13 + "vmovups 4160(%[weight]), %%zmm31\n" + "vmovups 4224(%[weight]), %%zmm30\n" + "vmovups 4288(%[weight]), %%zmm29\n" + "vmovups 4352(%[weight]), %%zmm28\n" + "vmovups 4416(%[weight]), %%zmm27\n" + "vbroadcastss 52(%[src_0]), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 14 + "vmovups 4480(%[weight]), %%zmm31\n" + "vmovups 4544(%[weight]), %%zmm30\n" + "vmovups 4608(%[weight]), %%zmm29\n" + "vmovups 4672(%[weight]), %%zmm28\n" + "vmovups 4736(%[weight]), %%zmm27\n" + "vbroadcastss 56(%[src_0]), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + // block 15 + "vmovups 4800(%[weight]), %%zmm31\n" + "vmovups 4864(%[weight]), %%zmm30\n" + "vmovups 4928(%[weight]), %%zmm29\n" + "vmovups 4992(%[weight]), %%zmm28\n" + "vmovups 5056(%[weight]), %%zmm27\n" + "vbroadcastss 60(%[src_0]), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + "add $5120, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24 %{{%%k1}}\n" + "add $320, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "vmaxps %%zmm24, %%zmm31, %%zmm24 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + "vminps %%zmm24, %%zmm30, %%zmm24 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..0f8b33a0c8527ed0d23fbff11b4e1d6a3f4dc7b9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,366 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..02950459fae47a9a64dfb47abfa11252ac561ca9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,522 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c98b28a06e087dd3ebd14b43a7ff8bd0316de08f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,677 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a829d46edd23e07b977acd5685176b34a2f4440b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32.c @@ -0,0 +1,833 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 192(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 128(%[bias]), %%zmm6\n" + "vmovups 192(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 128(%[bias]), %%zmm10\n" + "vmovups 192(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 192(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "vmovups 128(%[bias]), %%zmm18\n" + "vmovups 192(%[bias]), %%zmm19\n" + "vmovups 0(%[bias]), %%zmm20\n" + "vmovups 64(%[bias]), %%zmm21\n" + "vmovups 128(%[bias]), %%zmm22\n" + "vmovups 192(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 8 + "vmovups 2048(%[weight]), %%zmm31\n" + "vmovups 2112(%[weight]), %%zmm30\n" + "vmovups 2176(%[weight]), %%zmm29\n" + "vmovups 2240(%[weight]), %%zmm28\n" + "vbroadcastss 32(%[src_0]), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 9 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vbroadcastss 36(%[src_0]), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 10 + "vmovups 2560(%[weight]), %%zmm31\n" + "vmovups 2624(%[weight]), %%zmm30\n" + "vmovups 2688(%[weight]), %%zmm29\n" + "vmovups 2752(%[weight]), %%zmm28\n" + "vbroadcastss 40(%[src_0]), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 11 + "vmovups 2816(%[weight]), %%zmm31\n" + "vmovups 2880(%[weight]), %%zmm30\n" + "vmovups 2944(%[weight]), %%zmm29\n" + "vmovups 3008(%[weight]), %%zmm28\n" + "vbroadcastss 44(%[src_0]), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 12 + "vmovups 3072(%[weight]), %%zmm31\n" + "vmovups 3136(%[weight]), %%zmm30\n" + "vmovups 3200(%[weight]), %%zmm29\n" + "vmovups 3264(%[weight]), %%zmm28\n" + "vbroadcastss 48(%[src_0]), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 13 + "vmovups 3328(%[weight]), %%zmm31\n" + "vmovups 3392(%[weight]), %%zmm30\n" + "vmovups 3456(%[weight]), %%zmm29\n" + "vmovups 3520(%[weight]), %%zmm28\n" + "vbroadcastss 52(%[src_0]), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 14 + "vmovups 3584(%[weight]), %%zmm31\n" + "vmovups 3648(%[weight]), %%zmm30\n" + "vmovups 3712(%[weight]), %%zmm29\n" + "vmovups 3776(%[weight]), %%zmm28\n" + "vbroadcastss 56(%[src_0]), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 15 + "vmovups 3840(%[weight]), %%zmm31\n" + "vmovups 3904(%[weight]), %%zmm30\n" + "vmovups 3968(%[weight]), %%zmm29\n" + "vmovups 4032(%[weight]), %%zmm28\n" + "vbroadcastss 60(%[src_0]), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + "add $4096, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23 %{{%%k1}}\n" + "add $256, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19 %{{%%k1}}\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19 %{{%%k1}}\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b8cbfee16823e48464b59aa521d13f9d86bbc75c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,410 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..65d050f07ae638687e23e60242fda81cd2313040 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,589 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8e9ea4690d0eb8a476862e1b1d372110f189ba31 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,767 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_6]), %%zmm18\n" + "vmovups 64(%[dst_6]), %%zmm19\n" + "vmovups 128(%[dst_6]), %%zmm20\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 0(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 4(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 8(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 12(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 16(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 20(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 24(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 28(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 32(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 36(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 40(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 44(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 48(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 52(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 56(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 60(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm23\n" + "vbroadcastss 0(%[src_6]), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm22, %%zmm20 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_6])\n" + "vmovups %%zmm19, 64(%[dst_6])\n" + "vmovups %%zmm20, 128(%[dst_6]) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4e3cdde5195281db7e52ab2a5a5d658165e29620 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,450 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d09d597cd84c669879f3cfe80eee5998c55b706f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,652 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..f51d815fab6d47fc72f672893853d07172204e85 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32.c @@ -0,0 +1,854 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 0(%[dst_3]), %%zmm9\n" + "vmovups 64(%[dst_3]), %%zmm10\n" + "vmovups 128(%[dst_3]), %%zmm11\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm12\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm13\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm14\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm15\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm16\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_6]), %%zmm18\n" + "vmovups 64(%[dst_6]), %%zmm19\n" + "vmovups 128(%[dst_6]), %%zmm20\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm21\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm22\n" + "vmovups 128(%[dst_6], %[dst_stride], 1), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 128(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 64(%[bias]), %%zmm4\n" + "vmovups 128(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 128(%[bias]), %%zmm8\n" + "vmovups 0(%[bias]), %%zmm9\n" + "vmovups 64(%[bias]), %%zmm10\n" + "vmovups 128(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 128(%[bias]), %%zmm14\n" + "vmovups 0(%[bias]), %%zmm15\n" + "vmovups 64(%[bias]), %%zmm16\n" + "vmovups 128(%[bias]), %%zmm17\n" + "vmovups 0(%[bias]), %%zmm18\n" + "vmovups 64(%[bias]), %%zmm19\n" + "vmovups 128(%[bias]), %%zmm20\n" + "vmovups 0(%[bias]), %%zmm21\n" + "vmovups 64(%[bias]), %%zmm22\n" + "vmovups 128(%[bias]), %%zmm23\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_6]), %%zmm27\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 1 + "vmovups 192(%[weight]), %%zmm31\n" + "vmovups 256(%[weight]), %%zmm30\n" + "vmovups 320(%[weight]), %%zmm29\n" + "vbroadcastss 4(%[src_0]), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_6]), %%zmm27\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 2 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vbroadcastss 8(%[src_0]), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_6]), %%zmm27\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 3 + "vmovups 576(%[weight]), %%zmm31\n" + "vmovups 640(%[weight]), %%zmm30\n" + "vmovups 704(%[weight]), %%zmm29\n" + "vbroadcastss 12(%[src_0]), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_6]), %%zmm27\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 4 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vbroadcastss 16(%[src_0]), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_6]), %%zmm27\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 5 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vbroadcastss 20(%[src_0]), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_6]), %%zmm27\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 6 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vbroadcastss 24(%[src_0]), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_6]), %%zmm27\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 7 + "vmovups 1344(%[weight]), %%zmm31\n" + "vmovups 1408(%[weight]), %%zmm30\n" + "vmovups 1472(%[weight]), %%zmm29\n" + "vbroadcastss 28(%[src_0]), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_6]), %%zmm27\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 8 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vbroadcastss 32(%[src_0]), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 32(%[src_3]), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_6]), %%zmm27\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 9 + "vmovups 1728(%[weight]), %%zmm31\n" + "vmovups 1792(%[weight]), %%zmm30\n" + "vmovups 1856(%[weight]), %%zmm29\n" + "vbroadcastss 36(%[src_0]), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 36(%[src_3]), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_6]), %%zmm27\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 10 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vbroadcastss 40(%[src_0]), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 40(%[src_3]), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_6]), %%zmm27\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 11 + "vmovups 2112(%[weight]), %%zmm31\n" + "vmovups 2176(%[weight]), %%zmm30\n" + "vmovups 2240(%[weight]), %%zmm29\n" + "vbroadcastss 44(%[src_0]), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 44(%[src_3]), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_6]), %%zmm27\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 12 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vbroadcastss 48(%[src_0]), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 48(%[src_3]), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_6]), %%zmm27\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 13 + "vmovups 2496(%[weight]), %%zmm31\n" + "vmovups 2560(%[weight]), %%zmm30\n" + "vmovups 2624(%[weight]), %%zmm29\n" + "vbroadcastss 52(%[src_0]), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 52(%[src_3]), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_6]), %%zmm27\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 14 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vbroadcastss 56(%[src_0]), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 56(%[src_3]), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_6]), %%zmm27\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + // block 15 + "vmovups 2880(%[weight]), %%zmm31\n" + "vmovups 2944(%[weight]), %%zmm30\n" + "vmovups 3008(%[weight]), %%zmm29\n" + "vbroadcastss 60(%[src_0]), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 60(%[src_3]), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_6]), %%zmm27\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + "add $3072, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vbroadcastss 0(%[src_0]), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm8 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14 %{{%%k1}}\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_6]), %%zmm27\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm28, %%zmm17 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm20 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm23 %{{%%k1}}\n" + "add $192, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14 %{{%%k1}}\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20 %{{%%k1}}\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14 %{{%%k1}}\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20 %{{%%k1}}\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 128(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm9, 0(%[dst_3])\n" + "vmovups %%zmm10, 64(%[dst_3])\n" + "vmovups %%zmm11, 128(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm13, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm14, 128(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm15, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm16, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm17, 128(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm18, 0(%[dst_6])\n" + "vmovups %%zmm19, 64(%[dst_6])\n" + "vmovups %%zmm20, 128(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm21, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm22, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm23, 128(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7ec4f9482f1a053bad10626f9c4109fbbfcb9cf3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32.c @@ -0,0 +1,490 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm2\n" + "vmovups 0(%[dst_3]), %%zmm3\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm4\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_6]), %%zmm6\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 0(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 0(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 0(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 0(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 1 + "vmovups 64(%[weight]), %%zmm31\n" + "vbroadcastss 4(%[src_0]), %%zmm30\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 4(%[src_3]), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_6]), %%zmm24\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 2 + "vmovups 128(%[weight]), %%zmm31\n" + "vbroadcastss 8(%[src_0]), %%zmm30\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 8(%[src_3]), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_6]), %%zmm24\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 3 + "vmovups 192(%[weight]), %%zmm31\n" + "vbroadcastss 12(%[src_0]), %%zmm30\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 12(%[src_3]), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_6]), %%zmm24\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 4 + "vmovups 256(%[weight]), %%zmm31\n" + "vbroadcastss 16(%[src_0]), %%zmm30\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 16(%[src_3]), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_6]), %%zmm24\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 5 + "vmovups 320(%[weight]), %%zmm31\n" + "vbroadcastss 20(%[src_0]), %%zmm30\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 20(%[src_3]), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_6]), %%zmm24\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 6 + "vmovups 384(%[weight]), %%zmm31\n" + "vbroadcastss 24(%[src_0]), %%zmm30\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 24(%[src_3]), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_6]), %%zmm24\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 7 + "vmovups 448(%[weight]), %%zmm31\n" + "vbroadcastss 28(%[src_0]), %%zmm30\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 28(%[src_3]), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_6]), %%zmm24\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 8 + "vmovups 512(%[weight]), %%zmm31\n" + "vbroadcastss 32(%[src_0]), %%zmm30\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 32(%[src_3]), %%zmm27\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 32(%[src_6]), %%zmm24\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 9 + "vmovups 576(%[weight]), %%zmm31\n" + "vbroadcastss 36(%[src_0]), %%zmm30\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 36(%[src_3]), %%zmm27\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 36(%[src_6]), %%zmm24\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 10 + "vmovups 640(%[weight]), %%zmm31\n" + "vbroadcastss 40(%[src_0]), %%zmm30\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 40(%[src_3]), %%zmm27\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 40(%[src_6]), %%zmm24\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 11 + "vmovups 704(%[weight]), %%zmm31\n" + "vbroadcastss 44(%[src_0]), %%zmm30\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 44(%[src_3]), %%zmm27\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 44(%[src_6]), %%zmm24\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 12 + "vmovups 768(%[weight]), %%zmm31\n" + "vbroadcastss 48(%[src_0]), %%zmm30\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 48(%[src_3]), %%zmm27\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 48(%[src_6]), %%zmm24\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 13 + "vmovups 832(%[weight]), %%zmm31\n" + "vbroadcastss 52(%[src_0]), %%zmm30\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 52(%[src_3]), %%zmm27\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 52(%[src_6]), %%zmm24\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 14 + "vmovups 896(%[weight]), %%zmm31\n" + "vbroadcastss 56(%[src_0]), %%zmm30\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 56(%[src_3]), %%zmm27\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 56(%[src_6]), %%zmm24\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + // block 15 + "vmovups 960(%[weight]), %%zmm31\n" + "vbroadcastss 60(%[src_0]), %%zmm30\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 60(%[src_3]), %%zmm27\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 60(%[src_6]), %%zmm24\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "add $1024, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vbroadcastss 0(%[src_0]), %%zmm30\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm28\n" + "vbroadcastss 0(%[src_3]), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_6]), %%zmm24\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm8 %{{%%k1}}\n" + "add $64, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0 %{{%%k1}}\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2 %{{%%k1}}\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4 %{{%%k1}}\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6 %{{%%k1}}\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0 %{{%%k1}}\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2 %{{%%k1}}\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4 %{{%%k1}}\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6 %{{%%k1}}\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm1, 0(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm3, 0(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm5, 0(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm7, 0(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..92b2bddd793dfae5d0ac68722940f6d57818198b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/gemm_mask_avx512/nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32.c @@ -0,0 +1,715 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +// clang-format off +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in depth + "movq %[inc_flag], %%rax\n" + "kmovw (%[mask]), %%k1\n" + "and $0x1, %%rax\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovups 0(%[bias]), %%zmm0\n" + "vmovups 64(%[bias]), %%zmm1\n" + "vmovups 0(%[bias]), %%zmm2\n" + "vmovups 64(%[bias]), %%zmm3\n" + "vmovups 0(%[bias]), %%zmm4\n" + "vmovups 64(%[bias]), %%zmm5\n" + "vmovups 0(%[bias]), %%zmm6\n" + "vmovups 64(%[bias]), %%zmm7\n" + "vmovups 0(%[bias]), %%zmm8\n" + "vmovups 64(%[bias]), %%zmm9\n" + "vmovups 0(%[bias]), %%zmm10\n" + "vmovups 64(%[bias]), %%zmm11\n" + "vmovups 0(%[bias]), %%zmm12\n" + "vmovups 64(%[bias]), %%zmm13\n" + "vmovups 0(%[bias]), %%zmm14\n" + "vmovups 64(%[bias]), %%zmm15\n" + "vmovups 0(%[bias]), %%zmm16\n" + "vmovups 64(%[bias]), %%zmm17\n" + "jmp 2f\n" + ".align 16\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + ".align 16\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), [ mask ] "r"(mask), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%rax", "%k1", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %%k1\n" + "cmp $16, %[depth]\n" + "jb 1f\n" + ".align 16\n" + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 8 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vbroadcastss 32(%[src_0]), %%zmm29\n" + "vbroadcastss 32(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 32(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 32(%[src_3]), %%zmm26\n" + "vbroadcastss 32(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 32(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 32(%[src_6]), %%zmm23\n" + "vbroadcastss 32(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 32(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 9 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vbroadcastss 36(%[src_0]), %%zmm29\n" + "vbroadcastss 36(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 36(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 36(%[src_3]), %%zmm26\n" + "vbroadcastss 36(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 36(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 36(%[src_6]), %%zmm23\n" + "vbroadcastss 36(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 36(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 10 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vbroadcastss 40(%[src_0]), %%zmm29\n" + "vbroadcastss 40(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 40(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 40(%[src_3]), %%zmm26\n" + "vbroadcastss 40(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 40(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 40(%[src_6]), %%zmm23\n" + "vbroadcastss 40(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 40(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 11 + "vmovups 1408(%[weight]), %%zmm31\n" + "vmovups 1472(%[weight]), %%zmm30\n" + "vbroadcastss 44(%[src_0]), %%zmm29\n" + "vbroadcastss 44(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 44(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 44(%[src_3]), %%zmm26\n" + "vbroadcastss 44(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 44(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 44(%[src_6]), %%zmm23\n" + "vbroadcastss 44(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 44(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 12 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vbroadcastss 48(%[src_0]), %%zmm29\n" + "vbroadcastss 48(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 48(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 48(%[src_3]), %%zmm26\n" + "vbroadcastss 48(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 48(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 48(%[src_6]), %%zmm23\n" + "vbroadcastss 48(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 48(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 13 + "vmovups 1664(%[weight]), %%zmm31\n" + "vmovups 1728(%[weight]), %%zmm30\n" + "vbroadcastss 52(%[src_0]), %%zmm29\n" + "vbroadcastss 52(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 52(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 52(%[src_3]), %%zmm26\n" + "vbroadcastss 52(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 52(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 52(%[src_6]), %%zmm23\n" + "vbroadcastss 52(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 52(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 14 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vbroadcastss 56(%[src_0]), %%zmm29\n" + "vbroadcastss 56(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 56(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 56(%[src_3]), %%zmm26\n" + "vbroadcastss 56(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 56(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 56(%[src_6]), %%zmm23\n" + "vbroadcastss 56(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 56(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + // block 15 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vbroadcastss 60(%[src_0]), %%zmm29\n" + "vbroadcastss 60(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 60(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 60(%[src_3]), %%zmm26\n" + "vbroadcastss 60(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 60(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 60(%[src_6]), %%zmm23\n" + "vbroadcastss 60(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 60(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "add $2048, %[weight]\n" + "add $64, %[src_0]\n" + "add $64, %[src_3]\n" + "add $64, %[src_6]\n" + "sub $16, %[depth]\n" + "cmp $16, %[depth]\n" + "jge 0b\n" + "cmp $0, %[depth]\n" + "je 2f\n" + ".align 16\n" + "1:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm21\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15 %{{%%k1}}\n" + "vfmadd231ps %%zmm31, %%zmm21, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm21, %%zmm17 %{{%%k1}}\n" + "add $128, %[weight]\n" + "add $4, %[src_0]\n" + "add $4, %[src_3]\n" + "add $4, %[src_6]\n" + "dec %[depth]\n" + "jg 1b\n" + ".align 16\n" + "2:\n" + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "and $0x3, %[act_flag]\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1 %{{%%k1}}\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3 %{{%%k1}}\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5 %{{%%k1}}\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7 %{{%%k1}}\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9 %{{%%k1}}\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11 %{{%%k1}}\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13 %{{%%k1}}\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15 %{{%%k1}}\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17 %{{%%k1}}\n" + "and $0x1, %[act_flag]\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1 %{{%%k1}}\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3 %{{%%k1}}\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5 %{{%%k1}}\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7 %{{%%k1}}\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9 %{{%%k1}}\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11 %{{%%k1}}\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13 %{{%%k1}}\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15 %{{%%k1}}\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17 %{{%%k1}}\n" + ".align 16\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0]) %{{%%k1}}\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3]) %{{%%k1}}\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2) %{{%%k1}}\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6]) %{{%%k1}}\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1) %{{%%k1}}\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2) %{{%%k1}}\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ depth ] "r"(depth), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ mask ] "r"(mask), [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/generate_hpc.sh b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/generate_hpc.sh new file mode 100644 index 0000000000000000000000000000000000000000..cd2fbf519b4bae376ca3cef489e7c8099e011ea1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/generate_hpc.sh @@ -0,0 +1,88 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +CRTDIR=$( cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# generate gemm fma asm code +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c +# +## generate gemm fma intrinics code +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c + +# generate gemm avx512 asm code +n=(96 80 64 48 32 16) +m=(4 5 6 8 12 12) +for ((index = 0; index < 6; index++)) +do + for ((row = 1; row <= ${m[index]}; row++)) + do + dst_file=$CRTDIR"/gemm_avx512/nnacl_gemm_avx512_$row""x${n[index]}_kernel_nhwc_fp32.c" + python3 $CRTDIR/generator.py -I $CRTDIR/template_file/gemm_avx512_nhwc_asm.c.in -A row_block=$row col_block=${n[index]} -O $dst_file + done +done diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/generator.py b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf25ab550f6ce8b4784b716ea4cc6711ac7666a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/generator.py @@ -0,0 +1,160 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""HPC generator""" + +import sys +import os +import io +import argparse +from itertools import chain + +def key_value_pair(line): + """ + split key and value + :param line: + :return: + """ + key = None + value = None + try: + key, value = line.split("=", 1) + except ValueError: + print("line must be format: key=value, but now is:", line) + sys.exit(1) + try: + value = int(value) + except ValueError: + print("Error: you input value must be integer, but now is:", value) + sys.exit(1) + return key, value + +def get_indent(line): + """ + get indent length + :param line: + :return: + """ + index = 0 + for i in line: + if i == " ": + index += 1 + else: + break + return index + +def print_line(line): + """ + Convert line to a python string + :param line: + :return: + """ + global PYTHON_INDENT + global GENERATE_CODE_INDENT + if line.strip()[0] == "}" or line.strip()[0] == ")": + PYTHON_INDENT = -1 + split_str = line.split("@") + if line.strip()[0] != "@" and len(split_str) == 1: + if get_indent(line) == PYTHON_INDENT or PYTHON_INDENT == -1: + result = ["print(", line, ", file=OUT_STREAM)"] + PYTHON_INDENT = -1 + if "{" in line or "asm volatile(" in line: + GENERATE_CODE_INDENT = get_indent(line) + if line.strip().startswith("}") and "{" not in line: + GENERATE_CODE_INDENT -= 4 + if len(line) == 1 and line[0] == "}": + # modify next fun GENERATE_CODE_INDENT + GENERATE_CODE_INDENT = -4 + return "\"".join(result) + + if line.strip()[0] == '@': + # get python indent and first GENERATE_CODE_INDENT + if PYTHON_INDENT == -1: + GENERATE_CODE_INDENT = get_indent(line) - 4 + PYTHON_INDENT = get_indent(line) + result = split_str[0][PYTHON_INDENT:] + split_str[1] + return result + + index = get_indent(split_str[0]) + result = [split_str[0][PYTHON_INDENT:index] + "print("] + prefix = " " * (GENERATE_CODE_INDENT + 4) + split_str[0].lstrip() + + suffix = " %(" + for str_tmp in split_str[1:]: + second = str_tmp.find("}") + suffix += str_tmp[1:second] + ', ' + str_tmp = str_tmp.replace(str_tmp[0:second + 1], "%d") + prefix += str_tmp + result.append(prefix) + result.append(suffix + "), file=OUT_STREAM)") + return "\"".join(result) + +def generate_code(template_file, exec_dict): + """ + generate hpc + :param template_file: template file path + :param exec_dict: dict + :return: hpc + """ + output_stream = io.StringIO() + with open(template_file, 'r') as f: + generate_code_lines = [] + for line in f: + line = line.replace("\n", "") + if line.strip() and line.strip()[0] != "@": + line = line.replace("\"", "\\\"") + line = line.replace("%", "%%") + if "print" in line: + line = line.replace("%%", "%") + if not line: + generate_code_lines.append("print(" + "\"" + line + "\"" + ", file=OUT_STREAM)") + else: + str = print_line(line) + if "%(" not in str: + str = str.replace("%%[", "%[") + generate_code_lines.append(str) + c = compile('\n'.join(generate_code_lines), '', 'exec') + exec_dict["OUT_STREAM"] = output_stream + exec(c, exec_dict) + return output_stream.getvalue() + +def check_python_version(): + if sys.version_info < (3, 6): + sys.stdout.write("At least python 3.6 is required, but now is " + str(sys.version_info.major) + "." + + str(sys.version_info.minor) + "\n") + sys.exit(1) + +GENERATE_CODE_INDENT = -4 +PYTHON_INDENT = -1 + +parser = argparse.ArgumentParser(description="MSLite NNACL Code Generator") +parser.add_argument("-I", dest="Template_File", nargs=1, help="template file to generate code") +parser.add_argument("-A", dest="defines", metavar="KEY=VALUE", nargs="*", type=key_value_pair, action="append", + help="Custom Parameters") +parser.add_argument("-O", dest="Output_File", nargs=1, help="generate code output file path") + +if __name__ == "__main__": + check_python_version() + parameters = parser.parse_args(sys.argv[1:]) + exec_globals = dict(chain(*parameters.defines)) + + generate_code_str = generate_code(parameters.Template_File[0], exec_globals) + if os.path.exists(parameters.Output_File[0]): + os.remove(parameters.Output_File[0]) + + saveDir = os.path.dirname(parameters.Output_File[0]) + if not os.path.exists(saveDir): + os.mkdir(saveDir) + with open(parameters.Output_File[0], "w", encoding='utf-8') as output_file: + output_file.write(generate_code_str) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/template_file/gemm_avx512_mask_nhwc_asm.c.in b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/template_file/gemm_avx512_mask_nhwc_asm.c.in new file mode 100644 index 0000000000000000000000000000000000000000..ee243655f5aacab61c0cbbb81889ae5b17b61067 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/template_file/gemm_avx512_mask_nhwc_asm.c.in @@ -0,0 +1,263 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_@{row_block}x@{col_block}_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, const u_int16_t* mask) { + @import math + @row_stride_map = {6 : 4, 5 : 5, 4 : 6, 3 : 8, 2 : 12, 1 : 20} + @src_addr_stride = 3 + @asm_flag_list = [] + @row_split_number = [row for row in range(3, row_block, 3)] + @for row in row_split_number: + const float *dst_@{row} = dst + @{row} * dst_stride; + @asm_flag_list.append("[dst_" + str(row) + "] " + "\"r\"(dst_" + str(row) + ")"); + size_t dst_stride_t = dst_stride << 2; + @col_split_num = col_block >> 4; + asm volatile( + // inc in depth + "movq %[inc_flag], %rax\\n" + "kmovw (%[mask]), %k1\\n" + "and $0x1, %rax\\n" + "je 0f\\n" + @for row in range(0, row_block): + @src_addr = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if row % 3 == 0: + "vmovups @{col * 64}(%[dst_@{src_addr}]), %%zmm@{row * col_split_num + col}\\n" + @else: + "vmovups @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr}), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + ".align 16\\n" + "0:\\n" + "cmpq $0, %[bias]\\n" + "je 1f\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vmovups @{col * 64}(%[bias]), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + ".align 16\\n" + "1:\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vxorps %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}\\n" + ".align 16\\n" + "2:\\n" + : + @list = ["[dst_0] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)", "[mask] \"r\"(mask)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : \"%rax\", \"%k1\", " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, row_block * col_block >>4)]), file=OUT_STREAM) + ); + @for row in row_split_number: + const float *src_@{row} = src + @{row} * src_stride; + @asm_flag_list.append("[src_" + str(row) + "] " + "\"r\"(src_" + str(row) + ")"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "kmovw (%[mask]), %k1\\n" + @loop_count = 16 + "cmp $@{loop_count}, %[depth]\\n" + "jb 1f\\n" + ".align 16\\n" + "0:\\n" + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + "add $@{col_block * 4 * loop_count}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "sub $@{loop_count}, %[depth]\\n" + "cmp $@{loop_count}, %[depth]\\n" + "jge 0b\\n" + "cmp $0, %[depth]\\n" + "je 2f\\n" + ".align 16\\n" + "1:\\n" + @loop_count = 1 + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + @if col == col_split_num - 1: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index} %{{%%k1}}\\n" + @else: + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + "add $@{col_block * 4 * loop_count}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "dec %[depth]\\n" + "jg 1b\\n" + ".align 16\\n" + "2:\\n" + "and $0x2, %[inc_flag]\\n" + "je 3f\\n" + "and $0x3, %[act_flag]\\n" + "je 3f\\n" + // relu + "vxorps %zmm31, %zmm31, %zmm31\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + @if col == col_split_num - 1: + "vmaxps %%zmm@{row * col_split_num + col}, %%zmm31, %%zmm@{row * col_split_num + col} %{{%%k1}}\\n" + @else: + "vmaxps %%zmm@{row * col_split_num + col}, %%zmm31, %%zmm@{row * col_split_num + col}\\n" + "and $0x1, %[act_flag]\\n" + "je 3f\\n" + // relu6 + "mov $0x40C00000, %eax\\n" + "vmovd %eax, %xmm30\\n" + "vbroadcastss %%xmm@{30}, %%zmm30\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + @if col == col_split_num - 1: + "vminps %%zmm@{row * col_split_num + col}, %%zmm30, %%zmm@{row * col_split_num + col} %{{%%k1}}\\n" + @else: + "vminps %%zmm@{row * col_split_num + col}, %%zmm30, %%zmm@{row * col_split_num + col}\\n" + ".align 16\\n" + "3:\\n" + @for row in range(0, row_block): + @src_addr = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if col == col_split_num - 1: + @if row % 3 == 0: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}]) %{{%%k1}}\\n" + @else: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr}) %{{%%k1}}\\n" + @else: + @if row % 3 == 0: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}])\\n" + @else: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr})\\n" + : + @list = ["[src_0] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[depth] \"r\"(depth)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst_0] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)", "[mask] \"r\"(mask)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, 32)]), file=OUT_STREAM) + ); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in new file mode 100644 index 0000000000000000000000000000000000000000..ad4985198ca2a036930f3d983f9b16cff2a84c22 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in @@ -0,0 +1,231 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/matmul_avx512_fp32.h" + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + @import math + @row_stride_map = {6 : 4, 5 : 5, 4 : 6, 3 : 8, 2 : 12, 1 : 20} + @src_addr_stride = 3 + @asm_flag_list = [] + @row_split_number = [row for row in range(3, row_block, 3)] + @for row in row_split_number: + const float *dst_@{row} = dst + @{row} * dst_stride; + @asm_flag_list.append("[dst_" + str(row) + "] " + "\"r\"(dst_" + str(row) + ")"); + size_t dst_stride_t = dst_stride << 2; + @col_split_num = col_block >> 4; + asm volatile( + // inc in depth + "movq %[inc_flag], %rax\\n" + "and $0x1, %rax\\n" + "je 0f\\n" + @for row in range(0, row_block): + @src_addr = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if row % 3 == 0: + "vmovups @{col * 64}(%[dst_@{src_addr}]), %%zmm@{row * col_split_num + col}\\n" + @else: + "vmovups @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr}), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + ".align 16\\n" + "0:\\n" + "cmpq $0, %[bias]\\n" + "je 1f\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vmovups @{col * 64}(%[bias]), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + ".align 16\\n" + "1:\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vxorps %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}\\n" + ".align 16\\n" + "2:\\n" + : + @list = ["[dst_0] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, row_block * col_block >>4)]), file=OUT_STREAM) + ); + @for row in row_split_number: + const float *src_@{row} = src + @{row} * src_stride; + @asm_flag_list.append("[src_" + str(row) + "] " + "\"r\"(src_" + str(row) + ")"); + size_t src_stride_t = src_stride << 2; + asm volatile( + @loop_count = 16 + "cmp $@{loop_count}, %[depth]\\n" + "jb 1f\\n" + ".align 16\\n" + "0:\\n" + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + "add $@{col_block * 4 * loop_count}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "sub $@{loop_count}, %[depth]\\n" + "cmp $@{loop_count}, %[depth]\\n" + "jge 0b\\n" + "cmp $0, %[depth]\\n" + "je 2f\\n" + ".align 16\\n" + "1:\\n" + @loop_count = 1 + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) + @for col in range(0, col_split_num): + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + "add $@{col_block * 4 * loop_count}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "dec %[depth]\\n" + "jg 1b\\n" + ".align 16\\n" + "2:\\n" + "and $0x2, %[inc_flag]\\n" + "je 3f\\n" + "and $0x3, %[act_flag]\\n" + "je 3f\\n" + // relu + "vxorps %zmm31, %zmm31, %zmm31\\n" + @for col in range(0, col_split_num): + @for row in range(0, row_block): + "vmaxps %%zmm@{row + col * row_block}, %%zmm31, %%zmm@{row + col * row_block}\\n" + "and $0x1, %[act_flag]\\n" + "je 3f\\n" + // relu6 + "mov $0x40C00000, %eax\\n" + "vmovd %eax, %xmm30\\n" + "vbroadcastss %xmm30, %zmm30\\n" + @for col in range(0, col_split_num): + @for row in range(0, row_block): + "vminps %%zmm@{row + col * row_block}, %%zmm30, %%zmm@{row + col * row_block}\\n" + ".align 16\\n" + "3:\\n" + @for row in range(0, row_block): + @src_addr = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if row % 3 == 0: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}])\\n" + @else: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr})\\n" + : + @list = ["[src_0] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[depth] \"r\"(depth)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst_0] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : \"%rax\", " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, 32)]), file=OUT_STREAM) + ); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/template_file/gemm_fma_nc8hw8.c.in b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/template_file/gemm_fma_nc8hw8.c.in new file mode 100644 index 0000000000000000000000000000000000000000..641b1857e787440a97bf91843a781afaa307baaf --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/template_file/gemm_fma_nc8hw8.c.in @@ -0,0 +1,85 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t deep, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + __m256 dst@{j * row_block + i}; + if (inc_flag) { + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{j * row_block + i} = _mm256_load_ps(dst + @{j} * dst_stride + @{i * 8}); + } else if (bias == NULL) { + @for i in range(0, row_block * col_block >> 3): + dst@{i} = _mm256_setzero_ps(); + } else { + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{j * row_block + i} = _mm256_load_ps(bias + @{j * 8}); + } + for (int i = 0; i < (deep >> 3); ++i) { + @for i in range(0, 8): + // bock@{i} + @if col_block == 32: + @for row in range(0, row_block): + __m256 src@{row}@{i} = _mm256_set1_ps(*(src + @{row * 8 + i})); + @for col in range(0, col_block >> 3): + __m256 weight@{col}@{i} = _mm256_load_ps(weight + @{col * 8 + i * col_block}); + @for row in range(0, row_block): + dst@{row + col * row_block} = _mm256_fmadd_ps(dst@{row + col * row_block}, src@{row}@{i}, weight@{col}@{i}); + @else: + @for col in range(0, col_block >> 3): + __m256 weight@{col}@{i} = _mm256_load_ps(weight + @{col * 8 + i * col_block}); + @for row in range(0, row_block): + __m256 src@{row}@{i} = _mm256_set1_ps(*(src + @{row * 8 + i})); + @for col in range(0, col_block >> 3): + dst@{row + col * row_block} = _mm256_fmadd_ps(dst@{row + col * row_block}, src@{row}@{i}, weight@{col}@{i}); + src = src + src_stride; + weight += @{8 * col_block * 4}; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{i + j * row_block} = _mm256_min_ps(dst@{i + j * row_block}, relu6); + // relu + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{i + j * row_block} = _mm256_max_ps(dst@{i + j * row_block}, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{i + j * row_block} = _mm256_max_ps(dst@{i + j * row_block}, relu); + } + @if col_block == 32: + @for j in range(0, col_block >> 3): + @for i in range(0, row_block): + _mm256_store_ps(dst + @{j} * src_stride + @{i * 8}, dst@{j * row_block + i}); + @else: + @for j in range(0, col_block >> 3): + @for i in range(0, row_block): + _mm256_store_ps(dst + @{j} * src_stride + @{i * 8}, dst@{j * row_block + i}); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in new file mode 100644 index 0000000000000000000000000000000000000000..70178cf583ceafa1cc6fb9485e5364b393d64af6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/experimental/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t deep, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + @if col_block == 32: + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\\n" + "je 0f\\n" + @for col in range(0, min((col_block >> 3), 3)): + @for row in range(0, row_block): + @if col == 0: + "vmovups @{row * 32}(%[dst]), %%ymm@{row + col * row_block}\\n" + @else: + "vmovups @{row * 32}(%[dst], %[dst_stride], @{col}), %%ymm@{row + col * row_block}\\n" + @if col_block == 32: + @for row in range(0, row_block): + "vmovups @{row * 32}(%[dst_4]), %%ymm@{row + (col + 1) * row_block}\\n" + "jmp 2f\\n" + "0:\\n" + "cmpq $0, %[bias]\\n" + "je 1f\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vmovaps @{col * 32}(%[bias]), %%ymm@{row + col * row_block}\\n" + "jmp 2f\\n" + "1:\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vxorps %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}\\n" + "2:\\n" + : + @list = ["[dst] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"] + @if col_block == 32: + @list.append("[dst_4] \"r\"(dst_4)") + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, row_block * col_block >> 3)]), file=OUT_STREAM) + ); + asm volatile( + "0:\\n" + @for i in range(0, 8): + // block @{i} + @if col_block == 32: + @for row in range(0, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - row}\\n" + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - row_block}\\n" + @for row in range(0, row_block): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - row_block}, %%ymm@{15 - row}\\n" + @elif col_block == 24: + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n" + @for row in range(0, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n" + @for col in range(0, col_block >> 3): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n" + @elif col_block == 16: + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n" + @for row in range(0, row_block >> 1): + "vbroadcastss @{row * 64 + i}(%[src]), %%ymm@{14 - col}\\n" + "vbroadcastss @{row * 64 + 32 + i}(%[src]), %%ymm@{13 - col}\\n" + @for col in range(0, col_block >> 3): + @for j in range(0, 2): + "vfmadd231ps %%ymm@{row * 2 + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n" + @for row in range(row_block >> 1 << 1, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n" + @for col in range(0, col_block >> 3): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n" + @else: + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n" + @split_num = 3 + @for row in range(0, int(row_block / split_num)): + @for j in range(0, split_num): + "vbroadcastss @{row * 96 + j * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - j}\\n" + @for col in range(0, col_block >> 3): + @for j in range(0, split_num): + "vfmadd231ps %%ymm@{row * split_num + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n" + @for row in range(int(row_block / split_num) * split_num, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}\\n" + @for col in range(0, col_block >> 3): + @for row in range(int(row_block / split_num) * split_num, row_block): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}, %%ymm@{15 - col}\\n" + "dec %[deep]\\n" + "add @{col_block * 4 * 8}, %[weight]\\n" + "add %[src_stride], %[src]\\n" + "jg 0b\\n" + + "movq %[inc_flag], %rax\\n" + "and $0x2, %eax\\n" + "je 3f\\n" + "movq %[act_flag], %rax\\n" + "and $0x3, %eax\\n" + "je 3f\\n" + // relu + "vxorps %ymm15, %ymm15, %ymm15\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vmaxps %%ymm@{row + col * row_block}, %%ymm15, %%ymm@{row + col * row_block}\\n" + "and $0x1, %eax\\n" + "je 3f\\n" + // relu6 + "mov $0x40C00000, %eax\\n" + "vmovd %eax, %xmm14\\n" + "vpermps %ymm14, %ymm15, %ymm14\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vminps %%ymm@{row + col * row_block}, %%ymm14, %%ymm@{row + col * row_block}\\n" + "3:\\n" + @for col in range(0, min((col_block >> 3), 3)): + @for row in range(0, row_block): + @if col == 0: + "vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst])\\n" + @else: + "vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst], %[dst_stride], @{col})\\n" + @if col_block == 32: + @for row in range(0, row_block): + "vmovups %%ymm@{row + (col + 1) * row_block}, @{row * 32}(%[dst_4])\\n" + : + @list = ["[src] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"] + @if col_block == 32: + @list.append("[dst_4] \"r\"(dst_4)") + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : \"%rax\", " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, 16)]), file=OUT_STREAM) + ); +} diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/akg_kernel_register.h b/mindspore-lite/ops/kernel/cpu/nnacl/fill_parameter.h similarity index 69% rename from mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/akg_kernel_register.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fill_parameter.h index 3d2e2bbc2fb3e3ce7b780593d470f5937c8d3e93..41c7d25c32fa4ee0207d8fb78446410f843ed97f 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/akg_kernel_register.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fill_parameter.h @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_CXXAPI_AKG_KERNEL_REGISTER_H_ -#define MINDSPORE_CCSRC_CXXAPI_AKG_KERNEL_REGISTER_H_ -#include "include/api/visible.h" -namespace mindspore { -MS_API void RegAllOp(); -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXXAPI_AKG_KERNEL_REGISTER_H_ +#ifndef NNACL_FILL_PARAMETER_H_ +#define NNACL_FILL_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct FillParameter { + OpParameter op_parameter_; +} FillParameter; + +#endif // NNACL_FILL_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/flatten_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/flatten_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..d5d0a71900a7624ff859201ef91e0e9f2bcb3ffa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/flatten_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FLATTEN_PARAMETER_H_ +#define NNACL_FLATTEN_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct FlattenParameter { + OpParameter op_parameter_; + int axis_; +} FlattenParameter; + +#endif // NNACL_FLATTEN_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/format_transpose_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/format_transpose_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..167b17fb4f36f2214712a93c4d4b77aed9709a9c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/format_transpose_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FORMAT_TRANSPOSE_PARAMETER_H_ +#define NNACL_FORMAT_TRANSPOSE_PARAMETER_H_ + +#include "nnacl/op_base.h" +#include "nnacl/infer/common_infer.h" +static const int FormatTransposeInput = 2; +typedef struct FormatTransposeParameter { + // Primitive parameter + OpParameter op_parameter_; + FormatC src_format_; + FormatC dst_format_; +} FormatTransposeParameter; + +#endif // NNACL_FORMAT_TRANSPOSE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/activation_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/activation_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..8718430159e714e0f7b145ab9ac0b58b48ccb312 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/activation_fp16.c @@ -0,0 +1,319 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/activation_fp16.h" +#include +#include "nnacl/fp32/exp_fp32.h" +#include "nnacl/fp16/exp_fp16.h" +#include "nnacl/errorcode.h" + +int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) { + int offset = 0; +#ifdef ENABLE_NEON + float16x8_t zero = vdupq_n_f16(0); + for (; offset <= ele_num - C8NUM; offset += C8NUM) { + float16x8_t src_value = vld1q_f16(src + offset); + float16x8_t rst_value = vmaxq_f16(src_value, zero); + vst1q_f16(dst + offset, rst_value); + } +#endif + for (; offset < ele_num; offset++) { + dst[offset] = src[offset] < 0.0f ? 0.0f : src[offset]; + } + return NNACL_OK; +} + +int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) { + int offset = 0; +#ifdef ENABLE_NEON + float16x8_t zero_data = vdupq_n_f16(0); + float16x8_t six_data = vdupq_n_f16(6); + for (; offset <= ele_num - C8NUM; offset += C8NUM) { + float16x8_t relu6_data = vld1q_f16(data + offset); + relu6_data = vmaxq_f16(relu6_data, zero_data); + relu6_data = vminq_f16(relu6_data, six_data); + vst1q_f16(dst + offset, relu6_data); + } +#endif + for (; offset < ele_num; offset++) { + dst[offset] = data[offset] < 0.0f ? 0.0f : data[offset]; + dst[offset] = dst[offset] > 6.0f ? 6.0f : dst[offset]; + } + return NNACL_OK; +} + +int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) { + int offset = 0; +#ifdef ENABLE_NEON + float16x8_t zero_data = vdupq_n_f16(0); + float16x8_t alpha_data = vdupq_n_f16(alpha); + for (; offset <= ele_num - C8NUM; offset += C8NUM) { + float16x8_t src_tmp = vld1q_f16(src + offset); + float16x8_t mul_tmp = vmulq_f16(src_tmp, alpha_data); + uint16x8_t mask = vcleq_f16(src_tmp, zero_data); + vst1q_f16(dst + offset, vbslq_f16(mask, mul_tmp, src_tmp)); + } +#endif + for (; offset < ele_num; ++offset) { + dst[offset] = src[offset] > (float16_t)0.0f ? src[offset] : (src[offset] * alpha); + } + return NNACL_OK; +} + +int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { + int i = 0; +#ifdef ENABLE_NEON + int count = (ele_num / C4NUM) * C4NUM; + for (; i < count; i += C4NUM) { + float32x4_t tmp; + simd_exp128(vnegq_f32(vcvt_f32_f16(vld1_f16(src + i))), (float *)&tmp); + vst1_f16(dst + i, vcvt_f16_f32(MS_DIVQ_F32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp)))); + } +#endif + for (; i < ele_num; ++i) { + float temp; + simd_exp32(-src[i], &temp); + dst[i] = (float16_t)1.0f / ((float16_t)1.0f + temp); + } + return NNACL_OK; +} + +float16_t TanhOptFp16(float16_t src) { + if (src > 5.0f) { + return 1.0f; + } else if (src < -5.0f) { + return -1.0f; + } else { + float square = src * src; + float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * src; + float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; + return a / b; + } +} + +int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { + int i = 0; +#ifdef ENABLE_NEON + static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f}, + {17325.0f, 17325.0f, 17325.0f, 17325.0f}, + {135135.0f, 135135.0f, 135135.0f, 135135.0f}, + {28.0f, 28.0f, 28.0f, 28.0f}, + {3150.0f, 3150.0f, 3150.0f, 3150.0f}, + {62370.0f, 62370.0f, 62370.0f, 62370.0f}}; + float32x4_t neg_one = {-1.0f, -1.0f, -1.0f, -1.0f}; + float32x4_t pos_one = {1.0f, 1.0f, 1.0f, 1.0f}; + int count = (ele_num / C4NUM) * C4NUM; + for (; i < count; i += C4NUM) { + float32x4_t input = vcvt_f32_f16(vld1_f16(src + i)); + float32x4_t square = vmulq_f32(input, input); + float32x4_t a = vmulq_f32( + vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(square, paramv[0]), square), paramv[1]), square), paramv[2]), + input); + float32x4_t b = vaddq_f32( + vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), + paramv[2]); + vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(MS_DIVQ_F32(a, b), neg_one), pos_one))); + } +#endif + for (; i < ele_num; ++i) { + float input = src[i]; + float square = input * input; + float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * input; + float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; + dst[i] = a / b; + dst[i] = MSMAX(dst[i], -1); + dst[i] = MSMIN(dst[i], 1); + } + return NNACL_OK; +} + +int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) { + int i = 0; +#ifdef ENABLE_NEON + const MS_FLOAT16X8 zero_data = vdupq_n_f16(0); + const MS_FLOAT16X8 three_data = vdupq_n_f16(3); + const MS_FLOAT16X8 six_data = vdupq_n_f16(6); + for (; i <= ele_num - C8NUM; i += C8NUM) { + MS_FLOAT16X8 in_data = MS_LDQ_F16(src + i); + MS_FLOAT16X8 tmp = MS_MAXQ_F16(in_data + three_data, zero_data); + tmp = MS_MINQ_F16(tmp, six_data); + MS_STQ_F16(dst + i, vmulq_f16(in_data, MS_DIVQ_F16(tmp, six_data))); + } +#endif + for (; i < ele_num; ++i) { + float16_t in = src[i]; + float16_t relu6 = MSMIN(MSMAX(in + 3.0f, 0.0f), 6.0f); + dst[i] = in * relu6 / (float16_t)6.0f; + } + return NNACL_OK; +} + +int SwishFp16(const float16_t *src, float16_t *dst, int ele_num) { + int i = 0; +#ifdef ENABLE_NEON + float32x4_t const_val = vdupq_n_f32(1.0f); + for (int num_max = ele_num - C16NUM; i <= num_max; i += C16NUM) { + float16x4x4_t ins = vld4_f16(src + i); + float32x4_t in0 = MS_CVT_F32_F16(ins.val[0]); + float32x4_t in1 = MS_CVT_F32_F16(ins.val[1]); + float32x4_t in2 = MS_CVT_F32_F16(ins.val[2]); + float32x4_t in3 = MS_CVT_F32_F16(ins.val[3]); + float32x4_t exp0 = simd_exp128_f32(vnegq_f32(in0)); + float32x4_t exp1 = simd_exp128_f32(vnegq_f32(in1)); + float32x4_t exp2 = simd_exp128_f32(vnegq_f32(in2)); + float32x4_t exp3 = simd_exp128_f32(vnegq_f32(in3)); + float32x4_t res0 = MS_DIVQ_F32(in0, vaddq_f32(const_val, exp0)); + float32x4_t res1 = MS_DIVQ_F32(in1, vaddq_f32(const_val, exp1)); + float32x4_t res2 = MS_DIVQ_F32(in2, vaddq_f32(const_val, exp2)); + float32x4_t res3 = MS_DIVQ_F32(in3, vaddq_f32(const_val, exp3)); + float16x4x4_t res = {MS_CVT_F16_F32(res0), MS_CVT_F16_F32(res1), MS_CVT_F16_F32(res2), MS_CVT_F16_F32(res3)}; + vst4_f16(dst + i, res); + } + for (int num_max = ele_num - C4NUM; i <= num_max; i += C4NUM) { + float32x4_t in = MS_CVT_F32_F16(vld1_f16(src + i)); + float16x4_t res = MS_CVT_F16_F32(MS_DIVQ_F32(in, vaddq_f32(const_val, simd_exp128_f32(vnegq_f32(in))))); + vst1_f16(dst + i, res); + } +#endif + for (; i < ele_num; ++i) { + float temp = simd_exp32_f32(-src[i]); + dst[i] = src[i] / (1.0f + temp); + } + return NNACL_OK; +} + +int HSigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { + int offset = 0; +#ifdef ENABLE_NEON + const MS_FLOAT16X8 zero_data = vdupq_n_f16(0); + const MS_FLOAT16X8 three_data = vdupq_n_f16(3); + const MS_FLOAT16X8 six_data = vdupq_n_f16(6); + for (; offset <= ele_num - C8NUM; offset += C8NUM) { + MS_FLOAT16X8 relu6_data = MS_LDQ_F16(src + offset) + three_data; + relu6_data = MS_MAXQ_F16(relu6_data, zero_data); + relu6_data = MS_MINQ_F16(relu6_data, six_data); + MS_STQ_F16(dst + offset, MS_DIVQ_F16(relu6_data, six_data)); + } +#endif + + for (; offset < ele_num; offset++) { + float16_t tmp = (src[offset] + 3.0 < 0.0) ? 0.0 : src[offset] + 3.0; + dst[offset] = ((tmp < 6.0) ? tmp : 6.0) / 6.0; + } + + return NNACL_OK; +} + +int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val, float max_val) { + if (max_val <= min_val) { + return NNACL_ERR; + } + int i = 0; + if (min_val == FLT_MIN) { + for (i = 0; i < length; ++i) { + dst[i] = src[i] > max_val ? max_val : src[i]; + } + } else if (max_val == FLT_MAX) { + for (i = 0; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : src[i]; + } + } else { + for (i = 0; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : (src[i] > max_val ? max_val : src[i]); + } + } + return NNACL_OK; +} + +int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate) { + if (src == NULL || dst == NULL) { + return NNACL_ERR; + } + int i = 0; + if (approximate) { + // dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3))) +#ifdef ENABLE_NEON + for (int num_max = length - C16NUM; i <= num_max; i += C16NUM) { + float16x4x4_t ins = vld4_f16(src + i); + float32x4_t in0 = MS_CVT_F32_F16(ins.val[0]); + float32x4_t in1 = MS_CVT_F32_F16(ins.val[1]); + float32x4_t in2 = MS_CVT_F32_F16(ins.val[2]); + float32x4_t in3 = MS_CVT_F32_F16(ins.val[3]); + float32x4_t res0 = 0.5f * in0 * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in0 * in0) * in0)); + float32x4_t res1 = 0.5f * in1 * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in1 * in1) * in1)); + float32x4_t res2 = 0.5f * in2 * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in2 * in2) * in2)); + float32x4_t res3 = 0.5f * in3 * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in3 * in3) * in3)); + float16x4x4_t res = { + MS_CVT_F16_F32(res0), + MS_CVT_F16_F32(res1), + MS_CVT_F16_F32(res2), + MS_CVT_F16_F32(res3), + }; + vst4_f16(dst + i, res); + } + for (int num_max = length - C4NUM; i <= num_max; i += C4NUM) { + float32x4_t in = MS_CVT_F32_F16(vld1_f16(src + i)); + float32x4_t res = 0.5f * in * (1.0f + MS_TANHX4_F32((0.79788456080287f + 0.035677408136f * in * in) * in)); + vst1_f16(dst + i, MS_CVT_F16_F32(res)); + } +#endif + for (; i < length; i++) { + dst[i] = + 0.5f * src[i] * + (1.0f + TanhOptFp16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * src[i] * src[i]) * src[i])); + } + } else { +#ifdef ENABLE_NEON + int C8 = DOWN_ROUND(length, C8NUM); + for (; i < C8; i += C8NUM) { + float16x8_t in = vld1q_f16(src + i); + const float16x8_t res = 0.5f * in * (1.0f + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f)); + vst1q_f16(dst + i, res); + } +#endif + for (; i < length; i++) { + dst[i] = 0.5f * src[i] * (1.0f + erff(src[i] / 1.4142135623730951f)); + } + } + return NNACL_OK; +} + +int SoftplusFp16(const float16_t *src, int length, float16_t *dst) { + int i = 0; + for (; i < length; ++i) { + single_exp_fp16(src[i], dst + i); + dst[i] = log1p(dst[i]); + } + return NNACL_OK; +} + +int EluFp16(const float16_t *src, int length, float16_t *dst, float16_t alpha) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t one = MS_MOVQ_F16(1.0f); + for (; i <= length - 8; i += 8) { + float16x8_t src_tmp = MS_LDQ_F16(src + i); + float16x8_t exp_tmp = VexpFp16(src_tmp); // exp(x) + exp_tmp = MS_SUBQ_F16(exp_tmp, one); // exp(x) - 1 + float16x8_t elu_tmp = MS_MULQ_N_F16(exp_tmp, alpha); + uint16x8_t mask = vcleq_f16(src_tmp, MS_MOVQ_F16(0.0f)); + MS_STQ_F16(dst + i, vbslq_f16(mask, elu_tmp, src_tmp)); + } +#endif + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (expm1(src[i]) * alpha); + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/activation_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/activation_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..3fe5c6ebc2e3a1533f2ee1ad30ef13d588e2a75e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/activation_fp16.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_ACTIVATION_FP16_H_ +#define NNACL_FP16_ACTIVATION_FP16_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl/int8/fixed_point.h" +#include "nnacl/activation_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ReluFp16(const float16_t *src, float16_t *dst, int ele_num); +int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num); +int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha); +int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num); +int TanhFp16(const float16_t *src, float16_t *dst, int ele_num); +int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num); +int HSigmoidFp16(const float16_t *src, float16_t *dst, int ele_num); +int SwishFp16(const float16_t *src, float16_t *dst, int ele_num); +int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val, float max_val); +int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate); +int SoftplusFp16(const float16_t *src, int length, float16_t *dst); +int EluFp16(const float16_t *src, int length, float16_t *dst, float16_t alpha); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_ACTIVATION_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arg_min_max_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arg_min_max_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..d05bfa723a035dd53d83e8f6e85bc6c193f9ab0e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arg_min_max_fp16.c @@ -0,0 +1,273 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/arg_min_max_fp16.h" + +int ArgCompareAscFp16(const void *a, const void *b) { + float16_t a_value = ((ArgElement *)a)->data_.f16_data_; + float16_t b_value = ((ArgElement *)b)->data_.f16_data_; + if (b_value > a_value) { + return -1; + } + if (b_value < a_value) { + return 1; + } + + return 0; +} + +int ArgCompareDescFp16(const void *a, const void *b) { + float16_t b_value = ((ArgElement *)b)->data_.f16_data_; + float16_t a_value = ((ArgElement *)a)->data_.f16_data_; + if (b_value > a_value) { + return 1; + } + if (b_value < a_value) { + return -1; + } + + return 0; +} + +void ArgMaxTopK1Fp16(const float16_t *input, void *output, float16_t *output_value, const ArgMinMaxComputeParam *param, + int pre_axis_count, int axis_count, int after_axis_count) { + bool out_value = param->out_value_; + float16_t *outputfp16 = (float16_t *)output; + int *outputint = (int *)output; + for (int i = 0; i < pre_axis_count; ++i) { + size_t output_offset = i * after_axis_count; + size_t input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float16_t value = -FLT_MAX; + int index = 0; + for (int k = 0; k < axis_count; ++k) { + float16_t value_tmp = input[input_offset + k * after_axis_count + j]; + if (value_tmp > value) { + value = value_tmp; + index = k; + } + } + if (out_value) { + outputfp16[output_offset + j] = value; + } else { + outputint[output_offset + j] = index; + } + if (output_value != NULL) { + output_value[output_offset + j] = value; + } + } + } +} + +void ArgMinTopK1Fp16(const float16_t *input, void *output, float16_t *output_value, const ArgMinMaxComputeParam *param, + int pre_axis_count, int axis_count, int after_axis_count) { + bool out_value = param->out_value_; + float16_t *outputfp16 = (float16_t *)output; + int *outputint = (int *)output; + for (int i = 0; i < pre_axis_count; ++i) { + size_t output_offset = i * after_axis_count; + size_t input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float16_t value = FLT_MAX; + int index = 0; + for (int k = 0; k < axis_count; ++k) { + float16_t value_tmp = input[input_offset + k * after_axis_count + j]; + if (value_tmp < value) { + value = value_tmp; + index = k; + } + } + if (out_value) { + outputfp16[output_offset + j] = value; + } else { + outputint[output_offset + j] = index; + } + if (output_value != NULL) { + output_value[output_offset + j] = value; + } + } + } +} + +void ArgMinMaxDim0Fp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) { + float16_t *outputfp16 = (float16_t *)output; + int *outputint = (int *)output; + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { + for (int j = 0; j < in_shape[0]; ++j) { + size_t offset = param->in_strides_[0] * j + i; + param->arg_elements_[j].index_ = j; + param->arg_elements_[j].data_.f16_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), *compare_func); + for (int j = 0; j < param->topk_; ++j) { + size_t out_offset = j * param->out_strides_[0] + i; + if (param->out_value_) { + outputfp16[out_offset] = param->arg_elements_[j].data_.f16_data_; + } else { + outputint[out_offset] = param->arg_elements_[j].index_; + } + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[j].data_.f16_data_; + } + } + } + return; +} + +void ArgMinMaxDim1Fp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) { + int in_shape1 = in_shape[1]; + float16_t *outputfp16 = (float16_t *)output; + int *outputint = (int *)output; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < param->in_strides_[1]; ++j) { + for (int k = 0; k < in_shape1; ++k) { + size_t offset = param->in_strides_[1] * k + in_dim0_offset + j; + param->arg_elements_[k].index_ = k; + param->arg_elements_[k].data_.f16_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), *compare_func); + for (int k = 0; k < param->topk_; ++k) { + size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1]; + if (param->out_value_) { + outputfp16[out_offset] = param->arg_elements_[k].data_.f16_data_; + } else { + outputint[out_offset] = param->arg_elements_[k].index_; + } + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[k].data_.f16_data_; + } + } + } + } + return; +} + +void ArgMinMaxDim2Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + float *outputfp16 = (float *)output; + int *outputint = (int *)output; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < param->in_strides_[2]; ++k) { + for (int l = 0; l < in_shape2; ++l) { + size_t offset = param->in_strides_[2] * l + k + in_dim1_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f16_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), *compare_func); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2]; + if (param->out_value_) { + outputfp16[out_offset] = param->arg_elements_[l].data_.f16_data_; + } else { + outputint[out_offset] = param->arg_elements_[l].index_; + } + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[l].data_.f16_data_; + } + } + } + } + } +} + +void ArgMinMaxDim3Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param, COMPARE_FUNCTION compare_func) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + int in_shape3 = in_shape[3]; + float *outputfp16 = (float *)output; + int *outputint = (int *)output; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < in_shape2; ++k) { + size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; + size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; + for (int l = 0; l < in_shape3; ++l) { + size_t offset = l + in_dim2_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f16_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), *compare_func); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim2_offset + l; + if (param->out_value_) { + outputfp16[out_offset] = param->arg_elements_[l].data_.f16_data_; + } else { + outputint[out_offset] = param->arg_elements_[l].index_; + } + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[l].data_.f16_data_; + } + } + } + } + } +} + +void ArgMinMaxFp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param) { + if (param->topk_ == 1) { + int pre_axis_count = 1; + int axis_count = 1; + int after_axis_count = 1; + ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); + + if (param->get_max_) { + ArgMaxTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); + } else { + ArgMinTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); + } + return; + } + + COMPARE_FUNCTION compare_function = NULL; + if (param->get_max_) { + compare_function = ArgCompareDescFp16; + } else { + compare_function = ArgCompareAscFp16; + } + + switch (param->axis_) { + case 0: + ArgMinMaxDim0Fp16(input, output, output_value, in_shape, param, compare_function); + break; + case 1: + ArgMinMaxDim1Fp16(input, output, output_value, in_shape, param, compare_function); + break; + case 2: + ArgMinMaxDim2Fp16(input, output, output_value, in_shape, param, compare_function); + break; + case 3: + ArgMinMaxDim3Fp16(input, output, output_value, in_shape, param, compare_function); + break; + } + return; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arg_min_max_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arg_min_max_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..56bb069d6d6a4625d7e5a6e4f45dca53398c486f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arg_min_max_fp16.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_ARG_MIN_MAX_FP16_H_ +#define NNACL_FP16_ARG_MIN_MAX_FP16_H_ + +#include +#include "nnacl/arg_min_max_parameter.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/kernel/arg_min_max.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ArgMinMaxFp16(const float16_t *input, void *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxComputeParam *param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_ARG_MIN_MAX_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arithmetic_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arithmetic_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..c3d3ec10a7d662324b272da6ca829159cf81880c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arithmetic_fp16.c @@ -0,0 +1,1314 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/arithmetic_fp16.h" +#include +#include "nnacl/common_func.h" +#include "nnacl/nnacl_utils.h" + +int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1, + float16_t *out, int size, ArithmeticParameter *param) { + TileDimensionsFp16(in0, in1, tile_in0, tile_in1, param); + return ElementAddFp16(tile_in0, tile_in1, out, size); +} + +void TileOneDimensionFp16(const void *input, void *output, int dim, size_t ndim, const int *inShape, + const int *inStrides, const int *outStrides, const int *multiple) { + const float16_t *inData = (const float16_t *)input; + float16_t *outData = (float16_t *)output; + + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(float16_t)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimensionFp16(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, + inShape, inStrides, outStrides, multiple); + } + } +} + +void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionFp16(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimensionFp16(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] * input1[index]; + } + return NNACL_OK; +} + +int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] * input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmulq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] * input1[0]; + } + } + return NNACL_OK; +} + +int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] * input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0_opt, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[0] * input1[index]; + output[index] = res > 0 ? res : 0; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmulq_f16(vin0, vin1_opt); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] * input1[0]; + output[index] = res > 0 ? res : 0; + } + } + return NNACL_OK; +} + +int ElementMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] * input1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmulq_f16(vin0_opt, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[0] * input1[index], 0), 6); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmulq_f16(vin0, vin1_opt); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] * input1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } + for (; index <= element_size - 4; index += C4NUM) { + float16x4_t vin0 = vld1_f16(input0 + index); + float16x4_t vin1 = vld1_f16(input1 + index); + float16x4_t vout = vadd_f16(vin0, vin1); + vst1_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] + input1[index]; + } + return NNACL_OK; +} + +int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] + input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vaddq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] + input1[0]; + } + } + return NNACL_OK; +} + +int ElementAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } + float16x4_t zeros1 = vdup_n_f16(0.0f); + for (; index <= element_size - 4; index += C4NUM) { + float16x4_t vin0 = vld1_f16(input0 + index); + float16x4_t vin1 = vld1_f16(input1 + index); + float16x4_t vout = vadd_f16(vin0, vin1); + vout = vmax_f16(vout, zeros1); + vst1_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] + input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0_opt, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[0] + input1[index]; + output[index] = res > 0 ? res : 0; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vaddq_f16(vin0, vin1_opt); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] + input1[0]; + output[index] = res > 0 ? res : 0; + } + } + return NNACL_OK; +} + +int ElementAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } + float16x4_t zeros1 = vdup_n_f16(0.0); + float16x4_t bounds1 = vdup_n_f16(6.0); + for (; index <= element_size - 4; index += C4NUM) { + float16x4_t vin0 = vld1_f16(input0 + index); + float16x4_t vin1 = vld1_f16(input1 + index); + float16x4_t vout = vadd_f16(vin0, vin1); + vout = vmin_f16(vmax_f16(vout, zeros1), bounds1); + vst1_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vaddq_f16(vin0_opt, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[0] + input1[index], 0), 6); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vaddq_f16(vin0, vin1_opt); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] + input1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] - input1[index]; + } + return NNACL_OK; +} + +int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] - input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vsubq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] - input1[0]; + } + } + return NNACL_OK; +} + +int ElementSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] - input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0_opt, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[0] - input1[index]; + output[index] = res > 0 ? res : 0; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vsubq_f16(vin0, vin1_opt); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + float16_t res = input0[index] - input1[0]; + output[index] = res > 0 ? res : 0; + } + } + return NNACL_OK; +} + +int ElementSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] - input1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vsubq_f16(vin0_opt, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[0] - input1[index], 0), 6); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vsubq_f16(vin0, vin1_opt); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] - input1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] / input1[index]; + } + return NNACL_OK; +} + +int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = MS_DIVQ_F16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] / input1[index]; + } + } else { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] / input1[0]; + } + } + return NNACL_OK; +} + +int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + float16_t res = input0[index] / input1[index]; + output[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMAX(input0[0] / input1[index], 0); + } + } else { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMAX(input0[index] / input1[0], 0); + } + } + return NNACL_OK; +} + +int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMIN(MSMAX(input0[index] / input1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t zeros = vdupq_n_f16(0.0); + float16x8_t bounds = vdupq_n_f16(6.0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + if (input1[index] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[index] = MSMIN(MSMAX(input0[0] / input1[index], 0), 6); + } + } else { + if (input1[0] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros), bounds); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(MSMAX(input0[index] / input1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + if (input1[i] == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + output[i] = input0[i] - floorf(input0[i] / input1[i]) * input1[i]; + } + return NNACL_OK; +} + +int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { + if (!first_scalar) { + for (int i = 0; i < element_size; ++i) { + output[i] = input0[i] - floorf(input0[i] / input1[0]) * input1[0]; + } + } else { + for (int i = 0; i < element_size; ++i) { + output[i] = input0[i] - floorf(input0[i] / input1[i]) * input1[i]; + } + } + return NNACL_OK; +} + +int ElementFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + output[i] = floorf(input0[i] / input1[i]); + } + return NNACL_OK; +} +int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { + if (!first_scalar) { + for (int i = 0; i < element_size; ++i) { + output[i] = floorf(input0[i] / input1[0]); + } + } else { + for (int i = 0; i < element_size; ++i) { + output[i] = floorf(input0[i] / input1[i]); + } + } + return NNACL_OK; +} + +int ElementLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t vtrue = vdupq_n_f16(1); + float16x8_t vfalse = vdupq_n_f16(0); + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = vdupq_n_u16(0); + for (; index <= element_size - 8; index += C8NUM) { + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vld1q_f16(input0 + index)), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vld1q_f16(input1 + index)), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[index]) & (bool)(input1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t vtrue = vdupq_n_f16(1); + float16x8_t vfalse = vdupq_n_f16(0); + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = vdupq_n_u16(0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1_ = vld1q_f16(input1 + index); + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vin0_opt), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vin1_), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[0]) & (bool)(input1[index])); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0_ = vld1q_f16(input0 + index); + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vin0_), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vin1_opt), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vandq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[index]) & (bool)(input1[0])); + } + } + return NNACL_OK; +} + +int ElementLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + float16x8_t vtrue = vdupq_n_f16(1); + float16x8_t vfalse = vdupq_n_f16(0); + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = vdupq_n_u16(0); + for (; index <= element_size - 8; index += C8NUM) { + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vld1q_f16(input0 + index)), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vld1q_f16(input1 + index)), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[index]) | (bool)(input1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); + float16x8_t vtrue = vdupq_n_f16(1); + float16x8_t vfalse = vdupq_n_f16(0); + uint16x8_t mask = vmovq_n_u16(((uint16_t)(1u << 15) - 1)); + uint16x8_t zeros = vdupq_n_u16(0); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1_ = vld1q_f16(input1 + index); + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vin0_opt), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vin1_), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[0]) | (bool)(input1[index])); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0_ = vld1q_f16(input0 + index); + uint16x8_t vin0 = vandq_u16(vreinterpretq_u16_f16(vin0_), mask); + uint16x8_t vin1 = vandq_u16(vreinterpretq_u16_f16(vin1_opt), mask); + float16x8_t vout = vbslq_f16(vceqq_u16(vorrq_u16(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = (float16_t)((bool)(input0[index]) | (bool)(input1[0])); + } + } + return NNACL_OK; +} + +int ElementSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size) { + ElementSubFp16(input0, input1, output, element_size); + return ElementMulFp16(output, output, output, element_size); +} + +int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size, bool first_scalar) { + ElementOptSubFp16(input0, input1, output, element_size, first_scalar); + return ElementMulFp16(output, output, output, element_size); +} + +int ElementMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmaxq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMAX(input0[index], input1[index]); + } + return NNACL_OK; +} + +int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vmaxq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMAX(input0[0], input1[index]); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vmaxq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMAX(input0[index], input1[0]); + } + } + return NNACL_OK; +} + +int ElementMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vminq_f16(vin0, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(input0[index], input1[index]); + } + return NNACL_OK; +} + +int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + float16x8_t vout = vminq_f16(vin0_opt, vin1); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(input0[0], input1[index]); + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vminq_f16(vin0, vin1_opt); + vst1q_f16(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = MSMIN(input0[index], input1[0]); + } + } + return NNACL_OK; +} + +int ElementNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] != input1[index]; + } + return NNACL_OK; +} + +int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] != input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] != input1[0]; + } + } + return NNACL_OK; +} + +int ElementEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] == input1[index]; + } + return NNACL_OK; +} + +int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] == input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vceqq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] == input1[0]; + } + } + return NNACL_OK; +} + +int ElementLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcltq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] < input1[index]; + } + return NNACL_OK; +} + +int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcltq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] < input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vcltq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] < input1[0]; + } + } + return NNACL_OK; +} + +int ElementLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcleq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] <= input1[index]; + } + return NNACL_OK; +} + +int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcleq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] <= input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vcleq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] <= input1[0]; + } + } + return NNACL_OK; +} + +int ElementGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] > input1[index]; + } + return NNACL_OK; +} + +int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] > input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vcgtq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] > input1[0]; + } + } + return NNACL_OK; +} + +int ElementGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] >= input1[index]; + } + return NNACL_OK; +} + +int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar) { +#ifdef ENABLE_NEON + float16x8_t vin0_opt = vdupq_n_f16(input0[0]); + float16x8_t vin1_opt = vdupq_n_f16(input1[0]); +#endif + int index = 0; + if (first_scalar) { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin1 = vld1q_f16(input1 + index); + uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0_opt, vin1)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[0] >= input1[index]; + } + } else { +#ifdef ENABLE_NEON + for (; index <= element_size - 8; index += C8NUM) { + float16x8_t vin0 = vld1q_f16(input0 + index); + uint8x8_t vout = vmovn_u16(vcgeq_f16(vin0, vin1_opt)); + vst1_u8(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] = input0[index] >= input1[0]; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arithmetic_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arithmetic_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..3ad702cb2bafe03670c8cff1cba7fe80f7bbb326 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arithmetic_fp16.h @@ -0,0 +1,124 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_ARITHMETIC_FP16_H_ +#define NNACL_FP16_ARITHMETIC_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl/base/arithmetic_base.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void TileOneDimensionFp16(const void *input, void *output, int dim, size_t ndim, const int *inShape, + const int *inStrides, const int *outStrides, const int *multiple); +void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, + ArithmeticParameter *param); + +int ElementOptMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, + int element_size, bool first_scalar); +int ElementOptMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size, + bool first_scalar); +int ElementOptNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementOptGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementMulFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementMulReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementMulRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementAddFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementAddReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementAddRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1, + float16_t *out, int size, ArithmeticParameter *param); + +int ElementSubFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementSubReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementSubRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementFloorModFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementFloorDivFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementLogicalAndFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementLogicalOrFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementSquaredDifferenceFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementMaximumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); +int ElementMinimumFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementNotEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementLessFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementLessEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementGreaterFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); +int ElementGreaterEqualFp16(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_ARITHMETIC_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arithmetic_self_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arithmetic_self_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..55507760ec7d533ff8341e580c0f6a15b8ac4f62 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arithmetic_self_fp16.c @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp16/arithmetic_self_fp16.h" + +int ElementAbsFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = fabsf(input[i]); + } + return NNACL_OK; +} + +int ElementCosFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = cosf(input[i]); + } + return NNACL_OK; +} + +int ElementLogFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] <= 0) { + return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO; + } + output[i] = logf(input[i]); + } + return NNACL_OK; +} + +int ElementSquareFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = input[i] * input[i]; + } + return NNACL_OK; +} + +int ElementSqrtFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + if (input[i] < 0) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + output[i] = sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementRsqrtFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = 1.f / sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementSinFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = sinf(input[i]); + } + return NNACL_OK; +} + +int ElementLogicalNotFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = (float)(!((bool)(input[i]))); + } + return NNACL_OK; +} + +int ElementRoundFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = roundf(input[i]); + } + return NNACL_OK; +} + +int ElementFloorFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = floorf(input[i]); + } + return NNACL_OK; +} + +int ElementCeilFp16(const float16_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = ceilf(input[i]); + } + return NNACL_OK; +} + +int ElementNegativeFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + output[i] = -input[i]; + } + return NNACL_OK; +} + +int ElementReciprocalFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; ++i) { + if (input[i] == 0.0f) { + return NNACL_ERR; + } + output[i] = 1.f / input[i]; + } + return NNACL_OK; +} + +int ElementErfFp16(const float16_t *input, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = erff(input[i]); + } + return NNACL_OK; +} diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_converter_utils/multi_process.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arithmetic_self_fp16.h similarity index 30% rename from mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_converter_utils/multi_process.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp16/arithmetic_self_fp16.h index 817809ce4ba5616930da2f76b593d3ad1b6a7402..60c9e7899b591963f9d28b37d76d1b941c8514d2 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_converter_utils/multi_process.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/arithmetic_self_fp16.h @@ -13,51 +13,45 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef NNACL_FP16_ARITHMETIC_SELF_FP16_H_ +#define NNACL_FP16_ARITHMETIC_SELF_FP16_H_ -#ifndef MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H -#define MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H -#include -#include -#include "include/api/status.h" - -namespace mindspore { -struct MessageFlag { - uint64_t heartbeat = 0; - uint64_t stop = false; - uint64_t msg_len = 0; - uint64_t msg_total_len = 0; - uint64_t read_ready_flag = false; - uint64_t read_finish_flag = false; -}; - -class MultiProcess; -using ProcessFuncCall = std::function; -using CreateBufferCall = std::function; - -class MultiProcess { - public: - MultiProcess(); - ~MultiProcess(); - - Status MainProcess(const ProcessFuncCall &parent_process, const ProcessFuncCall &child_process); - Status SendMsg(const void *buffer, uint64_t msg_len); - Status ReceiveMsg(const CreateBufferCall &create_buffer_call) const; - - private: - uint8_t *shmat_addr_ = nullptr; - uint8_t *shmat_data_addr_ = nullptr; - uint64_t shmat_data_max_size_ = 0; - uint64_t memory_size_ = 0; - - bool peer_stopped_ = false; - bool stopped_ = false; - MessageFlag *send_msg_ = nullptr; - MessageFlag *receive_msg_ = nullptr; - - static void HeartbeatThreadFunc(MultiProcess *multi_process); - void HeartbeatThreadFuncInner(); - Status ParentProcess(const ProcessFuncCall &parent_process); - void ChildProcess(const ProcessFuncCall &child_process); -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementAbsFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementCosFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementLogFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementSquareFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementSqrtFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementRsqrtFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementSinFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementLogicalNotFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementRoundFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementFloorFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementCeilFp16(const float16_t *input, float16_t *output, int number); + +int ElementNegativeFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementReciprocalFp16(const float16_t *input, float16_t *output, int element_size); + +int ElementErfFp16(const float16_t *input, float16_t *output, int element_size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_ARITHMETIC_SELF_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/batchnorm_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/batchnorm_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..73a90cc2931e3031ea39023454adbb2a39a95d5e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/batchnorm_fp16.c @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/batchnorm_fp16.h" +#include +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" + +void BatchNormFp16(const float16_t *input, const float16_t *mean, const float16_t *variance, + const BatchNormStruct *param, int task_id, int thread_num, float16_t *output) { + int units_per_thread = UP_DIV(param->unit_, thread_num); + int completed_units = task_id * units_per_thread; + int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); + int cur_offset = completed_units * param->channel_; + + for (int i = 0; i < cur_unit; i++) { + const float16_t *unit_input = input + cur_offset; + float16_t *unit_output = output + cur_offset; + int c = 0; +#ifdef ENABLE_ARM + for (; c <= param->channel_ - C8NUM; c += C8NUM) { + MS_FLOAT16X8 input_8 = MS_LDQ_F16(unit_input + c); + MS_FLOAT16X8 mean_8 = MS_LDQ_F16(mean + c); + MS_FLOAT16X8 variance_8 = MS_LDQ_F16(variance + c); + MS_FLOAT16X8 variance_sqrt = MS_SQRTFX8_F16(MS_ADDQ_F16(variance_8, MS_MOVQ_F16(param->epsilon_))); + MS_FLOAT16X8 output_8 = MS_DIVQ_F16(MS_SUBQ_F16(input_8, mean_8), variance_sqrt); + MS_STQ_F16(unit_output + c, output_8); + } +#endif + for (; c < param->channel_; c++) { + float16_t variance_sqrt = sqrtf(variance[c] + param->epsilon_); + unit_output[c] = (unit_input[c] - mean[c]) / variance_sqrt; + } + cur_offset += param->channel_; + } +} + +void FusedBatchNormFp16(const float16_t *input, const float16_t *scale, const float16_t *offset, const float16_t *mean, + const float16_t *variance, const BatchNormStruct *param, int task_id, int thread_num, + float16_t *output) { + int units_per_thread = UP_DIV(param->unit_, thread_num); + int completed_units = task_id * units_per_thread; + int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); + int cur_offset = completed_units * param->channel_; + + for (int i = 0; i < cur_unit; i++) { + const float16_t *unit_input = input + cur_offset; + float16_t *unit_output = output + cur_offset; + int c = 0; +#ifdef ENABLE_ARM + for (; c <= param->channel_ - C8NUM; c += C8NUM) { + MS_FLOAT16X8 input_8 = MS_LDQ_F16(unit_input + c); + MS_FLOAT16X8 scale_8 = MS_LDQ_F16(scale + c); + MS_FLOAT16X8 offset_8 = MS_LDQ_F16(offset + c); + MS_FLOAT16X8 mean_8 = MS_LDQ_F16(mean + c); + MS_FLOAT16X8 variance_8 = MS_LDQ_F16(variance + c); + MS_FLOAT16X8 variance_sqrt = MS_SQRTFX8_F16(MS_ADDQ_F16(variance_8, MS_MOVQ_F16(param->epsilon_))); + MS_FLOAT16X8 norm_val = MS_DIVQ_F16(MS_SUBQ_F16(input_8, mean_8), variance_sqrt); + MS_FLOAT16X8 output_8 = MS_ADDQ_F16(MS_MULQ_F16(norm_val, scale_8), offset_8); + MS_STQ_F16(unit_output + c, output_8); + } +#endif + for (; c < param->channel_; c++) { + float16_t variance_sqrt = sqrtf(variance[c] + param->epsilon_); + float16_t norm_val = (unit_input[c] - mean[c]) / variance_sqrt; + unit_output[c] = norm_val * scale[c] + offset[c]; + } + cur_offset += param->channel_; + } +} + +void FusedBatchNormFp16MeanVar(const float16_t *input, float16_t *run_mean, float16_t *run_var, + const BatchNormStruct *param, float16_t *save_mean, float16_t *save_var) { + const float N = (float)param->unit_; + const float VN = N; + const float VNUB = (N > 1.0f) ? (N - 1.0f) : 1.0f; + const float momentum = (1.0f - param->momentum_); + + for (int i = 0; i < param->unit_; i++) { + for (int c = 0; c < param->channel_; c++) { + int idx = i * param->channel_ + c; + run_mean[c] += input[idx]; + } + } + for (int c = 0; c < param->channel_; c++) { + run_mean[c] /= (float16_t)N; + } + for (int i = 0; i < param->unit_; i++) { + for (int c = 0; c < param->channel_; c++) { + int idx = i * param->channel_ + c; + run_var[c] += (float16_t)((float)(input[idx] - run_mean[c]) * (float)(input[idx] - run_mean[c])); + } + } + for (int c = 0; c < param->channel_; c++) { + float unbiased_var = ((float)run_var[c] / VNUB); + run_var[c] = (float16_t)((float)run_var[c] / VN); + save_mean[c] = (float16_t)(momentum * (float)save_mean[c] + (1.0f - momentum) * (float)run_mean[c]); + save_var[c] = (float16_t)(momentum * (float)save_var[c] + (1.0f - momentum) * unbiased_var); + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/batchnorm_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/batchnorm_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..eab97fb99bd86584cd65da9e08419114db17f586 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/batchnorm_fp16.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_BATCHNORM_FP16_H_ +#define NNACL_FP16_BATCHNORM_FP16_H_ + +#include "nnacl/kernel/batch_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void BatchNormFp16(const float16_t *input, const float16_t *mean, const float16_t *variance, + const BatchNormStruct *param, int task_id, int thread_num, float16_t *output); +void FusedBatchNormFp16(const float16_t *input, const float16_t *scale, const float16_t *offset, const float16_t *mean, + const float16_t *variance, const BatchNormStruct *param, int task_id, int thread_num, + float16_t *output); +void FusedBatchNormFp16MeanVar(const float16_t *input, float16_t *run_mean, float16_t *run_var, + const BatchNormStruct *param, float16_t *save_mean, float16_t *save_var); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_BATCHNORM_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/cast_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/cast_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..e9a5f96252910ce71f755df67f72792cca73732e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/cast_fp16.h @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_CAST_FP16_H_ +#define NNACL_FP16_CAST_FP16_H_ + +#include "nnacl/op_base.h" +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) +#include + +#ifdef __cplusplus +extern "C" { +#endif + +inline void BoolToFloat16(const bool *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +inline void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +inline void Float16ToInt32(const float16_t *input, int32_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int32_t)input[i]; + } +} + +inline void Float16ToInt64(const float16_t *input, int64_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int64_t)input[i]; + } +} + +#ifdef ENABLE_ARM64 +inline void Float32ToFloat16(const float *__restrict input, float16_t *__restrict output, int number) { + int count = (number & ~(C8NUM - 1)); + int i = 0; + for (; i < count; i += C8NUM) { + float32x4_t in1 = vld1q_f32(input + i); + float16x4_t out1 = vcvt_f16_f32(in1); + float32x4_t in2 = vld1q_f32(input + i + 4); + float16x4_t out2 = vcvt_f16_f32(in2); + float16x8_t out = vcombine_f16(out1, out2); + vst1q_f16(output + i, out); + } + for (; i < number; ++i) { + output[i] = (float16_t)input[i]; + } +} + +inline void Float16ToFloat32(const float16_t *__restrict input, float *__restrict output, int number) { + int count = number & ~(C8NUM - 1); + int i = 0; + for (; i < count; i += C8NUM) { + float16x8_t in = vld1q_f16(input + i); + float16x4_t in1 = vget_low_f16(in); + float16x4_t in2 = vget_high_f16(in); + float32x4_t out1 = vcvt_f32_f16(in1); + vst1q_f32(output + i, out1); + float32x4_t out2 = vcvt_f32_f16(in2); + vst1q_f32(output + i + C4NUM, out2); + } + for (; i < number; ++i) { + output[i] = (float)input[i]; + } +} +#else +void Float32ToFloat16(const float *input, float16_t *output, int number); + +void Float16ToFloat32(const float16_t *input, float *output, int number); +#endif + +#ifdef __cplusplus +} +#endif +#endif +#endif // NNACL_FP16_CAST_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/common_func_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/common_func_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..cdba0fdd2585fe5af73cc7680a36e2b20ca9ec15 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/common_func_fp16.c @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/common_func_fp16.h" + +void PostConvFuncCommFp16(float16_t *out_ptr, const float16_t *src_ptr_, const float16_t *bias_ptr, + size_t output_channel, size_t plane_size, size_t oc_stride, size_t hw_stride, + ActType act_type, int size) { + if (size == 0) { + return; + } + for (int oc = 0; oc < output_channel; oc++) { + int oc_div = oc / size, oc_mod = oc % size; + for (int hw = 0; hw < plane_size; hw++) { + int src_index = oc_div * size * hw_stride + hw * size + oc_mod; + int dst_index = hw * oc_stride + oc; + float16_t value = src_ptr_[src_index]; + if (bias_ptr != NULL) { + value = value + bias_ptr[oc]; + } + value = (act_type == ActType_Relu || act_type == ActType_Relu6) ? (MSMAX(0.f, value)) : (value); + value = (act_type == ActType_Relu6) ? (MSMIN(6.f, value)) : (value); + out_ptr[dst_index] = value; + } + } + return; +} + +void PostConvFuncFp16C8(const float16_t *c8_out, float16_t *nhwc_out, const float16_t *bias, size_t oc, size_t plane, + size_t oc_stride, ActType act_type) { +#ifdef ENABLE_ARM64 + size_t oc8mod = oc % C8NUM; + size_t oc8div = oc - oc8mod; + size_t stride_size = oc_stride * sizeof(float16_t); + PostFuncBiasReluC8Fp16(nhwc_out, c8_out, bias, oc8div, oc8mod, plane, stride_size, act_type); +#else + PostConvFuncCommFp16(nhwc_out, c8_out, bias, oc, plane, oc_stride, plane, act_type, C8NUM); +#endif +} + +void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const float16_t *bias, size_t oc, size_t plane, + size_t plane_stride, ActType act_type) { +#ifdef ENABLE_ARM64 + size_t oc4mod = oc % C4NUM; + size_t oc4div = oc - oc4mod; + size_t stride_size = (plane_stride - plane) * C4NUM * sizeof(float16_t); + PostFuncBiasReluC4Fp16(nhwc_out, c4_out, bias, oc4div, oc4mod, plane, stride_size, act_type); +#else + PostConvFuncCommFp16(nhwc_out, c4_out, bias, oc, plane, oc, plane_stride, act_type, C4NUM); +#endif +} diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/ms_tensor_ref.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/common_func_fp16.h similarity index 36% rename from mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/ms_tensor_ref.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp16/common_func_fp16.h index c2ecd339fbe2cf24e5281f3f878ce65b4b116305..595830d9f185c07baa56960ee2386646e117c9d4 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/ms_tensor_ref.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/common_func_fp16.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,37 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_CXX_API_ACL_VM_MS_TENSOR_REF_H -#define MINDSPORE_CCSRC_CXX_API_ACL_VM_MS_TENSOR_REF_H - -#include -#include -#include -#include "include/api/types.h" -#include "base/base_ref.h" - -namespace mindspore { -class MSTensorRef : public BaseRef { - public: - MS_DECLARE_PARENT(MSTensorRef, BaseRef); - - static VectorRef Convert(const std::vector &tensors); - static std::vector Convert(const BaseRef &args); - - explicit MSTensorRef(const MSTensor &tensor) : ms_tensor_(tensor) {} - ~MSTensorRef() override = default; - - const MSTensor &GetTensor() const { return ms_tensor_; } - std::shared_ptr copy() const override; - - uint32_t type() const override { return tid(); } - std::string ToString() const override { return ms_tensor_.Name(); } - bool operator==(const BaseRef &other) const override; - - private: - static std::vector ConvertTuple(const VectorRef &args); - - MSTensor ms_tensor_; -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_ACL_VM_MS_TENSOR_REF_H +#ifndef NNACL_FP16_COMMON_FUNC_FP16_H_ +#define NNACL_FP16_COMMON_FUNC_FP16_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* deconv common */ +void PostConvFuncFp16C8(const float16_t *c8_out_ptr, float16_t *out_ptr, const float16_t *bias_ptr, + size_t output_channel, size_t plane_size, size_t stride, ActType act_type); +void PostFuncBiasReluC8Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type); + +/* deconv winograd */ +void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const float16_t *bias, size_t output_channel, + size_t plane_size, size_t plane_stride, ActType act_type); +void PostFuncBiasReluC4Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc4div, size_t oc4mod, + size_t plane_size, size_t plane_stride, size_t relu_type); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_COMMON_FUNC_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/constant_of_shape_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/constant_of_shape_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..c77e057e5cb22f2b826c7c65bae5ba557c0876fa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/constant_of_shape_fp16.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ +#define NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/constant_of_shape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +#ifdef __cplusplus +#ifdef ENABLE_FP16 +inline int ConstantOfShapeFp16(float16_t *output, int start, int end, float16_t value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} +#endif +} +#endif + +#endif // NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/conv_depthwise_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/conv_depthwise_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..14bd6c995e1860e346c3a82c5fbaa8741d7ad054 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/conv_depthwise_fp16.c @@ -0,0 +1,842 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/conv_depthwise_fp16.h" +#include +#include "nnacl/fp16/activation_fp16.h" + +#ifdef ENABLE_ARM82_A32 +void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step) { + for (int i = 0; i < num_pixels; i++) { + for (int c = 0; c < output_channel; c++) { + *output_ptr++ += weight_ptr[c] * input_ptr[c]; + } + input_ptr += input_step; + } +} +#endif + +#ifdef ENABLE_ARM +static void ConvDw3x3RowLeftFp16(const float16_t *src, float16_t *line, int lw, int channel) { + MS_FLOAT16X8 v0, v1, v2, v3; + v0 = MS_MOVQ_F16((float16_t)0.0); + int ic = 0; + for (; ic < channel - 7; ic += 8) { + v1 = MS_LDQ_F16(src + ic); + v2 = MS_LDQ_F16(src + channel + ic); + v3 = MS_LDQ_F16(src + 2 * channel + ic); + MS_FLOAT16X8 b0 = MS_SUBQ_F16(v0, v2); + MS_FLOAT16X8 b1 = MS_ADDQ_F16(v1, v2); + MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1); + MS_FLOAT16X8 b3 = MS_SUBQ_F16(v3, v1); + MS_STQ_F16(line + lw * ic, b0); + MS_STQ_F16(line + lw * ic + 8, b1); + MS_STQ_F16(line + lw * ic + 16, b2); + MS_STQ_F16(line + lw * ic + 24, b3); + } + if (ic < channel) { + float16_t *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 16, 0, 16); + memset(remain_line + 24, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float16_t d1 = src[i + ic]; + float16_t d2 = src[i + ic + channel]; + float16_t d3 = src[i + ic + 2 * channel]; + remain_line[i] = (float16_t)0.0 - d2; + remain_line[i + 8] = d1 + d2; + remain_line[i + 16] = d2 - d1; + remain_line[i + 24] = d3 - d1; + } + } +} + +static void ConvDw3x3RowMiddleFp16(const float16_t *src, float16_t *line, int lw, int channel) { + MS_FLOAT16X8 v0, v1, v2, v3; + int ic = 0; + for (; ic < channel - 7; ic += 8) { + v0 = MS_LDQ_F16(src + ic); + v1 = MS_LDQ_F16(src + channel + ic); + v2 = MS_LDQ_F16(src + 2 * channel + ic); + v3 = MS_LDQ_F16(src + 3 * channel + ic); + MS_FLOAT16X8 b0 = MS_SUBQ_F16(v0, v2); + MS_FLOAT16X8 b1 = MS_ADDQ_F16(v1, v2); + MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1); + MS_FLOAT16X8 b3 = MS_SUBQ_F16(v3, v1); + MS_STQ_F16(line + lw * ic, b0); + MS_STQ_F16(line + lw * ic + 8, b1); + MS_STQ_F16(line + lw * ic + 16, b2); + MS_STQ_F16(line + lw * ic + 24, b3); + } + if (ic < channel) { + float16_t *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 16, 0, 16); + memset(remain_line + 24, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float16_t d0 = src[i + ic]; + float16_t d1 = src[i + ic + channel]; + float16_t d2 = src[i + ic + 2 * channel]; + float16_t d3 = src[i + ic + 3 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 8] = d1 + d2; + remain_line[i + 16] = d2 - d1; + remain_line[i + 24] = d3 - d1; + } + } +} + +static void ConvDw3x3RowRightFp16(const float16_t *src, float16_t *line, int lw, int channel) { + MS_FLOAT16X8 v0, v1, v2, v3; + int ic = 0; + v3 = MS_MOVQ_F16((float16_t)0.0); + for (; ic < channel - 7; ic += 8) { + v0 = MS_LDQ_F16(src + ic); + v1 = MS_LDQ_F16(src + channel + ic); + v2 = MS_LDQ_F16(src + 2 * channel + ic); + MS_FLOAT16X8 b0 = MS_SUBQ_F16(v0, v2); + MS_FLOAT16X8 b1 = MS_ADDQ_F16(v1, v2); + MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1); + MS_FLOAT16X8 b3 = MS_SUBQ_F16(v3, v1); + MS_STQ_F16(line + lw * ic, b0); + MS_STQ_F16(line + lw * ic + 8, b1); + MS_STQ_F16(line + lw * ic + 16, b2); + MS_STQ_F16(line + lw * ic + 24, b3); + } + if (ic < channel) { + float16_t *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 16, 0, 16); + memset(remain_line + 24, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float16_t d0 = src[i + ic]; + float16_t d1 = src[i + ic + channel]; + float16_t d2 = src[i + ic + 2 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 8] = d1 + d2; + remain_line[i + 16] = d2 - d1; + remain_line[i + 24] = (float16_t)0.0 - d1; + } + } +} + +static void ConvDw3x3RowSingleFp16(const float16_t *src, float16_t *line, int lw, int channel) { + MS_FLOAT16X8 v0, v1, v2; + int ic = 0; + v2 = MS_MOVQ_F16((float16_t)0.0); + for (; ic < channel - 7; ic += 8) { + v0 = MS_LDQ_F16(src + ic); + v1 = MS_LDQ_F16(src + channel + ic); + MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1); + MS_STQ_F16(line + lw * ic, v0); + MS_STQ_F16(line + lw * ic + 8, v1); + MS_STQ_F16(line + lw * ic + 16, b2); + memset(line + lw * ic + 24, 0, 16); + } + if (ic < channel) { + float16_t *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 16, 0, 16); + memset(remain_line + 24, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float16_t d0 = src[i + ic]; + float16_t d1 = src[i + ic + channel]; + remain_line[i] = d0; + remain_line[i + 8] = d1; + remain_line[i + 16] = (float16_t)0.0 - d1; + } + } +} + +static void ConvDw3x3InitTopFp16(const float16_t *src, float16_t **lines, int width, int channel) { + float16_t *line0 = lines[0]; + float16_t *line1 = lines[1]; + float16_t *line2 = lines[2]; + int c8 = UP_ROUND(channel, C8NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(line0, 0, c8 * lw * sizeof(float16_t)); + ConvDw3x3RowLeftFp16(src, line1, lw, channel); + ConvDw3x3RowLeftFp16(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowMiddleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRightFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowRightFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowSingleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } +} + +static void ConvDw3x3InitRowFp16(const float16_t *src, float16_t **lines, int width, int channel) { + float16_t *line0 = lines[0]; + float16_t *line1 = lines[1]; + float16_t *line2 = lines[2]; + int lw = UP_DIV(width, C2NUM) * C4NUM; + ConvDw3x3RowLeftFp16(src - width * channel, line0, lw, channel); + ConvDw3x3RowLeftFp16(src, line1, lw, channel); + ConvDw3x3RowLeftFp16(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddleFp16(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 8, lw, channel); + ConvDw3x3RowMiddleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowMiddleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRightFp16(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 8, lw, channel); + ConvDw3x3RowRightFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowRightFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingleFp16(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 8, lw, channel); + ConvDw3x3RowSingleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel); + ConvDw3x3RowSingleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel); + } +} + +static void ConvDw3x3RowFp16(const float16_t *src, float16_t **lines, int width, int channel) { + float16_t *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c8 = UP_ROUND(channel, C8NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(tmp, 0, c8 * lw * sizeof(float16_t)); + ConvDw3x3RowLeftFp16(src, tmp, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddleFp16(src + (ow - 1) * channel, tmp + 2 * ow * 8, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRightFp16(src + (ow - 1) * channel, tmp + 2 * ow * 8, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingleFp16(src + (ow - 1) * channel, tmp + 2 * ow * 8, lw, channel); + } +} + +static void ConvDw3x3BottomFp16(float16_t **lines, int width, int channel) { + float16_t *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c8 = UP_ROUND(channel, C8NUM); + memset(tmp, 0, UP_DIV(width, C2NUM) * c8 * C4NUM * sizeof(float16_t)); +} + +void ConvDw3x3LineFp16(float16_t *dst, float16_t **lines, const float16_t *weight, const float16_t *bias_data, + int width, int ori_channel, bool relu, bool relu6) { + int channel = ori_channel; + float16_t *line0 = lines[0]; + float16_t *line1 = lines[1]; + float16_t *line2 = lines[2]; + for (; channel > 0; channel -= 8) { + MS_FLOAT16X8 bias = MS_LDQ_F16(bias_data); + bias_data += 8; + MS_FLOAT16X8 g00 = MS_LDQ_F16(weight); + MS_FLOAT16X8 g01 = MS_LDQ_F16(weight + 8); + MS_FLOAT16X8 g02 = MS_LDQ_F16(weight + 16); + MS_FLOAT16X8 g03 = MS_LDQ_F16(weight + 24); + MS_FLOAT16X8 g10 = MS_LDQ_F16(weight + 32); + MS_FLOAT16X8 g11 = MS_LDQ_F16(weight + 40); + MS_FLOAT16X8 g12 = MS_LDQ_F16(weight + 48); + MS_FLOAT16X8 g13 = MS_LDQ_F16(weight + 56); + MS_FLOAT16X8 g20 = MS_LDQ_F16(weight + 64); + MS_FLOAT16X8 g21 = MS_LDQ_F16(weight + 72); + MS_FLOAT16X8 g22 = MS_LDQ_F16(weight + 80); + MS_FLOAT16X8 g23 = MS_LDQ_F16(weight + 88); + weight += 96; + float16_t *cur_dst = dst; + int ow = 0; + for (; ow < width - 1; ow += 2) { + MS_FLOAT16X8 acc0 = MS_MULQ_F16(MS_LDQ_F16(line0), g00); + MS_FLOAT16X8 acc1 = MS_MULQ_F16(MS_LDQ_F16(line0 + 8), g01); + MS_FLOAT16X8 acc2 = MS_MULQ_F16(MS_LDQ_F16(line0 + 16), g02); + MS_FLOAT16X8 acc3 = MS_MULQ_F16(MS_LDQ_F16(line0 + 24), g03); + line0 += 32; + acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line1), g10); + acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line1 + 8), g11); + acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line1 + 16), g12); + acc3 = MS_FMAQ_F16(acc3, MS_LDQ_F16(line1 + 24), g13); + + line1 += 32; + acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line2), g20); + acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line2 + 8), g21); + acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line2 + 16), g22); + acc3 = MS_FMAQ_F16(acc3, MS_LDQ_F16(line2 + 24), g23); + + line2 += 32; + MS_FLOAT16X8 res0 = MS_ADDQ_F16(acc0, MS_ADDQ_F16(acc2, acc1)); + MS_FLOAT16X8 res1 = MS_ADDQ_F16(acc1, MS_SUBQ_F16(acc3, acc2)); + res0 = MS_ADDQ_F16(res0, bias); + res1 = MS_ADDQ_F16(res1, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F16(res0, MS_MOVQ_F16((float16_t)0.0)); + res1 = MS_MAXQ_F16(res1, MS_MOVQ_F16((float16_t)0.0)); + } + if (relu6) { + res0 = MS_MINQ_F16(res0, MS_MOVQ_F16((float16_t)6.0)); + res1 = MS_MINQ_F16(res1, MS_MOVQ_F16((float16_t)6.0)); + } + if (channel >= 8) { + MS_STQ_F16(cur_dst, res0); + MS_STQ_F16(cur_dst + ori_channel, res1); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = res0[i]; + cur_dst[ori_channel + i] = res1[i]; + } + } + cur_dst += 2 * ori_channel; + } + if (ow < width) { + MS_FLOAT16X8 acc0 = MS_MULQ_F16(MS_LDQ_F16(line0), g00); + MS_FLOAT16X8 acc1 = MS_MULQ_F16(MS_LDQ_F16(line0 + 8), g01); + MS_FLOAT16X8 acc2 = MS_MULQ_F16(MS_LDQ_F16(line0 + 16), g02); + line0 += 32; + acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line1), g10); + acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line1 + 8), g11); + acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line1 + 16), g12); + + line1 += 32; + acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line2), g20); + acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line2 + 8), g21); + acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line2 + 16), g22); + + line2 += 32; + MS_FLOAT16X8 res0 = MS_ADDQ_F16(acc0, MS_ADDQ_F16(acc2, acc1)); + res0 = MS_ADDQ_F16(res0, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F16(res0, MS_MOVQ_F16((float16_t)0.0)); + } + if (relu6) { + res0 = MS_MINQ_F16(res0, MS_MOVQ_F16((float16_t)6.0)); + } + if (channel >= 8) { + MS_STQ_F16(cur_dst, res0); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = res0[i]; + } + } + } + dst += 8; + } +} + +void ConvDw3x3Fp16(float16_t *output_data, float16_t *buffer, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) { + int units = UP_DIV(conv_param->output_w_, C2NUM); + int c8 = UP_ROUND(conv_param->input_channel_, C8NUM); + int line = conv_param->input_channel_ * conv_param->input_w_; + + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const float16_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float16_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + float16_t *line0 = buffer; + float16_t *line1 = buffer + units * c8 * C4NUM; + float16_t *line2 = buffer + units * c8 * C4NUM * 2; + float16_t *lines[3] = {line0, line1, line2}; + int oh = start_oh; + if (oh == 0) { + // input trans + ConvDw3x3InitTopFp16(src, lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3InitRowFp16(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3LineFp16(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + for (oh = start_oh + 1; oh < end_oh - 1; oh++) { + // input trans + ConvDw3x3RowFp16(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + // dst calc and trans + ConvDw3x3LineFp16(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, + conv_param->input_channel_, relu, relu6); + } + if (oh == conv_param->output_h_ - 1) { + // input trans + ConvDw3x3BottomFp16(lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3RowFp16(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3LineFp16(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + } +} + +#endif + +void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, int task_id) { + NNACL_CHECK_ZERO_RETURN(conv_param->stride_w_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_); + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + for (int b = 0; b < conv_param->output_batch_; b++) { + const float16_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float16_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + float16_t *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(dst_data + ow * conv_param->output_channel_, bias_data, conv_param->output_channel_ * sizeof(float16_t)); + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + + const float16_t *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const float16_t *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int out_w_start = MSMAX( + 0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_); + int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ - + conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / + conv_param->stride_w_); + + float16_t *dst_w = dst_data + out_w_start * conv_param->output_channel_; + int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + + const float16_t *src_kw = src_kh + iw_origin * conv_param->input_channel_; + int num_pixels = out_w_end - out_w_start; + ConvDwFp16Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step); + weight_kh += conv_param->output_channel_; + } + } + if (relu) { + ReluFp16(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_); + } + if (relu6) { + Relu6Fp16(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_); + } + } + } +} + +/*conv depthwise fp16 begin*/ +void DepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + int height, int width, int in_kh_step, int in_kw_step, int kernel_w_step, bool is_relu, + bool is_relu6) { + for (int c = 0; c < C8NUM; c++) { + dst[c] = 0; + } + const float16_t *src_kh = src; + const float16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const float16_t *src_kw = src_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + float16x8_t src_8 = vld1q_f16(src_kw); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst, dst_8); + + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w_step; + } // kernel_h loop + for (int c = 0; c < C8NUM; c++) { + dst[c] += bias[c]; + dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]); + dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]); + } +} + +void DepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top, + int bottom, int left, int right, const ConvParameter *conv_param, + const SlidingWindowParam *sliding) { + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + float16_t *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float16_t *src_h = src + ih * sliding->in_h_step_; + + float16_t *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float16_t *src_w = src_h + iw * sliding->block_channel_; + + const float16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; +#ifdef ENABLE_ARM64 + ConvDwFp16Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_ * sizeof(float16_t), sliding->in_kw_step_ * sizeof(float16_t), + conv_param->kernel_w_ * C8NUM * sizeof(float16_t), relu, relu6); +#else + DepthwiseBorderPixelFp16(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C8NUM, relu, relu6); +#endif + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + int height, int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, + int in_sh_step, int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { + float16_t *dst_h = dst; + const float16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + float16_t *dst_w = dst_h; + const float16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const float16_t *src_kh = src_w; + const float16_t *weight_kh = weight; + for (int c = 0; c < C8NUM; c++) { + dst_w[c] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const float16_t *src_kw = src_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { +#ifdef ENABLE_ARM64 + float16x8_t src_8 = vld1q_f16(src_kw); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst_w); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst_w, dst_8); +#else + for (int c = 0; c < C8NUM; c++) { + dst_w[c] += src_kw[c] * weight_kw[c]; + } +#endif + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + // add biad relu + for (int c = 0; c < C8NUM; c++) { + dst_w[c] += bias[c]; + dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]); + dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} +#endif + +// conv depthwise fp16: sliding window +void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + const float16_t *src = input_data; + float16_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float16_t *src_data = src + oc * C8NUM; + float16_t *dst_data = dst + oc * C8NUM; + const float16_t *weight = weight_data + oc * sliding->kernel_step_; + const float16_t *bias = bias_data + oc * C8NUM; + DepthwiseBorderFp16(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, + sliding); + DepthwiseBorderFp16(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, + conv_param->output_w_, conv_param, sliding); + DepthwiseBorderFp16(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DepthwiseBorderFp16(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + const float16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + float16_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM64 + ConvDwFp16Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float16_t), + sliding->block_channel_ * sizeof(float16_t), sliding->in_sh_step_ * sizeof(float16_t), + sliding->in_sw_step_ * sizeof(float16_t), sliding->in_kh_step_ * sizeof(float16_t), + sliding->in_kw_step_ * sizeof(float16_t), relu, relu6); +#else + DepthwiseCenterFp16(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, + sliding->in_kh_step_, sliding->in_kw_step_, relu, relu6); +#endif + } + } // output C8 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nchwc8 +} +/*conv depthwise fp16 end*/ + +/*deconv depthwise fp16 begin*/ +void DeconvDepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int height, + int width, int in_kh_step, int in_kw_step, int kernel_w_step) { + float16_t *dst_kh = dst; + const float16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + float16_t *dst_kw = dst_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + float16x8_t src_8 = vld1q_f16(src); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst_kw); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst_kw, dst_8); + + dst_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w_step; + } // kernel_h loop +} + +void DeconvDepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int top, int bottom, + int left, int right, const ConvParameter *conv_param, + const SlidingWindowParam *sliding) { + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_); + const float16_t *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + float16_t *dst_h = dst + oh * sliding->in_h_step_; + + const float16_t *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + float16_t *dst_w = dst_h + ow * sliding->block_channel_; + + const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; + float16_t *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; +#ifdef ENABLE_ARM64 + DeconvDwFp16Border(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_ * sizeof(float16_t), sliding->in_kw_step_ * sizeof(float16_t), + conv_param->kernel_w_ * C8NUM * sizeof(float16_t)); +#else + DeconvDepthwiseBorderPixelFp16(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C8NUM); +#endif + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DeconvDepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, + int in_sw_step, int in_kh_step, int in_kw_step) { + float16_t *dst_h = dst; + const float16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + float16_t *dst_w = dst_h; + const float16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float16_t *dst_kh = dst_w; + const float16_t *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + float16_t *dst_kw = dst_kh; + const float16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { +#ifdef ENABLE_NEON + float16x8_t src_8 = vld1q_f16(src_w); + float16x8_t weight_8 = vld1q_f16(weight_kw); + float16x8_t dst_8 = vld1q_f16(dst_kw); + dst_8 = vfmaq_f16(dst_8, src_8, weight_8); + vst1q_f16(dst_kw, dst_8); +#else + for (int c = 0; c < C8NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } +#endif + dst_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} +#endif + +void DeconvDepthwisePostFuncFp16(float16_t *dst, const float16_t *bias, int block_channel, + const ConvParameter *conv_param) { + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + int hw = conv_param->output_h_ * conv_param->output_w_; + int hw8 = hw / C8NUM * C8NUM; + float16x8_t bias_value = vld1q_f16(bias); + float16x8_t zero = vdupq_n_f16(0.0f); + float16x8_t six = vdupq_n_f16(6.0f); + + int i = 0; + for (; i < hw8; i += C8NUM) { + float16_t *dst_ptr = dst + i * block_channel; + float16x8_t dst_value0 = vld1q_f16(dst_ptr); + float16x8_t dst_value1 = vld1q_f16(dst_ptr + C1NUM * block_channel); + float16x8_t dst_value2 = vld1q_f16(dst_ptr + C2NUM * block_channel); + float16x8_t dst_value3 = vld1q_f16(dst_ptr + C3NUM * block_channel); + float16x8_t dst_value4 = vld1q_f16(dst_ptr + C4NUM * block_channel); + float16x8_t dst_value5 = vld1q_f16(dst_ptr + C5NUM * block_channel); + float16x8_t dst_value6 = vld1q_f16(dst_ptr + C6NUM * block_channel); + float16x8_t dst_value7 = vld1q_f16(dst_ptr + C7NUM * block_channel); + + dst_value0 = vaddq_f16(dst_value0, bias_value); + dst_value1 = vaddq_f16(dst_value1, bias_value); + dst_value2 = vaddq_f16(dst_value2, bias_value); + dst_value3 = vaddq_f16(dst_value3, bias_value); + dst_value4 = vaddq_f16(dst_value4, bias_value); + dst_value5 = vaddq_f16(dst_value5, bias_value); + dst_value6 = vaddq_f16(dst_value6, bias_value); + dst_value7 = vaddq_f16(dst_value7, bias_value); + if (relu) { + dst_value0 = vmaxq_f16(dst_value0, zero); + dst_value1 = vmaxq_f16(dst_value1, zero); + dst_value2 = vmaxq_f16(dst_value2, zero); + dst_value3 = vmaxq_f16(dst_value3, zero); + dst_value4 = vmaxq_f16(dst_value4, zero); + dst_value5 = vmaxq_f16(dst_value5, zero); + dst_value6 = vmaxq_f16(dst_value6, zero); + dst_value7 = vmaxq_f16(dst_value7, zero); + } + if (relu6) { + dst_value0 = vminq_f16(dst_value0, six); + dst_value1 = vminq_f16(dst_value1, six); + dst_value2 = vminq_f16(dst_value2, six); + dst_value3 = vminq_f16(dst_value3, six); + dst_value4 = vminq_f16(dst_value4, six); + dst_value5 = vminq_f16(dst_value5, six); + dst_value6 = vminq_f16(dst_value6, six); + dst_value7 = vminq_f16(dst_value7, six); + } + vst1q_f16(dst_ptr, dst_value0); + vst1q_f16(dst_ptr + C1NUM * block_channel, dst_value1); + vst1q_f16(dst_ptr + C2NUM * block_channel, dst_value2); + vst1q_f16(dst_ptr + C3NUM * block_channel, dst_value3); + vst1q_f16(dst_ptr + C4NUM * block_channel, dst_value4); + vst1q_f16(dst_ptr + C5NUM * block_channel, dst_value5); + vst1q_f16(dst_ptr + C6NUM * block_channel, dst_value6); + vst1q_f16(dst_ptr + C7NUM * block_channel, dst_value7); + } + + float16_t *dst_ptr = dst + i * block_channel; + for (; i < hw; i++, dst_ptr += block_channel) { + float16x8_t dst_value0 = vld1q_f16(dst_ptr); + dst_value0 = vaddq_f16(dst_value0, bias_value); + dst_value0 = relu ? vmaxq_f16(dst_value0, zero) : dst_value0; + dst_value0 = relu6 ? vminq_f16(dst_value0, six) : dst_value0; + vst1q_f16(dst_ptr, dst_value0); + } +} + +// deconv depthwise fp16: sliding window +void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + const float16_t *src = input_data; + float16_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float16_t *src_data = src + oc * C8NUM; + float16_t *dst_data = dst + oc * C8NUM; + const float16_t *weight = weight_data + oc * sliding->kernel_step_; + const float16_t *bias = bias_data + oc * C8NUM; + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, + sliding); + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, + conv_param->input_w_, conv_param, sliding); + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DeconvDepthwiseBorderFp16(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->input_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + float16_t *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; + const float16_t *in_t = + src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM64 + DeconvDwFp16Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float16_t), + sliding->block_channel_ * sizeof(float16_t), sliding->in_sh_step_ * sizeof(float16_t), + sliding->in_sw_step_ * sizeof(float16_t), sliding->in_kh_step_ * sizeof(float16_t), + sliding->in_kw_step_ * sizeof(float16_t)); +#else + DeconvDepthwiseCenterFp16(out_t, in_t, weight, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, + sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); +#endif + } + DeconvDepthwisePostFuncFp16(dst_data, bias, sliding->block_channel_, conv_param); + } // output C8 loop + src += sliding->out_step_; + dst += sliding->in_step_; + } // batch loop + // output nchwc8 +} +/*deconv depthwise fp16 end*/ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/conv_depthwise_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/conv_depthwise_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..d6b3d542b9151c85db100810fde390e6f3ff7020 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/conv_depthwise_fp16.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_CONV_DEPTHWISE_FP16_H_ +#define NNACL_FP16_CONV_DEPTHWISE_FP16_H_ + +#include "nnacl/conv_parameter.h" +#include "nnacl/fp32/conv_depthwise_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *filter_ptr, size_t num_pixels, + size_t input_channel, size_t input_step); +#ifdef ENABLE_ARM64 +void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, + size_t relu6); +void ConvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, + size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, + size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, + size_t relu, size_t relu6); +void DeconvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w); +void DeconvDwFp16Center(float16_t *dst, const float16_t *src, const float16_t *weight, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +#endif + +void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, int task_id); + +void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +#ifdef ENABLE_ARM +void ConvDw3x3LineFp16(float16_t *dst, float16_t **lines, const float16_t *weight, const float16_t *bias_data, + int width, int ori_channel, bool relu, bool relu6); +void ConvDw3x3Fp16(float16_t *output_data, float16_t *buffer, const float16_t *input_data, const float16_t *weight_data, + const float16_t *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_CONV_DEPTHWISE_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/conv_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/conv_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..187351c8ac60968f36c268c76f26f5d3e643c38f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/conv_fp16.c @@ -0,0 +1,334 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/conv_fp16.h" +#include +#include "nnacl/fp16/pack_fp16.h" +#include "nnacl/fp16/winograd_transform_fp16.h" +#include "nnacl/fp16/matmul_fp16.h" + +void Im2ColPackUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int kernel_plane = kernel_h * kernel_w; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + int input_stride = (input_h * in_w + input_w) * in_channel; + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h)); + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + if (dilation_h == 1 && dilation_w == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, + (kw_e - kw_s) * in_channel * sizeof(float16_t)); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int n = kw_s; n < kw_e; n++) { + int input_x_stride = input_y_stride + n * dilation_w * in_channel; + int input_plane_offset = (j * kernel_w + n) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float16_t)); + } // kernel_w loop + } // kernel_h loop + } + } // tile num loop +} + +// fp16 convolution common (im2col+gemm) +void ConvFp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight, + const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, + const ConvParameter *conv_param) { +#ifdef ENABLE_ARM64 + const int tile_n = 16; +#else + const int tile_n = 12; +#endif + NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_); + NNACL_CHECK_ZERO_RETURN(tile_n); + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int block_per_thread = UP_DIV(UP_DIV(output_hw, tile_n), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * tile_n; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * tile_n); + if (start_hw >= end_hw) { + return; + } + int out_stride = conv_param->output_channel_ * tile_n; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * tile_n; + col_major_input += task_id * deep * tile_n; + size_t input_size = deep * tile_n * sizeof(float16_t); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * conv_param->output_channel_ * output_hw + start_hw * conv_param->output_channel_; + for (int i = start_hw; i < end_hw; i += tile_n, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, tile_n); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_cal_row, i); +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep); +#else + RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep); +#endif + MatMulFp16(col_major_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep, + real_cal_row, conv_param->output_channel_, conv_param->output_channel_, OutType_Nhwc); + } + } +} + +void ConvOutNc8hw8Fp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight, + const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, + const ConvParameter *conv_param) { +#ifdef ENABLE_ARM64 + const int tile_n = 16; +#else + const int tile_n = 12; +#endif + NNACL_CHECK_ZERO_RETURN(conv_param->op_parameter_.thread_num_); + NNACL_CHECK_ZERO_RETURN(tile_n); + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int input_block = UP_DIV(output_hw, tile_n); + int block_per_thread = UP_DIV(input_block, conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int end_block = MSMIN(start_block + block_per_thread, input_block); + if (start_block >= end_block) { + return; + } + int weight_block = UP_DIV(conv_param->output_channel_, C8NUM); + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += deep * tile_n * task_id; + col_major_input += deep * tile_n * task_id; + size_t input_size = deep * tile_n * sizeof(float16_t); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + for (int i = start_block; i < end_block; i++) { + int real_in_row = (i != input_block - 1) ? tile_n : output_hw - i * tile_n; + memset(packed_input, 0, input_size); + Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_in_row, i * tile_n); +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep); +#else + RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep); +#endif + const float16_t *cur_weight = packed_weight; + const float16_t *cur_bias = bias_data; + for (int j = 0; j < weight_block; j++, cur_weight += C8NUM * deep, cur_bias += C8NUM) { + int real_weight_row = (j != weight_block - 1) ? C8NUM : conv_param->output_channel_ - j * C8NUM; + int out_offset = j * output_hw * C8NUM + i * tile_n * real_weight_row; + MatMulFp16(col_major_input, cur_weight, output_data + out_offset, cur_bias, conv_param->act_type_, deep, + real_in_row, real_weight_row, real_weight_row, OutType_Nhwc); + } + } + } +} + +void Conv1x1OutNc8hw8MultiThreadByInputFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight, + const float16_t *bias, float16_t *output, int task_id, + const MatMulParameter *param) { +#ifdef ENABLE_ARM64 + const int tile_n = 16; +#else + const int tile_n = 12; +#endif + NNACL_CHECK_ZERO_RETURN(tile_n); + NNACL_CHECK_ZERO_RETURN(param->op_parameter_.thread_num_); + int input_block = UP_DIV(param->row_, tile_n); + int weight_block = UP_DIV(param->col_, C8NUM); + + int block_per_thread = UP_DIV(input_block, param->op_parameter_.thread_num_); + int input_start_block = block_per_thread * task_id; + int input_end_block = MSMIN(input_start_block + block_per_thread, input_block); + if (input_start_block >= input_end_block) { + return; + } + input += input_start_block * tile_n * param->deep_; + pack_input += input_start_block * tile_n * param->deep_; + + int cur_row_cnt = MSMIN(block_per_thread * tile_n, param->row_ - input_start_block * tile_n); +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(input, pack_input, cur_row_cnt, param->deep_); +#else + RowMajor2Col12MajorFp16Opt(input, pack_input, cur_row_cnt, param->deep_); +#endif + for (int i = input_start_block; i < input_end_block; i++) { + int real_in_row = (i != input_block - 1) ? tile_n : param->row_ - i * tile_n; + const float16_t *cur_weight = weight; + const float16_t *cur_bias = bias; + for (int j = 0; j < weight_block; j++, cur_weight += C8NUM * param->deep_, cur_bias += C8NUM) { + int real_weight_row = (j != weight_block - 1) ? C8NUM : param->col_ - j * C8NUM; + int out_offset = j * param->row_ * C8NUM + i * tile_n * real_weight_row; + MatMulFp16(pack_input, cur_weight, output + out_offset, cur_bias, param->act_type_, param->deep_, real_in_row, + real_weight_row, real_weight_row, OutType_Nhwc); + } + pack_input += real_in_row * param->deep_; + } +} + +void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight, + const float16_t *bias, float16_t *output, int task_id, + const MatMulParameter *param) { +#ifdef ENABLE_ARM64 + const int tile_n = 16; +#else + const int tile_n = 12; +#endif + NNACL_CHECK_ZERO_RETURN(tile_n); + NNACL_CHECK_ZERO_RETURN(param->op_parameter_.thread_num_); + int input_block = UP_DIV(param->row_, tile_n); + int weight_block = UP_DIV(param->col_, C8NUM); + + int block_per_thread = UP_DIV(weight_block, param->op_parameter_.thread_num_); + int weight_start_block = block_per_thread * task_id; + int weight_end_block = MSMIN(weight_start_block + block_per_thread, weight_block); + if (weight_start_block >= weight_end_block) { + return; + } + for (int i = 0; i < input_block; i++) { + int real_in_row = (i != input_block - 1) ? tile_n : param->row_ - i * tile_n; + const float16_t *cur_weight = weight + weight_start_block * C8NUM * param->deep_; + const float16_t *cur_bias = bias + weight_start_block * C8NUM; + for (int j = weight_start_block; j < weight_end_block; j++, cur_weight += C8NUM * param->deep_, cur_bias += C8NUM) { + int real_weight_row = (j != weight_block - 1) ? C8NUM : param->col_ - j * C8NUM; + int out_offset = j * param->row_ * C8NUM + i * tile_n * real_weight_row; + MatMulFp16(pack_input, cur_weight, output + out_offset, cur_bias, param->act_type_, param->deep_, real_in_row, + real_weight_row, real_weight_row, OutType_Nhwc); + } + pack_input += real_in_row * param->deep_; + } +} + +// fp16 convolution winograd +void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data, + float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, + const ConvParameter *conv_param, TransFp16FuncList trans_func) { +#ifdef ENABLE_ARM64 + const int tile_num = 16; +#else + const int tile_num = 12; +#endif + NNACL_CHECK_ZERO_RETURN(conv_param->output_unit_); + NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_); + int in_channel = conv_param->input_channel_; + int input_unit = conv_param->input_unit_; + int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); + int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); + int output_count = out_w_block * out_h_block; + int per_thread_num = UP_DIV(output_count, conv_param->thread_num_); + int real_tile = per_thread_num < tile_num ? per_thread_num : tile_num; + NNACL_CHECK_ZERO_RETURN(real_tile); + int output_tile_count = UP_DIV(output_count, real_tile); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + int input_unit_square = input_unit * input_unit; + + float16_t *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel; + float16_t *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM; + float16_t *tmp_data = buffer_list[2] + task_id * input_unit_square * C8NUM; + float16_t *col_buffer = buffer_list[3] + task_id * tile_num * in_channel; + // step 1 : filter transform (pre-processed offline) + // step 2 : input transform (online) + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int out_tile_index = thread_id * real_tile; + int cal_num = output_count - thread_id * real_tile; + cal_num = cal_num > real_tile ? real_tile : cal_num; + if (cal_num <= 0) { + return; + } + +#ifdef ENABLE_ARM64 + // Optimize input transform. Only valid for arm64, the tile num is 16. + // For arm32, the tile_num is 12. The function(InputTransform4x4Pack12Fp16) needs to be rewritten. + bool fused_pack = + (cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL); + if (fused_pack) { + float16_t *opt_trans_input = + buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, C8NUM); + WinogradInputTransformOptStepFp16(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, + out_tile_index, out_w_block, conv_param, trans_func.in_step_func_); + + for (int w_index = 0; w_index < input_unit; w_index++) { + float16_t *src_w = opt_trans_input + w_index * input_unit * tile_num * C8NUM; + for (int c = 0; c < UP_DIV(in_channel, C8NUM); c++) { + int real_c = in_channel - c * C8NUM; + real_c = real_c > C8NUM ? C8NUM : real_c; + float16_t *src_c = src_w + c * input_unit_square * tile_num * C8NUM; + float16_t *dst_c = trans_input + c * tile_num * C8NUM; + trans_func.in_pack_func_(src_c, dst_c, C8NUM, in_channel * tile_num, real_c); + } + + for (int h_index = 0; h_index < input_unit; h_index++) { + const float16_t *gemm_input = trans_input + h_index * tile_num * in_channel; + int point_index = h_index * input_unit + w_index; + const float16_t *gemm_weight = trans_weight + point_index * in_channel * oc8 * C8NUM; + MatMulFp16(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num, + oc8 * C8NUM, input_unit_square, OutType_TileC8); + } + } + } else { +#endif + WinogradInputTransformFp16(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_func_); + // step 3 : gemm + float16_t *src_ptr = trans_input; + float16_t *dst_ptr = gemm_out; + float16_t *tmp_col_ptr = col_buffer; + for (int i = 0; i < input_unit_square; ++i) { +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); +#else + RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); +#endif + MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, + cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8); + } +#ifdef ENABLE_ARM64 + } +#endif + + // step 4 : output transform + if (conv_param->out_format_ != Format_NC4HW4) { // nc4hw4 + WinogradOutputNHWCTransformFp16(gemm_out, output_data + out_batch_offset, bias_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.out_func_); + } else { + WinogradOutputNC8HW8TransformFp16(gemm_out, output_data + out_batch_offset, bias_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.out_func_); + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/conv_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/conv_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..924f934eb40985eb8b878f77d70dc26be3c6bb46 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/conv_fp16.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_CONV_FP16_H_ +#define NNACL_FP16_CONV_FP16_H_ + +#include +#include "nnacl/conv_parameter.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/fp16/winograd_utils_fp16.h" +#include "nnacl/fp16/winograd_transform_fp16.h" + +typedef float16_t *TmpBufferAddressFp16; +typedef float16_t *MatricesFp16; + +#ifdef __cplusplus +extern "C" { +#endif +void Im2ColPackUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index); + +// fp16 convolution common (im2col+gemm) +void ConvFp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight, + const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, + const ConvParameter *conv_param); + +void ConvOutNc8hw8Fp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight, + const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, + const ConvParameter *conv_param); + +void Conv1x1OutNc8hw8MultiThreadByInputFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight, + const float16_t *bias, float16_t *output, int task_id, + const MatMulParameter *param); + +void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight, + const float16_t *bias, float16_t *output, int task_id, + const MatMulParameter *param); + +// fp16 convolution winograd +void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data, + float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, + const ConvParameter *conv_param, TransFp16FuncList trans_func); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_CONV_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/crop_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/crop_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..cb73bc51338de628541d95502bb91823f18d7fc1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/crop_fp16.c @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/crop_fp16.h" + +#include + +#include "nnacl/crop_parameter.h" + +void Fp16Crop1D(const float16_t *input, float16_t *output, int *out_shape, int64_t *in_offset, int task_id, + int thread_count) { + const int out_batch = out_shape[0]; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_batch, thread_count) : out_batch; + if (task_id_stride <= 0) { + return; + } + int n = task_id * task_id_stride; + if (n >= out_batch) { + return; + } + const float16_t *in_ptr = input + n + in_offset[0]; + float16_t *out_ptr = output + n; + int64_t out_dist_stride = MSMIN(out_batch - task_id * task_id_stride, task_id_stride); + memcpy(out_ptr, in_ptr, sizeof(float16_t) * out_dist_stride); +} + +void Fp16Crop2D(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int task_id, int thread_count) { + const int in_height = in_shape[1]; + const int out_batch = out_shape[0]; + const int out_height = out_shape[1]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + for (int n = 0; n < out_batch; n++) { + int h = task_id * task_id_stride; + if (h >= out_height) { + return; + } + const float16_t *in_ptr = input + (n + in_offset[0]) * in_height + h + in_offset[1]; + float16_t *out_ptr = output + n * out_height + h; + int64_t out_dist_stride = MSMIN(out_height - task_id * task_id_stride, task_id_stride); + memcpy(out_ptr, in_ptr, sizeof(float16_t) * out_dist_stride); + } +} + +void Fp16Crop3D(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int task_id, int thread_count) { + const int in_height = in_shape[1]; + const int in_width = in_shape[2]; + + const int out_batch = out_shape[0]; + const int out_height = out_shape[1]; + const int out_width = out_shape[2]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + const int in_stride_h = in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_h = out_width; + const int out_stride_n = out_stride_h * out_height; + + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + int h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + const float16_t *in_ptr = + input + (n + in_offset[0]) * in_stride_n + (h + in_offset[1]) * in_stride_h + in_offset[2]; + float16_t *out_ptr = output + n * out_stride_n + h * out_stride_h; + memcpy(out_ptr, in_ptr, sizeof(float16_t) * out_width); + } + } +} + +void Fp16Crop4D(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int task_id, int thread_count) { + const int in_height = in_shape[1]; + const int in_width = in_shape[2]; + const int in_channel = in_shape[3]; + + const int out_batch = out_shape[0]; + const int out_height = out_shape[1]; + const int out_width = out_shape[2]; + const int out_channel = out_shape[3]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + const int in_stride_w = in_channel; + const int in_stride_h = in_channel * in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_w = out_channel; + const int out_stride_h = out_channel * out_width; + const int out_stride_n = out_stride_h * out_height; + + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + int h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + for (int w = 0; w < out_width; w++) { + const float16_t *in_ptr = input + (n + in_offset[0]) * in_stride_n + (h + in_offset[1]) * in_stride_h + + (w + in_offset[2]) * in_stride_w + in_offset[3]; + float16_t *out_ptr = output + n * out_stride_n + h * out_stride_h + w * out_stride_w; + memcpy(out_ptr, in_ptr, sizeof(float16_t) * out_channel); + } + } + } +} + +void Fp16Crop(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int input_dim, int task_id, int thread_num) { + switch (input_dim) { + case 1: + Fp16Crop1D(input, output, out_shape, in_offset, task_id, thread_num); + break; + case 2: + Fp16Crop2D(input, output, in_shape, out_shape, in_offset, task_id, thread_num); + break; + case 3: + Fp16Crop3D(input, output, in_shape, out_shape, in_offset, task_id, thread_num); + break; + case 4: + Fp16Crop4D(input, output, in_shape, out_shape, in_offset, task_id, thread_num); + break; + default: + break; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/crop_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/crop_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..71a149d172aae8e8e08c33b173872a3c11161ae2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/crop_fp16.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_CROP_FP16_H_ +#define NNACL_FP16_CROP_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/crop_parameter.h" + +void Fp16Crop(const float16_t *input, float16_t *output, int *in_shape, int *out_shape, int64_t *in_offset, + int input_dim, int task_id, int thread_num); + +#endif // NNACL_FP16_CROP_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/custom_gru_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/custom_gru_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..72391811425efc00a9bb1527702c830a5f33e22d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/custom_gru_fp16.c @@ -0,0 +1,70 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/custom_gru_fp16.h" +#include "nnacl/fp16/activation_fp16.h" +#include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl/fp16/matmul_fp16.h" + +void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input, + const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden, + const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param) { + int num_step = gru_param->num_step; + int batch_size = gru_param->batch_size; + int input_size = gru_param->input_size; + int hidden_size = gru_param->hidden_size; + int output_size = batch_size * hidden_size; + int double_output_size = output_size * C2NUM; + int col_align = UP_ROUND(hidden_size, C8NUM); + int weight_in_offset = col_align * input_size; + int weight_hidden_offset = col_align * hidden_size; + float16_t *input_gate = buffer[1]; + float16_t *hidden_gate = buffer[C3NUM]; + for (int i = 0; i < num_step; ++i) { + if (batch_size != 1) { + RowMajor2ColNMajorFp16(input + i * batch_size * input_size, buffer[0], batch_size, input_size, false); + for (int j = 0; j < C3NUM; ++j) { + MatmulBaseFp16Neon(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + RowMajor2ColNMajorFp16(init_h, buffer[C2NUM], batch_size, hidden_size, false); + for (int j = 0; j < C3NUM; ++j) { + MatmulBaseFp16Neon(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + } else { + for (int j = 0; j < C3NUM; ++j) { + VecMatmulFp16(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, hidden_size); + VecMatmulFp16(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size); + } + } + ElementAddFp16(input_gate, hidden_gate, input_gate, double_output_size); + SigmoidFp16(input_gate, input_gate, double_output_size); + ElementMulFp16(input_gate, hidden_gate + double_output_size, input_gate, output_size); + ElementAddFp16(input_gate, input_gate + double_output_size, input_gate, output_size); + TanhFp16(input_gate, input_gate, output_size); + ElementSubFp16(init_h, input_gate, hidden_gate, output_size); + ElementMulFp16(input_gate + output_size, hidden_gate, hidden_gate, output_size); + ElementAddFp16(input_gate, hidden_gate, output, output_size); + init_h = output; + output += output_size; + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/custom_gru_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/custom_gru_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..f80248d7f64aa449c66e64de9520edd3c3b837ed --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/custom_gru_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_CUSTOM_GRU_FP16_H_ +#define NNACL_FP16_CUSTOM_GRU_FP16_H_ +#ifdef ENABLE_ARM64 +#include "nnacl/custom_gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *weight_input, + const float16_t *weight_hidden, const float16_t *bias_input, const float16_t *bias_hidden, + const float16_t *init_h, float16_t *buffer[4], const CustomGruParameter *gru_param); +#ifdef __cplusplus +} +#endif + +#endif +#endif // NNACL_FP16_CUSTOM_GRU_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/deconv_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/deconv_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..aa94cb51f6d86202dc740106c6c7b1a805bf31e5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/deconv_fp16.c @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/deconv_fp16.h" +#include + +void DeConvPostAddC8WithStride(const float16_t *source, float16_t *dest, size_t srcStride, size_t dststride, + size_t count) { + if (count == 0) { + return; + } + + const float16_t *src_ptr = source; + float16_t *dst_ptr = dest; + float16x8_t src1 = vld1q_f16(src_ptr); + float16x8_t dst1 = vld1q_f16(dst_ptr); + float16x8_t src2; + float16x8_t dst2; + size_t i = 1; + while (i < count - 1) { + dst1 = vaddq_f16(dst1, src1); + vst1q_f16(dst_ptr, dst1); + + src2 = vld1q_f16(src_ptr + srcStride); + dst2 = vld1q_f16(dst_ptr + dststride); + dst2 = vaddq_f16(dst2, src2); + vst1q_f16(dst_ptr + dststride, dst2); + i = i + 2; + src1 = vld1q_f16(src_ptr + srcStride + srcStride); + dst1 = vld1q_f16(dst_ptr + dststride + dststride); + + src_ptr = src_ptr + srcStride + srcStride; + dst_ptr = dst_ptr + dststride + dststride; + } + dst1 = vaddq_f16(dst1, src1); + vst1q_f16(dst_ptr, dst1); + if (i < count) { + src2 = vld1q_f16(src_ptr + srcStride); + dst2 = vld1q_f16(dst_ptr + dststride); + dst2 = vaddq_f16(dst2, src2); + vst1q_f16(dst_ptr + dststride, dst2); + } +} + +int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel, + const ConvParameter *conv_param) { + float16x8_t min_v = vdupq_n_f16(-FLT_MAX); + float16x8_t max_v = vdupq_n_f16(FLT_MAX); + if (conv_param->act_type_ == ActType_Relu) { + min_v = vdupq_n_f16(0.f); + } + if (conv_param->act_type_ == ActType_Relu6) { + min_v = vdupq_n_f16(0.f); + max_v = vdupq_n_f16(6.f); + } + + /* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ + size_t input_plane = conv_param->input_w_ * conv_param->input_h_; + size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + size_t output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc8 = UP_ROUND(output_channel, C8NUM); + int in_plane16 = UP_ROUND(input_plane, 16); + int src_iw_stride = C8NUM; + int src_ih_stride = conv_param->input_w_ * C8NUM; + int src_kw_stride = in_plane16 * C8NUM; + int src_kh_stride = in_plane16 * conv_param->kernel_w_ * C8NUM; + int dst_oh_stride = conv_param->output_w_ * C8NUM; + int dst_ow_stride = C8NUM; + int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM; + int dst_kw_stride = conv_param->dilation_w_ * C8NUM; + + NNACL_CHECK_ZERO_RETURN_ERR(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN_ERR(conv_param->dilation_w_); + + for (int c = 0; c < oc8; c += 8) { + float16_t *dst_ptr = tmp + c * output_plane; + const float16_t *src_ptr = src + c * in_plane16 * kernel_plane; + memset(dst_ptr, 0, output_plane * C8NUM * sizeof(float16_t)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + + const float16_t *src_in_ptr = src_ptr + ih * src_ih_stride + iw * src_iw_stride; + float16_t *dst_in_ptr = dst_ptr + oh * dst_oh_stride + ow * dst_ow_stride; + + for (int kh = kh_start; kh < kh_end; kh++) { + const float16_t *src_kh_ptr = src_in_ptr + kh * src_kh_stride; + float16_t *dst_kh_ptr = dst_in_ptr + kh * dst_kh_stride; + DeConvPostAddC8WithStride(src_kh_ptr + kw_start * src_kw_stride, dst_kh_ptr + kw_start * dst_kw_stride, + src_kw_stride, dst_kw_stride, kw_end - kw_start); + } // kh + } // iw + } // ih + + /* add bias for current oh*ow*C8 + * write to output data ptr in nhwc format */ + float16x8_t bias_v = vld1q_f16(bias + c); + float16_t *pack_tmp_data = dst_ptr; + for (size_t i = 0; i < output_plane; i++) { + float16x8_t data_v = vld1q_f16(pack_tmp_data); + data_v = vaddq_f16(data_v, bias_v); + data_v = vminq_f16(data_v, max_v); + data_v = vmaxq_f16(data_v, min_v); + vst1q_f16(pack_tmp_data, data_v); + pack_tmp_data += C8NUM; + } + } // oc8 + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/deconv_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/deconv_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..8dbb3a547fbe494f114cce4e5c4b716624ce81c4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/deconv_fp16.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_DECONV_FP16_H_ +#define NNACL_FP16_DECONV_FP16_H_ + +#include +#include +#include "nnacl/conv_parameter.h" +#include "nnacl/errorcode.h" +#include "nnacl/fp16/common_func_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, float16_t *dst, int output_channel, + const ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_DECONV_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/deconv_winograd_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/deconv_winograd_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..c48ee1b09c600c9949da313b264cdc1fcdc71ca3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/deconv_winograd_fp16.c @@ -0,0 +1,519 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/deconv_winograd_fp16.h" +#include "nnacl/base/minimal_filtering_generator.h" + +void DeConvWgInputPackFp16(const float16_t *src_ptr, float16_t *dst_ptr, int channel, int stride) { + int ic4div = channel / C4NUM; + int ic4mod = channel % C4NUM; + const float16_t *src = src_ptr; + float16_t *dst = dst_ptr; + + for (int ic = 0; ic < ic4div; ic++) { + vst1_f16(dst, vld1_f16(src)); + dst += stride; + src += C4NUM; + } + + if (ic4mod != 0) { + int ic_res = 0; + for (; ic_res < ic4mod; ic_res++) { + dst[ic_res] = src[ic_res]; + } + for (; ic_res < C4NUM; ic_res++) { + dst[ic_res] = 0; + } + } + return; +} + +#ifdef ENABLE_ARM82_A32 +void DeconvWgMergeFp16A32Fun(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_step, size_t dst_step) { + asm volatile( + "mov r0, %[src_ptr]\n" + "mov r1, %[dst_ptr]\n" + "mov r2, r1\n" + + "vld1.16 {d0}, [r0], %[src_step]\n" + "vld1.16 {d2}, [r1], %[dst_step]\n" + "vld1.16 {d4}, [r0], %[src_step]\n" + "vld1.16 {d6}, [r1], %[dst_step]\n" + "vadd.f16 d0, d0, d2\n" + "vld1.16 {d8}, [r0], %[src_step]\n" + "vadd.f16 d4, d4, d6\n" + "vst1.16 {d0}, [r2], %[dst_step]\n" + "vst1.16 {d4}, [r2], %[dst_step]\n" + + "vld1.16 {d10}, [r1], %[dst_step]\n" + "vld1.16 {d12}, [r0], %[src_step]\n" + "vadd.f16 d8, d8, d10\n" + "vld1.16 {d14}, [r1], %[dst_step]\n" + "vadd.f16 d12, d12, d14\n" + "vld1.16 {d0}, [r0], %[src_step]\n" + "vst1.16 {d8}, [r2], %[dst_step]\n" + "vst1.16 {d12}, [r2], %[dst_step]\n" + + "vld1.16 {d2}, [r1], %[dst_step]\n" + "vld1.16 {d4}, [r0], %[src_step]\n" + "vld1.16 {d6}, [r1], %[dst_step]\n" + "vadd.f16 d0, d0, d2\n" + "vadd.f16 d4, d4, d6\n" + "vst1.16 {d0}, [r2], %[dst_step]\n" + "vst1.16 {d4}, [r2], %[dst_step]\n" + + "vld1.16 {d8}, [r0], %[src_step]\n" + "vld1.16 {d10}, [r1], %[dst_step]\n" + "vld1.16 {d12}, [r0], %[src_step]\n" + "vld1.16 {d14}, [r1], %[dst_step]\n" + "vadd.f16 d8, d8, d10\n" + "vadd.f16 d12, d12, d14\n" + "vst1.16 {d8}, [r2], %[dst_step]\n" + "vst1.16 {d12}, [r2], %[dst_step]\n" + + : + : [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step) + : "r0", "r1", "r2", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +} +#endif + +void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride, size_t count) { + const float16_t *src_ptr = src; + float16_t *dst_ptr = dst; + size_t cuont8 = count / C8NUM * C8NUM; + int i = 0; + for (; i < cuont8; i += C8NUM) { +#ifdef ENABLE_ARM64 + size_t src_step = src_stride * sizeof(float16_t); + size_t dst_step = dst_stride * sizeof(float16_t); + asm volatile( + "mov x7, %[src_ptr]\n" + "mov x8, %[dst_ptr]\n" + "mov x10, x8\n" + + "ld1 {v0.4h}, [x7], %[src_step]\n" + "ld1 {v1.4h}, [x8], %[dst_step]\n" + "ld1 {v2.4h}, [x7], %[src_step]\n" + "ld1 {v3.4h}, [x8], %[dst_step]\n" + "fadd v0.4h, v0.4h, v1.4h\n" + "ld1 {v4.4h}, [x7], %[src_step]\n" + "fadd v2.4h, v2.4h, v3.4h\n" + "st1 {v0.4h}, [x10], %[dst_step]\n" + "st1 {v2.4h}, [x10], %[dst_step]\n" + + "ld1 {v5.4h}, [x8], %[dst_step]\n" + "ld1 {v6.4h}, [x7], %[src_step]\n" + "fadd v4.4h, v4.4h, v5.4h\n" + "ld1 {v7.4h}, [x8], %[dst_step]\n" + "fadd v6.4h, v6.4h, v7.4h\n" + "ld1 {v0.4h}, [x7], %[src_step]\n" + "st1 {v4.4h}, [x10], %[dst_step]\n" + "st1 {v6.4h}, [x10], %[dst_step]\n" + + "ld1 {v1.4h}, [x8], %[dst_step]\n" + "ld1 {v2.4h}, [x7], %[src_step]\n" + "ld1 {v3.4h}, [x8], %[dst_step]\n" + "fadd v0.4h, v0.4h, v1.4h\n" + "fadd v2.4h, v2.4h, v3.4h\n" + "st1 {v0.4h}, [x10], %[dst_step]\n" + "st1 {v2.4h}, [x10], %[dst_step]\n" + + "ld1 {v4.4h}, [x7], %[src_step]\n" + "ld1 {v5.4h}, [x8], %[dst_step]\n" + "ld1 {v6.4h}, [x7], %[src_step]\n" + "ld1 {v7.4h}, [x8], %[dst_step]\n" + "fadd v4.4h, v4.4h, v5.4h\n" + "fadd v6.4h, v6.4h, v7.4h\n" + "st1 {v4.4h}, [x10], %[dst_step]\n" + "st1 {v6.4h}, [x10], %[dst_step]\n" + + : + : [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step) + : "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#elif defined(ENABLE_ARM82_A32) + size_t src_step = src_stride * sizeof(float16_t); + size_t dst_step = dst_stride * sizeof(float16_t); + DeconvWgMergeFp16A32Fun(src_ptr, dst_ptr, src_step, dst_step); +#else + for (int j = 0; j < C8NUM; j++) { + const float16_t *s = src_ptr + j * src_stride; + float16_t *d = dst_ptr + j * dst_stride; + for (int k = 0; k < C4NUM; k++) { + d[k] += s[k]; + } + } +#endif + src_ptr += C8NUM * src_stride; + dst_ptr += C8NUM * dst_stride; + } + + for (; i < count; i++) { + float16x4_t src_data = vld1_f16(src_ptr); + float16x4_t dst_data = vld1_f16(dst_ptr); + dst_data = vadd_f16(src_data, dst_data); + vst1_f16(dst_ptr, dst_data); + + src_ptr += src_stride; + dst_ptr += dst_stride; + } + return; +} + +void DeConvWgCalWgFp16(const float16_t *tile_in, float16_t *tile_out, const float16_t *weight_buf, float16_t *tmp_buf, + const float16_t *at_buf, float16_t *a_mid_buf, float16_t *trans_a_buf, bool *transferred, + const float16_t *bt_buf, float16_t *b_tmp_buf, int unit_size, int w_start, int h_start, + const ConvParameter *conv_param, const DeConvParam *deconv_param) { + int winograd_plane = unit_size * unit_size; + if (!transferred[unit_size]) { + WinogradTransLeftFp16(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, + DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRightFp16(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, + deconv_param->ic_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + transferred[unit_size] = true; + } + + for (int index = 0; index < winograd_plane; index++) { + float16_t *src = trans_a_buf + index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float16_t *dst = tmp_buf + index * deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + const float16_t *weight = weight_buf + index * deconv_param->ic_up_ * deconv_param->oc_up_; + TiledC4MatmulFp16(dst, src, weight, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, deconv_param->ic_div_, + deconv_param->oc_div_); + } + + WinogradTransLeftFp16(tmp_buf, bt_buf, b_tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRightFp16(b_tmp_buf, bt_buf, tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + + // Add to dest + for (int uhi = 0; uhi < unit_size; uhi++) { + int h_index = uhi * conv_param->stride_h_ + h_start; + for (int uwi = 0; uwi < unit_size; uwi++) { + int w_index = uwi * conv_param->stride_w_ + w_start; + + float16_t *dst = tile_out + w_index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_ + + h_index * deconv_param->out_tile_w_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + float16_t *src = tmp_buf + (uwi + uhi * unit_size) * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + DeConvWgMergeFp16(src, dst, C4NUM, C4NUM, DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_div_); + } + } + return; +} + +void DeConvWgCalCommFp16(const float16_t *tile_in, float16_t *tile_out, const float16_t *weight, float16_t *tmp_buf, + int h_start, int w_start, int h_size, int w_size, const ConvParameter *conv_param, + const DeConvParam *deconv_param) { + int count = deconv_param->oc_div_ * w_size * h_size; + int in_stride = DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + int out_stride = DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + + for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { + for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { + const float16_t *src_in = tile_in + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * in_stride; + TiledC4MatmulFp16(tmp_buf, src_in, weight, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, deconv_param->ic_div_, count); + + for (int uhi = 0; uhi < h_size; uhi++) { + for (int uwi = 0; uwi < w_size; uwi++) { + int w_index = (wi + uwi) * conv_param->stride_w_ + w_start; + int h_index = (hi + uhi) * conv_param->stride_h_ + h_start; + float16_t *dst = tile_out + h_index * out_stride * deconv_param->out_tile_w_ + w_index * out_stride; + float16_t *src = tmp_buf + (uwi + uhi * w_size) * out_stride; + DeConvWgMergeFp16(src, dst, C4NUM, C4NUM, DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_div_); + } + } + } + } + return; +} + +int PackDeConvWgDataFp16(const float16_t *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param, + const DeConvParam *deconv_param) { + int tmp_kernel_plane = unit->w_size_ * unit->h_size_; + int output_channel = conv_param->output_channel_; + int size = conv_param->input_channel_ * output_channel * tmp_kernel_plane; + float16_t *current_unit_weight = (float16_t *)malloc(size * sizeof(float16_t)); + if (current_unit_weight == NULL) { + return NNACL_NULL_PTR; + } + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + const float16_t *src_ic = nhwc_weight + deconv_param->kernel_plane_ * output_channel * ic; + float16_t *dst_ic = current_unit_weight + tmp_kernel_plane * output_channel * ic; + for (int uhi = 0; uhi < unit->h_size_; uhi++) { + for (int uwi = 0; uwi < unit->w_size_; uwi++) { + int src_h_offset = unit->h_start_ + uhi * conv_param->stride_h_; + int src_w_offset = unit->w_start_ + uwi * conv_param->stride_w_; + const float16_t *src_hw = src_ic + (src_h_offset * conv_param->kernel_w_ + src_w_offset) * output_channel; + float16_t *dst_hw = dst_ic + (uhi * unit->w_size_ + uwi) * output_channel; + memcpy(dst_hw, src_hw, output_channel * sizeof(float16_t)); + } + } + } + + if (unit->use_winograd_) { + /* Generate winograd */ + float matrix_g[64]; + float matrix_gt[64]; + float matrix_a[64]; + float matrix_at[64]; + float matrix_b[64]; + float matrix_bt[64]; + int ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 0.5f, + DECONV_WINOGRAD_DEFAULT_UNIT, unit->h_size_); + if (ret != NNACL_OK) { + free(current_unit_weight); + current_unit_weight = NULL; + return NNACL_ERRCODE_WINOGRAD_GENERATOR_ERROR; + } + + /* winograd AT */ + unit->winograd_.AT_ = malloc(unit->winograd_.i_ * unit->winograd_.o_ * sizeof(float16_t)); + if (unit->winograd_.AT_ == NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + return NNACL_NULL_PTR; + } + Float32ToFloat16(matrix_at, unit->winograd_.AT_, unit->winograd_.i_ * unit->winograd_.o_); + + /* winograd BT */ + unit->winograd_.BT_ = malloc(unit->winograd_.o_ * unit->winograd_.o_ * sizeof(float16_t)); + if (unit->winograd_.BT_ == NULL) { + free(current_unit_weight); + free(unit->winograd_.AT_); + current_unit_weight = NULL; + unit->winograd_.AT_ = NULL; + return NNACL_NULL_PTR; + } + Float32ToFloat16(matrix_bt, unit->winograd_.BT_, unit->winograd_.o_ * unit->winograd_.o_); + + /* winograd Weight */ + size = conv_param->input_channel_ * output_channel * unit->winograd_.kh_ * unit->winograd_.kw_; + float16_t *winograd_unit_weight = (float16_t *)malloc(size * sizeof(float16_t)); + if (winograd_unit_weight == NULL) { + free(current_unit_weight); + free(unit->winograd_.AT_); + free(unit->winograd_.BT_); + current_unit_weight = NULL; + unit->winograd_.AT_ = NULL; + unit->winograd_.BT_ = NULL; + return NNACL_NULL_PTR; + } + + WinogradWeightTransformFp16(current_unit_weight, winograd_unit_weight, matrix_g, matrix_gt, C4NUM, + unit->winograd_.kh_, unit->h_size_, output_channel, conv_param->input_channel_, false); + + /* reset weight data & info */ + tmp_kernel_plane = unit->winograd_.kh_ * unit->winograd_.kw_; + free(current_unit_weight); + current_unit_weight = winograd_unit_weight; + winograd_unit_weight = NULL; + } + + /* trans mhwc -> hw1:k1-knc0-c4:k1-knc5-c8:hw2:k1-knc0-c4:k1 */ + float16_t *dst_weight = (float16_t *)unit->weight_; + size = deconv_param->ic_up_ * deconv_param->oc_up_ * tmp_kernel_plane; + memset(dst_weight, 0, size * sizeof(float16_t)); + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / C4NUM, oc4mod = oc % C4NUM; + for (int upi = 0; upi < tmp_kernel_plane; upi++) { + int src_index = ic * output_channel * tmp_kernel_plane + upi * output_channel + oc; + int dst_index = upi * deconv_param->oc_up_ * deconv_param->ic_up_ + oc4div * C4NUM * deconv_param->ic_up_ + + ic * C4NUM + oc4mod; + dst_weight[dst_index] = current_unit_weight[src_index]; + } + } + } + + free(current_unit_weight); + current_unit_weight = NULL; + return NNACL_OK; +} + +void DeconvWgFp16(const float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_out, int start_index, + int calculate_count, const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id) { + NNACL_CHECK_ZERO_RETURN(deconv_param->in_tile_w_count_); + /* pack tile input */ + int tile_in_unit_stride = deconv_param->ic_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + float16x4_t zero = vdup_n_f16(0.0f); + + for (int unit_index = 0; unit_index < calculate_count; unit_index++) { + int plane_index = start_index + unit_index; + int w_unit_index = plane_index % deconv_param->in_tile_w_count_; + int h_unit_index = plane_index / deconv_param->in_tile_w_count_; + int w_start = w_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT; + int h_start = h_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT; + + float16_t *dst_unit = tile_in + unit_index * C4NUM; + for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { + for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { + float16_t *dst = dst_unit + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * tile_in_unit_stride; + int w_index = w_start + wi; + int h_index = h_start + hi; + if (w_index >= conv_param->input_w_ || h_index >= conv_param->input_h_) { + for (int ic4_index = 0; ic4_index < deconv_param->ic_div_; ic4_index++) { + vst1_f16(dst + ic4_index * DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, zero); + } + continue; + } + + const float16_t *src = nhwc_input_ + (w_index + h_index * conv_param->input_w_) * conv_param->input_channel_; + DeConvWgInputPackFp16(src, dst, conv_param->input_channel_, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM); + } + } + } + + /* compute */ + bool transferred[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; + for (int i = 0; i < deconv_param->compute_size_; i++) { + DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; + if (unit->use_winograd_) { + float16_t *tmp_buf = (float16_t *)unit->tmp_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * + deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + + /* winograd a buffer */ + if (unit->winograd_.kh_ >= DECONV_WINOGRAD_BUFFER_COUNT) { + return; + } + DeConvWgABuffer *tmp_a = &deconv_param->a_buffer_[unit->winograd_.kh_]; + float16_t *mid_a = (float16_t *)tmp_a->middle_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float16_t *dst_a = (float16_t *)tmp_a->dest_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float16_t *tmp_b = (float16_t *)unit->winograd_.b_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + DeConvWgCalWgFp16(tile_in, tile_out, (float16_t *)unit->weight_, tmp_buf, unit->winograd_.AT_, mid_a, dst_a, + transferred, unit->winograd_.BT_, tmp_b, unit->winograd_.kh_, unit->w_start_, unit->h_start_, + conv_param, deconv_param); + } else { + float16_t *tmp_buf = (float16_t *)unit->tmp_buffer_ + task_id * deconv_param->oc_div_ * unit->w_size_ * + unit->h_size_ * DECONV_WINOGRAD_DEFAULT_TILE * C4NUM; + DeConvWgCalCommFp16(tile_in, tile_out, (float16_t *)unit->weight_, tmp_buf, unit->h_start_, unit->w_start_, + unit->h_size_, unit->w_size_, conv_param, deconv_param); + } + } + return; +} + +void DeconvWgPostFp16(const float16_t *tile_out, float16_t *nc4hw4_output, const ConvParameter *conv_param, + const DeConvParam *deconv_param, int calculate_count, int tile_index) { + NNACL_CHECK_ZERO_RETURN(deconv_param->in_tile_w_count_); + /* merge */ + int src_unit_stride = deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + + int src_stride = DECONV_WINOGRAD_DEFAULT_TILE * C4NUM; + int dst_stride = conv_param->output_w_ * conv_param->output_h_ * C4NUM; + + for (int index = 0; index < calculate_count; ++index) { + const float16_t *src_start = tile_out + index * C4NUM; + + int plane_index = tile_index * DECONV_WINOGRAD_DEFAULT_TILE + index; + int w_unit_index = plane_index % deconv_param->in_tile_w_count_; + int h_unit_index = plane_index / deconv_param->in_tile_w_count_; + int w_start = w_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT * conv_param->stride_w_ - conv_param->pad_l_; + int h_start = h_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT * conv_param->stride_h_ - conv_param->pad_u_; + float16_t *dst_start = nc4hw4_output + h_start * conv_param->output_w_ * C4NUM + w_start * C4NUM; + + int merge_w_start = MSMAX(-w_start, 0); + int merge_h_start = MSMAX(-h_start, 0); + int merge_h_end = MSMIN(deconv_param->out_tile_h_, conv_param->output_h_ - h_start); + int merge_w_end = MSMIN(deconv_param->out_tile_w_, conv_param->output_w_ - w_start); + + for (int hi = merge_h_start; hi < merge_h_end; hi++) { + for (int wi = merge_w_start; wi < merge_w_end; wi++) { + const float16_t *src = src_start + (hi * deconv_param->out_tile_w_ + wi) * src_unit_stride; + float16_t *dst = dst_start + (hi * conv_param->output_w_ + wi) * C4NUM; + DeConvWgMergeFp16(src, dst, src_stride, dst_stride, deconv_param->oc_div_); + } + } + } + return; +} + +#ifndef ENABLE_ARM +void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length) { + const int unitStep = C4NUM * length; + for (int y = 0; y < h; ++y) { + float16_t *dstY = M + y * w * unitStep; + for (int x = 0; x < w; ++x) { + float16_t *dstX = dstY + x * unitStep; + const float16_t *srcX = S + x * unitStep; + memset(dstX, 0, unitStep * sizeof(float16_t)); + for (int i = 0; i < k; ++i) { + float16_t b = B[i * h + y]; + const float16_t *srcY = srcX + i * w * unitStep; + if (0.0f == b) { + continue; + } + for (int j = 0; j < unitStep; ++j) { + dstX[j] += srcY[j] * b; + } + } + } + } +} + +void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length) { + const int unitStep = C4NUM * length; + for (int y = 0; y < h; ++y) { + float16_t *dstY = M + y * w * unitStep; + const float16_t *srcY = S + y * k * unitStep; + + for (int x = 0; x < w; ++x) { + float16_t *dstX = dstY + x * unitStep; + memset(dstX, 0, unitStep * sizeof(float16_t)); + for (int i = 0; i < k; ++i) { + const float16_t *srcX = srcY + i * unitStep; + float16_t b = B[i * h + x]; + if (0.0f == b) { + continue; + } + for (int j = 0; j < unitStep; ++j) { + dstX[j] += srcX[j] * b; + } + } + } + } +} + +void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t cal_num, size_t ic4, + size_t oc4) { + int dx, sz, dz; + int src_depth_step = C4NUM * DECONV_WINOGRAD_DEFAULT_TILE; + for (dz = 0; dz < oc4; ++dz) { + float16_t *dst_z = dst + dz * cal_num; + const float16_t *weight_dz = weight + dz * ic4 * 16; + for (dx = 0; dx < DECONV_WINOGRAD_DEFAULT_TILE; ++dx) { + float16_t *dst_x = dst_z + dx * C4NUM; + dst_x[0] = 0.0f; + dst_x[1] = 0.0f; + dst_x[2] = 0.0f; + dst_x[3] = 0.0f; + const float16_t *src_dx = src + C4NUM * dx; + for (sz = 0; sz < ic4; ++sz) { + const float16_t *src_z = src_dx + sz * src_depth_step; + const float16_t *weight_z = weight_dz + sz * 16; + for (int i = 0; i < C4NUM; ++i) { + for (int j = 0; j < C4NUM; ++j) { + dst_x[j] += src_z[i] * weight_z[C4NUM * i + j]; + } + } + } + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/deconv_winograd_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/deconv_winograd_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..dd65a6d5bcce4945171f7077d1fd0015601e349d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/deconv_winograd_fp16.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_DECONV_WINOGRAD_FP16_H_ +#define NNACL_FP16_DECONV_WINOGRAD_FP16_H_ + +#include "nnacl/fp16/winograd_transform_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PackDeConvWgDataFp16(const float16_t *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param, + const DeConvParam *deconv_param); + +void DeconvWgFp16(const float16_t *nhwc_input_, float16_t *tile_in, float16_t *tile_out, int start_index, + int calculate_count, const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id); + +void DeconvWgPostFp16(const float16_t *tile_out, float16_t *nc4hw4_output, const ConvParameter *conv_param, + const DeConvParam *deconv_param, int calculate_count, int tile_index); + +void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t ic4, size_t cal_num, + size_t oc4); + +void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length); + +void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_DECONV_WINOGRAD_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/dynamic_quant_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/dynamic_quant_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..feab8214eccb68beae5f611223d6e78427b1c0a3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/dynamic_quant_fp16.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/dynamic_quant_fp16.h" + +void CalculateMinMaxFp16(const float16_t *data, int count, float16_t *real_min, float16_t *real_max) { +#ifndef ENABLE_ARM64 + for (int i = 0; i < count; ++i) { + if (data[i] < *real_min) { + *real_min = data[i]; + } + if (data[i] > *real_max) { + *real_max = data[i]; + } + } +#else + // avoid to compile optimize. + volatile int count_8 = DOWN_ROUND(count, C8NUM); + CalculateMinMaxCount8Fp16(data, count_8, real_min, real_max); + for (int i = count_8; i < count; ++i) { + if (data[i] < *real_min) { + *real_min = data[i]; + } + if (data[i] > *real_max) { + *real_max = data[i]; + } + } +#endif +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/dynamic_quant_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/dynamic_quant_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..c2ed03624017603e153310a4a3b7a83994236ee5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/dynamic_quant_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DYNAMIC_QUANT_FP16_H_ +#define NNACL_INT8_DYNAMIC_QUANT_FP16_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CalculateMinMaxFp16(const float16_t *data, int count, float16_t *real_min, float16_t *real_max); + +#ifdef ENABLE_ARM64 +void CalculateMinMaxCount8Fp16(const float16_t *data, int count_8, float16_t *real_min, float16_t *real_max); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DYNAMIC_QUANT_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/exp_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/exp_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..93f005c8eb76f43c2e61572a43b2816a5141d557 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/exp_fp16.c @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/exp_fp16.h" +#include +#include +#include "nnacl/errorcode.h" + +#if defined(ENABLE_NEON) +static inline void simd_exp_fp16(float16x8_t input, float16_t *dst) { + static float16x8_t maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static float16x8_t minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = vmaxq_f16(minv, vminq_f16(input, maxv)); + vst1q_f16(dst, VexpFp16(input)); +} +#endif + +void ExpFp16(const float16_t *src, float16_t *dst, int num) { + int i = 0; +#ifdef ENABLE_NEON + int count = (num / C8NUM) * C8NUM; + for (; i < count; i += C8NUM) { + simd_exp_fp16(vld1q_f16(src + i), dst + i); + } +#endif + for (; i < num; ++i) { + single_exp_fp16(src[i], dst + i); + } +} + +int ExpFusionFp16(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id) { + NNACL_CHECK_ZERO_RETURN_ERR(exp->base_.thread_nr_); + ExpParameter *param = (ExpParameter *)exp->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + float16_t *src = (float16_t *)src_data; + float16_t *dst = (float16_t *)dst_data; + int stride = UP_DIV(exp->element_num_, exp->base_.thread_nr_); + int start = stride * task_id; + int end = MSMIN(exp->element_num_, start + stride); + int num = end - start; + + if (param->scale_ == 1) { + ExpFp16(src + start, dst + start, num); + } else { + int i = 0; +#ifdef ENABLE_ARM64 + MS_FLOAT16X8 scale = MS_MOVQ_F16(exp->in_scale_); + int count = (num / C8NUM) * C8NUM; + for (; i < count; i += C8NUM) { + simd_exp_fp16(MS_MULQ_F16(MS_LDQ_F16(src + i), scale), dst + i); + } +#endif + for (; i < num; ++i) { + single_exp_fp16(src[i] * exp->in_scale_, dst + i); + } + } + if (exp->out_scale_ != 1) { + int i = 0; +#ifdef ENABLE_ARM64 + MS_FLOAT16X8 scale = MS_MOVQ_F16(exp->out_scale_); + int count = (num / C8NUM) * C8NUM; + for (; i < count; i += C8NUM) { + simd_exp_fp16(MS_LDQ_F16(src + i), dst + i); + MS_STQ_F16(dst + i, MS_MULQ_F16(MS_LDQ_F16(dst + i), scale)); + } +#endif + for (; i < num; ++i) { + single_exp_fp16(src[i], dst + i); + dst[i] *= exp->out_scale_; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/exp_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/exp_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..dc451abcd69ac9fef1fd2d67a2781d8212b6c293 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/exp_fp16.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_EXP_FP16_H_ +#define NNACL_FP16_EXP_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/kernel/exp.h" +#include "nnacl/exp_parameter.h" +#include "nnacl/fp32/exp_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ExpFp16(const float16_t *src, float16_t *dst, int num); +int ExpFusionFp16(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id); + +#ifdef ENABLE_NEON +static inline float16x8_t VexpFp16(float16x8_t input) { + float32x4_t input_low = MS_CVT_F32_F16(vget_low_f16(input)); + float32x4_t input_high = MS_CVT_F32_F16(vget_high_f16(input)); + return vcombine_f16(MS_CVT_F16_F32(VexpFp32(input_low)), MS_CVT_F16_F32(VexpFp32(input_high))); +} +#endif + +static inline void single_exp_fp16(float16_t src, float16_t *dst) { + static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; + int integer; + if (src > 0) { + src = MSMIN(88.72283935546875f, src); + integer = (float)src * 1.44269504088896341f + 0.5f; + } else { + src = MSMAX(-87.3365478515625f, src); + integer = (float)src * 1.44269504088896341f - 0.5f; + } + const int shift = 23; + const int bias = 126; + const float factor = 2; + float decimal = (float)src - integer * param[0]; + int int_exp = (integer + bias) << shift; + const float decimal_exp = + 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); + float *tmp = (float *)(&int_exp); + *dst = (float16_t)(*(tmp)*decimal_exp * factor); +} + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_EXP_FP16_H_ diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/AppDelegate.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/fill_fp16.c similarity index 76% rename from mindspore-lite/examples/quick_start_ios/mindspore-lite/AppDelegate.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp16/fill_fp16.c index 6aa7efdb40531d1088a60db9b27c1874aceb0d96..7dc59b676a1a3be7d3a3dc52bae0aece98c3212b 100644 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/AppDelegate.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/fill_fp16.c @@ -14,10 +14,11 @@ * limitations under the License. */ -#import - -@interface AppDelegate : UIResponder - - -@end +#include "nnacl/fp16/fill_fp16.h" +inline int FillFp16(float16_t *output, int size, float16_t data) { + for (int i = 0; i < size; ++i) { + output[i] = data; + } + return NNACL_OK; +} diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/main.m b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/fill_fp16.h similarity index 62% rename from mindspore-lite/examples/quick_start_ios/mindspore-lite/main.m rename to mindspore-lite/ops/kernel/cpu/nnacl/fp16/fill_fp16.h index afc843df79f9a3bebcd0c6f37f14ab1b2ff65418..2375884d9aed3ba3b5b9742790de82e43f1f47a0 100644 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/main.m +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/fill_fp16.h @@ -13,15 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef NNACL_FP16_FILL_FP16_H_ +#define NNACL_FP16_FILL_FP16_H_ -#import -#import "AppDelegate.h" +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/fill_parameter.h" +#ifdef ENABLE_ARM +#include +#endif -int main(int argc, char * argv[]) { - NSString * appDelegateClassName; - @autoreleasepool { - // Setup code that might create autoreleased objects goes here. - appDelegateClassName = NSStringFromClass([AppDelegate class]); - } - return UIApplicationMain(argc, argv, nil, appDelegateClassName); +#ifdef __cplusplus +extern "C" { +#endif +int FillFp16(float16_t *output, int size, float16_t data); +#ifdef __cplusplus } +#endif + +#endif // NNACL_FP16_FILL_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/gru_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/gru_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..5d83c0a9ea84569a17bcd7a98aa02071ae78433f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/gru_fp16.c @@ -0,0 +1,148 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/gru_fp16.h" +#include +#include "nnacl/fp16/lstm_fp16.h" +#include "nnacl/fp16/activation_fp16.h" +#include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl/fp16/matmul_fp16.h" + +void GruStepUnitFp16(float16_t *output, float16_t *update_gate, float16_t *reset_gate, float16_t *hidden_buffer, + const float16_t *state_weight, const float16_t *state_bias, float16_t *hidden_state, + float16_t *buffer[4], const GruParameter *gru_param) { + float16_t *packed_state = buffer[2]; + float16_t *state_gate = buffer[3]; + bool is_vec = gru_param->batch_ == 1; + + const float16_t *state_update_weight = state_weight; + const float16_t *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_; + const float16_t *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2; + float16_t *state_update_gate = state_gate; + float16_t *state_reset_gate = state_gate + gru_param->batch_ * gru_param->hidden_size_; + float16_t *state_hidden_buffer = state_gate + gru_param->batch_ * gru_param->hidden_size_ * 2; + const float16_t *state_update_bias = state_bias; + const float16_t *state_reset_bias = state_bias + gru_param->hidden_size_; + const float16_t *state_hidden_bias = state_bias + gru_param->hidden_size_ * 2; + + // state * weight + if (is_vec) { + LstmMatMulFp16(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + LstmMatMulFp16(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + RowMajor2Col16MajorFp16(hidden_state, packed_state, gru_param->batch_, gru_param->hidden_size_, false); + LstmMatMulFp16(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + LstmMatMulFp16(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAddFp16(update_gate, state_update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_); + ElementAddFp16(reset_gate, state_update_gate + gru_param->batch_ * gru_param->hidden_size_, reset_gate, + gru_param->batch_ * gru_param->hidden_size_); + + // update reset_gate + SigmoidFp16(reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); + + // update update_gate + SigmoidFp16(update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_); + + ElementMulFp16(hidden_state, reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); + if (is_vec) { + LstmMatMulFp16(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + RowMajor2Col16MajorFp16(reset_gate, packed_state, gru_param->batch_, gru_param->hidden_size_, false); + LstmMatMulFp16(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAddFp16(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); + + TanhFp16(hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); + + ElementMulFp16(update_gate, hidden_state, hidden_state, gru_param->batch_ * gru_param->hidden_size_); + + float16_t one = 1.0f; + ElementOptSubFp16(&one, update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_, true); + + ElementMulAccFp16(update_gate, hidden_buffer, hidden_state, gru_param->batch_ * gru_param->hidden_size_); + + memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float16_t)); +} + +void GruUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_g, + const float16_t *weight_r, const float16_t *input_bias, const float16_t *state_bias, + float16_t *hidden_state, float16_t *buffer[4], const GruParameter *gru_param, + bool is_backward) { + float16_t *gate = buffer[1]; + for (int i = 0; i < 3; i++) { + const float16_t *weight_loop = weight_g + gru_param->input_size_ * gru_param->input_col_align_ * i; + const float16_t *bias_loop = input_bias + gru_param->input_col_align_ * i; + float16_t *gate_loop = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * i; + MatMulFp16(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, gru_param->input_size_, + gru_param->seq_len_ * gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, OutType_Nhwc); + } + + float16_t *update_gate = gate; + float16_t *reset_gate = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_; + float16_t *hidden_buffer = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * 2; + for (int t = 0; t < gru_param->seq_len_; t++) { + int real_t = is_backward ? gru_param->seq_len_ - t - 1 : t; + float16_t *update_gate_t = update_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float16_t *reset_gate_t = reset_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float16_t *hidden_buffer_t = hidden_buffer + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float16_t *output_ptr = output + real_t * gru_param->output_step_; + GruStepUnitFp16(output_ptr, update_gate_t, reset_gate_t, hidden_buffer_t, weight_r, state_bias, hidden_state, + buffer, gru_param); + } +} + +void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, + const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *buffer[4], + int check_seq_len, const GruParameter *gru_param) { + // forward + float16_t *packed_input = buffer[0]; + RowMajor2Col16MajorFp16(input, packed_input, gru_param->seq_len_ * gru_param->batch_, gru_param->input_size_, false); + GruUnidirectionalFp16(output, packed_input, weight_g, weight_r, input_bias, state_bias, hidden_state, buffer, + gru_param, false); + // zero out extra fw outputs + for (int t = check_seq_len; t < gru_param->seq_len_; t++) { + float16_t *output_ptr = output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + + // backward + if (gru_param->bidirectional_) { + const float16_t *backward_weight_g = weight_g + 3 * gru_param->input_col_align_ * gru_param->input_size_; + const float16_t *backward_weight_r = weight_r + 3 * gru_param->state_col_align_ * gru_param->hidden_size_; + const float16_t *backward_input_bias = input_bias + 3 * gru_param->input_col_align_; + const float16_t *backward_state_bias = state_bias + 3 * gru_param->state_col_align_; + float16_t *backward_output = output + gru_param->batch_ * gru_param->hidden_size_; + float16_t *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_; + GruUnidirectionalFp16(backward_output, packed_input, backward_weight_g, backward_weight_r, backward_input_bias, + backward_state_bias, backward_hidden_state, buffer, gru_param, true); + + // zero out extra bw outputs + for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) { + float16_t *output_ptr = backward_output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/gru_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/gru_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..4f790d154b03a505daba4a7683e65bbaef9ffc3d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/gru_fp16.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRU_H_ +#define NNACL_FP16_GRU_H_ +#include "nnacl/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void GruFp16(float16_t *output, const float16_t *input, const float16_t *weight_g, const float16_t *weight_r, + const float16_t *input_bias, const float16_t *state_bias, float16_t *hidden_state, float16_t *buffer[4], + int check_seq_len, const GruParameter *gru_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRU_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/instance_norm_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/instance_norm_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..280834ea9324a765fc8d70abe4cb3535a5db185c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/instance_norm_fp16.c @@ -0,0 +1,217 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/instance_norm_fp16.h" +#include +#include "nnacl/errorcode.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" + +int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel = param->channel_; + int hw_plane = param->inner_size_; + NNACL_CHECK_ZERO_RETURN_ERR(hw_plane); + int channel_step = UP_DIV(channel, param->op_parameter_.thread_num_); + int channel_begin = task_id * channel_step; + int channel_end = MSMIN(channel_begin + channel_step, channel); + + for (int b = 0; b < param->batch_; b++) { + const float16_t *src_b = src_data + b * channel * hw_plane; + float16_t *dst_b = dst_data + b * channel * hw_plane; + for (int c = channel_begin; c < channel_end; c++) { + const float16_t *src = src_b + c * hw_plane; + float16_t *dst = dst_b + c * hw_plane; + float mean = 0.0f; + float square_mean = 0.0f; + + int index = 0; + for (; index <= hw_plane - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + float16x8_t squarev = vmulq_f16(srcv, srcv); + + float16x4_t sum2 = vadd_f16(vget_low_f16(srcv), vget_high_f16(srcv)); + float32x4_t sum_f32 = vcvt_f32_f16(sum2); + mean += MS_ADDVQ_F32(sum_f32); + + float16x4_t square2 = vadd_f16(vget_low_f16(squarev), vget_high_f16(squarev)); + float32x4_t square_f32 = vcvt_f32_f16(square2); + square_mean += MS_ADDVQ_F32(square_f32); + } + for (; index < hw_plane; index++) { + mean += src[index]; + square_mean += src[index] * src[index]; + } + + mean /= (float)hw_plane; + square_mean /= (float)hw_plane; + const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_); + + index = 0; + float16x8_t meanv = vdupq_n_f16(mean); + float16x8_t denov = vdupq_n_f16(deno); + for (; index <= hw_plane - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + float16x8_t outv = vsubq_f16(srcv, meanv); + outv = vmulq_f16(outv, denov); + + float16x8_t gammav = vdupq_n_f16(gamma_data[c]); + float16x8_t betav = vdupq_n_f16(beta_data[c]); + outv = vmulq_f16(outv, gammav); + outv = vaddq_f16(outv, betav); + vst1q_f16(dst + index, outv); + } + for (; index < hw_plane; index++) { + dst[index] = (src[index] - mean) * deno; + dst[index] = dst[index] * gamma_data[c] + beta_data[c]; + } + } + } + return NNACL_OK; +} + +int InstanceNormNC8HW8Fp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel = param->channel_; + int hw_plane = param->inner_size_; + NNACL_CHECK_ZERO_RETURN_ERR(hw_plane); + int channel_step = UP_DIV(UP_DIV(channel, C8NUM), param->op_parameter_.thread_num_) * C8NUM; + int channel_begin = (int)(task_id)*channel_step; + int channel_end = MSMIN(channel_begin + channel_step, channel); + int c8_down = channel_end / C8NUM * C8NUM; + int c_res = channel_end - c8_down; + float32x4_t hw_plane_4 = vdupq_n_f32(hw_plane); + for (int b = 0; b < param->batch_; b++) { + const float16_t *src_b = src_data + b * channel * hw_plane; + float16_t *dst_b = dst_data + b * channel * hw_plane; + int c = channel_begin; + for (; c <= channel_end - C16NUM; c += C16NUM) { + const float16_t *src = src_b + c * hw_plane; + const float16_t *src1 = src_b + (c + C8NUM) * hw_plane; + float16_t *dst = dst_b + c; + float32x4_t mean1 = vdupq_n_f32(0.0f); + float32x4_t mean2 = vdupq_n_f32(0.0f); + float32x4_t mean3 = vdupq_n_f32(0.0f); + float32x4_t mean4 = vdupq_n_f32(0.0f); + float32x4_t square_mean1 = vdupq_n_f32(0.0f); + float32x4_t square_mean2 = vdupq_n_f32(0.0f); + float32x4_t square_mean3 = vdupq_n_f32(0.0f); + float32x4_t square_mean4 = vdupq_n_f32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float16x8_t srcv1 = vld1q_f16(src1 + index * C8NUM); + + float32x4_t srcv01 = vcvt_f32_f16(vget_low_f16(srcv)); + float32x4_t srcv02 = vcvt_f32_f16(vget_high_f16(srcv)); + float32x4_t srcv11 = vcvt_f32_f16(vget_low_f16(srcv1)); + float32x4_t srcv12 = vcvt_f32_f16(vget_high_f16(srcv1)); + mean1 = vaddq_f32(mean1, srcv01); + mean2 = vaddq_f32(mean2, srcv02); + mean3 = vaddq_f32(mean3, srcv11); + mean4 = vaddq_f32(mean4, srcv12); + square_mean1 = vaddq_f32(square_mean1, vmulq_f32(srcv01, srcv01)); + square_mean2 = vaddq_f32(square_mean2, vmulq_f32(srcv02, srcv02)); + square_mean3 = vaddq_f32(square_mean3, vmulq_f32(srcv11, srcv11)); + square_mean4 = vaddq_f32(square_mean4, vmulq_f32(srcv12, srcv12)); + } + float16x8_t mean = + vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean1, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean2, hw_plane_4))); + float16x8_t mean_1 = + vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean3, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean4, hw_plane_4))); + float16x8_t square_mean = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean1, hw_plane_4)), + vcvt_f16_f32(MS_DIVQ_F32(square_mean2, hw_plane_4))); + float16x8_t square_mean_1 = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean3, hw_plane_4)), + vcvt_f16_f32(MS_DIVQ_F32(square_mean4, hw_plane_4))); + float16x8_t deno = vaddq_f16(vsubq_f16(square_mean, vmulq_f16(mean, mean)), vdupq_n_f16(param->epsilon_)); + float16x8_t deno1 = vaddq_f16(vsubq_f16(square_mean_1, vmulq_f16(mean_1, mean_1)), vdupq_n_f16(param->epsilon_)); + deno = 1 / MS_SQRTFX8_F16(deno); + deno1 = 1 / MS_SQRTFX8_F16(deno1); + + float16x8_t gammav = vmulq_f16(vld1q_f16(gamma_data + c), deno); // deno * gamma_data[c] + float16x8_t gammav1 = vmulq_f16(vld1q_f16(gamma_data + c + C8NUM), deno1); // deno * gamma_data[c] + float16x8_t betav = vld1q_f16(beta_data + c); + float16x8_t betav1 = vld1q_f16(beta_data + c + C8NUM); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float16x8_t srcv1 = vld1q_f16(src1 + index * C8NUM); + float16x8_t outv = vsubq_f16(srcv, mean); + float16x8_t outv1 = vsubq_f16(srcv1, mean_1); + outv = vmulq_f16(outv, gammav); + outv1 = vmulq_f16(outv1, gammav1); + outv = vaddq_f16(outv, betav); + outv1 = vaddq_f16(outv1, betav1); + vst1q_f16(dst + index * channel, outv); + vst1q_f16(dst + index * channel + C8NUM, outv1); + } + } + for (; c <= channel_end - C8NUM; c += C8NUM) { + const float16_t *src = src_b + c * hw_plane; + float16_t *dst = dst_b + c; + float32x4_t mean1 = vdupq_n_f32(0.0f); + float32x4_t mean2 = vdupq_n_f32(0.0f); + float32x4_t square_mean1 = vdupq_n_f32(0.0f); + float32x4_t square_mean2 = vdupq_n_f32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float32x4_t srcv1 = vcvt_f32_f16(vget_low_f16(srcv)); + float32x4_t srcv2 = vcvt_f32_f16(vget_high_f16(srcv)); + mean1 = vaddq_f32(mean1, srcv1); + mean2 = vaddq_f32(mean2, srcv2); + square_mean1 = vaddq_f32(square_mean1, vmulq_f32(srcv1, srcv1)); + square_mean2 = vaddq_f32(square_mean2, vmulq_f32(srcv2, srcv2)); + } + float16x8_t mean = + vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(mean1, hw_plane_4)), vcvt_f16_f32(MS_DIVQ_F32(mean2, hw_plane_4))); + float16x8_t square_mean = vcombine_f16(vcvt_f16_f32(MS_DIVQ_F32(square_mean1, hw_plane_4)), + vcvt_f16_f32(MS_DIVQ_F32(square_mean2, hw_plane_4))); + float16x8_t deno = + vaddq_f16(vsubq_f16(square_mean, vmulq_f16(mean, mean)), vdupq_n_f16(param->epsilon_)); // question + deno = 1 / MS_SQRTFX8_F16(deno); // question + + float16x8_t gammav = vmulq_f16(vld1q_f16(gamma_data + c), deno); // deno * gamma_data[c] + float16x8_t betav = vld1q_f16(beta_data + c); + for (int index = 0; index < hw_plane; ++index) { + float16x8_t srcv = vld1q_f16(src + index * C8NUM); + float16x8_t outv = vsubq_f16(srcv, mean); + outv = vmulq_f16(outv, gammav); + outv = vaddq_f16(outv, betav); + vst1q_f16(dst + index * channel, outv); + } + } + for (; c < channel_end; ++c) { + const float16_t *src = src_b + c8_down * hw_plane + c; + float16_t *dst = dst_b + c; + float mean = 0.0f; + float square_mean = 0.0f; + for (int index = 0; index < hw_plane; ++index) { + float16_t tmp = src[index * c_res]; + mean += tmp; + square_mean += tmp * tmp; + } + mean /= (float)hw_plane; + square_mean /= (float)hw_plane; + const float deno = gamma_data[c] / sqrtf(square_mean - mean * mean + param->epsilon_); + for (int index = 0; index < hw_plane; ++index) { + dst[index * channel] = (src[index * c_res] - mean) * deno + beta_data[c]; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/instance_norm_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/instance_norm_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..a944839beae39006623d016c940cc4dd528aa34d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/instance_norm_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_INSTANCE_NORM_FP16_H_ +#define NNACL_FP16_INSTANCE_NORM_FP16_H_ + +#include "nnacl/instance_norm_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif + +int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id); +int InstanceNormNC8HW8Fp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, + const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_INSTANCE_NORM_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/layer_norm_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/layer_norm_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..c486c9be14b2d5dc183eb8541218512cf0a09c10 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/layer_norm_fp16.c @@ -0,0 +1,110 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/layer_norm_fp16.h" +#include +#include "nnacl/errorcode.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" + +int LayerNormMeanAndSquareFp16(const float16_t *src, int num, float16_t *mean, float16_t *variance) { + if (num <= 0) { + return NNACL_ERR; + } + int index = 0; + float sum = 0.0f; + float square_mean = 0.0f; + for (; index <= num - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + for (int i = 0; i < C8NUM; ++i) { + square_mean += srcv[i] * srcv[i]; + } + float16x4_t sum2 = vadd_f16(vget_low_f16(srcv), vget_high_f16(srcv)); + float32x4_t sum_f32 = vcvt_f32_f16(sum2); + sum += MS_ADDVQ_F32(sum_f32); + } + for (; index < num; index++) { + sum += src[index]; + square_mean += src[index] * src[index]; + } + *mean = (float16_t)(sum / num); + square_mean = square_mean / num; + *variance = square_mean - (*mean) * (*mean); + return NNACL_OK; +} + +void LayerNormGammaAndBetaFp16(float16_t *dst, const float16_t *src, const float16_t *gamma_data, + const float16_t *beta_data, int num, const float16_t mean, const float16_t deno) { + int index = 0; + float16x8_t meanv = vdupq_n_f16(mean); + float16x8_t denov = vdupq_n_f16(deno); + for (; index <= num - C8NUM; index += C8NUM) { + float16x8_t srcv = vld1q_f16(src + index); + float16x8_t outv = vsubq_f16(srcv, meanv); + outv = vmulq_f16(outv, denov); + float16x8_t gammav = vld1q_f16(gamma_data + index); + float16x8_t betav = vld1q_f16(beta_data + index); + outv = vmulq_f16(outv, gammav); + outv = vaddq_f16(outv, betav); + vst1q_f16(dst + index, outv); + } + for (; index < num; index++) { + dst[index] = (src[index] - mean) * (deno); + dst[index] = dst[index] * gamma_data[index] + beta_data[index]; + } +} + +int LayerNormFp16(const float16_t *src_data, const float16_t *gamma_data, const float16_t *beta_data, + float16_t *dst_data, float16_t *out_mean, float16_t *out_variance, const LayerNormComputeParam *param, + int task_id, int thread_num) { + if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { + return NNACL_NULL_PTR; + } + NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_); + NNACL_CHECK_ZERO_RETURN_ERR(param->params_outer_size_); + NNACL_CHECK_ZERO_RETURN_ERR(thread_num); + int step = UP_DIV(param->norm_outer_size_, thread_num); + int thread_end = MSMIN((task_id + 1) * step, param->norm_outer_size_); + for (int i = task_id * step; i < thread_end; i++) { + const float16_t *src_norm = src_data + i * param->norm_inner_size_; + float16_t *dst_norm = dst_data + i * param->norm_inner_size_; + float16_t cur_mean = 0.0f; + float16_t cur_variance = 0.0f; + int ret = LayerNormMeanAndSquareFp16(src_norm, param->norm_inner_size_, &cur_mean, &cur_variance); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + if (out_mean != NULL) { + out_mean[i] = cur_mean; + } + if (out_variance != NULL) { + out_variance[i] = cur_variance; + } + const float16_t deno = 1 / sqrtf(cur_variance + param->epsilon_); + if (param->norm_outer_size_ <= param->params_outer_size_) { + for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) { + const float16_t *src_param = src_norm + x * param->params_inner_size_; + float16_t *dst_param = dst_norm + x * param->params_inner_size_; + LayerNormGammaAndBetaFp16(dst_param, src_param, gamma_data, beta_data, param->params_inner_size_, cur_mean, + deno); + } + } else { + int x = i / param->params_outer_size_; + const float16_t *gamma = gamma_data + x * param->norm_inner_size_; + const float16_t *beta = beta_data + x * param->norm_inner_size_; + LayerNormGammaAndBetaFp16(dst_norm, src_norm, gamma, beta, param->norm_inner_size_, cur_mean, deno); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/layer_norm_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/layer_norm_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..080892ad39ee403f7f2f6b7b3502e86bf196c4ff --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/layer_norm_fp16.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_LAYER_NORM_FP16_H_ +#define NNACL_FP16_LAYER_NORM_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/kernel/layer_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormFp16(const float16_t *src_data, const float16_t *gamma_data, const float16_t *beta_data, + float16_t *dst_data, float16_t *out_mean, float16_t *out_variance, const LayerNormComputeParam *param, + int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_LAYER_NORM_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/log_softmax_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/log_softmax_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..6460e8dda2e5c13ce03a76cd0219a05fcd3bb51e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/log_softmax_fp16.c @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/log_softmax_fp16.h" +#include +#include +#include "nnacl/fp16/softmax_fp16.h" +#include "nnacl/fp16/exp_fp16.h" + +void LogSoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, float16_t *exp_data, int batch, int channel) { + SoftmaxNormFp16(src, dst, batch, channel); + ExpFp16(dst, exp_data, batch * channel); + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + float16_t sum = 0; + int j = 0; +#ifdef ENABLE_NEON + float16x8_t sum8 = vdupq_n_f16(0); + int count = (channel / C8NUM) * C8NUM; + for (; j < count; j += C8NUM) { + sum8 = vaddq_f16(sum8, vld1q_f16(exp_data + cur_batch_offset + j)); + } + sum = sum8[0] + sum8[1] + sum8[2] + sum8[3] + sum8[4] + sum8[5] + sum8[6] + sum8[7]; +#endif + for (; j < channel; j++) { + sum += exp_data[cur_batch_offset + j]; + } + for (int k = 0; k < channel; k++) { + dst[cur_batch_offset + k] = dst[cur_batch_offset + k] - log(sum); + } + } +} + +// output = (input - reduce_max(input, axis)) - log(reduce_sum(exp(input - reduce_max(input, axis)), axis)) +void LogSoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int *input_shape, int n_dim, + int axis) { + int inner_size = 1; + int outter_size = 1; + + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float16_t max_data = input_ptr[inner_offset]; + sum_data[k + sum_outter_offset] = 0; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = input_ptr[axis_offset] - max_data; + sum_data[k + sum_outter_offset] += exp(output_ptr[axis_offset]); + } + } + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] - log(sum_data[k + sum_outter_offset]); + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/log_softmax_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/log_softmax_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..68a419e030a38352a50b56140453b63ff209ac8c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/log_softmax_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_LOG_SOFTMAX_FP16_H_ +#define NNACL_FP16_LOG_SOFTMAX_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/softmax_parameter.h" +#ifdef ENABLE_NEON +#include +#endif +#ifdef __cplusplus +extern "C" { +#endif +void LogSoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, float16_t *exp_data, int batch, int channel); +void LogSoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int *input_shape, int n_dim, + int axis); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_LOG_SOFTMAX_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/lstm_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/lstm_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..73e66b7799df9b92f1ccf3331cc2361d07c5fc6a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/lstm_fp16.c @@ -0,0 +1,367 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/lstm_fp16.h" +#include +#include +#include "nnacl/fp16/activation_fp16.h" +#include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl/fp16/matmul_fp16.h" +#include "nnacl/fp16/cast_fp16.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" + +void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align, + const int32_t *order) { + for (int i = 0; i < batch; i++) { + const float *src_batch = src + i * col * deep; + float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align * deep; +#ifdef ENABLE_ARM64 + RowMajor2ColNMajorFp16(src_batch, dst_batch, col, deep, true); +#else + RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, true); +#endif + } +} + +void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align, + const int32_t *order) { + for (int i = 0; i < batch; i++) { + const float16_t *src_batch = src + i * col * deep; + float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align * deep; +#ifdef ENABLE_ARM64 + RowMajor2ColNMajorFp16(src_batch, dst_batch, col, deep, false); +#else + RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, false); +#endif + } +} + +void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float *src_batch = src + i * col; + float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align; + Float32ToFloat16(src_batch, dst_batch, col); + } + if (is_bidirectional) { + const float *backward_src = src + batch * col; + float16_t *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float *backward_src_batch = backward_src + i * col; + float16_t *backward_dst_batch = backward_dst + (order == NULL ? i : order[i]) * col_align; + Float32ToFloat16(backward_src_batch, backward_dst_batch, col); + } + } +} + +void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float16_t *src_batch = src + i * col; + float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align; + (void)memcpy(dst_batch, src_batch, col * sizeof(float16_t)); + } + if (is_bidirectional) { + const float16_t *backward_src = src + batch * col; + float16_t *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float16_t *backward_src_batch = backward_src + i * col; + float16_t *backward_dst_batch = backward_dst + (order == NULL ? i : order[i]) * col_align; + (void)memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float16_t)); + } + } +} + +// input: [row, inner_size]; weight: [col, inner_size]; output: [row, col] +void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols, + int inner_size) { + for (int r = 0; r < rows; r++) { + for (int c = 0; c < cols; c++) { + float16_t res = 0; + const float16_t *input_col = input + r * inner_size; + const float16_t *weight_col = weight + c * inner_size; + int index = 0; + float16x8_t out = vdupq_n_f16(0.0f); + for (; index <= inner_size - 8; index += 8) { + float16x8_t in_0 = vld1q_f16(input_col + index); + float16x8_t in_1 = vld1q_f16(weight_col + index); + out = vfmaq_f16(out, in_1, in_0); + } + float16x4_t add2 = vadd_f16(vget_low_f16(out), vget_high_f16(out)); + float16x4_t add4 = vpadd_f16(add2, add2); + float16x4_t add8 = vpadd_f16(add4, add4); + res += vget_lane_f16(add8, 0); + for (; index < inner_size; index++) { + res += input_col[index] * weight_col[index]; + } + output[r * cols + c] += res; + } + } +} + +void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size) { + int index = 0; + for (; index <= element_size - 8; index += 8) { + float16x8_t in_0 = vld1q_f16(input0 + index); + float16x8_t in_1 = vld1q_f16(input1 + index); + float16x8_t out = vld1q_f16(output + index); + out = vfmaq_f16(out, in_1, in_0); + vst1q_f16(output + index, out); + } + for (; index < element_size; index++) { + output[index] += input0[index] * input1[index]; + } +} + +int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size) { + int index = 0; + for (; index <= element_size - 8; index += 8) { + float16x8_t vin0 = vld1q_f16(input0 + index); + float16x8_t vout = vld1q_f16(output + index); + vout = MS_FMAQ_N_F16(vout, vin0, input1); + vst1q_f16(output + index, vout); + } + for (; index < element_size; index++) { + output[index] += input0[index] * input1; + } + return NNACL_OK; +} + +void UpdateStateFp16(float16_t *cell_state, const float16_t *forget_gate, const float16_t *input_gate, + const float16_t *cell_gate, float16_t *state_buffer, int batch, int hidden_size, + float16_t zoneout) { + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { // zoneout * old_cell_state + (void)memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float16_t)); + ElementOptMulFp16(state_buffer, &zoneout, state_buffer, batch * hidden_size, false); + } + + ElementMulFp16(forget_gate, cell_state, cell_state, batch * hidden_size); + ElementMulAccFp16(input_gate, cell_gate, cell_state, batch * hidden_size); + + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { // (1 - zoneout) * new_cell_state + ElementOptMulAccFp16(cell_state, 1 - zoneout, state_buffer, batch * hidden_size); + } +} + +void UpdateOutputFp16(float16_t *hidden_state, float16_t *output, const float16_t *cell_state, float16_t *output_gate, + const float16_t *weight_project, const float16_t *project_bias, float16_t *buffer[C7NUM], + const LstmParameter *lstm_param) { + int batch = lstm_param->batch_; + int hidden_size = lstm_param->hidden_size_; + int output_size = lstm_param->output_size_; + float16_t *state_buffer = buffer[C5NUM]; + float16_t *hidden_buffer = weight_project ? buffer[C3NUM] : hidden_state; + float16_t zoneout = lstm_param->zoneout_hidden_; + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { + (void)memcpy(state_buffer, hidden_state, batch * output_size * sizeof(float16_t)); + ElementOptMulFp16(state_buffer, &zoneout, state_buffer, batch * output_size, false); + } + + TanhFp16(cell_state, hidden_buffer, batch * hidden_size); + ElementMulFp16(hidden_buffer, output_gate, hidden_buffer, batch * hidden_size); + + if (weight_project) { + float16_t *left_matrix = hidden_buffer; +#ifdef ENABLE_ARM64 + if (batch >= C4NUM) { + left_matrix = buffer[C6NUM]; + RowMajor2ColLadder12MajorFp16(hidden_buffer, left_matrix, batch, hidden_size); + } +#else + if (batch != 1) { + left_matrix = buffer[C6NUM]; + RowMajor2Col16MajorFp16(hidden_buffer, left_matrix, batch, hidden_size, false); + } +#endif + LstmMatMulFp16(hidden_state, left_matrix, weight_project, project_bias, batch, hidden_size, output_size, + batch == 1); + } + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { + ElementOptMulAccFp16(hidden_state, 1 - zoneout, state_buffer, batch * output_size); + } + (void)memcpy(output, hidden_state, batch * output_size * sizeof(float16_t)); +} + +#ifdef ENABLE_ARM64 +void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, + int col, bool is_vec) { + MatmulFp16OptV2(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); +} +#else +void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, + int col, bool is_vec) { + if (is_vec) { + (void)memcpy(c, bias, col * sizeof(float16_t)); + MatMulAccFp16(c, a, b, row, col, deep); + } else { + MatMulFp16(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); + } +} +#endif + +void UpdateLstmGateFp16(float16_t *gate_buffer, const float16_t *input, const float16_t *weight, const float16_t *bias, + int row, int deep, int col, int col_align, bool is_vec) { + for (int i = 0; i < 4; i++) { + const float16_t *weight_i = weight + deep * col_align * i; + const float16_t *bias_i = bias + col_align * i; + float16_t *gate = gate_buffer + row * col * i; + LstmMatMulFp16(gate, input, weight_i, bias_i, row, deep, col, is_vec); + } +} + +void LstmStepUnitFp16(float16_t *output, float16_t *input_gate, float16_t *forget_gate, float16_t *cell_gate, + float16_t *output_gate, const float16_t *state_weight, const float16_t *state_bias, + const float16_t *weight_project, const float16_t *project_bias, float16_t *hidden_state, + float16_t *cell_state, float16_t *buffer[C7NUM], const LstmParameter *lstm_param) { + float16_t *packed_state = buffer[C2NUM]; + float16_t *state_gate = buffer[C3NUM]; + float16_t *cell_buffer = buffer[C4NUM]; + float16_t *hidden_buffer = buffer[C5NUM]; + bool is_vec = lstm_param->batch_ == 1; +#ifdef ENABLE_ARM64 + if (lstm_param->batch_ <= C3NUM) { + UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); + } else { + RowMajor2ColLadder12MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_); + UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); + } +#else + if (is_vec) { + UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); + } else { + RowMajor2Col16MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_, false); + UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); + } +#endif + ElementAddFp16(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(cell_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 3, cell_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAddFp16(output_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_, output_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + + // update input_gate + SigmoidFp16(input_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); + + // update forget_gate + SigmoidFp16(forget_gate, forget_gate, lstm_param->batch_ * lstm_param->hidden_size_); + + // update cell_gate + TanhFp16(cell_gate, cell_gate, lstm_param->batch_ * lstm_param->hidden_size_); + // update cell state + UpdateStateFp16(cell_state, forget_gate, input_gate, cell_gate, cell_buffer, lstm_param->batch_, + lstm_param->hidden_size_, lstm_param->zoneout_cell_); + + // update output_gate + SigmoidFp16(output_gate, output_gate, lstm_param->batch_ * lstm_param->hidden_size_); + // update output + UpdateOutputFp16(hidden_state, output, cell_state, output_gate, weight_project, project_bias, buffer, lstm_param); + + if (!(lstm_param->zoneout_cell_ >= -FLT_EPSILON && lstm_param->zoneout_cell_ <= FLT_EPSILON)) { + (void)memcpy(cell_state, cell_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float16_t)); + } + + if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) { + (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->output_size_ * sizeof(float16_t)); + } +} + +void LstmGateCompute(float16_t *gate, const float16_t *input, const float16_t *weight_i, const float16_t *input_bias, + const LstmParameter *lstm_param) { + int row_input = lstm_param->seq_len_ * lstm_param->batch_; + for (int i = 0; i < C4NUM; i++) { + const float16_t *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; + const float16_t *bias_loop = input_bias + lstm_param->input_col_align_ * i; + float16_t *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i; +#ifdef ENABLE_ARM64 + MatmulFp16OptV2(input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, row_input, + lstm_param->hidden_size_, lstm_param->hidden_size_, OutType_Nhwc); +#else + MatMulFp16(input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, row_input, + lstm_param->hidden_size_, lstm_param->hidden_size_, OutType_Nhwc); +#endif + } +} + +void LstmUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_i, + const float16_t *weight_h, const float16_t *input_bias, const float16_t *state_bias, + const float16_t *weight_project, const float16_t *project_bias, float16_t *hidden_state, + float16_t *cell_state, float16_t *buffer[C7NUM], const LstmParameter *lstm_param, + bool is_backward) { + float16_t *gate = buffer[1]; + LstmGateCompute(gate, packed_input, weight_i, input_bias, lstm_param); + + float16_t *input_gate = gate; + float16_t *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2; + float16_t *cell_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 3; + float16_t *output_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_; + for (int t = 0; t < lstm_param->seq_len_; t++) { + int real_t = is_backward ? lstm_param->seq_len_ - t - 1 : t; + float16_t *input_gate_t = input_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *forget_gate_t = forget_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *cell_gate_t = cell_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float16_t *output_ptr = output + real_t * lstm_param->output_step_; + LstmStepUnitFp16(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, + weight_project, project_bias, hidden_state, cell_state, buffer, lstm_param); + } +} + +void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h, + const float16_t *input_bias, const float16_t *state_bias, const float16_t *weight_project, + const float16_t *project_bias, float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[C7NUM], + const LstmParameter *lstm_param) { + // forward +#ifdef ENABLE_ARM64 + const float16_t *packed_input = input; + if (lstm_param->batch_ * lstm_param->seq_len_ >= C4NUM) { + float16_t *temp_input = buffer[0]; + RowMajor2ColLadder12MajorFp16(input, temp_input, lstm_param->seq_len_ * lstm_param->batch_, + lstm_param->input_size_); + packed_input = temp_input; + } +#else + float16_t *packed_input = buffer[0]; + RowMajor2Col16MajorFp16(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_, + false); +#endif + LstmUnidirectionalFp16(output, packed_input, weight_i, weight_h, input_bias, state_bias, weight_project, project_bias, + hidden_state, cell_state, buffer, lstm_param, false); + + // backward + if (lstm_param->bidirectional_) { + const float16_t *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_; + const float16_t *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->output_size_; + const float16_t *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_; + const float16_t *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_; + const float16_t *backward_weight_project = + weight_project ? weight_project + lstm_param->hidden_size_ * lstm_param->proj_col_align_ : NULL; + float16_t *backward_output = output + lstm_param->batch_ * lstm_param->output_size_; + float16_t *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; + float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->output_size_; + + LstmUnidirectionalFp16(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, + backward_state_bias, backward_weight_project, project_bias, backward_hidden_state, + backward_cell_state, buffer, lstm_param, true); + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/lstm_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/lstm_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..d6af9c7872bbedcf9259cf7588667b58c00b624f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/lstm_fp16.h @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_LSTM_FP16_H_ +#define NNACL_FP16_LSTM_FP16_H_ + +#include "nnacl/lstm_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align, + const int32_t *order); + +void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align, + const int32_t *order); + +void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order); + +void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order); + +void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, + int col, bool is_vec); + +void MatMulAccFp16(float16_t *output, const float16_t *input, const float16_t *weight, int rows, int cols, + int inner_size); + +void ElementMulAccFp16(const float16_t *input0, const float16_t *input1, float16_t *output, int element_size); + +int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float16_t *output, const int element_size); + +void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight_i, const float16_t *weight_h, + const float16_t *input_bias, const float16_t *state_bias, const float16_t *weight_project, + const float16_t *project_bias, float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[C7NUM], + const LstmParameter *lstm_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_LSTM_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/matmul_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/matmul_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..40265b394135764b6795d1f17a299d8b9a16589f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/matmul_fp16.c @@ -0,0 +1,1204 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/matmul_fp16.h" + +static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + int row_c8 = row / C8NUM * C8NUM; + int col_c8 = col / C8NUM * C8NUM; + const float16_t *src = (const float16_t *)src_ptr; + int ci = 0; + for (; ci < col_c8; ci += C8NUM) { + int ri = 0; + for (; ri < row_c8; ri += C8NUM) { + const float16_t *src_ptr1 = src + ci * row + ri; + float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; +#ifdef ENABLE_ARM64 + size_t strid_row = row * 2; + asm volatile( + "mov x10, %[src_ptr1]\n" + "mov x11, %[dst_ptr1]\n" + "mov x12, %[strid_row]\n" + "ld1 {v0.8h}, [x10], x12\n" + "ld1 {v1.8h}, [x10], x12\n" + "ld1 {v2.8h}, [x10], x12\n" + "ld1 {v3.8h}, [x10], x12\n" + "ld1 {v4.8h}, [x10], x12\n" + "ld1 {v5.8h}, [x10], x12\n" + "ld1 {v6.8h}, [x10], x12\n" + "ld1 {v7.8h}, [x10], x12\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip1 v9.8h, v2.8h, v3.8h\n" + "zip1 v10.8h, v4.8h, v5.8h\n" + "zip1 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v16.2d, v12.2d, v14.2d\n" + "trn2 v18.2d, v12.2d, v14.2d\n" + "trn1 v17.2d, v13.2d, v15.2d\n" + "trn2 v19.2d, v13.2d, v15.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v22.2d, v12.2d, v14.2d\n" + "trn1 v21.2d, v13.2d, v15.2d\n" + "trn2 v23.2d, v13.2d, v15.2d\n" + + "st1 {v16.8h}, [x11], #16\n" + "st1 {v17.8h}, [x11], #16\n" + "st1 {v18.8h}, [x11], #16\n" + "st1 {v19.8h}, [x11], #16\n" + "st1 {v20.8h}, [x11], #16\n" + "st1 {v21.8h}, [x11], #16\n" + "st1 {v22.8h}, [x11], #16\n" + "st1 {v23.8h}, [x11], #16\n" + : + : [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); +#else + for (int tr = 0; tr < C8NUM; ++tr) { + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr]; + } + } +#endif + } + for (; ri < row; ++ri) { + const float16_t *src_ptr1 = src + ci * row; + float16_t *dst_ptr1 = dst_ptr + ci * row; + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri]; + } + } + } + for (int r = 0; r < row; r++) { + for (int tc = ci; tc < col; tc++) { + int cd8 = tc / C8NUM; + int cm8 = tc % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r]; + } + } +} + +static void Col2Row8SrcFromFp32(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + int row_c8 = row / C8NUM * C8NUM; + int col_c8 = col / C8NUM * C8NUM; + int ci = 0; + const float *src = (const float *)src_ptr; + for (; ci < col_c8; ci += C8NUM) { + int ri = 0; + for (; ri < row_c8; ri += C8NUM) { + const float *src_ptr1 = src + ci * row + ri; + float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; +#ifdef ENABLE_ARM64 + size_t strid_row = row * 4; + asm volatile( + "mov x10, %[src_ptr1]\n" + "mov x11, %[dst_ptr1]\n" + "mov x12, %[strid_row]\n" + "ld1 {v8.4s, v9.4s}, [x10], x12\n" + "ld1 {v10.4s, v11.4s}, [x10], x12\n" + "ld1 {v12.4s, v13.4s}, [x10], x12\n" + "ld1 {v14.4s, v15.4s}, [x10], x12\n" + "ld1 {v16.4s, v17.4s}, [x10], x12\n" + "ld1 {v18.4s, v19.4s}, [x10], x12\n" + "ld1 {v20.4s, v21.4s}, [x10], x12\n" + "ld1 {v22.4s, v23.4s}, [x10], x12\n" + + "fcvtn v0.4h, v8.4s\n" + "fcvtn2 v0.8h, v9.4s\n" + "fcvtn v1.4h, v10.4s\n" + "fcvtn2 v1.8h, v11.4s\n" + "fcvtn v2.4h, v12.4s\n" + "fcvtn2 v2.8h, v13.4s\n" + "fcvtn v3.4h, v14.4s\n" + "fcvtn2 v3.8h, v15.4s\n" + "fcvtn v4.4h, v16.4s\n" + "fcvtn2 v4.8h, v17.4s\n" + "fcvtn v5.4h, v18.4s\n" + "fcvtn2 v5.8h, v19.4s\n" + "fcvtn v6.4h, v20.4s\n" + "fcvtn2 v6.8h, v21.4s\n" + "fcvtn v7.4h, v22.4s\n" + "fcvtn2 v7.8h, v23.4s\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip1 v9.8h, v2.8h, v3.8h\n" + "zip1 v10.8h, v4.8h, v5.8h\n" + "zip1 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v16.2d, v12.2d, v14.2d\n" + "trn2 v18.2d, v12.2d, v14.2d\n" + "trn1 v17.2d, v13.2d, v15.2d\n" + "trn2 v19.2d, v13.2d, v15.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v22.2d, v12.2d, v14.2d\n" + "trn1 v21.2d, v13.2d, v15.2d\n" + "trn2 v23.2d, v13.2d, v15.2d\n" + + "st1 {v16.8h}, [x11], #16\n" + "st1 {v17.8h}, [x11], #16\n" + "st1 {v18.8h}, [x11], #16\n" + "st1 {v19.8h}, [x11], #16\n" + "st1 {v20.8h}, [x11], #16\n" + "st1 {v21.8h}, [x11], #16\n" + "st1 {v22.8h}, [x11], #16\n" + "st1 {v23.8h}, [x11], #16\n" + : + : [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); +#else + for (int tr = 0; tr < C8NUM; ++tr) { + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]); + } + } +#endif + } + for (; ri < row; ++ri) { + const float *src_ptr1 = src + ci * row; + float16_t *dst_ptr1 = dst_ptr + ci * row; + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]); + } + } + } + for (int r = 0; r < row; r++) { + for (int tc = ci; tc < col; tc++) { + int cd8 = tc / C8NUM; + int cm8 = tc % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]); + } + } +} + +void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) { + if (src_float16) { + Col2Row8SrcFromFp16(src_ptr, dst_ptr, row, col); + } else { + Col2Row8SrcFromFp32(src_ptr, dst_ptr, row, col); + } + return; +} + +void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode) { + if (write_mode == OutType_Nhwc) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r16div = r / 16, r16mod = r % 16; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r16div * deep * 16 + d * 16 + r16mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (write_mode == OutType_C8) { + int col_8 = UP_ROUND(col, C8NUM); + int row_16 = UP_ROUND(row, C16NUM); + for (int r = 0; r < row_16; r++) { + for (int c = 0; c < col_8; c++) { + int r16div = r / C16NUM, r16mod = r % C16NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = (c8div * C8NUM * row_16 + r * C8NUM + c8mod); + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else { + for (int i = 0; i < row; ++i) { + int src_r_offset = i; + int dst_r_offset = i * col * stride; + for (int j = 0; j < col; ++j) { + int c8div = j / 8, c8mod = j % 8; + size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; + float16_t value = 0; + for (int d = 0; d < deep; ++d) { + size_t ai = src_r_offset + d * C16NUM; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, j) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } +} + +void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode) { + if (write_mode == OutType_Nhwc) { // common conv and matmul + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / 12, r12mod = r % 12; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * 12 + d * 12 + r12mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (write_mode == OutType_C8) { // common deconv + int col_8 = UP_ROUND(col, C8NUM); + int row_12 = UP_ROUND(row, C12NUM); + for (int r = 0; r < row_12; r++) { + for (int c = 0; c < col_8; c++) { + int r12div = r / C12NUM, r12mod = r % C12NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod); + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else { // winograd conv + for (int i = 0; i < row; ++i) { + int src_r_offset = i; + int dst_r_offset = i * col * stride; + for (int j = 0; j < col; ++j) { + int c8div = j / 8, c8mod = j % 8; + size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; + float16_t value = 0; + for (int d = 0; d < deep; ++d) { + size_t ai = src_r_offset + d * C12NUM; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, j) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } +} + +#ifdef ENABLE_DEBUG +void MatMul12x16Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, size_t stride, size_t out_type) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / C12NUM, r12mod = r % C12NUM; + int c16div = c / C16NUM, c16mod = c % C16NUM; + size_t index = r * stride + c; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; + size_t bi = c16div * deep * C16NUM + d * C16NUM + c16mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[index] = value; + } + } +} +#endif + +void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int row, int col, int stride, int out_type) { + if (out_type == OutType_C8) { + // common deconv +#ifdef ENABLE_ARM64 + MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, false); +#else + MatMul12x8Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); +#endif + } else { + // winograd conv(OntType_TileC8) and common conv(OutType_Nhwc) and matmul(OutType_Nhwc) +#ifdef ENABLE_ARM64 + MatmulFp16Neon64Opt(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); +#else + MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); +#endif + } + return; +} + +#ifdef ENABLE_ARM64 +// 1*8 X 8*8 -> 1 X 8 +void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, int depth, + int col) { + int ci = col; + const float16_t *bv_base = b; + + while (ci > 0) { + float16x8_t acc_0 = vdupq_n_f16((float16_t)0.0); + if (bias != NULL) { + acc_0 = vld1q_f16(bias); + bias += C8NUM; + } + + int di = 0; + for (; di < depth - C8NUM + 1; di += C8NUM) { + float16x8_t av = vld1q_f16(a + di); + float16x8_t bv_0; + float16x8_t bv_1; + for (int i = 0; i < C8NUM; i += C2NUM) { + bv_0 = vld1q_f16(bv_base); // bv_i为一行,8列数据 + acc_0 = vfmaq_n_f16(acc_0, bv_0, av[i]); // av[i]为向量中的一个值 + bv_base += C8NUM; + + bv_1 = vld1q_f16(bv_base); // bv_i为一行,8列数据 + acc_0 = vfmaq_n_f16(acc_0, bv_1, av[i + 1]); // av[i]为向量中的一个值 + bv_base += C8NUM; + } + } + if (di < depth) { + for (; di < depth; ++di) { + float16_t ai = a[di]; + float16x8_t bv0 = vld1q_f16(bv_base); + bv_base += C8NUM; + acc_0 = vfmaq_n_f16(acc_0, bv0, ai); + } + } + if (act_type == ActType_Relu) { + acc_0 = vmaxq_f16(acc_0, vdupq_n_f16((float16_t)0.0)); + } + if (act_type == ActType_Relu6) { + acc_0 = vminq_f16(vmaxq_f16(acc_0, vdupq_n_f16((float16_t)0.0)), vdupq_n_f16((float16_t)6.0)); + } + + // only save actual col num data + if (ci < C8NUM) { + for (int i = 0; i < ci; ++i) { + c[i] = acc_0[i]; + } + return; + } + vst1q_f16(c, acc_0); + c += C8NUM; + ci -= C8NUM; + } +} +#endif + +#ifdef ENABLE_ARM82_A32 +void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col) { + for (int ci = 0; ci < col; ci++) { + float value = 0; + for (int di = 0; di < depth; di++) { + value += a[di] * b[ci * depth + di]; + } + if (bias != NULL) value += bias[ci]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type == ActType_Relu || act_type == ActType_Relu6) value = MSMAX(0.0f, value); + c[ci] = value; + } +} +#endif + +void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int col) { +#ifdef ENABLE_ARM64 + MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col); +#else + MatVecMulA32NeonFp16(a, b, c, bias, (int)act_type, depth, col); +#endif +} + +#ifdef ENABLE_ARM64 +static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) { + size_t stride = col * 2; + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.8h}, [x10], %[stride]\n" + "ld1 {v1.8h}, [x10], %[stride]\n" + "ld1 {v2.8h}, [x10], %[stride]\n" + "ld1 {v3.8h}, [x10], %[stride]\n" + "ld1 {v4.8h}, [x10], %[stride]\n" + "ld1 {v5.8h}, [x10], %[stride]\n" + "ld1 {v6.8h}, [x10], %[stride]\n" + "ld1 {v7.8h}, [x10], %[stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[stride]\n" + "ld1 {v9.8h}, [x10], %[stride]\n" + "ld1 {v10.8h}, [x10], %[stride]\n" + "ld1 {v11.8h}, [x10], %[stride]\n" + "ld1 {v12.8h}, [x10], %[stride]\n" + "ld1 {v13.8h}, [x10], %[stride]\n" + "ld1 {v14.8h}, [x10], %[stride]\n" + "ld1 {v15.8h}, [x10], %[stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + "zip1 v18.8h, v12.8h, v13.8h\n" + "zip1 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + "zip2 v18.8h, v12.8h, v13.8h\n" + "zip2 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + : + : [ dst_c ] "r"(dst_ptr), [ src_c ] "r"(src_ptr), [ stride ] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif + +void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + size_t row_up_16 = UP_ROUND(row, C16NUM); + size_t row16 = row / C16NUM * C16NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + size_t ri = 0; + // find 16 block unit + for (; ri < row16; ri += C16NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; +#ifdef ENABLE_ARM64 + Row2Col16Block16(src_c, dst_c, col); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + for (size_t i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + for (; ri < row_up_16; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C16NUM] = 0; + } + dst_r += 1; + } + return; +} + +#ifdef ENABLE_ARM64 +void RowMajor2ColLadder12MajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col) { + // Col12Major ==> Col8Major ==> Col4Major + const float16_t *src_r = src; + float16_t *dst_r = dst_ptr; + int ri = 0; + size_t col8 = col / C8NUM * C8NUM; + // find 16 block unit + for (; ri <= row - C12NUM; ri += C12NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + Transpose12x8ARM64Fp16(src_c, dst_c, col * C2NUM, C24NUM); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + for (size_t i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C12NUM * col; + dst_r += C12NUM * col; + } + for (; ri <= row - C8NUM; ri += C8NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + Transpose8x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C8NUM * sizeof(float16_t)); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + for (size_t i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri <= row - C4NUM; ri += C4NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C4NUM; + Transpose4x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C4NUM * sizeof(float16_t)); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C4NUM; + for (size_t i = 0; i < C4NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + if (ri < row) { + memcpy(dst_r, src_r, (row - ri) * col * C2NUM); + } +} + +void RowMajor2RowLadder12MajorFp16(const float16_t *src, float16_t *dst, int row, int col) { + // Row12 ==> Row8 ==> Row4 + for (int r = 0; r < row; r++) { + int c = 0; + for (; c <= col - C12NUM; c += C12NUM) { + MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); + MS_FLOAT16X4 src_data1 = MS_LD_F16(src + r * col + c + C8NUM); + MS_STQ_F16(dst + c / C12NUM * C12NUM * row + r * C12NUM, src_data); + MS_ST_F16(dst + c / C12NUM * C12NUM * row + r * C12NUM + C8NUM, src_data1); + } + for (; c <= col - C8NUM; c += C8NUM) { + MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); + MS_STQ_F16(dst + c / C12NUM * C12NUM * row + r * C8NUM, src_data); + } + for (; c <= col - C4NUM; c += C4NUM) { + MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c); + MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); + } + for (; c < col; ++c) { + dst[c / C4NUM * C4NUM * row + r + c % C4NUM * row] = src[r * col + c]; + } + } +} + +void RowMajor2ColNMajorFp16srcFp16(const float16_t *src_ptr, float16_t *dst_ptr, int row, int col) { + const float16_t *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + int ri = 0; + size_t col8 = col / C8NUM * C8NUM; + // find 16 block unit + for (; ri <= row - C16NUM; ri += C16NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + Row2Col16Block16(src_c, dst_c, col); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + for (size_t i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + for (; ri <= row - C8NUM; ri += C8NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + Transpose8x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C8NUM * sizeof(float16_t)); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + for (size_t i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri <= row - C4NUM; ri += C4NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C4NUM; + Transpose4x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C4NUM * sizeof(float16_t)); + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C4NUM; + for (size_t i = 0; i < C4NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C4NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } +} + +void RowMajor2ColNMajorFp16(const void *src_ptr, float16_t *dst_ptr, int row, int col, bool is_fp32_src) { + // Col16Major ==> Col8Major ==> Col4Major + if (!is_fp32_src) { + RowMajor2ColNMajorFp16srcFp16((const float16_t *)src_ptr, dst_ptr, row, col); + return; + } + const float *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + int ri = 0; + // find 16 block unit + for (; ri <= row - C16NUM; ri += C16NUM) { + for (int r = 0; r < C16NUM; ++r) { + for (int c = 0; c < col; ++c) { + dst_r[c * C16NUM + r % C16NUM] = src_r[r * col + c]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + for (; ri <= row - C8NUM; ri += C8NUM) { + for (int r = 0; r < C8NUM; ++r) { + for (int c = 0; c < col; ++c) { + dst_r[c * C8NUM + r % C8NUM] = src_r[r * col + c]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri <= row - C4NUM; ri += C4NUM) { + for (int r = 0; r < C4NUM; ++r) { + for (int c = 0; c < col; ++c) { + dst_r[c * C4NUM + r % C4NUM] = src_r[r * col + c]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + for (; ri < row; ++ri) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C4NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } +} + +void RowMajor2RowNMajorFp16(const void *src_ptr, float16_t *dst, int row, int col, bool is_fp32_src) { + // Row16 ==> Row8 ==> Row4 + if (is_fp32_src) { + const float *src = (const float *)src_ptr; + for (int r = 0; r < row; r++) { + int c = 0; + for (; c <= col - C16NUM; c += C16NUM) { + const float *cur_src = src + r * col + c; + MS_FLOAT32X4X4 src_f32_data = {MS_LDQ_F32(cur_src), MS_LDQ_F32(cur_src + C4NUM), MS_LDQ_F32(cur_src + C8NUM), + MS_LDQ_F32(cur_src + C12NUM)}; + MS_FLOAT16X4X4 res = { + MS_CVT_F16_F32(src_f32_data.val[0]), + MS_CVT_F16_F32(src_f32_data.val[1]), + MS_CVT_F16_F32(src_f32_data.val[2]), + MS_CVT_F16_F32(src_f32_data.val[3]), + }; + MS_ST4_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, res); + } + for (; c <= col - C8NUM; c += C8NUM) { + const float *cur_src = src + r * col + c; + MS_FLOAT32X4X2 src_f32_data = {MS_LDQ_F32(cur_src), MS_LDQ_F32(cur_src + C4NUM)}; + MS_FLOAT16X4X2 res = { + MS_CVT_F16_F32(src_f32_data.val[0]), + MS_CVT_F16_F32(src_f32_data.val[1]), + }; + MS_ST2_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, res); + } + for (; c <= col - C4NUM; c += C4NUM) { + MS_FLOAT16X4 src_data = MS_CVT_F16_F32(MS_LDQ_F32(src + r * col + c)); + MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); + } + for (; c < col; ++c) { + dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c]; + } + } + return; + } + const float16_t *src = (const float16_t *)src_ptr; + for (int r = 0; r < row; r++) { + int c = 0; + for (; c <= col - C16NUM; c += C16NUM) { + MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); + MS_FLOAT16X8 src_data1 = MS_LDQ_F16(src + r * col + c + C8NUM); + MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, src_data); + MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM + C8NUM, src_data1); + } + for (; c <= col - C8NUM; c += C8NUM) { + MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); + MS_STQ_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, src_data); + } + for (; c <= col - C4NUM; c += C4NUM) { + MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c); + MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); + } + for (; c < col; ++c) { + dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c]; + } + } +} +#endif + +void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + size_t row_up_12 = UP_ROUND(row, C12NUM); + size_t row12 = row / C12NUM * C12NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + size_t ri = 0; + // transpose 12x8 + for (; ri < row12; ri += C12NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; +#ifdef ENABLE_ARM64 + Transpose12x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), 24); +#elif ENABLE_ARM82_A32 + Transpose12x8A32Fp16(src_c, dst_c, col * sizeof(float16_t), 24); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + for (size_t i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C12NUM * col; + dst_r += C12NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C12NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + for (; ri < row_up_12; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + dst_r += 1; + } +} + +void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + if (is_fp32_src) { + const float *fp32_src = (const float *)src; + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r_div16 = r / 16; + int r_mod16 = r % 16; + dst[r_div16 * 16 * col + c * 16 + r_mod16] = (float16_t)(fp32_src[r * col + c]); + } + } + } else { + const float16_t *fp16_src = (const float16_t *)src; + RowMajor2Col16MajorFp16Opt(fp16_src, dst, row, col); + } + return; +} + +void RowMajor2Col12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + if (is_fp32_src) { + const float *fp32_src = (const float *)src; + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r_div12 = r / 12; + int r_mod12 = r % 12; + dst[r_div12 * 12 * col + c * 12 + r_mod12] = (float16_t)(fp32_src[r * col + c]); + } + } + } else { + const float16_t *fp16_src = (const float16_t *)src; + RowMajor2Col12MajorFp16Opt(fp16_src, dst, row, col); + } + return; +} + +void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int c_div16 = c / 16; + int c_mod16 = c % 16; + if (is_fp32_src) { + dst[c_div16 * 16 * row + r * 16 + c_mod16] = (float16_t)(((const float *)src)[r * col + c]); + } else { + dst[c_div16 * 16 * row + r * 16 + c_mod16] = ((const float16_t *)src)[r * col + c]; + } + } + } +} + +void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col) { + int col_align = UP_ROUND(col, C16NUM); + for (int r = 0; r < row; r++) { + int c = 0; + for (; c < col; c++) { + int c_div16 = c / C16NUM; + int c_mod16 = c % C16NUM; + dst[c_div16 * C16NUM * row + r * C16NUM + c_mod16] = src[r * col + c]; + } + for (; c < col_align; c++) { + int c_div16 = c / C16NUM; + int c_mod16 = c % C16NUM; + dst[c_div16 * C16NUM * row + r * C16NUM + c_mod16] = (float16_t)0.0; + } + } +} + +void RowMajor2Row12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int c_div12 = c / 12; + int c_mod12 = c % 12; + if (is_fp32_src) { + dst[c_div12 * 12 * row + r * 12 + c_mod12] = (float16_t)(((const float *)src)[r * col + c]); + } else { + dst[c_div12 * 12 * row + r * 12 + c_mod12] = ((const float16_t *)src)[r * col + c]; + } + } + } +} + +void RowMajor2Row8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + int down_c8 = col / C8NUM; + int stride = C8NUM * row; + for (int r = 0; r < row; r++) { + int c = 0; + for (; c < down_c8; c++) { + MS_FLOAT16X8 src_data = MS_LDQ_F16((const float16_t *)src + r * col + c * C8NUM); + MS_STQ_F16(dst + c * stride + r * C8NUM, src_data); + } + c *= C8NUM; + for (; c < col; c++) { + int c_div8 = c / 8; + int c_mod8 = c % 8; + dst[c_div8 * stride + r * 8 + c_mod8] = ((const float16_t *)src)[r * col + c]; + } + } +} + +void RowMajor2ColMajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + if (is_fp32_src) { + dst[c * row + r] = (float16_t)(((const float *)src)[r * col + c]); + } else { + dst[c * row + r] = ((const float16_t *)src)[r * col + c]; + } + } + } +} + +#ifdef ENABLE_ARM64 +void RowMajor2Col8MajorFp16_arm64(const float16_t *src_c, float16_t *dst_c, size_t col) { + size_t stride = col * sizeof(float16_t); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.8h}, [x10], %[stride]\n" + "ld1 {v1.8h}, [x10], %[stride]\n" + "ld1 {v2.8h}, [x10], %[stride]\n" + "ld1 {v3.8h}, [x10], %[stride]\n" + "ld1 {v4.8h}, [x10], %[stride]\n" + "ld1 {v5.8h}, [x10], %[stride]\n" + "ld1 {v6.8h}, [x10], %[stride]\n" + "ld1 {v7.8h}, [x10], %[stride]\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v0.8h, v1.8h\n" + "zip1 v10.8h, v2.8h, v3.8h\n" + "zip2 v11.8h, v2.8h, v3.8h\n" + "zip1 v12.8h, v4.8h, v5.8h\n" + "zip2 v13.8h, v4.8h, v5.8h\n" + "zip1 v14.8h, v6.8h, v7.8h\n" + "zip2 v15.8h, v6.8h, v7.8h\n" + + "trn1 v16.4s, v8.4s, v10.4s\n" + "trn2 v17.4s, v8.4s, v10.4s\n" + "trn1 v18.4s, v12.4s, v14.4s\n" + "trn2 v19.4s, v12.4s, v14.4s\n" + "trn1 v20.4s, v9.4s, v11.4s\n" + "trn2 v21.4s, v9.4s, v11.4s\n" + "trn1 v22.4s, v13.4s, v15.4s\n" + "trn2 v23.4s, v13.4s, v15.4s\n" + + "trn1 v0.2d, v16.2d, v18.2d\n" + "trn1 v1.2d, v17.2d, v19.2d\n" + "trn2 v2.2d, v16.2d, v18.2d\n" + "trn2 v3.2d, v17.2d, v19.2d\n" + "trn1 v4.2d, v20.2d, v22.2d\n" + "trn1 v5.2d, v21.2d, v23.2d\n" + "trn2 v6.2d, v20.2d, v22.2d\n" + "trn2 v7.2d, v21.2d, v23.2d\n" + + "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x11], #64\n" + "st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x11], #64\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); + return; +} +#endif + +void RowMajor2Col8MajorFp16SrcFp16(const float16_t *src, float16_t *dst, int row, int col) { + int row8 = row / C8NUM * C8NUM; +#ifdef ENABLE_ARM64 + int col_skip = col / C8NUM * C8NUM; + int skip_size = C8NUM; +#else + int col_skip = col / C4NUM * C4NUM; + int skip_size = C4NUM; +#endif + const float16_t *src_r = src; + float16_t *dst_r = dst; + + int ri = 0; + for (; ri < row8; ri += C8NUM) { + int ci = 0; + for (; ci < col_skip; ci += skip_size) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + +#ifdef ENABLE_ARM64 + RowMajor2Col8MajorFp16_arm64(src_c, dst_c, col); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C8NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C8NUM; + for (int i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + for (; ri < row; ri++, src_r += col, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C8NUM] = src_r[i]; + } + } + + for (; ri < UP_ROUND(row, C8NUM); ri++, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C8NUM] = 0; + } + } +} + +void RowMajor2Col8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { + if (!is_fp32_src) { + return RowMajor2Col8MajorFp16SrcFp16(src, dst, row, col); + } + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r_div8 = r / 8; + int r_mod8 = r % 8; + dst[r_div8 * 8 * col + c * 8 + r_mod8] = (float16_t)(((const float *)src)[r * col + c]); + } + } +} + +#if defined(ENABLE_DEBUG) && defined(ENABLE_ARM64) +// arm64 matmul +void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc) { + int r16 = row / C16NUM * C16NUM; + int r8 = row / C8NUM * C8NUM; + for (int r = 0; r < row; ++r) { + int row_tile = 0; + if (r < r16) { + row_tile = C16NUM; + } else if (r < r8) { + row_tile = C8NUM; + } else { + row_tile = C4NUM; + } + int index = r / row_tile * row_tile * depth + r % row_tile; + for (int t = 0; t < col; ++t) { + int c_div = t / C8NUM; + int c_mod = t % C8NUM; + float16_t res = bias[t]; + for (int d = 0; d < depth; ++d) { + res += a[index + d * row_tile] * b[c_div * depth * C8NUM + d * C8NUM + c_mod]; + } + c[r * col + t] = res; + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/matmul_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/matmul_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..c057b3d89b45006fbf740766e3d4b076f8fd2d1d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/matmul_fp16.h @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_MATMUL_FP16_H_ +#define NNACL_FP16_MATMUL_FP16_H_ + +#include +#include +#include "nnacl/errorcode.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl/fp16/pack_fp16.h" + +#define ADD_BIAS(value, bias, c) \ + if (bias != NULL) value = value + bias[c]; + +#define DO_RELU(value, act_type) \ + if (act_type == ActType_Relu) value = MSMAX(0.0f, value); + +#define DO_RELU6(value, act_type) \ + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); \ + if (act_type == ActType_Relu6) value = MSMAX(0.0f, value); + +#ifdef __cplusplus +extern "C" { +#endif +void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode); + +void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode); + +#ifdef ENABLE_ARM64 +void RowMajor2ColLadder12MajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col); + +void RowMajor2RowLadder12MajorFp16(const float16_t *src, float16_t *dst, int row, int col); + +void RowMajor2ColNMajorFp16(const void *src, float16_t *dst_ptr, int row, int col, bool is_fp32_src); + +void RowMajor2RowNMajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, size_t stride, size_t out_type); +void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); + +void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); + +void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); + +void MatmulFp16OptV2(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); + +#ifdef ENABLE_DEBUG +void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); +#endif + +void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); + +void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, int depth, + int col); +void VecMatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); +#elif ENABLE_ARM82_A32 +void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode); + +void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); + +void MatVecMulA32NeonFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); +#endif + +void MatMul12x16Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, size_t stride, size_t out_type); + +void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int row, int col, int stride, int out_type); + +void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, + int depth, int col); + +void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16); + +void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); + +void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); + +void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Col12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col); + +void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Row12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Row8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2Col8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +void RowMajor2ColMajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_MATMUL_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/matrix_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/matrix_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..a716dd974ec821597b2f05f113b0e3ca71f9ced3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/matrix_fp16.c @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/matrix_fp16.h" + +void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, + int n) { + int count = 0; + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float16_t res = 0; + for (int i = 0; i < k; i++) { + res += *(matrix_a + h_offset + i) * *(matrix_b + w + i * n); + } + *(matrix_c + count) = res; + count++; + } + } +} + +#ifndef ENABLE_ARM64 +void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, + int n, int in_channel) { + int cnt = 0; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + for (int y = 0; y < in_channel; ++y) { + float tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + matrix_c[cnt++] = tmp; + } + } + } +} +#endif + +void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c, + const float16_t *bias, int m, int k, int n) { + if (bias == NULL) { + int count = 0; + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float16x8_t res = vmovq_n_f16(0); + for (int i = 0; i < k; i++) { + res = vaddq_f16(res, vmulq_f16(matrix_a[h_offset + i], matrix_b[w + i * n])); + } + matrix_c[count] = res; + count++; + } + } + } else { + int count = 0; + float16x8_t bias_ptr = vld1q_f16(bias); + for (int h = 0; h < m; h++) { + int h_offset = h * k; + for (int w = 0; w < n; w++) { + float16x8_t res = vmovq_n_f16(0); + for (int i = 0; i < k; i++) { + res = vaddq_f16(res, vmulq_f16(matrix_a[h_offset + i], matrix_b[w + i * n])); + } + matrix_c[count] = vaddq_f16(res, bias_ptr); + count++; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/matrix_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/matrix_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..e347c242661198296d8778c9e2c7b9c367f8f1dc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/matrix_fp16.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_MATRIX_FP16_H_ +#define NNACL_FP16_MATRIX_FP16_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif +void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n); + +void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c, + const float16_t *bias, int m, int k, int n); +void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, + int n, int in_channel); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_MATRIX_FP16_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.cc b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/one_hot_fp16.c similarity index 34% rename from mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/fp16/one_hot_fp16.c index ddd77844817048619bfe731d0879bc631e764907..298e3fba6a8334bed9ccc9d1602d5ec14edb1716 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/one_hot_fp16.c @@ -14,35 +14,37 @@ * limitations under the License. */ -#include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h" -#include -#include "src/common/log_util.h" - -CudaHelper &CudaHelper::GetInstance() { - static CudaHelper instance; - return instance; -} -int CudaHelper::GetThreadNum() const { return threads_per_block_; } -int CudaHelper::GetThreadNum(const int block_size) const { - return std::min(threads_per_block_, ((block_size - 1) / 32 + 1) * 32); -} -int CudaHelper::GetBlocksNum(const int total_threads) const { - return std::min(((total_threads - 1) / threads_per_block_) + 1, max_blocks_); -} -int CudaHelper::GetBlocksNum(const int total_threads, const int block_size) const { - int valid_block_size = std::min(block_size, threads_per_block_); - if (valid_block_size == 0) { - MS_LOG(ERROR) << "invalid input of block_size: " << block_size; - return 0; +#include "nnacl/fp16/one_hot_fp16.h" +#include "nnacl/errorcode.h" +int OneHotToFp16(const int *indices, float16_t on_value, float16_t off_value, float16_t *output, + const OneHotStruct *one_hot_param, const int tid, const int thread_num) { + if (indices == NULL || one_hot_param == NULL || output == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; } - return std::min(((total_threads - 1) / valid_block_size) + 1, max_blocks_); -} -CudaHelper::CudaHelper() { - int device_id = 0; - (void)cudaGetDevice(&device_id); - cudaDeviceProp prop; - (void)cudaGetDeviceProperties(&prop, device_id); - threads_per_block_ = prop.maxThreadsPerBlock; - max_blocks_ = prop.multiProcessorCount; + int outer_size = one_hot_param->outer_size_; + int inner_size = one_hot_param->inner_size_; + int depth = one_hot_param->depth_; + int i, j, k; + for (i = tid; i < outer_size; i += thread_num) { + float16_t *output_ptr = output + i * depth * inner_size; + for (k = 0; k < depth; k++) { // output layout: outer_size * depth * inner_size + const int *indices_ptr = indices + i * inner_size; + for (j = 0; j < inner_size; j++) { + *output_ptr = off_value; + int index = *(indices_ptr++); + if (one_hot_param->support_neg_index_ && index < 0) { + index += depth; + } + if (index == k) { + *output_ptr = on_value; + } + output_ptr++; + } + } + } + return NNACL_OK; } diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/one_hot_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/one_hot_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..39dbdded8f36b7464c4aa1854bc0e2809248315a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/one_hot_fp16.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_ONE_HOT_FP16_H_ +#define NNACL_FP16_ONE_HOT_FP16_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/one_hot_parameter.h" +#include "nnacl/kernel/one_hot.h" + +#ifdef __cplusplus +extern "C" { +#endif +int OneHotToFp16(const int *indices, float16_t on_value, float16_t off_value, float16_t *output, + const OneHotStruct *one_hot_param, const int tid, const int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_ONE_HOT_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pack_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pack_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..78f255574db60b35d387408c9f4f072a688461f7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pack_fp16.c @@ -0,0 +1,933 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/pack_fp16.h" +#include + +#ifdef ENABLE_ARM +void PackWeightConvDw3x3Fp16(const void *src, void *dst, int channel) { + // nchw to nc8hw8 with 1D F(2,3) + for (int i = 0; i < channel; i++) { + float16_t *src_kernel = (float16_t *)src + i * 9; + float16_t *dst_kernel = (float16_t *)dst + (i / 8) * 96 + i % 8; + for (int y = 0; y < 3; y++) { + float16_t g0 = src_kernel[3 * y]; + float16_t g1 = src_kernel[3 * y + 1]; + float16_t g2 = src_kernel[3 * y + 2]; + + dst_kernel[32 * y] = g0; + dst_kernel[32 * y + 8] = (float16_t)0.5 * (g0 + g1 + g2); + dst_kernel[32 * y + 16] = (float16_t)0.5 * (g0 - g1 + g2); + dst_kernel[32 * y + 24] = g2; + } + } +} +#endif + +void PackHWCToWHCFp16(const float16_t *src, float16_t *dst, int height, int width, int channel) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float16_t)); + } + } +} + +void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c4 * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C4NUM; + for (int i = 0; i < channel; i++) { + int c4_block_num = i / C4NUM; + int c4_block_rem = i % C4NUM; + int src_ic_offset = src_kernel_offset + i; + int dst_ic_offset = dst_kernel_offset + c4_block_num * plane * C4NUM + c4_block_rem; + ((float16_t *)dst + dst_ic_offset)[0] = ((float16_t *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNHWCToNC8HW8NotAlignedFp16(const float16_t *src, float16_t *dst, const int batch, const int plane, + const int channel) { + if (channel <= C8NUM) { + memcpy(dst, src, batch * plane * channel * sizeof(float16_t)); + return; + } + int tmp = DOWN_DIV(channel, C8NUM); + int c_res = channel - tmp * C8NUM; + int c8_block = tmp * plane * C8NUM; + for (int b = 0; b < batch; b++) { + int batch_oc_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = batch_oc_offset + k * channel; + int dst_kernel_offset = batch_oc_offset + k * C8NUM; + int c = 0; + for (; c <= channel - C8NUM; c += C8NUM) { + float16x8_t src_data = vld1q_f16(src + src_kernel_offset + c); + vst1q_f16(dst + dst_kernel_offset + c * plane, src_data); + } + for (; c < channel; ++c) { + dst[batch_oc_offset + c8_block + k * c_res + c - tmp * C8NUM] = src[src_kernel_offset + c]; + } + } + } +} + +void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c4 * C4NUM; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C4NUM * k + c4_block_rem; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int channel, int task_id, + int thread_count) { +#ifdef ENABLE_ARM64 + // Transpose16x8 in arm64 + const int hw_tile = C16NUM; +#else + // Transpose8x8 in others + const int hw_tile = C8NUM; +#endif + int hw_align = plane / hw_tile; + int task_start = 0; + int task_end = plane; + if (thread_count > 0) { + int offset_hw = UP_DIV(hw_align, thread_count) * hw_tile; + task_start = offset_hw * task_id; + int count = plane - task_start; + if (count <= 0) { + return; + } + task_end = (task_id + 1) == thread_count ? plane : MSMIN(plane, task_start + offset_hw); + hw_align = task_start + ((task_end - task_start) >= offset_hw ? offset_hw : 0); + } else { + hw_align *= hw_tile; + } + int c8 = channel / C8NUM * C8NUM; + int batch = plane * channel; + for (int n = 0; n < batches; n++) { + const float16_t *src_batch = (const float16_t *)src + n * batch; + float16_t *dst_batch = (float16_t *)dst + n * batch; + int hw = task_start; + for (; hw < hw_align; hw += hw_tile) { + int c = 0; + for (; c < c8; c += C8NUM) { + const float16_t *src_ptr = src_batch + hw * channel + c; + float16_t *dst_ptr = dst_batch + c * plane + hw; +#ifdef ENABLE_ARM64 + size_t src_stride = channel * sizeof(float16_t); + size_t dst_stride = plane * sizeof(float16_t); + Transpose16x8ARM64Fp16(src_ptr, dst_ptr, src_stride, dst_stride); +#elif defined(ENABLE_ARM82_A32) + size_t src_stride = channel * sizeof(float16_t); + size_t dst_stride = plane * sizeof(float16_t); + Transpose8x8A32Fp16(src_ptr, dst_ptr, src_stride, dst_stride); +#else + for (int tr = 0; tr < hw_tile; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; + } + } +#endif + } + for (; c < channel; c++) { + const float16_t *src_ptr = src_batch + hw * channel + c; + float16_t *dst_ptr = dst_batch + c * plane + hw; + for (size_t i = 0; i < hw_tile; i++) { + dst_ptr[i] = src_ptr[i * channel]; + } + } + } + for (; hw < task_end; hw++) { + const float16_t *src_ptr = src_batch + hw * channel; + float16_t *dst_ptr = dst_batch + hw; + for (size_t i = 0; i < channel; i++) { + dst_ptr[i * plane] = src_ptr[i]; + } + } + } +} + +void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count) { + return PackNHWCToNCHWFp16(src, dst, batch, channel, plane, task_id, thread_count); +} + +void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int ic4 = UP_DIV(channel, C4NUM); + int c4_channel = ic4 * C4NUM; + int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float16_t *dst_per_plane = (float16_t *)dst + nhwc4_batch_offset + i * c4_channel; + memcpy(dst_per_plane, (float16_t *)src + batch_offset + i * channel, channel * sizeof(float16_t)); + for (int j = channel; j < c4_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy(dst, src, ori_input_size); + } +} + +void PackNHWCToNHWC8Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int ic8 = UP_DIV(channel, C8NUM); + int c8_channel = ic8 * C8NUM; + int nhwc8_batch_unit_offset = ic8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float16_t *dst_per_plane = (float16_t *)dst + nhwc8_batch_offset + i * c8_channel; + memcpy(dst_per_plane, (float16_t *)src + batch_offset + i * channel, channel * sizeof(float16_t)); + for (int j = channel; j < c8_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy(dst, src, ori_input_size); + } +} + +void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc_batch_unit_offset = channel * plane; + for (int b = 0; b < batch; b++) { + int batch_offset = b * c4 * C4NUM * plane; + for (int i = 0; i < plane; i++) { + memcpy((float16_t *)dst + b * nhwc_batch_unit_offset + i * channel, + (float16_t *)src + batch_offset + i * c4 * C4NUM, channel * sizeof(float16_t)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float16_t); + memcpy((float16_t *)dst, (float16_t *)src, ori_input_size); + } +} + +void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int nhwc4_batch_offset = 0; + int ic4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = ic4 * C4NUM * plane; + + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int c = 0; c < channel; c++) { + int src_c_offset = batch_offset + c * plane; + int dst_c_offset = nhwc4_batch_offset + c; + for (int i = 0; i < plane; i++) { + int src_plane_offset = src_c_offset + i; + int dst_plane_offset = dst_c_offset + i * ic4 * C4NUM; + ((float16_t *)dst)[dst_plane_offset] = ((float16_t *)src)[src_plane_offset]; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } +} + +void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * channel; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c * plane; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k; + ((float16_t *)dst + dst_kernel_offset)[0] = ((float16_t *)src + src_kernel_offset)[0]; + } + } + } +} + +void PackNCHWFp32ToNC8HW8Fp16(const void *src_ptr, void *dst_ptr, int batch, int plane, int channel) { + const float *src = (const float *)src_ptr; + float16_t *dst = (float16_t *)dst_ptr; + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c8 * C8NUM; + for (int c = 0; c < channel; c++) { + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem; + (dst + dst_kernel_offset)[0] = (float16_t)(src + src_kernel_offset)[0]; + } + } + } +} + +void PackNCHWFp16ToNC8HW8Fp16(const void *src_ptr, void *dst_ptr, int batch, int plane, int channel) { + const float16_t *src = (const float16_t *)src_ptr; + float16_t *dst = (float16_t *)dst_ptr; + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c8 * C8NUM; + for (int c = 0; c < channel; c++) { + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + int src_c_offset = src_offset + c * plane; + int dst_c_offset = dst_offset + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k; + int dst_kernel_offset = dst_c_offset + C8NUM * k + c8_block_rem; + (dst + dst_kernel_offset)[0] = (float16_t)(src + src_kernel_offset)[0]; + } + } + } +} + +void PackNC8HW8ToNCHWFp16(const void *src_ptr, void *dst_ptr, int batch, int plane, int channel) { + const float16_t *src = (const float16_t *)src_ptr; + float16_t *dst = (float16_t *)dst_ptr; + int c8 = UP_ROUND(channel, C8NUM); + for (int b = 0; b < batch; b++) { + const float16_t *batch_src = src + b * plane * c8; + float16_t *batch_dst = dst + b * plane * channel; + + for (size_t c = 0; c < channel; c++) { + size_t c_div = c / C8NUM; + size_t c_mod = c % C8NUM; + for (size_t p = 0; p < plane; p++) { + int src_offset = c_div * plane * C8NUM + p * C8NUM + c_mod; + int dst_offset = c * plane + p; + batch_dst[dst_offset] = batch_src[src_offset]; + } + } + } +} + +void PackNHWCFp32ToNHWC8Fp16(const float *src, float16_t *dst, int batch, int plane, int channel) { + int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; + for (int b = 0; b < batch; b++) { + float16_t *dst_batch = dst + b * plane * c8_channel; + const float *src_batch = src + b * plane * channel; + for (int i = 0; i < plane; i++) { + float16_t *dst_plane = dst_batch + i * c8_channel; + const float *src_plane = src_batch + i * channel; + for (int c = 0; c < channel; c++) { + dst_plane[c] = (float16_t)(src_plane[c]); + } + } + } +} + +void PackNHWCFp32ToC8HWN8Fp16(const float *src, float16_t *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + for (int c = 0; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + dst[dst_index] = (float16_t)(src[src_index]); + } + } + } + return; +} + +void PackNC8HW8ToNHWCFp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c8 * C8NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C8NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c8 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C8NUM; + int dst_c_offset = dst_kernel_offset + c * C8NUM; + vst1q_f16(dst + dst_c_offset, vld1q_f16(src + src_c_offset)); + } + // res part + int res_c = channel - (c8 - 1) * C8NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c8 - 1) * C8NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c8 - 1) * C8NUM + i; + ((float16_t *)dst + dst_res_c_offset)[0] = ((float16_t *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNHWCFp16ToC8HWN8Fp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + for (int c = 0; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + dst[dst_index] = src[src_index]; + } + } + } + return; +} + +void PackNHWC8Fp16ToNHWCFp32(const float16_t *src, float *dst, int batch, int plane, int channel) { + int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; + for (int b = 0; b < batch; b++) { + const float16_t *src_batch = src + b * plane * c8_channel; + float *dst_batch = dst + b * plane * channel; + for (int i = 0; i < plane; i++) { + const float16_t *src_plane = src_batch + i * c8_channel; + float *dst_plane = dst_batch + i * channel; + for (int c = 0; c < channel; c++) { + dst_plane[c] = (float16_t)(src_plane[c]); + } + } + } +} + +void PackNHWC8ToNHWCFp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel) { + int c8_channel = UP_DIV(channel, C8NUM) * C8NUM; + for (int b = 0; b < batch; b++) { + const float16_t *src_batch = src + b * plane * c8_channel; + float16_t *dst_batch = dst + b * plane * channel; + for (int i = 0; i < plane; i++) { + const float16_t *src_plane = src_batch + i * c8_channel; + float16_t *dst_plane = dst_batch + i * channel; + memcpy(dst_plane, src_plane, channel * sizeof(float16_t)); + } + } +} + +#ifdef ENABLE_ARM82_A32 +inline void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov r10, %[src]\n" + "mov r12, %[dst]\n" + "vld1.16 {q0}, [r10], %[src_stride]\n" + "vld1.16 {q2}, [r10], %[src_stride]\n" + "vld1.16 {q4}, [r10], %[src_stride]\n" + "vld1.16 {q6}, [r10], %[src_stride]\n" + + "vtrn.16 d0, d4\n" + "vtrn.16 d1, d5\n" + "vtrn.16 d8, d12\n" + "vtrn.16 d9, d13\n" + + "vld1.16 {q8}, [r10], %[src_stride]\n" + "vld1.16 {q10}, [r10], %[src_stride]\n" + "vld1.16 {q12}, [r10], %[src_stride]\n" + "vld1.16 {q14}, [r10], %[src_stride]\n" + + "vtrn.32 d0, d8\n" + "vtrn.32 d4, d12\n" + "vtrn.32 d1, d9\n" + "vtrn.32 d5, d13\n" + + "vtrn.16 d16, d20\n" + "vtrn.16 d17, d21\n" + "vtrn.16 d24, d28\n" + "vtrn.16 d25, d29\n" + + "vtrn.32 d16, d24\n" + "vtrn.32 d20, d28\n" + "vtrn.32 d17, d25\n" + "vtrn.32 d21, d29\n" + + "vswp d1, d16\n" + "vswp d5, d20\n" + "vswp d9, d24\n" + "vswp d13, d28\n" + + "vst1.16 {q0}, [r12], %[dst_stride]\n" + "vst1.16 {q2}, [r12], %[dst_stride]\n" + "vst1.16 {q4}, [r12], %[dst_stride]\n" + "vst1.16 {q6}, [r12], %[dst_stride]\n" + + "vst1.16 {q8}, [r12], %[dst_stride]\n" + "vst1.16 {q10}, [r12], %[dst_stride]\n" + "vst1.16 {q12}, [r12], %[dst_stride]\n" + "vst1.16 {q14}, [r12], %[dst_stride]\n" + + : + : [ dst ] "r"(dst), [ src ] "r"(src), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +} + +inline void Transpose12x8A32Fp16(const float16_t *src_c, float16_t *dst_c, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.16 {q0}, [r10], %[src_stride]\n" + "vld1.16 {q2}, [r10], %[src_stride]\n" + "vld1.16 {q4}, [r10], %[src_stride]\n" + "vld1.16 {q6}, [r10], %[src_stride]\n" + + "vtrn.16 d0, d4\n" + "vtrn.16 d1, d5\n" + "vtrn.16 d8, d12\n" + "vtrn.16 d9, d13\n" + + "vld1.16 {q8}, [r10], %[src_stride]\n" + "vld1.16 {q10}, [r10], %[src_stride]\n" + "vld1.16 {q12}, [r10], %[src_stride]\n" + "vld1.16 {q14}, [r10], %[src_stride]\n" + + "vtrn.32 d0, d8\n" + "vtrn.32 d4, d12\n" + "vtrn.32 d1, d9\n" + "vtrn.32 d5, d13\n" + + "vtrn.16 d16, d20\n" + "vtrn.16 d17, d21\n" + "vtrn.16 d24, d28\n" + "vtrn.16 d25, d29\n" + + "vld1.16 {q1}, [r10], %[src_stride]\n" + "vld1.16 {q3}, [r10], %[src_stride]\n" + "vld1.16 {q5}, [r10], %[src_stride]\n" + "vld1.16 {q7}, [r10], %[src_stride]\n" + + "vtrn.32 d16, d24\n" + "vtrn.32 d20, d28\n" + "vtrn.32 d17, d25\n" + "vtrn.32 d21, d29\n" + + "vswp d1, d16\n" + "vswp d5, d20\n" + "vswp d9, d24\n" + "vswp d13, d28\n" + + "vtrn.16 d2, d6\n" + "vtrn.16 d3, d7\n" + "vtrn.16 d10, d14\n" + "vtrn.16 d11, d15\n" + + "vtrn.32 d2, d10\n" + "vtrn.32 d6, d14\n" + "vtrn.32 d3, d11\n" + "vtrn.32 d7, d15\n" + + "vst1.16 {q0, d2}, [r12], %[dst_stride]\n" + "vst1.16 {q2, d6}, [r12], %[dst_stride]\n" + "vst1.16 {q4, d10}, [r12], %[dst_stride]\n" + "vst1.16 {q6, d14}, [r12], %[dst_stride]\n" + + "vswp d3, d18\n" + "vswp d7, d22\n" + "vswp d11, d26\n" + "vswp d15, d30\n" + + "vst1.16 {q8, d18}, [r12], %[dst_stride]\n" + "vst1.16 {q10, d22}, [r12], %[dst_stride]\n" + "vst1.16 {q12, d26}, [r12], %[dst_stride]\n" + "vst1.16 {q14, d30}, [r12], %[dst_stride]\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +} +#endif + +#ifdef ENABLE_ARM64 +inline void Transpose4x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { + dst_stride += dst_stride; + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "add x10, x11, %[dst_stride]\n" + + "zip1 v4.8h, v0.8h, v1.8h\n" + "zip1 v5.8h, v2.8h, v3.8h\n" + + "trn1 v6.4s, v4.4s, v5.4s\n" + "trn2 v7.4s, v4.4s, v5.4s\n" + + "trn1 v24.2d, v6.2d, v7.2d\n" + "trn2 v25.2d, v6.2d, v7.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + + "trn1 v10.4s, v8.4s, v9.4s\n" + "trn2 v11.4s, v8.4s, v9.4s\n" + + "trn1 v26.2d, v10.2d, v11.2d\n" + "trn2 v27.2d, v10.2d, v11.2d\n" + + "st1 {v24.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v25.8h}, [x10], %[tow_dst_stride]\n" + "st1 {v26.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v27.8h}, [x10], %[tow_dst_stride]\n" + : + : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride), + [ dst_stride ] "r"(dst_stride), [ tow_dst_stride ] "r"(2 * dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v24", "v25", "v26", + "v27"); +} + +inline void Transpose8x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "ld1 {v4.8h}, [x10], %[src_stride]\n" + "ld1 {v5.8h}, [x10], %[src_stride]\n" + "ld1 {v6.8h}, [x10], %[src_stride]\n" + "ld1 {v7.8h}, [x10], %[src_stride]\n" + "add x10, x11, %[dst_stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v26.2d, v20.2d, v22.2d\n" + "trn1 v25.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v28.2d, v12.2d, v14.2d\n" + "trn2 v30.2d, v12.2d, v14.2d\n" + "trn1 v29.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v24.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v25.8h}, [x10], %[tow_dst_stride]\n" + "st1 {v26.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v27.8h}, [x10], %[tow_dst_stride]\n" + "st1 {v28.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v29.8h}, [x10], %[tow_dst_stride]\n" + "st1 {v30.8h}, [x11], %[tow_dst_stride]\n" + "st1 {v31.8h}, [x10], %[tow_dst_stride]\n" + : + : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride), + [ dst_stride ] "r"(dst_stride), [ tow_dst_stride ] "r"(2 * dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} + +void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { +#ifdef ENABLE_DEBUG + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * C12NUM + tr] = src_ptr[tr * col + tc]; + } + } +#else + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "ld1 {v4.8h}, [x10], %[src_stride]\n" + "ld1 {v5.8h}, [x10], %[src_stride]\n" + "ld1 {v6.8h}, [x10], %[src_stride]\n" + "ld1 {v7.8h}, [x10], %[src_stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[src_stride]\n" + "ld1 {v9.8h}, [x10], %[src_stride]\n" + "ld1 {v10.8h}, [x10], %[src_stride]\n" + "ld1 {v11.8h}, [x10], %[src_stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + + "trn1 v28.2d, v20.2d, v20.2d\n" + "trn2 v29.2d, v20.2d, v20.2d\n" + "trn1 v30.2d, v21.2d, v21.2d\n" + "trn2 v31.2d, v21.2d, v21.2d\n" + + "add x10, x11, #16\n" + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.4h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.4h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.4h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.4h}, [x10], %[dst_stride]\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + + "trn1 v28.2d, v20.2d, v20.2d\n" + "trn2 v29.2d, v20.2d, v20.2d\n" + "trn1 v30.2d, v21.2d, v21.2d\n" + "trn2 v31.2d, v21.2d, v21.2d\n" + + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.4h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.4h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.4h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.4h}, [x10], %[dst_stride]\n" + : + : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +#endif +} + +inline void Transpose16x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "ld1 {v4.8h}, [x10], %[src_stride]\n" + "ld1 {v5.8h}, [x10], %[src_stride]\n" + "ld1 {v6.8h}, [x10], %[src_stride]\n" + "ld1 {v7.8h}, [x10], %[src_stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[src_stride]\n" + "ld1 {v9.8h}, [x10], %[src_stride]\n" + "ld1 {v10.8h}, [x10], %[src_stride]\n" + "ld1 {v11.8h}, [x10], %[src_stride]\n" + "ld1 {v12.8h}, [x10], %[src_stride]\n" + "ld1 {v13.8h}, [x10], %[src_stride]\n" + "ld1 {v14.8h}, [x10], %[src_stride]\n" + "ld1 {v15.8h}, [x10], %[src_stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + "zip1 v18.8h, v12.8h, v13.8h\n" + "zip1 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "add x10, x11, #16\n" + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.8h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.8h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.8h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.8h}, [x10], %[dst_stride]\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + "zip2 v18.8h, v12.8h, v13.8h\n" + "zip2 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.8h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.8h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.8h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.8h}, [x10], %[dst_stride]\n" + : + : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pack_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pack_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..b8e11d292bb553766e77532e32efc659106fe2cd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pack_fp16.h @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_PACK_FP16_H_ +#define NNACL_FP16_PACK_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void PackHWCToWHCFp16(const float16_t *src, float16_t *dst, int height, int width, int channel); + +void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count); + +void PackNHWCToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count); + +void PackNHWCToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWCToNHWC8Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWC4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWC4Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC4HW4ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWFp32ToNC8HW8Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNCHWFp16ToNC8HW8Fp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC8HW8ToNCHWFp16(const void *src, void *dst, int batch, int plane, int channel); + +void PackNC8HW8ToNHWCFp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWCToNC8HW8NotAlignedFp16(const float16_t *src, float16_t *dst, const int batch, const int plane, + const int channel); + +void PackNHWCFp32ToNHWC8Fp16(const float *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWCFp32ToC8HWN8Fp16(const float *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWCFp16ToC8HWN8Fp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel); + +void PackNHWC8Fp16ToNHWCFp32(const float16_t *src, float *dst, int batch, int plane, int channel); + +void PackNHWC8ToNHWCFp16(const float16_t *src, float16_t *dst, int batch, int plane, int channel); + +#ifdef ENABLE_ARM82_A32 +void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); + +void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +#endif + +#ifdef ENABLE_ARM64 +void Transpose4x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +void Transpose8x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +void Transpose12x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride); +void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +#endif + +#ifdef ENABLE_ARM +void PackWeightConvDw3x3Fp16(const void *src, void *dst, int channel); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_PACK_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pad_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pad_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..37f97f7a50b3d97e407cd34efa38352b5a3541be --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pad_fp16.c @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/pad_fp16.h" +#include "nnacl/common_func.h" + +void PadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, + const int *paddings, int tid, int thread_num) { + int in[DEFAULT_PAD_NDIMS], out[DEFAULT_PAD_NDIMS]; + for (in[0] = 0; in[0] < input_shape[0]; in[0]++) { + out[0] = in[0] + paddings[0]; + for (in[1] = tid; in[1] < input_shape[1]; in[1] += thread_num) { + out[1] = in[1] + paddings[2]; + for (in[2] = 0; in[2] < input_shape[2]; in[2]++) { + out[2] = in[2] + paddings[4]; + for (in[3] = 0; in[3] < input_shape[3]; in[3]++) { + out[3] = in[3] + paddings[6]; + for (in[4] = 0; in[4] < input_shape[4]; in[4]++) { + out[4] = in[4] + paddings[8]; + float16_t *dst = output_data + Offset6d(output_shape, out) + paddings[10]; + const float16_t *src = input_data + Offset6d(input_shape, in); + memcpy(dst, src, input_shape[5] * sizeof(float16_t)); + } + } + } + } + } +} + +void MirrorPadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *in_strides, + const int *out_strides, const int *padding, int mirror_offset, int begin, int end) { + for (int i = begin; i < end; ++i) { + output_data[i] = input_data[GetInputFlattenIndex(i, input_shape, in_strides, out_strides, padding, mirror_offset)]; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pad_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pad_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..70251aeb4af0e79e196ee6b597daf79f82a84993 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pad_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_PAD_FP16_H_ +#define NNACL_FP16_PAD_FP16_H_ + +#include "nnacl/fp32/pad_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, + const int *paddings, int tid, int thread_num); +void MirrorPadFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *in_strides, + const int *out_strides, const int *padding, int mirror_offset, int begin, int end); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_PAD_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pooling_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pooling_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..5ef8b3596c4c5612e917e04c2930f74eed2c3365 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pooling_fp16.c @@ -0,0 +1,305 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/pooling_fp16.h" +#include +#include "nnacl/errorcode.h" + +int AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + float16_t min = (float16_t)pooling_args->minf; + float16_t max = (float16_t)pooling_args->maxf; + + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int c8 = channel / C8NUM; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + +#ifdef ENABLE_NEON + MS_FLOAT16X8 min_value = MS_MOVQ_F16(min); + MS_FLOAT16X8 max_value = MS_MOVQ_F16(max); +#endif + + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + for (int batch = 0; batch < pooling_args->output_batch_; batch++) { + const float16_t *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float16_t *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + + const float16_t *src_plane_ptr = src_b_ptr; + float16_t *dst_plane_ptr = dst_b_ptr + index * channel; + + int real_win_h_start = MSMAX(0, -in_h_index); + int real_win_h_end = MSMIN(win_h, in_h - in_h_index); + int real_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_end = MSMIN(win_w, in_w - in_w_index); + + for (int ci = 0; ci < c8; ci++) { + const float16_t *src_c_ptr = src_plane_ptr + ci * C8NUM; + float16_t *dst_c_ptr = dst_plane_ptr + ci * C8NUM; +#ifdef ENABLE_NEON + MS_FLOAT16X8 tmp_avg = MS_MOVQ_F16(0); +#else + float16_t tmp_avg0 = 0; + float16_t tmp_avg1 = 0; + float16_t tmp_avg2 = 0; + float16_t tmp_avg3 = 0; + float16_t tmp_avg4 = 0; + float16_t tmp_avg5 = 0; + float16_t tmp_avg6 = 0; + float16_t tmp_avg7 = 0; +#endif + int real_count = 0; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + const float16_t *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_avg = MS_ADDQ_F16(tmp_avg, MS_LDQ_F16(src_win_ptr)); +#else + tmp_avg0 += src_win_ptr[0]; + tmp_avg1 += src_win_ptr[1]; + tmp_avg2 += src_win_ptr[2]; + tmp_avg3 += src_win_ptr[3]; + tmp_avg4 += src_win_ptr[4]; + tmp_avg5 += src_win_ptr[5]; + tmp_avg6 += src_win_ptr[6]; + tmp_avg7 += src_win_ptr[7]; +#endif + ++real_count; + } + } + if (real_count == 0) { + return NNACL_ERR; + } +#ifdef ENABLE_NEON + tmp_avg = MS_DIVQ_F16(tmp_avg, MS_MOVQ_F16((float16_t)real_count)); + MS_STQ_F16(dst_c_ptr, MS_MINQ_F16(MS_MAXQ_F16(tmp_avg, min_value), max_value)); +#else + dst_c_ptr[0] = MSMIN(MSMAX(tmp_avg0 / (float16_t)real_count, min), max); + dst_c_ptr[1] = MSMIN(MSMAX(tmp_avg1 / (float16_t)real_count, min), max); + dst_c_ptr[2] = MSMIN(MSMAX(tmp_avg2 / (float16_t)real_count, min), max); + dst_c_ptr[3] = MSMIN(MSMAX(tmp_avg3 / (float16_t)real_count, min), max); + dst_c_ptr[4] = MSMIN(MSMAX(tmp_avg4 / (float16_t)real_count, min), max); + dst_c_ptr[5] = MSMIN(MSMAX(tmp_avg5 / (float16_t)real_count, min), max); + dst_c_ptr[6] = MSMIN(MSMAX(tmp_avg6 / (float16_t)real_count, min), max); + dst_c_ptr[7] = MSMIN(MSMAX(tmp_avg7 / (float16_t)real_count, min), max); +#endif + } // c8 loop + int channel_s = c8 * C8NUM; + for (int ci = channel_s; ci < channel; ci++) { + const float16_t *src_c_ptr = src_plane_ptr + ci; + float16_t *dst_c_ptr = dst_plane_ptr + ci; + float16_t tmp_avg = 0; + int real_count = 0; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + const float16_t *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += src_win_ptr[0]; + ++real_count; + } + } + if (real_count == 0) { + return NNACL_ERR; + } + tmp_avg = tmp_avg / (float16_t)real_count; + tmp_avg = fmax(tmp_avg, min); + tmp_avg = fmin(tmp_avg, max); + dst_c_ptr[0] = tmp_avg; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + } // out_batch loop + return NNACL_OK; +} + +void MaxPoolingC8Fp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingComputeParam *pooling_args, + float16_t min, float16_t max, int in_batch_offset, int out_plane_offset, int real_win_h_start, + int real_win_h_end, int real_win_w_start, int real_win_w_end, int in_h_index, int in_w_index) { + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int c8 = channel / C8NUM; +#ifdef ENABLE_NEON + float16x8_t min_value = vdupq_n_f16(min); + float16x8_t max_value = vdupq_n_f16(max); +#endif + for (int j = 0; j < c8; j++) { + int in_channel_offset = in_batch_offset + j * C8NUM; + int out_channel_offset = out_plane_offset + j * C8NUM; +#ifdef ENABLE_NEON + float16x8_t tmp_max = vdupq_n_f16(min); +#else + float16_t tmp_max[8] = {min, min, min, min, min, min, min, min}; +#endif + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_f16(tmp_max, vld1q_f16(input_ptr + in_offset)); +#else + for (int k = 0; k < C8NUM; k++) { + tmp_max[k] = fmax(tmp_max[k], *(input_ptr + in_offset + k)); + } +#endif + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + tmp_max = vmaxq_f16(tmp_max, min_value); + tmp_max = vminq_f16(tmp_max, max_value); + vst1q_f16(output_ptr + out_channel_offset, tmp_max); +#else + for (int l = 0; l < C8NUM; ++l) { + tmp_max[l] = fmax(tmp_max[l], min); + tmp_max[l] = fmin(tmp_max[l], max); + *(output_ptr + out_channel_offset + l) = tmp_max[l]; + } +#endif + } // c8 loop +} + +void MaxPoolingC4Fp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingComputeParam *pooling_args, + float16_t min, float16_t max, int in_batch_offset, int out_plane_offset, int real_win_h_start, + int real_win_h_end, int real_win_w_start, int real_win_w_end, int in_h_index, int in_w_index) { + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int c8 = channel / C8NUM; + int c8_res = channel % C8NUM; + int c4 = c8_res / C4NUM; +#ifdef ENABLE_NEON + float16x4_t min_value2 = vdup_n_f16(min); + float16x4_t max_value2 = vdup_n_f16(max); +#endif + int c4_offset = c8 * C8NUM; + for (int j = 0; j < c4; j++) { + int in_channel_offset = in_batch_offset + c4_offset + j * C4NUM; + int out_channel_offset = out_plane_offset + c4_offset + j * C4NUM; +#ifdef ENABLE_NEON + float16x4_t tmp_max = vdup_n_f16(min); +#else + float16_t tmp_max[4] = {min, min, min, min}; +#endif + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmax_f16(tmp_max, vld1_f16(input_ptr + in_offset)); +#else + for (int k = 0; k < C4NUM; k++) { + tmp_max[k] = fmax(tmp_max[k], *(input_ptr + in_offset + k)); + } +#endif + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + tmp_max = vmax_f16(tmp_max, min_value2); + tmp_max = vmin_f16(tmp_max, max_value2); + vst1_f16(output_ptr + out_channel_offset, tmp_max); +#else + for (int l = 0; l < C4NUM; ++l) { + tmp_max[l] = fmax(tmp_max[l], min); + tmp_max[l] = fmin(tmp_max[l], max); + output_ptr[out_channel_offset + l] = tmp_max[l]; + } +#endif + } // c4 loop +} +void MaxPoolingC1Fp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingComputeParam *pooling_args, + float16_t min, float16_t max, int in_batch_offset, int out_plane_offset, int real_win_h_start, + int real_win_h_end, int real_win_w_start, int real_win_w_end, int in_h_index, int in_w_index) { + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int c8 = channel / C8NUM; + int c8_res = channel % C8NUM; + int c4 = c8_res / C4NUM; + int channel_s = c8 * C8NUM + c4 * C4NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + float16_t tmp_max = -FLT_MAX; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = fmax(tmp_max, *(input_ptr + in_offset)); + } // win_w loop + } // win_h loop + tmp_max = fmax(tmp_max, min); + tmp_max = fmin(tmp_max, max); + output_ptr[out_channel_offset] = tmp_max; + } // channel_res loop +} + +void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + float16_t min = (float16_t)pooling_args->minf; + float16_t max = (float16_t)pooling_args->maxf; + + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int output_batch = pooling_args->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + + // input channel is equal to output channel + NNACL_CHECK_ZERO_RETURN(output_w); + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + int real_win_h_start = MSMAX(0, -in_h_index); + int real_win_h_end = MSMIN(win_h, in_h - in_h_index); + int real_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_end = MSMIN(win_w, in_w - in_w_index); + MaxPoolingC8Fp16(input_ptr, output_ptr, pooling_args, min, max, in_batch_offset, out_plane_offset, + real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w_index); + MaxPoolingC4Fp16(input_ptr, output_ptr, pooling_args, min, max, in_batch_offset, out_plane_offset, + real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w_index); + MaxPoolingC1Fp16(input_ptr, output_ptr, pooling_args, min, max, in_batch_offset, out_plane_offset, + real_win_h_start, real_win_h_end, real_win_w_start, real_win_w_end, in_h_index, in_w_index); + } // real_cal_num loop + } // out_plane loop + } // out_batch loop +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pooling_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pooling_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..49ea015d681059826cd0eb9324b1641e93917d9d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/pooling_fp16.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_POOLING_FP16_H_ +#define NNACL_FP16_POOLING_FP16_H_ + +#include +#include "nnacl/pooling_parameter.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +int AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); + +void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_POOLING_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/power_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/power_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..c54cac64b7ddc348b657ef7fb4486392f158e42f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/power_fp16.c @@ -0,0 +1,117 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/power_fp16.h" +#include "nnacl/errorcode.h" + +#if defined(ENABLE_NEON) +float16x8_t OptimizedPowerSimdFp16(float16x8_t x, const void *exponent) { + int tmp = (int)(*(float16_t *)exponent); + int exp = abs(tmp); + float16x8_t result = vmovq_n_f16(1.0f); + while (exp) { + if (exp % 2) { + result *= x; + } + x *= x; + exp = exp / 2; + } + if (tmp >= 0) { + return result; + } + return 1 / result; +} +#endif + +float16_t OptimizedPowerScalarFp16(float16_t x, const void *exponent) { + int tmp = *(float16_t *)exponent; + int exp = abs(tmp); + float16_t result = 1; + while (exp) { + if (exp % 2) { + result *= x; + } + x *= x; + exp = exp / 2; + } + return tmp >= 0 ? result : 1 / result; +} + +void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, + float shift) { + PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; +#if defined(ENABLE_NEON) + PowerSimdFunFp16 PowerSimdFunFp16_ = NULL; +#endif + + if (CheckIntegerFp16(*exponent)) { +#if defined(ENABLE_NEON) + PowerSimdFunFp16_ = OptimizedPowerSimdFp16; +#endif + PowerScalarFunFp16_ = OptimizedPowerScalarFp16; + } else { +#if defined(ENABLE_NEON) + PowerSimdFunFp16_ = StdPowerSimdFp16; +#endif + PowerScalarFunFp16_ = StdPowerScalarFp16; + } + int i = 0; +#ifdef ENABLE_NEON + int len_c8 = DOWN_ROUND(len, C8NUM); + float16x8_t scale_8 = vmovq_n_f16(scale); + float16x8_t shift_8 = vmovq_n_f16(shift); + for (; i < len_c8; i += C8NUM) { + float16x8_t result = PowerSimdFunFp16_(scale_8 * vld1q_f16(input + i) + shift_8, exponent); + vst1q_f16(output + i, result); + } +#endif + for (; i < len; ++i) { + output[i] = PowerScalarFunFp16_(scale * input[i] + shift, exponent); + } +} + +void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, + float shift) { + int i = 0; + PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; +#ifdef ENABLE_NEON + int len_c8 = DOWN_ROUND(len, C8NUM); + float16x8_t scale_8 = vmovq_n_f16(scale); + float16x8_t shift_8 = vmovq_n_f16(shift); + for (; i < len_c8; i += C8NUM) { + float16x8_t tmp_8 = scale_8 * vld1q_f16(input + i) + shift_8; + for (int j = 0; j < 8; ++j) { + PowerScalarFunFp16_ = CheckIntegerFp16(exponent[i + j]) ? OptimizedPowerScalarFp16 : StdPowerScalarFp16; + output[i + j] = PowerScalarFunFp16_(tmp_8[j], exponent + i + j); + } + } +#endif + for (; i < len; ++i) { + PowerScalarFunFp16_ = CheckIntegerFp16(exponent[i]) ? OptimizedPowerScalarFp16 : StdPowerScalarFp16; + output[i] = PowerScalarFunFp16_(scale * input[i] + shift, exponent + i); + } +} + +int PowerFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, float shift, + bool broadcast) { + if (input == NULL || exponent == NULL || output == NULL) { + return NNACL_NULL_PTR; + } + PowerFunFp16 PowerFunFp16_ = NULL; + PowerFunFp16_ = broadcast ? PowerBroadCastFp16 : PowerSingleFp16; + PowerFunFp16_(input, exponent, output, len, scale, shift); + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/power_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/power_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..9ee1c4d552138b7ae13e87352a9f026eb06f562c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/power_fp16.h @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_POWER_FP16_H_ +#define NNACL_FP16_POWER_FP16_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl/pow_parameter.h" + +#if defined(ENABLE_NEON) +typedef float16x8_t (*PowerSimdFunFp16)(float16x8_t x, const void *exponent); +#endif +typedef float16_t (*PowerScalarFunFp16)(float16_t x, const void *exponent); +typedef void (*PowerFunFp16)(const float16_t *, const float16_t *, float16_t *, int, float, float); + +#ifdef __cplusplus +extern "C" { +#endif +static inline bool CheckIntegerFp16(float16_t f) { return floorf(f) == f; } + +static inline float16_t StdPowerScalarFp16(float16_t x, const void *exponent) { + return powf(x, *(float16_t *)exponent); +} + +#if defined(ENABLE_NEON) +static inline float16x8_t StdPowerSimdFp16(float16x8_t x, const void *exponent) { + float16x8_t result; + result[0] = powf(x[0], *(float16_t *)exponent); + result[1] = powf(x[1], *(float16_t *)exponent); + result[2] = powf(x[2], *(float16_t *)exponent); + result[3] = powf(x[3], *(float16_t *)exponent); + result[4] = powf(x[4], *(float16_t *)exponent); + result[5] = powf(x[5], *(float16_t *)exponent); + result[6] = powf(x[6], *(float16_t *)exponent); + result[7] = powf(x[7], *(float16_t *)exponent); + return result; +} +#endif +int PowerFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, float shift, + bool broadcast); +void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, + float shift); +void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, + float shift); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_POWER_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/prelu_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/prelu_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..2064011e2643456ad445167e49d2f445dfafef73 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/prelu_fp16.c @@ -0,0 +1,146 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/prelu_fp16.h" + +#ifdef ENABLE_ARM64 +static inline void PReluFp164x32(const float16_t *in, float16_t *out, const float16_t *cur_slope, size_t step) { + asm volatile( + "mov x10, %[in]\n" + "mov x11, %[out]\n" + "mov x12, %[cur_slope]\n" + "ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12]\n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], %[step]\n" + "fmul v16.8h, v0.8h, v4.8h\n" + "fmul v17.8h, v1.8h, v5.8h\n" + "fmul v18.8h, v2.8h, v6.8h\n" + "fmul v19.8h, v3.8h, v7.8h\n" + "fcmgt v20.8h, v0.8h, #0\n" + "fcmgt v21.8h, v1.8h, #0\n" + "fcmgt v22.8h, v2.8h, #0\n" + "fcmgt v23.8h, v3.8h, #0\n" + "ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10], %[step]\n" + "bif v0.16b, v16.16b, v20.16b\n" + "bif v1.16b, v17.16b, v21.16b\n" + "bif v2.16b, v18.16b, v22.16b\n" + "bif v3.16b, v19.16b, v23.16b\n" + "fmul v8.8h, v24.8h, v4.8h\n" + "fmul v9.8h, v25.8h, v5.8h\n" + "fmul v10.8h, v26.8h, v6.8h\n" + "fmul v11.8h, v27.8h, v7.8h\n" + "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x11], %[step]\n" + "fcmgt v12.8h, v24.8h, #0\n" + "fcmgt v13.8h, v25.8h, #0\n" + "fcmgt v14.8h, v26.8h, #0\n" + "fcmgt v15.8h, v27.8h, #0\n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], %[step]\n" + "bif v24.16b, v8.16b, v12.16b\n" + "bif v25.16b, v9.16b, v13.16b\n" + "bif v26.16b, v10.16b, v14.16b\n" + "bif v27.16b, v11.16b, v15.16b\n" + "fmul v16.8h, v0.8h, v4.8h\n" + "fmul v17.8h, v1.8h, v5.8h\n" + "fmul v18.8h, v2.8h, v6.8h\n" + "fmul v19.8h, v3.8h, v7.8h\n" + "st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x11], %[step]\n" + "fcmgt v20.8h, v0.8h, #0\n" + "fcmgt v21.8h, v1.8h, #0\n" + "fcmgt v22.8h, v2.8h, #0\n" + "fcmgt v23.8h, v3.8h, #0\n" + "ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10]\n" + "bif v0.16b, v16.16b, v20.16b\n" + "bif v1.16b, v17.16b, v21.16b\n" + "bif v2.16b, v18.16b, v22.16b\n" + "bif v3.16b, v19.16b, v23.16b\n" + "fmul v8.8h, v24.8h, v4.8h\n" + "fmul v9.8h, v25.8h, v5.8h\n" + "fmul v10.8h, v26.8h, v6.8h\n" + "fmul v11.8h, v27.8h, v7.8h\n" + "fcmgt v12.8h, v24.8h, #0\n" + "fcmgt v13.8h, v25.8h, #0\n" + "fcmgt v14.8h, v26.8h, #0\n" + "fcmgt v15.8h, v27.8h, #0\n" + "st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x11], %[step]\n" + "bif v24.16b, v8.16b, v12.16b\n" + "bif v25.16b, v9.16b, v13.16b\n" + "bif v26.16b, v10.16b, v14.16b\n" + "bif v27.16b, v11.16b, v15.16b\n" + "st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x11]\n" + : + : [ in ] "r"(in), [ out ] "r"(out), [ cur_slope ] "r"(cur_slope), [ step ] "r"(step) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27"); +} +#endif + +void PReluFp16(const float16_t *input, float16_t *output, const float16_t *slope, int start, int end, int channel) { + int i = start; +#ifdef ENABLE_ARM64 + for (; i <= end - C4NUM; i += C4NUM) { + const float16_t *cur_in = input + i * channel; + float16_t *cur_out = output + i * channel; + int j = 0; + for (; j <= channel - C32NUM; j += C32NUM) { + const float16_t *in = cur_in + j; + float16_t *out = cur_out + j; + const float16_t *cur_slope = slope + j; + size_t step = channel * sizeof(float16_t); + PReluFp164x32(in, out, cur_slope, step); + } + for (; j < channel; j++) { + cur_out[j] = (cur_in[j] > 0) ? cur_in[j] : (cur_in[j] * slope[j]); + cur_out[j + channel] = (cur_in[j + channel] > 0) ? cur_in[j + channel] : cur_in[j + channel] * slope[j]; + cur_out[j + 2 * channel] = + (cur_in[j + 2 * channel] > 0) ? cur_in[j + 2 * channel] : (cur_in[j + 2 * channel] * slope[j]); + cur_out[j + 3 * channel] = + (cur_in[j + 3 * channel] > 0) ? cur_in[j + 3 * channel] : (cur_in[j + 3 * channel] * slope[j]); + } + } +#endif + for (; i < end; i++) { + const float16_t *cur_in = input + i * channel; + float16_t *cur_out = output + i * channel; + int j = 0; +#ifdef ENABLE_NEON + for (; j <= channel - C8NUM; j += C8NUM) { + float16x8_t in = vld1q_f16(cur_in + j); + float16x8_t s = vld1q_f16(slope + j); + float16x8_t mul = vmulq_f16(in, s); + uint16x8_t mask = vcleq_f16(in, vmovq_n_f16(0.0f)); + vst1q_f16(cur_out + j, vbslq_f16(mask, mul, in)); + } +#endif + for (; j < channel; j++) { + cur_out[j] = cur_in[j] > 0 ? cur_in[j] : cur_in[j] * slope[j]; + } + } +} + +void PReluShareChannelFp16(const float16_t *input, float16_t *output, float16_t slope, int start, int end) { + int i = start; +#ifdef ENABLE_NEON + float16x8_t zero_data = vdupq_n_f16(0); + float16x8_t slope_data = vdupq_n_f16(slope); + for (; i <= end - C8NUM; i += C8NUM) { + float16x8_t src_tmp = vld1q_f16(input + i); + float16x8_t mul_tmp = vmulq_f16(src_tmp, slope_data); + uint16x8_t mask = vcleq_f16(src_tmp, zero_data); + vst1q_f16(output + i, vbslq_f16(mask, mul_tmp, src_tmp)); + } +#endif + for (; i < end; i++) { + output[i] = input[i] > 0 ? input[i] : input[i] * slope; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/prelu_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/prelu_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..6476fa3437969d96aa834a207f4f6ce4c2efa91b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/prelu_fp16.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_PRELU_FP16_H_ +#define NNACL_FP16_PRELU_FP16_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PReluFp16(const float16_t *input, float16_t *output, const float16_t *slope, int start, int end, int channel); + +void PReluShareChannelFp16(const float16_t *input, float16_t *output, float16_t slope, int start, int end); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_PRELU_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/quant_dtype_cast_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/quant_dtype_cast_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..ed7cf06292016bde4500dd2383d62f1c5cfd8e08 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/quant_dtype_cast_fp16.c @@ -0,0 +1,290 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl/fp16/quant_dtype_cast_fp16.h" +#include "nnacl/errorcode.h" + +#ifdef ENABLE_ARM64 +void Int8ToFp16_arm64(const int8_t *quant_values, float16_t *dst, float scale, int32_t zp, int size) { + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "dup v20.4s, %w[zp32]\n" + "dup v21.4s, %w[scale]\n" + + "cmp w8, #16\n" + "blt 1f\n" + + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v7.16b}, [%[quant_values]], #16\n" + + "sxtl v8.8h, v7.8b\n" + "sxtl2 v9.8h, v7.16b\n" + + "sxtl v0.4s, v8.4h\n" + "sxtl2 v1.4s, v8.8h\n" + "sxtl v2.4s, v9.4h\n" + "sxtl2 v3.4s, v9.8h\n" + "sub v0.4s, v0.4s, v20.4s\n" + "sub v1.4s, v1.4s, v20.4s\n" + "sub v2.4s, v2.4s, v20.4s\n" + "sub v3.4s, v3.4s, v20.4s\n" + "scvtf v4.4s, v0.4s\n" + "scvtf v5.4s, v1.4s\n" + "scvtf v6.4s, v2.4s\n" + "scvtf v7.4s, v3.4s\n" + + "fmul v0.4s, v4.4s, v21.4s\n" + "fmul v1.4s, v5.4s, v21.4s\n" + "fmul v2.4s, v6.4s, v21.4s\n" + "fmul v3.4s, v7.4s, v21.4s\n" + + "fcvtn v4.4h, v0.4s\n" + "fcvtn2 v4.8h, v1.4s\n" + "fcvtn v5.4h, v2.4s\n" + "fcvtn2 v5.8h, v3.4s\n" + + "st1 {v4.8h, v5.8h}, [%[dst]], #32\n" + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "ldrsb w9, [%[quant_values]], #1\n" + + "subs w8, w8, #1\n" + "sub w9, w9, %w[zp32]\n" + "scvtf s9, w9\n" + + "fmul s9, s9, s21\n" + "fcvtn v4.4h, v9.4s\n" + "str h4, [%[dst]], #2\n" + "bne 1b\n" + + "2:\n" + + : + : [ quant_values ] "r"(quant_values), [ dst ] "r"(dst), [ scale ] "r"(scale), [ zp32 ] "r"(zp), [ size ] "r"(size) + : "w8", "w9", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v20", "v21"); +} +#endif + +int DoDequantizeInt8ToFp16(const int8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Int8ToFp16_arm64(quant_values, real_values, scale, zp, size); +#else + for (int i = 0; i < size; ++i) { + real_values[i] = (quant_values[i] - zp) * scale; + } +#endif + return NNACL_OK; +} + +#ifdef ENABLE_ARM64 +void Fp16ToInt8_arm64(const float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { + const float one = 1.0f; + const float ivs = one / scale; + const int32_t min_value = -128; + const int32_t max_value = 127; + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, wzr\n" + "beq 3f\n" + + "dup v28.4s, %w[ivs]\n" + "dup v29.4s, %w[min_value]\n" + "dup v30.4s, %w[max_value]\n" + + "cmp w8, #32\n" + "blt 2f\n" + "1:\n" // loop 32 + "subs w8, w8, #32\n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%[real_values]], #64\n" + "fcvtl v8.4s, v0.4h\n" + "fcvtl2 v9.4s, v0.8h\n" + "fcvtl v10.4s, v1.4h\n" + "fcvtl2 v11.4s, v1.8h\n" + "fcvtl v12.4s, v2.4h\n" + "fcvtl2 v13.4s, v2.8h\n" + "fcvtl v14.4s, v3.4h\n" + "fcvtl2 v15.4s, v3.8h\n" + + "dup v16.4s, %w[zp]\n" + "dup v17.4s, %w[zp]\n" + "dup v18.4s, %w[zp]\n" + "dup v19.4s, %w[zp]\n" + "dup v20.4s, %w[zp]\n" + "dup v21.4s, %w[zp]\n" + "dup v22.4s, %w[zp]\n" + "dup v23.4s, %w[zp]\n" + "scvtf v16.4s, v16.4s\n" + "scvtf v17.4s, v17.4s\n" + "scvtf v18.4s, v18.4s\n" + "scvtf v19.4s, v19.4s\n" + "scvtf v20.4s, v20.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v23.4s, v23.4s\n" + + "fmla v16.4s, v8.4s, v28.4s\n" + "fmla v17.4s, v9.4s, v28.4s\n" + "fmla v18.4s, v10.4s, v28.4s\n" + "fmla v19.4s, v11.4s, v28.4s\n" + "fmla v20.4s, v12.4s, v28.4s\n" + "fmla v21.4s, v13.4s, v28.4s\n" + "fmla v22.4s, v14.4s, v28.4s\n" + "fmla v23.4s, v15.4s, v28.4s\n" + + "fcvtas v8.4s, v16.4s\n" + "fcvtas v9.4s, v17.4s\n" + "fcvtas v10.4s, v18.4s\n" + "fcvtas v11.4s, v19.4s\n" + "fcvtas v12.4s, v20.4s\n" + "fcvtas v13.4s, v21.4s\n" + "fcvtas v14.4s, v22.4s\n" + "fcvtas v15.4s, v23.4s\n" + + "smax v8.4s, v8.4s, v29.4s\n" + "smax v9.4s, v9.4s, v29.4s\n" + "smax v10.4s, v10.4s, v29.4s\n" + "smax v11.4s, v11.4s, v29.4s\n" + "smax v12.4s, v12.4s, v29.4s\n" + "smax v13.4s, v13.4s, v29.4s\n" + "smax v14.4s, v14.4s, v29.4s\n" + "smax v15.4s, v15.4s, v29.4s\n" + "smin v8.4s, v8.4s, v30.4s\n" + "smin v9.4s, v9.4s, v30.4s\n" + "smin v10.4s, v10.4s, v30.4s\n" + "smin v11.4s, v11.4s, v30.4s\n" + "smin v12.4s, v12.4s, v30.4s\n" + "smin v13.4s, v13.4s, v30.4s\n" + "smin v14.4s, v14.4s, v30.4s\n" + "smin v15.4s, v15.4s, v30.4s\n" + + "sqxtn v16.4h, v8.4s\n" + "sqxtn2 v16.8h, v9.4s\n" + "sqxtn v17.4h, v10.4s\n" + "sqxtn2 v17.8h, v11.4s\n" + "sqxtn v18.4h, v12.4s\n" + "sqxtn2 v18.8h, v13.4s\n" + "sqxtn v19.4h, v14.4s\n" + "sqxtn2 v19.8h, v15.4s\n" + "sqxtn v20.8b, v16.8h\n" + "sqxtn2 v20.16b, v17.8h\n" + "sqxtn v21.8b, v18.8h\n" + "sqxtn2 v21.16b, v19.8h\n" + + "st1 {v20.16b, v21.16b}, [%[quant_values]], #32\n" + + "beq 3f\n" + "cmp w8, #32\n" + "bge 1b\n" + + "2:\n" // 1 by 1 + "scvtf s10, %w[zp]\n" + "subs w8, w8, #1\n" + "ldr h0, [%[real_values]], #2\n" + "fcvt s0, h0\n" + "fmul s0, s0, s28\n" + "fadd s0, s0, s10\n" + "fcvtas s0, s0\n" + "smax v0.4s, v0.4s, v29.4s\n" + "smin v0.4s, v0.4s, v30.4s\n" + "sqxtn v0.4h, v0.4s\n" + "sqxtn v0.8b, v0.8h\n" + "st1 {v0.b}[0], [%[quant_values]], #1\n" + "bne 2b\n" + + "3:\n" + : + : [ size ] "r"(size), [ ivs ] "r"(ivs), [ real_values ] "r"(real_values), [ quant_values ] "r"(quant_values), + [ zp ] "r"(zp), [ min_value ] "r"(min_value), [ max_value ] "r"(max_value) + : "w8", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v28", "v29", "v30"); +} +#endif + +int DoQuantizeFp16ToInt8(const float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Fp16ToInt8_arm64(real_values, quant_values, scale, zp, size); +#else + const int8_t min_value = -128; + const int8_t max_value = 127; + for (int i = 0; i < size; ++i) { + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + continue; + } + if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + continue; + } + float temp = round((float)real_values[i] / scale + zp); + if (temp > max_value) { + quant_values[i] = max_value; + } else if (temp < min_value) { + quant_values[i] = min_value; + } else { + quant_values[i] = (int8_t)temp; + } + } +#endif + return NNACL_OK; +} + +int DoDequantizeUInt8ToFp16(const uint8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size) { + uint8_t zp_ = (uint8_t)zp; + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + real_values[i] = (quant_values[i] - zp_) * scale; + } + return NNACL_OK; +} + +int DoQuantizeFp16ToUInt8(const float16_t *real_values, uint8_t *quant_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + if (isinf((float)real_values[i])) { + quant_values[i] = 255; + continue; + } + float temp = round((float)real_values[i] / scale + zp); + if (temp > 255.0f) { + quant_values[i] = 255; + } else if (temp < 0.0f) { + quant_values[i] = 0; + } else { + quant_values[i] = (uint8_t)temp; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/quant_dtype_cast_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/quant_dtype_cast_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..db64103d55b8e838d92ea4e662d50e9a7117f73b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/quant_dtype_cast_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_QUANTDTYPECAST_FP16_H_ +#define NNACL_FP16_QUANTDTYPECAST_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoDequantizeInt8ToFp16(const int8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp16ToInt8(const float16_t *real_values, int8_t *quant_values, float scale, int32_t zp, int size); + +int DoDequantizeUInt8ToFp16(const uint8_t *quant_values, float16_t *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp16ToUInt8(const float16_t *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_QUANTDTYPECAST_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/ragged_range_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/ragged_range_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..320311d2f5f12ec9ce400e100b7c5854b36e2387 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/ragged_range_fp16.c @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/ragged_range_fp16.h" + +void RaggedRangeFp16(const float16_t *starts, const float16_t *limits, const float16_t *deltas, int *splits, + float16_t *value, const RaggedRangeStruct *param) { + splits[0] = 0; + for (int i = 0; i < param->rows_; i++) { + float16_t start = param->starts_is_scalar_ ? starts[0] : starts[i]; + float16_t limit = param->limits_is_scalar_ ? limits[0] : limits[i]; + float16_t delta = param->deltas_is_scalar_ ? deltas[0] : deltas[i]; + int len = NNACL_MAX((int)ceil((float16_t)(limit - start) / delta), 0); + splits[i + 1] = splits[i] + len; + for (int j = 0; j < len; j++) { + *value++ = start; + start += delta; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/ragged_range_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/ragged_range_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..ff0b066d712b4eac00211fd16e3d1bb6f3da8d13 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/ragged_range_fp16.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_RAGGED_RANGE_FP16_H_ +#define NNACL_FP16_RAGGED_RANGE_FP16_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/kernel/ragged_range.h" + +void RaggedRangeFp16(const float16_t *starts, const float16_t *limits, const float16_t *deltas, int *splits, + float16_t *value, const RaggedRangeStruct *param); + +#endif // NNACL_FP16_RAGGED_RANGE_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/range_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/range_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..9912388fcc2fff8cc7b57ebcfd07bfa26e18f4b7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/range_fp16.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_RANGE_FP16_H_ +#define NNACL_FP16_RANGE_FP16_H_ + +#include "nnacl/op_base.h" + +void RangeFp16(float16_t *output_ptr, float16_t start, float16_t delta, int nums) { + for (int i = 0; i < nums; ++i, start += delta) { + output_ptr[i] = start; + } +} + +#endif // NNACL_FP16_RANGE_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/reduce_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/reduce_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..5867e646ea5414cfc91825f5c8def4e0c4e77635 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/reduce_fp16.c @@ -0,0 +1,198 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "nnacl/fp16/reduce_fp16.h" +#include "nnacl/errorcode.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" + +int ReduceMeanFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (axis_size == 0) { + return NNACL_ERR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float16_t *outer_src = src_data + j * axis_size * inner_size; + float16_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float16_t *inner_src = outer_src + k; + float16_t *inner_dst = outer_dst + k; + float tmp = 0.0; + for (i = 0; i < axis_size; i++) { + tmp += inner_src[i * inner_size]; + } + *inner_dst = (float16_t)(tmp / axis_size); + } + } + return NNACL_OK; +} + +int ReduceMaxFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float16_t *outer_src = src_data + j * axis_size * inner_size; + float16_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float16_t *inner_src = outer_src + k; + float16_t *inner_dst = outer_dst + k; + float tmp = -FLT_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceMinFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float16_t *outer_src = src_data + j * axis_size * inner_size; + float16_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float16_t *inner_src = outer_src + k; + float16_t *inner_dst = outer_dst + k; + float16_t tmp = 65504; // fp16 max value + for (i = 0; i < axis_size; i++) { + tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceProdFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const float16_t *outer_src = src_data + j * axis_size * inner_size; + float16_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const float16_t *inner_src = outer_src + k; + float16_t *inner_dst = outer_dst + k; + float16_t tmp = 1.0f; + for (i = 0; i < axis_size; i++) { + tmp *= inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceSumFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + int stride = UP_DIV(outer_size, thread_num); + int start = stride * tid; + int end = MSMIN(outer_size, start + stride); + int num = end - start; +#ifdef ENABLE_NEON + int block_c8 = inner_size - inner_size % C8NUM; +#endif + + int src_stride = axis_size * inner_size; + src_data += start * src_stride; + dst_data += start * inner_size; + + for (int i = 0; i < num; i++, src_data += src_stride, dst_data += inner_size) { + int j = 0; +#ifdef ENABLE_NEON + for (; j < block_c8; j += C8NUM) { + const float16_t *inner_src = src_data + j; + float16_t *inner_dst = dst_data + j; + float16x8_t tmp = {0, 0, 0, 0, 0, 0, 0, 0}; + for (int k = 0; k < axis_size; k++) { + tmp = vaddq_f16(tmp, vld1q_f16(inner_src + k * inner_size)); + } + vst1q_f16(inner_dst, tmp); + } +#endif + for (; j < inner_size; j++) { + const float16_t *inner_src = src_data + j; + float16_t *inner_dst = dst_data + j; + float tmp = 0.0f; + for (int k = 0; k < axis_size; k++) { + tmp += inner_src[k * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceL2NormFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + int stride = UP_DIV(outer_size, thread_num); + int start = stride * tid; + int end = MSMIN(outer_size, start + stride); + int num = end - start; +#ifdef ENABLE_NEON + int block_c8 = inner_size - inner_size % C8NUM; +#endif + + int src_stride = axis_size * inner_size; + src_data += start * src_stride; + dst_data += start * inner_size; + + for (int i = 0; i < num; i++, src_data += src_stride, dst_data += inner_size) { + int j = 0; +#ifdef ENABLE_NEON + for (; j < block_c8; j += C8NUM) { + const float16_t *inner_src = src_data + j; + float16_t *inner_dst = dst_data + j; + float16x8_t tmp = {0, 0, 0, 0, 0, 0, 0, 0}; + for (int k = 0; k < axis_size; k++) { + float16x8_t src = vld1q_f16(inner_src + k * inner_size); + tmp = MS_FMAQ_F16(tmp, src, src); + } + vst1q_f16(inner_dst, MS_SQRTFX8_F16(tmp)); + } +#endif + for (; j < inner_size; j++) { + const float16_t *inner_src = src_data + j; + float tmp = 0.0f; + for (int k = 0; k < axis_size; k++) { + tmp += inner_src[k * inner_size] * inner_src[k * inner_size]; + } + dst_data[j] = sqrtf(tmp); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/reduce_fp16.h similarity index 34% rename from mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp16/reduce_fp16.h index c2d9cdf7f943acc42511f108a00bf5b9b18750b8..253cd8d6f60dc1aa708aab429dc95452ce57b6ee 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/reduce_fp16.h @@ -13,41 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H -#define MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H -#include -#include -#include -#include -#include -#include "include/api/cell.h" -#include "include/api/status.h" -#include "cxx_api/model/model_impl.h" -#include "cxx_api/model/acl/model_converter.h" -#include "cxx_api/model/acl/acl_model_options.h" -#include "ir/tensor.h" -#include "ir/anf.h" +#ifndef NNACL_FP16_REDUCE_FP16_H_ +#define NNACL_FP16_REDUCE_FP16_H_ +#include "nnacl/op_base.h" +#include "nnacl/reduce_parameter.h" -namespace mindspore { -class AclModel : public ModelImpl { - public: - AclModel() : model_converter_(), options_(nullptr) {} - ~AclModel() = default; +#ifdef __cplusplus +extern "C" { +#endif +int ReduceMeanFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceMaxFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceMinFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceProdFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceSumFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +int ReduceL2NormFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); +#ifdef __cplusplus +} +#endif - Status Build() override; - Status Resize(const std::vector &inputs, const std::vector> &dims) override; - - std::vector GetInputs() override; - std::vector GetOutputs() override; - - bool CheckDeviceSupport(mindspore::DeviceType device_type) override; - bool CheckModelSupport(enum ModelType model_type) override; - - private: - ModelConverter model_converter_; - std::shared_ptr options_; - std::map> dynamic_size_graph_map_; -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H +#endif // NNACL_FP16_REDUCE_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/resize_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/resize_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..082967fa3cd87b32fac7058efc56b66bcf75035b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/resize_fp16.c @@ -0,0 +1,380 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp16/resize_fp16.h" +#include "nnacl/common_func.h" +#include "nnacl/errorcode.h" + +void CalculateCoordinateFp16(float16_t out, int in, int *bottom, int *top, float16_t *bottom_weight) { + *bottom = (int)(floorf(out)); + *bottom = *bottom >= 0 ? *bottom : 0; // extrapolate may generate neg value + *top = *bottom + 1 < in ? (*bottom + 1) : (in - 1); + float16_t top_weight = (float16_t)out - (float16_t)(*bottom); + *bottom_weight = 1.0f - top_weight; +} + +static void BicubicBaseFuncFp16(float16_t a, float16_t x, float16_t *weight) { + float16_t abs_x = fabsf(x); + if (abs_x >= 0 && abs_x <= 1) { + *weight = ((a + 2) * abs_x - (a + 3)) * abs_x * abs_x + 1; + } else if (abs_x > 1 && abs_x <= 2) { + *weight = a * abs_x * abs_x * abs_x - 5 * a * abs_x * abs_x + 8 * a * abs_x - 4 * a; + } else { + *weight = 0; + } +} + +// a is a coefficient +// W(x) = { (a + 2) * |x| * |x| * |x| - (a + 3) * |x| * |x| + 1, for |x| <= 1 +// { a * |x| * |x| * |x| - 5 * a * |x| * |x| + 8 * a *|x| - 4 * a, for 1 < |x| < 2 +// { 0, otherwise +// the value of 'a' depends on if is half_pixel_center(the scheme is the same as tf). +// If is half pixel mode, a equals to -0.5, otherwise -0.75. +void CalculateWeightForBicubicFp16(float16_t out, int in, int *index, float16_t *weights, float16_t a) { + int floor_index = (int)(floorf(out)); + index[0] = (floor_index - 1) < 0 ? 0 : (floor_index - 1); + index[1] = floor_index; + index[2] = (floor_index + 1) < in ? (floor_index + 1) : (in - 1); + index[3] = (floor_index + 2) < in ? (floor_index + 2) : (in - 1); + + // get positive value + float16_t distance[4] = {-1, 0, 1, 2}; + float16_t tmp_dis = out - (float16_t)floor_index; + distance[0] -= tmp_dis; + distance[1] -= tmp_dis; + distance[2] -= tmp_dis; + distance[3] -= tmp_dis; + + for (int i = 0; i < 4; ++i) { + BicubicBaseFuncFp16(a, distance[i], &weights[i]); + } +} + +int PrepareResizeBilinearFp16(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, + int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float16_t *y_bottom_weights, + float16_t *x_left_weights) { + if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || + x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_h = input_shape[1]; + int in_w = input_shape[2]; + + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int h = 0; h < new_height; h++) { + float16_t actual_y = calculate(h, in_h, new_height); + CalculateCoordinateFp16(actual_y, in_h, y_bottoms + h, y_tops + h, y_bottom_weights + h); + } + for (int w = 0; w < new_width; w++) { + float16_t actual_x = calculate(w, in_w, new_width); + CalculateCoordinateFp16(actual_x, in_w, x_lefts + w, x_rights + w, x_left_weights + w); + } + return NNACL_OK; +} + +int PrepareResizeBicubicFp16(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, + int *y_tops, int *x_lefts, float16_t *y_weights, float16_t *x_weights, + float16_t cubic_coeff) { + if (input_shape == NULL || output_shape == NULL || y_tops == NULL || x_lefts == NULL || y_weights == NULL || + x_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int h = 0; h < new_height; h++) { + float16_t actual_y = calculate(h, in_h, new_height); + CalculateWeightForBicubicFp16(actual_y, in_h, y_tops + 4 * h, y_weights + 4 * h, cubic_coeff); + } + for (int w = 0; w < new_width; w++) { + float16_t actual_x = calculate(w, in_w, new_width); + CalculateWeightForBicubicFp16(actual_x, in_w, x_lefts + 4 * w, x_weights + 4 * w, cubic_coeff); + } + return NNACL_OK; +} + +int InterpRowFp16(const float16_t *src_line, float16_t *linear_output, int new_width, const float16_t *x_left_weights, + const int *x_lefts, const int *x_rights, int in_c) { + int w; + for (w = 0; w < new_width; w++) { + int c = 0; +#if defined(ENABLE_NEON) + float16x8_t left_w_8 = vdupq_n_f16(x_left_weights[w]); + float16x8_t right_w_8 = vdupq_n_f16(1.0f - x_left_weights[w]); + for (; c <= in_c - C8NUM; c += C8NUM) { + float16x8_t left = vld1q_f16(src_line + x_lefts[w] * in_c + c); + float16x8_t right = vld1q_f16(src_line + x_rights[w] * in_c + c); + float16x8_t interp_value = vaddq_f16(vmulq_f16(left, left_w_8), vmulq_f16(right, right_w_8)); + vst1q_f16(linear_output + w * in_c + c, interp_value); + } +#endif + int left_w_offset = x_lefts[w] * in_c; + int right_w_offset = x_rights[w] * in_c; + for (; c < in_c; c++) { + float16_t left = src_line[left_w_offset + c]; + float16_t right = src_line[right_w_offset + c]; + linear_output[w * in_c + c] = left * x_left_weights[w] + right * (1.0f - x_left_weights[w]); + } + } + return 0; +} + +int InterpColFp16(const float16_t *bottom_line, const float16_t *top_line, float16_t *output, int new_width, + float16_t y_bottom_weight, int in_c) { + int w; + for (w = 0; w < new_width; w++) { + int c = 0; +#if defined(ENABLE_NEON) + float16x8_t bottom_w_8 = vdupq_n_f16(y_bottom_weight); + float16x8_t top_w_8 = vdupq_n_f16(1.0f - y_bottom_weight); + for (; c <= in_c - C8NUM; c += C8NUM) { + float16x8_t bottom = vld1q_f16(bottom_line + w * in_c + c); + float16x8_t top = vld1q_f16(top_line + w * in_c + c); + float16x8_t interp_value = vaddq_f16(vmulq_f16(bottom, bottom_w_8), vmulq_f16(top, top_w_8)); + vst1q_f16(output + w * in_c + c, interp_value); + } +#endif + for (; c < in_c; c++) { + float16_t bottom = bottom_line[w * in_c + c]; + float16_t top = top_line[w * in_c + c]; + output[w * in_c + c] = bottom * y_bottom_weight + top * (1.0f - y_bottom_weight); + } + } + return 0; +} + +void BilinearFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, + const int *y_bottom, const int *y_top, const int *x_left, const int *x_right, + const float16_t *y_bottom_weight, const float16_t *x_left_weight, float16_t *line0, float16_t *line1, + const int h_begin, const int h_end) { + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_width = output_shape[2]; + int h_stride = new_width * in_c; + + bool cache_line_used[2] = {false, false}; + int cache_line_num[2] = {-1, -1}; + float16_t *const cache_line_ptr[2] = {line0, line1}; + float16_t *current_line_ptr[2] = {line0, line1}; + int current_line_num[2] = {-1, -1}; + + for (int h = h_begin; h < h_end; h++) { + current_line_num[0] = y_bottom[h]; + current_line_num[1] = y_top[h]; + + for (int i = 0; i < 2; i++) { + cache_line_used[i] = false; + } + // search if we cached + for (int j = 0; j < 2; j++) { + bool find = false; + for (int k = 0; k < 2; k++) { + if (current_line_num[j] == cache_line_num[k]) { + cache_line_used[k] = true; + current_line_ptr[j] = cache_line_ptr[k]; + find = true; + break; + } + } + + if (!find) { + const float16_t *line = input_data + current_line_num[j] * in_w * in_c; + for (int k = 0; k < 2; k++) { + if (!cache_line_used[k]) { + cache_line_num[k] = current_line_num[j]; + cache_line_used[k] = true; + current_line_ptr[j] = cache_line_ptr[k]; + InterpRowFp16(line, current_line_ptr[j], new_width, x_left_weight, x_left, x_right, in_c); + break; + } + } + } + } + // do col interp + InterpColFp16(current_line_ptr[0], current_line_ptr[1], output_data + h * h_stride, new_width, y_bottom_weight[h], + in_c); + } +} + +int ResizeBilinearFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, const int *y_bottoms, const int *y_tops, const int *x_lefts, + const int *x_rights, const float16_t *y_bottom_weights, const float16_t *x_left_weights, + float16_t *line0, float16_t *line1, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_bottoms == NULL || + y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_b = input_shape[0]; + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int b = 0; b < in_b; b++) { + const float16_t *input = input_data + b * in_h * in_w * in_c; + float16_t *output = output_data + b * new_height * new_width * in_c; + BilinearFp16(input, output, input_shape, output_shape, y_bottoms, y_tops, x_lefts, x_rights, y_bottom_weights, + x_left_weights, line0, line1, h_begin, h_end); + } + return NNACL_OK; +} + +void BicubicInterpRowFp16(const float16_t *src, float16_t *dst, const float16_t *weights, const int *lefts, int width, + int channel) { + for (int w = 0; w < width; w++) { + const float16_t *weight = weights + 4 * w; + float16_t *dst_w = dst + w * channel; + const float16_t *src0_w = src + lefts[4 * w] * channel; + const float16_t *src1_w = src + lefts[4 * w + 1] * channel; + const float16_t *src2_w = src + lefts[4 * w + 2] * channel; + const float16_t *src3_w = src + lefts[4 * w + 3] * channel; + int c = 0; +#if defined(ENABLE_NEON) + float16x8_t weight0_vec_8 = vdupq_n_f16(weight[0]); + float16x8_t weight1_vec_8 = vdupq_n_f16(weight[1]); + float16x8_t weight2_vec_8 = vdupq_n_f16(weight[2]); + float16x8_t weight3_vec_8 = vdupq_n_f16(weight[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + float16x8_t src0_vec = vld1q_f16(src0_w + c); + float16x8_t src1_vec = vld1q_f16(src1_w + c); + float16x8_t src2_vec = vld1q_f16(src2_w + c); + float16x8_t src3_vec = vld1q_f16(src3_w + c); + float16x8_t dst0 = vmulq_f16(src0_vec, weight0_vec_8); + float16x8_t dst1 = vmulq_f16(src1_vec, weight1_vec_8); + float16x8_t dst2 = vmulq_f16(src2_vec, weight2_vec_8); + float16x8_t dst3 = vmulq_f16(src3_vec, weight3_vec_8); + float16x8_t interp_value = vaddq_f16(dst3, vaddq_f16(dst2, vaddq_f16(dst1, dst0))); + vst1q_f16(dst_w + c, interp_value); + } +#endif + for (; c < channel; c++) { + dst_w[c] = src0_w[c] * weight[0] + src1_w[c] * weight[1] + src2_w[c] * weight[2] + src3_w[c] * weight[3]; + } + } +} + +void BicubicInterpColFp16(const float16_t *src, float16_t *dst, const float16_t *weights, int width, int channel) { + const float16_t *src0 = src; + const float16_t *src1 = src + width * channel; + const float16_t *src2 = src + 2 * width * channel; + const float16_t *src3 = src + 3 * width * channel; + for (int w = 0; w < width; w++) { + float16_t *dst_w = dst + w * channel; + const float16_t *src0_w = src0 + w * channel; + const float16_t *src1_w = src1 + w * channel; + const float16_t *src2_w = src2 + w * channel; + const float16_t *src3_w = src3 + w * channel; + int c = 0; +#ifdef ENABLE_NEON + float16x8_t weight0_vec_8 = vdupq_n_f16(weights[0]); + float16x8_t weight1_vec_8 = vdupq_n_f16(weights[1]); + float16x8_t weight2_vec_8 = vdupq_n_f16(weights[2]); + float16x8_t weight3_vec_8 = vdupq_n_f16(weights[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + float16x8_t src0_vec = vld1q_f16(src0_w + c); + float16x8_t src1_vec = vld1q_f16(src1_w + c); + float16x8_t src2_vec = vld1q_f16(src2_w + c); + float16x8_t src3_vec = vld1q_f16(src3_w + c); + float16x8_t dst1 = vmulq_f16(src0_vec, weight0_vec_8); + float16x8_t dst2 = vmulq_f16(src1_vec, weight1_vec_8); + float16x8_t dst3 = vmulq_f16(src2_vec, weight2_vec_8); + float16x8_t dst4 = vmulq_f16(src3_vec, weight3_vec_8); + float16x8_t interp_value = vaddq_f16(dst4, vaddq_f16(dst3, vaddq_f16(dst1, dst2))); + vst1q_f16(dst_w + c, interp_value); + } +#endif + for (; c < channel; c++) { + dst_w[c] = src0_w[c] * weights[0] + src1_w[c] * weights[1] + src2_w[c] * weights[2] + src3_w[c] * weights[3]; + } + } +} + +void BicubicFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, const int *output_shape, + const int *y_tops, const int *x_lefts, const float16_t *y_weights, const float16_t *x_weights, + float16_t *line_buffer, const int h_begin, const int h_end) { + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_width = output_shape[2]; + int h_stride = new_width * in_c; + + for (int h = h_begin; h < h_end; h++) { + for (int i = 0; i < 4; ++i) { + BicubicInterpRowFp16(input_data + y_tops[4 * h + i] * in_w * in_c, line_buffer + i * h_stride, x_weights, x_lefts, + new_width, in_c); + } + BicubicInterpColFp16(line_buffer, output_data + h * h_stride, y_weights + 4 * h, new_width, in_c); + } +} + +int ResizeBicubicFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, const int *y_tops, const int *x_lefts, const float16_t *y_weights, + const float16_t *x_weights, float16_t *line_buffer, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_tops == NULL || + x_lefts == NULL || y_weights == NULL || x_weights == NULL) { + return NNACL_NULL_PTR; + } + int input_cube_per_batch = input_shape[1] * input_shape[2] * input_shape[3]; + int output_cube_per_batch = output_shape[1] * output_shape[2] * input_shape[3]; + for (int b = 0; b < input_shape[0]; b++) { + const float16_t *input = input_data + b * input_cube_per_batch; + float16_t *output = output_data + b * output_cube_per_batch; + BicubicFp16(input, output, input_shape, output_shape, y_tops, x_lefts, y_weights, x_weights, line_buffer, h_begin, + h_end); + } + return NNACL_OK; +} + +int ResizeNearestNeighborFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, CalculateOriginalCoordinate calculate, + int coordinate_transform_mode, int tid, int thread_num) { + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int c = input_shape[3]; + bool align_corners = coordinate_transform_mode == 1; + for (int batch = 0; batch < output_shape[0]; batch++) { + for (int y = tid; y < output_shape[1]; y += thread_num) { + float16_t actual_y = calculate(y, input_shape[1], output_shape[1]); + int input_y; + if (align_corners) { + input_y = (int)(roundf(actual_y)); + } else { + input_y = (int)(floorf(actual_y)); + } + for (int x = 0; x < output_shape[2]; x++) { + float16_t actual_x = calculate(x, input_shape[2], output_shape[2]); + int input_x; + if (align_corners) { + input_x = (int)(roundf(actual_x)); + } else { + input_x = (int)(floorf(actual_x)); + } + int in_offset = Offset(input_shape, batch, input_y, input_x, 0); + int out_offset = Offset(output_shape, batch, y, x, 0); + memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(float16_t)); + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/resize_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/resize_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..9519189a363f5b1edf7b071713e7b30a971caf42 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/resize_fp16.h @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_RESIZE_FP16_H_ +#define NNACL_FP16_RESIZE_FP16_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/resize_parameter.h" +#include "nnacl/op_base.h" +#include "nnacl/crop_parameter.h" +#include "nnacl/fp32/resize_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PrepareResizeBilinearFp16(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, + int *y_bottoms, int *y_tops, int *x_lefts, int *x_rights, float16_t *y_bottom_weights, + float16_t *x_left_weights); + +int PrepareResizeBicubicFp16(const int *input_shape, const int *output_shape, CalculateOriginalCoordinate calculate, + int *y_tops, int *x_lefts, float16_t *y_weights, float16_t *x_weights, + float16_t cubic_coeff); + +int ResizeBilinearFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, const int *y_bottoms, const int *y_tops, const int *x_lefts, + const int *x_rights, const float16_t *y_bottom_weights, const float16_t *x_left_weights, + float16_t *line0, float16_t *line1, const int h_begin, const int h_end); + +int ResizeBicubicFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, const int *y_tops, const int *x_lefts, const float16_t *y_weights, + const float16_t *x_weights, float16_t *line_buffer, const int h_begin, const int h_end); + +int ResizeNearestNeighborFp16(const float16_t *input_data, float16_t *output_data, const int *input_shape, + const int *output_shape, CalculateOriginalCoordinate calculate, + int coordinate_transform_mode, int tid, int thread_num); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_RESIZE_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/scale_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/scale_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..26c7b9e320b2771fc2f279d438cb8fb0f26d901a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/scale_fp16.c @@ -0,0 +1,226 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/scale_fp16.h" + +void Fp16ScaleInner(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size, int inner_size) { + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_NEON + for (; in_index < inner_size - 8; in_index += 8) { + int in_offset = axis_offset + in_index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vdupq_n_f16(scale[i]); + float16x8_t offset_8 = vdupq_n_f16(offset[i]); + float16x8_t result = vfmaq_f16(offset_8, data, scale_8); + + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + out_data[in_offset] = in_data[in_offset] * scale[i] + offset[i]; + } + } + } +} + +void Fp16ScaleAxis(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size) { + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_NEON + for (; index < axis_size - 8; index += 8) { + int in_offset = out_offset + index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vld1q_f16(scale + index); + float16x8_t offset_8 = vld1q_f16(offset + index); + float16x8_t result = vfmaq_f16(offset_8, data, scale_8); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + out_data[in_offset] = in_data[in_offset] * scale[index] + offset[index]; + } + } +} + +void DoScaleFp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + Fp16ScaleAxis(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + Fp16ScaleInner(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void Fp16ScaleInnerRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size, int inner_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_NEON + for (; in_index < inner_size - 8; in_index += 8) { + int in_offset = axis_offset + in_index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vdupq_n_f16(scale[i]); + float16x8_t offset_8 = vdupq_n_f16(offset[i]); + float16x8_t tmp = vfmaq_f16(offset_8, data, scale_8); + float16x8_t result = vmaxq_f16(tmp, zeros); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } + } +} + +void Fp16ScaleAxisRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_NEON + for (; index < axis_size - 8; index += 8) { + int in_offset = out_offset + index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vld1q_f16(scale + index); + float16x8_t offset_8 = vld1q_f16(offset + index); + float16x8_t tmp = vfmaq_f16(offset_8, data, scale_8); + float16x8_t result = vmaxq_f16(tmp, zeros); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } +} + +void Fp16DoScaleRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + Fp16ScaleAxisRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + Fp16ScaleInnerRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void Fp16ScaleInnerRelu6(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size, int inner_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; + float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_NEON + for (; in_index < inner_size - 8; in_index += 8) { + int in_offset = axis_offset + in_index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vdupq_n_f16(scale[i]); + float16x8_t offset_8 = vdupq_n_f16(offset[i]); + float16x8_t tmp = vfmaq_f16(offset_8, data, scale_8); + float16x8_t result = vminq_f16(vmaxq_f16(tmp, zeros), bounds); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } + } +} + +void Fp16ScaleAxisRelu6(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int outer_start, int outer_end, int axis_size) { +#ifdef ENABLE_NEON + float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; + float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_NEON + for (; index < axis_size - 8; index += 8) { + int in_offset = out_offset + index; + float16x8_t data = vld1q_f16(in_data + in_offset); + float16x8_t scale_8 = vld1q_f16(scale + index); + float16x8_t offset_8 = vld1q_f16(offset + index); + float16x8_t tmp = vfmaq_f16(offset_8, data, scale_8); + float16x8_t result = vminq_f16(vmaxq_f16(tmp, zeros), bounds); + vst1q_f16(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } +} + +void DoScaleRelu6Fp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + Fp16ScaleAxisRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + Fp16ScaleInnerRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/scale_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/scale_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..43fe5887d0bd60cf6ce40cd484e993f4ae29a560 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/scale_fp16.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_SCALE_FP16_H_ +#define NNACL_FP16_SCALE_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl/scale_parameter.h" +#include "nnacl/kernel/scale.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DoScaleFp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param); +void Fp16DoScaleRelu(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param); +void DoScaleRelu6Fp16(const float16_t *in_data, float16_t *out_data, const float16_t *scale, const float16_t *offset, + int task_id, const ScaleStruct *scale_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_SCALE_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/softmax_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/softmax_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..ad614a0a4065de24df1abfedb37ff3dc4ab8acb4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/softmax_fp16.c @@ -0,0 +1,134 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/softmax_fp16.h" +#include +#include "nnacl/fp16/exp_fp16.h" + +void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + int j = 0; +#ifdef ENABLE_NEON + float16x8_t max_8 = vdupq_n_f16(-FLT16_MAX); + int count = (channel / C8NUM) * C8NUM; + for (; j < count; j += C8NUM) { + float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + j); + max_8 = vmaxq_f16(max_8, input_8); + } + float16_t max = MS_MAXVQ_F16(max_8); +#else + float16_t max = -FLT_MAX; +#endif + for (; j < channel; j++) { + float16_t input = src[cur_batch_offset + j]; + if (input > max) { + max = input; + } + } + int k = 0; +#ifdef ENABLE_NEON + int count2 = (channel / C8NUM) * C8NUM; + for (; k < count2; k += C8NUM) { + float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + k); + float16x8_t output_8 = vsubq_f16(input_8, vdupq_n_f16(max)); + vst1q_f16(dst + cur_batch_offset + k, output_8); + } +#endif + for (; k < channel; k++) { + int offset = cur_batch_offset + k; + dst[offset] = src[offset] - max; + } + } +} + +void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + float16_t sum = 0.0f; + int j = 0; +#ifdef ENABLE_NEON + float16x8_t sum8 = vdupq_n_f16(0); + int count = (channel / C8NUM) * C8NUM; + for (; j < count; j += C8NUM) { + sum8 = vaddq_f16(sum8, vld1q_f16(src + cur_batch_offset + j)); + } + sum = sum8[0] + sum8[1] + sum8[2] + sum8[3] + sum8[4] + sum8[5] + sum8[6] + sum8[7]; +#endif + for (; j < channel; j++) { + sum += src[cur_batch_offset + j]; + } + int k = 0; +#ifdef ENABLE_NEON + const float16_t div = 1.0f / sum; + for (; k < count; k += C8NUM) { + vst1q_f16(dst + cur_batch_offset + k, vmulq_n_f16(vld1q_f16(src + cur_batch_offset + k), div)); + } +#endif + for (; k < channel; k++) { + dst[cur_batch_offset + k] = src[cur_batch_offset + k] / sum; + } + } +} + +void SoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, int batch, int channel) { + SoftmaxNormFp16(src, dst, batch, channel); + ExpFp16(dst, dst, batch * channel); + SumAndDivFp16(dst, dst, batch, channel); +} + +// output = exp(input) / reduce_sum(exp(input), axis) +void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int axis, int n_dim, + const int *input_shape) { + int inner_size = 1; + int outter_size = 1; + + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float16_t max_data = input_ptr[inner_offset]; + sum_data[k + sum_outter_offset] = 0; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = expf(input_ptr[axis_offset] - max_data); + sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; + } + } + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset]; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/softmax_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/softmax_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..edd202b31f82ee542517b90dec65062fffba74d4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/softmax_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_SOFTMAX_FP16_H_ +#define NNACL_FP16_SOFTMAX_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl/softmax_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channel); +void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *sum_data, int axis, int n_dim, + const int *input_shape); +void SoftmaxLastAxisFp16(const float16_t *src, float16_t *dst, int batch, int channel); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_SOFTMAX_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/sparse_to_dense_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/sparse_to_dense_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..2ecc8d938046d7e389b3f0a9a41d69c3af27f7d7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/sparse_to_dense_fp16.c @@ -0,0 +1,78 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/sparse_to_dense_fp16.h" +#include "nnacl/errorcode.h" + +int SparseToDenseSetDefaultFp16(float16_t *output, float16_t default_value, SparseToDenseParameter *param, + int task_id) { + if (output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->output_num, param->op_parameter_.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->output_num); + for (int i = begin; i < end; i++) { + output[i] = default_value; + } + return NNACL_OK; +} + +int SparseToDenseFp16(int *indices_vec, const float16_t *sparse_values, float16_t default_value, float16_t *output, + SparseToDenseParameter *param, int task_id) { + if (indices_vec == NULL || sparse_values == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter_.thread_num_) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->index_num, param->op_parameter_.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->index_num); + + int stride0 = param->output_stride[0]; + int stride1 = param->output_stride[1]; + int stride2 = param->output_stride[2]; + + if (param->validate_indices_ == true) { + int index_before = -1; + for (int i = begin; i < end; i++) { + int *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + if (index <= index_before) { + return NNACL_ERR; + } + index_before = index; + } + } + + if (param->is_scalar == true) { + for (int i = begin; i < end; i++) { + int *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + output[index] = sparse_values[0]; + } + } else { + for (int i = begin; i < end; i++) { + int *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + output[index] = sparse_values[i]; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/sparse_to_dense_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/sparse_to_dense_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..7ef7a3dcadbd3a8f447f10ed1441e816905e6741 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/sparse_to_dense_fp16.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_SPARSE_TO_DENSE_FP16_H_ +#define NNACL_FP16_SPARSE_TO_DENSE_FP16_H_ + +#include "nnacl/sparse_to_dense_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SparseToDenseSetDefaultFp16(float16_t *output, float16_t default_value, SparseToDenseParameter *param, int task_id); +int SparseToDenseFp16(int *indices_vec, const float16_t *sparse_values, float16_t default_value, float16_t *output, + SparseToDenseParameter *param, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_SPARSE_TO_DENSE_FP16_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_plugin_impl.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/splice_fp16.c similarity index 47% rename from mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_plugin_impl.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp16/splice_fp16.c index 595fc31d1268e73ad00f5a585a58dd834d789f4d..e758adadcda33bcb294f7468a9bff38fcb4ab84c 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_plugin_impl.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/splice_fp16.c @@ -1,5 +1,5 @@ /** - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,21 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_PLUGIN_IMPL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_PLUGIN_IMPL_H_ -#include "include/api/status.h" -#include "src/common/log_adapter.h" -#include "extendrt/delegate/plugin/tensorrt_executor_plugin.h" - -namespace mindspore::lite { -class TensorRTPluginImpl : public TensorRTExecutorPluginImplBase { - public: - TensorRTPluginImpl() = default; - ~TensorRTPluginImpl() = default; - int GetGPUGroupSize() const; - int GetRankID() const; -}; -} // namespace mindspore::lite - -extern "C" MS_API mindspore::lite::TensorRTExecutorPluginImplBase *CreateTensorRTPluginImpl(); -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_API_H_ +#include "nnacl/fp16/splice_fp16.h" +void SpliceFp16(const float16_t *src_data, int src_row, int src_col, const SpliceParameter *param, float16_t *dst_data, + int dst_row, int dst_col) { + int forward_index = 0; + for (int r = 0; r < dst_row; ++r) { + float16_t *dst_row_data = dst_data + r * dst_col; + for (int off = 0; off < param->context_dim_; ++off) { + int r_off = param->forward_indexes_[forward_index]; + forward_index++; + const float16_t *tmp_src_data = src_data + r_off * src_col; + float16_t *tmp_dst_data = dst_row_data + off * src_col; + memcpy(tmp_dst_data, tmp_src_data, (size_t)(src_col) * sizeof(float16_t)); + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/splice_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/splice_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..9bc52c4e21444bc36c2076518f8662b33041f9af --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/splice_fp16.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_SPLICE_FP16_H_ +#define NNACL_FP16_SPLICE_FP16_H_ +#include +#include "nnacl/splice_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif + +void SpliceFp16(const float16_t *src_data, int src_row, int src_col, const SpliceParameter *param, float16_t *dst_data, + int dst_row, int dst_col); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_SPLICE_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/topk_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/topk_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..6e8aa91ce3453b4ba8b15554ccf7580f937e2cd7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/topk_fp16.c @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/topk_fp16.h" + +int TopkFp16DescendCmp(const void *a, const void *b) { + float16_t sub = ((const TopkFp16Node *)b)->element - ((const TopkFp16Node *)a)->element; + if (sub > 0) { + return 1; + } else if (sub < 0) { + return -1; + } + if (((const TopkFp16Node *)a)->index > ((const TopkFp16Node *)b)->index) { + return 1; + } else { + return -1; + } +} + +int TopkFp16IndexSortCmp(const void *a, const void *b) { + if (((const TopkFp16Node *)a)->index > ((const TopkFp16Node *)b)->index) { + return 1; + } else { + return -1; + } +} + +void TopkFp16(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter) { + int dim_size = parameter->dim_size_; + int outer_loop_num = parameter->outer_loop_num_; + int inner_loop_num = parameter->inner_loop_num_; + int k = parameter->k_; + TopkFp16Node *top_map = (TopkFp16Node *)parameter->topk_node_list_; + + float16_t *cur_input_data = (float16_t *)input_data; + float16_t *cur_output_data = (float16_t *)output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < outer_loop_num; i++) { + int in_offset = i * dim_size * inner_loop_num; + int out_offset = i * k * inner_loop_num; + for (int j = 0; j < inner_loop_num; j++) { + for (int m = 0; m < dim_size; m++) { + int offset = in_offset + m * inner_loop_num + j; + top_map[m].element = *(cur_input_data + offset); + top_map[m].index = m; + } + qsort(top_map, dim_size, sizeof(top_map[0]), TopkFp16DescendCmp); + if (!parameter->sorted_) { + qsort(top_map, k, sizeof(top_map[0]), TopkFp16IndexSortCmp); + } + for (int m = 0; m < k; m++) { + int offset = out_offset + m * inner_loop_num + j; + cur_output_data[offset] = top_map[m].element; + cur_output_index[offset] = top_map[m].index; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/topk_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/topk_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..dc103c17ea23582c8c462957842dd260ef2b839d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/topk_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_TOPK_FP16_H_ +#define NNACL_FP16_TOPK_FP16_H_ + +#include "nnacl/fp32/topk_fp32.h" +#include "nnacl/op_base.h" + +typedef struct TopkFp16Node { + float16_t element; + int32_t index; +} TopkFp16Node; + +#ifdef __cplusplus +extern "C" { +#endif +void TopkFp16(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_TOPK_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/transpose_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/transpose_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..075e0994e8b45bb71719a8c42ffa77b33dfd044e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/transpose_fp16.c @@ -0,0 +1,257 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/transpose_fp16.h" +#include +#include "nnacl/errorcode.h" + +void Fp16TransposeDim2(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * output1; + int stride0_i = i * 1 * stride0; + for (int j = 0; j < output1; ++j) { + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; + } + } +} + +void Fp16TransposeDim3(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; + } + } + } +} + +void Fp16TransposeDim4(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; + } + } + } + } +} + +void Fp16TransposeDim5(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; + } + } + } + } + } +} + +void Fp16TransposeDim6(const float16_t *in_data, float16_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int stride5 = strides[perm[5]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int out_stride4 = out_strides[4]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + const int output5 = output_shape[5]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + int out_stride4_n = n * out_stride4; + int stride4_n = n * stride4; + for (int g = 0; g < output5; ++g) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + g] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + g * stride5]; + } + } + } + } + } + } +} + +void TransposeDimsFp16(const void *in, void *out, const int *output_shape, int *perm, int *strides, int *out_strides, + int num_axes, int task_id, int thread_num) { + const float16_t *in_data = (const float16_t *)in; + float16_t *out_data = (float16_t *)out; + + NNACL_CHECK_NULL_RETURN_VOID(in_data); + NNACL_CHECK_NULL_RETURN_VOID(out_data); + NNACL_CHECK_NULL_RETURN_VOID(output_shape); + NNACL_CHECK_NULL_RETURN_VOID(perm); + NNACL_CHECK_NULL_RETURN_VOID(strides); + NNACL_CHECK_NULL_RETURN_VOID(out_strides); + NNACL_CHECK_ZERO_RETURN(thread_num); + + size_t data_size = (*out_strides) * output_shape[0]; + size_t offset_size = UP_DIV(data_size, thread_num); + size_t task_offset = offset_size * task_id; + int count = data_size - task_offset; + if (count <= 0) { + return; + } + count = MSMIN(offset_size, count); + for (size_t idx = task_offset; idx < task_offset + count; ++idx) { + int pos = idx; + int output_idx = 0; + int input_idx = 0; + for (int i = 0; i < num_axes; ++i) { + NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); + int position = pos / *(out_strides + i); + int out_stride = i < num_axes - 1 ? out_strides[i] : 1; + output_idx += (position * out_stride); + input_idx += (position * strides[perm[i]]); + pos -= position * (*(out_strides + i)); + } + out_data[output_idx] = in_data[input_idx]; + } +} + +int DoTransposeFp16(const void *in, void *out, const int *output_shape, int *perm, int *strides, int *out_strides, + int data_size, int num_axes) { + const float16_t *in_data = (const float16_t *)in; + float16_t *out_data = (float16_t *)out; + + NNACL_CHECK_NULL_RETURN_ERR(in_data); + NNACL_CHECK_NULL_RETURN_ERR(out_data); + NNACL_CHECK_NULL_RETURN_ERR(output_shape); + NNACL_CHECK_NULL_RETURN_ERR(perm); + NNACL_CHECK_NULL_RETURN_ERR(strides); + NNACL_CHECK_NULL_RETURN_ERR(out_strides); + + // check if transpose is needed + bool needTranspose = false; + for (int i = 1; i < num_axes; ++i) { + if (perm[i] - perm[i - 1] != 1) { + needTranspose = true; + break; + } + } + + if (!needTranspose) { + (void)memcpy(out_data, in_data, data_size); + return NNACL_OK; + } + for (int i = 0; i < num_axes; ++i) { + if (perm[i] < 0) { + return NNACL_PARAM_INVALID; + } + } + if (num_axes == 2) { + Fp16TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 3) { + Fp16TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 4) { + Fp16TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 5) { + Fp16TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 6) { + Fp16TransposeDim6(in_data, out_data, strides, out_strides, perm, output_shape); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/transpose_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/transpose_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..df83508a6afa7864e6106f04bb98c7b0da185ac3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/transpose_fp16.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_TRANSPOSE_FP16_H_ +#define NNACL_FP16_TRANSPOSE_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl/transpose_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void TransposeDimsFp16(const void *src, void *dst, const int *output_shape, int *perm, int *strides, int *out_strides, + int num_axes, int task_id, int thread_num); +int DoTransposeFp16(const void *src, void *dst, const int *output_shape, int *perm, int *strides, int *out_strides, + int data_size, int num_axes); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_TRANSPOSE_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/unique_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/unique_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..49a93ac9a2a89c080707b13463bdfef89153b57f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/unique_fp16.c @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/unique_fp16.h" + +int FindFp16(const float16_t *array, int len, float16_t target) { + for (int i = 0; i < len; ++i) { + if (array[i] == target) { + return i; + } + } + return -1; +} + +void UniqueFp16(const float16_t *input, int input_len, float16_t *output0, int *output0_len, int *output1) { + *output0_len = 0; + for (int i = 0; i < input_len; i++) { + int idx = FindFp16(output0, *output0_len, input[i]); + if (idx != -1) { + *output1++ = idx; + } else { + output0[(*output0_len)++] = input[i]; + *output1++ = *output0_len - 1; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/unique_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/unique_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..dcbd7297a15f54cd10cb5e37594118f3569c96fa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/unique_fp16.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_UNIQUE_FP16_H +#define NNACL_FP16_UNIQUE_FP16_H + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void UniqueFp16(const float16_t *input, int input_len, float16_t *output0, int *output0_len, int *output1); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_UNIQUE_FP16_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/utils_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/utils_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..2038ee623d8d70c394bcbfbd68d74f86462dba37 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/utils_fp16.c @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/utils_fp16.h" +#include "nnacl/fp16/common_func_fp16.h" +#include "nnacl/fp16/cast_fp16.h" +#include "nnacl/tensor_c_utils.h" + +void *GetOrAllocFp16Data(TensorC *t, ExecEnv *env, bool cast) { + NNACL_CHECK_NULL_RETURN_NULL(t); + if (t->data_type_ == kNumberTypeFloat16) { + return t->data_; + } + if (t->data_type_ == kNumberTypeFloat32) { + int ele_num = NNACLGetElementNum(t); + void *fp16_data = env->Alloc(env->allocator_, ele_num * sizeof(float16_t)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fp16_data); + if (cast) { + Float32ToFloat16((float *)t->data_, (float16_t *)fp16_data, ele_num); + } + return fp16_data; + } + return NULL; +} diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_plugin_impl.cc b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/utils_fp16.h similarity index 70% rename from mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_plugin_impl.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/fp16/utils_fp16.h index 21737aa9390851902c49f30a3f156f596aff72ba..6c5849ab4a97954446374434a72c6dbace4d24ad 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_plugin_impl.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/utils_fp16.h @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef NNACL_FP16_UTILS_FP16_H_ +#define NNACL_FP16_UTILS_FP16_H_ -#include -#include "extendrt/delegate/ascend_native/ascend_native_plugin_impl.h" +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" -namespace mindspore { +void *GetOrAllocFp16Data(TensorC *t, ExecEnv *env, bool cast); -AscendNativeExecutorPluginImpl *CreateAscendNativeExecutorPluginImpl() { return new AscendNativeExecutorPluginImpl(); } -} // namespace mindspore +#endif // NNACL_FP16_UTILS_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/where_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/where_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..683bb36e48f77054ba6d2514f44801c99c365146 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/where_fp16.c @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16/where_fp16.h" +#include "nnacl/common_func.h" + +void WhereWithTripleInputsFp16(const float16_t *x, const float16_t *y, float16_t *output, const WhereArgs *param, + int task_id, int thread_num) { + const bool *condition = param->condition_; + int stride = UP_DIV(param->max_num_, thread_num); + int begin = task_id * stride; + int end = MSMIN(begin + stride, param->max_num_); + + for (int i = begin; i < end; ++i) { + bool cond = condition[param->condition_num_ > 1 ? i : 0]; + if (cond) { + output[i] = x[param->x_num_ > 1 ? i : 0]; + } else { + output[i] = y[param->y_num_ > 1 ? i : 0]; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/where_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/where_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..afad274e6d286af4806e7eab405c93ce218194f8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/where_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_WHERE_FP16_H_ +#define NNACL_FP16_WHERE_FP16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/where_parameter.h" +#include "nnacl/kernel/where.h" + +#ifdef __cplusplus +extern "C" { +#endif +void WhereWithTripleInputsFp16(const float16_t *x, const float16_t *y, float16_t *output, const WhereArgs *param, + int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_WHERE_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_transform_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_transform_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..18577d401f62c73c5168f807b0a5518602323307 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_transform_fp16.c @@ -0,0 +1,360 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/winograd_transform_fp16.h" + +void PrepareTransInputFp16(const float16_t *src_data, float16_t *dst_data, int interval_x_s, int interval_x_e, + int interval_y_s, int interval_y_e, int real_c, const ConvParameter *conv_param) { + int input_unit = conv_param->input_unit_; + int in_channel = conv_param->input_channel_; + int input_w = conv_param->input_w_; + + // clear tmp buffer + if (interval_x_e - interval_x_s != input_unit || interval_y_e - interval_y_s != input_unit) { + memset(dst_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t)); + } + + // get real input block with padding + if (real_c == C8NUM) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = src_data + src_x_offset; + float16_t *dst_addr = dst_data + dst_x_offset; +#ifdef ENABLE_NEON + vst1q_f16(dst_addr, vld1q_f16(src_addr)); +#else + for (int k = 0; k < C8NUM; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + } + } + } else if (real_c < 8 && real_c >= 4) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = src_data + src_x_offset; + float16_t *dst_addr = dst_data + dst_x_offset; + int rc = real_c - 4; +#ifdef ENABLE_NEON + vst1_f16(dst_addr, vld1_f16(src_addr)); +#else + for (int k = 0; k < C4NUM; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + src_addr += 4; + dst_addr += 4; + for (int i = 0; i < rc; ++i) { + dst_addr[i] = src_addr[i]; + } + } + } + } else { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = src_data + src_x_offset; + float16_t *dst_addr = dst_data + dst_x_offset; + for (int k = 0; k < real_c; k++) { + dst_addr[k] = src_addr[k]; + } + } + } + } +} + +// fp16 common winograd +void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransFp16Func func) { +#ifdef ENABLE_ARM64 + const int tile_num = 16; +#else + const int tile_num = 12; +#endif + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int ic8 = UP_DIV(in_channel, C8NUM); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + if (out_w_block_num == 0) { + return; + } + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * in_channel; + for (int ic = 0; ic < ic8; ic++) { + int real_c = in_channel - ic * C8NUM; + real_c = real_c > C8NUM ? C8NUM : real_c; + const float16_t *src_data = input_data + src_plane_offset + ic * C8NUM; + PrepareTransInputFp16(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, + conv_param); + + // input transform + int dst_ic8_offset = dst_plane_offset + ic * C8NUM; + size_t dst_step = in_channel * tile_num; + float16_t *trans_input_ptr = trans_input + dst_ic8_offset; + func(tmp_data, trans_input_ptr, C8NUM, dst_step, real_c); + } + out_tile_index++; + } // cal_tile_num loop +} + +// Only support arm64 +void WinogradInputTransformOptStepFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int cal_num, int out_tile_index, int out_w_block_num, + const ConvParameter *conv_param, InputTransStepFp16Func func) { + const int tile_num = 16; + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int ic8 = UP_DIV(in_channel, C8NUM); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + if (out_w_block_num == 0) { + return; + } + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * C8NUM; + for (int ic = 0; ic < ic8; ic++) { + int real_c = in_channel - ic * C8NUM; + real_c = real_c > C8NUM ? C8NUM : real_c; + const float16_t *src_data = input_data + src_plane_offset + ic * C8NUM; + PrepareTransInputFp16(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, + conv_param); + + // input transform + int dst_ic8_offset = dst_plane_offset + ic * tile_num * input_unit * input_unit * C8NUM; + size_t dst_step = input_unit * tile_num * C8NUM; + float16_t *trans_input_ptr = trans_input + dst_ic8_offset; + func(tmp_data, trans_input_ptr, C8NUM, dst_step, tile_num * C8NUM); + } + out_tile_index++; + } // cal_tile_num loop +} + +void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, + int cal_num, int out_tile_index, int output_unit_num, + const ConvParameter *conv_param, OutputTransFp16Func func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int output_channel = conv_param->output_channel_; + int oc8 = UP_DIV(output_channel, C8NUM); + int input_unit = conv_param->input_unit_; + NNACL_CHECK_ZERO_RETURN(output_unit_num); + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; + int dst_tile_offset = output_channel * (dst_x_s + dst_y_s * output_w); + + for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; + int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; + int dst_oc8_offset = dst_tile_offset + j * C8NUM; + const float16_t *src_ptr = gemm_out + src_oc8_offset; + const float16_t *bias_ptr = bias_data + j * C8NUM; + float16_t *dst_ptr = tmp_out_data + dst_oc8_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c); + } + out_tile_index++; + } +} + +void WinogradOutputNC8HW8TransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, + int cal_num, int out_tile_index, int output_unit_num, + const ConvParameter *conv_param, OutputTransFp16Func func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int plane = output_w * output_h; + int output_channel = conv_param->output_channel_; + int oc8 = UP_DIV(output_channel, C8NUM); + int input_unit = conv_param->input_unit_; + NNACL_CHECK_ZERO_RETURN(output_unit_num); + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; + int dst_tile_offset = dst_x_s + dst_y_s * output_w; + + for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; + int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; + int dst_oc8_offset = (dst_tile_offset + plane * j) * C8NUM; + const float16_t *src_ptr = gemm_out + src_oc8_offset; + const float16_t *bias_ptr = bias_data + j * C8NUM; + float16_t *dst_ptr = tmp_out_data + dst_oc8_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, r_c, r_w, r_h, r_c); + } + out_tile_index++; + } +} + +int WinogradWeightTransformFp16(const float16_t *weight_data, float16_t *winograd_data, const float *matrix_g, + const float *matrix_gt, int oc_block, int input_unit, int kernel_unit, + int filter_channel, int filter_batch, bool pack) { + // original weight format : ohwi + int oc_block_num = UP_DIV(filter_batch, oc_block); + int block_stride = filter_channel * oc_block; + int block_num_stride = block_stride * oc_block_num; + + float16_t *matrix_gt_data_fp16 = (float16_t *)(malloc(input_unit * kernel_unit * sizeof(float16_t))); + if (matrix_gt_data_fp16 == NULL) { + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } + Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit * kernel_unit); + + // trans_filter = G*g*GT (g represents weight_data) = [(g * (G)T)T * (G)T]T + // separate into two steps ===> tmp = (g * (G)T)T ===> out = [tmp * (G)T]T + float16_t *tmp_data = (float16_t *)(malloc(filter_channel * input_unit * kernel_unit * sizeof(float16_t))); + if (tmp_data == NULL) { + free(matrix_gt_data_fp16); + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } + float16_t *trans_out_data = (float16_t *)(malloc(filter_channel * input_unit * input_unit * sizeof(float16_t))); + if (trans_out_data == NULL) { + free(tmp_data); + free(matrix_gt_data_fp16); + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } + +#ifndef ENABLE_ARM64 + float16_t *tmp_data1 = (float16_t *)(malloc(filter_channel * input_unit * kernel_unit * sizeof(float16_t))); + if (tmp_data1 == NULL) { + free(tmp_data); + free(matrix_gt_data_fp16); + free(trans_out_data); + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } + float16_t *trans_out_data1 = (float16_t *)(malloc(filter_channel * input_unit * input_unit * sizeof(float16_t))); + if (trans_out_data1 == NULL) { + free(tmp_data); + free(tmp_data1); + free(matrix_gt_data_fp16); + free(trans_out_data); + return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR; + } +#endif + + int input_oz_offset = kernel_unit * kernel_unit * filter_channel; + for (int i = 0; i < filter_batch; i++) { + int out_c_block = i / oc_block; + int out_c_res = i % oc_block; + int output_oz_offset = out_c_block * block_stride + out_c_res; + +#ifndef ENABLE_ARM64 + // tmp_data = g * GT + MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit, + kernel_unit, input_unit, filter_channel); + // tmp_data1 = (tmp_data)T + PackHWCToWHCFp16(tmp_data, tmp_data1, kernel_unit, input_unit, filter_channel); + // trans_out_data1 = tmp * GT + MatrixMultiplyWinogradFp16(tmp_data1, matrix_gt_data_fp16, trans_out_data1, input_unit, kernel_unit, input_unit, + filter_channel); + // trans_out_data = (trans_out_data1)T + PackHWCToWHCFp16(trans_out_data1, trans_out_data, input_unit, input_unit, filter_channel); +#else + // tmp = (g * GT)T + MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit, + kernel_unit, input_unit, filter_channel); + // trans = (tmp * GT)T + MatrixMultiplyWinogradFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit, kernel_unit, input_unit, + filter_channel); +#endif + + if (pack) { + int in_offset = 0; + for (int j = 0; j < input_unit; ++j) { + for (int k = 0; k < input_unit; ++k) { + for (int c = 0; c < filter_channel; ++c) { + *(winograd_data + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c]; + } + in_offset += filter_channel; + output_oz_offset += block_num_stride; + } + } + } else { + memcpy(winograd_data + i * filter_channel * input_unit * input_unit, trans_out_data, + filter_channel * input_unit * input_unit * sizeof(float16_t)); + } + } + +#ifndef ENABLE_ARM64 + free(tmp_data1); + free(trans_out_data1); +#endif + free(tmp_data); + free(trans_out_data); + free(matrix_gt_data_fp16); + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_transform_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_transform_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..b6e98e42f83c35afc49f8446c71da21f7ea491fa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_transform_fp16.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_WINOGRAD_TRANSFORM_FP16_H_ +#define NNACL_FP16_WINOGRAD_TRANSFORM_FP16_H_ + +#include +#include +#include "nnacl/errorcode.h" +#include "nnacl/fp16/cast_fp16.h" +#include "nnacl/fp16/conv_fp16.h" +#include "nnacl/fp16/matrix_fp16.h" +#include "nnacl/fp16/pack_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +// fp16 common winograd +void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransFp16Func func); + +void WinogradInputTransformOptStepFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, + int cal_num, int out_tile_index, int out_w_block_num, + const ConvParameter *conv_param, InputTransStepFp16Func func); + +void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, + int cal_num, int out_tile_index, int output_unit_num, + const ConvParameter *conv_param, OutputTransFp16Func func); + +void WinogradOutputNC8HW8TransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, + int cal_num, int out_tile_index, int output_unit_num, + const ConvParameter *conv_param, OutputTransFp16Func func); + +// fp16 winograd weight trans +int WinogradWeightTransformFp16(const float16_t *weight_data, float16_t *winograd_data, const float *matrix_g, + const float *matrix_gt, int oc_block, int input_unit, int kernel_unit, + int filter_channel, int filter_batch, bool pack); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_WINOGRAD_TRANSFORM_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..7757d2fb42cf868664d0a3048d0f285ee44030ca --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.c @@ -0,0 +1,3278 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/winograd_utils_fp16.h" +#include "nnacl/fp16/matrix_fp16.h" + +#define MIN_UNIT_FP16 2 +#define MAX_UNIT_FP16 4 + +#ifdef ENABLE_ARM64 +void transpose8(float16x8_t *s0, float16x8_t *s1, float16x8_t *s2, float16x8_t *s3, float16x8_t *s4, float16x8_t *s5, + float16x8_t *s6, float16x8_t *s7) { + float32x4_t m0 = (float32x4_t)(vtrn1q_f16(*s0, *s1)); + float32x4_t m1 = (float32x4_t)(vtrn2q_f16(*s0, *s1)); + float32x4_t m2 = (float32x4_t)(vtrn1q_f16(*s2, *s3)); + float32x4_t m3 = (float32x4_t)(vtrn2q_f16(*s2, *s3)); + float32x4_t m4 = (float32x4_t)(vtrn1q_f16(*s4, *s5)); + float32x4_t m5 = (float32x4_t)(vtrn2q_f16(*s4, *s5)); + float32x4_t m6 = (float32x4_t)(vtrn1q_f16(*s6, *s7)); + float32x4_t m7 = (float32x4_t)(vtrn2q_f16(*s6, *s7)); + + float64x2_t t0 = (float64x2_t)(vtrn1q_f32(m0, m2)); + float64x2_t t2 = (float64x2_t)(vtrn2q_f32(m0, m2)); + float64x2_t t1 = (float64x2_t)(vtrn1q_f32(m1, m3)); + float64x2_t t3 = (float64x2_t)(vtrn2q_f32(m1, m3)); + float64x2_t t4 = (float64x2_t)(vtrn1q_f32(m4, m6)); + float64x2_t t6 = (float64x2_t)(vtrn2q_f32(m4, m6)); + float64x2_t t5 = (float64x2_t)(vtrn1q_f32(m5, m7)); + float64x2_t t7 = (float64x2_t)(vtrn2q_f32(m5, m7)); + + *s0 = (float16x8_t)(vtrn1q_f64(t0, t4)); + *s4 = (float16x8_t)(vtrn2q_f64(t0, t4)); + *s1 = (float16x8_t)(vtrn1q_f64(t1, t5)); + *s5 = (float16x8_t)(vtrn2q_f64(t1, t5)); + *s2 = (float16x8_t)(vtrn1q_f64(t2, t6)); + *s6 = (float16x8_t)(vtrn2q_f64(t2, t6)); + *s3 = (float16x8_t)(vtrn1q_f64(t3, t7)); + *s7 = (float16x8_t)(vtrn2q_f64(t3, t7)); +} +#endif + +static InputTransFp16Func InputTransFp16FuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4UnitFp16, NULL, InputTransform6x6UnitFp16, NULL, InputTransform8x8UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncList4[] = {NULL, NULL, OutputTransform4x2UnitFp16, + OutputTransform4x3UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncReluList4[] = {NULL, NULL, OutputTransform4x2ReluUnitFp16, + OutputTransform4x3ReluUnitFp16}; +static OutputTransFp16Func OutputTransFp16FuncRelu6List4[] = {NULL, NULL, OutputTransform4x2Relu6UnitFp16, + OutputTransform4x3Relu6UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncList6[] = {NULL, + NULL, + OutputTransform6x2UnitFp16, + OutputTransform6x3UnitFp16, + OutputTransform6x4UnitFp16, + OutputTransform6x5UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncReluList6[] = {NULL, + NULL, + OutputTransform6x2ReluUnitFp16, + OutputTransform6x3ReluUnitFp16, + OutputTransform6x4ReluUnitFp16, + OutputTransform6x5ReluUnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncRelu6List6[] = {NULL, + NULL, + OutputTransform6x2Relu6UnitFp16, + OutputTransform6x3Relu6UnitFp16, + OutputTransform6x4Relu6UnitFp16, + OutputTransform6x5Relu6UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncList8[] = {NULL, + NULL, + OutputTransform8x2UnitFp16, + OutputTransform8x3UnitFp16, + OutputTransform8x4UnitFp16, + OutputTransform8x5UnitFp16, + OutputTransform8x6UnitFp16, + OutputTransform8x7UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncReluList8[] = {NULL, + NULL, + OutputTransform8x2ReluUnitFp16, + OutputTransform8x3ReluUnitFp16, + OutputTransform8x4ReluUnitFp16, + OutputTransform8x5ReluUnitFp16, + OutputTransform8x6ReluUnitFp16, + OutputTransform8x7ReluUnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncRelu6List8[] = {NULL, + NULL, + OutputTransform8x2Relu6UnitFp16, + OutputTransform8x3Relu6UnitFp16, + OutputTransform8x4Relu6UnitFp16, + OutputTransform8x5Relu6UnitFp16, + OutputTransform8x6Relu6UnitFp16, + OutputTransform8x7Relu6UnitFp16}; + +InputTransFp16Func GetInputTransFp16Func(int input_unit) { return InputTransFp16FuncList[input_unit]; } + +#ifdef ENABLE_ARM64 +static InputTransStepFp16Func InputTransStepFp16FuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4StepFp16, NULL, InputTransform6x6StepFp16, NULL, InputTransform8x8StepFp16}; + +static InputTransPackFp16Func InputTransPackFp16FuncList[] = {NULL, + NULL, + NULL, + NULL, + InputTransform4x4Pack16Fp16, + NULL, + InputTransform6x6Pack16Fp16, + NULL, + InputTransform8x8Pack16Fp16}; + +InputTransStepFp16Func GetInputTransStepFp16Func(int input_unit) { return InputTransStepFp16FuncList[input_unit]; } + +InputTransPackFp16Func GetInputTransPackFp16Func(int input_unit) { return InputTransPackFp16FuncList[input_unit]; } +#endif + +void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int j = 0; + if (real_c == 8) { + float16x8_t src[16]; + float16x8_t t[16]; + float16x8_t m[16]; + Load16DataFp16; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vsubq_f16(src[offset], src[2 + offset]); + t[4 + l] = vaddq_f16(src[1 + offset], src[2 + offset]); + t[8 + l] = vsubq_f16(src[2 + offset], src[1 + offset]); + t[12 + l] = vsubq_f16(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = vsubq_f16(t[offset], t[2 + offset]); + m[4 + l] = vaddq_f16(t[1 + offset], t[2 + offset]); + m[8 + l] = vsubq_f16(t[2 + offset], t[1 + offset]); + m[12 + l] = vsubq_f16(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + vst1q_f16(dst_data + dst_offset, m[i]); + } + real_c -= 8; + } else if (real_c < 8 && real_c >= 4) { + float16x4_t src[16]; + float16x4_t t[16]; + float16x4_t m[16]; + Load16DataC4Fp16; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vsub_f16(src[offset], src[2 + offset]); + t[4 + l] = vadd_f16(src[1 + offset], src[2 + offset]); + t[8 + l] = vsub_f16(src[2 + offset], src[1 + offset]); + t[12 + l] = vsub_f16(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = vsub_f16(t[offset], t[2 + offset]); + m[4 + l] = vadd_f16(t[1 + offset], t[2 + offset]); + m[8 + l] = vsub_f16(t[2 + offset], t[1 + offset]); + m[12 + l] = vsub_f16(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, m[i]); + } + j = 4; + } + for (; j < real_c; ++j) { + float16_t src[16]; + float16_t t[16]; + float16_t m[16]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[j + k * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] - src[2 + offset]; + t[4 + l] = src[1 + offset] + src[2 + offset]; + t[8 + l] = src[2 + offset] - src[1 + offset]; + t[12 + l] = src[3 + offset] - src[1 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = t[offset] - t[2 + offset]; + m[4 + l] = t[1 + offset] + t[2 + offset]; + m[8 + l] = t[2 + offset] - t[1 + offset]; + m[12 + l] = t[3 + offset] - t[1 + offset]; + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + dst_data[j + dst_offset] = m[i]; + } + } +} + +void InputTransform4x4StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step) { + for (int l = 0; l < 4; ++l) { + const float16_t *src_ptr = src_data + l * 4 * src_step; + float16_t *dst_ptr = dst_data + l * dst_row_step; + + float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step); + float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step); + float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step); + float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step); + float16x8_t m0 = vsubq_f16(s0, s2); + float16x8_t m1 = vaddq_f16(s1, s2); + float16x8_t m2 = vsubq_f16(s2, s1); + float16x8_t m3 = vsubq_f16(s3, s1); + + vst1q_f16(dst_ptr + 0 * dst_step, m0); + vst1q_f16(dst_ptr + 1 * dst_step, m1); + vst1q_f16(dst_ptr + 2 * dst_step, m2); + vst1q_f16(dst_ptr + 3 * dst_step, m3); + } +} + +#ifdef ENABLE_ARM64 +void InputTransform4x4Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile, + int src_point_stride) { + LOAD_LINE_DATA_FP16(0); + LOAD_LINE_DATA_FP16(1); + LOAD_LINE_DATA_FP16(2); + LOAD_LINE_DATA_FP16(3); + + float16x8_t m0 = vsubq_f16(s00, s20); + float16x8_t m1 = vsubq_f16(s01, s21); + vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(s10, s20); + m1 = vaddq_f16(s11, s21); + vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + + m0 = vsubq_f16(s20, s10); + m1 = vsubq_f16(s21, s11); + vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + + m0 = vsubq_f16(s30, s10); + m1 = vsubq_f16(s31, s11); + vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); +} + +void InputTransform4x4Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 16; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; + for (int l = 0; l < 4; ++l) { + float16_t *src_ptr = src_data + l * C8NUM * block_tile; + TRANSPOSE_16x8; + } + + for (int c = 0; c < real_c; ++c) { + float16_t *src_ptr = src_data + c * block_tile; + float16_t *dst_ptr = dst_data + c * block_tile; + InputTransform4x4Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +} +#endif + +void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int j = 0; + if (real_c == 8) { + float16x8_t src[36]; + float16x8_t t[36]; + float16x8_t m[36]; + Load36DataFp16; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vsubq_f16(src[3 + offset], src[1 + offset]); + float16x8_t tmp2 = vsubq_f16(src[4 + offset], src[2 + offset]); + t[l] = vaddq_f16(vsubq_f16(vmulq_n_f16(src[offset], 4), vmulq_n_f16(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = vaddq_f16(vmulq_n_f16(vaddq_f16(src[1 + offset], src[2 + offset]), -4), + vaddq_f16(src[3 + offset], src[4 + offset])); + t[12 + l] = vaddq_f16(vmulq_n_f16(vsubq_f16(src[1 + offset], src[2 + offset]), 4), + vsubq_f16(src[4 + offset], src[3 + offset])); + t[18 + l] = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + t[24 + l] = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + t[30 + l] = + vaddq_f16(vsubq_f16(vmulq_n_f16(src[1 + offset], 4), vmulq_n_f16(src[3 + offset], 5)), src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vsubq_f16(t[3 + offset], t[1 + offset]); + float16x8_t tmp2 = vsubq_f16(t[4 + offset], t[2 + offset]); + m[l] = vaddq_f16(vsubq_f16(vmulq_n_f16(t[offset], 4), vmulq_n_f16(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = + vaddq_f16(vmulq_n_f16(vaddq_f16(t[1 + offset], t[2 + offset]), -4), vaddq_f16(t[3 + offset], t[4 + offset])); + m[12 + l] = + vaddq_f16(vmulq_n_f16(vsubq_f16(t[1 + offset], t[2 + offset]), 4), vsubq_f16(t[4 + offset], t[3 + offset])); + m[18 + l] = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + m[24 + l] = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + m[30 + l] = vaddq_f16(vsubq_f16(vmulq_n_f16(t[1 + offset], 4), vmulq_n_f16(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + vst1q_f16(dst_data + dst_offset, m[i]); + } + real_c -= 8; + } else if (real_c < 8 && real_c >= 4) { + float16x4_t src[36]; + float16x4_t t[36]; + float16x4_t m[36]; + Load36DataC4Fp16; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x4_t tmp1 = vsub_f16(src[3 + offset], src[1 + offset]); + float16x4_t tmp2 = vsub_f16(src[4 + offset], src[2 + offset]); + t[l] = vadd_f16(vsub_f16(vmul_n_f16(src[offset], 4), vmul_n_f16(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = vadd_f16(vmul_n_f16(vadd_f16(src[1 + offset], src[2 + offset]), -4), + vadd_f16(src[3 + offset], src[4 + offset])); + t[12 + l] = + vadd_f16(vmul_n_f16(vsub_f16(src[1 + offset], src[2 + offset]), 4), vsub_f16(src[4 + offset], src[3 + offset])); + t[18 + l] = vadd_f16(vmul_n_f16(tmp1, 2), tmp2); + t[24 + l] = vadd_f16(vmul_n_f16(tmp1, -2), tmp2); + t[30 + l] = vadd_f16(vsub_f16(vmul_n_f16(src[1 + offset], 4), vmul_n_f16(src[3 + offset], 5)), src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x4_t tmp1 = vsub_f16(t[3 + offset], t[1 + offset]); + float16x4_t tmp2 = vsub_f16(t[4 + offset], t[2 + offset]); + m[l] = vadd_f16(vsub_f16(vmul_n_f16(t[offset], 4), vmul_n_f16(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = + vadd_f16(vmul_n_f16(vadd_f16(t[1 + offset], t[2 + offset]), -4), vadd_f16(t[3 + offset], t[4 + offset])); + m[12 + l] = + vadd_f16(vmul_n_f16(vsub_f16(t[1 + offset], t[2 + offset]), 4), vsub_f16(t[4 + offset], t[3 + offset])); + m[18 + l] = vadd_f16(vmul_n_f16(tmp1, 2), tmp2); + m[24 + l] = vadd_f16(vmul_n_f16(tmp1, -2), tmp2); + m[30 + l] = vadd_f16(vsub_f16(vmul_n_f16(t[1 + offset], 4), vmul_n_f16(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, m[i]); + } + j = 4; + } + for (; j < real_c; ++j) { + float16_t src[36]; + float16_t t[36]; + float16_t m[36]; + for (int k = 0; k < 36; ++k) { + src[k] = src_data[j + k * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16_t tmp1 = src[3 + offset] - src[1 + offset]; + float16_t tmp2 = src[4 + offset] - src[2 + offset]; + t[l] = src[offset] * 4 - src[2 + offset] * 5 + src[4 + offset]; + t[6 + l] = (src[1 + offset] + src[2 + offset]) * -4 + (src[3 + offset] + src[4 + offset]); + t[12 + l] = (src[1 + offset] - src[2 + offset]) * 4 + (src[4 + offset] - src[3 + offset]); + t[18 + l] = tmp1 * 2 + tmp2; + t[24 + l] = tmp1 * -2 + tmp2; + t[30 + l] = src[1 + offset] * 4 - src[3 + offset] * 5 + src[5 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16_t tmp1 = t[3 + offset] - t[1 + offset]; + float16_t tmp2 = t[4 + offset] - t[2 + offset]; + m[l] = t[offset] * 4 - t[2 + offset] * 5 + t[4 + offset]; + m[6 + l] = (t[1 + offset] + t[2 + offset]) * -4 + (t[3 + offset] + t[4 + offset]); + m[12 + l] = (t[1 + offset] - t[2 + offset]) * 4 + (t[4 + offset] - t[3 + offset]); + m[18 + l] = tmp1 * 2 + tmp2; + m[24 + l] = tmp1 * -2 + tmp2; + m[30 + l] = t[1 + offset] * 4 - t[3 + offset] * 5 + t[5 + offset]; + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + dst_data[j + dst_offset] = m[i]; + } + } +} + +void InputTransform6x6StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step) { + for (int l = 0; l < 6; ++l) { + const float16_t *src_ptr = src_data + l * 6 * src_step; + float16_t *dst_ptr = dst_data + l * dst_row_step; + + float16x8_t s0 = vld1q_f16(src_ptr + 0 * src_step); + float16x8_t s1 = vld1q_f16(src_ptr + 1 * src_step); + float16x8_t s2 = vld1q_f16(src_ptr + 2 * src_step); + float16x8_t s3 = vld1q_f16(src_ptr + 3 * src_step); + float16x8_t s4 = vld1q_f16(src_ptr + 4 * src_step); + float16x8_t s5 = vld1q_f16(src_ptr + 5 * src_step); + + float16x8_t tmp1 = vsubq_f16(s3, s1); + float16x8_t tmp2 = vsubq_f16(s4, s2); + float16x8_t m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s0, 4), vmulq_n_f16(s2, 5)), s4); + float16x8_t m1 = vaddq_f16(vmulq_n_f16(vaddq_f16(s1, s2), -4), vaddq_f16(s3, s4)); + float16x8_t m2 = vaddq_f16(vmulq_n_f16(vsubq_f16(s1, s2), 4), vsubq_f16(s4, s3)); + float16x8_t m3 = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + float16x8_t m4 = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + float16x8_t m5 = vaddq_f16(vsubq_f16(vmulq_n_f16(s1, 4), vmulq_n_f16(s3, 5)), s5); + + vst1q_f16(dst_ptr + 0 * dst_step, m0); + vst1q_f16(dst_ptr + 1 * dst_step, m1); + vst1q_f16(dst_ptr + 2 * dst_step, m2); + vst1q_f16(dst_ptr + 3 * dst_step, m3); + vst1q_f16(dst_ptr + 4 * dst_step, m4); + vst1q_f16(dst_ptr + 5 * dst_step, m5); + } +} + +#ifdef ENABLE_ARM64 +void InputTransform6x6Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile, + int src_point_stride) { + LOAD_LINE_DATA_FP16(0); + LOAD_LINE_DATA_FP16(1); + LOAD_LINE_DATA_FP16(2); + LOAD_LINE_DATA_FP16(3); + LOAD_LINE_DATA_FP16(4); + LOAD_LINE_DATA_FP16(5); + + float16x8_t m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s00, 4), vmulq_n_f16(s20, 5)), s40); + float16x8_t m1 = vaddq_f16(vsubq_f16(vmulq_n_f16(s01, 4), vmulq_n_f16(s21, 5)), s41); + vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vaddq_f16(s10, s20), -4), vaddq_f16(s30, s40)); + m1 = vaddq_f16(vmulq_n_f16(vaddq_f16(s11, s21), -4), vaddq_f16(s31, s41)); + vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s10, s20), 4), vsubq_f16(s40, s30)); + m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s11, s21), 4), vsubq_f16(s41, s31)); + vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s30, s10), 2), vsubq_f16(s40, s20)); + m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s31, s11), 2), vsubq_f16(s41, s21)); + vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vmulq_n_f16(vsubq_f16(s30, s10), -2), vsubq_f16(s40, s20)); + m1 = vaddq_f16(vmulq_n_f16(vsubq_f16(s31, s11), -2), vsubq_f16(s41, s21)); + vst1q_f16(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vsubq_f16(vmulq_n_f16(s10, 4), vmulq_n_f16(s30, 5)), s50); + m1 = vaddq_f16(vsubq_f16(vmulq_n_f16(s11, 4), vmulq_n_f16(s31, 5)), s51); + vst1q_f16(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); +} + +void InputTransform6x6Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 16; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; + for (int l = 0; l < 6; ++l) { + float16_t *src_ptr = src_data + l * C8NUM * block_tile; + TRANSPOSE_16x8; + } + + for (int c = 0; c < real_c; ++c) { + float16_t *src_ptr = src_data + c * block_tile; + float16_t *dst_ptr = dst_data + c * block_tile; + InputTransform6x6Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +} +#endif + +void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int j = 0; + if (real_c == 8) { + float16x8_t src[64]; + float16x8_t t[64]; + float16x8_t m[64]; + Load64DataFp16; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(src[offset], 0.5625), vmulq_n_f16(src[2 + offset], 3.0625)), + vmulq_n_f16(src[4 + offset], 3.5)), + src[6 + offset]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 1.125), vmulq_n_f16(src[5 + offset], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 2.25), vmulq_n_f16(src[4 + offset], 3.25)); + t[8 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 0.5625), vmulq_n_f16(src[4 + offset], 2.5)); + t[24 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 0.375), vmulq_n_f16(src[5 + offset], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 0.25), vmulq_n_f16(src[4 + offset], 1.25)); + t[40 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = + vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(src[1 + offset], -0.5625), vmulq_n_f16(src[3 + offset], 3.0625)), + vmulq_n_f16(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(t[offset], 0.5625), vmulq_n_f16(t[2 + offset], 3.0625)), + vmulq_n_f16(t[4 + offset], 3.5)), + t[6 + offset]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 1.125), vmulq_n_f16(t[5 + offset], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 2.25), vmulq_n_f16(t[4 + offset], 3.25)); + m[8 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 0.5625), vmulq_n_f16(t[4 + offset], 2.5)); + m[24 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 0.375), vmulq_n_f16(t[5 + offset], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 0.25), vmulq_n_f16(t[4 + offset], 1.25)); + m[40 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = + vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(t[1 + offset], -0.5625), vmulq_n_f16(t[3 + offset], 3.0625)), + vmulq_n_f16(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + vst1q_f16(dst_data + dst_offset, m[i]); + } + real_c -= 8; + } else if (real_c < 8 && real_c >= 4) { + float16x4_t src[64]; + float16x4_t t[64]; + float16x4_t m[64]; + Load64DataC4Fp16; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = vsub_f16(vadd_f16(vsub_f16(vmul_n_f16(src[offset], 0.5625), vmul_n_f16(src[2 + offset], 3.0625)), + vmul_n_f16(src[4 + offset], 3.5)), + src[6 + offset]); + float16x4_t tmp1 = vadd_f16(vmul_n_f16(src[1 + offset], 1.125), vmul_n_f16(src[5 + offset], 0.5)); + float16x4_t tmp2 = vsub_f16(vmul_n_f16(src[2 + offset], 2.25), vmul_n_f16(src[4 + offset], 3.25)); + t[8 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = vsub_f16(vmul_n_f16(src[2 + offset], 0.5625), vmul_n_f16(src[4 + offset], 2.5)); + t[24 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(src[1 + offset], 0.375), vmul_n_f16(src[5 + offset], 1.5)); + tmp2 = vsub_f16(vmul_n_f16(src[2 + offset], 0.25), vmul_n_f16(src[4 + offset], 1.25)); + t[40 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = vadd_f16(vsub_f16(vadd_f16(vmul_n_f16(src[1 + offset], -0.5625), vmul_n_f16(src[3 + offset], 3.0625)), + vmul_n_f16(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = vsub_f16(vadd_f16(vsub_f16(vmul_n_f16(t[offset], 0.5625), vmul_n_f16(t[2 + offset], 3.0625)), + vmul_n_f16(t[4 + offset], 3.5)), + t[6 + offset]); + float16x4_t tmp1 = vadd_f16(vmul_n_f16(t[1 + offset], 1.125), vmul_n_f16(t[5 + offset], 0.5)); + float16x4_t tmp2 = vsub_f16(vmul_n_f16(t[2 + offset], 2.25), vmul_n_f16(t[4 + offset], 3.25)); + m[8 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = vsub_f16(vmul_n_f16(t[2 + offset], 0.5625), vmul_n_f16(t[4 + offset], 2.5)); + m[24 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(t[1 + offset], 0.375), vmul_n_f16(t[5 + offset], 1.5)); + tmp2 = vsub_f16(vmul_n_f16(t[2 + offset], 0.25), vmul_n_f16(t[4 + offset], 1.25)); + m[40 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = vadd_f16(vsub_f16(vadd_f16(vmul_n_f16(t[1 + offset], -0.5625), vmul_n_f16(t[3 + offset], 3.0625)), + vmul_n_f16(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, m[i]); + } + j = 4; + } + for (; j < real_c; ++j) { + float16_t src[64]; + float16_t t[64]; + float16_t m[64]; + for (int k = 0; k < 64; ++k) { + src[k] = src_data[j + k * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] * 0.5625f - src[2 + offset] * 3.0625f + src[4 + offset] * 3.5f - src[6 + offset]; + float16_t tmp1 = src[1 + offset] * 1.125f + src[5 + offset] * 0.5f; + float16_t tmp2 = src[2 + offset] * 2.25f - src[4 + offset] * 3.25f; + t[8 + l] = tmp1 + tmp2 - src[3 + offset] * 1.625f + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + src[3 + offset] * 1.625f + src[6 + offset]; + tmp1 = src[1 + offset] * 0.5625f + src[5 + offset]; + tmp2 = src[2 + offset] * 0.5625f - src[4 + offset] * 2.5f; + t[24 + l] = tmp1 + tmp2 - src[3 + offset] * 2.5f + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + src[3 + offset] * 2.5f + src[6 + offset]; + tmp1 = src[1 + offset] * 0.375f + src[5 + offset] * 1.5f; + tmp2 = src[2 + offset] * 0.25f - src[4 + offset] * 1.25f; + t[40 + l] = tmp1 + tmp2 - src[3 + offset] * 1.875f + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + src[3 + offset] * 1.875f + src[6 + offset]; + t[56 + l] = src[1 + offset] * -0.5625 + src[3 + offset] * 3.0625f - src[5 + offset] * 3.5f + src[7 + offset]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = t[offset] * 0.5625f - t[2 + offset] * 3.0625f + t[4 + offset] * 3.5f - t[6 + offset]; + float16_t tmp1 = t[1 + offset] * 1.125f + t[5 + offset] * 0.5f; + float16_t tmp2 = t[2 + offset] * 2.25f - t[4 + offset] * 3.25f; + m[8 + l] = tmp1 + tmp2 - t[3 + offset] * 1.625f + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + t[3 + offset] * 1.625f + t[6 + offset]; + tmp1 = t[1 + offset] * 0.5625f + t[5 + offset]; + tmp2 = t[2 + offset] * 0.5625f - t[4 + offset] * 2.5f; + m[24 + l] = tmp1 + tmp2 - t[3 + offset] * 2.5f + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + t[3 + offset] * 2.5f + t[6 + offset]; + tmp1 = t[1 + offset] * 0.375f + t[5 + offset] * 1.5f; + tmp2 = t[2 + offset] * 0.25f - t[4 + offset] * 1.25f; + m[40 + l] = tmp1 + tmp2 - t[3 + offset] * 1.875f + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + t[3 + offset] * 1.875f + t[6 + offset]; + m[56 + l] = t[1 + offset] * -0.5625 + t[3 + offset] * 3.0625f - t[5 + offset] * 3.5f + t[7 + offset]; + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + dst_data[j + dst_offset] = m[i]; + } + } +} + +void InputTransform8x8StepFp16_uint(float16x8_t *s, float16x8_t *m) { + m[0] = + vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s[0], 0.5625), vmulq_n_f16(s[2], 3.0625)), vmulq_n_f16(s[4], 3.5)), s[6]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(s[1], 1.125), vmulq_n_f16(s[5], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(s[2], 2.25), vmulq_n_f16(s[4], 3.25)); + m[1] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 1.625)), s[6]); + m[2] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 1.625)), s[6]); + tmp1 = vaddq_f16(vmulq_n_f16(s[1], 0.5625), s[5]); + tmp2 = vsubq_f16(vmulq_n_f16(s[2], 0.5625), vmulq_n_f16(s[4], 2.5)); + m[3] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 2.5)), s[6]); + m[4] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 2.5)), s[6]); + tmp1 = vaddq_f16(vmulq_n_f16(s[1], 0.375), vmulq_n_f16(s[5], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(s[2], 0.25), vmulq_n_f16(s[4], 1.25)); + m[5] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(s[3], 1.875)), s[6]); + m[6] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(s[3], 1.875)), s[6]); + m[7] = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s[1], -0.5625), vmulq_n_f16(s[3], 3.0625)), vmulq_n_f16(s[5], 3.5)), + s[7]); +} + +void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step) { + for (int l = 0; l < 8; ++l) { + const float16_t *src_ptr = src_data + l * 8 * src_step; + float16_t *dst_ptr = dst_data + l * dst_row_step; + + float16x8_t s[8]; + float16x8_t m[8]; + + s[0] = vld1q_f16(src_ptr + 0 * src_step); + s[1] = vld1q_f16(src_ptr + 1 * src_step); + s[2] = vld1q_f16(src_ptr + 2 * src_step); + s[3] = vld1q_f16(src_ptr + 3 * src_step); + s[4] = vld1q_f16(src_ptr + 4 * src_step); + s[5] = vld1q_f16(src_ptr + 5 * src_step); + s[6] = vld1q_f16(src_ptr + 6 * src_step); + s[7] = vld1q_f16(src_ptr + 7 * src_step); + + InputTransform8x8StepFp16_uint(s, m); + + vst1q_f16(dst_ptr + 0 * dst_step, m[0]); + vst1q_f16(dst_ptr + 1 * dst_step, m[1]); + vst1q_f16(dst_ptr + 2 * dst_step, m[2]); + vst1q_f16(dst_ptr + 3 * dst_step, m[3]); + vst1q_f16(dst_ptr + 4 * dst_step, m[4]); + vst1q_f16(dst_ptr + 5 * dst_step, m[5]); + vst1q_f16(dst_ptr + 6 * dst_step, m[6]); + vst1q_f16(dst_ptr + 7 * dst_step, m[7]); + } +} + +#ifdef ENABLE_ARM64 +void InputTransform8x8Pack16ChannelFp16(float16_t *src_ptr, float16_t *dst_ptr, int dst_step, int pack_tile, + int src_point_stride) { + LOAD_LINE_DATA_FP16(0); + LOAD_LINE_DATA_FP16(1); + LOAD_LINE_DATA_FP16(2); + LOAD_LINE_DATA_FP16(3); + LOAD_LINE_DATA_FP16(4); + LOAD_LINE_DATA_FP16(5); + LOAD_LINE_DATA_FP16(6); + LOAD_LINE_DATA_FP16(7); + + float16x8_t m0 = + vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s00, 0.5625), vmulq_n_f16(s20, 3.0625)), vmulq_n_f16(s40, 3.5)), s60); + float16x8_t m1 = + vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(s01, 0.5625), vmulq_n_f16(s21, 3.0625)), vmulq_n_f16(s41, 3.5)), s61); + vst1q_f16(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + + float16x8_t tmp10 = vaddq_f16(vmulq_n_f16(s10, 1.125), vmulq_n_f16(s50, 0.5)); + float16x8_t tmp11 = vaddq_f16(vmulq_n_f16(s11, 1.125), vmulq_n_f16(s51, 0.5)); + float16x8_t tmp20 = vsubq_f16(vmulq_n_f16(s20, 2.25), vmulq_n_f16(s40, 3.25)); + float16x8_t tmp21 = vsubq_f16(vmulq_n_f16(s21, 2.25), vmulq_n_f16(s41, 3.25)); + m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 1.625)), s60); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 1.625)), s61); + vst1q_f16(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 1.625)), s60); + m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 1.625)), s61); + vst1q_f16(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + + tmp10 = vaddq_f16(vmulq_n_f16(s10, 0.5625), s50); + tmp11 = vaddq_f16(vmulq_n_f16(s11, 0.5625), s51); + tmp20 = vsubq_f16(vmulq_n_f16(s20, 0.5625), vmulq_n_f16(s40, 2.5)); + tmp21 = vsubq_f16(vmulq_n_f16(s21, 0.5625), vmulq_n_f16(s41, 2.5)); + m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 2.5)), s60); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 2.5)), s61); + vst1q_f16(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 2.5)), s60); + m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 2.5)), s61); + vst1q_f16(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + + tmp10 = vaddq_f16(vmulq_n_f16(s10, 0.375), vmulq_n_f16(s50, 1.5)); + tmp11 = vaddq_f16(vmulq_n_f16(s11, 0.375), vmulq_n_f16(s51, 1.5)); + tmp20 = vsubq_f16(vmulq_n_f16(s20, 0.25), vmulq_n_f16(s40, 1.25)); + tmp21 = vsubq_f16(vmulq_n_f16(s21, 0.25), vmulq_n_f16(s41, 1.25)); + m0 = vaddq_f16(vsubq_f16(vaddq_f16(tmp10, tmp20), vmulq_n_f16(s30, 1.875)), s60); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(tmp11, tmp21), vmulq_n_f16(s31, 1.875)), s61); + vst1q_f16(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vaddq_f16(vsubq_f16(tmp20, tmp10), vmulq_n_f16(s30, 1.875)), s60); + m1 = vaddq_f16(vaddq_f16(vsubq_f16(tmp21, tmp11), vmulq_n_f16(s31, 1.875)), s61); + vst1q_f16(dst_ptr + 6 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 6 * dst_step + 1 * pack_tile, m1); + + m0 = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s10, -0.5625), vmulq_n_f16(s30, 3.0625)), vmulq_n_f16(s50, 3.5)), s70); + m1 = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(s11, -0.5625), vmulq_n_f16(s31, 3.0625)), vmulq_n_f16(s51, 3.5)), s71); + vst1q_f16(dst_ptr + 7 * dst_step + 0 * pack_tile, m0); + vst1q_f16(dst_ptr + 7 * dst_step + 1 * pack_tile, m1); +} + +void InputTransform8x8Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 16; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; + for (int l = 0; l < 8; ++l) { + float16_t *src_ptr = src_data + l * C8NUM * block_tile; + TRANSPOSE_16x8; + } + + for (int c = 0; c < real_c; ++c) { + float16_t *src_ptr = src_data + c * block_tile; + float16_t *dst_ptr = dst_data + c * block_tile; + InputTransform8x8Pack16ChannelFp16(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +} +#endif + +OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type) { + if (input_unit == 4 && output_unit < 4) { + if (act_type == ActType_Relu) { + return OutputTransFp16FuncReluList4[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFp16FuncRelu6List4[output_unit]; + } else { + return OutputTransFp16FuncList4[output_unit]; + } + } else if (input_unit == 6 && output_unit < 6) { + if (act_type == ActType_Relu) { + return OutputTransFp16FuncReluList6[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFp16FuncRelu6List6[output_unit]; + } else { + return OutputTransFp16FuncList6[output_unit]; + } + } else if (input_unit == 8 && output_unit < 8) { + if (act_type == ActType_Relu) { + return OutputTransFp16FuncReluList8[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFp16FuncRelu6List8[output_unit]; + } else { + return OutputTransFp16FuncList8[output_unit]; + } + } else { + return NULL; + } +} + +void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[16]; + float16x4_t t[8]; + float16x4_t m[4]; + Load16DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vadd_f16(vadd_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vadd_f16(vsub_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vadd_f16(vadd_f16(vadd_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vadd_f16(vadd_f16(vsub_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_h == 2 && r_w == 2) { + Store4DataC4Fp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[16]; + float16_t t[8]; + float16_t m[4]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + bias_ptr; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset] + bias_ptr; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[16]; + float16x4_t t[8]; + float16x4_t m[4]; + float16x4_t zero = vdup_n_f16(0); + Load16DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vadd_f16(vadd_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vadd_f16(vsub_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vadd_f16(vadd_f16(vadd_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vadd_f16(vadd_f16(vsub_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l + 2] = vmax_f16(zero, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataC4Fp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[16]; + float16_t t[8]; + float16_t m[4]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + bias_ptr; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l + 2] = m[l + 2] > 0 ? m[l + 2] : 0; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform4x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[16]; + float16x4_t t[8]; + float16x4_t m[4]; + float16x4_t zero = vdup_n_f16(0); + float16x4_t six = vdup_n_f16(6); + Load16DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vadd_f16(vadd_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vadd_f16(vsub_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vadd_f16(vadd_f16(vadd_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vadd_f16(vadd_f16(vsub_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l] = vmin_f16(six, m[l]); + m[l + 2] = vmax_f16(zero, m[l + 2]); + m[l + 2] = vmin_f16(six, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataC4Fp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[16]; + float16_t t[8]; + float16_t m[4]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + bias_ptr; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l] = m[l] < 6 ? m[l] : 6; + m[l + 2] = m[l + 2] > 0 ? m[l + 2] : 0; + m[l + 2] = m[l + 2] < 6 ? m[l + 2] : 6; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[12]; + float16x8_t m[9]; + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f16(src[offset], tmp); + t[l + 4] = vsubq_f16(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f16(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f16(vaddq_f16(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(tmp, t[3 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform4x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[12]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f16(src[offset], tmp); + t[l + 4] = vsubq_f16(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f16(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f16(vaddq_f16(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(tmp, t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform4x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[12]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f16(src[offset], tmp); + t[l + 4] = vsubq_f16(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f16(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f16(vaddq_f16(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(tmp, t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 3] = vminq_f16(six, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[12]; + float16x8_t m[4]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f16(vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), + vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[12]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f16(vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), + vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[12]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f16(vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), + vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[18]; + float16x8_t m[9]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f16( + vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[18]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f16( + vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[18]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f16( + vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 3] = vminq_f16(six, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[24]; + float16x8_t m[16]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[24]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[24]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 4] = vminq_f16(six, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 8] = vminq_f16(six, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[30]; + float16x8_t m[25]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)); + t[l + 24] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[30]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)); + t[l + 24] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[30]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)); + t[l + 24] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 5] = vminq_f16(six, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 10] = vminq_f16(six, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 15] = vminq_f16(six, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + m[l + 20] = vminq_f16(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[16]; + float16x8_t m[4]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), t[7 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[16]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[16]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[24]; + float16x8_t m[9]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), t[7 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[24]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[24]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 3] = vminq_f16(six, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[32]; + float16x8_t m[16]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[32]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[32]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 4] = vminq_f16(six, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 8] = vminq_f16(six, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[40]; + float16x8_t m[25]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[40]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[40]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 5] = vminq_f16(six, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 10] = vminq_f16(six, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 15] = vminq_f16(six, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + m[l + 20] = vminq_f16(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[64]; + float16x4_t t[48]; + float16x4_t m[36]; + Load64DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp2 = vadd_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp3 = vadd_f16(src[5 + offset], src[6 + offset]); + float16x4_t tmp4 = vsub_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp5 = vsub_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp6 = vsub_f16(src[5 + offset], src[6 + offset]); + t[l] = vadd_f16(vadd_f16(vadd_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)); + t[l + 16] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)); + t[l + 24] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)); + t[l + 32] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)); + t[l + 40] = + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp2 = vadd_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp3 = vadd_f16(t[5 + offset], t[6 + offset]); + float16x4_t tmp4 = vsub_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp5 = vsub_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp6 = vsub_f16(t[5 + offset], t[6 + offset]); + m[l] = vadd_f16(vadd_f16(vadd_f16(vadd_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vadd_f16( + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[64]; + float16_t t[48]; + float16_t m[36]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16_t tmp1 = src[1 + offset] + src[2 + offset]; + float16_t tmp2 = src[3 + offset] + src[4 + offset]; + float16_t tmp3 = src[5 + offset] + src[6 + offset]; + float16_t tmp4 = src[1 + offset] - src[2 + offset]; + float16_t tmp5 = src[3 + offset] - src[4 + offset]; + float16_t tmp6 = src[5 + offset] - src[6 + offset]; + t[l] = src[offset] + tmp1 + tmp2 + tmp3; + t[l + 8] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f; + t[l + 16] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f; + t[l + 24] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f; + t[l + 32] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f; + t[l + 40] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16_t tmp1 = t[1 + offset] + t[2 + offset]; + float16_t tmp2 = t[3 + offset] + t[4 + offset]; + float16_t tmp3 = t[5 + offset] + t[6 + offset]; + float16_t tmp4 = t[1 + offset] - t[2 + offset]; + float16_t tmp5 = t[3 + offset] - t[4 + offset]; + float16_t tmp6 = t[5 + offset] - t[6 + offset]; + m[l] = t[offset] + tmp1 + tmp2 + tmp3 + bias_ptr; + m[l + 6] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f + bias_ptr; + m[l + 12] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f + bias_ptr; + m[l + 18] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f + bias_ptr; + m[l + 24] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f + bias_ptr; + m[l + 30] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + t[7 + offset] + bias_ptr; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform8x6ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 18] = vmaxq_f16(zero, m[l + 18]); + m[l + 24] = vmaxq_f16(zero, m[l + 24]); + m[l + 30] = vmaxq_f16(zero, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[64]; + float16x4_t t[48]; + float16x4_t m[36]; + float16x4_t zero = vdup_n_f16(0); + Load64DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp2 = vadd_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp3 = vadd_f16(src[5 + offset], src[6 + offset]); + float16x4_t tmp4 = vsub_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp5 = vsub_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp6 = vsub_f16(src[5 + offset], src[6 + offset]); + t[l] = vadd_f16(vadd_f16(vadd_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)); + t[l + 16] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)); + t[l + 24] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)); + t[l + 32] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)); + t[l + 40] = + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp2 = vadd_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp3 = vadd_f16(t[5 + offset], t[6 + offset]); + float16x4_t tmp4 = vsub_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp5 = vsub_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp6 = vsub_f16(t[5 + offset], t[6 + offset]); + m[l] = vadd_f16(vadd_f16(vadd_f16(vadd_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vadd_f16( + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l + 6] = vmax_f16(zero, m[l + 6]); + m[l + 12] = vmax_f16(zero, m[l + 12]); + m[l + 18] = vmax_f16(zero, m[l + 18]); + m[l + 24] = vmax_f16(zero, m[l + 24]); + m[l + 30] = vmax_f16(zero, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[64]; + float16_t t[48]; + float16_t m[36]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16_t tmp1 = src[1 + offset] + src[2 + offset]; + float16_t tmp2 = src[3 + offset] + src[4 + offset]; + float16_t tmp3 = src[5 + offset] + src[6 + offset]; + float16_t tmp4 = src[1 + offset] - src[2 + offset]; + float16_t tmp5 = src[3 + offset] - src[4 + offset]; + float16_t tmp6 = src[5 + offset] - src[6 + offset]; + t[l] = src[offset] + tmp1 + tmp2 + tmp3; + t[l + 8] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f; + t[l + 16] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f; + t[l + 24] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f; + t[l + 32] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f; + t[l + 40] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16_t tmp1 = t[1 + offset] + t[2 + offset]; + float16_t tmp2 = t[3 + offset] + t[4 + offset]; + float16_t tmp3 = t[5 + offset] + t[6 + offset]; + float16_t tmp4 = t[1 + offset] - t[2 + offset]; + float16_t tmp5 = t[3 + offset] - t[4 + offset]; + float16_t tmp6 = t[5 + offset] - t[6 + offset]; + m[l] = t[offset] + tmp1 + tmp2 + tmp3 + bias_ptr; + m[l + 6] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f + bias_ptr; + m[l + 12] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f + bias_ptr; + m[l + 18] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f + bias_ptr; + m[l + 24] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f + bias_ptr; + m[l + 30] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + t[7 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l + 6] = m[l + 6] > 0 ? m[l + 6] : 0; + m[l + 12] = m[l + 12] > 0 ? m[l + 12] : 0; + m[l + 18] = m[l + 18] > 0 ? m[l + 18] : 0; + m[l + 24] = m[l + 24] > 0 ? m[l + 24] : 0; + m[l + 30] = m[l + 30] > 0 ? m[l + 30] : 0; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform8x6Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + int z = 0; + if (r_c == 8) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + m[l + 18] = vmaxq_f16(zero, m[l + 18]); + m[l + 18] = vminq_f16(six, m[l + 18]); + m[l + 24] = vmaxq_f16(zero, m[l + 24]); + m[l + 24] = vminq_f16(six, m[l + 24]); + m[l + 30] = vmaxq_f16(zero, m[l + 30]); + m[l + 30] = vminq_f16(six, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[64]; + float16x4_t t[48]; + float16x4_t m[36]; + float16x4_t zero = vdup_n_f16(0); + float16x4_t six = vdup_n_f16(6); + Load64DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp2 = vadd_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp3 = vadd_f16(src[5 + offset], src[6 + offset]); + float16x4_t tmp4 = vsub_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp5 = vsub_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp6 = vsub_f16(src[5 + offset], src[6 + offset]); + t[l] = vadd_f16(vadd_f16(vadd_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)); + t[l + 16] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)); + t[l + 24] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)); + t[l + 32] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)); + t[l + 40] = + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp2 = vadd_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp3 = vadd_f16(t[5 + offset], t[6 + offset]); + float16x4_t tmp4 = vsub_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp5 = vsub_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp6 = vsub_f16(t[5 + offset], t[6 + offset]); + m[l] = vadd_f16(vadd_f16(vadd_f16(vadd_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vadd_f16( + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l] = vmin_f16(six, m[l]); + m[l + 6] = vmax_f16(zero, m[l + 6]); + m[l + 6] = vmin_f16(six, m[l + 6]); + m[l + 12] = vmax_f16(zero, m[l + 12]); + m[l + 12] = vmin_f16(six, m[l + 12]); + m[l + 18] = vmax_f16(zero, m[l + 18]); + m[l + 18] = vmin_f16(six, m[l + 18]); + m[l + 24] = vmax_f16(zero, m[l + 24]); + m[l + 24] = vmin_f16(six, m[l + 24]); + m[l + 30] = vmax_f16(zero, m[l + 30]); + m[l + 30] = vmin_f16(six, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[64]; + float16_t t[48]; + float16_t m[36]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16_t tmp1 = src[1 + offset] + src[2 + offset]; + float16_t tmp2 = src[3 + offset] + src[4 + offset]; + float16_t tmp3 = src[5 + offset] + src[6 + offset]; + float16_t tmp4 = src[1 + offset] - src[2 + offset]; + float16_t tmp5 = src[3 + offset] - src[4 + offset]; + float16_t tmp6 = src[5 + offset] - src[6 + offset]; + t[l] = src[offset] + tmp1 + tmp2 + tmp3; + t[l + 8] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f; + t[l + 16] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f; + t[l + 24] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f; + t[l + 32] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f; + t[l + 40] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16_t tmp1 = t[1 + offset] + t[2 + offset]; + float16_t tmp2 = t[3 + offset] + t[4 + offset]; + float16_t tmp3 = t[5 + offset] + t[6 + offset]; + float16_t tmp4 = t[1 + offset] - t[2 + offset]; + float16_t tmp5 = t[3 + offset] - t[4 + offset]; + float16_t tmp6 = t[5 + offset] - t[6 + offset]; + m[l] = t[offset] + tmp1 + tmp2 + tmp3 + bias_ptr; + m[l + 6] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f + bias_ptr; + m[l + 12] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f + bias_ptr; + m[l + 18] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f + bias_ptr; + m[l + 24] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f + bias_ptr; + m[l + 30] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + t[7 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l] = m[l] > 0 ? m[l] : 0; + m[l + 6] = m[l + 6] > 0 ? m[l + 6] : 0; + m[l + 6] = m[l + 6] < 6 ? m[l + 6] : 6; + m[l + 12] = m[l + 12] > 0 ? m[l + 12] : 0; + m[l + 12] = m[l + 12] < 6 ? m[l + 12] : 6; + m[l + 18] = m[l + 18] > 0 ? m[l + 18] : 0; + m[l + 18] = m[l + 18] < 6 ? m[l + 18] : 6; + m[l + 24] = m[l + 24] > 0 ? m[l + 24] : 0; + m[l + 24] = m[l + 24] < 6 ? m[l + 24] : 6; + m[l + 30] = m[l + 30] > 0 ? m[l + 30] : 0; + m[l + 30] = m[l + 30] < 6 ? m[l + 30] : 6; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } + } +} + +void OutputTransform8x7UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[56]; + float16x8_t m[49]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)); + t[l + 48] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f16(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x7ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[56]; + float16x8_t m[49]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)); + t[l + 48] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 7] = vmaxq_f16(zero, m[l + 7]); + m[l + 14] = vmaxq_f16(zero, m[l + 14]); + m[l + 21] = vmaxq_f16(zero, m[l + 21]); + m[l + 28] = vmaxq_f16(zero, m[l + 28]); + m[l + 35] = vmaxq_f16(zero, m[l + 35]); + m[l + 42] = vmaxq_f16(zero, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f16(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x7Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[56]; + float16x8_t m[49]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)); + t[l + 48] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 7] = vmaxq_f16(zero, m[l + 7]); + m[l + 7] = vminq_f16(six, m[l + 7]); + m[l + 14] = vmaxq_f16(zero, m[l + 14]); + m[l + 14] = vminq_f16(six, m[l + 14]); + m[l + 21] = vmaxq_f16(zero, m[l + 21]); + m[l + 21] = vminq_f16(six, m[l + 21]); + m[l + 28] = vmaxq_f16(zero, m[l + 28]); + m[l + 28] = vminq_f16(six, m[l + 28]); + m[l + 35] = vmaxq_f16(zero, m[l + 35]); + m[l + 35] = vminq_f16(six, m[l + 35]); + m[l + 42] = vmaxq_f16(zero, m[l + 42]); + m[l + 42] = vminq_f16(six, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f16(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +int SelectOutputUnitFp16(const ConvParameter *conv_param) { + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int in_c = conv_param->input_channel_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_c = conv_param->output_channel_; + int unit2 = UP_DIV(out_w * out_h, C16NUM * conv_param->op_parameter_.thread_num_); + int max_out_unit = (int)(sqrtf((float)unit2)); + max_out_unit = max_out_unit < MAX_UNIT_FP16 ? max_out_unit : MAX_UNIT_FP16; + max_out_unit = max_out_unit > MIN_UNIT_FP16 ? max_out_unit : MIN_UNIT_FP16; + + int unit = 0; + float max_rate = 0.0f; + float common_cost = (float)out_h * out_w * in_c * out_c * kernel_h * kernel_w; + + for (int i = MIN_UNIT_FP16; i <= max_out_unit; ++i) { + int input_unit = i + kernel_w - 1; + if (!GetOutputTransFp16Func(input_unit, i, ActType_No)) { + continue; + } + float penalty = ((float)input_unit * input_unit) / ((float)kernel_h * kernel_w) * 0.12f; + float wino_cost = ((2 + out_c) * (float)input_unit * input_unit * in_c + ((float)input_unit + i) * i * out_c) * + UP_DIV(out_w, i) * UP_DIV(out_h, i); + float reduce_rate = common_cost / wino_cost - penalty; + if (reduce_rate > max_rate) { + max_rate = reduce_rate; + unit = i; + } + } + if (max_rate < 1.0f) { + return 1; + } + // If output_unit is 1, then it is conventional convolution + return unit; +} + +void CheckIfUseWinogradFp16(bool *use_winograd, int *output_unit, const ConvParameter *conv_param) { + if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { + *output_unit = SelectOutputUnitFp16(conv_param); + if (*output_unit > 1) { + *use_winograd = true; + } + } else { + *use_winograd = false; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..d50510d73e2dae57d68a3f49972f54acc07a71dc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16.h @@ -0,0 +1,163 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_WINOGRAD_UTILS_H_ +#define NNACL_FP16_WINOGRAD_UTILS_H_ + +#include "nnacl/conv_parameter.h" +#include "nnacl/op_base.h" +#include "nnacl/fp16/winograd_utils_fp16_macro.h" + +#define MAX_LEN 256 + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int real_c); + +typedef void (*InputTransStepFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +typedef void (*InputTransPackFp16Func)(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int real_c); + +typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +typedef struct TransFp16FuncList { + InputTransFp16Func in_func_; + InputTransStepFp16Func in_step_func_; + InputTransPackFp16Func in_pack_func_; + OutputTransFp16Func out_func_; +} TransFp16FuncList; + +InputTransFp16Func GetInputTransFp16Func(int input_unit); + +#ifdef ENABLE_ARM64 +InputTransStepFp16Func GetInputTransStepFp16Func(int input_unit); + +InputTransPackFp16Func GetInputTransPackFp16Func(int input_unit); +#endif + +void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform4x4StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +void InputTransform6x6StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +void InputTransform8x8StepFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int dst_row_step); + +#ifdef ENABLE_ARM64 +void InputTransform4x4Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8Pack16Fp16(float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); +#endif + +OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type); + +void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform6x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform8x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +int SelectOutputUnitFp16(const ConvParameter *conv_param); + +void CheckIfUseWinogradFp16(bool *use_winograd, int *output_unit, const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_WINOGRAD_UTILS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16_macro.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16_macro.h new file mode 100644 index 0000000000000000000000000000000000000000..dd2d8f7825be0b80b214bb3716df23e7fe6d95c8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16/winograd_utils_fp16_macro.h @@ -0,0 +1,437 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_WINOGRAD_UTILS_MACRO_H_ +#define NNACL_FP16_WINOGRAD_UTILS_MACRO_H_ + +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define Load16DataFp16 \ + src[0] = vld1q_f16(src_data + 0 * src_step); \ + src[1] = vld1q_f16(src_data + 1 * src_step); \ + src[2] = vld1q_f16(src_data + 2 * src_step); \ + src[3] = vld1q_f16(src_data + 3 * src_step); \ + src[4] = vld1q_f16(src_data + 4 * src_step); \ + src[5] = vld1q_f16(src_data + 5 * src_step); \ + src[6] = vld1q_f16(src_data + 6 * src_step); \ + src[7] = vld1q_f16(src_data + 7 * src_step); \ + src[8] = vld1q_f16(src_data + 8 * src_step); \ + src[9] = vld1q_f16(src_data + 9 * src_step); \ + src[10] = vld1q_f16(src_data + 10 * src_step); \ + src[11] = vld1q_f16(src_data + 11 * src_step); \ + src[12] = vld1q_f16(src_data + 12 * src_step); \ + src[13] = vld1q_f16(src_data + 13 * src_step); \ + src[14] = vld1q_f16(src_data + 14 * src_step); \ + src[15] = vld1q_f16(src_data + 15 * src_step); + +#define Load16DataC4Fp16 \ + src[0] = vld1_f16(src_data + 0 * src_step); \ + src[1] = vld1_f16(src_data + 1 * src_step); \ + src[2] = vld1_f16(src_data + 2 * src_step); \ + src[3] = vld1_f16(src_data + 3 * src_step); \ + src[4] = vld1_f16(src_data + 4 * src_step); \ + src[5] = vld1_f16(src_data + 5 * src_step); \ + src[6] = vld1_f16(src_data + 6 * src_step); \ + src[7] = vld1_f16(src_data + 7 * src_step); \ + src[8] = vld1_f16(src_data + 8 * src_step); \ + src[9] = vld1_f16(src_data + 9 * src_step); \ + src[10] = vld1_f16(src_data + 10 * src_step); \ + src[11] = vld1_f16(src_data + 11 * src_step); \ + src[12] = vld1_f16(src_data + 12 * src_step); \ + src[13] = vld1_f16(src_data + 13 * src_step); \ + src[14] = vld1_f16(src_data + 14 * src_step); \ + src[15] = vld1_f16(src_data + 15 * src_step); + +#define Load36DataFp16 \ + src[0] = vld1q_f16(src_data + 0 * src_step); \ + src[1] = vld1q_f16(src_data + 1 * src_step); \ + src[2] = vld1q_f16(src_data + 2 * src_step); \ + src[3] = vld1q_f16(src_data + 3 * src_step); \ + src[4] = vld1q_f16(src_data + 4 * src_step); \ + src[5] = vld1q_f16(src_data + 5 * src_step); \ + src[6] = vld1q_f16(src_data + 6 * src_step); \ + src[7] = vld1q_f16(src_data + 7 * src_step); \ + src[8] = vld1q_f16(src_data + 8 * src_step); \ + src[9] = vld1q_f16(src_data + 9 * src_step); \ + src[10] = vld1q_f16(src_data + 10 * src_step); \ + src[11] = vld1q_f16(src_data + 11 * src_step); \ + src[12] = vld1q_f16(src_data + 12 * src_step); \ + src[13] = vld1q_f16(src_data + 13 * src_step); \ + src[14] = vld1q_f16(src_data + 14 * src_step); \ + src[15] = vld1q_f16(src_data + 15 * src_step); \ + src[16] = vld1q_f16(src_data + 16 * src_step); \ + src[17] = vld1q_f16(src_data + 17 * src_step); \ + src[18] = vld1q_f16(src_data + 18 * src_step); \ + src[19] = vld1q_f16(src_data + 19 * src_step); \ + src[20] = vld1q_f16(src_data + 20 * src_step); \ + src[21] = vld1q_f16(src_data + 21 * src_step); \ + src[22] = vld1q_f16(src_data + 22 * src_step); \ + src[23] = vld1q_f16(src_data + 23 * src_step); \ + src[24] = vld1q_f16(src_data + 24 * src_step); \ + src[25] = vld1q_f16(src_data + 25 * src_step); \ + src[26] = vld1q_f16(src_data + 26 * src_step); \ + src[27] = vld1q_f16(src_data + 27 * src_step); \ + src[28] = vld1q_f16(src_data + 28 * src_step); \ + src[29] = vld1q_f16(src_data + 29 * src_step); \ + src[30] = vld1q_f16(src_data + 30 * src_step); \ + src[31] = vld1q_f16(src_data + 31 * src_step); \ + src[32] = vld1q_f16(src_data + 32 * src_step); \ + src[33] = vld1q_f16(src_data + 33 * src_step); \ + src[34] = vld1q_f16(src_data + 34 * src_step); \ + src[35] = vld1q_f16(src_data + 35 * src_step); + +#define Load36DataC4Fp16 \ + src[0] = vld1_f16(src_data + 0 * src_step); \ + src[1] = vld1_f16(src_data + 1 * src_step); \ + src[2] = vld1_f16(src_data + 2 * src_step); \ + src[3] = vld1_f16(src_data + 3 * src_step); \ + src[4] = vld1_f16(src_data + 4 * src_step); \ + src[5] = vld1_f16(src_data + 5 * src_step); \ + src[6] = vld1_f16(src_data + 6 * src_step); \ + src[7] = vld1_f16(src_data + 7 * src_step); \ + src[8] = vld1_f16(src_data + 8 * src_step); \ + src[9] = vld1_f16(src_data + 9 * src_step); \ + src[10] = vld1_f16(src_data + 10 * src_step); \ + src[11] = vld1_f16(src_data + 11 * src_step); \ + src[12] = vld1_f16(src_data + 12 * src_step); \ + src[13] = vld1_f16(src_data + 13 * src_step); \ + src[14] = vld1_f16(src_data + 14 * src_step); \ + src[15] = vld1_f16(src_data + 15 * src_step); \ + src[16] = vld1_f16(src_data + 16 * src_step); \ + src[17] = vld1_f16(src_data + 17 * src_step); \ + src[18] = vld1_f16(src_data + 18 * src_step); \ + src[19] = vld1_f16(src_data + 19 * src_step); \ + src[20] = vld1_f16(src_data + 20 * src_step); \ + src[21] = vld1_f16(src_data + 21 * src_step); \ + src[22] = vld1_f16(src_data + 22 * src_step); \ + src[23] = vld1_f16(src_data + 23 * src_step); \ + src[24] = vld1_f16(src_data + 24 * src_step); \ + src[25] = vld1_f16(src_data + 25 * src_step); \ + src[26] = vld1_f16(src_data + 26 * src_step); \ + src[27] = vld1_f16(src_data + 27 * src_step); \ + src[28] = vld1_f16(src_data + 28 * src_step); \ + src[29] = vld1_f16(src_data + 29 * src_step); \ + src[30] = vld1_f16(src_data + 30 * src_step); \ + src[31] = vld1_f16(src_data + 31 * src_step); \ + src[32] = vld1_f16(src_data + 32 * src_step); \ + src[33] = vld1_f16(src_data + 33 * src_step); \ + src[34] = vld1_f16(src_data + 34 * src_step); \ + src[35] = vld1_f16(src_data + 35 * src_step); + +#define Load64DataFp16 \ + src[0] = vld1q_f16(src_data + 0 * src_step); \ + src[1] = vld1q_f16(src_data + 1 * src_step); \ + src[2] = vld1q_f16(src_data + 2 * src_step); \ + src[3] = vld1q_f16(src_data + 3 * src_step); \ + src[4] = vld1q_f16(src_data + 4 * src_step); \ + src[5] = vld1q_f16(src_data + 5 * src_step); \ + src[6] = vld1q_f16(src_data + 6 * src_step); \ + src[7] = vld1q_f16(src_data + 7 * src_step); \ + src[8] = vld1q_f16(src_data + 8 * src_step); \ + src[9] = vld1q_f16(src_data + 9 * src_step); \ + src[10] = vld1q_f16(src_data + 10 * src_step); \ + src[11] = vld1q_f16(src_data + 11 * src_step); \ + src[12] = vld1q_f16(src_data + 12 * src_step); \ + src[13] = vld1q_f16(src_data + 13 * src_step); \ + src[14] = vld1q_f16(src_data + 14 * src_step); \ + src[15] = vld1q_f16(src_data + 15 * src_step); \ + src[16] = vld1q_f16(src_data + 16 * src_step); \ + src[17] = vld1q_f16(src_data + 17 * src_step); \ + src[18] = vld1q_f16(src_data + 18 * src_step); \ + src[19] = vld1q_f16(src_data + 19 * src_step); \ + src[20] = vld1q_f16(src_data + 20 * src_step); \ + src[21] = vld1q_f16(src_data + 21 * src_step); \ + src[22] = vld1q_f16(src_data + 22 * src_step); \ + src[23] = vld1q_f16(src_data + 23 * src_step); \ + src[24] = vld1q_f16(src_data + 24 * src_step); \ + src[25] = vld1q_f16(src_data + 25 * src_step); \ + src[26] = vld1q_f16(src_data + 26 * src_step); \ + src[27] = vld1q_f16(src_data + 27 * src_step); \ + src[28] = vld1q_f16(src_data + 28 * src_step); \ + src[29] = vld1q_f16(src_data + 29 * src_step); \ + src[30] = vld1q_f16(src_data + 30 * src_step); \ + src[31] = vld1q_f16(src_data + 31 * src_step); \ + src[32] = vld1q_f16(src_data + 32 * src_step); \ + src[33] = vld1q_f16(src_data + 33 * src_step); \ + src[34] = vld1q_f16(src_data + 34 * src_step); \ + src[35] = vld1q_f16(src_data + 35 * src_step); \ + src[36] = vld1q_f16(src_data + 36 * src_step); \ + src[37] = vld1q_f16(src_data + 37 * src_step); \ + src[38] = vld1q_f16(src_data + 38 * src_step); \ + src[39] = vld1q_f16(src_data + 39 * src_step); \ + src[40] = vld1q_f16(src_data + 40 * src_step); \ + src[41] = vld1q_f16(src_data + 41 * src_step); \ + src[42] = vld1q_f16(src_data + 42 * src_step); \ + src[43] = vld1q_f16(src_data + 43 * src_step); \ + src[44] = vld1q_f16(src_data + 44 * src_step); \ + src[45] = vld1q_f16(src_data + 45 * src_step); \ + src[46] = vld1q_f16(src_data + 46 * src_step); \ + src[47] = vld1q_f16(src_data + 47 * src_step); \ + src[48] = vld1q_f16(src_data + 48 * src_step); \ + src[49] = vld1q_f16(src_data + 49 * src_step); \ + src[50] = vld1q_f16(src_data + 50 * src_step); \ + src[51] = vld1q_f16(src_data + 51 * src_step); \ + src[52] = vld1q_f16(src_data + 52 * src_step); \ + src[53] = vld1q_f16(src_data + 53 * src_step); \ + src[54] = vld1q_f16(src_data + 54 * src_step); \ + src[55] = vld1q_f16(src_data + 55 * src_step); \ + src[56] = vld1q_f16(src_data + 56 * src_step); \ + src[57] = vld1q_f16(src_data + 57 * src_step); \ + src[58] = vld1q_f16(src_data + 58 * src_step); \ + src[59] = vld1q_f16(src_data + 59 * src_step); \ + src[60] = vld1q_f16(src_data + 60 * src_step); \ + src[61] = vld1q_f16(src_data + 61 * src_step); \ + src[62] = vld1q_f16(src_data + 62 * src_step); \ + src[63] = vld1q_f16(src_data + 63 * src_step); + +#define Load64DataC4Fp16 \ + src[0] = vld1_f16(src_data + 0 * src_step); \ + src[1] = vld1_f16(src_data + 1 * src_step); \ + src[2] = vld1_f16(src_data + 2 * src_step); \ + src[3] = vld1_f16(src_data + 3 * src_step); \ + src[4] = vld1_f16(src_data + 4 * src_step); \ + src[5] = vld1_f16(src_data + 5 * src_step); \ + src[6] = vld1_f16(src_data + 6 * src_step); \ + src[7] = vld1_f16(src_data + 7 * src_step); \ + src[8] = vld1_f16(src_data + 8 * src_step); \ + src[9] = vld1_f16(src_data + 9 * src_step); \ + src[10] = vld1_f16(src_data + 10 * src_step); \ + src[11] = vld1_f16(src_data + 11 * src_step); \ + src[12] = vld1_f16(src_data + 12 * src_step); \ + src[13] = vld1_f16(src_data + 13 * src_step); \ + src[14] = vld1_f16(src_data + 14 * src_step); \ + src[15] = vld1_f16(src_data + 15 * src_step); \ + src[16] = vld1_f16(src_data + 16 * src_step); \ + src[17] = vld1_f16(src_data + 17 * src_step); \ + src[18] = vld1_f16(src_data + 18 * src_step); \ + src[19] = vld1_f16(src_data + 19 * src_step); \ + src[20] = vld1_f16(src_data + 20 * src_step); \ + src[21] = vld1_f16(src_data + 21 * src_step); \ + src[22] = vld1_f16(src_data + 22 * src_step); \ + src[23] = vld1_f16(src_data + 23 * src_step); \ + src[24] = vld1_f16(src_data + 24 * src_step); \ + src[25] = vld1_f16(src_data + 25 * src_step); \ + src[26] = vld1_f16(src_data + 26 * src_step); \ + src[27] = vld1_f16(src_data + 27 * src_step); \ + src[28] = vld1_f16(src_data + 28 * src_step); \ + src[29] = vld1_f16(src_data + 29 * src_step); \ + src[30] = vld1_f16(src_data + 30 * src_step); \ + src[31] = vld1_f16(src_data + 31 * src_step); \ + src[32] = vld1_f16(src_data + 32 * src_step); \ + src[33] = vld1_f16(src_data + 33 * src_step); \ + src[34] = vld1_f16(src_data + 34 * src_step); \ + src[35] = vld1_f16(src_data + 35 * src_step); \ + src[36] = vld1_f16(src_data + 36 * src_step); \ + src[37] = vld1_f16(src_data + 37 * src_step); \ + src[38] = vld1_f16(src_data + 38 * src_step); \ + src[39] = vld1_f16(src_data + 39 * src_step); \ + src[40] = vld1_f16(src_data + 40 * src_step); \ + src[41] = vld1_f16(src_data + 41 * src_step); \ + src[42] = vld1_f16(src_data + 42 * src_step); \ + src[43] = vld1_f16(src_data + 43 * src_step); \ + src[44] = vld1_f16(src_data + 44 * src_step); \ + src[45] = vld1_f16(src_data + 45 * src_step); \ + src[46] = vld1_f16(src_data + 46 * src_step); \ + src[47] = vld1_f16(src_data + 47 * src_step); \ + src[48] = vld1_f16(src_data + 48 * src_step); \ + src[49] = vld1_f16(src_data + 49 * src_step); \ + src[50] = vld1_f16(src_data + 50 * src_step); \ + src[51] = vld1_f16(src_data + 51 * src_step); \ + src[52] = vld1_f16(src_data + 52 * src_step); \ + src[53] = vld1_f16(src_data + 53 * src_step); \ + src[54] = vld1_f16(src_data + 54 * src_step); \ + src[55] = vld1_f16(src_data + 55 * src_step); \ + src[56] = vld1_f16(src_data + 56 * src_step); \ + src[57] = vld1_f16(src_data + 57 * src_step); \ + src[58] = vld1_f16(src_data + 58 * src_step); \ + src[59] = vld1_f16(src_data + 59 * src_step); \ + src[60] = vld1_f16(src_data + 60 * src_step); \ + src[61] = vld1_f16(src_data + 61 * src_step); \ + src[62] = vld1_f16(src_data + 62 * src_step); \ + src[63] = vld1_f16(src_data + 63 * src_step); + +#define LOAD_LINE_DATA_FP16(line) \ + float16x8_t s##line##0 = vld1q_f16(src_ptr + line * src_point_stride + 0 * pack_tile); \ + float16x8_t s##line##1 = vld1q_f16(src_ptr + line * src_point_stride + 1 * pack_tile); + +#define TRANSPOSE_16x8 \ + float16x8_t s0 = vld1q_f16(src_ptr + 0 * pack_tile); \ + float16x8_t s2 = vld1q_f16(src_ptr + 1 * pack_tile); \ + float16x8_t s4 = vld1q_f16(src_ptr + 2 * pack_tile); \ + float16x8_t s6 = vld1q_f16(src_ptr + 3 * pack_tile); \ + float16x8_t s8 = vld1q_f16(src_ptr + 4 * pack_tile); \ + float16x8_t s10 = vld1q_f16(src_ptr + 5 * pack_tile); \ + float16x8_t s12 = vld1q_f16(src_ptr + 6 * pack_tile); \ + float16x8_t s14 = vld1q_f16(src_ptr + 7 * pack_tile); \ + float16x8_t s1 = vld1q_f16(src_ptr + 8 * pack_tile); \ + float16x8_t s3 = vld1q_f16(src_ptr + 9 * pack_tile); \ + float16x8_t s5 = vld1q_f16(src_ptr + 10 * pack_tile); \ + float16x8_t s7 = vld1q_f16(src_ptr + 11 * pack_tile); \ + float16x8_t s9 = vld1q_f16(src_ptr + 12 * pack_tile); \ + float16x8_t s11 = vld1q_f16(src_ptr + 13 * pack_tile); \ + float16x8_t s13 = vld1q_f16(src_ptr + 14 * pack_tile); \ + float16x8_t s15 = vld1q_f16(src_ptr + 15 * pack_tile); \ + transpose8(&s0, &s2, &s4, &s6, &s8, &s10, &s12, &s14); \ + transpose8(&s1, &s3, &s5, &s7, &s9, &s11, &s13, &s15); \ + vst1q_f16(src_ptr + 0 * pack_tile, s0); \ + vst1q_f16(src_ptr + 1 * pack_tile, s1); \ + vst1q_f16(src_ptr + 2 * pack_tile, s2); \ + vst1q_f16(src_ptr + 3 * pack_tile, s3); \ + vst1q_f16(src_ptr + 4 * pack_tile, s4); \ + vst1q_f16(src_ptr + 5 * pack_tile, s5); \ + vst1q_f16(src_ptr + 6 * pack_tile, s6); \ + vst1q_f16(src_ptr + 7 * pack_tile, s7); \ + vst1q_f16(src_ptr + 8 * pack_tile, s8); \ + vst1q_f16(src_ptr + 9 * pack_tile, s9); \ + vst1q_f16(src_ptr + 10 * pack_tile, s10); \ + vst1q_f16(src_ptr + 11 * pack_tile, s11); \ + vst1q_f16(src_ptr + 12 * pack_tile, s12); \ + vst1q_f16(src_ptr + 13 * pack_tile, s13); \ + vst1q_f16(src_ptr + 14 * pack_tile, s14); \ + vst1q_f16(src_ptr + 15 * pack_tile, s15); + +#define Store4DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + dst_step * out_c, m[2]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[3]); + +#define Store4DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + dst_step * out_c, m[2]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[3]); + +#define Store9DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + 2 * out_c, m[2]); \ + vst1q_f16(dst_data + dst_step * out_c, m[3]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[4]); \ + vst1q_f16(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c, m[6]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define Store9DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + 2 * out_c, m[2]); \ + vst1_f16(dst_data + dst_step * out_c, m[3]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[4]); \ + vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + vst1_f16(dst_data + 2 * dst_step * out_c, m[6]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define Store16DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + 2 * out_c, m[2]); \ + vst1q_f16(dst_data + 3 * out_c, m[3]); \ + vst1q_f16(dst_data + dst_step * out_c, m[4]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[5]); \ + vst1q_f16(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + vst1q_f16(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c, m[8]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c, m[12]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define Store16DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + 2 * out_c, m[2]); \ + vst1_f16(dst_data + 3 * out_c, m[3]); \ + vst1_f16(dst_data + dst_step * out_c, m[4]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[5]); \ + vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + vst1_f16(dst_data + 2 * dst_step * out_c, m[8]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + vst1_f16(dst_data + 3 * dst_step * out_c, m[12]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define Store25DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + 2 * out_c, m[2]); \ + vst1q_f16(dst_data + 3 * out_c, m[3]); \ + vst1q_f16(dst_data + 4 * out_c, m[4]); \ + vst1q_f16(dst_data + dst_step * out_c, m[5]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[6]); \ + vst1q_f16(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + vst1q_f16(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + vst1q_f16(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c, m[10]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c, m[15]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c, m[20]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +#define Store25DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + 2 * out_c, m[2]); \ + vst1_f16(dst_data + 3 * out_c, m[3]); \ + vst1_f16(dst_data + 4 * out_c, m[4]); \ + vst1_f16(dst_data + dst_step * out_c, m[5]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[6]); \ + vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + vst1_f16(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + vst1_f16(dst_data + 2 * dst_step * out_c, m[10]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + vst1_f16(dst_data + 3 * dst_step * out_c, m[15]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + vst1_f16(dst_data + 4 * dst_step * out_c, m[20]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_WINOGRAD_UTILS_MACRO_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/activation_grad_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/activation_grad_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..11499071413022c3c06189310a8cc29c8561905e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/activation_grad_fp16.c @@ -0,0 +1,151 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16_grad/activation_grad_fp16.h" +#include +#include +#ifdef ENABLE_NEON +#include +#include "nnacl/fp32/exp_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" +#endif +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" + +int ReluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t zero_v = vdupq_n_f16(0); + for (; i <= length - C8NUM; i += C8NUM) { + float16x8_t src0_v = vld1q_f16(src0 + i); + float16x8_t src1_v = vld1q_f16(src1 + i); + uint16x8_t mask_v = vcleq_f16(src1_v, zero_v); + float16x8_t dst_v = vbslq_f16(mask_v, zero_v, src0_v); + vst1q_f16(dst + i, dst_v); + } +#endif + for (; i < length; i++) { + dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int Relu6Fp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t zero_8 = vdupq_n_f16(0); + float16x8_t six_8 = vdupq_n_f16(6); + for (; i <= length - C8NUM; i += C8NUM) { + float16x8_t src1_8 = vld1q_f16(src1 + i); + float16x8_t src0_8 = vld1q_f16(src0 + i); + uint16x8_t gt_8 = vcgtq_f16(src1_8, zero_8); + uint16x8_t le_8 = vcleq_f16(src1_8, six_8); + uint16x8_t mask_8 = vandq_u16(gt_8, le_8); + float16x8_t dst_8 = vbslq_f16(mask_8, src0_8, zero_8); + vst1q_f16(dst + i, dst_8); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f && src1[i] <= 6.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int LReluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst, float16_t alpha) { + int i = 0; +#ifdef ENABLE_NEON + const MS_FLOAT16X8 one_8 = vdupq_n_f16(1); + for (; i <= length - C8NUM; i += C8NUM) { + MS_FLOAT16X8 src0_8 = MS_LDQ_F16(src0 + i); + MS_FLOAT16X8 src1_8 = MS_LDQ_F16(src1 + i); + MS_STQ_F16(dst + i, vmulq_f16(src0_8, vmulq_f16(src1_8, (one_8 - src1_8)))); + } +#endif + for (; i < length; ++i) { + dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); + } + return NNACL_OK; +} + +int SigmoidFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t one_8 = vdupq_n_f16(1); + for (; i < length - C8NUM; i += C8NUM) { + float16x8_t src0_8 = vld1q_f16(src0 + i); + float16x8_t src1_8 = vld1q_f16(src1 + i); + float16x8_t dst_8 = vmulq_f16(src0_8, vmulq_f16(src1_8, vsubq_f16(one_8, src1_8))); + vst1q_f16(dst + i, dst_8); + } +#endif + for (; i < length; i++) { + dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); + } + return NNACL_OK; +} + +int TanhFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = (float16_t)((1.0f - ((float)src1[i] * (float)src1[i])) * (float)src0[i]); + } + return NNACL_OK; +} + +int HSwishFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + for (int i = 0; i < length; ++i) { + float16_t tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} + +int HSigmoidFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + for (int i = 0; i < length; ++i) { + float16_t tmp = (src1[i] > 3.0f ? 0.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} +int EluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst, float16_t alpha) { + int i = 0; +#ifdef ENABLE_NEON + float16x4_t zero_4 = vdup_n_f16(0); + float16x4_t one_4 = vdup_n_f16(1); + float16x4_t alpha_4 = vdup_n_f16(alpha); + for (; i <= length - C4NUM; i += C4NUM) { + float16x4_t src0_4 = vld1_f16(src0 + i); + float16x4_t src1_4 = vld1_f16(src1 + i); + uint16x4_t mask_4 = vcgt_f16(src1_4, zero_4); + float32x4_t tmp; + simd_exp128(vcvt_f32_f16(src1_4), (float *)&tmp); + uint16x4_t expm1_4 = vsub_f16(vcvt_f16_f32(tmp), one_4); + float16x4_t dst_4 = vbsl_f16(mask_4, src0_4, vmul_f16(alpha_4, vmul_f16(expm1_4, src0_4))); + vst1_f16(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f ? src0[i] : alpha * expm1(src1[i]) * src0[i]); + } + return NNACL_OK; +} + +int GeluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst) { + for (int i = 0; i < length; ++i) { + dst[i] = src0[i] * ((0.5 * (1.0 + erf(src1[i] / 1.4142135623730951))) + + (src1[i] * exp(-0.5 * src1[i] * src1[i]) / 2.5066282746)); + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/activation_grad_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/activation_grad_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..91dd183369cef98e7a019994b8515e0585b0b617 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/activation_grad_fp16.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_ACTIVATION_GRAD_FP16_H_ +#define NNACL_FP16_GRAD_ACTIVATION_GRAD_FP16_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl/op_base.h" +#include "nnacl/int8/fixed_point.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int Relu6Fp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int LReluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst, float16_t alpha); +int SigmoidFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int TanhFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int HSwishFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int HSigmoidFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); +int EluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst, float16_t alpha); +int GeluFp16Grad(const float16_t *src0, const float16_t *src1, int length, float16_t *dst); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_ACTIVATION_GRAD_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/arithmetic_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/arithmetic_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..290839ea934a5b92cb60bfe917dff3855496679d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/arithmetic_grad.c @@ -0,0 +1,158 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16_grad/arithmetic_grad.h" +#include +#include +#include "nnacl/fp32_grad/utils.h" +#include "nnacl/errorcode.h" +#include "nnacl/op_base.h" + +void ElementDivNegSquareFp16(const float16_t *nom, const float16_t *denom, float16_t *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -nom[i] / (denom[i] * denom[i]); + } +} + +void ElementMulAndDivNegSquareFp16(const float16_t *a, const float16_t *b, const float16_t *denom, float16_t *output, + int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -a[i] * b[i] / (denom[i] * denom[i]); + } +} + +int ElementAbsGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = (in1[i] < 0.f) ? -in2[i] : ((in1[i] > 0.f) ? in2[i] : 0); + } + return NNACL_OK; +} + +void MaximumByAxesFp16(const float16_t *input0, const float16_t *input1, const float16_t *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float16_t *output0, float16_t *output1, + int num_dims) { + int num_output0 = 1; + int num_output1 = 1; + bool same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_output0 *= input0_dims[idx]; + num_output1 *= input1_dims[idx]; + if (input0_dims[idx] != input1_dims[idx]) { + same_shape = false; + } + } + + if (same_shape) { + int input_iter[C8NUM] = {0}; + + // Iterate through input_data. + do { + size_t offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset] = input0[offset] > input1[offset] ? dy[offset] : 0.; + output1[offset] = input1[offset] >= input0[offset] ? dy[offset] : 0.; + } while (NextIndex(num_dims, input0_dims, input_iter)); + } else { + memset(output0, 0, num_output0 * sizeof(float16_t)); // zero output + memset(output1, 0, num_output1 * sizeof(float16_t)); // zero output + + int input_iter[C8NUM] = {0}; + int axes0[C5NUM] = {0}; + int axes1[C5NUM] = {0}; + int num_axes0 = 0; + int num_axes1 = 0; + for (int i = 0; i < num_dims; i++) { + if (input0_dims[i] == 1) { + axes0[num_axes0++] = i; + } + if (input1_dims[i] == 1) { + axes1[num_axes1++] = i; + } + } + + do { + size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0); + size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1); + size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset0] += input0[offset0] > input1[offset1] ? dy[yt_offset] : 0.; + output1[offset1] += input1[offset1] >= input0[offset0] ? dy[yt_offset] : 0.; + } while (NextIndex(num_dims, dy_dims, input_iter)); + } +} + +void MinimumByAxesFp16(const float16_t *input0, const float16_t *input1, const float16_t *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float16_t *output0, float16_t *output1, + int num_dims) { + int num_output0 = 1; + int num_output1 = 1; + bool same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_output0 *= input0_dims[idx]; + num_output1 *= input1_dims[idx]; + if (input0_dims[idx] != input1_dims[idx]) { + same_shape = false; + } + } + + if (same_shape) { + int input_iter[C8NUM] = {0}; + + // Iterate through input_data. + do { + size_t offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset] = input0[offset] < input1[offset] ? dy[offset] : 0.; + output1[offset] = input1[offset] <= input0[offset] ? dy[offset] : 0.; + } while (NextIndex(num_dims, input0_dims, input_iter)); + } else { + memset(output0, 0, num_output0 * sizeof(float16_t)); // zero output + memset(output1, 0, num_output1 * sizeof(float16_t)); // zero output + + int input_iter[C8NUM] = {0}; + int axes0[C5NUM] = {0}; + int axes1[C5NUM] = {0}; + int num_axes0 = 0; + int num_axes1 = 0; + for (int i = 0; i < num_dims; i++) { + if (input0_dims[i] == 1) { + axes0[num_axes0++] = i; + } + if (input1_dims[i] == 1) { + axes1[num_axes1++] = i; + } + } + + do { + size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0); + size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1); + size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset0] += input0[offset0] < input1[offset1] ? dy[yt_offset] : 0.; + output1[offset1] += input1[offset1] <= input0[offset0] ? dy[yt_offset] : 0.; + } while (NextIndex(num_dims, dy_dims, input_iter)); + } +} + +int ElementSqrtGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = 0.5f * in2[i] / in1[i]; + } + return NNACL_OK; +} + +int ElementRsqrtGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = -0.5f * in2[i] * in1[i] * in1[1] * in1[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_collective.cc b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/arithmetic_grad.h similarity index 31% rename from mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_collective.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/arithmetic_grad.h index 41aa7cfff7bf7fec93763ce3784ba7d39c0bfa0d..ec2aea6f3854c01b6017e3a7d7c60df5c71561b1 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_collective.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/arithmetic_grad.h @@ -13,32 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef NNACL_FP16_GRAD_ARITHMETIC_GRAD_H_ +#define NNACL_FP16_GRAD_ARITHMETIC_GRAD_H_ -#include "src/extendrt/delegate/tensorrt/distribution/distribution_collective.h" +#include "nnacl/op_base.h" -namespace mindspore::lite { -DistributionCollective::DistributionCollective() {} +#ifdef __cplusplus +extern "C" { +#endif +void ElementDivNegSquareFp16(const float16_t *nom, const float16_t *denom, float16_t *output, int element_size); +void ElementMulAndDivNegSquareFp16(const float16_t *a, const float16_t *b, const float16_t *denom, float16_t *output, + int element_size); +int ElementAbsGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, int element_size); +void MaximumByAxesFp16(const float16_t *input0, const float16_t *input1, const float16_t *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float16_t *output0, float16_t *output1, + int num_dims); +void MinimumByAxesFp16(const float16_t *input0, const float16_t *input1, const float16_t *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float16_t *output0, float16_t *output1, + int num_dims); +int ElementSqrtGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, const int element_size); +int ElementRsqrtGradFp16(const float16_t *in1, const float16_t *in2, float16_t *out, const int element_size); -DistributionCollective &DistributionCollective::instance() { - static DistributionCollective instance; - return instance; +#ifdef __cplusplus } +#endif -int DistributionCollective::ReduceScatterWrapper(const void *input_addr, void *output_addr, size_t count, - nvinfer1::DataType data_type, ReduceMode reduce_type, - cudaStream_t stream, const std::string &group) { - return RET_OK; -} - -int DistributionCollective::AllReduceWrapper(const void *input_addr, void *output_addr, size_t count, - nvinfer1::DataType data_type, ReduceMode reduce_type, cudaStream_t stream, - const std::string &group) { - return RET_OK; -} - -int DistributionCollective::AllGatherWrapper(const void *input_addr, void *output_addr, size_t count, - nvinfer1::DataType data_type, cudaStream_t stream, - const std::string &group_name) { - return RET_OK; -} -} // namespace mindspore::lite +#endif // NNACL_FP16_GRAD_ARITHMETIC_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/arithmetic_self_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/arithmetic_self_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..7cd7b4d7e54f7029a564df7b2f5d7d37eff97fcd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/arithmetic_self_grad.c @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl/op_base.h" +#include "nnacl/fp16_grad/arithmetic_self_grad.h" +#include "nnacl/errorcode.h" + +int Fp16LogGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst) { + int i = 0; +#ifdef ENABLE_NEON + float16x8_t log_10 = vdupq_n_f16(log(10)); + for (; i < length - 4; i += 4) { + float16x8_t src0_4 = vld1q_f16(src0 + i); + float16x8_t src1_4 = vld1q_f16(src1 + i); + float16x8_t dst_4 = vmulq_f16(src0_4, vrecpeq_f16(vmulq_f16(src1_4, log_10))); + vst1q_f16(dst + i, dst_4); + } +#endif + for (; i < length; i++) { + dst[i] = src0[i] * 1.0f / (src1[i] * log(10)); + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/arithmetic_self_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/arithmetic_self_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..1775b780b2a65b380745b29651b2b2c32efc074a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/arithmetic_self_grad.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_ARITHMETHIC_SELF_GRAD_H_ +#define NNACL_FP16_GRAD_ARITHMETHIC_SELF_GRAD_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl/op_base.h" + +typedef struct ArithmeticSelfGradParameterFp16 { + OpParameter op_parameter; + int type_; +} ArithmeticSelfGradParameterFp16; +#ifdef __cplusplus +extern "C" { +#endif + +int Fp16LogGrad(const float16_t *src0, const float16_t *src1, size_t length, float16_t *dst); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP16_GRAD_ARITHMETHIC_SELF_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/batch_norm.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/batch_norm.c new file mode 100644 index 0000000000000000000000000000000000000000..a9ac4b54d8e49fc32c8488b9c0a43fc44ca53fd9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/batch_norm.c @@ -0,0 +1,88 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "nnacl/fp16_grad/batch_norm.h" + +void var2InvarFp16(float16_t *save_var, int size, float eps) { + for (int i = 0; i < size; i++) { + save_var[i] = (float16_t)(1.0f / sqrtf((float)save_var[i] + eps)); + } +} + +void backwardAllFp16(const float16_t *restrict in, const float16_t *restrict yt, const float16_t *restrict mean, + const float16_t *restrict invar, const float16_t *restrict scale, int size, int ch, + float *restrict dxhat_sum, float *restrict dxhathat_sum, float16_t *restrict dbias, + float16_t *restrict dscale, float16_t *restrict dx) { + NNACL_CHECK_ZERO_RETURN(size); + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // dscale + float16_t x_hat = (in[ix] - mean[c]) * invar[c]; + dscale[c] += (yt[ix] * x_hat); + // dx_1 + float dx_hat = (float)(yt[ix] * scale[c]); + dxhat_sum[c] += dx_hat; + dxhathat_sum[c] += (float)(dx_hat * x_hat); + } + } + float N = (float)size; + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + // dx_2 + int ix = i * ch + c; + float16_t x_hat = (in[ix] - mean[c]) * invar[c]; + float16_t dx_hat = yt[ix] * scale[c]; + dx[ix] = (float16_t)((float)((invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c])) / N); + } + } +} +void backwardP1Fp16(const float16_t *restrict in, const float16_t *restrict yt, const float16_t *restrict mean, + const float16_t *restrict invar, const float16_t *restrict scale, int size, int ch, + float *restrict dxhat_sum, float *restrict dxhathat_sum, float16_t *restrict dbias, + float16_t *restrict dscale) { + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // dscale + float x_hat = (float)((in[ix] - mean[c]) * invar[c]); + dscale[c] += (yt[ix] * x_hat); + // dx_1 + float dx_hat = (float)(yt[ix] * scale[c]); + dxhat_sum[c] += dx_hat; + dxhathat_sum[c] += dx_hat * x_hat; + } + } +} + +void backwardP2Fp16(const float16_t *restrict in, const float16_t *restrict yt, const float16_t *restrict mean, + const float16_t *restrict invar, const float16_t *restrict scale, int size, int total_size, int ch, + const float *dxhat_sum, const float *dxhathat_sum, float16_t *restrict dx) { + NNACL_CHECK_ZERO_RETURN(total_size); + const float N = (float)total_size; + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + // dx_2 + int ix = i * ch + c; + float x_hat = (float)((in[ix] - mean[c]) * invar[c]); + float dx_hat = (float)(yt[ix] * scale[c]); + dx[ix] = (float16_t)(((float)(invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c])) / N); + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/batch_norm.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/batch_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..51c14f1f5a1179dcda8a5b7d05890bc28f6e17db --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/batch_norm.h @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_BATCH_NORM_H_ +#define NNACL_FP16_GRAD_BATCH_NORM_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void var2InvarFp16(float16_t *save_var, int size, float eps); +void backwardAllFp16(const float16_t *in, const float16_t *yt, const float16_t *mean, const float16_t *invar, + const float16_t *scale, int size, int ch, float *dxhat_sum, float *dxhathat_sum, float16_t *dbias, + float16_t *dscale, float16_t *dx); +void backwardP1Fp16(const float16_t *in, const float16_t *yt, const float16_t *mean, const float16_t *invar, + const float16_t *scale, int size, int ch, float *dxhat_sum, float *dxhathat_sum, float16_t *dbias, + float16_t *dscale); +void backwardP2Fp16(const float16_t *in, const float16_t *yt, const float16_t *mean, const float16_t *invar, + const float16_t *scale, int size, int total_size, int ch, const float *dxhat_sum, + const float *dxhathat_sum, float16_t *dx); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_BATCH_NORM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/convolution_grad_filter.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/convolution_grad_filter.c new file mode 100644 index 0000000000000000000000000000000000000000..01499bccdb5a83955e754e4b91425d99d9f2bde5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/convolution_grad_filter.c @@ -0,0 +1,361 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16_grad/convolution_grad_filter.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl/errorcode.h" +#ifdef ENABLE_NEON +#include +#endif + +#ifdef ENABLE_NEON + +static int FilterGrad32Arm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~31); i_c += 32) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + float32x4_t sum_1 = vdupq_n_f32(0.0f); + float32x4_t sum_2 = vdupq_n_f32(0.0f); + float32x4_t sum_3 = vdupq_n_f32(0.0f); + float32x4_t sum_4 = vdupq_n_f32(0.0f); + float32x4_t sum_5 = vdupq_n_f32(0.0f); + float32x4_t sum_6 = vdupq_n_f32(0.0f); + float32x4_t sum_7 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float16x8_t x_0 = vld1q_f16(x_addr + offset_x); + float16x8_t dy_0 = vld1q_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(vget_low_f16(x_0), vget_low_f16(dy_0), sum_0); + sum_1 = MS_VMLAL_F16(vget_high_f16(x_0), vget_high_f16(dy_0), sum_1); + + float16x8_t x_1 = vld1q_f16(x_addr + offset_x + 8); + float16x8_t dy_1 = vld1q_f16(dy_addr + offset_dy + 8); + sum_2 = MS_VMLAL_F16(vget_low_f16(x_1), vget_low_f16(dy_1), sum_2); + sum_3 = MS_VMLAL_F16(vget_high_f16(x_1), vget_high_f16(dy_1), sum_3); + + float16x8_t x_2 = vld1q_f16(x_addr + offset_x + 16); + float16x8_t dy_2 = vld1q_f16(dy_addr + offset_dy + 16); + sum_4 = MS_VMLAL_F16(vget_low_f16(x_2), vget_low_f16(dy_2), sum_4); + sum_5 = MS_VMLAL_F16(vget_high_f16(x_2), vget_high_f16(dy_2), sum_5); + + float16x8_t x_3 = vld1q_f16(x_addr + offset_x + 24); + float16x8_t dy_3 = vld1q_f16(dy_addr + offset_dy + 24); + sum_6 = MS_VMLAL_F16(vget_low_f16(x_3), vget_low_f16(dy_3), sum_6); + sum_7 = MS_VMLAL_F16(vget_high_f16(x_3), vget_high_f16(dy_3), sum_7); + } + } + } + // store into memory + for (int l = 0; l < 4; l++) { + dw[(i_c + l) * k_spatial + k_idx] = sum_0[l]; + dw[(i_c + 4 + l) * k_spatial + k_idx] = sum_1[l]; + dw[(i_c + 8 + l) * k_spatial + k_idx] = sum_2[l]; + dw[(i_c + 12 + l) * k_spatial + k_idx] = sum_3[l]; + dw[(i_c + 16 + l) * k_spatial + k_idx] = sum_4[l]; + dw[(i_c + 20 + l) * k_spatial + k_idx] = sum_5[l]; + dw[(i_c + 24 + l) * k_spatial + k_idx] = sum_6[l]; + dw[(i_c + 28 + l) * k_spatial + k_idx] = sum_7[l]; + } + } + return i_c; +} + +static int FilterGrad16Arm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~15); i_c += 16) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + float32x4_t sum_1 = vdupq_n_f32(0.0f); + float32x4_t sum_2 = vdupq_n_f32(0.0f); + float32x4_t sum_3 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float16x8_t x_0 = vld1q_f16(x_addr + offset_x); + float16x8_t dy_0 = vld1q_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(vget_low_f16(x_0), vget_low_f16(dy_0), sum_0); + sum_1 = MS_VMLAL_F16(vget_high_f16(x_0), vget_high_f16(dy_0), sum_1); + + float16x8_t x_1 = vld1q_f16(x_addr + offset_x + 8); + float16x8_t dy_1 = vld1q_f16(dy_addr + offset_dy + 8); + sum_2 = MS_VMLAL_F16(vget_low_f16(x_1), vget_low_f16(dy_1), sum_2); + sum_3 = MS_VMLAL_F16(vget_high_f16(x_1), vget_high_f16(dy_1), sum_3); + } + } + } + for (int l = 0; l < 4; l++) { + dw[(i_c + l) * k_spatial + k_idx] = sum_0[l]; + dw[(i_c + l + 4) * k_spatial + k_idx] = sum_1[l]; + dw[(i_c + l + 8) * k_spatial + k_idx] = sum_2[l]; + dw[(i_c + l + 12) * k_spatial + k_idx] = sum_3[l]; + } + } + return i_c; +} + +static int FilterGrad8Arm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~7); i_c += 8) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + float32x4_t sum_1 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + float16x8_t x_0 = vld1q_f16(x_addr + offset_x); + float16x8_t dy_0 = vld1q_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(vget_low_f16(x_0), vget_low_f16(dy_0), sum_0); + sum_1 = MS_VMLAL_F16(vget_high_f16(x_0), vget_high_f16(dy_0), sum_1); + } + } + } + for (int l = 0; l < 4; l++) { + dw[(i_c + l) * k_spatial + k_idx] = sum_0[l]; + dw[(i_c + 4 + l) * k_spatial + k_idx] = sum_1[l]; + } + } + return i_c; +} + +static int FilterGrad4Arm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~3); i_c += 4) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + float16x4_t x_0 = vld1_f16(x_addr + offset_x); + float16x4_t dy_0 = vld1_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(x_0, dy_0, sum_0); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_0[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_0[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_0[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_0[3]; + } + return i_c; +} + +static int FilterGradLeftoverArm(const float16_t *x, const float16_t *dy, int i_c, int k_idx, float16_t *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + int leftover = out_ch - i_c; + if (leftover > 0) { + float32x4_t sum_0 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + float16x4_t x_0 = vld1_f16(x_addr + offset_x); + float16x4_t dy_0 = vld1_f16(dy_addr + offset_dy); + sum_0 = MS_VMLAL_F16(x_0, dy_0, sum_0); + } + } + } + for (int l = 0; l < leftover; l++) { + dw[(i_c + l) * k_spatial + k_idx] = sum_0[l]; + } + } + return out_ch; +} + +#endif + +int ConvDwFilterFp16Grad(const float16_t *x, const float16_t *dy, float16_t *dw, int start, int count, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + + for (int i_k = 0; i_k < count; i_k++) { + int k_idx = start + i_k; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + int i_c = 0; +#ifdef ENABLE_NEON + i_c = FilterGrad32Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad16Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad8Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad4Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGradLeftoverArm(x, dy, i_c, k_idx, dw, conv_param); +#endif + for (; i_c < out_ch; i_c++) { + float sum = 0; + for (int b = 0; b < batch; ++b) { + const float16_t *x_addr = &x[b * x_size]; + const float16_t *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + sum += x_addr[offset_x] * dy_addr[offset_dy]; + } + } + } + dw[i_c * k_spatial + k_idx] = sum; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/convolution_grad_filter.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/convolution_grad_filter.h new file mode 100644 index 0000000000000000000000000000000000000000..84628210446a088af8eabd36ae9e5a42ed31ea17 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/convolution_grad_filter.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_CONVOLUTION_GRAD_FILTER_H_ +#define NNACL_FP16_GRAD_CONVOLUTION_GRAD_FILTER_H_ + +#include +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwFilterFp16Grad(const float16_t *x, const float16_t *dy, float16_t *dw, int start, int count, + const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_CONVOLUTION_GRAD_FILTER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/convolution_grad_input.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/convolution_grad_input.c new file mode 100644 index 0000000000000000000000000000000000000000..0332ff682a574a803fd8c5b1c8325c402433f751 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/convolution_grad_input.c @@ -0,0 +1,332 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16_grad/convolution_grad_input.h" +#include "nnacl/errorcode.h" +#ifdef ENABLE_ARM +#include +#endif + +static int ConvDwInputGrad16(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int end, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int batch = conv_param->input_batch_; + int in_size = in_h * in_w * in_ch; + int out_size = out_h * out_w * out_ch; + + int j = start; + for (; j <= (end - C16NUM); j += C16NUM) { + float16_t *c = dx + j; + const float16_t *mat_b[C16NUM]; + for (int j_i = 0; j_i < C16NUM; j_i++) { + mat_b[j_i] = w + (j + j_i) * k_spatial; + } + for (int si = 0; si < out_spatial; si++) { + const float16_t *a = dy + j + si * out_ch; + int output_row = (si) / out_w; + int output_col = (si) % out_w; + int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_; + int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; +#ifdef ENABLE_ARM +#ifdef ENABLE_ARM64 + float16x8_t mat_b0 = {mat_b[0][k], mat_b[1][k], mat_b[2][k], mat_b[3][k], + mat_b[4][k], mat_b[5][k], mat_b[6][k], mat_b[7][k]}; + float16x8_t mat_b1 = {mat_b[8][k], mat_b[9][k], mat_b[10][k], mat_b[11][k], + mat_b[12][k], mat_b[13][k], mat_b[14][k], mat_b[15][k]}; +#else + float16x4_t mat_b00; + float16x4_t mat_b01; + float16x4_t mat_b10; + float16x4_t mat_b11; + asm volatile( + "vld1.16 %0[0], [%2]\n" + "vld1.16 %0[1], [%3]\n" + "vld1.16 %0[2], [%4]\n" + "vld1.16 %0[3], [%5]\n" + "vld1.16 %1[0], [%6]\n" + "vld1.16 %1[1], [%7]\n" + "vld1.16 %1[2], [%8]\n" + "vld1.16 %1[3], [%9]\n" + : "=w"(mat_b00), "=w"(mat_b01) + : "r"(mat_b[0] + k), "r"(mat_b[1] + k), "r"(mat_b[2] + k), "r"(mat_b[3] + k), "r"(mat_b[4] + k), + "r"(mat_b[5] + k), "r"(mat_b[6] + k), "r"(mat_b[7] + k) + :); + asm volatile( + "vld1.16 %0[0], [%2]\n" + "vld1.16 %0[1], [%3]\n" + "vld1.16 %0[2], [%4]\n" + "vld1.16 %0[3], [%5]\n" + "vld1.16 %1[0], [%6]\n" + "vld1.16 %1[1], [%7]\n" + "vld1.16 %1[2], [%8]\n" + "vld1.16 %1[3], [%9]\n" + : "=w"(mat_b10), "=w"(mat_b11) + : "r"(mat_b[8] + k), "r"(mat_b[9] + k), "r"(mat_b[10] + k), "r"(mat_b[11] + k), "r"(mat_b[12] + k), + "r"(mat_b[13] + k), "r"(mat_b[14] + k), "r"(mat_b[15] + k) + :); + float16x8_t mat_b0 = vcombine_f16(mat_b00, mat_b01); + float16x8_t mat_b1 = vcombine_f16(mat_b10, mat_b11); +#endif + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + float16x8_t mat_c0 = vld1q_f16(c + dx_offset); + float16x8_t mat_a0 = vld1q_f16(a + dy_offset); + mat_c0 = vfmaq_f16(mat_c0, mat_b0, mat_a0); + vst1q_f16(c + dx_offset, mat_c0); + + float16x8_t mat_c1 = vld1q_f16(c + dx_offset + 8); + float16x8_t mat_a1 = vld1q_f16(a + dy_offset + 8); + mat_c1 = vfmaq_f16(mat_c1, mat_b1, mat_a1); + vst1q_f16(c + dx_offset + 8, mat_c1); + } +#else + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + for (int j_i = 0; j_i < C16NUM; j_i++) { + c[dx_offset + j_i] += a[dy_offset + j_i] * mat_b[j_i][k]; + } + } +#endif + } + } + } + } + return j; +} + +static int ConvDwInputGrad8(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int end, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int batch = conv_param->input_batch_; + int in_size = in_h * in_w * in_ch; + int out_size = out_h * out_w * out_ch; + + int j = start; + for (; j <= (end - C8NUM); j += C8NUM) { + float16_t *c = dx + j; + const float16_t *mat_b[C8NUM]; + for (int j_i = 0; j_i < C8NUM; j_i++) { + mat_b[j_i] = w + (j + j_i) * k_spatial; + } + + for (int si = 0; si < out_spatial; si++) { + const float16_t *a = dy + j + si * out_ch; + int output_row = (si) / out_w; + int output_col = (si) % out_w; + int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_; + int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; +#ifdef ENABLE_ARM +#ifdef ENABLE_ARM64 + float16x8_t mat_b0 = {mat_b[0][k], mat_b[1][k], mat_b[2][k], mat_b[3][k], + mat_b[4][k], mat_b[5][k], mat_b[6][k], mat_b[7][k]}; +#else + float16x4_t mat_b00; + float16x4_t mat_b01; + asm volatile( + "vld1.16 %0[0], [%2]\n" + "vld1.16 %0[1], [%3]\n" + "vld1.16 %0[2], [%4]\n" + "vld1.16 %0[3], [%5]\n" + "vld1.16 %1[0], [%6]\n" + "vld1.16 %1[1], [%7]\n" + "vld1.16 %1[2], [%8]\n" + "vld1.16 %1[3], [%9]\n" + : "=w"(mat_b00), "=w"(mat_b01) + : "r"(mat_b[0] + k), "r"(mat_b[1] + k), "r"(mat_b[2] + k), "r"(mat_b[3] + k), "r"(mat_b[4] + k), + "r"(mat_b[5] + k), "r"(mat_b[6] + k), "r"(mat_b[7] + k) + :); + float16x8_t mat_b0 = vcombine_f16(mat_b00, mat_b01); +#endif + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + float16x8_t mat_c0 = vld1q_f16(c + dx_offset); + float16x8_t mat_a0 = vld1q_f16(a + dy_offset); + mat_c0 = vfmaq_f16(mat_c0, mat_b0, mat_a0); + vst1q_f16(c + dx_offset, mat_c0); + } +#else + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + for (int j_i = 0; j_i < C8NUM; j_i++) { + c[dx_offset + j_i] += a[dy_offset + j_i] * mat_b[j_i][k]; + } + } +#endif + } + } + } + } + return j; +} + +static int ConvDwInputGrad4(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int end, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int batch = conv_param->input_batch_; + int in_size = in_h * in_w * in_ch; + int out_size = out_h * out_w * out_ch; + + int j = start; + for (; j <= (end - C4NUM); j += C4NUM) { + float16_t *c = dx + j; + const float16_t *mat_b_0 = w + (j + 0) * k_spatial; + const float16_t *mat_b_1 = w + (j + 1) * k_spatial; + const float16_t *mat_b_2 = w + (j + 2) * k_spatial; + const float16_t *mat_b_3 = w + (j + 3) * k_spatial; + + for (int si = 0; si < out_spatial; si++) { + const float16_t *a = dy + j + si * out_ch; + int output_row = (si) / out_w; + int output_col = (si) % out_w; + int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_; + int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; +#ifdef ENABLE_ARM +#ifdef ENABLE_ARM64 + float16x4_t mat_b = {mat_b_0[k], mat_b_1[k], mat_b_2[k], mat_b_3[k]}; +#else + float16x4_t mat_b; + asm volatile( + "vld1.16 %0[0], [%1]\n" + "vld1.16 %0[1], [%2]\n" + "vld1.16 %0[2], [%3]\n" + "vld1.16 %0[3], [%4]\n" + : "=w"(mat_b) + : "r"(mat_b_0 + k), "r"(mat_b_1 + k), "r"(mat_b_2 + k), "r"(mat_b_3 + k) + :); +#endif + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + float16x4_t mat_c = vld1_f16(c + dx_offset); + float16x4_t mat_a = vld1_f16(a + dy_offset); + mat_c = vfma_f16(mat_c, mat_b, mat_a); + vst1_f16(c + dx_offset, mat_c); + } +#else + for (int b = 0; b < batch; b++) { + int dx_offset = b * in_size + offset; + int dy_offset = b * out_size; + c[dx_offset + 0] += a[dy_offset + 0] * mat_b_0[k]; + c[dx_offset + 1] += a[dy_offset + 1] * mat_b_1[k]; + c[dx_offset + 2] += a[dy_offset + 2] * mat_b_2[k]; + c[dx_offset + 3] += a[dy_offset + 3] * mat_b_3[k]; + } +#endif + } + } + } + } + return j; +} + +int ConvDwInputGradFp16(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int count, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_h = conv_param->output_h_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int end = start + count; + int batch = conv_param->input_batch_; + int in_size = in_h * in_w * in_ch; + int out_size = out_h * out_w * out_ch; + + int j = start; + j = ConvDwInputGrad16(dy, w, dx, j, end, conv_param); + j = ConvDwInputGrad8(dy, w, dx, j, end, conv_param); + j = ConvDwInputGrad4(dy, w, dx, j, end, conv_param); + for (; j < end; j++) { + float16_t *c = dx + j; + const float16_t *b = w + j * k_spatial; + for (int si = 0; si < out_spatial; si++) { + const float16_t *a = dy + j + si * out_ch; + int output_row = si / out_w; + int output_col = si % out_w; + int row_stride_offset = -conv_param->pad_u_ + output_row * conv_param->stride_h_; + int col_stride_offset = -conv_param->pad_l_ + output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; + for (int bi = 0; bi < batch; bi++) { + c[bi * in_size + offset + 0] += a[0 + bi * out_size] * b[k]; + } + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/convolution_grad_input.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/convolution_grad_input.h new file mode 100644 index 0000000000000000000000000000000000000000..8582a5ea6b6d1f4244b828333d4e030bc4200649 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/convolution_grad_input.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ +#define NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ + +#include +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwInputGradFp16(const float16_t *dy, const float16_t *w, float16_t *dx, int start, int count, + const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/dropout_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/dropout_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..aee7a0fa273e5e5acb4f16b1016848623c032243 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/dropout_grad.c @@ -0,0 +1,24 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16_grad/dropout_grad.h" + +void DropoutFp16Grad(const float16_t *yt_ptr, const float16_t *mask, float16_t *output_ptr, int length, + float16_t scale) { + for (int i = 0; i < length; i++) { + output_ptr[i] = yt_ptr[i] * mask[i] * scale; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/dropout_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/dropout_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..624ed7d895e4f0e1d04163a24329bec83180ac37 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/dropout_grad.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_DROPOUT_GRAD_H_ +#define NNACL_FP16_GRAD_DROPOUT_GRAD_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void DropoutFp16Grad(const float16_t *yt_ptr, const float16_t *mask, float16_t *output_ptr, int length, + float16_t ratio); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_DROPOUT_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/gemm_fp16.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/gemm_fp16.c new file mode 100644 index 0000000000000000000000000000000000000000..ff0b08a3b2d9295f339bc1a6e6f7b4a75a6f1217 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/gemm_fp16.c @@ -0,0 +1,385 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16_grad/gemm_fp16.h" +#include +#ifdef __ARM_NEON +#include +#endif +#include "nnacl/fp16/matmul_fp16.h" +#include "nnacl/fp16/pack_fp16.h" + +#ifdef ENABLE_ARM64 +static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) { + size_t stride = col * 2; + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.8h}, [x10], %[stride]\n" + "ld1 {v1.8h}, [x10], %[stride]\n" + "ld1 {v2.8h}, [x10], %[stride]\n" + "ld1 {v3.8h}, [x10], %[stride]\n" + "ld1 {v4.8h}, [x10], %[stride]\n" + "ld1 {v5.8h}, [x10], %[stride]\n" + "ld1 {v6.8h}, [x10], %[stride]\n" + "ld1 {v7.8h}, [x10], %[stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[stride]\n" + "ld1 {v9.8h}, [x10], %[stride]\n" + "ld1 {v10.8h}, [x10], %[stride]\n" + "ld1 {v11.8h}, [x10], %[stride]\n" + "ld1 {v12.8h}, [x10], %[stride]\n" + "ld1 {v13.8h}, [x10], %[stride]\n" + "ld1 {v14.8h}, [x10], %[stride]\n" + "ld1 {v15.8h}, [x10], %[stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + "zip1 v18.8h, v12.8h, v13.8h\n" + "zip1 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + "zip2 v18.8h, v12.8h, v13.8h\n" + "zip2 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + : + : [ dst_c ] "r"(dst_ptr), [ src_c ] "r"(src_ptr), [ stride ] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif + +void AddMatrixFp16(const float16_t *restrict v1, float16_t *restrict v2, float16_t beta, int row, int col, int stride) { + const float16_t *src_ptr = v1; + float16_t *dst_ptr = v2; +#ifdef ENABLE_NEON + float16x8_t beta_0 = vdupq_n_f16(beta); +#endif + for (int r = 0; r < row; r++) { + int c = 0; +#ifdef ENABLE_NEON + for (; c <= (col - C8NUM); c += C8NUM) { + float16x8_t dst_0 = vld1q_f16(dst_ptr + c); + float16x8_t src_0 = vld1q_f16(src_ptr + c); + float16x8_t sum_0 = vfmaq_f16(dst_0, beta_0, src_0); + vst1q_f16(dst_ptr + c, sum_0); + } +#endif + for (; c < col; c++) { + dst_ptr[c] += beta * src_ptr[c]; + } + src_ptr += stride; + dst_ptr += stride; + } +} + +int MatSizeFp16(int row, int col, int round) { + int res = UP_ROUND(row, round) * col; + return res; +} + +int MatSizeTotalFp16(int row, int col, int deep, int stride) { +#ifdef ENABLE_ARM64 + const int num = C16NUM; +#else + const int num = C12NUM; +#endif + int res = MatSizeFp16(row, deep, num) + MatSizeFp16(col, deep, C8NUM); + if (stride > 0) res += row * stride; + return res; +} + +#ifdef ENABLE_ARM64 +static void RowMajor2Col16MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + size_t row_up_16 = UP_ROUND(row, C16NUM); + size_t row16 = row / C16NUM * C16NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src; + float16_t *dst_r = dst; + size_t ri = 0; + // find 16 block unit + for (; ri < row16; ri += C16NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; +#ifdef ENABLE_ARM64 + Row2Col16Block16(src_c, dst_c, stride); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * stride + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C16NUM; + for (size_t i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * stride]; + } + } + src_r += C16NUM * stride; + dst_r += C16NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += stride; + dst_r += 1; + } + for (; ri < row_up_16; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C16NUM] = 0; + } + dst_r += 1; + } + return; +} +#endif + +void RowMajor2Row16MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int c_div16 = c / C16NUM; + int c_mod16 = c % C16NUM; + dst[c_div16 * C16NUM * row + r * C16NUM + c_mod16] = src[r * stride + c]; + } + } +} + +void RowMajor2Col12MajorStrideFp16(const float16_t *src, float16_t *dst, size_t row, size_t col, int stride) { + size_t row_up_12 = UP_ROUND(row, C12NUM); + size_t row12 = row / C12NUM * C12NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src; + float16_t *dst_r = dst; + size_t ri = 0; + // transpose 12x8 + for (; ri < row12; ri += C12NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; +#ifdef ENABLE_ARM82_A32 + Transpose12x8A32Fp16(src_c, dst_c, stride * sizeof(float16_t), 24); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * stride + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + for (size_t i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * stride]; + } + } + src_r += C12NUM * stride; + dst_r += C12NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C12NUM] = src_r[i]; + } + src_r += stride; + dst_r += 1; + } + for (; ri < row_up_12; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + dst_r += 1; + } +} + +void RowMajor2Row12MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int c_div12 = c / C12NUM; + int c_mod12 = c % C12NUM; + dst[c_div12 * C12NUM * row + r * C12NUM + c_mod12] = src[r * stride + c]; + } + } +} + +static void RowMajor2Col8MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r_div8 = r / C8NUM; + int r_mod8 = r % C8NUM; + dst[r_div8 * C8NUM * col + c * C8NUM + r_mod8] = src[r * stride + c]; + } + } +} + +static void RowMajor2Row8MajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + const float16_t *src_ptr = src + r * stride; + int c = 0; + for (; c < col; c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst[cd8 * C8NUM * row + r * C8NUM + cm8] = src_ptr[c]; + } + for (; c < UP_ROUND(col, C8NUM); c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst[cd8 * C8NUM * row + r * C8NUM + cm8] = 0; + } + } + return; +} + +static void RowMajor2ColXMajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { +#ifdef ENABLE_ARM64 + RowMajor2Col16MajorStrideFp16(src, dst, row, col, stride); +#else + RowMajor2Col12MajorStrideFp16(src, dst, row, col, stride); +#endif +} + +static void RowMajor2RowXMajorStrideFp16(const float16_t *src, float16_t *dst, int row, int col, int stride) { +#ifdef ENABLE_ARM64 + RowMajor2Row16MajorStrideFp16(src, dst, row, col, stride); +#else + RowMajor2Row12MajorStrideFp16(src, dst, row, col, stride); +#endif +} + +void GemmMatmulFp16(int ta, int tb, int M, int N, int K, float16_t alpha, const float16_t *mat_a, int lda, + const float16_t *mat_b, int ldb, float16_t beta, float16_t *mat_c, int ldc, float16_t *workspace) { + GemmCbFp16 gcb; + gcb.atype = ActType_No; + gcb.ca = 0; + gcb.cb = 0; + gcb.bias = NULL; + GemmMatmulPlusFp16(ta, tb, M, N, K, alpha, mat_a, lda, mat_b, ldb, beta, mat_c, ldc, workspace, &gcb); +} + +void GemmMatmulPlusFp16(int ta, int tb, int M, int N, int K, float16_t alpha, const float16_t *mat_a, int lda, + const float16_t *mat_b, int ldb, float16_t beta, float16_t *mat_c, int ldc, + float16_t *workspace, GemmCbFp16 *gcb) { +#ifdef ENABLE_ARM64 + const int num = C16NUM; +#else + const int num = C12NUM; +#endif + float16_t *output = mat_c; + float16_t *fworkspace = workspace; + int incremental = (beta < 0.f) || (beta > 0.f); + float16_t *mat_a_input = (float16_t *)mat_a; + float16_t *mat_b_input = (float16_t *)mat_b; + + if (!gcb->ca) { + mat_a_input = fworkspace; + fworkspace += MatSizeFp16(M, K, num); + if (ta) { + RowMajor2RowXMajorStrideFp16(mat_a, mat_a_input, K, M, lda); + } else { + RowMajor2ColXMajorStrideFp16(mat_a, mat_a_input, M, K, lda); + } + } + if (!gcb->cb) { + mat_b_input = fworkspace; + fworkspace += MatSizeFp16(N, K, C8NUM); + if (tb) { + RowMajor2Col8MajorStrideFp16(mat_b, mat_b_input, N, K, ldb); + } else { + RowMajor2Row8MajorStrideFp16(mat_b, mat_b_input, K, N, ldb); + } + } + if (incremental) output = fworkspace; + MatMulFp16(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc); + if (incremental) AddMatrixFp16(output, mat_c, beta, M, N, ldc); + gcb->mat_a = mat_a_input; + gcb->mat_b = mat_b_input; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/gemm_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/gemm_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..82cc9fd677ed99d4b0b86d1b93295f8270a6a381 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/gemm_fp16.h @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_GEMM_FP16_H_ +#define NNACL_FP16_GRAD_GEMM_FP16_H_ + +#include +#include "nnacl/op_base.h" +#ifdef __cplusplus +extern "C" { +#endif +typedef struct { + int ca; + int cb; + ActType atype; + float16_t *bias; + float16_t *mat_a; + float16_t *mat_b; +} GemmCbFp16; + +void GemmMatmulFp16(int ta, int tb, int M, int N, int K, float16_t alpha, const float16_t *mat_a, int lda, + const float16_t *mat_b, int ldb, float16_t beta, float16_t *mat_c, int ldc, float16_t *workspace); +void GemmMatmulPlusFp16(int ta, int tb, int M, int N, int K, float16_t alpha, const float16_t *mat_a, int lda, + const float16_t *mat_b, int ldb, float16_t beta, float16_t *mat_c, int ldc, + float16_t *workspace, GemmCbFp16 *gcb); +int MatSizeFp16(int row, int col, int round); +int MatSizeTotalFp16(int row, int col, int deep, int inc); +void AddMatrixFp16(const float16_t *v1, float16_t *v2, float16_t beta, int row, int col, int stride); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_GEMM_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/layernorm_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/layernorm_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..e0795a9c74ad9bdf764d9d97e8c8874705310b72 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/layernorm_grad.c @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16_grad/layernorm_grad.h" +#include +#include + +void LayerNormFp16Grad(const float16_t *x, const float16_t *dy, const float16_t *var, const float16_t *mean, + const float16_t *gamma, int param_num, int param_size, int block_num, int block_size, + float16_t *dx, float16_t *dg, float16_t *db) { + // var is actually 1/sqrf(var)-> var^0.5 + NNACL_CHECK_ZERO_RETURN(block_size); + const float16_t *var_sqrt_rev = var; + for (size_t i = 0; i < param_num; ++i) { + float dgamma = 0.0f; + float dbeta = 0.0f; + for (size_t j = i; j < param_size * param_num; j += param_num) { + int norm_shift = (int)(j / block_size); + dgamma += dy[j] * var_sqrt_rev[norm_shift] * (x[j] - mean[norm_shift]); + dbeta += dy[j]; + } + dg[i] = (float16_t)dgamma; + db[i] = (float16_t)dbeta; + } + for (size_t i = 0; i < block_num; ++i) { + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { + int param_shift = j % param_num; + int norm_shift = (int)(j / block_size); + float16_t dxm = x[j] - mean[norm_shift]; + float16_t dyg = dy[j] * gamma[param_shift]; + sum1 += -0.5f * dyg * dxm * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift]; + sum2 += dyg; + sum3 += -2.0f * dxm; + } + for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { + int param_shift = j % param_num; + int norm_shift = (int)(j / block_size); + float16_t var_sqrt = var_sqrt_rev[norm_shift]; + float dx1 = dy[j] * gamma[param_shift] * var_sqrt; + float dx2 = sum1 * 2.0f / block_size * (x[j] - mean[norm_shift]); + float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size); + dx[j] = (float16_t)(dx1 + dx2 + dx3); + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/layernorm_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/layernorm_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..9f6c1c80d687dfacf2db6129cc3212773e398730 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/layernorm_grad.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_LAYERNORM_GRAD_H_ +#define NNACL_FP16_GRAD_LAYERNORM_GRAD_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void LayerNormFp16Grad(const float16_t *x, const float16_t *dy, const float16_t *var, const float16_t *mean, + const float16_t *gamma, int param_num, int param_size, int block_num, int block_size, + float16_t *dx, float16_t *dg, float16_t *db); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_GRAD_LAYERNORM_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/pack_fp16_ext.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/pack_fp16_ext.c new file mode 100644 index 0000000000000000000000000000000000000000..8d7b1bbc7b3fbe14b68155b777f6cb549ef48d36 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/pack_fp16_ext.c @@ -0,0 +1,201 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl/fp16_grad/pack_fp16_ext.h" + +void RollingIm2ColPackDwUnitFp16(const float16_t *in_data, const ConvParameter *conv_param, float16_t *data_col_orig, + int real_cal_num, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_; + const int stride = kernel_h * kernel_w; + + int kernel_row, kernel_col; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + float16_t *data_col = data_col_orig + i * channels * stride; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * channels; + for (int c = 0; c < channels; c++) { + data_col[c * stride] = in_data[offset + c]; + } + data_col++; + } else { + for (int c = 0; c < channels; c++) { + data_col[c * stride] = 0; + } + data_col++; + } + } + } + } +} + +void RollingIm2ColPackUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col; + + if (channels == 1) { + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + *packed_input = input_data[offset]; + packed_input++; + } else { + *packed_input = 0; + packed_input++; + } + } + } + } + } else { + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + memcpy(packed_input, input_data + offset, sizeof(float16_t) * channels); + packed_input += channels; + } else { + memset(packed_input, 0, sizeof(float16_t) * channels); + packed_input += channels; + } + } + } + } + } +} + +void RollingCol2ImPackUnitFp16(const float16_t *data_col, float16_t *data_im, const ConvParameter *conv_param, + int real_cal_num, int block_index) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col; + + if (channels == 1) { + for (int r = 0; r < real_cal_num; r++) { + int output_col = (block_index + r) % output_w; + int output_row = (block_index + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float16_t *data_im_ptr = data_im + offset; + *data_im_ptr += *data_col; + } + data_col++; + } + } + } + } else { + for (int r = 0; r < real_cal_num; r++) { + int output_col = (block_index + r) % output_w; + int output_row = (block_index + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float16_t *data_im_ptr = &data_im[offset]; + for (int i = 0; i < channels; i++) { + data_im_ptr[i] += data_col[i]; + } + } + data_col += channels; + } + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/pack_fp16_ext.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/pack_fp16_ext.h new file mode 100644 index 0000000000000000000000000000000000000000..3806d9410c6e089ba4f2e9bbfedbe2cb0650d068 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/pack_fp16_ext.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_PACK_FP16_EXT_H_ +#define NNACL_FP16_GRAD_PACK_FP16_EXT_H_ + +#include +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void RollingIm2ColPackUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index); +void RollingIm2ColPackDwUnitFp16(const float16_t *input_data, const ConvParameter *conv_param, float16_t *packed_input, + int real_cal_num, int block_index); +void RollingCol2ImPackUnitFp16(const float16_t *data_col, float16_t *data_im, const ConvParameter *conv_param, + int real_cal_num, int block_index); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_PACK_FP16_EXT_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/pooling_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/pooling_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..06d45e5c54a6abc95ee63b33ebea94be4ceb7ea2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/pooling_grad.c @@ -0,0 +1,192 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "nnacl/fp16_grad/pooling_grad.h" +#include "nnacl/op_base.h" + +void AvgPoolingFp16Grad(const float16_t *input_ptr, float16_t *output_ptr, int count, PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + + const float16_t kk = 1.0f / (float16_t)(win_h * win_w); +#if ENABLE_NEON + const float16x4_t factor = vdup_n_f16(kk); +#endif + for (int ib = 0; ib < count; ib++) { + float16_t *out = &output_ptr[(ib * in_h * in_w * channel)]; + const float16_t *inPtr = &input_ptr[(ib * output_h * output_w * channel)]; + // iterate over yt + for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); + for (int yw = 0; yw < output_w; yw++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic < channel - C4NUM; ic += C4NUM) { + int idx = (yw + yh * output_w) * channel + ic; +#ifdef ENABLE_NEON + float16x4_t in = vld1_f16(inPtr + idx); + float16x4_t delta = vmul_f16(in, factor); +#else + float16_t delta[C4NUM] = {inPtr[idx], inPtr[idx + C1NUM], inPtr[idx + C2NUM], inPtr[idx + C3NUM]}; + for (int i = 0; i < C4NUM; i++) delta[i] *= kk; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; +#ifdef ENABLE_NEON + float16_t *out_vec = out + (xw + in_w * xh) * channel + ic; + float16x4_t outr = vld1_f16(out + (xw + in_w * xh) * channel + ic); + float16x4_t outs = vadd_f16(outr, delta); + vst1_f16(out_vec, outs); +#else + + for (int i = 0; i < C4NUM; i++) { + out[(xw + in_w * xh) * channel + ic + i] += ((float16_t *)&delta)[i]; + } +#endif + } + } + } + for (; ic < channel; ic++) { + int idx = (yw + yh * output_w) * channel + ic; + float16_t delta = inPtr[idx] * kk; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + out[(xw + in_w * xh) * channel + ic] += delta; + } + } + } + } + } + } +} + +#ifdef ENABLE_NEON +static int32x4_t MaxIndex(float16x4_t in, float16x4_t *max, uint32x4_t index, uint32x4_t prev_index) { + uint16x4_t res = vcgt_f16(in, *max); + int16x4_t tmp = vreinterpret_s16_u16(res); + uint32x4_t res_tmp = vreinterpretq_u32_s32(vmovl_s16(tmp)); + int32x4_t m_index = vbslq_s32(res_tmp, index, prev_index); + *max = vbsl_f16(res, in, *max); + return m_index; +} +#endif + +void MaxPoolingFp16Grad(const float16_t *input_ptr, const float16_t *dy_ptr, float16_t *output_ptr, int output_batch, + PoolingParameter *pooling_param, const PoolingComputeParam *pooling_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + + for (int ib = 0; ib < output_batch; ib++) { + float16_t *out = &output_ptr[(ib * in_h * in_w * channel)]; + const float16_t *inPtr = &input_ptr[(ib * in_h * in_w * channel)]; + const float16_t *dyPtr = &dy_ptr[(ib * output_h * output_w * channel)]; + for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); + for (int yw = 0; yw < output_w; yw++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic < (channel & ~3); ic += C4NUM) { + int idx = (yw + yh * output_w) * channel + ic; +#ifdef ENABLE_NEON + uint32x4_t max_idx = vdupq_n_u32(0); + float16x4_t max_val = vdup_n_f16(-FLT16_MAX); + float16x4_t delta = vld1_f16(dyPtr + idx); +#else + float16_t delta[C4NUM] = {dyPtr[idx], dyPtr[idx + C1NUM], dyPtr[idx + C2NUM], dyPtr[idx + C3NUM]}; + float16_t max_val[C4NUM] = {-FLT16_MAX, -FLT16_MAX, -FLT16_MAX, -FLT16_MAX}; + uint max_idx[C4NUM] = {0}; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; +#ifdef ENABLE_NEON + uint32x4_t index = {val_idx, val_idx + 1, val_idx + 2, val_idx + 3}; + float16x4_t in = vld1_f16(inPtr + val_idx); + max_idx = MaxIndex(in, &max_val, index, max_idx); +#else + float16_t val[C4NUM] = {inPtr[val_idx], inPtr[val_idx + C1NUM], inPtr[val_idx + C2NUM], + inPtr[val_idx + C3NUM]}; + for (int i = 0; i < C4NUM; i++) { + if (val[i] > max_val[i]) { + max_val[i] = val[i]; + max_idx[i] = val_idx + i; + } + } +#endif + } + } + for (int i = 0; i < C4NUM; i++) { + out[((int *)&max_idx)[i]] += ((float16_t *)&delta)[i]; + } + } + for (; ic < channel; ic++) { + float16_t max_val = -FLT16_MAX; + int max_idx = 0; + int idx = (yw + yh * output_w) * channel + ic; + float16_t delta = dyPtr[idx]; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_e; kw < kw_s; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; + float16_t val = inPtr[val_idx]; + if (val > max_val) { + max_val = val; + max_idx = val_idx; + } + } + } + out[max_idx] += delta; + } + } + } + } +} diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/hash.cuh b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/pooling_grad.h old mode 100755 new mode 100644 similarity index 48% rename from mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/hash.cuh rename to mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/pooling_grad.h index 779abba36b1cec2e72e46063a2a537a92fc24e45..616cc043b48033786194f70c7a36c619a4ef089c --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/hash.cuh +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/pooling_grad.h @@ -14,14 +14,21 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_HASH_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_HASH_H_ +#ifndef NNACL_FP16_GRAD_POOLING_GRAD_H_ +#define NNACL_FP16_GRAD_POOLING_GRAD_H_ -template -void DoHashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size, - const int hash_dim, cudaStream_t cuda_stream); +#include "nnacl/fp16/pooling_fp16.h" +#include "nnacl/kernel/pooling.h" -template -void DoHashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size, - const int hash_dim, cudaStream_t cuda_stream); -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_HASH_H_ +#ifdef __cplusplus +extern "C" { +#endif +void AvgPoolingFp16Grad(const float16_t *input_ptr, float16_t *output_ptr, int count, PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args); +void MaxPoolingFp16Grad(const float16_t *input_ptr, const float16_t *dy_ptr, float16_t *output_ptr, int output_batch, + PoolingParameter *pooling_param, const PoolingComputeParam *pooling_args); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_POOLING_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/resize_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/resize_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..8a621f41611a98c865609913b0d00c1c0619399f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/resize_grad.c @@ -0,0 +1,146 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16_grad/resize_grad.h" +#include +#include "nnacl/infer/common_infer.h" + +int ResizeNearestNeighborFp16Grad(float16_t *in_addr, float16_t *out_addr, int batch_size, int channel, int format, + ResizeFp16GradParameter *param) { + bool align_corners = param->align_corners_; + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + if (format == Format_NHWC) { + NNACL_CHECK_ZERO_RETURN_ERR(param->in_width_); + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t in_y = i / param->in_width_; + size_t in_x = i % param->in_width_; + for (int32_t c = 0; c < channel; ++c) { + size_t out_y = MSMIN( + (align_corners) ? (size_t)roundf(in_y * param->height_scale_) : (size_t)floorf(in_y * param->height_scale_), + param->out_height_ - 1); + size_t out_x = MSMIN( + (align_corners) ? (size_t)roundf(in_x * param->width_scale_) : (size_t)floorf(in_x * param->width_scale_), + param->out_width_ - 1); + size_t out_offset = out_y * (param->out_width_ * channel) + (out_x * channel) + c; + size_t in_offset = in_y * (param->in_width_ * channel) + (in_x * channel) + c; + out_addr[out_offset] += in_addr[in_offset]; + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } + } else if (format == Format_NCHW) { + for (int32_t b = 0; b < batch_size; ++b) { + for (int32_t c = 0; c < channel; ++c) { + for (size_t h = 0; h < param->in_height_; ++h) { + size_t out_y = + MSMIN((align_corners) ? (size_t)roundf(h * param->height_scale_) : (size_t)floorf(h * param->height_scale_), + param->out_height_ - 1); + for (size_t w = 0; w < param->in_width_; ++w) { + size_t out_x = + MSMIN((align_corners) ? (size_t)roundf(w * param->width_scale_) : (size_t)floorf(w * param->width_scale_), + param->out_width_ - 1); + out_addr[out_y * param->out_width_ + out_x] += in_addr[h * param->in_width_ + w]; + } + } + out_addr += out_hw_size; + in_addr += in_hw_size; + } + } + } + return NNACL_OK; +} + +int ResizeBiLinearFp16Grad(float16_t *in_addr, float16_t *out_addr, int batch_size, int channel, int format, + ResizeFp16GradParameter *param) { + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + if (format == Format_NHWC) { + NNACL_CHECK_ZERO_RETURN_ERR(param->in_width_); + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t h = i / param->in_width_; + size_t w = i % param->in_width_; + for (int32_t c = 0; c < channel; ++c) { + float16_t in_y = (float16_t)h * param->height_scale_; + size_t top_y_index = MSMAX((size_t)(floorf(in_y)), (size_t)(0)); + size_t bottom_y_index = MSMIN((size_t)(ceilf(in_y)), param->out_height_ - 1); + float16_t y_lerp = in_y - floorf(in_y); + const float16_t inverse_y_lerp = 1.0 - y_lerp; + + float16_t in_x = (float16_t)w * param->width_scale_; + size_t left_x_index = MSMAX((size_t)(floorf(in_x)), (size_t)(0)); + size_t right_x_index = MSMIN((size_t)(ceilf(in_x)), param->out_width_ - 1); + float16_t x_lerp = in_x - floorf(in_x); + const float16_t inverse_x_lerp = 1.0 - x_lerp; + + size_t in_offset = h * (param->in_width_ * channel) + (w * channel) + c; + size_t out_offset_top_y_left_x = top_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_top_y_right_x = top_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + size_t out_offset_bottom_y_left_x = + bottom_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_bottom_y_right_x = + bottom_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + + out_addr[out_offset_top_y_left_x] += in_addr[in_offset] * (float16_t)(inverse_y_lerp * inverse_x_lerp); + out_addr[out_offset_top_y_right_x] += in_addr[in_offset] * (float16_t)(inverse_y_lerp * x_lerp); + out_addr[out_offset_bottom_y_left_x] += in_addr[in_offset] * (float16_t)(y_lerp * inverse_x_lerp); + out_addr[out_offset_bottom_y_right_x] += in_addr[in_offset] * (float16_t)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } + } else if (format == Format_NCHW) { + size_t in_height = param->in_height_; + size_t in_width = param->in_width_; + size_t out_height = param->out_height_; + size_t out_width = param->out_width_; + + for (size_t b = 0; b < batch_size; ++b) { + for (size_t c = 0; c < channel; ++c) { + for (size_t h = 0; h < in_height; ++h) { + const float16_t in_y = (float16_t)(h)*param->height_scale_; + const size_t top_y_index = MSMAX((size_t)floorf(in_y), 0); + const size_t bottom_y_index = MSMIN((size_t)ceilf(in_y), out_height - 1); + const float16_t y_lerp = in_y - floorf(in_y); + const float16_t inverse_y_lerp = 1.0 - y_lerp; + for (size_t w = 0; w < in_width; ++w) { + const float16_t in_x = (float16_t)(w)*param->width_scale_; + const size_t left_x_index = MSMAX((size_t)floorf(in_x), 0); + const size_t right_x_index = MSMIN((size_t)ceilf(in_x), out_width - 1); + const float16_t x_lerp = in_x - floorf(in_x); + const float16_t inverse_x_lerp = 1.0 - x_lerp; + out_addr[top_y_index * out_width + left_x_index] += + in_addr[h * in_width + w] * (float16_t)(inverse_y_lerp * inverse_x_lerp); + out_addr[top_y_index * out_width + right_x_index] += + in_addr[h * in_width + w] * (float16_t)(inverse_y_lerp * x_lerp); + out_addr[bottom_y_index * out_width + left_x_index] += + in_addr[h * in_width + w] * (float16_t)(y_lerp * inverse_x_lerp); + out_addr[bottom_y_index * out_width + right_x_index] += + in_addr[h * in_width + w] * (float16_t)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size; + in_addr += in_hw_size; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/resize_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/resize_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..9320f6f6787b79a67e55ffbb633f4b2a0ca3082a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/resize_grad.h @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_RESIZE_GRAD_H_ +#define NNACL_FP16_GRAD_RESIZE_GRAD_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct ResizeFp16GradParameter { + OpParameter op_parameter_; + bool align_corners_; + int method; + size_t in_height_; + size_t in_width_; + size_t out_height_; + size_t out_width_; + float16_t height_scale_; + float16_t width_scale_; +} ResizeFp16GradParameter; + +int ResizeNearestNeighborFp16Grad(float16_t *in_addr, float16_t *out_addr, int batch_size, int channel, int format, + ResizeFp16GradParameter *param); +int ResizeBiLinearFp16Grad(float16_t *in_addr, float16_t *out_addr, int batch_size, int channel, int format, + ResizeFp16GradParameter *param); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_GRAD_RESIZE_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/strided_slice_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/strided_slice_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..b8cfe6b50d862f21f61459d66404d65cbf8c7b9d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/strided_slice_grad.c @@ -0,0 +1,67 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16_grad/strided_slice_grad.h" +#include "nnacl/errorcode.h" + +static size_t CalcIndex(const int *shape, size_t size, int i, size_t pos) { + size_t res = 1; + for (size_t j = 0; j < size; j++) { + res *= shape[(i + 1) + j]; + } + NNACL_CHECK_ZERO_RETURN_ERR(res); + NNACL_CHECK_ZERO_RETURN_ERR(shape[i]); + return (pos / res % shape[i]); +} + +int DoStridedSliceFp16Grad(const float16_t *inputs, float16_t *output, const int *dx_shape, + StridedSliceParameter *param) { + if (inputs == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->num_axes_ > DIMENSION_7D) { + return NNACL_PARAM_INVALID; + } + + size_t size = 1; + int *s = param->strides_; + int *b = param->begins_; + for (int i = 0; i < DIMENSION_7D; i++) { + size *= param->in_shape_[i]; + } + + for (size_t pos = 0; pos < size; pos++) { + size_t i = CalcIndex(param->in_shape_, C6NUM, C0NUM, pos); + size_t j = CalcIndex(param->in_shape_, C5NUM, C1NUM, pos); + size_t k = CalcIndex(param->in_shape_, C4NUM, C2NUM, pos); + size_t l = CalcIndex(param->in_shape_, C3NUM, C3NUM, pos); + size_t m = CalcIndex(param->in_shape_, C2NUM, C4NUM, pos); + size_t n = CalcIndex(param->in_shape_, C1NUM, C5NUM, pos); + size_t o = CalcIndex(param->in_shape_, C0NUM, C6NUM, pos); + + size_t input_idx = + (i * s[C0NUM] + b[C0NUM]) * dx_shape[C1NUM] * dx_shape[C2NUM] * dx_shape[C3NUM] * dx_shape[C4NUM] * + dx_shape[C5NUM] * dx_shape[C6NUM] + + (j * s[C1NUM] + b[C1NUM]) * dx_shape[C2NUM] * dx_shape[C3NUM] * dx_shape[C4NUM] * dx_shape[C5NUM] * + dx_shape[C6NUM] + + (k * s[C2NUM] + b[C2NUM]) * dx_shape[C3NUM] * dx_shape[C4NUM] * dx_shape[C5NUM] * dx_shape[C6NUM] + + (l * s[C3NUM] + b[C3NUM]) * dx_shape[C4NUM] * dx_shape[C5NUM] * dx_shape[C6NUM] + + (m * s[C4NUM] + b[C4NUM]) * dx_shape[C5NUM] * dx_shape[C6NUM] + (n * s[C5NUM] + b[C5NUM]) * dx_shape[C6NUM] + + (o * s[C6NUM] + b[C6NUM]); + output[input_idx] = inputs[pos]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/strided_slice_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/strided_slice_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..fb82e8e6774ff2a231824ec37b752b7fec699dca --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/strided_slice_grad.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP16_GRAD_STRIDED_SLICE_GRAD_H_ +#define NNACL_FP16_GRAD_STRIDED_SLICE_GRAD_H_ + +#include "nnacl/op_base.h" +#include "nnacl/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoStridedSliceFp16Grad(const float16_t *inputs, float16_t *output, const int *dx_shape, + StridedSliceParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP16_GRAD_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/unsorted_segment_sum.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/unsorted_segment_sum.c new file mode 100644 index 0000000000000000000000000000000000000000..6c822edb6519cc8a197376e30a75698cfdc1add9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/unsorted_segment_sum.c @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp16_grad/unsorted_segment_sum.h" +#include "nnacl/errorcode.h" + +int UnsortedSegmentSumFp16(const float16_t *input, int unit_num, int input_dim1, const int *indices, float16_t *output, + int output_dim0, int output_dim1) { + NNACL_CHECK_ZERO_RETURN_ERR(input_dim1); + for (int i = 0; i < unit_num; ++i) { + int j = i / input_dim1; + int k = i % input_dim1; + + int index = indices[j]; + if (index < 0 || index >= output_dim0) { + continue; + } + int output_index = index * output_dim1 + k; + output[output_index] += input[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/unsorted_segment_sum.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/unsorted_segment_sum.h new file mode 100644 index 0000000000000000000000000000000000000000..4e12b7b4d8dd706256170ad63366897c1ed5b640 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp16_grad/unsorted_segment_sum.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP16_GRAD_UNSORTED_SEGMENT_SUM_H_ +#define NNACL_FP16_GRAD_UNSORTED_SEGMENT_SUM_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UnsortedSegmentSumFp16(const float16_t *input, int unit_num, int input_dim1, const int *indices, float16_t *output, + int output_dim0, int output_dim1); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP16_GRAD_UNSORTED_SEGMENT_SUM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/activation_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/activation_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a532839b556b33db259b8b33d8412320fa408b91 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/activation_fp32.c @@ -0,0 +1,292 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/activation_fp32.h" +#include "nnacl/fp32/exp_fp32.h" +#include "nnacl/errorcode.h" +#include "nnacl/activation_fp32_simd.h" + +int Fp32Relu(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Fp32Relu, i, src, length, dst); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : 0; + } + return NNACL_OK; +} + +int Int32Relu(const int32_t *src, int length, int32_t *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Int32Relu, i, src, length, dst); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : 0; + } + return NNACL_OK; +} + +int Fp32Relu6(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Fp32Relu6, i, src, length, dst); + + for (; i < length; ++i) { + if (src[i] < 0) { + dst[i] = 0; + } else { + dst[i] = src[i] > 6.0f ? 6.0f : src[i]; // relu 6.0 + } + } + return NNACL_OK; +} + +int Fp32Clip(const float *src, int length, float *dst, float min, float max) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Fp32Clip, i, src, length, dst, min, max); + + for (; i < length; ++i) { + if (src[i] < min) { + dst[i] = min; + } else { + dst[i] = src[i] > max ? max : src[i]; + } + } + return NNACL_OK; +} + +int Int32Clip(const int32_t *src, int length, int32_t *dst, int min, int max) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Int32Clip, i, src, length, dst, min, max); + + for (; i < length; ++i) { + if (src[i] < min) { + dst[i] = min; + } else { + dst[i] = src[i] > max ? max : src[i]; + } + } + return NNACL_OK; +} + +int LRelu(const float *src, int length, float *dst, float alpha) { + int i = 0; + + SIMD_RUN_NO_SCALAR(LRelu, i, src, length, dst, alpha); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (src[i] * alpha); + } + return NNACL_OK; +} + +int Sigmoid(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Sigmoid, i, src, length, dst); + + for (; i < length; ++i) { + simd_exp32(-src[i], dst + i); + dst[i] = 1.0f / (1.0f + dst[i]); + } + return NNACL_OK; +} + +float TanhOpt(float src) { + if (src > 5.0) { // src > 5.0, tanh(src) = 1.0f + return 1.0f; + } else if (src < -5.0) { // src < -5.0, tanh(src) = -1.0f + return -1.0f; + } else { + float square = src * src; + float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * src; + float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f; + return a / b; + } +} + +int Tanh(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Tanh, i, src, length, dst); + + for (; i < length; ++i) { + dst[i] = TanhOpt(src[i]); + } + return NNACL_OK; +} + +int Swish(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Swish, i, src, length, dst); + + for (; i < length; ++i) { + simd_exp32(-src[i], dst + i); + dst[i] = src[i] / (1.0f + dst[i]); + } + return NNACL_OK; +} + +int HSwish(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(HSwish, i, src, length, dst); + + for (; i < length; ++i) { + float in = src[i]; + float relu6 = MSMIN(MSMAX(in + C3NUM, 0), C6NUM); + dst[i] = in * relu6 / C6NUM; + } + return NNACL_OK; +} + +int HSigmoid(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(HSigmoid, i, src, length, dst); + + for (; i < length; ++i) { + float relu6 = MSMIN(MSMAX(src[i] + C3NUM, 0), C6NUM); + dst[i] = relu6 / C6NUM; + } + return NNACL_OK; +} + +int HardTanh(const float *src, int length, float *dst, float min_val, float max_val) { + if (max_val <= min_val) { + return NNACL_ERR; + } + int i = 0; + if (min_val == FLT_MIN) { + SIMD_RUN_NO_SCALAR(HardTanhNoLimitMin, i, src, length, dst, min_val, max_val); + + for (; i < length; ++i) { + dst[i] = src[i] > max_val ? max_val : src[i]; + } + } else if (max_val == FLT_MAX) { + SIMD_RUN_NO_SCALAR(HardTanhNoLimitMax, i, src, length, dst, min_val, max_val); + + for (; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : src[i]; + } + } else { + SIMD_RUN_NO_SCALAR(HardTanhLimitMinMax, i, src, length, dst, min_val, max_val); + + for (; i < length; ++i) { + dst[i] = src[i] < min_val ? min_val : (src[i] > max_val ? max_val : src[i]); + } + } + return NNACL_OK; +} + +int Gelu(const float *src, int length, float *dst, bool approximate) { + if (src == NULL || dst == NULL) { + return NNACL_ERR; + } + int i = 0; + if (approximate) { + SIMD_RUN_NO_SCALAR(GeluTanhApproximate, i, src, length, dst); + + // dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3))) + for (; i < length; i++) { + dst[i] = 0.5 * src[i] * (1.0 + TanhOpt((0.79788456080287f + 0.035677408136f * src[i] * src[i]) * src[i])); + } + } else { + SIMD_RUN_NO_SCALAR(GeluErfAPPROXIMATE, i, src, length, dst); + + for (; i < length; i++) { + dst[i] = + 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f)); // dst = 0.5 * x * (1.0 + x / 1.4142135623730951f)) + } + } + return NNACL_OK; +} + +int Softplus(const float *src, int length, float *dst) { + float log_max = 88.0; + int i = 0; + + SIMD_RUN_NO_SCALAR(Softplus, i, src, length, dst); + + for (; i < length; ++i) { + simd_exp32(src[i], dst + i); + if (src[i] > log_max) { + dst[i] = src[i]; + } else { + dst[i] = log1p(dst[i]); + } + } + return NNACL_OK; +} + +int Elu(const float *src, int length, float *dst, float alpha) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Elu, i, src, length, dst, alpha); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (expm1(src[i]) * alpha); + } + return NNACL_OK; +} + +void Celu(const float *src, int length, float *dst, float alpha) { + int i = 0; + + SIMD_RUN_NO_SCALAR(Celu, i, src, length, dst, alpha); + + for (; i < length; ++i) { + dst[i] = src[i] > 0 ? src[i] : (expm1(src[i] / alpha) * alpha); + } + return; +} + +int HardShrink(const float *src, int length, float *dst, float lambd) { + int i = 0; + const float neg_lambd = -1 * lambd; + SIMD_RUN_NO_SCALAR(HardShrink, i, src, length, dst, lambd); + + for (; i < length; ++i) { + dst[i] = src[i] >= neg_lambd && src[i] <= lambd ? 0 : src[i]; + } + return NNACL_OK; +} + +int SoftShrink(const float *src, int length, float *dst, float lambd) { + int i = 0; + const float neg_lambd = -1 * lambd; + SIMD_RUN_NO_SCALAR(SoftShrink, i, src, length, dst, lambd); + + for (; i < length; ++i) { + dst[i] = (src[i] > lambd) ? (src[i] - lambd) : ((src[i] < neg_lambd) ? (src[i] + lambd) : (0)); + } + return NNACL_OK; +} + +int SoftsignFp32Opt(const float *src, int length, float *dst) { + int i = 0; + + SIMD_RUN_NO_SCALAR(SoftsignFp32Opt, i, src, length, dst); + for (; i < length; ++i) { + dst[i] = src[i] / (1.0 + fabsf(src[i])); + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/activation_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/activation_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..1630a31c4eef20037eb0155560bf84fc51bdfd91 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/activation_fp32.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/int8/fixed_point.h" +#include "nnacl/activation_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Fp32Relu(const float *src, int length, float *dst); +int Int32Relu(const int32_t *src, int length, int32_t *dst); +int Fp32Relu6(const float *src, int length, float *dst); +int Fp32Clip(const float *src, int length, float *dst, float min, float max); +int Int32Clip(const int32_t *src, int length, int32_t *dst, int min, int max); +int LRelu(const float *src, int length, float *dst, float alpha); +int Sigmoid(const float *src, int length, float *dst); +int Tanh(const float *src, int length, float *dst); +int HSigmoid(const float *src, int length, float *dst); +int Swish(const float *src, int length, float *dst); +int HSwish(const float *src, int length, float *dst); +int HardTanh(const float *src, int length, float *dst, float min_val, float max_val); +int Gelu(const float *src, int length, float *dst, bool approximate); +int Softplus(const float *src, int length, float *dst); +int Elu(const float *src, int length, float *dst, float alpha); +void Celu(const float *src, int length, float *dst, float alpha); +float TanhOpt(float src); +int HardShrink(const float *src, int length, float *dst, float lambd); +int SoftShrink(const float *src, int length, float *dst, float lambd); +int SoftsignFp32Opt(const float *src, int length, float *dst); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_ACTIVATION_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/activation_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/activation_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..0e989d5bb977c2d22995d67334eb55e3cf1ccef7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/activation_fp32_simd.h.in @@ -0,0 +1,289 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int Fp32Relu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 zero = SIMD_SET0_F32; + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_MAX_F32(SIMD_LD_F32(src + index), zero)); + } + return index; +} + +static inline int Int32Relu@SIMD_INSTRUCTION@(int index, const int32_t *src, int length, int32_t *dst) { + SIMD_EPI32 zero = SIMD_MOV_EPI32(0.0f); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(dst + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(src + index), zero)); + } + return index; +} + +static inline int Fp32Relu6@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 zero = SIMD_SET0_F32; + SIMD_F32 six = SIMD_MOV_F32(6.0f); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_CLAMP_F32(SIMD_LD_F32(src + index), zero, six)); + } + return index; +} + +static inline int Fp32Clip@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float min, float max) { + SIMD_F32 min_val = SIMD_MOV_F32(min); + SIMD_F32 max_val = SIMD_MOV_F32(max); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_CLAMP_F32(SIMD_LD_F32(src + index), min_val, max_val)); + } + return index; +} + +static inline int Int32Clip@SIMD_INSTRUCTION@(int index, const int32_t *src, int length, int32_t *dst, int min, int max) { + SIMD_EPI32 min_val = SIMD_MOV_EPI32(min); + SIMD_EPI32 max_val = SIMD_MOV_EPI32(max); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(dst + index, SIMD_CLAMP_EPI32(SIMD_LD_EPI32(src + index), min_val, max_val)); + } + return index; +} + +static inline int LRelu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float alpha) { + SIMD_F32 alpha_data = SIMD_MOV_F32(alpha); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_MASK mask = SIMD_CMPGT_F32(SIMD_SET0_F32, src_tmp); + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_F32(src_tmp, alpha_data), mask)); + } + return index; +} + +static inline int Sigmoid@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EXP_ST_F32(SIMD_SUB_F32(SIMD_SET0_F32, (SIMD_LD_F32(src + index))), dst + index); + SIMD_ST_F32(dst + index, + SIMD_DIV_F32(SIMD_MOV_F32(1.0f), SIMD_ADD_F32(SIMD_MOV_F32(1.0f), SIMD_LD_F32(dst + index)))); + } + return index; +} + +static inline int Softplus@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 log_max = SIMD_MOV_F32(88.0); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_F32 dst_tmp = SIMD_EXP_F32(src_tmp); + dst_tmp = SIMD_LOG_F32(SIMD_ADD_F32(SIMD_MOV_F32(1.0f), dst_tmp)); + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(dst_tmp, src_tmp, SIMD_CMPGT_F32(src_tmp, log_max))); + } + return index; +} + +static inline int Tanh@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(src + index); + SIMD_ST_F32(dst + index, SIMD_TANH_F32(input)); + } + return index; +} + +static inline int Swish@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_value = SIMD_LD_F32(src + index); + SIMD_EXP_ST_F32(SIMD_SUB_F32(SIMD_SET0_F32, src_value), dst + index); + SIMD_ST_F32(dst + index, + SIMD_DIV_F32(src_value, SIMD_ADD_F32(SIMD_MOV_F32(1.0f), SIMD_LD_F32(dst + index)))); + } + return index; +} + +static inline int HSwish@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_value = SIMD_LD_F32(src + index); + SIMD_F32 relu6 = SIMD_CLAMP_N_F32(SIMD_ADD_N_F32(src_value, 3), 0, 6); + SIMD_ST_F32(dst + index, SIMD_DIV_N_F32(SIMD_MUL_F32(src_value, relu6), 6)); + } + return index; +} + +static inline int HSigmoid@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_value = SIMD_LD_F32(src + index); + SIMD_F32 relu6 = SIMD_CLAMP_N_F32(SIMD_ADD_N_F32(src_value, 3), 0, 6); + SIMD_ST_F32(dst + index, SIMD_DIV_N_F32(relu6, 6)); + } + return index; +} + +static inline int HardTanhNoLimitMin@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float min_val, + float max_val) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_MIN_N_F32(SIMD_LD_F32(src + index), max_val)); + } + return index; +} + +static inline int HardTanhNoLimitMax@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float min_val, + float max_val) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_MAX_N_F32(SIMD_LD_F32(src + index), min_val)); + } + return index; +} + +static inline int HardTanhLimitMinMax@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float min_val, + float max_val) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(dst + index, SIMD_CLAMP_N_F32(SIMD_LD_F32(src + index), min_val, max_val)); + } + return index; +} + +static inline int GeluTanhApproximate@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in = SIMD_LD_F32(src + index); + SIMD_F32 tmp1 = SIMD_FMADD_F32(SIMD_MUL_N_F32(in, 0.035677408136f), in, SIMD_MOV_F32(0.79788456080287f)); + SIMD_F32 tmp2 = SIMD_MUL_F32(tmp1, in); + SIMD_ST_F32(dst + index, SIMD_MUL_F32(SIMD_MUL_N_F32(in, 0.5f), SIMD_ADD_N_F32(SIMD_TANH_F32(tmp2), 1.0f))); + } + return index; +} + +static inline int Gelu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 para1 = SIMD_MOV_F32(1.4142135623730951f); + SIMD_F32 para2 = SIMD_MOV_F32(1.0f); + SIMD_F32 para3 = SIMD_MOV_F32(0.5f); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in = SIMD_LD_F32(src + index); + SIMD_F32 res = SIMD_MUL_F32(SIMD_MUL_F32(para3, in), SIMD_ADD_F32(para2, SIMD_ERF_F32(SIMD_DIV_F32(in, para1)))); + SIMD_ST_F32(dst + index, res); + } + return index; +} + +static inline SIMD_F32 SIMD_ERFCCHEB@SIMD_INSTRUCTION@(SIMD_F32 src) { + static const int ncof = 7; + const double cof[7] = {-1.3026537197817094, 6.4196979235649026e-1, 1.9476473204185836e-2, -9.561514786808631e-3, + -9.46595344482036e-4, 3.66839497852761e-4, 4.2523324806907e-5}; + SIMD_F32 dst; + SIMD_F32 d = SIMD_SET0_F32; + SIMD_F32 dd = SIMD_SET0_F32; + SIMD_F32 t = SIMD_DIV_F32(SIMD_MOV_F32(2.0f), SIMD_ADD_F32(src, SIMD_MOV_F32(2.0f))); + SIMD_F32 ty = SIMD_SUB_F32(SIMD_MUL_F32(SIMD_MOV_F32(4.0f), t), SIMD_MOV_F32(2.0f)); + + for (int j = ncof - 1; j > 0; j--) { + SIMD_F32 tmp = d; + d = SIMD_SUB_F32(SIMD_FMADD_F32(ty, d, SIMD_MOV_F32(cof[j])), dd); + dd = tmp; + } + + dst = + SIMD_FMADD_F32(src, src, MS_FSMUL_F32(dd, SIMD_FMADD_F32(ty, d, SIMD_MOV_F32(cof[0])), SIMD_MOV_F32(0.5f))); + dst = SIMD_MUL_F32(t, SIMD_EXP_F32(SIMD_MUL_F32(SIMD_MOV_F32(-1.0f), dst))); + return dst; +} + +static inline SIMD_F32 SIMD_ERF_APPROXIMATE@SIMD_INSTRUCTION@(SIMD_F32 src) { + SIMD_F32 abs_src = SIMD_ABS_F32(src); + SIMD_F32 sign = SIMD_GETSIGN_F32(src); + SIMD_F32 dst = SIMD_ERFCCHEB@SIMD_INSTRUCTION@(abs_src); + return SIMD_MUL_F32(sign, SIMD_SUB_F32(SIMD_MOV_F32(1.0f), dst)); +} + +static inline int GeluErfAPPROXIMATE@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + SIMD_F32 para1 = SIMD_MOV_F32(1.4142135623730951f); + SIMD_F32 para2 = SIMD_MOV_F32(1.0f); + SIMD_F32 para3 = SIMD_MOV_F32(0.5f); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in = SIMD_LD_F32(src + index); + SIMD_F32 res = SIMD_MUL_F32(SIMD_MUL_F32(para3, in), SIMD_ADD_F32(para2, SIMD_ERF_APPROXIMATE@SIMD_INSTRUCTION@(SIMD_DIV_F32(in, para1)))); + SIMD_ST_F32(dst + index, res); + } + return index; +} + +static inline int Elu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float alpha) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_F32 exp_tmp = SIMD_SUB_N_F32(SIMD_EXP_F32(src_tmp), 1.0f); + SIMD_MASK mask = SIMD_CMPLE_F32(src_tmp, SIMD_SET0_F32); + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_N_F32(exp_tmp, alpha), mask)); + } + return index; +} + +static inline int Celu@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float alpha) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_F32 exp_tmp = SIMD_SUB_N_F32(SIMD_EXP_F32(SIMD_DIV_N_F32(src_tmp, alpha)), 1.0f); + SIMD_MASK mask = SIMD_CMPLE_F32(src_tmp, SIMD_SET0_F32); + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src_tmp, SIMD_MUL_N_F32(exp_tmp, alpha), mask)); + } + return index; +} + +static inline int HardShrink@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float lambd) { + SIMD_F32 pos_lamdb_v = SIMD_MOV_F32(lambd); + SIMD_F32 neg_lamdb_v = SIMD_MOV_F32(-lambd); + + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_t = SIMD_LD_F32(src + index); + /* v0 = (in > lamdb) & in */ + SIMD_F32 value0 = SIMD_AND_MASK_F32(SIMD_CMPGT_F32(src_t, pos_lamdb_v), src_t); + /* v1 = (in < -lamdb) & in */ + SIMD_F32 value1 = SIMD_AND_MASK_F32(SIMD_CMPLT_F32(src_t, neg_lamdb_v), src_t); + /* out = (v0 | v1) */ + SIMD_ST_F32(dst + index, SIMD_OR_F32(value0, value1)); + } + return index; +} + +static inline int SoftShrink@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst, float lambd) { + SIMD_F32 pos_lamdb_v = SIMD_MOV_F32(lambd); + SIMD_F32 neg_lamdb_v = SIMD_MOV_F32(-lambd); + + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_t = SIMD_LD_F32(src + index); + /* v0 = (in > lamdb) & (in - lamdb) */ + SIMD_F32 value0 = SIMD_AND_MASK_F32(SIMD_CMPGT_F32(src_t, pos_lamdb_v), SIMD_SUB_F32(src_t, pos_lamdb_v)); + /* v1 = (in < -lamdb) & (in + lamdb) */ + SIMD_F32 value1 = SIMD_AND_MASK_F32(SIMD_CMPLT_F32(src_t, neg_lamdb_v), SIMD_ADD_F32(src_t, pos_lamdb_v)); + /* out = (v0 | v1) */ + SIMD_ST_F32(dst + index, SIMD_OR_F32(value0, value1)); + } + return index; +} + +static inline int SoftsignFp32Opt@SIMD_INSTRUCTION@(int index, const float *src, int length, float *dst) { + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src_tmp = SIMD_LD_F32(src + index); + SIMD_F32 divisor_tmp = SIMD_ADD_F32(SIMD_MOV_F32(1.0f), SIMD_ABS_F32(src_tmp)); + SIMD_ST_F32(dst + index, SIMD_DIV_F32(src_tmp, divisor_tmp)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adam_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adam_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8e84d111db5ecf81ef172d6afaa273852beaf0ce --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adam_fp32.c @@ -0,0 +1,239 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/exp_fp32.h" +#include "nnacl/fp32/adam_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" +#ifdef ENABLE_AVX512 +#include "nnacl/avx512/adam_fp32_avx512.h" +#endif + +int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient, + size_t start, size_t end, bool use_nesterov) { + size_t c1 = start; +#ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; + __m256 coeff1_r = _mm256_set1_ps(1 - beta1); + __m256 coeff2_r = _mm256_set1_ps(1 - beta2); + __m256 beta1_r = _mm256_set1_ps(beta1); + __m256 lr_r = _mm256_set1_ps(lr); + __m256 epsi_r = _mm256_set1_ps(epsilon); + + float *var_ptr = var + start; + float *m_ptr = m + start; + float *v_ptr = v + start; + const float *grad_ptr = gradient + start; + + __m256 avx_r0, avx_r1; + __m256 var_r, m_r, v_r, grad_r; + + for (; c1 < start + c8; c1 += C8NUM) { + grad_r = _mm256_loadu_ps(grad_ptr); + m_r = _mm256_loadu_ps(m_ptr); + avx_r0 = _mm256_sub_ps(grad_r, m_r); + avx_r1 = _mm256_mul_ps(avx_r0, coeff1_r); + m_r = _mm256_add_ps(m_r, avx_r1); + _mm256_storeu_ps(m_ptr, m_r); + + v_r = _mm256_loadu_ps(v_ptr); + avx_r0 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), v_r); + v_r = _mm256_add_ps(v_r, _mm256_mul_ps(avx_r0, coeff2_r)); + _mm256_storeu_ps(v_ptr, v_r); + + if (use_nesterov) { + avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r)); + avx_r1 = _mm256_mul_ps(lr_r, avx_r0); + avx_r0 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r); + __m256 avx_r2 = _mm256_div_ps(avx_r1, avx_r0); + + var_r = _mm256_loadu_ps(var_ptr); + var_r = _mm256_sub_ps(var_r, avx_r2); + _mm256_storeu_ps(var_ptr, var_r); + } else { + avx_r0 = _mm256_mul_ps(lr_r, m_r); + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r); + __m256 avx_r2 = _mm256_div_ps(avx_r0, avx_r1); + var_r = _mm256_loadu_ps(var_ptr); + var_r = _mm256_sub_ps(var_r, avx_r2); + _mm256_storeu_ps(var_ptr, var_r); + } + m_ptr += C8NUM; + v_ptr += C8NUM; + var_ptr += C8NUM; + grad_ptr += C8NUM; + } +#endif + + // remaining + for (; c1 < end; c1++) { + m[c1] += (gradient[c1] - m[c1]) * (1 - beta1); + v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * (1 - beta2); + if (use_nesterov) { + var[c1] -= lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon); + } else { + var[c1] -= lr * m[c1] / (sqrt(v[c1]) + epsilon); + } + } + return NNACL_OK; +} + +int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon, + const float *gradient, size_t start, size_t end, bool use_nesterov) { + size_t c1 = start; +#ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; + __m256 coeff1_r = _mm256_set1_ps(1.0f - beta1); + __m256 coeff2_r = _mm256_set1_ps(1.0f - beta2); + __m256 beta1_r = _mm256_set1_ps(beta1); + __m256 beta2_r = _mm256_set1_ps(beta2); + __m256 lr_r = _mm256_set1_ps(-lr); + __m256 epsi_r = _mm256_set1_ps(epsilon); + + float *m_ptr = m + start; + float *v_ptr = v + start; + float *delta_ptr = delta + start; + const float *gradient_ptr = gradient + start; + + __m256 m_r, v_r, delta_r, grad_r; + __m256 avx_r0, avx_r1; + for (; c1 < start + c8; c1 += C8NUM) { + m_r = _mm256_loadu_ps(m_ptr); + avx_r0 = _mm256_mul_ps(m_r, beta1_r); + grad_r = _mm256_loadu_ps(gradient_ptr); + m_r = _mm256_add_ps(avx_r0, _mm256_mul_ps(coeff1_r, grad_r)); + _mm256_storeu_ps(m_ptr, m_r); + + v_r = _mm256_loadu_ps(v_ptr); + avx_r0 = _mm256_mul_ps(v_r, beta2_r); + avx_r1 = _mm256_mul_ps(_mm256_mul_ps(coeff2_r, grad_r), grad_r); + v_r = _mm256_add_ps(avx_r0, avx_r1); + _mm256_storeu_ps(v_ptr, v_r); + + if (use_nesterov) { + avx_r0 = _mm256_add_ps(_mm256_mul_ps(m_r, beta1_r), _mm256_mul_ps(coeff1_r, grad_r)); + avx_r0 = _mm256_mul_ps(lr_r, avx_r0); + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r); + delta_r = _mm256_div_ps(avx_r0, avx_r1); + _mm256_storeu_ps(delta_ptr, delta_r); + } else { + avx_r0 = _mm256_mul_ps(lr_r, m_r); + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(v_r), epsi_r); + delta_r = _mm256_div_ps(avx_r0, avx_r1); + _mm256_storeu_ps(delta_ptr, delta_r); + } + m_ptr += C8NUM; + v_ptr += C8NUM; + delta_ptr += C8NUM; + gradient_ptr += C8NUM; + } +#endif + + // remaining + for (; c1 < end; ++c1) { + m[c1] *= beta1; + m[c1] += (1 - beta1) * gradient[c1]; + v[c1] *= beta2; + v[c1] += (1 - beta2) * gradient[c1] * gradient[c1]; + if (use_nesterov) { + delta[c1] = -lr * (m[c1] * beta1 + (1 - beta1) * gradient[c1]) / (sqrt(v[c1]) + epsilon); + } else { + delta[c1] = -lr * m[c1] / (sqrt(v[c1]) + epsilon); + } + } + return NNACL_OK; +} + +int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + const float *gradient, size_t start, size_t end) { + size_t c1 = start; + SIMD_RUN_AVX512(AdamWeightDecayFp32, c1, var, m, v, lr, beta1, beta2, epsilon, decay, gradient, end); + + // remaining + const float beta1_minus = 1 - beta1; + const float beta2_minus = 1 - beta2; + for (; c1 < end; c1++) { + m[c1] += (gradient[c1] - m[c1]) * beta1_minus; + v[c1] += (gradient[c1] * gradient[c1] - v[c1]) * beta2_minus; + var[c1] -= lr * (m[c1] / (sqrt(v[c1]) + epsilon) + decay * var[c1]); + } + return NNACL_OK; +} + +size_t FusedCastAdamFp32Fp16(float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end) { + size_t c1 = start; + + SIMD_RUN_AVX512(FusedCastAdamFp32Fp16, c1, var, gradient16, m, v, lr, beta1, beta2, epsilon, decay, + global_norm_reciprocal, end); + return c1; +} + +size_t FusedCastAdamFp32Fp32(float *var, const float *gradient32, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end) { + size_t c1 = start; + + SIMD_RUN_AVX512(FusedCastAdamFp32Fp32, c1, var, gradient32, m, v, lr, beta1, beta2, epsilon, decay, + global_norm_reciprocal, end); + return c1; +} + +size_t FusedCastAdamFp16Fp16(int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end) { + size_t c1 = start; + SIMD_RUN_AVX512(FusedCastAdamFp16Fp16, c1, var16, gradient16, m, v, lr, beta1, beta2, epsilon, decay, + global_norm_reciprocal, end); + return c1; +} + +size_t FusedCastAdamFp16Fp32(int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end) { + size_t c1 = start; + SIMD_RUN_AVX512(FusedCastAdamFp16Fp32, c1, var16, gradient32, m, v, lr, beta1, beta2, epsilon, decay, + global_norm_reciprocal, end); + return c1; +} + +int DoAdam(float *m, float *v, const float *gradient, float *weight, float beta1, float beta2, float *beta1_power, + float *beta2_power, float eps, float learning_rate, bool nesterov, int start, int end) { + if ((1.f - beta1_power[0]) <= 0.0f) { + return NNACL_PARAM_INVALID; + } + if ((1.f - beta2_power[0]) < 0.0f) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + + float update_lr = learning_rate * sqrtf(1.f - beta2_power[0]) / (1.f - beta1_power[0]); + const float one_minus_beta1 = 1.f - beta1; + const float one_minus_beta2 = 1.f - beta2; + if (nesterov) { // Nadam + for (int i = start; i < end; ++i) { + m[i] += (gradient[i] - m[i]) * one_minus_beta1; + v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; + weight[i] -= update_lr * (m[i] * beta1 + one_minus_beta1 * gradient[i]) / (sqrtf(v[i]) + eps); + } + } else { + for (int i = start; i < end; ++i) { + m[i] += (gradient[i] - m[i]) * one_minus_beta1; + v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; + weight[i] -= update_lr * m[i] / (sqrtf(v[i]) + eps); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adam_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adam_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..f843e19bb57dc929255febb4abfc219ce156c68a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adam_fp32.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ADAM_FP32_H +#define MINDSPORE_NNACL_ADAM_FP32_H +#include +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int AdamFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, const float *gradient, + size_t start, size_t end, bool use_nesterov); +int AdamDeltaFp32(float *delta, float *m, float *v, float lr, float beta1, float beta2, float epsilon, + const float *gradient, size_t start, size_t end, bool use_nesterov); +int AdamWeightDecayFp32(float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + const float *gradient, size_t start, size_t end); +size_t FusedCastAdamFp32Fp16(float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end); +size_t FusedCastAdamFp32Fp32(float *var, const float *gradient32, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end); +size_t FusedCastAdamFp16Fp16(int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end); +size_t FusedCastAdamFp16Fp32(int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1, + float beta2, float epsilon, float decay, float global_norm_reciprocal, size_t start, + size_t end); + +int DoAdam(float *m, float *v, const float *gradient, float *weight, float beta1, float beta2, float *beta1_power, + float *beta2_power, float eps, float learning_rate, bool nesterov, int start, int end); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADAM_FP32_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adam_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adam_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..d1ea83ac68444d59ebbfc2ad028a0eca137e3076 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adam_fp32_simd.h.in @@ -0,0 +1,203 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ADAM_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ADAM_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ +#ifdef MS_SIMD_AVX512 + static inline size_t AdamWeightDecayFp32@SIMD_INSTRUCTION@(size_t index, float *var, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + const float *gradient, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_LD_F32(var + index); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_LD_F32(gradient + index); + + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + SIMD_ST_F32(var + index, var_r); + } + + return index; +} + +static inline size_t FusedCastAdamFp32Fp16@SIMD_INSTRUCTION@(size_t index, float *var, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + float global_norm_reciprocal, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_LD_F32(var + index); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(gradient16 + index)); + + g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r); + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(var + index, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + } + + return index; +} + +static inline size_t FusedCastAdamFp32Fp32@SIMD_INSTRUCTION@(size_t index, float *var, const float *gradient32, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + float global_norm_reciprocal, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_LD_F32(var + index); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_LD_F32(gradient32 + index); + + g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r); + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(var + index, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + } + + return index; +} + +static inline size_t FusedCastAdamFp16Fp16@SIMD_INSTRUCTION@(size_t index, int16_t *var16, const int16_t *gradient16, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + float global_norm_reciprocal, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(var16)); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(gradient16 + index)); + g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r); + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + SIMD_ST_HALF_EPI32(var16 + index, SIMD_F32_TO_F16(var_r, 0)); + } + + return index; +} + +static inline size_t FusedCastAdamFp16Fp32@SIMD_INSTRUCTION@(size_t index, int16_t *var16, const float *gradient32, float *m, float *v, float lr, float beta1, float beta2, float epsilon, float decay, + float global_norm_reciprocal, size_t end) { + SIMD_F32 beta1_r = SIMD_MOV_F32(beta1); + SIMD_F32 beta2_r = SIMD_MOV_F32(beta2); + SIMD_F32 beta1_minus_r = SIMD_MOV_F32(1.0f - beta1); + SIMD_F32 beta2_minus_r = SIMD_MOV_F32(1.0f - beta2); + SIMD_F32 lr_neg_r = SIMD_MOV_F32(-lr); + SIMD_F32 epsilon_r = SIMD_MOV_F32(epsilon); + SIMD_F32 decay_r = SIMD_MOV_F32(decay); + SIMD_F32 global_norm_reciprocal_r = SIMD_MOV_F32(global_norm_reciprocal); + + for (size_t block_max_size = end - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 var_r = SIMD_F16_TO_F32(SIMD_LD_HALF_EPI32(var16)); + SIMD_F32 m_r = SIMD_LD_F32(m + index); + SIMD_F32 v_r = SIMD_LD_F32(v + index); + SIMD_F32 g_r = SIMD_LD_F32(gradient32 + index); + g_r = SIMD_MUL_F32(g_r, global_norm_reciprocal_r); + m_r = SIMD_MUL_F32(m_r, beta1_r); + v_r = SIMD_MUL_F32(v_r, beta2_r); + SIMD_F32 avx_r0 = SIMD_MUL_F32(g_r, g_r); + m_r = SIMD_FMADD_F32(g_r, beta1_minus_r, m_r); + v_r = SIMD_FMADD_F32(avx_r0, beta2_minus_r, v_r); + avx_r0 = SIMD_SQRT_F32(v_r); + avx_r0 = SIMD_DIV_F32(m_r, SIMD_ADD_F32(avx_r0, epsilon_r)); + avx_r0 = SIMD_FMADD_F32(var_r, decay_r, avx_r0); + var_r = SIMD_FMADD_F32(avx_r0, lr_neg_r, var_r); + SIMD_ST_F32(m + index, m_r); + SIMD_ST_F32(v + index, v_r); + SIMD_ST_HALF_EPI32(var16 + index, SIMD_F32_TO_F16(var_r, 0)); + } + + return index; +} +#endif + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/add_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/add_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..94d613076f919e28c79f27796043461d1ad6eaa5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/add_fp32.c @@ -0,0 +1,156 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/add_fp32.h" +#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/add_fp32_simd.h" + +int ElementOptAdd(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAdd, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] + in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAdd, index, in1, in0, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptAddExt(const float *in0, const float *in1, const float alpha, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAddExtNum0, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[0] + in1[index] * alpha; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAddExtNum1, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[0] * alpha; + } + } + return NNACL_OK; +} + +int ElementOptAddInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAddInt, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] + in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAddInt, index, in1, in0, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAddRelu, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[0] + in1[index], 0); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAddRelu, index, in1, in0, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[index] + in1[0], 0); + } + } + return NNACL_OK; +} + +int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptAddRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] + in1[index], 0), 6); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptAddRelu6, index, in1, in0, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] + in1[0], 0), 6); + } + } + return NNACL_OK; +} + +int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param) { + TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); + return ElementAdd(tile_in0, tile_in1, out, size); +} + +int ElementAdd(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAdd, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[index]; + } + return NNACL_OK; +} + +int ElementAddExt(const float *in0, const float *in1, const float alpha, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAddExt, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[index] * alpha; + } + return NNACL_OK; +} + +int ElementAddRelu(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAddRelu, index, in0, in1, out, size); + for (; index < size; index++) { + float res = in0[index] + in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementAddRelu6(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAddRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] + in1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementAddInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementAddInt, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] + in1[index]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/add_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/add_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..82f50f8244c5afd43a34a4d0c7761b3d2bcc6a90 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/add_fp32.h @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ADD_H_ +#define MINDSPORE_NNACL_FP32_ADD_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/base/arithmetic_base.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ElementAdd(const float *in0, const float *in1, float *out, int size); +int ElementAddExt(const float *in0, const float *in1, const float alpha, float *out, int size); +int ElementOptAddExt(const float *in0, const float *in1, const float alpha, float *out, int size, bool first_scalar); +int ElementAddRelu(const float *in0, const float *in1, float *out, int size); +int ElementAddRelu6(const float *in0, const float *in1, float *out, int size); +int ElementAddInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptAdd(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptAddInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ADD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/add_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/add_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..8da28acd94a9eb8702e4509415f3c214fec77da5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/add_fp32_simd.h.in @@ -0,0 +1,153 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ADD_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ADD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ElementOptAdd@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin0_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_ADD_F32(vin0_, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddExtNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, int size) { + SIMD_F32 vin0 = SIMD_MOV_F32(in0[0]); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddExtNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, int size) { + SIMD_F32 vin1 = SIMD_MOV_F32(in1[0]); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, + int size) { + SIMD_EPI32 vin0_ = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_ADD_EPI32(vin0_, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_ADD_F32(vin0_, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptAddRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_ADD_F32(vin0_, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAdd@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAddExt@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAddRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_ADD_F32(vin0, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAddRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_ADD_F32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementAddInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_ADD_EPI32(vin0, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adder_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adder_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d9b567f9d1f3fad72591affb4ac51815fa63fc67 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adder_fp32.c @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/adder_fp32.h" +#include +#include +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32/conv_common_fp32.h" + +void Adder12x4(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / 12, r12mod = r % 12; + int c4div = c / 4, c4mod = c % 4; + size_t ci = r * stride + c; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * 12 + d * 12 + r12mod; + size_t bi = c4div * deep * 4 + d * 4 + c4mod; + value += fabsf(a[ai] - b[bi]); + } + value = -value; + if (bias != NULL) value += bias[c]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type != ActType_No) value = MSMAX(0.0f, value); + dst[ci] = value; + } + } +} + +void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, + size_t stride) { +#ifdef ENABLE_ARM64 + AdderFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride); +#else + Adder12x4(a, b, c, bias, act_type, deep, row, col, stride); +#endif +} + +void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param) { + int out_channel = conv_param->output_channel_; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + int output_count = conv_param->output_h_ * conv_param->output_w_; + if (conv_param->thread_num_ == 0) { + return; + } +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) + const int cal_num = C4NUM; +#else + const int cal_num = C12NUM; +#endif + int output_tile_count = UP_DIV(output_count, cal_num); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * out_channel * output_count; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int start_index = thread_id * cal_num; + int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num; + float *gemm_input = packed_input + task_id * deep * cal_num; + float *col_major_gemm_input = col_major_input + task_id * deep * cal_num; + size_t packed_input_size = deep * cal_num * sizeof(float); + memset(gemm_input, 0, packed_input_size); + memset(col_major_gemm_input, 0, packed_input_size); + Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); + + int out_offset = thread_id * cal_num * out_channel + out_batch_offset; + float *gemm_output = output_data + out_offset; +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) + RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep); +#else + RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep); +#endif + AdderOpt(col_major_gemm_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_num, + out_channel, out_channel); + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adder_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adder_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..3908102a83303fde218a18cea9bf44f5b8ff76a3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/adder_fp32.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ADDER_H_ +#define MINDSPORE_NNACL_FP32_ADDER_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef ENABLE_ARM64 +void AdderFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, size_t stride); +#endif + +void AdderOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, + size_t stride); + +void AdderFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ADDER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arg_min_max_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arg_min_max_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..324dd63c7b1fd5ccb6bb95efa45786c6e02a5510 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arg_min_max_fp32.c @@ -0,0 +1,298 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/arg_min_max_fp32.h" +#include + +#define ARG_MIN_MAX_FUNC(data_type) \ + int ArgCompareDesc32##data_type(const void *a, const void *b) { \ + DATA_TYPE b_value = ((ArgElement *)b)->data_.UNION_DATA; \ + DATA_TYPE a_value = ((ArgElement *)a)->data_.UNION_DATA; \ + if (b_value > a_value) { \ + return 1; \ + } \ + if (b_value < a_value) { \ + return -1; \ + } \ + return 0; \ + } \ + int ArgCompareAsc32##data_type(const void *a, const void *b) { \ + DATA_TYPE a_value = ((ArgElement *)a)->data_.UNION_DATA; \ + DATA_TYPE b_value = ((ArgElement *)b)->data_.UNION_DATA; \ + if (b_value > a_value) { \ + return -1; \ + } \ + if (b_value < a_value) { \ + return 1; \ + } \ + return 0; \ + } \ + \ + void ArgMaxTopK1##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const ArgMinMaxComputeParam *param, int pre_axis_count, int axis_count, \ + int after_axis_count) { \ + bool out_value = param->out_value_; \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int i = 0; i < pre_axis_count; ++i) { \ + int output_offset = i * after_axis_count; \ + int input_offset = output_offset * axis_count; \ + for (int j = 0; j < after_axis_count; ++j) { \ + DATA_TYPE value = MIN_VALUE; \ + int index = 0; \ + for (int k = 0; k < axis_count; ++k) { \ + DATA_TYPE value_tmp = input[input_offset + k * after_axis_count + j]; \ + if (value_tmp > value) { \ + value = value_tmp; \ + index = k; \ + } \ + } \ + if (out_value) { \ + outputfp32[output_offset + j] = value; \ + } else { \ + outputint[output_offset + j] = index; \ + } \ + if (output_value != NULL) { \ + output_value[output_offset + j] = value; \ + } \ + } \ + } \ + } \ + \ + void ArgMinTopK1##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const ArgMinMaxComputeParam *param, int pre_axis_count, int axis_count, \ + int after_axis_count) { \ + bool out_value = param->out_value_; \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int i = 0; i < pre_axis_count; ++i) { \ + int output_offset = i * after_axis_count; \ + int input_offset = output_offset * axis_count; \ + for (int j = 0; j < after_axis_count; ++j) { \ + DATA_TYPE value = MAX_VALUE; \ + int index = 0; \ + for (int k = 0; k < axis_count; ++k) { \ + DATA_TYPE value_tmp = input[input_offset + k * after_axis_count + j]; \ + if (value_tmp < value) { \ + value = value_tmp; \ + index = k; \ + } \ + } \ + if (out_value) { \ + outputfp32[output_offset + j] = value; \ + } else { \ + outputint[output_offset + j] = index; \ + } \ + if (output_value != NULL) { \ + output_value[output_offset + j] = value; \ + } \ + } \ + } \ + } \ + \ + void ArgMinMaxDim0##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param, \ + COMPARE_FUNCTION compare_func) { \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { \ + for (int j = 0; j < in_shape[0]; ++j) { \ + int offset = param->in_strides_[0] * j + i; \ + param->arg_elements_[j].index_ = (uint32_t)j; \ + param->arg_elements_[j].data_.UNION_DATA = input[offset]; \ + } \ + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), *compare_func); \ + for (int j = 0; j < param->topk_; ++j) { \ + int out_offset = j * param->out_strides_[0] + i; \ + if (param->out_value_) { \ + outputfp32[out_offset] = param->arg_elements_[j].data_.UNION_DATA; \ + } else { \ + outputint[out_offset] = param->arg_elements_[j].index_; \ + } \ + if (output_value != NULL) { \ + output_value[out_offset] = param->arg_elements_[j].data_.UNION_DATA; \ + } \ + } \ + } \ + return; \ + } \ + \ + void ArgMinMaxDim1##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param, \ + COMPARE_FUNCTION compare_func) { \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + int in_shape1 = in_shape[1]; \ + for (int i = 0; i < in_shape[0]; ++i) { \ + int in_dim0_offset = i * param->in_strides_[0]; \ + int out_dim0_offset = i * param->out_strides_[0]; \ + for (int j = 0; j < param->in_strides_[1]; ++j) { \ + for (int k = 0; k < in_shape1; ++k) { \ + int offset = param->in_strides_[1] * k + in_dim0_offset + j; \ + param->arg_elements_[k].index_ = (uint32_t)k; \ + param->arg_elements_[k].data_.UNION_DATA = input[offset]; \ + } \ + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), *compare_func); \ + for (int k = 0; k < param->topk_; ++k) { \ + int out_offset = out_dim0_offset + j + k * param->out_strides_[1]; \ + if (param->out_value_) { \ + outputfp32[out_offset] = param->arg_elements_[k].data_.UNION_DATA; \ + } else { \ + outputint[out_offset] = param->arg_elements_[k].index_; \ + } \ + if (output_value != NULL) { \ + output_value[out_offset] = param->arg_elements_[k].data_.UNION_DATA; \ + } \ + } \ + } \ + } \ + return; \ + } \ + \ + void ArgMinMaxDim2##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param, \ + COMPARE_FUNCTION compare_func) { \ + int in_shape1 = in_shape[1]; \ + int in_shape2 = in_shape[2]; \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int i = 0; i < in_shape[0]; ++i) { \ + int in_dim0_offset = i * param->in_strides_[0]; \ + int out_dim0_offset = i * param->out_strides_[0]; \ + for (int j = 0; j < in_shape1; ++j) { \ + int in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; \ + int out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; \ + for (int k = 0; k < param->in_strides_[2]; ++k) { \ + for (int l = 0; l < in_shape2; ++l) { \ + int offset = param->in_strides_[2] * l + k + in_dim1_offset; \ + param->arg_elements_[l].index_ = (uint32_t)l; \ + param->arg_elements_[l].data_.UNION_DATA = input[offset]; \ + } \ + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), *compare_func); \ + for (int l = 0; l < param->topk_; ++l) { \ + int out_offset = out_dim1_offset + k + l * param->out_strides_[2]; \ + if (param->out_value_) { \ + outputfp32[out_offset] = param->arg_elements_[l].data_.UNION_DATA; \ + } else { \ + outputint[out_offset] = param->arg_elements_[l].index_; \ + } \ + if (output_value != NULL) { \ + output_value[out_offset] = param->arg_elements_[l].data_.UNION_DATA; \ + } \ + } \ + } \ + } \ + } \ + } \ + \ + void ArgMinMaxDim3##data_type(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param, \ + COMPARE_FUNCTION compare_func) { \ + int in_shape1 = in_shape[1]; \ + int in_shape2 = in_shape[2]; \ + int in_shape3 = in_shape[3]; \ + DATA_TYPE *outputfp32 = (DATA_TYPE *)output; \ + int32_t *outputint = (int32_t *)output; \ + for (int i = 0; i < in_shape[0]; ++i) { \ + int in_dim0_offset = i * param->in_strides_[0]; \ + int out_dim0_offset = i * param->out_strides_[0]; \ + for (int j = 0; j < in_shape1; ++j) { \ + int in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; \ + int out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; \ + for (int k = 0; k < in_shape2; ++k) { \ + int in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; \ + int out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; \ + for (int l = 0; l < in_shape3; ++l) { \ + int offset = l + in_dim2_offset; \ + param->arg_elements_[l].index_ = (uint32_t)l; \ + param->arg_elements_[l].data_.UNION_DATA = input[offset]; \ + } \ + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), *compare_func); \ + for (int l = 0; l < param->topk_; ++l) { \ + int out_offset = out_dim2_offset + l; \ + if (param->out_value_) { \ + outputfp32[out_offset] = param->arg_elements_[l].data_.UNION_DATA; \ + } else { \ + outputint[out_offset] = (int)(param->arg_elements_[l].index_); \ + } \ + if (output_value != NULL) { \ + output_value[out_offset] = param->arg_elements_[l].data_.UNION_DATA; \ + } \ + } \ + } \ + } \ + } \ + } \ + \ + void ArgMinMax##data_type##32(const DATA_TYPE *input, void *output, DATA_TYPE *output_value, \ + const int32_t *in_shape, const ArgMinMaxComputeParam *param) { \ + if (param->topk_ == 1) { \ + int pre_axis_count = 1; \ + int axis_count = 1; \ + int after_axis_count = 1; \ + ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); \ + \ + if (param->get_max_) { \ + ArgMaxTopK1##data_type(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); \ + } else { \ + ArgMinTopK1##data_type(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); \ + } \ + return; \ + } \ + \ + COMPARE_FUNCTION compare_function = NULL; \ + if (param->get_max_) { \ + compare_function = ArgCompareDesc32##data_type; \ + } else { \ + compare_function = ArgCompareAsc32##data_type; \ + } \ + \ + switch (param->axis_) { \ + case 0: \ + ArgMinMaxDim0##data_type(input, output, output_value, in_shape, param, compare_function); \ + break; \ + case 1: \ + ArgMinMaxDim1##data_type(input, output, output_value, in_shape, param, compare_function); \ + break; \ + case 2: \ + ArgMinMaxDim2##data_type(input, output, output_value, in_shape, param, compare_function); \ + break; \ + case 3: \ + ArgMinMaxDim3##data_type(input, output, output_value, in_shape, param, compare_function); \ + break; \ + } \ + return; \ + } + +#define DATA_TYPE float +#define MIN_VALUE -FLT_MAX +#define MAX_VALUE FLT_MAX +#define UNION_DATA f_data_ +ARG_MIN_MAX_FUNC(Fp) +#undef DATA_TYPE +#undef MIN_VALUE +#undef MAX_VALUE +#undef UNION_DATA + +#define DATA_TYPE int32_t +#define MIN_VALUE INT32_MIN +#define MAX_VALUE INT32_MAX +#define UNION_DATA i_data_ +ARG_MIN_MAX_FUNC(Int) +#undef DATA_TYPE +#undef MIN_VALUE +#undef MAX_VALUE +#undef UNION_DATA diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arg_min_max_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arg_min_max_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..f7ddbf0ea7c97bae0b7132c840467f3563a851b2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arg_min_max_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FP32_ARG_MIN_MAX_FP32_H_ +#define FP32_ARG_MIN_MAX_FP32_H_ + +#include "nnacl/nnacl_common.h" +#include "nnacl/arg_min_max_parameter.h" +#include "nnacl/kernel/arg_min_max.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ArgMinMaxFp32(const float *input, void *output, float *output_value, const int32_t *in_shape, + const ArgMinMaxComputeParam *param); +void ArgMinMaxInt32(const int32_t *input, void *output, int32_t *output_value, const int32_t *in_shape, + const ArgMinMaxComputeParam *param); +#ifdef __cplusplus +} +#endif + +#endif // FP32_ARG_MIN_MAX_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_compare_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_compare_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..09dd9923ee182f3b6a6c61e2b65e156b142289de --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_compare_fp32.c @@ -0,0 +1,198 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/arithmetic_compare_fp32.h" + +inline bool EqualFp32(float x, float y); +inline bool EqualBool(bool x, bool y); +inline bool NotEqualFp32(float x, float y); +inline bool LessFp32(float x, float y); +inline bool LessEqualFp32(float x, float y); +inline bool GreaterFp32(float x, float y); +inline bool GreaterEqualFp32(float x, float y); + +inline bool EqualInt32(int x, int y); +inline bool NotEqualInt32(int x, int y); +inline bool NotEqualInt64(int64_t x, int64_t y); +inline bool LessInt32(int x, int y); +inline bool LessEqualInt32(int x, int y); +inline bool GreaterInt32(int x, int y); +inline bool GreaterEqualInt32(int x, int y); + +bool EqualFp32(float x, float y) { return x == y; } +bool EqualBool(bool x, bool y) { return x == y; } +bool NotEqualFp32(float x, float y) { return x != y; } +bool LessFp32(float x, float y) { return x < y; } +bool LessEqualFp32(float x, float y) { return x <= y; } +bool GreaterFp32(float x, float y) { return x > y; } +bool GreaterEqualFp32(float x, float y) { return x >= y; } + +bool EqualInt32(int x, int y) { return x == y; } +bool NotEqualInt32(int x, int y) { return x != y; } +bool NotEqualInt64(int64_t x, int64_t y) { return x != y; } +bool LessInt32(int x, int y) { return x < y; } +bool LessEqualInt32(int x, int y) { return x <= y; } +bool GreaterInt32(int x, int y) { return x > y; } +bool GreaterEqualInt32(int x, int y) { return x >= y; } + +#define ELEMENT_COMPARE(input0, input1, output, element_size, compare_func) \ + do { \ + for (int i = 0; i < element_size; i++) { \ + output[i] = compare_func(input0[i], input1[i]); \ + } \ + return NNACL_OK; \ + } while (0) + +#define ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, compare_func) \ + do { \ + int i = 0; \ + if (first_scalar) { \ + for (; i < element_size; i++) { \ + output[i] = compare_func(input0[0], input1[i]); \ + } \ + } else { \ + for (; i < element_size; i++) { \ + output[i] = compare_func(input0[i], input1[0]); \ + } \ + } \ + return NNACL_OK; \ + } while (0) + +// equal: +int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, EqualFp32); +} + +int ElementEqualBool(const bool *input0, const bool *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, EqualBool); +} + +int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, EqualFp32); +} + +int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, EqualInt32); +} + +int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, EqualInt32); +} + +// not equal +int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualFp32); +} + +int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualFp32); +} + +int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualInt32); +} + +int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualInt32); +} + +int ElementNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, NotEqualInt64); +} + +int ElementOptNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, NotEqualInt64); +} + +// less +int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, LessFp32); +} + +int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessFp32); +} + +int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, LessInt32); +} + +int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessInt32); +} + +// less equal +int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, LessEqualFp32); +} + +int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessEqualFp32); +} + +int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, LessEqualInt32); +} + +int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, LessEqualInt32); +} + +// greater +int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, GreaterFp32); +} + +int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterFp32); +} + +int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, GreaterInt32); +} + +int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterInt32); +} + +// greater equal +int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, GreaterEqualFp32); +} + +int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterEqualFp32); +} + +int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size) { + ELEMENT_COMPARE(input0, input1, output, element_size, GreaterEqualInt32); +} + +int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar) { + ELEMENT_COMPARE_OPT(input0, input1, output, element_size, first_scalar, GreaterEqualInt32); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_compare_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_compare_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..758fd493f14fdc744d6117ed49fadf226457b72e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_compare_fp32.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_ARITHMETIC_COMPARE_H_ +#define MINDSPORE_NNACL_ARITHMETIC_COMPARE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/base/arithmetic_base.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementEqualBool(const bool *input0, const bool *input1, uint8_t *output, int element_size); +int ElementOptEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar); +int ElementEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptNotEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptNotEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size); +int ElementOptNotEqualInt64(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptLessFp32(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar); +int ElementLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptLessInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptLessEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptLessEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptGreaterFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptGreaterInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); + +int ElementGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size); +int ElementOptGreaterEqualFp32(const float *input0, const float *input1, uint8_t *output, int element_size, + bool first_scalar); +int ElementGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size); +int ElementOptGreaterEqualInt32(const int32_t *input0, const int32_t *input1, uint8_t *output, int element_size, + bool first_scalar); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_ARITHMETIC_COMPARE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b67c72f43383b2d3c9b4c3da45571fd697b10523 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_fp32.c @@ -0,0 +1,482 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/arithmetic_fp32.h" +#include +#include "nnacl/arithmetic_fp32_simd.h" + +#define ACCURACY_DATA 0.00000001 + +int ElementFloorMod(const float *in0, const float *in1, float *out, int size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementFloorMod, i, in0, in1, out, size); // neon no floor instruction + + for (; i < size; i++) { + out[i] = in0[i] - floorf(in0[i] / in1[i]) * in1[i]; + } + return NNACL_OK; +} + +int ElementOptFloorMod(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int i = 0; + + if (first_scalar) { + SIMD_RUN_X86_NO_SCALAR(ElementOptFloorModNum0, i, in0, in1, out, size); // neon no floor instruction + for (; i < size; i++) { + out[i] = in0[0] - floorf(in0[0] / in1[i]) * in1[i]; + } + } else { + SIMD_RUN_X86_NO_SCALAR(ElementOptFloorModNum1, i, in0, in1, out, size); // neon no floor instruction + for (; i < size; i++) { + out[i] = in0[i] - floorf(in0[i] / in1[0]) * in1[0]; + } + } + + return NNACL_OK; +} + +int ElementFloorModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int i = 0; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + int remainder = in0[i] - (in0[i] / in1[i]) * in1[i]; + out[i] = (remainder != 0) && ((in0[i] > 0) != (in1[i] > 0)) ? remainder + in1[i] : remainder; + } + return NNACL_OK; +} + +int ElementOptFloorModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int i = 0; + if (first_scalar) { + for (; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + int remainder = in0[0] - (in0[0] / in1[i]) * in1[i]; + out[i] = (remainder != 0) && ((in0[0] > 0) != (in1[i] > 0)) ? remainder + in1[i] : remainder; + } + } else { + NNACL_CHECK_ZERO_RETURN_ERR(in1[0]); + for (; i < size; i++) { + int remainder = in0[i] - (in0[i] / in1[0]) * in1[0]; + out[i] = (remainder != 0) && ((in0[i] > 0) != (in1[0] > 0)) ? remainder + in1[0] : remainder; + } + } + + return NNACL_OK; +} + +int ElementMod(const float *in0, const float *in1, float *out, int size) { + for (int i = 0; i < size; i++) { + out[i] = fmodf(in0[i], in1[i]); + } + return NNACL_OK; +} + +int ElementOptMod(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = fmodf(in0[0], in1[index]); + } + } else { + for (; index < size; index++) { + out[index] = fmodf(in0[index], in1[0]); + } + } + return NNACL_OK; +} + +int ElementModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int i = 0; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + out[i] = in0[i] % in1[i]; + } + return NNACL_OK; +} + +int ElementOptModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + if (first_scalar) { + for (int index = 0; index < size; index++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[index]); + out[index] = in0[0] % in1[index]; + } + } else { + NNACL_CHECK_ZERO_RETURN_ERR(in1[0]); + for (int index = 0; index < size; index++) { + out[index] = in0[index] % in1[0]; + } + } + return NNACL_OK; +} + +int ElementFloorDiv(const float *in0, const float *in1, float *out, int size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementFloorDiv, i, in0, in1, out, size); // neon no floor instruction + + for (; i < size; i++) { + out[i] = floorf(in0[i] / in1[i]); + } + return NNACL_OK; +} + +int ElementOptFloorDiv(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int i = 0; + + if (first_scalar) { + SIMD_RUN_X86_NO_SCALAR(ElementOptFloorDivNum0, i, in0, in1, out, size); // neon no floor instruction + + for (; i < size; i++) { + out[i] = floorf(in0[0] / in1[i]); + } + } else { + SIMD_RUN_X86_NO_SCALAR(ElementOptFloorDivNum1, i, in0, in1, out, size); // neon no floor instruction + + for (; i < size; i++) { + out[i] = floorf(in0[i] / in1[0]); + } + } + + return NNACL_OK; +} + +int ElementFloorDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementFloorDivInt, i, in0, in1, out, size); + + for (; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + out[i] = in0[i] / in1[i]; + } + return NNACL_OK; +} + +int ElementOptFloorDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int i = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptFloorDivIntNum0, i, in0, in1, out, size); + + for (; i < size; i++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[i]); + out[i] = in0[0] / in1[i]; + } + } else { + NNACL_CHECK_ZERO_RETURN_ERR(in1[0]); + + SIMD_RUN_NO_SCALAR(ElementOptFloorDivIntNum1, i, in0, in1, out, size); + + for (; i < size; i++) { + out[i] = in0[i] / in1[0]; + } + } + + return NNACL_OK; +} + +int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementLogicalAnd, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = (float)((bool)(in0[index]) & (bool)(in1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalAnd(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + SIMD_RUN_NO_SCALAR(ElementOptLogicalAnd, index, in0, in1, out, size, first_scalar); + if (first_scalar) { + for (; index < size; index++) { + out[index] = (float)((bool)(in0[0]) & (bool)(in1[index])); + } + } else { + for (; index < size; index++) { + out[index] = (float)((bool)(in0[index]) & (bool)(in1[0])); + } + } + + return NNACL_OK; +} + +int ElementLogicalAndInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + for (; index < size; index++) { + out[index] = (int)((unsigned int)(in0[index]) & (unsigned int)(in1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalAndInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = (int)((unsigned int)(in0[0]) & (unsigned int)(in1[index])); + } + } else { + for (; index < size; index++) { + out[index] = (int)((unsigned int)(in0[index]) & (unsigned int)(in1[0])); + } + } + + return NNACL_OK; +} + +int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size) { + int index = 0; + for (; index < size; index++) { + out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[index])); + } + + return NNACL_OK; +} + +int ElementOptLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = (bool)((unsigned int)(in0[0]) & (unsigned int)(in1[index])); + } + } else { + for (; index < size; index++) { + out[index] = (bool)((unsigned int)(in0[index]) & (unsigned int)(in1[0])); + } + } + + return NNACL_OK; +} + +int ElementLogicalOr(const float *in0, const float *in1, float *out, int size) { + int index = 0; +#ifdef ENABLE_NEON + float32x4_t vtrue = vdupq_n_f32(1); + float32x4_t vfalse = vdupq_n_f32(0); + uint32x4_t mask = vmovq_n_u32(((uint32_t)(1u << 31) - 1)); + uint32x4_t zeros = vdupq_n_u32(0); + for (; index <= size - 4; index += C4NUM) { + uint32x4_t vin0 = vandq_u32(vreinterpretq_u32_f32(vld1q_f32(in0 + index)), mask); + uint32x4_t vin1 = vandq_u32(vreinterpretq_u32_f32(vld1q_f32(in1 + index)), mask); + float32x4_t vout = vbslq_f32(vceqq_u32(vorrq_u32(vin0, vin1), zeros), vfalse, vtrue); + vst1q_f32(out + index, vout); + } +#endif + for (; index < size; index++) { + out[index] = (float)((bool)(in0[index]) | (bool)(in1[index])); + } + return NNACL_OK; +} + +int ElementOptLogicalOr(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = (float)((bool)(in0[0]) | (bool)(in1[index])); + } + } else { + for (; index < size; index++) { + out[index] = (float)((bool)(in0[index]) | (bool)(in1[0])); + } + } + + return NNACL_OK; +} + +int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size) { + int index = 0; + for (; index < size; index++) { + out[index] = (bool)(in0[index] | in1[index]); + } + return NNACL_OK; +} + +int ElementOptLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + for (; index < size; index++) { + out[index] = (bool)(in0[0] | in1[index]); + } + } else { + for (; index < size; index++) { + out[index] = (bool)(in0[index] | in1[0]); + } + } + + return NNACL_OK; +} + +int ElementMaximum(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMaximum, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[index] ? in0[index] : in1[index]; + } + return NNACL_OK; +} + +int ElementOptMaximum(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMaximumNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] > in1[index] ? in0[0] : in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMaximumNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[0] ? in0[index] : in1[0]; + } + } + + return NNACL_OK; +} + +int ElementMaximumInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMaximumInt, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[index] ? in0[index] : in1[index]; + } + return NNACL_OK; +} + +int ElementOptMaximumInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMaximumIntNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] > in1[index] ? in0[0] : in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMaximumIntNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[0] ? in0[index] : in1[0]; + } + } + + return NNACL_OK; +} + +int ElementMinimumInt(const int32_t *input0, const int32_t *input1, int32_t *output, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMinimumInt, index, input0, input1, output, size); + + for (; index < size; index++) { + output[index] = input0[index] > input1[index] ? input1[index] : input0[index]; + } + return NNACL_OK; +} + +int ElementOptMinimumInt(const int32_t *input0, const int32_t *input1, int32_t *output, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMinimumIntNum0, index, input0, input1, output, size); + + for (; index < size; index++) { + output[index] = input0[0] > input1[index] ? input1[index] : input0[0]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMinimumIntNum1, index, input0, input1, output, size); + + for (; index < size; index++) { + output[index] = input0[index] > input1[0] ? input1[0] : input0[index]; + } + } + + return NNACL_OK; +} + +int ElementMinimum(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMinimum, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[index] ? in1[index] : in0[index]; + } + return NNACL_OK; +} + +int ElementOptMinimum(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMinimumNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] > in1[index] ? in1[index] : in0[0]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMinimumNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] > in1[0] ? in1[0] : in0[index]; + } + } + + return NNACL_OK; +} + +#undef ACCURACY_DATA + +void TileOneDimensionFp32(const void *inPtr, void *outPtr, int dim, size_t ndim, const int32_t *inShape, + const int32_t *inStrides, const int32_t *outStrides, const int32_t *multiple) { + const float *inData = (const float *)inPtr; + float *outData = (float *)outPtr; + + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(float)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimensionFp32(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, + inShape, inStrides, outStrides, multiple); + } + } +} + +void TileDimensionsFp32(const float *data0, const float *data1, float *tile_data0, float *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionFp32(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimensionFp32(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +void AssignSubOpt(float *in0, const float *in1, size_t size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(AssignSubOpt, index, in0, in1, size); + + for (; index < size; index++) { + in0[index] = in0[index] - in1[index]; + } + return; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..e8f6819c33e7febd034786ed547252b3b3459917 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_fp32.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ARITHMETIC_H_ +#define MINDSPORE_NNACL_ARITHMETIC_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/base/arithmetic_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/fp32/add_fp32.h" +#include "nnacl/fp32/mul_fp32.h" +#include "nnacl/fp32/div_fp32.h" +#include "nnacl/fp32/sub_fp32.h" +#include "nnacl/fp32/squared_difference.h" + +#ifdef __cplusplus +extern "C" { +#endif +void TileOneDimensionFp32(const void *inData, void *outData, int dim, size_t ndim, const int32_t *inShape, + const int32_t *inStrides, const int32_t *outStrides, const int32_t *multiple); +void TileDimensionsFp32(const float *data0, const float *data1, float *tile_data0, float *tile_data1, + ArithmeticParameter *param); +/* logical and */ +int ElementLogicalAnd(const float *in0, const float *in1, float *out, int size); +int ElementOptLogicalAnd(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementLogicalAndInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptLogicalAndInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size); +int ElementOptLogicalAndBool(const bool *in0, const bool *in1, bool *out, int size, bool first_scalar); + +/* logical or */ +int ElementLogicalOr(const float *in0, const float *in1, float *out, int size); +int ElementOptLogicalOr(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size); +int ElementOptLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size, bool first_scalar); + +/* max min */ +int ElementMaximum(const float *in0, const float *in1, float *out, int size); +int ElementOptMaximum(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementMinimum(const float *in0, const float *in1, float *out, int size); +int ElementOptMinimum(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementMaximumInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptMaximumInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementMinimumInt(const int32_t *input0, const int32_t *input1, int32_t *output, int size); +int ElementOptMinimumInt(const int32_t *input0, const int32_t *input1, int32_t *output, int size, bool first_scalar); + +/* floor div */ +int ElementFloorDiv(const float *in0, const float *in1, float *out, int size); +int ElementOptFloorDiv(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementFloorDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptFloorDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); + +/* floor mod */ +int ElementFloorMod(const float *in0, const float *in1, float *out, int size); +int ElementOptFloorMod(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementFloorModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptFloorModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); + +/* mod */ +int ElementMod(const float *in0, const float *in1, float *out, int size); +int ElementOptMod(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptModInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); + +void AssignSubOpt(float *in0, const float *in1, size_t size); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_ARITHMETIC_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..f35c10e720ba7811525040f5a0667d838a039af2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_fp32_simd.h.in @@ -0,0 +1,287 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_ARITHMETIC_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_ARITHMETIC_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +#ifndef MS_SIMD_NEON +static inline int ElementFloorMod@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 floor_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_F32 out_tmp = SIMD_SUB_F32(in0_tmp, SIMD_MUL_F32(floor_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorModNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 floor_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_F32 out_tmp = SIMD_SUB_F32(in0_tmp, SIMD_MUL_F32(floor_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorModNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in1_tmp = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 floor_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_F32 out_tmp = SIMD_SUB_F32(in0_tmp, SIMD_MUL_F32(floor_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementFloorDiv@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 floor_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_ST_F32(out + index, floor_tmp); + } + return index; +} + +static inline int ElementOptFloorDivNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorDivNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in1_tmp = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 out_tmp = SIMD_FLOOR_F32(SIMD_DIV_F32(in0_tmp, in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} +#endif + +static inline int ElementFloorDivInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_DIV_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorDivIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in0_tmp = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_DIV_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptFloorDivIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in1_tmp = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 out_tmp = SIMD_DIV_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementMaximum@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_MAX_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMaximumNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_MAX_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMaximumNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in1_tmp = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 out_tmp = SIMD_MAX_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementMaximumInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_MAX_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMaximumIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in0_tmp = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_MAX_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMaximumIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in1_tmp = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 out_tmp = SIMD_MAX_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementMinimumInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_MIN_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMinimumIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in0_tmp = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in1_tmp = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 out_tmp = SIMD_MIN_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMinimumIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 in1_tmp = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 in0_tmp = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 out_tmp = SIMD_MIN_EPI32(in0_tmp, in1_tmp); + SIMD_ST_EPI32(out + index, out_tmp); + } + return index; +} + +static inline int ElementMinimum@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_MIN_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMinimumNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_MIN_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline int ElementOptMinimumNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 in1_tmp = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 out_tmp = SIMD_MIN_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +static inline size_t AssignSubOpt@SIMD_INSTRUCTION@(int index, float *in0, const float *in1, size_t size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_SUB_F32(in0_tmp, in1_tmp); + SIMD_ST_F32(in0 + index, out_tmp); + } + return index; +} + +int ElementLogicalAnd@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_AND_F32(SIMD_GETSIGN_F32(in0_tmp), SIMD_GETSIGN_F32(in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + return index; +} + +int ElementOptLogicalAnd@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size, bool first_scalar) { + if (first_scalar) { + SIMD_F32 in0_tmp = SIMD_MOV_F32(*in0); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in1_tmp = SIMD_LD_F32(in1 + index); + SIMD_F32 out_tmp = SIMD_AND_F32(SIMD_GETSIGN_F32(in0_tmp), SIMD_GETSIGN_F32(in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + } else { + SIMD_F32 in1_tmp = SIMD_MOV_F32(*in1); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 in0_tmp = SIMD_LD_F32(in0 + index); + SIMD_F32 out_tmp = SIMD_AND_F32(SIMD_GETSIGN_F32(in0_tmp), SIMD_GETSIGN_F32(in1_tmp)); + SIMD_ST_F32(out + index, out_tmp); + } + } + + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_self_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_self_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..ecf9baef97e5fb7db86ca371256fa2d530a1c9ec --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_self_fp32.c @@ -0,0 +1,230 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "nnacl/fp32/arithmetic_self_fp32.h" +#include "nnacl/arithmetic_self_fp32_simd.h" + +int ElementAbs(const float *input, float *output, const int element_size) { + int i = 0; + + // only avx512 support abs fp32 instruction + SIMD_RUN_AVX512(ElementAbs, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = fabsf(input[i]); + } + return NNACL_OK; +} + +int ElementAbsInt(const int32_t *input, int32_t *output, const int element_size) { + int i = 0; + + // only avx512 support abs fp32 instruction + SIMD_RUN_AVX512(ElementAbsInt, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = abs(input[i]); + } + return NNACL_OK; +} + +// cos +int ElementCos(const float *input, float *output, const int element_size) { + int i = 0; + SIMD_RUN_X86_NO_SCALAR(ElementCos, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = cosf(input[i]); + } + return NNACL_OK; +} + +// log: +int ElementLog(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementLog, i, input, output, element_size); + for (; i < element_size; i++) { + if (MS_UNLIKELY(input[i] < 0)) { + return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO; + } + output[i] = logf(input[i]); + } + return NNACL_OK; +} + +// log1p: +int ElementLog1p(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + if (MS_UNLIKELY(input[i] < -1.0f)) { + return NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO; + } + output[i] = log1p(input[i]); + } + return NNACL_OK; +} + +int ElementSquare(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementSquare, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = input[i] * input[i]; + } + return NNACL_OK; +} + +int ElementSqrt(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementSqrt, i, input, output, element_size); + for (; i < element_size; i++) { + if (MS_UNLIKELY(input[i] < 0)) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + output[i] = sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementRsqrt(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementRsqrt, i, input, output, element_size); + for (; i < element_size; i++) { + if (MS_UNLIKELY(input[i] < 0)) { + return NNACL_ERRCODE_RSQRT_NEGATIVE; + } + output[i] = 1.f / sqrtf(input[i]); + } + return NNACL_OK; +} + +int ElementSin(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = sinf(input[i]); + } + return NNACL_OK; +} + +// logical_not: +int ElementLogicalNot(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = (float)(!((bool)(input[i]))); + } + return NNACL_OK; +} + +// logical_not: +int ElementLogicalNotBool(const bool *input, bool *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = !input[i]; + } + return NNACL_OK; +} + +int ElementRound(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_AVX(ElementRound, i, input, output, element_size); + SIMD_RUN_SSE(ElementRound, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = roundf(input[i]); + } + return NNACL_OK; +} + +int ElementFloor(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementFloor, i, input, output, element_size); + for (; i < element_size; i++) { + output[i] = floorf(input[i]); + } + return NNACL_OK; +} + +int ElementCeil(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_X86_NO_SCALAR(ElementCeil, i, input, output, element_size); + for (; i < element_size; ++i) { + output[i] = ceilf(input[i]); + } + return NNACL_OK; +} + +int ElementNegative(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementNegative, i, input, output, element_size); + for (; i < element_size; ++i) { + output[i] = -input[i]; + } + return NNACL_OK; +} + +int ElementNegativeInt(const int32_t *input, int32_t *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementNegativeInt, i, input, output, element_size); + for (; i < element_size; ++i) { + output[i] = -input[i]; + } + return NNACL_OK; +} + +int ElementReciprocal(const float *input, float *output, const int element_size) { + int i = 0; + + SIMD_RUN_NO_SCALAR(ElementReciprocal, i, input, output, element_size); + for (; i < element_size; ++i) { + if (input[i] == 0.0f) { + return NNACL_ERR; + } + output[i] = 1.f / input[i]; + } + return NNACL_OK; +} + +// Erf +int ElementErf(const float *input, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = erff(input[i]); + } + return NNACL_OK; +} + +int ElementIsFinite(const float *input, bool *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = true; + if (isnan(input[i]) || isinf(input[i])) { + output[i] = false; + } + } + return NNACL_OK; +} + +int ElementMish(const float *input, float *output, const int element_size) { + int i = 0; + SIMD_RUN_NO_SCALAR(ElementMish, i, input, output, element_size); + + for (; i < element_size; ++i) { + simd_exp32(input[i], output + i); + float exp_pow = (output[i] + 1) * (output[i] + 1); + output[i] = input[i] * (exp_pow - 1) / (exp_pow + 1); + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_self_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_self_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..006f855ce3747fb3e9236b6004732900bcd842e2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_self_fp32.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_ARITHMETIC_SELF_H_ +#define MINDSPORE_NNACL_ARITHMETIC_SELF_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementAbs(const float *input, float *output, const int element_size); +int ElementAbsInt(const int32_t *input, int32_t *output, const int element_size); + +int ElementCos(const float *input, float *output, const int element_size); + +int ElementLog(const float *input, float *output, const int element_size); + +int ElementLog1p(const float *input, float *output, const int element_size); + +int ElementSquare(const float *input, float *output, const int element_size); + +int ElementSqrt(const float *input, float *output, const int element_size); + +int ElementRsqrt(const float *input, float *output, const int element_size); + +int ElementSin(const float *input, float *output, const int element_size); + +int ElementLogicalNot(const float *input, float *output, const int element_size); + +int ElementLogicalNotBool(const bool *input, bool *output, const int element_size); + +int ElementRound(const float *input, float *output, const int element_size); + +int ElementFloor(const float *input, float *output, const int element_size); + +int ElementCeil(const float *input, float *output, const int number); + +int ElementNegative(const float *input, float *output, const int element_size); +int ElementNegativeInt(const int32_t *input, int32_t *output, const int element_size); + +int ElementReciprocal(const float *input, float *output, const int element_size); + +int ElementErf(const float *input, float *output, const int element_size); + +int ElementIsFinite(const float *input, bool *output, const int element_size); + +int ElementMish(const float *input, float *output, const int element_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_ARITHMETIC_SELF_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_self_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_self_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..bca77f369db0f602184b0bb0ca0cd7f30b70cdd8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/arithmetic_self_fp32_simd.h.in @@ -0,0 +1,152 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_ARITHMETIC_SELF_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_ARITHMETIC_SELF_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +#if defined(MS_SIMD_AVX512) +// only avx512 support abs fp32 instruction +static inline int ElementAbs@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_ABS_F32(SIMD_LD_F32(input + index))); + } + return index; +} + +static inline int ElementAbsInt@SIMD_INSTRUCTION@(int index, const int32_t *input, int32_t *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_ABS_EPI32(SIMD_LD_EPI32(input + index))); + } + return index; +} +#endif + +#if !defined(MS_SIMD_NEON) +// not support neon + static inline int ElementCos@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin = SIMD_LD_F32(input + index); + SIMD_ST_F32(output + index, SIMD_COS_F32(vin)); + } + return index; + } + + static inline int ElementLog@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin = SIMD_LD_F32(input + index); + SIMD_ST_F32(output + index, SIMD_LOG_F32(vin)); + } + return index; + } +#endif + +static inline int ElementSquare@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin = SIMD_LD_F32(input + index); + SIMD_ST_F32(output + index, SIMD_MUL_F32(vin, vin)); + } + return index; +} + +static inline int ElementSqrt@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_SQRT_F32(SIMD_LD_F32(input + index))); + } + return index; +} + +static inline int ElementRsqrt@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_RSQRT_F32(SIMD_LD_F32(input + index))); + } + return index; +} + +static inline int ElementMish@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + SIMD_F32 one = SIMD_MOV_F32(1.0f); + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 exp_add_one = SIMD_ADD_F32(SIMD_EXP_F32(SIMD_LD_F32(input + index)), one); + SIMD_F32 exp_pow = SIMD_MUL_F32(exp_add_one, exp_add_one); + SIMD_ST_F32(output + index, SIMD_MUL_F32(SIMD_LD_F32(input + index), + SIMD_DIV_F32(SIMD_SUB_F32(exp_pow, one), SIMD_ADD_F32(exp_pow, one)))); + } + return index; +} + +#if defined(MS_SIMD_AVX) || defined(MS_SIMD_SSE) +// avx512 dont support round fp32 instruction +static inline int ElementRound@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_ROUND_F32(SIMD_LD_F32(input + index))); + } + return index; +} +#endif + +#ifndef MS_SIMD_NEON +// neon dont support floor fp32 instruction +static inline int ElementFloor@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_FLOOR_F32(SIMD_LD_F32(input + index))); + } + return index; +} +#endif + +#ifndef MS_SIMD_NEON +static inline int ElementCeil@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_CEIL_F32(SIMD_LD_F32(input + index))); + } + return index; +} +#endif + +static inline int ElementNegative@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_MUL_N_F32(SIMD_LD_F32(input + index), -1.0f)); + } + return index; +} + +static inline int ElementNegativeInt@SIMD_INSTRUCTION@(int index, const int32_t *input, int32_t *output, const int element_size) { + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(output + index, SIMD_MUL_N_EPI32(SIMD_LD_EPI32(input + index), -1)); + } + return index; +} + +static inline int ElementReciprocal@SIMD_INSTRUCTION@(int index, const float *input, float *output, const int element_size) { + SIMD_F32 num1 = SIMD_MOV_F32(1.0f); + for (int block_max_size = element_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_DIV_F32(num1, SIMD_LD_F32(input + index))); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/attention_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/attention_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e86345bdf19d39b57e36bcbbacb893e57a177308 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/attention_fp32.c @@ -0,0 +1,581 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/attention_fp32.h" +#include +#include +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32/add_fp32.h" +#include "nnacl/fp32/transpose_fp32.h" +#include "nnacl/transpose_parameter.h" +#include "nnacl/fp32/softmax_fp32.h" +#include "nnacl/errorcode.h" + +int InitMatrix(Matrix *matrix, int batch, int row, int col, bool is_trans) { + if (matrix == NULL) { + return NNACL_NULL_PTR; + } + matrix->batch_ = batch; + matrix->row_ = row; + matrix->col_ = col; + matrix->is_transpose_ = is_trans; + matrix->data_ = NULL; + matrix->packed_data_ = NULL; + return NNACL_OK; +} + +size_t LeftMatrixPackElementSize(Matrix *matrix, int row_tile) { + if (matrix == NULL) { + return 0; + } + int real_row = matrix->is_transpose_ ? matrix->col_ : matrix->row_; + int deep = matrix->is_transpose_ ? matrix->row_ : matrix->col_; + bool vec_matmul = real_row == 1; + int row_align = vec_matmul ? 1 : UP_ROUND(real_row, row_tile); + int dst_area = row_align * deep; + matrix->packed_row_ = row_align; + matrix->packed_col_ = deep; + return matrix->batch_ * dst_area; +} + +size_t RightMatrixPackElementSize(Matrix *matrix, int col_tile) { + if (matrix == NULL) { + return 0; + } + int deep = matrix->is_transpose_ ? matrix->col_ : matrix->row_; + int real_col = matrix->is_transpose_ ? matrix->row_ : matrix->col_; + bool vec_matmul = deep == 1; + int col_align = vec_matmul ? real_col : UP_ROUND(real_col, col_tile); + int dst_area = deep * col_align; + matrix->packed_row_ = deep; + matrix->packed_col_ = col_align; + return matrix->batch_ * dst_area; +} + +int PackLeftMatrix(Matrix *matrix, int row_tile) { + if (matrix == NULL || matrix->data_ == NULL || row_tile == 0) { + return NNACL_NULL_PTR; + } + int real_row = matrix->is_transpose_ ? matrix->col_ : matrix->row_; + int deep = matrix->is_transpose_ ? matrix->row_ : matrix->col_; + bool vec_matmul = real_row == 1; + int row_align = vec_matmul ? 1 : UP_ROUND(real_row, row_tile); + int src_area = matrix->row_ * matrix->col_; + int dst_area = row_align * deep; + bool malloced = false; + if (matrix->packed_data_ == NULL) { + matrix->packed_data_ = (float *)malloc(dst_area * matrix->batch_ * sizeof(float)); + if (matrix->packed_data_ == NULL) { + return NNACL_NULL_PTR; + } + malloced = true; + } + + if (vec_matmul) { + memcpy(matrix->packed_data_, matrix->data_, matrix->batch_ * dst_area * sizeof(float)); + } else { + for (int i = 0; i < matrix->batch_; i++) { + const float *cur_src = matrix->data_ + i * src_area; + float *cur_dst = matrix->packed_data_ + i * dst_area; + switch (row_tile) { + case C6NUM: + if (matrix->is_transpose_) { + RowMajor2Row6Major(cur_src, cur_dst, real_row, deep); + } else { + RowMajor2Col6Major(cur_src, cur_dst, real_row, deep); + } + break; + case C4NUM: + if (matrix->is_transpose_) { + RowMajor2Row4Major(cur_src, cur_dst, real_row, deep); + } else { + RowMajor2Col4Major(cur_src, cur_dst, real_row, deep); + } + break; + case C12NUM: + if (matrix->is_transpose_) { + RowMajor2Row12Major(cur_src, cur_dst, real_row, deep); + } else { + RowMajor2Col12Major(cur_src, cur_dst, real_row, deep); + } + break; + default: + if (malloced) { + free(matrix->packed_data_); + matrix->packed_data_ = NULL; + return NNACL_ERR; + } + break; + } + } + } + matrix->packed_row_ = row_align; + matrix->packed_col_ = deep; + return NNACL_OK; +} + +int PackRightMatrix(Matrix *matrix, int col_tile) { + if (matrix == NULL || matrix->data_ == NULL || col_tile == 0) { + return NNACL_NULL_PTR; + } + int deep = matrix->is_transpose_ ? matrix->col_ : matrix->row_; + int real_col = matrix->is_transpose_ ? matrix->row_ : matrix->col_; + bool vec_matmul = deep == 1; + int col_align = vec_matmul ? real_col : UP_ROUND(real_col, col_tile); + int src_area = matrix->row_ * matrix->col_; + int dst_area = deep * col_align; + bool malloced = false; + if (matrix->packed_data_ == NULL) { + matrix->packed_data_ = (float *)malloc(dst_area * matrix->batch_ * sizeof(float)); + if (matrix->packed_data_ == NULL) { + return NNACL_NULL_PTR; + } + malloced = true; + } + if (vec_matmul) { + memcpy(matrix->packed_data_, matrix->data_, matrix->batch_ * dst_area * sizeof(float)); + } else { + for (int i = 0; i < matrix->batch_; i++) { + const float *cur_src = matrix->data_ + i * src_area; + float *cur_dst = matrix->packed_data_ + i * dst_area; + switch (col_tile) { + case C16NUM: + if (matrix->is_transpose_) { + RowMajor2Col16Major(cur_src, cur_dst, deep, real_col); + } else { + RowMajor2Row16Major(cur_src, cur_dst, deep, real_col); + } + break; + case C4NUM: + if (matrix->is_transpose_) { + RowMajor2Col4Major(cur_src, cur_dst, deep, real_col); + } else { + RowMajor2Row4Major(cur_src, cur_dst, deep, real_col); + } + break; + case C8NUM: + if (matrix->is_transpose_) { + RowMajor2Col8Major(cur_src, cur_dst, deep, real_col); + } else { + RowMajor2Row8Major(cur_src, cur_dst, deep, real_col); + } + break; + default: + if (malloced) { + free(matrix->packed_data_); + matrix->packed_data_ = NULL; + return NNACL_ERR; + } + break; + } + } + } + matrix->packed_row_ = deep; + matrix->packed_col_ = col_align; + return NNACL_OK; +} + +int PackAttentionBias(Matrix *matrix, int tile) { + if (matrix == NULL || matrix->batch_ != 1 || matrix->row_ != 1 || matrix->data_ == NULL) { + return NNACL_PARAM_INVALID; + } + if (tile == 0) { + return NNACL_OK; + } + int size = matrix->col_; + float *src = matrix->data_; + int size_align = UP_ROUND(size, tile); + if (size_align <= 0) { + return NNACL_ERR; + } + matrix->packed_data_ = (float *)malloc(size_align * sizeof(float)); + if (matrix->packed_data_ == NULL) { + return NNACL_NULL_PTR; + } + matrix->packed_row_ = matrix->row_; + matrix->packed_col_ = size_align; + memset(matrix->packed_data_, 0, size_align * sizeof(float)); + memcpy(matrix->packed_data_, src, size * sizeof(float)); + return NNACL_OK; +} + +static void RelativeShiftPad(const float *input_data, float *output_data, const int32_t *input_shape, int tid, + int thread_num) { + int row = input_shape[0]; + int col = input_shape[1]; + int out_area = row * (col + 1); + memset(output_data, 0, out_area * sizeof(float)); + for (int r = tid; r < row; r += thread_num) { + float *dst = output_data + r * (col + 1); + const float *src = input_data + r * col; + memcpy(dst, src, col * sizeof(float)); + } + int tile = row % thread_num; + for (int r = row - tile; r < row; r++) { + float *dst = output_data + r * (col + 1); + const float *src = input_data + r * col; + memcpy(dst, src, col * sizeof(float)); + } +} + +static void RelativeShiftSlice(const float *input_data, float *output_data, const int32_t *input_shape, int tid, + int thread_num) { + int row = input_shape[0]; + int col = input_shape[1]; + int begin = row; + memset(output_data, 0, row * row * sizeof(float)); + for (int r = tid; r < row; r += thread_num) { + float *dst = output_data + r * row; + const float *src = input_data + r * col + begin; + memcpy(dst, src, (col / 2) * sizeof(float)); + } + int tile = row % thread_num; + for (int r = row - tile; r < row; r++) { + float *dst = output_data + r * row; + const float *src = input_data + r * col + begin; + memcpy(dst, src, (col / 2) * sizeof(float)); + } +} + +static void RelativeShift(const Matrix *x, float *pad_buf, float *slice_buf) { + int x_area = x->row_ * x->col_; + int pad_area = x->row_ * (x->col_ + 1); + int slice_area = x->row_ * (x->col_ / 2); + int input_shape[] = {x->row_, x->col_}; + memset(slice_buf, 0, x->batch_ * x->row_ * (x->col_ / 2) * sizeof(float)); + for (int i = 0; i < x->batch_; i++) { + float *cur_x_data = x->data_ + i * x_area; + memset(pad_buf, 0, pad_area * sizeof(float)); + // pad: [row, col + 1] + RelativeShiftPad(cur_x_data, pad_buf, input_shape, 0, 1); + // reshape: [col + 1, row] + // slice last row: [col, row] + // reshape: [row, col] + // slice col -> [row, row + col / 2]: [row, col / 2] + float *cur_slice_data = slice_buf + i * slice_area; + RelativeShiftSlice(pad_buf, cur_slice_data, input_shape, 0, 1); + } +} + +static void ElementOptAddDiv(const float *input0, const float *input1, const float input2, float *output, + const int batch, const int area) { + int index = 0; + const float mul = 1 / input2; + for (int b = 0; b < batch; b++) { + const float *cur_input0 = input0 + b * area; + const float *cur_input1 = input1 + b * area; + float *cur_output = output + b * area; +#ifdef ENABLE_NEON + for (; index <= area - 4; index += C4NUM) { + float32x4_t vin0 = vld1q_f32(cur_input0 + index); + float32x4_t vin1 = vld1q_f32(cur_input1 + index); + float32x4_t vout = vaddq_f32(vin0, vin1); + vout = vmulq_n_f32(vout, mul); + vst1q_f32(cur_output + index, vout); + } +#endif + for (; index < area; index++) { + cur_output[index] += (cur_input0[index] + cur_input1[index]) * mul; + } + } +} + +static bool GetTransposeParameter(TransposeParameter *param, const int in_shape[], int in_shape_len, + const int out_shape[], int out_shape_len, const int perm[], int perm_len) { + param->num_axes_ = perm_len; + size_t shape_size = 1; + for (int i = 0; i < perm_len; i++) { + param->perm_[i] = perm[i]; + shape_size *= perm[i]; // check overflow + } + param->data_num_ = (int)shape_size; // check overflow + param->strides_[param->num_axes_ - 1] = 1; + param->out_strides_[param->num_axes_ - 1] = 1; + if (param->num_axes_ - 1 >= in_shape_len) { + return false; + } + if (param->num_axes_ - 1 >= out_shape_len) { + return false; + } + for (int i = param->num_axes_ - 2; i >= 0; i--) { + param->strides_[i] = in_shape[i + 1] * param->strides_[i + 1]; + param->out_strides_[i] = out_shape[i + 1] * param->out_strides_[i + 1]; + } + return true; +} + +void QWithPosition(RelativePositionAttentionParameter *param, Matrix *q_mat, const Matrix *wq_mat, Matrix *bq_mat, + Matrix *q2wq_mat, Matrix *pu_mat, Matrix *pv_mat, Matrix *q2wq_with_pos_mat, + Matrix *q2wq_with_pu_trans_mat, Matrix *q2wq_with_pv_trans_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + // Q * WQ + int q_area = q_mat->packed_row_ * q_mat->packed_col_; + int wq_area = wq_mat->packed_row_ * wq_mat->packed_col_; + int q2wq_area = q2wq_mat->row_ * q2wq_mat->col_ * q2wq_mat->batch_ / param->batch_; + float *q2wq_data = q2wq_mat->data_; + memset(q2wq_data, 0, param->batch_ * q2wq_area * sizeof(float)); + for (int i = 0; i < param->batch_; i++) { + float *cur_q = q_mat->packed_data_ + i * q_area; + float *cur_wq = wq_mat->packed_data_ + i * wq_area; + float *cur_q2wq = q2wq_data + i * q2wq_area; + MatMulOpt(cur_q, cur_wq, cur_q2wq, bq_mat->packed_data_, ActType_No, q_mat->col_, q_mat->row_, wq_mat->col_, + wq_mat->col_, OutType_Nhwc); + } + // transpose param init + TransposeParameter q_with_pos_trans_param; + int q_with_pos_trans_in_shape[] = {batch, param->q_seq_, num_heads, depth}; + int q_with_pos_trans_out_shape[] = {batch, num_heads, param->q_seq_, depth}; + int q_with_pos_perm[] = {0, 2, 1, 3}; + (void)GetTransposeParameter(&q_with_pos_trans_param, q_with_pos_trans_in_shape, 4, q_with_pos_trans_out_shape, 4, + q_with_pos_perm, 4); + int q2wq_reshaped_area = q2wq_mat->row_ * q2wq_mat->col_; + // Q_WQ + POS_U + { + float *q_with_pu = q2wq_with_pos_mat->data_; + int q_with_pu_area = q2wq_with_pos_mat->row_ * q2wq_with_pos_mat->col_; + memset(q_with_pu, 0, q2wq_with_pos_mat->batch_ * q_with_pu_area * sizeof(float)); + for (int i = 0; i < q2wq_with_pos_mat->batch_; i++) { + float *cur_qw = q2wq_data + i * q2wq_reshaped_area; + float *cur_q_with_pu = q_with_pu + i * q_with_pu_area; + ElementAdd(cur_qw, pu_mat->packed_data_, cur_q_with_pu, q_with_pu_area); + } + // Q_WITH_U perm [0,2,1,3] + float *q_with_pu_trans = q2wq_with_pu_trans_mat->data_; + size_t q_with_pu_trans_data_size = (size_t)(q2wq_with_pu_trans_mat->batch_) * + (size_t)(q2wq_with_pu_trans_mat->row_) * (size_t)(q2wq_with_pu_trans_mat->col_) * + sizeof(float); + memset(q_with_pu_trans, 0, q_with_pu_trans_data_size); + TransposeDimsFp32(q_with_pu, q_with_pu_trans, q_with_pos_trans_out_shape, q_with_pos_trans_param.perm_, + q_with_pos_trans_param.strides_, q_with_pos_trans_param.out_strides_, + q_with_pos_trans_param.num_axes_, 0, 1); + } + + // Q_WQ + POS_V + { + float *q_with_pv = q2wq_with_pos_mat->data_; + int q_with_pv_area = q2wq_with_pos_mat->row_ * q2wq_with_pos_mat->col_; + memset(q_with_pv, 0, q2wq_with_pos_mat->batch_ * q_with_pv_area * sizeof(float)); + for (int i = 0; i < q2wq_with_pos_mat->batch_; i++) { + float *cur_qw = q2wq_data + i * q2wq_reshaped_area; + float *cur_q_with_pv = q_with_pv + i * q_with_pv_area; + ElementAdd(cur_qw, pv_mat->packed_data_, cur_q_with_pv, q_with_pv_area); + } + // Q_WITH_V perm [0,2,1,3] + float *q_with_pv_trans = q2wq_with_pv_trans_mat->data_; + size_t q_with_pv_trans_data_size = (size_t)(q2wq_with_pv_trans_mat->batch_) * + (size_t)(q2wq_with_pv_trans_mat->row_) * (size_t)(q2wq_with_pv_trans_mat->col_) * + sizeof(float); + memset(q_with_pv_trans, 0, q_with_pv_trans_data_size); + TransposeDimsFp32(q_with_pv, q_with_pv_trans, q_with_pos_trans_out_shape, q_with_pos_trans_param.perm_, + q_with_pos_trans_param.strides_, q_with_pos_trans_param.out_strides_, + q_with_pos_trans_param.num_axes_, 0, 1); + } +} + +void KMulWeightK(RelativePositionAttentionParameter *param, Matrix *k_mat, const Matrix *wk_mat, Matrix *bk_mat, + Matrix *k2wk_mat, Matrix *k2wk_trans_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + // K * WK + int k_area = k_mat->packed_row_ * k_mat->packed_col_; + int wk_area = wk_mat->packed_row_ * wk_mat->packed_col_; + int k2wk_area = k2wk_mat->row_ * k2wk_mat->col_ * k2wk_mat->batch_ / param->batch_; + float *k2wk = k2wk_mat->data_; + memset(k2wk, 0, param->batch_ * k2wk_area * sizeof(float)); + for (int i = 0; i < param->batch_; i++) { + float *cur_k = k_mat->packed_data_ + i * k_area; + float *cur_wk = wk_mat->packed_data_ + i * wk_area; + float *cur_k2wk = k2wk + i * k2wk_area; + MatMulOpt(cur_k, cur_wk, cur_k2wk, bk_mat->packed_data_, ActType_No, k_mat->col_, k_mat->row_, wk_mat->col_, + wk_mat->col_, OutType_Nhwc); + } + // K * WK perm [0,2,3,1] + float *k2wk_trans_data = k2wk_trans_mat->data_; + int k2wk_trans_area = k2wk_trans_mat->row_ * k2wk_trans_mat->col_; + memset(k2wk_trans_data, 0, k2wk_trans_mat->batch_ * k2wk_trans_area * sizeof(float)); + TransposeParameter k2wk_trans_param; + int k2wk_in_shape[] = {batch, param->k_seq_, num_heads, depth}; + int k2wk_out_shape[] = {batch, num_heads, depth, param->k_seq_}; + int k2wk_perm[] = {0, 2, 3, 1}; + (void)GetTransposeParameter(&k2wk_trans_param, k2wk_in_shape, 4, k2wk_out_shape, 4, k2wk_perm, 4); + TransposeDimsFp32(k2wk, k2wk_trans_data, k2wk_out_shape, k2wk_trans_param.perm_, k2wk_trans_param.strides_, + k2wk_trans_param.out_strides_, k2wk_trans_param.num_axes_, 0, 1); +} + +void VMulWeightV(RelativePositionAttentionParameter *param, Matrix *v_mat, const Matrix *wv_mat, Matrix *bv_mat, + Matrix *v2wv_mat, Matrix *v2wv_trans_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + // V * WV + int v_area = v_mat->packed_row_ * v_mat->packed_col_; + int wv_area = wv_mat->packed_row_ * wv_mat->packed_col_; + int v2wv_area = v2wv_mat->row_ * v2wv_mat->col_ * v2wv_mat->batch_ / param->batch_; + float *v2wv = v2wv_mat->data_; + memset(v2wv, 0, param->batch_ * v2wv_area * sizeof(float)); + for (int i = 0; i < param->batch_; i++) { + float *cur_v = v_mat->packed_data_ + i * v_area; + float *cur_wv = wv_mat->packed_data_ + i * wv_area; + float *cur_v2wv = v2wv + i * v2wv_area; + MatMulOpt(cur_v, cur_wv, cur_v2wv, bv_mat->packed_data_, ActType_No, v_mat->col_, v_mat->row_, wv_mat->col_, + wv_mat->col_, OutType_Nhwc); + } + // V * WV perm [0,2,1,3] + float *v2wv_trans_data = v2wv_trans_mat->data_; + int v2wv_trans_area = v2wv_trans_mat->row_ * v2wv_trans_mat->col_; + memset(v2wv_trans_data, 0, v2wv_trans_mat->batch_ * v2wv_trans_area * sizeof(float)); + TransposeParameter v2wv_trans_param; + int v2wv_in_shape[] = {batch, param->v_seq_, num_heads, depth}; + int v2wv_out_shape[] = {batch, num_heads, param->v_seq_, depth}; + int v2wv_perm[] = {0, 2, 1, 3}; + (void)GetTransposeParameter(&v2wv_trans_param, v2wv_in_shape, 4, v2wv_out_shape, 4, v2wv_perm, 4); + TransposeDimsFp32(v2wv, v2wv_trans_data, v2wv_out_shape, v2wv_trans_param.perm_, v2wv_trans_param.strides_, + v2wv_trans_param.out_strides_, v2wv_trans_param.num_axes_, 0, 1); +} + +void PMulWeightP(RelativePositionAttentionParameter *param, Matrix *p_mat, const Matrix *wp_mat, Matrix *p2wp_mat, + Matrix *p2wp_trans_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + + // P * WP + int p_area = p_mat->packed_row_ * p_mat->packed_col_; + int wp_area = wp_mat->packed_row_ * wp_mat->packed_col_; + int p2wp_area = p2wp_mat->row_ * p2wp_mat->col_ * p2wp_mat->batch_ / param->batch_; + float *p2wp_data = p2wp_mat->data_; + memset(p2wp_data, 0, param->batch_ * p2wp_area * sizeof(float)); + for (int i = 0; i < param->batch_; i++) { + float *cur_p = p_mat->packed_data_ + i * p_area; + float *cur_wp = wp_mat->packed_data_ + i * wp_area; + float *cur_p2wp = p2wp_data + i * p2wp_area; + MatMulOpt(cur_p, cur_wp, cur_p2wp, NULL, ActType_No, p_mat->col_, p_mat->row_, wp_mat->col_, wp_mat->col_, + OutType_Nhwc); + } + // P * WP perm [0,2,3,1] + float *p2wp_trans_data = p2wp_trans_mat->data_; + int p2wp_trans_area = p2wp_trans_mat->row_ * p2wp_trans_mat->col_; + memset(p2wp_trans_data, 0, p2wp_trans_mat->batch_ * p2wp_trans_area * sizeof(float)); + TransposeParameter p2wp_trans_param; + int p2wp_in_shape[] = {batch, param->p_seq_, num_heads, depth}; + int p2wp_out_shape[] = {batch, num_heads, depth, param->p_seq_}; + int p2wp_perm[] = {0, 2, 3, 1}; + (void)GetTransposeParameter(&p2wp_trans_param, p2wp_in_shape, 4, p2wp_out_shape, 4, p2wp_perm, 4); + TransposeDimsFp32(p2wp_data, p2wp_trans_data, p2wp_out_shape, p2wp_trans_param.perm_, p2wp_trans_param.strides_, + p2wp_trans_param.out_strides_, p2wp_trans_param.num_axes_, 0, 1); +} + +void CalculateLogits(RelativePositionAttentionParameter *param, Matrix *q2wq_with_pu_trans_mat, + Matrix *q2wq_with_pv_trans_mat, Matrix *k2wk_trans_mat, Matrix *p2wp_trans_mat, + Matrix *logits_with_u_mat, Matrix *logits_with_v_mat, Matrix *logits_with_v_pad_mat, + Matrix *logits_with_v_shifted_mat, Matrix *logits_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int depth = d_model / num_heads; + + // pack Q_WITH_U as left_matrix + // since we malloc dst data, pack function can not be failed + (void)PackLeftMatrix(q2wq_with_pu_trans_mat, param->row_tile_); + // pack Q_WITH_V as left_matrix + (void)PackLeftMatrix(q2wq_with_pv_trans_mat, param->row_tile_); + // pack K * WK as right_matrix + (void)PackRightMatrix(k2wk_trans_mat, param->col_tile_); + // pack P * WP as right_matrix + (void)PackRightMatrix(p2wp_trans_mat, param->col_tile_); + + // q_with_pu * k = logits_with_u + MatMulOpt(q2wq_with_pu_trans_mat->packed_data_, k2wk_trans_mat->packed_data_, logits_with_u_mat->data_, NULL, + ActType_No, q2wq_with_pu_trans_mat->col_, logits_with_u_mat->row_, logits_with_u_mat->col_, + logits_with_u_mat->col_, OutType_Nhwc); + + // q_with_pv * p = logits_with_v + MatMulOpt(q2wq_with_pv_trans_mat->packed_data_, p2wp_trans_mat->packed_data_, logits_with_v_mat->data_, NULL, + ActType_No, q2wq_with_pv_trans_mat->col_, logits_with_v_mat->row_, logits_with_v_mat->col_, + logits_with_v_mat->col_, OutType_Nhwc); + // relative shift logits_with_v + float *pad_buf = logits_with_v_pad_mat->data_; + float *logits_with_v_shifted_data = logits_with_v_shifted_mat->data_; + RelativeShift(logits_with_v_mat, pad_buf, logits_with_v_shifted_data); + // logits = (logits_with_u + logits_with_v) / sqrt(depth) + float *logits_buffer = logits_mat->data_; + ElementOptAddDiv(logits_with_u_mat->data_, logits_with_v_shifted_data, 1 / sqrt(depth), logits_buffer, + logits_with_u_mat->batch_, logits_with_u_mat->row_ * logits_with_u_mat->col_); +} + +void RelPosAttention(RelativePositionAttentionParameter *param, Matrix *logits_mat, Matrix *softmax_mat, + Matrix *v2wv_trans_mat, Matrix *logits2v_mat, Matrix *logits2v_trans_mat, const Matrix *wo_mat, + Matrix *bo_mat, Matrix *output_mat) { + int num_heads = param->num_heads_; + int d_model = param->d_model_; + int batch = param->batch_; + int depth = d_model / num_heads; + float *logits_buffer = logits_mat->data_; + // softmax(logits) + SoftmaxLastAxis(logits_buffer, softmax_mat->data_, batch * num_heads * softmax_mat->row_, softmax_mat->col_); + + // logits * v + (void)PackLeftMatrix(softmax_mat, param->row_tile_); + (void)PackRightMatrix(v2wv_trans_mat, param->col_tile_); + int softmax_logits_area = softmax_mat->packed_row_ * softmax_mat->packed_col_; + int v2wv_area = v2wv_trans_mat->packed_row_ * v2wv_trans_mat->packed_col_; + int logits2v_area = logits2v_mat->row_ * logits2v_mat->col_; + float *logits2v_data = logits2v_mat->data_; + memset(logits2v_data, 0, logits2v_mat->batch_ * logits2v_area * sizeof(float)); + for (int i = 0; i < logits2v_mat->batch_; i++) { + float *cur_logits = softmax_mat->packed_data_ + i * softmax_logits_area; + float *cur_v2wv = v2wv_trans_mat->packed_data_ + i * v2wv_area; + float *cur_logits2v = logits2v_data + i * logits2v_area; + MatMulOpt(cur_logits, cur_v2wv, cur_logits2v, NULL, ActType_No, softmax_mat->col_, softmax_mat->row_, + v2wv_trans_mat->col_, v2wv_trans_mat->col_, OutType_Nhwc); + } + // multi_head output perm [0,2,1,3] + float *logits2v_trans_data = logits2v_trans_mat->data_; + int logits2v_trans_area = logits2v_trans_mat->row_ * logits2v_trans_mat->col_; + memset(logits2v_trans_data, 0, logits2v_trans_mat->batch_ * logits2v_trans_area * sizeof(float)); + TransposeParameter logits2v_trans_param; + int logits2v_trans_in_shape[] = {batch, num_heads, param->q_seq_, depth}; + int logits2v_trans_out_shape[] = {batch, param->q_seq_, num_heads, depth}; + int logits2v_trans_perm[] = {0, 2, 1, 3}; + (void)GetTransposeParameter(&logits2v_trans_param, logits2v_trans_in_shape, 4, logits2v_trans_out_shape, 4, + logits2v_trans_perm, 4); + TransposeDimsFp32(logits2v_data, logits2v_trans_data, logits2v_trans_out_shape, logits2v_trans_param.perm_, + logits2v_trans_param.strides_, logits2v_trans_param.out_strides_, logits2v_trans_param.num_axes_, 0, + 1); + // concat = reshape [batch, -1, d_model] + logits2v_trans_mat->batch_ = batch; + logits2v_trans_mat->row_ = param->q_seq_; + logits2v_trans_mat->col_ = param->d_model_; + // * o + (void)PackLeftMatrix(logits2v_trans_mat, param->row_tile_); + int concat_out_area = logits2v_trans_mat->packed_row_ * logits2v_trans_mat->packed_col_; + int wo_area = wo_mat->packed_row_ * wo_mat->packed_col_; + int output_area = output_mat->row_ * output_mat->col_; + for (int i = 0; i < output_mat->batch_; i++) { + float *cur_concat_out = logits2v_trans_mat->packed_data_ + i * concat_out_area; + float *cur_wo = wo_mat->packed_data_ + i * wo_area; + float *cur_output = output_mat->data_ + i * output_area; + MatMulOpt(cur_concat_out, cur_wo, cur_output, bo_mat->packed_data_, ActType_No, logits2v_trans_mat->col_, + logits2v_trans_mat->row_, wo_mat->col_, wo_mat->col_, OutType_Nhwc); + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/attention_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/attention_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..5b713831f1de44db6d023c36b8b71432c9c57f2e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/attention_fp32.h @@ -0,0 +1,72 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ATTENTION_FP32_H_ +#define MINDSPORE_NNACL_FP32_ATTENTION_FP32_H_ + +#include "nnacl/attention_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct Matrix { + float *data_; + int row_; + int col_; + float *packed_data_; + int packed_row_; + int packed_col_; + int batch_; + bool is_transpose_; +} Matrix; + +int InitMatrix(Matrix *matrix, int batch, int row, int col, bool is_trans); + +size_t LeftMatrixPackElementSize(Matrix *matrix, int row_tile); + +size_t RightMatrixPackElementSize(Matrix *matrix, int col_tile); + +int PackLeftMatrix(Matrix *matrix, int row_tile); + +int PackRightMatrix(Matrix *matrix, int col_tile); + +int PackAttentionBias(Matrix *matrix, int tile); + +void QWithPosition(RelativePositionAttentionParameter *param, Matrix *q_mat, const Matrix *wq_mat, Matrix *bq_mat, + Matrix *q2wq_mat, Matrix *pu_mat, Matrix *pv_mat, Matrix *q2wq_with_pos_mat, + Matrix *q2wq_with_pu_trans_mat, Matrix *q2wq_with_pv_trans_mat); + +void KMulWeightK(RelativePositionAttentionParameter *param, Matrix *k_mat, const Matrix *wk_mat, Matrix *bk_mat, + Matrix *k2wk_mat, Matrix *k2wk_trans_mat); + +void VMulWeightV(RelativePositionAttentionParameter *param, Matrix *v_mat, const Matrix *wv_mat, Matrix *bv_mat, + Matrix *v2wv_mat, Matrix *v2wv_trans_mat); + +void PMulWeightP(RelativePositionAttentionParameter *param, Matrix *p_mat, const Matrix *wp_mat, Matrix *p2wp_mat, + Matrix *p2wp_trans_mat); + +void CalculateLogits(RelativePositionAttentionParameter *param, Matrix *q2wq_with_pu_trans_mat, + Matrix *q2wq_with_pv_trans_mat, Matrix *k2wk_trans_mat, Matrix *p2wp_trans_mat, + Matrix *logits_with_u_mat, Matrix *logits_with_v_mat, Matrix *logits_with_v_pad_mat, + Matrix *logits_with_v_shifted_mat, Matrix *logits_mat); + +void RelPosAttention(RelativePositionAttentionParameter *param, Matrix *logits_mat, Matrix *softmax_mat, + Matrix *v2wv_trans_mat, Matrix *logits2v_mat, Matrix *logits2v_trans_mat, const Matrix *wo_mat, + Matrix *bo_mat, Matrix *output_mat); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_ATTENTION_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/batchnorm_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/batchnorm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..78525d52c89da7c0c089e814fe701cffa5e411b7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/batchnorm_fp32.c @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/batchnorm_fp32.h" +#include +#include "nnacl/op_base.h" +#include "nnacl/batchnorm_fp32_simd.h" +#include "nnacl/kernel/fused_batch_norm.h" +#include "nnacl/tensor_c_utils.h" + +int FusedBatchNormEval(KernelBase *self) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + + if (fused_batch_norm->trained_) { + TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT]; + TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT]; + TensorC *mean_tensor = fused_batch_norm->bn_.base_.in_[FOURTH_INPUT]; + TensorC *var_tensor = fused_batch_norm->bn_.base_.in_[FIFTH_INPUT]; + (void)memcpy(fused_batch_norm->scale_, scale_tensor->data_, NNACLGetSize(scale_tensor)); + (void)memcpy(fused_batch_norm->offset_, offset_tensor->data_, NNACLGetSize(offset_tensor)); + (void)memcpy(fused_batch_norm->bn_.mean_, mean_tensor->data_, NNACLGetSize(mean_tensor)); + (void)memcpy(fused_batch_norm->bn_.variance_, var_tensor->data_, NNACLGetSize(var_tensor)); + } + return NNACL_OK; +} + +void BatchNormSetupVirtualBatch(KernelBase *self, int virtual_batch_multiplier, int momentum) { + BatchNormStruct *bn = (BatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_VOID(bn); + if (virtual_batch_multiplier > 0) { + float new_momentum = (momentum < 0.0f) ? (bn->momentum_ / virtual_batch_multiplier) : momentum; + bn->momentum_ = new_momentum; + } + return; +} + +void BatchNormFp32(const float *input, const float *mean, const float *variance, const BatchNormStruct *param, + int task_id, int thread_num, float *output) { + int units_per_thread = UP_DIV(param->unit_, thread_num); + int completed_units = task_id * units_per_thread; + int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); + int channel = param->channel_; + int cur_offset = completed_units * channel; + float epsilon = param->epsilon_; + + for (int i = 0; i < cur_unit; i++) { + const float *unit_input = input + cur_offset; + float *unit_output = output + cur_offset; + int c = 0; + + SIMD_RUN_NO_SCALAR(BatchNormFp32, c, unit_input, mean, variance, channel, epsilon, unit_output); + + for (; c < channel; c++) { + float variance_sqrt = sqrtf(variance[c] + epsilon); + unit_output[c] = (unit_input[c] - mean[c]) / variance_sqrt; + } + cur_offset += channel; + } +} + +void FusedBatchNormFp32(const float *input, const float *scale, const float *offset, const float *mean, + const float *variance, const BatchNormStruct *param, int task_id, int thread_num, + float *output) { + int units_per_thread = UP_DIV(param->unit_, thread_num); + int completed_units = task_id * units_per_thread; + int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units); + int channel = param->channel_; + float epsilon = param->epsilon_; + int cur_offset = completed_units * channel; + + for (int i = 0; i < cur_unit; i++) { + const float *unit_input = input + cur_offset; + float *unit_output = output + cur_offset; + int c = 0; + + SIMD_RUN_NO_SCALAR(FusedBatchNormFp32, c, unit_input, scale, offset, mean, variance, channel, epsilon, unit_output); + + for (; c < channel; c++) { + float variance_sqrt = sqrtf(variance[c] + epsilon); + float norm_val = (unit_input[c] - mean[c]) / variance_sqrt; + unit_output[c] = norm_val * scale[c] + offset[c]; + } + cur_offset += channel; + } +} + +void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, const BatchNormStruct *param, + float *save_mean, float *save_var, bool isBatchNorm2d) { + const float N = (float)param->unit_; + const float VN = N; + const float VNUB = (isBatchNorm2d == false) ? N : ((N > 1.0f) ? (N - 1.0f) : 1.0f); + const float momentum = (1.0f - param->momentum_); + + for (int i = 0; i < param->unit_; i++) { + for (int c = 0; c < param->channel_; c++) { + int idx = i * param->channel_ + c; + run_mean[c] += input[idx]; + } + } + for (int c = 0; c < param->channel_; c++) { + run_mean[c] /= N; + } + for (int i = 0; i < param->unit_; i++) { + for (int c = 0; c < param->channel_; c++) { + int idx = i * param->channel_ + c; + run_var[c] += (input[idx] - run_mean[c]) * (input[idx] - run_mean[c]); + } + } + for (int c = 0; c < param->channel_; c++) { + float unbiased_var = (run_var[c] / VNUB); + run_var[c] = (run_var[c] / VN); + save_mean[c] = momentum * save_mean[c] + (1.0f - momentum) * run_mean[c]; + save_var[c] = momentum * save_var[c] + (1.0f - momentum) * unbiased_var; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/batchnorm_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/batchnorm_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..cbb5d9a0555c9aa6ba8cc7ee6f3ed5a4ab57708c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/batchnorm_fp32.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_BATCHNORM_FP32_H_ +#define NNACL_FP32_BATCHNORM_FP32_H_ + +#include "nnacl/kernel/batch_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void BatchNormSetupVirtualBatch(KernelBase *self, int virtual_batch_multiplier, int momentum); +void BatchNormFp32(const float *input, const float *mean, const float *variance, const BatchNormStruct *param, + int task_id, int thread_num, float *output); + +int FusedBatchNormEval(KernelBase *self); +void FusedBatchNormFp32(const float *input, const float *scale, const float *offset, const float *mean, + const float *variance, const BatchNormStruct *param, int task_id, int thread_num, + float *output); +void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, const BatchNormStruct *param, + float *save_mean, float *save_var, bool isBatchNorm2d); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_BATCHNORM_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/batchnorm_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/batchnorm_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..90bc0f93228e8d3da3e7c02290734dd862e74a4f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/batchnorm_fp32_simd.h.in @@ -0,0 +1,60 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int BatchNormFp32@SIMD_INSTRUCTION@(int index, const float *input, const float *mean, + const float *variance, int channel, float epsilon, float *output) { + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input_data = SIMD_LD_F32(input + index); + SIMD_F32 mean_ = SIMD_LD_F32(mean + index); + SIMD_F32 variance_ = SIMD_LD_F32(variance + index); + SIMD_F32 variance_sqrt = SIMD_SQRT_F32(SIMD_ADD_F32(variance_, SIMD_MOV_F32(epsilon))); + SIMD_F32 output_data = SIMD_DIV_F32(SIMD_SUB_F32(input_data, mean_), variance_sqrt); + SIMD_ST_F32(output + index, output_data); + } + return index; +} + +static inline int FusedBatchNormFp32@SIMD_INSTRUCTION@(int index, const float *input, const float *scale, + const float *offset, const float *mean, const float *variance, int channel, float epsilon, float *output) { + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input_data = SIMD_LD_F32(input + index); + SIMD_F32 scale_ = SIMD_LD_F32(scale + index); + SIMD_F32 offset_ = SIMD_LD_F32(offset + index); + SIMD_F32 mean_ = SIMD_LD_F32(mean + index); + SIMD_F32 variance_ = SIMD_LD_F32(variance + index); + SIMD_F32 variance_sqrt = SIMD_SQRT_F32(SIMD_ADD_F32(variance_, SIMD_MOV_F32(epsilon))); + SIMD_F32 norm_val = SIMD_DIV_F32(SIMD_SUB_F32(input_data, mean_), variance_sqrt); + SIMD_F32 output_data = SIMD_ADD_F32(SIMD_MUL_F32(norm_val, scale_), offset_); + SIMD_ST_F32(output + index, output_data); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bce_with_logits_loss_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bce_with_logits_loss_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d52ea48837a1a164131da6ca76f1662979f044e2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bce_with_logits_loss_fp32.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void BCEWithLogitLoss(const float *logits, const float *label, const float *weight, const float *pos_weight, int length, + bool reduction, float *output, float *reduction_sum); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_ACTIVATION_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bce_with_logits_loss_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bce_with_logits_loss_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..b5deeae44314fae27c7a7dbfa46163b8ddfc6ca1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bce_with_logits_loss_fp32_simd.h.in @@ -0,0 +1,62 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_BCE_WITH_LOGITS_LOSS_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_BCE_WITH_LOGITS_LOSS_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int BCEWithLogitLoss@SIMD_INSTRUCTION@(int index, const float *logits, const float *label, + const float *weight, const float *pos_weight, int length, bool reduction, float *output, + float *reduction_sum) { + SIMD_F32 zero = SIMD_SET0_F32; + SIMD_F32 ones = SIMD_MOV_F32(1.0f); + SIMD_F32 middle_output = SIMD_SET0_F32; + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 logits_tmp = SIMD_LD_F32(logits + index); + SIMD_F32 label_tmp = SIMD_LD_F32(label + index); + SIMD_F32 weight_tmp = SIMD_LD_F32(weight + index); + SIMD_F32 pos_weight_tmp = SIMD_LD_F32(pos_weight + index); + SIMD_F32 neg_logits_tmp = SIMD_SUB_F32(zero, logits_tmp); + SIMD_F32 max_value = neg_logits_tmp; + max_value = SIMD_MAX_F32(max_value, zero); + SIMD_F32 neg_max_value = SIMD_SUB_F32(zero, max_value); + SIMD_F32 log_weight = SIMD_ADD_F32(SIMD_MUL_F32(SIMD_SUB_F32(pos_weight_tmp, ones), label_tmp), ones); + SIMD_F32 log_exp_value = + SIMD_LOG_F32(SIMD_ADD_F32(SIMD_HEXP_F32(neg_max_value), SIMD_HEXP_F32(SIMD_SUB_F32(neg_logits_tmp, max_value)))); + SIMD_F32 loss = SIMD_ADD_F32(SIMD_MUL_F32(SIMD_SUB_F32(ones, label_tmp), logits_tmp), + SIMD_MUL_F32(log_weight, SIMD_ADD_F32(log_exp_value, max_value))); + if (reduction) { + middle_output = SIMD_FMADD_F32(loss, weight_tmp, middle_output); + } else { + SIMD_ST_F32(output + index, SIMD_MUL_F32(loss, weight_tmp)); + } + } + if (reduction) { + *reduction_sum += SIMD_GET_SUM_F32(middle_output); + } + return index; +} +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bce_with_loigts_loss_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bce_with_loigts_loss_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a948dd33fa8896d2bf69ad57db324ee9f8d3474d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bce_with_loigts_loss_fp32.c @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/bce_with_logits_loss_fp32.h" +#include "nnacl/bce_with_logits_loss_fp32_simd.h" + +void BCEWithLogitLoss(const float *logits, const float *label, const float *weight, const float *pos_weight, int length, + bool reduction, float *output, float *reduction_sum) { + int i = 0; + float simd_reduction_output = 0.0f; + SIMD_RUN_NO_SCALAR(BCEWithLogitLoss, i, logits, label, weight, pos_weight, length, reduction, output, + &simd_reduction_output); + for (; i < length; ++i) { + float logits_value = logits[i]; + float label_value = label[i]; + float weight_value = weight[i]; + float post_weight_value = pos_weight[i]; + float max_value = -logits_value; + max_value = max_value > 0.f ? max_value : 0.f; + float log_weight = (post_weight_value - 1.0f) * label_value + 1.0f; + float log_exp_value = logf(expf(-max_value) + expf(-logits_value - max_value)); + float loss = (1.0f - label_value) * logits_value + log_weight * (log_exp_value + max_value); + if (reduction) { + simd_reduction_output += loss * weight_value; + } else { + output[i] = loss * weight_value; + } + } + if (reduction) { + *reduction_sum = simd_reduction_output; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bias_add.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bias_add.c new file mode 100644 index 0000000000000000000000000000000000000000..9ea1f611b3c778db9d339c76ab4be289cb1f5862 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bias_add.c @@ -0,0 +1,123 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/bias_add.h" +#include "nnacl/op_base.h" +#include "nnacl/bias_add_simd.h" + +void BiasAddByInnerCore(const float *input, const float *bias, float *output, int64_t num) { + int64_t index = 0; + + SIMD_RUN_NO_SCALAR(BiasAddByInnerCore, index, input, bias, output, num); + + for (; index < num; ++index) { + output[index] = input[index] + bias[index]; + } +} + +void BiasAddByBatchCore(const float *input, const float *bias, float *output, int64_t num) { + float *output1 = output; + float *output2 = output + num; + float *output3 = output + num * 2; + float *output4 = output + num * 3; + int64_t index = 0; + + SIMD_RUN_NO_SCALAR(BiasAddByBatchCore, index, input, bias, output1, output2, output3, output4, num); + + const float *input_data1 = input; + const float *input_data2 = input + num; + const float *input_data3 = input + num * 2; + const float *input_data4 = input + num * 3; + for (; index < num; ++index) { + output1[index] = input_data1[index] + bias[index]; + output2[index] = input_data2[index] + bias[index]; + output3[index] = input_data3[index] + bias[index]; + output4[index] = input_data4[index] + bias[index]; + } +} + +void DoBiasAddByBatch(const float *input, const float *bias, float *output, int64_t start_inner, int64_t start_outer, + int64_t end_inner, int64_t end_outer, int64_t inner_num) { + const float *cur_bias = bias + start_inner; + if (start_outer == end_outer) { + BiasAddByInnerCore(input, cur_bias, output, end_inner - start_inner); + return; + } + if (start_inner != 0) { + BiasAddByInnerCore(input, cur_bias, output, inner_num - start_inner); + start_outer += 1; + input += inner_num - start_inner; + cur_bias = bias; + output += inner_num - start_inner; + } + int64_t step = C4NUM * inner_num; + for (; start_outer <= end_outer - C4NUM; start_outer += C4NUM) { + BiasAddByBatchCore(input, cur_bias, output, inner_num); + input += step; + output += step; + } + for (; start_outer < end_outer; ++start_outer) { + BiasAddByInnerCore(input, cur_bias, output, inner_num); + input += inner_num; + output += inner_num; + } + BiasAddByInnerCore(input, cur_bias, output, end_inner); +} + +void DoBiasAddByInner(const float *input, const float *bias, float *output, int64_t start_inner, int64_t start_outer, + int64_t end_inner, int64_t end_outer, int64_t inner_num) { + const float *cur_bias = bias + start_inner; + if (start_outer == end_outer) { + BiasAddByInnerCore(input, cur_bias, output, end_inner - start_inner); + return; + } else { + BiasAddByInnerCore(input, cur_bias, output, inner_num - start_inner); + start_outer += 1; + input += inner_num - start_inner; + cur_bias = bias; + output += inner_num - start_inner; + } + if (start_outer == end_outer) { + BiasAddByInnerCore(input, cur_bias, output, end_inner); + return; + } else { + for (; start_outer < end_outer; ++start_outer) { + BiasAddByInnerCore(input, cur_bias, output, inner_num); + input += inner_num; + output += inner_num; + } + } + BiasAddByInnerCore(input, bias, output, end_inner); +} + +void BiasAddOpt(const float *input, const float *bias, float *output, int64_t start, int64_t end, int64_t inner_num, + bool batch_priority) { + if (inner_num == 0) { + return; + } + int64_t start_outer = start / inner_num; + int64_t start_inner = start % inner_num; + int64_t end_outer = end / inner_num; + int64_t end_inner = end % inner_num; + const float *cur_input = input + start; + float *cur_output = output + start; + + if (batch_priority) { + DoBiasAddByBatch(cur_input, bias, cur_output, start_inner, start_outer, end_inner, end_outer, inner_num); + } else { + DoBiasAddByInner(cur_input, bias, cur_output, start_inner, start_outer, end_inner, end_outer, inner_num); + } +} diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/activation.cuh b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bias_add.h similarity index 59% rename from mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/activation.cuh rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/bias_add.h index 81d187674bdafd8f96ff55377e07ce47fa4ef10a..210b176858dabd3292fbbd64102c851d054e5f69 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/activation.cuh +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bias_add.h @@ -14,13 +14,21 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_ACTIVATION_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_ACTIVATION_H_ +#ifndef MINDSPORE_NNACL_FP32_BIAS_ADD_H_ +#define MINDSPORE_NNACL_FP32_BIAS_ADD_H_ -template -void Sigmoid(const T *input1, T *output, int element_cnt, cudaStream_t stream); +#include +#include -template -void Gelu(const T *input1, T *output, int element_cnt, cudaStream_t stream); +#ifdef __cplusplus +extern "C" { +#endif -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_ACTIVATION_H_ +void BiasAddOpt(const float *input, const float *bias, float *output, int64_t start, int64_t end, int64_t inner_num, + bool batch_priority); + +#ifdef __cplusplus +}; +#endif + +#endif // MINDSPORE_NNACL_FP32_BIAS_ADD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bias_add_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bias_add_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..80052161749dec60e7027eb4067a8dd67244ef60 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/bias_add_simd.h.in @@ -0,0 +1,57 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_BIAS_ADD_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_BIAS_ADD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int BiasAddByInnerCore@SIMD_INSTRUCTION@(int index, const float *input, const float *bias, float *output, + int64_t num) { + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(input + index); + SIMD_F32 vin1 = SIMD_LD_F32(bias + index); + SIMD_F32 vout = SIMD_ADD_F32(vin0, vin1); + SIMD_ST_F32(output + index, vout); + } + return index; +} + +static inline int BiasAddByBatchCore@SIMD_INSTRUCTION@(int index, const float *input, const float *bias, float *output1, + float *output2, float *output3, float *output4, int64_t num) { + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_LDX4_F32(input_data, input + index, num); + SIMD_F32 bias_data = SIMD_LD_F32(bias + index); + SIMD_ST_F32(output1 + index, SIMD_ADD_F32(input_data1, bias_data)); + SIMD_ST_F32(output2 + index, SIMD_ADD_F32(input_data2, bias_data)); + SIMD_ST_F32(output3 + index, SIMD_ADD_F32(input_data3, bias_data)); + SIMD_ST_F32(output4 + index, SIMD_ADD_F32(input_data4, bias_data)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif + +#endif // MINDSPORE_NNACL_FP32_BIAS_ADD_SIMD_H_ \ No newline at end of file diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/oneslike_tensorrt.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cdist_fp32.c similarity index 31% rename from mindspore-lite/src/extendrt/delegate/tensorrt/op/oneslike_tensorrt.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/cdist_fp32.c index 80bff5232fb7e9c1d987a7b0e64ec0ee805e47ab..cc4d64d28e00439792171f1cbd84904522cdb569 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/oneslike_tensorrt.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cdist_fp32.c @@ -13,29 +13,65 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "nnacl/fp32/cdist_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/cdist_fp32_simd.h" -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ONESLIKE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ONESLIKE_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" +void CdistTwoNormalOpt(const float *a, const float *b, float *dst, int64_t m, float p) { + float result = 0; + int64_t i = 0; -namespace mindspore::lite { -class OneslikeTensorRT : public TensorRTOp { - public: - OneslikeTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} + SIMD_RUN_NO_SCALAR(CdistTwoNormalOpt, i, a, b, &result, m); - ~OneslikeTensorRT() override = default; + for (; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result += x * x; + } + result = sqrtf(result); + *dst = result; - int AddInnerOp(TensorRTContext *ctx) override; + return; +} - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; +void CdistPNormalOpt(const float *a, const float *b, float *dst, int64_t m, float p) { + float result = 0; + int64_t i = 0; - private: - int RunAsTrtOps(TensorRTContext *ctx); -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ONESLIKE_TENSORRT_H_ + SIMD_RUN_NO_SCALAR(CdistPNormalOpt, i, a, b, &result, m, p); + + for (; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result += powf(x, p); + } + result = powf(result, 1.0 / p); + *dst = result; + + return; +} + +void CdistZeroNormalOpt(const float *a, const float *b, float *c, int64_t m, float p) { + float result = 0; + for (int64_t i = 0; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result += MSMIN(ceilf(x), 1.0f); + } + *c = result; +} + +void CdistOneNormalOpt(const float *a, const float *b, float *c, int64_t m, float p) { + float result = 0; + for (int64_t i = 0; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result += x; + } + *c = result; +} + +void CdistInfNormalOpt(const float *a, const float *b, float *c, int64_t m, float p) { + float result = 0; + for (int64_t i = 0; i < m; i++) { + float x = fabsf(a[i] - b[i]); + result = MSMAX(result, x); + } + *c = result; +} diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/utils.cuh b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cdist_fp32.h similarity index 48% rename from mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/utils.cuh rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/cdist_fp32.h index 8d957877db92b4b35921498bcb8b1a0d12622f3f..2b51a59d34aedb0315e20999c482f3672bd93548 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/utils.cuh +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cdist_fp32.h @@ -13,29 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef MINDSPORE_NNACL_FP32_CDIST_H_ +#define MINDSPORE_NNACL_FP32_CDIST_H_ -#include -#include +#include "nnacl/op_base.h" -#define FINAL_MASK 0xffffffff +#ifdef __cplusplus +extern "C" { +#endif +void CdistTwoNormalOpt(const float *a, const float *b, float *dst, int64_t m, float p); +void CdistPNormalOpt(const float *a, const float *b, float *dst, int64_t m, float p); -template -__device__ T warpedReduceSum(T val) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); - } - return val; -} +void CdistZeroNormalOpt(const float *a, const float *b, float *c, int64_t m, float p); +void CdistOneNormalOpt(const float *a, const float *b, float *c, int64_t m, float p); +void CdistInfNormalOpt(const float *a, const float *b, float *c, int64_t m, float p); -template -__device__ T blockReduceSum(T val) { - static __shared__ T shared[32]; - int warped = threadIdx.x & 0x1f; - val = warpedReduceSum(val); - if (warped == 0) shared[threadIdx.x >> 5] = val; - __syncthreads(); - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[warped] : static_cast(0.0); - val = warpedReduceSum(val); - return val; +#ifdef __cplusplus } +#endif + +#endif // MINDSPORE_NNACL_FP32_CDIST_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cdist_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cdist_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..5cb30ab98846c3b0b05f955a20618e9d0056e046 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cdist_fp32_simd.h.in @@ -0,0 +1,63 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CDIST_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_CDIST_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t CdistTwoNormalOpt@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, + float *out, int64_t size) { + SIMD_F32 result_vec = SIMD_MOV_F32(0.0f); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_vec = SIMD_LD_F32(a + index); + SIMD_F32 b_vec = SIMD_LD_F32(b + index); + SIMD_F32 tmp_vec = SIMD_SUB_F32(a_vec, b_vec); + tmp_vec = SIMD_ABS_F32(tmp_vec); + result_vec = SIMD_FMADD_F32(tmp_vec, tmp_vec, result_vec); + } + *out += SIMD_GET_SUM_F32(result_vec); + + return index; +} + +static inline int64_t CdistPNormalOpt@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, + float *out, int64_t size, float p) { + SIMD_F32 result_vec = SIMD_MOV_F32(0.0f); + SIMD_F32 p_vec = SIMD_MOV_F32(p); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_vec = SIMD_LD_F32(a + index); + SIMD_F32 b_vec = SIMD_LD_F32(b + index); + SIMD_F32 tmp_vec = SIMD_SUB_F32(a_vec, b_vec); + tmp_vec = SIMD_ABS_F32(tmp_vec); + tmp_vec = SIMD_POW_F32(tmp_vec, p_vec); + result_vec = SIMD_ADD_F32(tmp_vec, result_vec); + } + *out += SIMD_GET_SUM_F32(result_vec); + + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/common_func_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/common_func_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..96860b72274ea91c3647530fc9fb865d5a5c6ca0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/common_func_fp32.c @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/common_func_fp32.h" + +void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t plane_stride, size_t oc_stride, ActType relu_type, int size) { + if (size == 0) { + return; + } + for (size_t oc = 0; oc < output_channel; oc++) { + int oc_div = oc / size; + int oc_mod = oc % size; + for (int hw = 0; hw < (int)plane_size; hw++) { + int src_index = oc_div * size * plane_stride + hw * size + oc_mod; + int dst_index = hw * oc_stride + oc; + float value = src_ptr_[src_index]; + if (bias_ptr != NULL) { + value = value + bias_ptr[oc]; + } + value = (relu_type == ActType_Relu || relu_type == ActType_Relu6) ? (MSMAX(0.f, value)) : (value); + value = (relu_type == ActType_Relu6) ? (MSMIN(6.f, value)) : (value); + out_ptr[dst_index] = value; + } + } +} + +void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, size_t relu_type) { +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) + PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM); +#else + size_t oc8mod = output_channel % C8NUM; + size_t oc8div = output_channel - oc8mod; + size_t stride_size = stride * sizeof(float); + PostFuncBiasReluC8(out_ptr, c8_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type); +#endif +} + +void WinogradPostConvFuncFp32CX(const float *cx_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t plane_stride, size_t relu_type) { +#ifdef ENABLE_AVX + size_t oc8mod = output_channel % C8NUM; + size_t oc8div = output_channel - oc8mod; + size_t stride_size = (plane_stride - plane_size) * C8NUM * sizeof(float); + WinogradPostFuncBiasReluC8(out_ptr, cx_out_ptr, bias_ptr, oc8div, oc8mod, plane_size, stride_size, relu_type); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + size_t oc4mod = output_channel % C4NUM; + size_t oc4div = output_channel - oc4mod; + size_t stride_size = (plane_stride - plane_size) * C4NUM * sizeof(float); + WinogradPostFuncBiasReluC4(out_ptr, cx_out_ptr, bias_ptr, oc4div, oc4mod, plane_size, stride_size, relu_type); +#else + PostConvFuncComm(cx_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, relu_type, + C4NUM); +#endif +} + +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + const int unitStep = 4 * length; + for (int y = 0; y < h; ++y) { + float *dstY = M + y * w * unitStep; + for (int x = 0; x < w; ++x) { + float *dstX = dstY + x * unitStep; + const float *srcX = S + x * unitStep; + memset(dstX, 0, unitStep * sizeof(float)); + for (int i = 0; i < k; ++i) { + float b = B[i * h + y]; + const float *srcY = srcX + i * w * unitStep; + if (0.0f == b) { + continue; + } + for (int j = 0; j < unitStep; ++j) { + dstX[j] += srcY[j] * b; + } + } + } + } +} + +// M = S * B , M = h * w * l, S = h * k * l, B = k * w +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + const int unitStep = 4 * length; + for (int y = 0; y < h; ++y) { + float *dstY = M + y * w * unitStep; + const float *srcY = S + y * k * unitStep; + + for (int x = 0; x < w; ++x) { + float *dstX = dstY + x * unitStep; + memset(dstX, 0, unitStep * sizeof(float)); + for (int i = 0; i < k; ++i) { + const float *srcX = srcY + i * unitStep; + float b = B[i * h + x]; + if (0.0f == b) { + continue; + } + for (int j = 0; j < unitStep; ++j) { + dstX[j] += srcX[j] * b; + } + } + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/common_func_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/common_func_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..9348b8a0715681a3031e14e56f818744b0c76bb2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/common_func_fp32.h @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_COMMON_FUNC_H_ +#define MINDSPORE_NNACL_FP32_COMMON_FUNC_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/conv_parameter.h" + +typedef struct ConvDwFp32BorderParam { + float *dst; + const float *src; + const float *weight; + const float *bias; + size_t height; + size_t width; + size_t in_kh_step; + size_t in_kw_step; + size_t kernel_w; + size_t relu; + size_t relu6; +} ConvDwFp32BorderParam; + +#ifdef __cplusplus +extern "C" { +#endif + +void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t stride, size_t relu_type); + +void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type); + +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length); + +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length); + +void WinogradPostConvFuncFp32CX(const float *cx_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, + size_t plane_size, size_t plane_stride, size_t relu_type); + +void WinogradPostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, + size_t plane_size, size_t plane_stride, size_t relu_type); + +void WinogradPostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t plane_stride, size_t relu_type); + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +#ifdef ENABLE_AVX +void ConvDwFp32Border(ConvDwFp32BorderParam *param); +#else +void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); +#endif +void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, + size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, + size_t in_kh_step, size_t in_kw_step); +void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step); +#endif + +#ifdef ENABLE_ARM64 +void DeconvDwFp32Border(float *dst, const float *src, const float *weight, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w); + +void ConvSwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t ic4, + size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, + size_t relu6); + +void ConvDw3x3Stride1(float *output, const float *buffer, const float *weight, const float *bias, int col_size, + int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6); + +void ConvDw3x3Stride2(float *output, const float *buffer, const float *weight, const float *bias, int col_size, + int row_size, int channel, int output_h, int output_w, size_t relu, size_t relu6); + +void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, + int in_kw_step, int channel, size_t relu, size_t relu6); + +void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, + int in_kw_step, int channel, size_t relu, size_t relu6); + +void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, + int in_kw_step, int channel, size_t relu, size_t relu6); +#endif + +#ifdef __cplusplus +} +#endif +#endif /* MINDSPORE_NNACL_FP32_COMMON_FUNC_H_ */ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/constant_of_shape_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/constant_of_shape_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..c884d03182301624a16517de69307f8c0c99fa24 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/constant_of_shape_fp32.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CONSTANT_OF_SHAPE_FP32_H_ +#define MINDSPORE_NNACL_FP32_CONSTANT_OF_SHAPE_FP32_H_ +#include +#include +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/constant_of_shape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +inline int ConstantOfShapeInt32(int32_t *output, int start, int end, int32_t value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} + +inline int ConstantOfShapeFp32(float *output, int start, int end, float value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} + +inline int ConstantOfShapeBool(bool *output, int start, int end, bool value) { + for (int i = start; i < end; i++) { + output[i] = value; + } + return NNACL_OK; +} + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONSTANT_OF_SHAPE_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_1x1_avx_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_1x1_avx_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b8d0f97a2a9d5bc18219f4bd2956214b2f69944d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_1x1_avx_fp32.c @@ -0,0 +1,1608 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/conv_1x1_avx_fp32.h" +#include "nnacl/intrinsics/ms_simd_avx_instructions.h" + +void Conv1x1SW3x32AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "vmovups 0x20(%7), %%ymm1\n" + "vmovups 0x40(%7), %%ymm2\n" + "vmovups 0x60(%7), %%ymm3\n" + "vmovups (%7, %6, 1), %%ymm4\n" + "vmovups 0x20(%7, %6, 1), %%ymm5\n" + "vmovups 0x40(%7, %6, 1), %%ymm6\n" + "vmovups 0x60(%7, %6, 1), %%ymm7\n" + "vmovups (%7, %6, 2), %%ymm8\n" + "vmovups 0x20(%7, %6, 2), %%ymm9\n" + "vmovups 0x40(%7, %6, 2), %%ymm10\n" + "vmovups 0x60(%7, %6, 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups 0x40(%2), %%ymm6\n" + "vmovups 0x60(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups 0x40(%2), %%ymm10\n" + "vmovups 0x60(%2), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm13\n" + "vbroadcastss (%0, %4), %%ymm14\n" + "vbroadcastss (%0, %4, 2), %%ymm15\n" + "vmovups (%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 0x20(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 0x40(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 0x60(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 4(%0), %%ymm13\n" + "vbroadcastss 4(%0, %4), %%ymm14\n" + "vbroadcastss 4(%0, %4, 2), %%ymm15\n" + "vmovups 128(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 160(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 192(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 224(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 8(%0), %%ymm13\n" + "vbroadcastss 8(%0, %4), %%ymm14\n" + "vbroadcastss 8(%0, %4, 2), %%ymm15\n" + "vmovups 256(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 288(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 320(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 352(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 12(%0), %%ymm13\n" + "vbroadcastss 12(%0, %4), %%ymm14\n" + "vbroadcastss 12(%0, %4, 2), %%ymm15\n" + "vmovups 384(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 416(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 448(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 480(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 16(%0), %%ymm13\n" + "vbroadcastss 16(%0, %4), %%ymm14\n" + "vbroadcastss 16(%0, %4, 2), %%ymm15\n" + "vmovups 512(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 544(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 576(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 608(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 20(%0), %%ymm13\n" + "vbroadcastss 20(%0, %4), %%ymm14\n" + "vbroadcastss 20(%0, %4, 2), %%ymm15\n" + "vmovups 640(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 672(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 704(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 736(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 24(%0), %%ymm13\n" + "vbroadcastss 24(%0, %4), %%ymm14\n" + "vbroadcastss 24(%0, %4, 2), %%ymm15\n" + "vmovups 768(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 800(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 832(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 864(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vbroadcastss 28(%0), %%ymm13\n" + "vbroadcastss 28(%0, %4), %%ymm14\n" + "vbroadcastss 28(%0, %4, 2), %%ymm15\n" + "vmovups 896(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 928(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 960(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 992(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "addq $1024, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + "vmovups %%ymm1, 0x20(%7)\n" + "vmovups %%ymm2, 0x40(%7)\n" + "vmovups %%ymm3, 0x60(%7)\n" + "vmovups %%ymm4, (%7, %6, 1)\n" + "vmovups %%ymm5, 0x20(%7, %6, 1)\n" + "vmovups %%ymm6, 0x40(%7, %6, 1)\n" + "vmovups %%ymm7, 0x60(%7, %6, 1)\n" + "vmovups %%ymm8, (%7, %6, 2)\n" + "vmovups %%ymm9, 0x20(%7, %6, 2)\n" + "vmovups %%ymm10, 0x40(%7, %6, 2)\n" + "vmovups %%ymm11, 0x60(%7, %6, 2)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void Conv1x1SW1x32AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "vmovups 0x20(%7), %%ymm1\n" + "vmovups 0x40(%7), %%ymm2\n" + "vmovups 0x60(%7), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm13\n" + "vmovups (%1), %%ymm4\n" + "vmovups 0x20(%1), %%ymm5\n" + "vmovups 0x40(%1), %%ymm6\n" + "vmovups 0x60(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 4(%0), %%ymm13\n" + "vmovups 128(%1), %%ymm4\n" + "vmovups 160(%1), %%ymm5\n" + "vmovups 192(%1), %%ymm6\n" + "vmovups 224(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 8(%0), %%ymm13\n" + "vmovups 256(%1), %%ymm4\n" + "vmovups 288(%1), %%ymm5\n" + "vmovups 320(%1), %%ymm6\n" + "vmovups 352(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 12(%0), %%ymm13\n" + "vmovups 384(%1), %%ymm4\n" + "vmovups 416(%1), %%ymm5\n" + "vmovups 448(%1), %%ymm6\n" + "vmovups 480(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 16(%0), %%ymm13\n" + "vmovups 512(%1), %%ymm4\n" + "vmovups 544(%1), %%ymm5\n" + "vmovups 576(%1), %%ymm6\n" + "vmovups 608(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 20(%0), %%ymm13\n" + "vmovups 640(%1), %%ymm4\n" + "vmovups 672(%1), %%ymm5\n" + "vmovups 704(%1), %%ymm6\n" + "vmovups 736(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 24(%0), %%ymm13\n" + "vmovups 768(%1), %%ymm4\n" + "vmovups 800(%1), %%ymm5\n" + "vmovups 832(%1), %%ymm6\n" + "vmovups 864(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + + "vbroadcastss 28(%0), %%ymm13\n" + "vmovups 896(%1), %%ymm4\n" + "vmovups 928(%1), %%ymm5\n" + "vmovups 960(%1), %%ymm6\n" + "vmovups 992(%1), %%ymm7\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm3\n" + "addq $1024, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + "vmovups %%ymm1, 0x20(%7)\n" + "vmovups %%ymm2, 0x40(%7)\n" + "vmovups %%ymm3, 0x60(%7)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm13", + "%ymm14"); +} + +void Conv1x1SW4x24AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_align / sizeof(float); + asm volatile( + "movq %10, %%rax\n" // dst_flag + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%8), %%ymm0\n" // dst_0 + "vmovups 0x20(%8), %%ymm1\n" + "vmovups 0x40(%8), %%ymm2\n" + "vmovups (%8, %7, 1), %%ymm3\n" + "vmovups 0x20(%8, %7, 1), %%ymm4\n" + "vmovups 0x40(%8, %7, 1), %%ymm5\n" + "vmovups (%8, %7, 2), %%ymm6\n" + "vmovups 0x20(%8, %7, 2), %%ymm7\n" + "vmovups 0x40(%8, %7, 2), %%ymm8\n" + "vmovups (%9), %%ymm9\n" + "vmovups 0x20(%9), %%ymm10\n" + "vmovups 0x40(%9), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups (%2), %%ymm3\n" + "vmovups 0x20(%2), %%ymm4\n" + "vmovups 0x40(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups 0x40(%2), %%ymm8\n" + "vmovups (%2), %%ymm9\n" + "vmovups 0x20(%2), %%ymm10\n" + "vmovups 0x40(%2), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "2:\n" // LoopIC + "vmovups (%1), %%ymm13\n" + "vmovups 0x20(%1), %%ymm14\n" + "vmovups 0x40(%1), %%ymm15\n" + "vbroadcastss (%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss (%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss (%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss (%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 96(%1), %%ymm13\n" + "vmovups 128(%1), %%ymm14\n" + "vmovups 160(%1), %%ymm15\n" + "vbroadcastss 4(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 4(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 4(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 4(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 192(%1), %%ymm13\n" + "vmovups 224(%1), %%ymm14\n" + "vmovups 256(%1), %%ymm15\n" + "vbroadcastss 8(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 8(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 8(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 8(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 288(%1), %%ymm13\n" + "vmovups 320(%1), %%ymm14\n" + "vmovups 352(%1), %%ymm15\n" + "vbroadcastss 12(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 12(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 12(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 12(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 384(%1), %%ymm13\n" + "vmovups 416(%1), %%ymm14\n" + "vmovups 448(%1), %%ymm15\n" + "vbroadcastss 16(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 16(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 16(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 16(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 480(%1), %%ymm13\n" + "vmovups 512(%1), %%ymm14\n" + "vmovups 544(%1), %%ymm15\n" + "vbroadcastss 20(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 20(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 20(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 20(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 576(%1), %%ymm13\n" + "vmovups 608(%1), %%ymm14\n" + "vmovups 640(%1), %%ymm15\n" + "vbroadcastss 24(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 24(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 24(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 24(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "vmovups 672(%1), %%ymm13\n" + "vmovups 704(%1), %%ymm14\n" + "vmovups 736(%1), %%ymm15\n" + "vbroadcastss 28(%0), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vbroadcastss 28(%0, %4), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "vbroadcastss 28(%0, %4, 2), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vbroadcastss 28(%0, %5), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "addq $768, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %10, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %6, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "3:\n" + "vmovups %%ymm0, (%8)\n" // dst_0 + "vmovups %%ymm1, 0x20(%8)\n" + "vmovups %%ymm2, 0x40(%8)\n" + "vmovups %%ymm3, (%8, %7, 1)\n" + "vmovups %%ymm4, 0x20(%8, %7, 1)\n" + "vmovups %%ymm5, 0x40(%8, %7, 1)\n" + "vmovups %%ymm6, (%8, %7, 2)\n" + "vmovups %%ymm7, 0x20(%8, %7, 2)\n" + "vmovups %%ymm8, 0x40(%8, %7, 2)\n" + "vmovups %%ymm9, (%9)\n" + "vmovups %%ymm10, 0x20(%9)\n" + "vmovups %%ymm11, 0x40(%9)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 + "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_flag) // 10 + : "%rax", "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void Conv1x1SW1x24AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "vmovups 0x20(%7), %%ymm1\n" + "vmovups 0x40(%7), %%ymm2\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm13\n" + "vmovups (%1), %%ymm4\n" + "vmovups 0x20(%1), %%ymm5\n" + "vmovups 0x40(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 4(%0), %%ymm13\n" + "vmovups 96(%1), %%ymm4\n" + "vmovups 128(%1), %%ymm5\n" + "vmovups 160(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 8(%0), %%ymm13\n" + "vmovups 192(%1), %%ymm4\n" + "vmovups 224(%1), %%ymm5\n" + "vmovups 256(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 12(%0), %%ymm13\n" + "vmovups 288(%1), %%ymm4\n" + "vmovups 320(%1), %%ymm5\n" + "vmovups 352(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 16(%0), %%ymm13\n" + "vmovups 384(%1), %%ymm4\n" + "vmovups 416(%1), %%ymm5\n" + "vmovups 448(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 20(%0), %%ymm13\n" + "vmovups 480(%1), %%ymm4\n" + "vmovups 512(%1), %%ymm5\n" + "vmovups 544(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 24(%0), %%ymm13\n" + "vmovups 576(%1), %%ymm4\n" + "vmovups 608(%1), %%ymm5\n" + "vmovups 640(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "vbroadcastss 28(%0), %%ymm13\n" + "vmovups 672(%1), %%ymm4\n" + "vmovups 704(%1), %%ymm5\n" + "vmovups 736(%1), %%ymm6\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm2\n" + + "addq $768, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + "vmovups %%ymm1, 0x20(%7)\n" + "vmovups %%ymm2, 0x40(%7)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm4", "%ymm5", "%ymm6", "%ymm12", "%ymm13", "%ymm14"); +} + +void Conv1x1SW6x16AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_align / sizeof(float); + asm volatile( + "movq %10, %%rax\n" // dst_flag + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%8), %%ymm0\n" // dst_0 + "vmovups 0x20(%8), %%ymm1\n" + "vmovups (%8, %7, 1), %%ymm2\n" + "vmovups 0x20(%8, %7, 1), %%ymm3\n" + "vmovups (%8, %7, 2), %%ymm4\n" + "vmovups 0x20(%8, %7, 2), %%ymm5\n" + "vmovups (%9), %%ymm6\n" + "vmovups 0x20(%9), %%ymm7\n" + "vmovups (%9, %7, 1), %%ymm8\n" + "vmovups 0x20(%9, %7, 1), %%ymm9\n" + "vmovups (%9, %7, 2), %%ymm10\n" + "vmovups 0x20(%9, %7, 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%2), %%ymm2\n" + "vmovups 0x20(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups (%2), %%ymm10\n" + "vmovups 0x20(%2), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "2:\n" // LoopIC + "movq %0, %%rax\n" + "addq %5, %%rax\n" + + "vmovups (%1), %%ymm12\n" + "vmovups 0x20(%1), %%ymm13\n" + "vbroadcastss (%0), %%ymm14\n" + "vbroadcastss (%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss (%0, %4, 2), %%ymm14\n" + "vbroadcastss (%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 64(%1), %%ymm12\n" + "vmovups 96(%1), %%ymm13\n" + "vbroadcastss 4(%0), %%ymm14\n" + "vbroadcastss 4(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 4(%0, %4, 2), %%ymm14\n" + "vbroadcastss 4(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 4(%%rax, %4), %%ymm14\n" + "vbroadcastss 4(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 128(%1), %%ymm12\n" + "vmovups 160(%1), %%ymm13\n" + "vbroadcastss 8(%0), %%ymm14\n" + "vbroadcastss 8(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 8(%0, %4, 2), %%ymm14\n" + "vbroadcastss 8(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 8(%%rax, %4), %%ymm14\n" + "vbroadcastss 8(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 192(%1), %%ymm12\n" + "vmovups 224(%1), %%ymm13\n" + "vbroadcastss 12(%0), %%ymm14\n" + "vbroadcastss 12(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 12(%0, %4, 2), %%ymm14\n" + "vbroadcastss 12(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 12(%%rax, %4), %%ymm14\n" + "vbroadcastss 12(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 256(%1), %%ymm12\n" + "vmovups 288(%1), %%ymm13\n" + "vbroadcastss 16(%0), %%ymm14\n" + "vbroadcastss 16(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 16(%0, %4, 2), %%ymm14\n" + "vbroadcastss 16(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 16(%%rax, %4), %%ymm14\n" + "vbroadcastss 16(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 320(%1), %%ymm12\n" + "vmovups 352(%1), %%ymm13\n" + "vbroadcastss 20(%0), %%ymm14\n" + "vbroadcastss 20(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 20(%0, %4, 2), %%ymm14\n" + "vbroadcastss 20(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 20(%%rax, %4), %%ymm14\n" + "vbroadcastss 20(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 384(%1), %%ymm12\n" + "vmovups 416(%1), %%ymm13\n" + "vbroadcastss 24(%0), %%ymm14\n" + "vbroadcastss 24(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 24(%0, %4, 2), %%ymm14\n" + "vbroadcastss 24(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 24(%%rax, %4), %%ymm14\n" + "vbroadcastss 24(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "vmovups 448(%1), %%ymm12\n" + "vmovups 480(%1), %%ymm13\n" + "vbroadcastss 28(%0), %%ymm14\n" + "vbroadcastss 28(%0, %4), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm0\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm3\n" + "vbroadcastss 28(%0, %4, 2), %%ymm14\n" + "vbroadcastss 28(%%rax), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm7\n" + "vbroadcastss 28(%%rax, %4), %%ymm14\n" + "vbroadcastss 28(%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm8\n" + "vfmadd231ps %%ymm13, %%ymm14, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm11\n" + + "addq $512, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %10, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %6, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "3:\n" + "vmovups %%ymm0, (%8)\n" // dst_0 + "vmovups %%ymm1, 0x20(%8)\n" + "vmovups %%ymm2, (%8, %7, 1)\n" + "vmovups %%ymm3, 0x20(%8, %7, 1)\n" + "vmovups %%ymm4, (%8, %7, 2)\n" + "vmovups %%ymm5, 0x20(%8, %7, 2)\n" + "vmovups %%ymm6, (%9)\n" // dst+3 + "vmovups %%ymm7, 0x20(%9)\n" + "vmovups %%ymm8, (%9, %7, 1)\n" + "vmovups %%ymm9, 0x20(%9, %7, 1)\n" + "vmovups %%ymm10, (%9, %7, 2)\n" + "vmovups %%ymm11, 0x20(%9, %7, 2)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 + "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_flag) // 10 + : "%rax", "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void Conv1x1SW1x16AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "vmovups 0x20(%7), %%ymm1\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm12\n" + "vmovups (%1), %%ymm13\n" + "vmovups 0x20(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 4(%0), %%ymm12\n" + "vmovups 64(%1), %%ymm13\n" + "vmovups 96(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 8(%0), %%ymm12\n" + "vmovups 128(%1), %%ymm13\n" + "vmovups 160(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 12(%0), %%ymm12\n" + "vmovups 192(%1), %%ymm13\n" + "vmovups 224(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 16(%0), %%ymm12\n" + "vmovups 256(%1), %%ymm13\n" + "vmovups 288(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 20(%0), %%ymm12\n" + "vmovups 320(%1), %%ymm13\n" + "vmovups 352(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 24(%0), %%ymm12\n" + "vmovups 384(%1), %%ymm13\n" + "vmovups 416(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "vbroadcastss 28(%0), %%ymm12\n" + "vmovups 448(%1), %%ymm13\n" + "vmovups 480(%1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + + "addq $512, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + "vmovups %%ymm1, 0x20(%7)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm13", "%ymm14"); +} + +void Conv1x1SW12x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + ic_align <<= 3; + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_align / sizeof(float); + float *dst_5 = dst + 5 * oc_align / sizeof(float); + float *dst_9 = dst + 9 * oc_align / sizeof(float); + asm volatile( + "movq %12, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%8), %%ymm0\n" // dst_0 + "vmovups (%8, %7), %%ymm1\n" + "vmovups (%8, %7, 2), %%ymm2\n" + "vmovups (%9), %%ymm3\n" // dst_3 + "vmovups (%8, %7, 4), %%ymm4\n" + "vmovups (%10), %%ymm5\n" // dst_5 + "vmovups (%10, %7, 1), %%ymm6\n" + "vmovups (%10, %7, 2), %%ymm7\n" + "vmovups (%8, %7, 8), %%ymm8\n" + "vmovups (%11), %%ymm9\n" // dst_9 + "vmovups (%11, %7, 1), %%ymm10\n" + "vmovups (%11, %7, 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "vmovups (%2), %%ymm1\n" + "vmovups (%2), %%ymm2\n" + "vmovups (%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups (%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups (%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups (%2), %%ymm9\n" + "vmovups (%2), %%ymm10\n" + "vmovups (%2), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "2:\n" // LoopIC + "vmovups (%1), %%ymm12\n" + "movq %0, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "addq %5, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "addq %5, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "addq %5, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %4), %%ymm14\n" + "vbroadcastss (%%rax, %4, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + "addq $32, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 2b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(src_3_step), "r"(act_flag), // 6 + "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "r"(dst_flag) // 12 + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x2, %%eax\n" + "je 0f\n" + "movq %0, %%rax\n" + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%3)\n" // dst_3 + "vmovups %%ymm4, (%2, %1, 4)\n" + "vmovups %%ymm5, (%4)\n" // dst_5 + "vmovups %%ymm6, (%4, %1, 1)\n" + "vmovups %%ymm7, (%4, %1, 2)\n" + "vmovups %%ymm8, (%2, %1, 8)\n" + "vmovups %%ymm9, (%5)\n" // dst_9 + "vmovups %%ymm10, (%5, %1, 1)\n" + "vmovups %%ymm11, (%5, %1, 2)\n" + : + : "r"(act_flag), "r"(oc_align), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "a"(dst_flag) // 6 + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void Conv1x1SW1x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + asm volatile( + "movq %8, %%rax\n" + "and $0x1, %%eax\n" + "je 0f\n" + "vmovups (%7), %%ymm0\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %2\n" + "je 1f\n" + "vmovups (%2), %%ymm0\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + + "2:\n" // LoopIC + "vbroadcastss (%0), %%ymm12\n" + "vmovups (%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 4(%0), %%ymm12\n" + "vmovups 32(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 8(%0), %%ymm12\n" + "vmovups 64(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 12(%0), %%ymm12\n" + "vmovups 96(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 16(%0), %%ymm12\n" + "vmovups 128(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 20(%0), %%ymm12\n" + "vmovups 160(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 24(%0), %%ymm12\n" + "vmovups 192(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + + "vbroadcastss 28(%0), %%ymm12\n" + "vmovups 224(%1), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "addq $256, %1\n" + "addq $32, %0\n" + "dec %3\n" + "jg 2b\n" + + "movq %8, %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %5, %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + + "3:\n" + "vmovups %%ymm0, (%7)\n" // dst_0 + : + : "r"(src), "r"(weight), "r"(bias), "r"(ic_align), "r"(in_sw_step), "r"(act_flag), "r"(oc_align), "r"(dst), + "r"(dst_flag) // 8 + : "%rax", "%ecx", "%ymm0", "%ymm12", "%ymm13"); +} + +// sliding window to compate 1x1 conv in x86 +void Conv1x1SWAVXFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, + int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param) { + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int ohw = output_h * output_w; + int ohw_step = UP_DIV(ohw, conv_param->thread_num_); + int ohw_start = ohw_step * task_id; + int ohw_end = MSMIN(ohw_start + ohw_step, ohw); + if (ohw_start >= ohw_end) { + return; + } + int act_type = C0NUM; + int oc_tile_ = C8NUM; // oc in algin to C8NUM in x86_64_avx + if (conv_param->act_type_ == ActType_Relu6) { + act_type += C1NUM; + } + if (conv_param->act_type_ == ActType_Relu6 || conv_param->act_type_ == ActType_Relu) { + act_type += C2NUM; + } + int pad_d = conv_param->pad_d_; + int pad_l = conv_param->pad_l_; + int pad_r = conv_param->pad_r_; + int pad_u = conv_param->pad_u_; + int oc_align = sw_param->block_channel_; + int oc_align_float = oc_align * sizeof(float); + int ic_align = sw_param->ic_align_; + int in_sw_step = sw_param->in_sw_step_; + int in_sw_step_float = sw_param->in_sw_step_ * sizeof(float); + int kernel_step = sw_param->kernel_step_; + int oc_num = sw_param->c_block_; + int in_step = sw_param->in_step_; + int out_step = sw_param->out_step_; + const int ow_block_num[4] = {12, 6, 4, 3}; + const Conv1x1SWAVXKernel kernel[4][2] = {{Conv1x1SW1x8AVXKernel, Conv1x1SW12x8AVXKernel}, + {Conv1x1SW1x16AVXKernel, Conv1x1SW6x16AVXKernel}, + {Conv1x1SW1x24AVXKernel, Conv1x1SW4x24AVXKernel}, + {Conv1x1SW1x32AVXKernel, Conv1x1SW3x32AVXKernel}}; + for (int b = 0; b < conv_param->output_batch_; b++) { + int ic_block = 128; + int dst_flag = 0; + for (int ic = 0; ic < ic_align; ic += ic_block) { + if (ic_align - ic <= ic_block) { + ic_block = ic_align - ic; + dst_flag = C3NUM - (ic == 0); + } else { + dst_flag = 1 - (ic == 0); + } + if (pad_d == 0 && pad_l == 0 && pad_r == 0 && pad_u == 0) { + const float *bias = bias_data; + int oc_block = 0; + for (int oc = 0; oc < oc_num; oc += oc_block) { + oc_block = MSMIN(C4NUM, oc_num - oc); // 4 3 2 1 + const float *weight = packed_weight + oc * kernel_step + ic * C8NUM * oc_block; + if (bias != NULL) { + bias = bias_data + oc * oc_tile_; + } + const float *src_w = input_data + ic + ohw_start * in_sw_step; + float *dst_oc = output_data + oc * oc_tile_; + int hw_block = ow_block_num[oc_block - 1]; + for (int hw = ohw_start; hw < ohw_end; hw += hw_block) { + if (hw_block > ohw_end - hw) { // ow is not enough and process one ow + hw_block = 1; + } + float *dst_w = dst_oc + hw * oc_align; + kernel[oc_block - 1][hw_block / ow_block_num[oc_block - 1]](dst_w, src_w, weight, bias, act_type, hw_block, + oc_block, oc_align_float, ic_block >> C3NUM, + in_sw_step_float, dst_flag); + src_w += hw_block * in_sw_step; + } + } + } + } + input_data += in_step; + output_data += out_step; + } // batch loop +} + +#ifdef ENABLE_DEBUG +void Conv1x1SWOWxOCAVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag) { + oc_align /= sizeof(float); + in_sw_step /= sizeof(float); + ic_align <<= C3NUM; + __m256 dst_data[12]; + const float *src_sw[12]; + __m256 weight_data[4]; + for (int i = 0; i < C4NUM; ++i) { + weight_data[i] = _mm256_set1_ps(0.0f); + } + for (int i = 0; i < ow_block; ++i) { + if (dst_flag & 0x01) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_loadu_ps(dst + i * oc_align + j * C8NUM); + } + } else { + if (bias != NULL) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * C8NUM); + } + } else { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f); + } + } + } + src_sw[i] = src + i * in_sw_step; + } + const float *weight_kernel = weight; + for (int ic = 0; ic < ic_align; ++ic) { + for (int j = 0; j < oc_block; ++j) { + weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); + } + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] += src_sw[i][ic] * weight_data[j]; + } + } + weight_kernel += C8NUM * oc_block; + } // ic loop + // add bias and relu + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + if (dst_flag & 0x02) { + if (0x1 & act_flag) { // relu6 + dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f)); + } + } + _mm256_storeu_ps(dst + i * oc_align + j * C8NUM, dst_data[i * oc_block + j]); + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_1x1_avx_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_1x1_avx_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..025053cf374986211cd75342b0b8d2e7ad46dbc5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_1x1_avx_fp32.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CONV_1X1_AVX_FP32_H_ +#define MINDSPORE_NNACL_FP32_CONV_1X1_AVX_FP32_H_ + +#include "nnacl/op_base.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*Conv1x1SWAVXKernel)(float *dst, const float *src, const float *weight, const float *bias, + size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, + size_t in_sw_step, size_t dst_flag); + +void Conv1x1SWAVXFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, + int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param); + +#ifdef ENABLE_DEBUG +void Conv1x1SWOWxOCAVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t ow_block, size_t oc_block, size_t oc_align, size_t ic_align, size_t in_sw_step, + size_t dst_flag); +#endif +#ifdef __cplusplus +} +#endif // MINDSPORE_NNACL_FP32_CONV_1X1_AVX_FP32_H_ +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_1x1_x86_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_1x1_x86_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..39c17939b4931322ca469522169e5506d549cade --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_1x1_x86_fp32.h @@ -0,0 +1,21 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CONV_1X1_X86_FP32_H_ +#define MINDSPORE_NNACL_FP32_CONV_1X1_X86_FP32_H_ + +#include "nnacl/fp32/conv_1x1_avx_fp32.h" + +#endif // MINDSPORE_NNACL_FP32_CONV_1X1_X86_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_common_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_common_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..187ac5efcedb05ee7cea46bc811634a9a9acae84 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_common_fp32.c @@ -0,0 +1,435 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/conv_common_fp32.h" +#include +#ifdef ENABLE_AVX +#ifdef _MSC_VER +#include +#else +#include +#endif +#endif +#include "nnacl/fp32/matmul_fp32.h" +void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, + int block_index) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int kernel_plane = kernel_h * kernel_w; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int out_w = conv_param->output_w_; + if (dilation_h == 0 || dilation_w == 0 || out_w == 0) { + return; + } + int in_channel = conv_param->input_channel_; + int in_w = conv_param->input_w_; + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_; + int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_; + if (conv_param->input_h_ - input_h < 0 || in_w - input_w < 0) { + continue; + } + int input_stride = (input_h * in_w + input_w) * in_channel; + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h)); + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + if (kw_e <= kw_s) { + continue; + } + if (dilation_w == 1 && dilation_h == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, + (kw_e - kw_s) * in_channel * sizeof(float)); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int k = kw_s; k < kw_e; ++k) { + int input_x_stride = input_y_stride + k * dilation_w * in_channel; + int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float)); + } + } // kernel_h loop + } + } // tile num loop +} + +// fp32 conv common +void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param) { + if (conv_param->thread_num_ == 0) { + return; + } + Row2ColMajorFuncPtr Row2ColMajor = NULL; + int output_hw = conv_param->output_h_ * conv_param->output_w_; +#ifdef ENABLE_AVX + Row2ColMajor = RowMajor2Col6Major; + const int cal_num = C6NUM; +#elif defined(ENABLE_SSE) + Row2ColMajor = RowMajor2Col4Major; + const int cal_num = C4NUM; +#elif defined(ENABLE_ARM64) + MatmulFloatOptFuncPtr MatmulFloatOpt = NULL; + int cal_num = 0; + if (output_hw <= C4NUM) { + Row2ColMajor = RowMajor2Col4Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow4; + cal_num = C4NUM; + } else if (output_hw <= C8NUM) { + Row2ColMajor = RowMajor2Col8Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow8; + cal_num = C8NUM; + } else { + Row2ColMajor = RowMajor2Col12Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow12; + cal_num = C12NUM; + } +#elif defined(ENABLE_ARM32) + Row2ColMajor = RowMajor2Col12Major; + const int cal_num = C12NUM; +#else + Row2ColMajor = RowMajor2Col12Major; + const int cal_num = C12NUM; +#endif + + int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * cal_num; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num); + if (start_hw >= end_hw) { + return; + } + int out_stride = conv_param->output_channel_ * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + col_major_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int out_channel = conv_param->output_channel_; + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel * output_hw + start_hw * out_channel; + for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + Row2ColMajor(packed_input, col_major_input, cal_num, deep); + float *gemm_output = output_data + out_offset; +// x86 func param types are different +#if ENABLE_AVX + MatmulFloatAvxOpt(col_major_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (size_t)OutType_Nhwc); +#elif ENABLE_SSE + MatmulFloatSse64Opt(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (int)OutType_Nhwc); +#elif ENABLE_ARM32 + MatmulFloatNeon32Opt12x4(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, out_channel, OutType_Nhwc); +#elif ENABLE_ARM64 + MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); +#else + MatMul12x8(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); +#endif + } + } +} + +// fp32 conv common +void ConvFp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *col_major_input, float *output_data, int task_id, + const ConvParameter *conv_param) { + if (conv_param->thread_num_ == 0) { + return; + } + int output_hw = conv_param->output_h_ * conv_param->output_w_; + Row2ColMajorFuncPtr Row2ColMajor = NULL; +#ifdef ENABLE_AVX + const int cal_num = C6NUM; + Row2ColMajor = RowMajor2Col6Major; +#elif defined(ENABLE_SSE) + const int cal_num = C4NUM; + Row2ColMajor = RowMajor2Col4Major; +#elif defined(ENABLE_ARM64) + int cal_num = 0; + MatmulFloatOptFuncPtr MatmulFloatOpt = NULL; + if (output_hw <= C4NUM) { + cal_num = C4NUM; + Row2ColMajor = RowMajor2Col4Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow4; + } else if (output_hw <= C8NUM) { + cal_num = C8NUM; + Row2ColMajor = RowMajor2Col8Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow8; + } else { + cal_num = C12NUM; + Row2ColMajor = RowMajor2Col12Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow12; + } +#elif defined(ENABLE_ARM32) + const int cal_num = C12NUM; + Row2ColMajor = RowMajor2Col12Major; +#else + const int cal_num = C12NUM; + Row2ColMajor = RowMajor2Col12Major; +#endif + + int block_batch_per_thread = UP_DIV(conv_param->input_batch_, conv_param->thread_num_); + int start_batch = block_batch_per_thread * task_id; + int end_batch = MSMIN(conv_param->input_batch_, (start_batch + block_batch_per_thread)); + + int out_stride = conv_param->output_channel_ * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + col_major_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = start_batch; b < end_batch; b++) { + int out_channel = conv_param->output_channel_; + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel * output_hw; + for (int i = 0; i < output_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + Row2ColMajor(packed_input, col_major_input, cal_num, deep); + float *gemm_output = output_data + out_offset; +// x86 func param types are different +#if ENABLE_AVX + MatmulFloatAvxOpt(col_major_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (size_t)OutType_Nhwc); +#elif ENABLE_SSE + MatmulFloatSse64Opt(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, (size_t)out_channel, (int)OutType_Nhwc); +#elif ENABLE_ARM32 + MatmulFloatNeon32Opt12x4(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, + real_cal_row, out_channel, out_channel, OutType_Nhwc); +#elif ENABLE_ARM64 + MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); +#else + MatMul12x8(col_major_input, packed_weight, gemm_output, bias_data, (int)conv_param->act_type_, deep, real_cal_row, + out_channel, out_channel, OutType_Nhwc); +#endif + } + } +} + +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param) { + if (conv_param->thread_num_ == 0) { + return; + } + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int out_channel = conv_param->output_channel_; + int input_hw = conv_param->input_h_ * conv_param->input_w_; + int in_channel = conv_param->input_channel_; + Row2ColMajorFuncPtr Row2ColMajor = NULL; + int cal_num = 0; + int out_tile = 0; +#ifdef ENABLE_AVX + cal_num = C6NUM; + out_tile = C8NUM; + Row2ColMajor = RowMajor2Col6Major; + int align_channel = UP_DIV(out_channel, C16NUM) * C16NUM; +#else + out_tile = C4NUM; + MatmulFloatOptFuncPtr MatmulFloatOpt = NULL; + if (output_hw <= C4NUM) { + cal_num = C4NUM; + Row2ColMajor = RowMajor2Col4Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow4; + } else if (output_hw <= C8NUM) { + cal_num = C8NUM; + Row2ColMajor = RowMajor2Col8Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow8; + } else { + cal_num = C12NUM; + Row2ColMajor = RowMajor2Col12Major; + MatmulFloatOpt = MatmulFloatNeon64OptRow12; + } +#endif + int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * cal_num; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num); + if (start_hw >= end_hw) { + return; + } +#ifdef ENABLE_AVX + int act_type = 0; + if (conv_param->act_type_ == ActType_Relu6) { + act_type += 1; + } + if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) { + act_type += 2; + } + int out_stride = out_tile * cal_num; + int out_block_stride = output_hw * C8NUM; +#else + int out_stride = MSMIN(out_channel, out_tile) * cal_num; +#endif + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + col_major_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_offset = b * in_channel * input_hw; +#ifdef ENABLE_AVX + int out_offset = b * align_channel * output_hw + start_hw * out_tile; +#else + int out_offset = b * out_channel * output_hw + start_hw * MSMIN(out_channel, out_tile); +#endif + for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + Row2ColMajor(packed_input, col_major_input, cal_num, deep); + float *gemm_output = output_data + out_offset; +#ifdef ENABLE_AVX + for (int oc = 0; oc < out_channel; oc += C16NUM) { + CommonConv6x16Kernel(gemm_output + oc * output_hw, col_major_input, packed_weight + oc * deep, bias_data + oc, + deep, out_block_stride, act_type, real_cal_row); + } +#else + MatmulFloatOpt(col_major_input, packed_weight, gemm_output, bias_data, conv_param->act_type_, deep, real_cal_row, + out_channel, output_hw, OutType_NC4HW4); +#endif + } + } +} +#endif + +#ifdef ENABLE_AVX +void CommonConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t depth, + const size_t out_step, const size_t act_flag, const size_t real_cal_row) { +#define Store1 \ + _mm256_storeu_ps(dst, out[0]); \ + _mm256_storeu_ps(dst + out_step, out[1]); +#define Store2 \ + Store1 _mm256_storeu_ps(dst + C8NUM, out[2]); \ + _mm256_storeu_ps(dst + out_step + C8NUM, out[3]); +#define Store3 \ + Store2 _mm256_storeu_ps(dst + C16NUM, out[4]); \ + _mm256_storeu_ps(dst + out_step + C16NUM, out[5]); +#define Store4 \ + Store3 _mm256_storeu_ps(dst + C24NUM, out[6]); \ + _mm256_storeu_ps(dst + out_step + C24NUM, out[7]); +#define Store5 \ + Store4 _mm256_storeu_ps(dst + C32NUM, out[8]); \ + _mm256_storeu_ps(dst + out_step + C32NUM, out[9]); +#define Store6 \ + Store5 _mm256_storeu_ps(dst + C40NUM, out[10]); \ + _mm256_storeu_ps(dst + out_step + C40NUM, out[11]); + + __m256 out[12]; + if (bias != NULL) { + out[0] = _mm256_loadu_ps(bias); + out[1] = _mm256_loadu_ps(bias + C8NUM); + } else { + out[0] = _mm256_set1_ps(0.0f); + out[1] = _mm256_set1_ps(0.0f); + } + out[2] = out[0]; + out[3] = out[1]; + out[4] = out[0]; + out[5] = out[1]; + out[6] = out[0]; + out[7] = out[1]; + out[8] = out[0]; + out[9] = out[1]; + out[10] = out[0]; + out[11] = out[1]; + for (int d = 0; d < depth; ++d) { + __m256 w1 = _mm256_loadu_ps(weight); + __m256 w2 = _mm256_loadu_ps(weight + C8NUM); + __m256 s1 = _mm256_set1_ps(*src); + __m256 s2 = _mm256_set1_ps(*(src + 1)); + out[0] = _mm256_fmadd_ps(s1, w1, out[0]); + out[1] = _mm256_fmadd_ps(s1, w2, out[1]); + out[2] = _mm256_fmadd_ps(s2, w1, out[2]); + out[3] = _mm256_fmadd_ps(s2, w2, out[3]); + s1 = _mm256_set1_ps(*(src + 2)); + s2 = _mm256_set1_ps(*(src + 3)); + out[4] = _mm256_fmadd_ps(s1, w1, out[4]); + out[5] = _mm256_fmadd_ps(s1, w2, out[5]); + out[6] = _mm256_fmadd_ps(s2, w1, out[6]); + out[7] = _mm256_fmadd_ps(s2, w2, out[7]); + s1 = _mm256_set1_ps(*(src + 4)); + s2 = _mm256_set1_ps(*(src + 5)); + out[8] = _mm256_fmadd_ps(s1, w1, out[8]); + out[9] = _mm256_fmadd_ps(s1, w2, out[9]); + out[10] = _mm256_fmadd_ps(s2, w1, out[10]); + out[11] = _mm256_fmadd_ps(s2, w2, out[11]); + weight += C16NUM; + src += C6NUM; + } + __m256 six = _mm256_set1_ps(6.0f); + __m256 zero = _mm256_set1_ps(0.0f); + if (0x1 & act_flag) { // relu6 + out[0] = _mm256_min_ps(out[0], six); + out[1] = _mm256_min_ps(out[1], six); + out[2] = _mm256_min_ps(out[2], six); + out[3] = _mm256_min_ps(out[3], six); + out[4] = _mm256_min_ps(out[4], six); + out[5] = _mm256_min_ps(out[5], six); + out[6] = _mm256_min_ps(out[6], six); + out[7] = _mm256_min_ps(out[7], six); + out[8] = _mm256_min_ps(out[8], six); + out[9] = _mm256_min_ps(out[9], six); + out[10] = _mm256_min_ps(out[10], six); + out[11] = _mm256_min_ps(out[11], six); + } + if (0x2 & act_flag) { // relu + out[0] = _mm256_max_ps(out[0], zero); + out[1] = _mm256_max_ps(out[1], zero); + out[2] = _mm256_max_ps(out[2], zero); + out[3] = _mm256_max_ps(out[3], zero); + out[4] = _mm256_max_ps(out[4], zero); + out[5] = _mm256_max_ps(out[5], zero); + out[6] = _mm256_max_ps(out[6], zero); + out[7] = _mm256_max_ps(out[7], zero); + out[8] = _mm256_max_ps(out[8], zero); + out[9] = _mm256_max_ps(out[9], zero); + out[10] = _mm256_max_ps(out[10], zero); + out[11] = _mm256_max_ps(out[11], zero); + } + if (real_cal_row == C6NUM) { + Store6 + } else if (real_cal_row == C5NUM) { + Store5 + } else if (real_cal_row == C4NUM) { + Store4 + } else if (real_cal_row == C3NUM) { + Store3 + } else if (real_cal_row == C2NUM) { + Store2 + } else if (real_cal_row == C1NUM) { + Store1 + } +} + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_common_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_common_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..3bc48f757eb61af7b9aa2321f754d1cf375c1b24 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_common_fp32.h @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_COMMON_H_ +#define MINDSPORE_NNACL_FP32_CONV_COMMON_H_ + +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/fp32/conv_sw_avx_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*Row2ColMajorFuncPtr)(const float *src_ptr, float *dst_ptr, int row, int col); +#ifdef ENABLE_ARM64 +typedef void (*MatmulFloatOptFuncPtr)(const float *a, const float *b, float *c, const float *bias, int act_type, + int depth, int row, int col, size_t stride, size_t write_mode); +#endif + +void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, + int block_index); + +// fp32 convolution common (im2col+gemm) +void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param); + +// fp32 convolution common (im2col+gemm) +void ConvFp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *col_major_input, float *output_data, int task_id, + const ConvParameter *conv_param); + +// common convolution output C4HW4, if out_channel mod 4 remains, just output real channel, no zeros padded. +void ConvFp32OutNC4HW4(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data, + float *col_major_input, float *output_data, int task_id, const ConvParameter *conv_param); + +#ifdef ENABLE_AVX +void CommonConv6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t depth, + size_t out_step, size_t act_flag, size_t real_cal_row); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_COMMON_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_avx_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_avx_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7be1bd90c927ae879befdf79a86aeddab24d6021 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_avx_fp32.c @@ -0,0 +1,93 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/conv_depthwise_avx_fp32.h" +#include "nnacl/common_func.h" +#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/errorcode.h" +#include "nnacl/fp32/activation_fp32.h" + +int ConvDwAVX(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param) { + if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) { + return NNACL_ERR; + } + + int32_t *num_pixels = conv_dw_calc_param->num_pixels_; + int32_t *out_w_start = conv_dw_calc_param->out_w_start_; + int first_calc_kw = conv_dw_calc_param->first_calc_kw_; + + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + bool first_calc_flag = true; + if (first_calc_kw == -1) { + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(dst_data + ow * conv_param->output_channel_, bias_data, + conv_param->output_channel_ * (int)(sizeof(float))); + } + first_calc_flag = false; + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + + const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const float *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + if (first_calc_flag) { + int iw_origin = -conv_param->pad_l_ + conv_param->dilation_w_ * first_calc_kw; + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + ConvDwAVXFp32Row(dst_data, src_kw, weight_kh + first_calc_kw * conv_param->output_channel_, + conv_param->output_w_, conv_param->output_channel_, in_sw_step, true, bias_data); + } + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + if (first_calc_flag && (kw == first_calc_kw)) { + weight_kh += conv_param->output_channel_; + first_calc_flag = false; + continue; + } + int iw_origin = (out_w_start[kw] * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + float *dst_w = dst_data + out_w_start[kw] * conv_param->output_channel_; + + ConvDwAVXFp32Row(dst_w, src_kw, weight_kh, num_pixels[kw], conv_param->output_channel_, in_sw_step, false, + bias_data); + weight_kh += conv_param->output_channel_; + } + } + if (relu) { + Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + if (relu6) { + Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_avx_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_avx_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..3b3a0553ad33eab593a073c80fe0803ab9452b3b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_avx_fp32.h @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_AVX_H_ +#define MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_AVX_H_ + +#include "nnacl/conv_parameter.h" +#include "nnacl/base/conv_common_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwAVX(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param_); + +void ConvDwAVXFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step, bool first_calc_flag, const float *bias); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..9756d99f3f1abe5c341310a0b8afa5c2a390db63 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_fp32.c @@ -0,0 +1,2074 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl/common_func.h" +#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/errorcode.h" +#include "nnacl/fp32/activation_fp32.h" + +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) +void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, int num_pixels, + int output_channel, int input_step) { + for (int i = 0; i < num_pixels; i++) { + for (int c = 0; c < output_channel; c++) { + *output_ptr++ += weight_ptr[c] * input_ptr[c]; + } + input_ptr += input_step; + } +} +#endif + +int ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id) { + if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) { + return NNACL_ERR; + } + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + for (int b = 0; b < conv_param->output_batch_; b++) { + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(dst_data + ow * conv_param->output_channel_, bias_data, + conv_param->output_channel_ * (int)(sizeof(float))); + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + + const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const float *dw_weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int out_w_start = MSMAX( + 0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_); + int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ - + conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / + conv_param->stride_w_); + + float *dst_w = dst_data + out_w_start * conv_param->output_channel_; + int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + int num_pixels = out_w_end - out_w_start; + + ConvDwFp32Row(dst_w, src_kw, dw_weight_kh, num_pixels, conv_param->output_channel_, in_sw_step); + dw_weight_kh += conv_param->output_channel_; + } + } + if (relu) { + Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + if (relu6) { + Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + } + } + return NNACL_OK; +} + +#ifdef ENABLE_AVX512 +int ConvDwAVX512(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param) { + if (conv_param->thread_num_ == 0 || conv_param->dilation_h_ == 0 || conv_param->stride_w_ == 0) { + return NNACL_ERR; + } + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + int32_t *num_pixels = conv_dw_calc_param->num_pixels_; + int32_t *out_w_start = conv_dw_calc_param->out_w_start_; + int first_calc_kw = conv_dw_calc_param->first_calc_kw_; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + + bool first_calc_flag = true; + if (first_calc_kw == -1) { + first_calc_flag = false; + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(dst_data + ow * conv_param->output_channel_, bias_data, + conv_param->output_channel_ * (int)(sizeof(float))); + } + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + + const float *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const float *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + + if (first_calc_flag) { + int iw_origin = -conv_param->pad_l_ + conv_param->dilation_w_ * first_calc_kw; + + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + + ConvDwAVX512Fp32Row(dst_data, src_kw, weight_kh + first_calc_kw * conv_param->output_channel_, + conv_param->output_w_, conv_param->output_channel_, in_sw_step, true, bias_data); + } + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + if (first_calc_flag && (kw == first_calc_kw)) { + first_calc_flag = false; + weight_kh += conv_param->output_channel_; + continue; + } + + float *dst_w = dst_data + out_w_start[kw] * conv_param->output_channel_; + int iw_origin = (out_w_start[kw] * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + + const float *src_kw = src_kh + iw_origin * conv_param->input_channel_; + + ConvDwAVX512Fp32Row(dst_w, src_kw, weight_kh, num_pixels[kw], conv_param->output_channel_, in_sw_step, false, + bias_data); + weight_kh += conv_param->output_channel_; + } + } + if (relu) { + Fp32Relu(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } else if (relu6) { + Fp32Relu6(dst_data, conv_param->output_w_ * conv_param->output_channel_, dst_data); + } + } + } + return NNACL_OK; +} +#endif + +void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) { + if (block == 0) { + return; + } + int left = 0; + int right = conv_param->output_w_; + int top = 0; + int bottom = conv_param->output_h_; + + while (left * conv_param->stride_w_ < conv_param->pad_l_) { + left++; + } + while ((right - 1) * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_ * conv_param->dilation_w_ > + conv_param->input_w_ && + right > left) { + right--; + } + while (top * conv_param->stride_h_ < conv_param->pad_u_) { + top++; + } + while ((bottom - 1) * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_ * conv_param->dilation_h_ > + conv_param->input_h_ && + bottom > top) { + bottom--; + } + sliding->left_ = left; + sliding->right_ = right; + sliding->top_ = top; + sliding->bottom_ = bottom; + sliding->c_block_ = UP_DIV(conv_param->output_channel_, block); + sliding->block_channel_ = UP_DIV(conv_param->output_channel_, block) * block; + sliding->out_step_ = conv_param->output_h_ * conv_param->output_w_ * sliding->block_channel_; + if (conv_param->out_format_ == Format_NC4HW4) { + // write to nc8hw8 + sliding->out_h_step_ = conv_param->output_w_ * block; + sliding->out_c_step_ = block * conv_param->output_h_ * conv_param->output_w_; + sliding->out_w_step_ = block; + sliding->out_block_step_ = sliding->out_c_step_; + } else { + // write to nhwc + sliding->out_h_step_ = conv_param->output_w_ * sliding->block_channel_; + sliding->out_c_step_ = block; + sliding->out_w_step_ = sliding->block_channel_; + sliding->out_block_step_ = sliding->out_w_step_; + } +} + +void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block, + int weight_block) { + InitSlidingParam(sliding, conv_param, weight_block); + AppendSlidingParamConv(sliding, conv_param, input_block, weight_block); +} + +void AppendSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block, + int weight_block) { + if (input_block == 0) { // is not aligned + sliding->ic_align_ = conv_param->input_channel_; + } else { // 1x1 input is aligned to input_block + sliding->ic_align_ = UP_DIV(conv_param->input_channel_, input_block) * input_block; + } + sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->ic_align_; // for batch loop + sliding->in_h_step_ = conv_param->input_w_ * sliding->ic_align_; + sliding->in_sh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->stride_h_; // stride H + sliding->in_sw_step_ = sliding->ic_align_ * conv_param->stride_w_; // stride W + sliding->in_kh_step_ = conv_param->input_w_ * sliding->ic_align_ * conv_param->dilation_h_; // kernel H + sliding->in_kw_step_ = sliding->ic_align_ * conv_param->dilation_w_; // kernel W + sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * sliding->ic_align_ * weight_block; +} + +void InitSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) { + InitSlidingParam(sliding, conv_param, block); + AppendSlidingParamConvDw(sliding, conv_param, block); +} + +void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block) { + sliding->in_step_ = conv_param->input_h_ * conv_param->input_w_ * sliding->block_channel_; // for batch loop + sliding->in_h_step_ = conv_param->input_w_ * sliding->block_channel_; + sliding->in_sh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->stride_h_; // stride H + sliding->in_sw_step_ = sliding->block_channel_ * conv_param->stride_w_; // stride W + sliding->in_kh_step_ = conv_param->input_w_ * sliding->block_channel_ * conv_param->dilation_h_; // kernel H + sliding->in_kw_step_ = sliding->block_channel_ * conv_param->dilation_w_; // kernel W + sliding->kernel_step_ = conv_param->kernel_w_ * conv_param->kernel_h_ * block; +} + +/*conv depthwise fp32 begin*/ +void ConvDwBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int in_kh_step, int in_kw_step, int kernel_w_step, bool is_relu, bool is_relu6) { + const float *src_kh = src; + const float *weight_kh = weight; + for (int c = 0; c < C4NUM; c++) { + dst[c] = 0; + } + for (int kh = 0; kh < height; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst[c] += src_kw[c] * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w_step; + } // kernel_h loop + for (int c = 0; c < C4NUM; c++) { + dst[c] += bias[c]; + dst[c] = (is_relu) ? (MSMAX(0, dst[c])) : (dst[c]); + dst[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[c]))) : (dst[c]); + } +} + +void ConvDwBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) { + return; + } + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + float *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float *src_h = src + ih * sliding->in_h_step_; + + float *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float *src_w = src_h + iw * sliding->block_channel_; + + const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; +#ifdef ENABLE_AVX + ConvDwFp32BorderParam *param = (ConvDwFp32BorderParam *)malloc(sizeof(ConvDwFp32BorderParam)); + if (param == NULL) { + return; + } + param->dst = dst_kernel; + param->src = src_kernel; + param->weight = weight_kernel; + param->bias = bias; + param->height = end_kh - start_kh; + param->width = end_kw - start_kw; + param->in_kh_step = sliding->in_kh_step_ * sizeof(float); + param->in_kw_step = sliding->in_kw_step_ * sizeof(float); + param->kernel_w = conv_param->kernel_w_ * C4NUM * sizeof(float); + param->relu = relu; + param->relu6 = relu6; + ConvDwFp32Border(param); + free(param); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), + conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); +#else + ConvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM, relu, relu6); +#endif + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void ConvDwCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const float *src_kh = src_w; + const float *weight_kh = weight; + for (int c = 0; c < C4NUM; c++) { + dst_w[c] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_w[c] += src_kw[c] * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + // add biad relu + for (int c = 0; c < C4NUM; c++) { + dst_w[c] += bias[c]; + dst_w[c] = (is_relu) ? (MSMAX(0, dst_w[c])) : (dst_w[c]); + dst_w[c] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[c]))) : (dst_w[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} +#endif + +// conv depthwise fp32: sliding window +void ConvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + if (conv_param->thread_num_ == 0) { + return; + } + const float *src = input_data; + float *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float *src_data = src + oc * C4NUM; + float *dst_data = dst + oc * C4NUM; + const float *weight = weight_data + oc * sliding->kernel_step_; + const float *bias = bias_data + oc * C4NUM; + ConvDwBorder(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, sliding); + ConvDwBorder(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, conv_param->output_w_, + conv_param, sliding); + ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding); + ConvDwBorder(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), + sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), + sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float), + sliding->in_kw_step_ * sizeof(float), relu, relu6); +#else + ConvDwCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, relu, + relu6); +#endif + } + } // output C4 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nhwc4 +} +/*conv depthwise fp32 end*/ + +/*conv depthwise 3x3 fp32 begin*/ +bool CheckConvDwUse3X3(const ConvParameter *conv_param) { + bool use_3x3 = + conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && + (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) && + (conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) && conv_param->stride_h_ == conv_param->stride_w_ && + (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && + conv_param->pad_u_ == conv_param->pad_l_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1; + if (!use_3x3 || conv_param->input_h_ == 1 || conv_param->input_w_ == 1) { + return false; + } + const int in_h = (conv_param->output_h_ - 1) * conv_param->stride_h_ + conv_param->kernel_h_; + const int in_w = (conv_param->output_w_ - 1) * conv_param->stride_w_ + conv_param->kernel_w_; + return in_h == (conv_param->input_h_ + 2 * conv_param->pad_u_) && + in_w == (conv_param->input_w_ + 2 * conv_param->pad_l_); +} + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +static void ConvDw3x3RowLeft(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2, v3; + v0 = MS_MOVQ_F32(0.0f); + int ic = 0; + for (; ic < channel - 3; ic += 4) { + v1 = MS_LDQ_F32(src + ic); + v2 = MS_LDQ_F32(src + channel + ic); + v3 = MS_LDQ_F32(src + 2 * channel + ic); + MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); + MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); + MS_STQ_F32(line + lw * ic, b0); + MS_STQ_F32(line + lw * ic + 4, b1); + MS_STQ_F32(line + lw * ic + 8, b2); + MS_STQ_F32(line + lw * ic + 12, b3); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 64); + for (int i = 0; i < channel - ic; i++) { + float d1 = src[i + ic]; + float d2 = src[i + ic + channel]; + float d3 = src[i + ic + 2 * channel]; + remain_line[i] = 0.0f - d2; + remain_line[i + 4] = d1 + d2; + remain_line[i + 8] = d2 - d1; + remain_line[i + 12] = d3 - d1; + } + } +} + +static void ConvDw3x3RowMiddle(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2, v3; + int ic = 0; + for (; ic < channel - 3; ic += 4) { + v0 = MS_LDQ_F32(src + ic); + v1 = MS_LDQ_F32(src + channel + ic); + v2 = MS_LDQ_F32(src + 2 * channel + ic); + v3 = MS_LDQ_F32(src + 3 * channel + ic); + MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); + MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); + MS_STQ_F32(line + lw * ic, b0); + MS_STQ_F32(line + lw * ic + 4, b1); + MS_STQ_F32(line + lw * ic + 8, b2); + MS_STQ_F32(line + lw * ic + 12, b3); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 64); + for (int i = 0; i < channel - ic; i++) { + float d0 = src[i + ic]; + float d1 = src[i + ic + channel]; + float d2 = src[i + ic + 2 * channel]; + float d3 = src[i + ic + 3 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 4] = d1 + d2; + remain_line[i + 8] = d2 - d1; + remain_line[i + 12] = d3 - d1; + } + } +} + +static void ConvDw3x3RowRight(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2, v3; + int ic = 0; + v3 = MS_MOVQ_F32(0.0f); + for (; ic < channel - 3; ic += 4) { + v0 = MS_LDQ_F32(src + ic); + v1 = MS_LDQ_F32(src + channel + ic); + v2 = MS_LDQ_F32(src + 2 * channel + ic); + MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); + MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); + MS_STQ_F32(line + lw * ic, b0); + MS_STQ_F32(line + lw * ic + 4, b1); + MS_STQ_F32(line + lw * ic + 8, b2); + MS_STQ_F32(line + lw * ic + 12, b3); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 64); + for (int i = 0; i < channel - ic; i++) { + float d0 = src[i + ic]; + float d1 = src[i + ic + channel]; + float d2 = src[i + ic + 2 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 4] = d1 + d2; + remain_line[i + 8] = d2 - d1; + remain_line[i + 12] = 0.0f - d1; + } + } +} + +static void ConvDw3x3RowSingle(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2; + int ic = 0; + v2 = MS_MOVQ_F32(0.0f); + for (; ic < channel - 3; ic += 4) { + v0 = MS_LDQ_F32(src + ic); + v1 = MS_LDQ_F32(src + channel + ic); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_STQ_F32(line + lw * ic, v0); + MS_STQ_F32(line + lw * ic + 4, v1); + MS_STQ_F32(line + lw * ic + 8, b2); + memset(line + lw * ic + 12, 0, 16); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 64); + for (int i = 0; i < channel - ic; i++) { + float d0 = src[i + ic]; + float d1 = src[i + ic + channel]; + remain_line[i] = d0; + remain_line[i + 4] = d1; + remain_line[i + 8] = 0.0f - d1; + } + } +} + +static void ConvDw3x3InitTop(const float *src, float **lines, int width, int channel) { + float *line0 = lines[0]; + float *line1 = lines[1]; + float *line2 = lines[2]; + int c4 = UP_ROUND(channel, C4NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(line0, 0, c4 * lw * sizeof(float)); + ConvDw3x3RowLeft(src, line1, lw, channel); + ConvDw3x3RowLeft(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } +} + +static void ConvDw3x3InitRow(const float *src, float **lines, int width, int channel) { + float *line0 = lines[0]; + float *line1 = lines[1]; + float *line2 = lines[2]; + int lw = UP_DIV(width, C2NUM) * C4NUM; + ConvDw3x3RowLeft(src - width * channel, line0, lw, channel); + ConvDw3x3RowLeft(src, line1, lw, channel); + ConvDw3x3RowLeft(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); + ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRight(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); + ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); + ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } +} + +static void ConvDw3x3Row(const float *src, float **lines, int width, int channel) { + float *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c4 = UP_ROUND(channel, C4NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(tmp, 0, c4 * lw * sizeof(float)); + ConvDw3x3RowLeft(src, tmp, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRight(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); + } +} + +static void ConvDw3x3Bottom(float **lines, int width, int channel) { + float *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c4 = UP_ROUND(channel, C4NUM); + memset(tmp, 0, UP_DIV(width, C2NUM) * c4 * C4NUM * sizeof(float)); +} + +#ifndef ENABLE_ARM64 +void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel, + bool relu, bool relu6) { + int channel = ori_channel; + float *line0 = lines[0]; + float *line1 = lines[1]; + float *line2 = lines[2]; + for (; channel > 0; channel -= 4) { + MS_FLOAT32X4 bias = MS_LDQ_F32(bias_data); + bias_data += 4; + MS_FLOAT32X4 g00 = MS_LDQ_F32(weight); + MS_FLOAT32X4 g01 = MS_LDQ_F32(weight + 4); + MS_FLOAT32X4 g02 = MS_LDQ_F32(weight + 8); + MS_FLOAT32X4 g03 = MS_LDQ_F32(weight + 12); + MS_FLOAT32X4 g10 = MS_LDQ_F32(weight + 16); + MS_FLOAT32X4 g11 = MS_LDQ_F32(weight + 20); + MS_FLOAT32X4 g12 = MS_LDQ_F32(weight + 24); + MS_FLOAT32X4 g13 = MS_LDQ_F32(weight + 28); + MS_FLOAT32X4 g20 = MS_LDQ_F32(weight + 32); + MS_FLOAT32X4 g21 = MS_LDQ_F32(weight + 36); + MS_FLOAT32X4 g22 = MS_LDQ_F32(weight + 40); + MS_FLOAT32X4 g23 = MS_LDQ_F32(weight + 44); + weight += 48; + float *cur_dst = dst; + int ow = 0; + for (; ow < width - 1; ow += 2) { + MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00); + MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01); + MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02); + MS_FLOAT32X4 acc3 = MS_MULQ_F32(MS_LDQ_F32(line0 + 12), g03); + line0 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12); + acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line1 + 12), g13); + line1 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22); + acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line2 + 12), g23); + line2 += 16; + MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1)); + MS_FLOAT32X4 res1 = MS_ADDQ_F32(acc1, MS_SUBQ_F32(acc3, acc2)); + res0 = MS_ADDQ_F32(res0, bias); + res1 = MS_ADDQ_F32(res1, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f)); + res1 = MS_MAXQ_F32(res1, MS_MOVQ_F32(0.0f)); + } + if (relu6) { + res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f)); + res1 = MS_MINQ_F32(res1, MS_MOVQ_F32(6.0f)); + } + if (channel >= 4) { + MS_STQ_F32(cur_dst, res0); + MS_STQ_F32(cur_dst + ori_channel, res1); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = MS_F32X4_GETI(res0, i); + cur_dst[ori_channel + i] = MS_F32X4_GETI(res1, i); + } + } + cur_dst += 2 * ori_channel; + } + if (ow < width) { + MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00); + MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01); + MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02); + line0 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12); + line1 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22); + line2 += 16; + MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1)); + res0 = MS_ADDQ_F32(res0, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f)); + } + if (relu6) { + res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f)); + } + if (channel >= 4) { + MS_STQ_F32(cur_dst, res0); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = MS_F32X4_GETI(res0, i); + } + } + } + dst += 4; + } +} +#endif + +void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data, + const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) { + int units = UP_DIV(conv_param->output_w_, C2NUM); + int c4 = UP_ROUND(conv_param->input_channel_, C4NUM); + int line = conv_param->input_channel_ * conv_param->input_w_; + + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + float *line0 = buffer; + float *line1 = buffer + units * c4 * C4NUM; + float *line2 = buffer + units * c4 * C8NUM; + float *lines[3] = {line0, line1, line2}; + int oh = start_oh; + if (oh == 0) { + // input trans + ConvDw3x3InitTop(src, lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3InitRow(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + for (oh = start_oh + 1; oh < end_oh - 1; oh++) { + // input trans + ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + // dst calc and trans + ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + } + if (oh == conv_param->output_h_ - 1) { + // input trans + ConvDw3x3Bottom(lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + } +} +#endif + +/*conv depthwise indirect buffer fp32 begin*/ +bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param) { + bool use_indirect = (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) || + (conv_param->kernel_h_ == 5 && conv_param->kernel_w_ == 5); + return use_indirect; +} + +void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param, + int step_h, int step_w) { +#ifdef ENABLE_AVX + int div = C8NUM; +#else + int div = C4NUM; +#endif + + int ic_div = UP_DIV(conv_param->input_channel_, div) * div; + for (int b = 0; b < conv_param->output_batch_; b++) { + float **indirect = indirect_buffer + b * conv_param->output_h_ * step_h; + float *input = src + b * conv_param->input_h_ * conv_param->input_w_ * ic_div; + for (int oh = 0; oh < conv_param->output_h_; oh++) { + for (int kh = 0; kh < conv_param->kernel_h_; kh++) { + int ih = oh * conv_param->stride_h_ + kh * conv_param->dilation_h_ - conv_param->pad_u_; + if (ih < conv_param->input_h_ && ih >= 0) { + for (int ow = 0; ow < conv_param->output_w_; ow++) { + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int iw = ow * conv_param->stride_w_ + kw * conv_param->dilation_w_ - conv_param->pad_l_; + int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh; + if (iw < conv_param->input_w_ && iw >= 0) { + indirect[index] = input + (ih * conv_param->input_w_ + iw) * ic_div; + } else { + indirect[index] = zero_ptr; + } + } + } + } else { + for (int ow = 0; ow < conv_param->output_w_; ow++) { + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int index = oh * step_h + ow * step_w * conv_param->kernel_h_ + kw * conv_param->kernel_h_ + kh; + indirect[index] = zero_ptr; + } + } + } + } + } + } +} + +#if !defined(ENABLE_ARM64) && !defined(ENABLE_AVX) +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel) { + do { + float **in = input; + size_t c = (size_t)channels; + const float *w = weights; + float *out = output; + memcpy(out, bias, channels * (int)sizeof(float)); + for (; c >= C4NUM; c -= C4NUM) { + for (int i = 0; i < C4NUM; i++) { + for (int k = 0; k < kernel; k++) { + out[i] += in[k][i] * w[i + k * C4NUM]; + } + } + w += kernel * C4NUM; + out += C4NUM; + for (int k = 0; k < kernel; k++) { + in[k] += C4NUM; + } + } + for (int i = 0; i < c; i++) { + for (int k = 0; k < kernel; k++) { + out[i] += in[k][i] * w[i + k * C4NUM]; + } + } + if (relu) { + Fp32Relu(output, channels, output); + } + if (relu6) { + Fp32Relu6(output, channels, output); + } + output += channels; + input = input + input_stride; + } while (--output_width != 0); +} +#endif + +#ifdef ENABLE_ARM64 +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel) { + if (kernel == 9) { + ConvDwFp32Indirect3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, + relu6); + } else if (kernel == 25) { + ConvDwFp32Indirect5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, + relu6); + } +} +#endif + +#ifdef ENABLE_AVX +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel) { + if (kernel == 9) { + ConvDwFp32Avx3x3(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6); + } else if (kernel == 25) { + ConvDwFp32Avx5x5(output, input, weights, bias, channels, output_width, input_stride * sizeof(float *), relu, relu6); + } +} +#endif + +void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data, + float *zero_ptr, const ConvParameter *conv_param, int task_id) { + if (conv_param->thread_num_ == 0) { + return; + } + int step_w = conv_param->dilation_w_ == 1 ? conv_param->stride_w_ : conv_param->kernel_w_; + int step_h = + (conv_param->kernel_h_ * conv_param->kernel_w_) + (conv_param->output_w_ - 1) * step_w * conv_param->kernel_h_; + int input_stride = conv_param->kernel_h_ * step_w; + + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + + int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int h_start = h_step * task_id; + int h_end = MSMIN(h_start + h_step, conv_param->output_h_); + + for (int b = 0; b < conv_param->output_batch_; b++) { + float **indirect_b = indirect_buffer + b * conv_param->output_h_ * step_h; + float *outout_b = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = h_start; oh < h_end; oh++) { + float **indirect = indirect_b + oh * step_h; + float *output_h = outout_b + oh * conv_param->output_w_ * conv_param->output_channel_; + if (conv_param->kernel_w_ == 3) { + ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_, + conv_param->output_w_, input_stride, relu, relu6, 9); + } else if (conv_param->kernel_w_ == 5) { + ConvDwFp32IndirectRow(output_h, indirect, weight_data, bias_data, conv_param->output_channel_, + conv_param->output_w_, input_stride, relu, relu6, 25); + } + } + } +} +/*conv depthwise indirect buffer fp32 end*/ + +/*deconv depthwise fp32 begin*/ +void DeconvDwBorderPixel(float *dst, const float *src, const float *weight, int height, int width, int in_kh_step, + int in_kw_step, int kernel_w_step) { + float *dst_kh = dst; + const float *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { +#ifdef ENABLE_ARM64 + float32x4_t src_4 = vld1q_f32(src); + float32x4_t weight_4 = vld1q_f32(weight_kw); + float32x4_t dst_4 = vld1q_f32(dst_kw); + dst_4 = vfmaq_f32(dst_4, src_4, weight_4); + vst1q_f32(dst_kw, dst_4); +#else + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src[c] * weight_kw[c]; + } +#endif + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w_step; + } // kernel_h loop +} + +void DeconvDwBorder(float *dst, const float *src, const float *weight, int top, int bottom, int left, int right, + const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) { + return; + } + const float *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + float *dst_h = dst + oh * sliding->in_h_step_; + + const float *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + float *dst_w = dst_h + ow * sliding->block_channel_; + + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + float *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; +#ifdef ENABLE_ARM64 + DeconvDwFp32Border(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), + conv_param->kernel_w_ * C4NUM * sizeof(float)); +#else + DeconvDwBorderPixel(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM); +#endif + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM64 +void DeconvDwCenter(float *dst, const float *src, const float *weight, int height, int width, int kernel_h, + int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int in_kh_step, + int in_kw_step) { + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float *dst_kh = dst_w; + const float *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} +#endif + +void DeconvDwPost(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) { + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; + float *dst_k = dst; + for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) { + for (int c = 0; c < C4NUM; c++) { + dst_k[c] += bias[c]; + dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]); + dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]); + } + dst_k += block_channel; + } +} + +// deconv depthwise fp32: sliding window +void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { + const float *src = input_data; + float *dst = output_data; + if (conv_param->thread_num_ == 0) { + return; + } + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const float *src_data = src + oc * C4NUM; + float *dst_data = dst + oc * C4NUM; + const float *weight = weight_data + oc * sliding->kernel_step_; + const float *bias = bias_data + oc * C4NUM; + DeconvDwBorder(dst_data, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, sliding); + DeconvDwBorder(dst_data, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, conv_param->input_w_, + conv_param, sliding); + DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding); + DeconvDwBorder(dst_data, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, conv_param->input_w_, + conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; + const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), + sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), + sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float), + sliding->in_kw_step_ * sizeof(float)); +#else + DeconvDwCenter(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); +#endif + } + DeconvDwPost(dst_data, bias, sliding->block_channel_, conv_param); + } // output C4 loop + src += sliding->out_step_; + dst += sliding->in_step_; + } // batch loop + // output nhwc4 +} +/*deconv depthwise fp32 end*/ + +#ifdef ENABLE_AVX +void DepthwiseBorderAvxFp32(float *dst, const float *src, const float *weight, const float *bias, int top, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, + const DepthwiseSWKernel kernel, int act_type, int ow_bock, int oc_block) { + // dw border compate + int ih = top * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const float *src_h = src + ih * sw_param->in_h_step_; + float *dst_kernel = dst + left * sw_param->block_channel_; + for (int ow = left; ow < right; ow += ow_bock) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const float *src_w = src_h + iw * sw_param->block_channel_; + const float *src_kernel = src_w + start_kh * sw_param->in_kh_step_ + start_kw * sw_param->in_kw_step_; + const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM * oc_block; + kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock, + oc_block, sw_param->block_channel_, sw_param->in_kw_step_, sw_param->in_kh_step_, sw_param->in_sw_step_, + (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block); + dst_kernel += ow_bock * sw_param->block_channel_; + } // width loop +} + +void DepthwiseSWAvxFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sw_param, int task_id) { + int oh_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int oh_start = oh_step * task_id; + int oh_end = MSMIN(oh_start + oh_step, conv_param->output_h_); + if (oh_start >= oh_end) { + return; + } + // depthwise sw in x86 avx instructions + int oc_tile_ = C8NUM; // oc in algin to C8NUM in x86_64_avx + int act_type = 0; + if (conv_param->act_type_ == ActType_Relu6) { + act_type += 1; + } + if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) { + act_type += 2; + } + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int output_w = conv_param->output_w_; + int oc_algin = sw_param->block_channel_; + int oc_num = sw_param->c_block_; + int in_step = sw_param->in_step_; + int out_step = sw_param->out_step_; + int in_sw_step = sw_param->in_sw_step_; + int in_kw_step = sw_param->in_kw_step_; + int in_kh_step = sw_param->in_kh_step_; + int in_sh_step = sw_param->in_sh_step_; + int out_right = sw_param->right_; + int out_left = sw_param->left_; + int out_top = sw_param->top_; + int out_bottom = sw_param->bottom_; + int kernel_step = sw_param->kernel_step_; + int out_h_step = sw_param->out_h_step_; + int in_h_start = out_top * conv_param->stride_h_ - conv_param->pad_u_; + int in_w_start = out_left * conv_param->stride_w_ - conv_param->pad_l_; + int in_start = in_h_start * sw_param->in_h_step_ + in_w_start * oc_algin; + const int ow_block_num[4] = {8, 4, 4, 3}; + const DepthwiseSWKernel kernel[4][2] = {{DepthwiseSW1x8Kernel, DepthwiseSW8x8Kernel}, + {DepthwiseSW1x16Kernel, DepthwiseSW4x16Kernel}, + {DepthwiseSW1x24Kernel, DepthwiseSW4x24Kernel}, + {DepthwiseSW1x32Kernel, DepthwiseSW3x32Kernel}}; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oh = oh_start; oh < oh_end; ++oh) { + float *dst_oh = output_data + oh * out_h_step; + const float *src_h = input_data + in_start + (oh - out_top) * in_sh_step; + int oc_block = 0; + const float *bias = bias_data; + for (int oc = 0; oc < oc_num; oc += oc_block) { + oc_block = MSMIN(C4NUM, oc_num - oc); // 4 3 2 1 + int oc_step = oc * oc_tile_; + const float *weight = weight_data + oc * kernel_step; + if (bias != NULL) { + bias = bias_data + oc_step; + } + float *dst_w = dst_oh + oc_step; + const DepthwiseSWKernel kernel_border = kernel[oc_block - 1][0]; + if (oh < out_top || oh >= out_bottom) { // oh in up or down border + DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, output_w, conv_param, sw_param, + kernel_border, act_type, 1, oc_block); + } else { // oh in center + // ow in right + DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, 0, out_left, conv_param, sw_param, + kernel_border, act_type, 1, oc_block); + // ow in center + const float *src_w = src_h + oc_step; + int ow_block = ow_block_num[oc_block - 1]; // 8 4 4 3 + for (int ow = out_left; ow < out_right; ow += ow_block) { // left ~ right + if (ow_block > out_right - ow) { // ow is not enough and process one ow + ow_block = 1; + } + kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]]( + dst_w + ow * oc_algin, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block, oc_algin, + in_kw_step, in_kh_step, in_sw_step, 0); + src_w += ow_block * in_sw_step; + } + // ow in left + DepthwiseBorderAvxFp32(dst_w, input_data + oc_step, weight, bias, oh, out_right, output_w, conv_param, + sw_param, kernel_border, act_type, 1, oc_block); + } + } + } // output h loop + input_data += in_step; + output_data += out_step; + } // batch loop +} + +#ifdef ENABLE_DEBUG +void DepthwiseSWWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + __m256 dst_data[12]; + __m256 src_data; + const float *src_kh[12]; + const float *src_kw[12]; + __m256 weight_data[4]; + for (int i = 0; i < ow_block; ++i) { + if (bias != NULL) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * 8); + } + } else { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f); + } + } + src_kh[i] = src + i * in_sw_step; + src_kw[i] = NULL; + } + const float *weight_kernel = weight; + for (int kh = 0; kh < kernel_h; kh++) { + for (int i = 0; i < ow_block; ++i) { + src_kw[i] = src_kh[i]; + } + for (int kw = 0; kw < kernel_w; kw++) { + for (int j = 0; j < oc_block; ++j) { + weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); + } + for (int i = 0; i < ow_block; ++i) { // loop ow + for (int j = 0; j < oc_block; ++j) { + src_data = _mm256_loadu_ps(src_kw[i] + j * C8NUM); + dst_data[i * oc_block + j] += src_data * weight_data[j]; + } + } + for (int i = 0; i < ow_block; ++i) { + src_kw[i] += in_kw_step; // ic8 * dilation_w + } + weight_kernel += oc_block * C8NUM; + } // kernel_w loop + weight_kernel += kw_remainder; + for (int i = 0; i < ow_block; ++i) { + src_kh[i] += in_kh_step; // + } + } // kernel_h loop + // add bias and relu + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + if (0x1 & act_flag) { // relu6 + dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f)); + } + _mm256_storeu_ps(dst + i * oc_algin + j * C8NUM, dst_data[i * oc_block + j]); + } + } +} +#endif + +void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups 0x40(%2), %%ymm6\n" + "vmovups 0x60(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups 0x40(%2), %%ymm10\n" + "vmovups 0x60(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + + "vmovups (%1), %%ymm12\n" + "vmovups (%%rcx), %%ymm13\n" + "vmovups (%%rcx, %7), %%ymm14\n" + "vmovups (%%rcx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + + "vmovups 0x20(%1), %%ymm12\n" + "vmovups 0x20(%%rcx), %%ymm13\n" + "vmovups 0x20(%%rcx, %7), %%ymm14\n" + "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + + "vmovups 0x40(%1), %%ymm12\n" + "vmovups 0x40(%%rcx), %%ymm13\n" + "vmovups 0x40(%%rcx, %7), %%ymm14\n" + "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + + "vmovups 0x60(%1), %%ymm12\n" + "vmovups 0x60(%%rcx), %%ymm13\n" + "vmovups 0x60(%%rcx, %7), %%ymm14\n" + "vmovups 0x60(%%rcx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + "addq $128, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %8, %1\n" + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) // 8 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" + "vmovups %%ymm4, (%2, %1, 1)\n" + "vmovups %%ymm5, 0x20(%2, %1, 1)\n" + "vmovups %%ymm6, 0x40(%2, %1, 1)\n" + "vmovups %%ymm7, 0x60(%2, %1, 1)\n" + "vmovups %%ymm8, (%2, %1, 2)\n" + "vmovups %%ymm9, 0x20(%2, %1, 2)\n" + "vmovups %%ymm10, 0x40(%2, %1, 2)\n" + "vmovups %%ymm11, 0x60(%2, %1, 2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "vmovups (%%rcx), %%ymm4\n" + "vmovups 0x20(%%rcx), %%ymm5\n" + "vmovups 0x40(%%rcx), %%ymm6\n" + "vmovups 0x60(%%rcx), %%ymm7\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n" + "vfmadd231ps 0x60(%1), %%ymm7, %%ymm3\n" + "addq $128, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %7, %1\n" + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(kw_remainder) // 7 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14"); +} + +void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + in_sw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_algin; + oc_algin *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%2), %%ymm3\n" + "vmovups 0x20(%2), %%ymm4\n" + "vmovups 0x40(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups 0x40(%2), %%ymm8\n" + "vmovups (%2), %%ymm9\n" + "vmovups 0x20(%2), %%ymm10\n" + "vmovups 0x40(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "vmovups (%1), %%ymm12\n" + "vmovups (%%rcx), %%ymm13\n" + "vmovups (%%rcx, %7, 1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n" + "vmovups (%%rcx, %7, 2), %%ymm15\n" + "vmovups (%%rcx, %9), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + + "vmovups 0x20(%1), %%ymm12\n" + "vmovups 0x20(%%rcx), %%ymm13\n" + "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n" + "vmovups 0x20(%%rcx, %9), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm10\n" + + "vmovups 0x40(%1), %%ymm12\n" + "vmovups 0x40(%%rcx), %%ymm13\n" + "vmovups 0x40(%%rcx, %7, 1), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vmovups 0x40(%%rcx, %7, 2), %%ymm15\n" + "vmovups 0x40(%%rcx, %9), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm11\n" + + "addq $96, %1\n" + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step) // 9 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", + "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm5, 0x40(%2, %1, 1)\n" + "vmovups %%ymm6, (%2, %1, 2)\n" + "vmovups %%ymm7, 0x20(%2, %1, 2)\n" + "vmovups %%ymm8, 0x40(%2, %1, 2)\n" + "vmovups %%ymm9, (%3)\n" // dst+3 + "vmovups %%ymm10, 0x20(%3)\n" + "vmovups %%ymm11, 0x40(%3)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "vmovups (%%rcx), %%ymm4\n" + "vmovups 0x20(%%rcx), %%ymm5\n" + "vmovups 0x40(%%rcx), %%ymm6\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm6, %%ymm2\n" + "addq $96, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %7, %1\n" // kw_remainder + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(kw_remainder) // 7 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm4", "%ymm5", "%ymm6"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14"); +} + +void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + in_sw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_algin; + oc_algin *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%2), %%ymm3\n" + "vmovups 0x20(%2), %%ymm4\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups (%2), %%ymm9\n" + "vmovups 0x20(%2), %%ymm10\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "vmovups (%1), %%ymm12\n" + "vmovups (%%rcx), %%ymm13\n" + "vmovups (%%rcx, %7, 1), %%ymm14\n" + "vmovups (%%rcx, %7, 2), %%ymm15\n" + "vmovups (%%rcx, %9), %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm2, %%ymm9\n" + + "vmovups 0x20(%1), %%ymm12\n" + "vmovups 0x20(%%rcx), %%ymm13\n" + "vmovups 0x20(%%rcx, %7, 1), %%ymm14\n" + "vmovups 0x20(%%rcx, %7, 2), %%ymm15\n" + "vmovups 0x20(%%rcx, %9), %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm2, %%ymm10\n" + + "addq $64, %1\n" + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder), "r"(src_3_step) // 9 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm13", + "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm3, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm6, (%2, %1, 2)\n" + "vmovups %%ymm7, 0x20(%2, %1, 2)\n" + "vmovups %%ymm9, (%3)\n" // dst+3 + "vmovups %%ymm10, 0x20(%3)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3) + : "%ecx", "%ymm0", "%ymm1", "%ymm3", "%ymm4", "%ymm6", "%ymm7", "%ymm9", "%ymm10", "%ymm12", "%ymm14"); +} + +void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "vmovups (%%rcx), %%ymm4\n" + "vmovups 0x20(%%rcx), %%ymm5\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm5, %%ymm1\n" + "addq $64, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %7, %1\n" // kw_remainder + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(kw_remainder) // 7 + : "%rcx", "%rsi", "%ymm0", "%ymm1", "%ymm4", "%ymm5"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14"); +} + +void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * oc_algin; + float *dst_5 = dst + 5 * oc_algin; + oc_algin *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups (%0), %%ymm1\n" + "vmovups (%0), %%ymm2\n" + "vmovups (%0), %%ymm3\n" + "vmovups (%0), %%ymm4\n" + "vmovups (%0), %%ymm5\n" + "vmovups (%0), %%ymm6\n" + "vmovups (%0), %%ymm7\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + + asm volatile( + "LoopH:\n" + "movq %3, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "LoopW:\n" + "movq %%rcx, %%rax\n" + "vmovups (%1), %%ymm12\n" + "vmovups (%%rax), %%ymm13\n" + "vmovups (%%rax, %6), %%ymm14\n" + "vmovups (%%rax, %6, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "addq %7, %%rax\n" + "vmovups (%%rax), %%ymm13\n" + "vmovups (%%rax, %6), %%ymm14\n" + "vmovups (%%rax, %6, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "addq %7, %%rax\n" + "vmovups (%%rax), %%ymm13\n" + "vmovups (%%rax, %6), %%ymm14\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + + "addq $32, %1\n" + "addq %4, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg LoopW\n" + + "addq %5, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %2\n" + "jg LoopH\n" + : + : "r"(src), "r"(weight), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), "r"(in_kh_step), // 5 + "r"(in_sw_step), "r"(src_3_step), "r"(kw_remainder) // 8 + : "%rcx", "%rsi", "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", + "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je Write\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + + "and $0x1, %%eax\n" + "je Write\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + + "Write:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%3)\n" // dst_3 + "vmovups %%ymm4, (%2, %1, 4)\n" + "vmovups %%ymm5, (%4)\n" // dst_5 + "vmovups %%ymm6, (%4, %1, 1)\n" + "vmovups %%ymm7, (%4, %1, 2)\n" + : + : "a"(act_flag), "r"(oc_algin), "r"(dst), "r"(dst_3), "r"(dst_5) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm12", "%ymm14"); +} + +void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + oc_algin *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "vmovups (%%rcx), %%ymm4\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "addq $32, %1\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %6, %0\n" // in_kh_step + "addq %7, %1\n" // kw_remainder + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(in_kw_step), // 5 + "r"(in_kh_step), "r"(kw_remainder) // 7 + : "%rcx", "%rsi", "%ymm0", "%ymm4"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + : + : "a"(act_flag), "r"(oc_algin), "r"(dst) + : "%ecx", "%ymm0", "%ymm12", "%ymm14"); +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d7f0aeaa28925a05246df41efc77a96acf32c5dc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_depthwise_fp32.h @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_H_ +#define MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_H_ + +#include "nnacl/conv_parameter.h" +#include "nnacl/base/conv_common_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef ENABLE_ARM64 +void DepthwiseCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6); +#endif + +int ConvDw(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id); + +int ConvDwAVX512(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, int task_id, ConvDwCalcParam *conv_dw_calc_param_); + +void ConvDwAVX512Fp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step, bool first_calc_flag, const float *bias); + +void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block); + +void InitSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int input_block, + int weight_block); + +void AppendSlidingParamConv(SlidingWindowParam *sliding, const ConvParameter *conv_param, int in_block, + int weight_block); + +void InitSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block); + +void AppendSlidingParamConvDw(SlidingWindowParam *sliding, const ConvParameter *conv_param, int block); + +void ConvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +bool CheckConvDwUse3X3(const ConvParameter *conv_param); + +bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param); + +void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param, + int step_h, int step_w); + +#ifdef ENABLE_ARM64 +void ConvDwFp32Indirect3x3(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, size_t input_stride, size_t relu, size_t relu6); + +void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, size_t input_stride, size_t relu, size_t relu6); +#endif + +#ifdef ENABLE_AVX +typedef void (*DepthwiseSWKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSW1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); + +void DepthwiseSWAvxFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +void DepthwiseBorderAvxFp32(float *dst, const float *src, const float *weight, const float *bias, int top, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, + const DepthwiseSWKernel kernel, int act_type, int ow_bock, int oc_block); + +void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, + size_t output_width, size_t input_stride, size_t relu, size_t relu6); + +void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const float *bias, size_t channels, + size_t output_width, size_t input_stride, size_t relu, size_t relu6); +#ifdef ENABLE_DEBUG +void DepthwiseSWWxKKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder); +#endif +#endif + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel, + bool relu, bool relu6); +void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data, + const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh); +#endif + +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel); + +void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data, + float *zero_ptr, const ConvParameter *conv_param, int task_id); + +void DeconvDwSWFp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, + const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_DEPTHWISE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_im2col_avx512_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_im2col_avx512_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1f85816de2fd3ec5be334512d30e3ca244fd2481 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_im2col_avx512_fp32.c @@ -0,0 +1,92 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/conv_im2col_avx512_fp32.h" +#include "nnacl/fp32/conv_im2col_fp32.h" +#include "nnacl/fp32/matmul_avx512_fp32.h" +#include "nnacl/intrinsics/ms_simd_avx512_instructions.h" + +// fp32 conv common +void ConvIm2ColAVX512Fp32(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *output_data, int task_id, const ConvParameter *conv_param, + int cal_num) { + if (conv_param->thread_num_ == 0) { + return; + } + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int out_channel_align = UP_ROUND(conv_param->output_channel_, C16NUM); + + int block_per_thread = UP_DIV(UP_DIV(output_hw, cal_num), conv_param->thread_num_); + int start_block = block_per_thread * task_id; + int start_hw = start_block * cal_num; + int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * cal_num); + if (start_hw >= end_hw) { + return; + } + int out_stride = out_channel_align * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel_align * output_hw + start_hw * out_channel_align; + for (int i = start_hw; i < end_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColDataPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + + float *gemm_output = output_data + out_offset; + MatMulAvx512Fp32(packed_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + out_channel_align, out_channel_align, real_cal_row); + } + } +} + +// fp32 conv common +void ConvIm2ColAVX512Fp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *output_data, int task_id, + const ConvParameter *conv_param, int cal_num) { + if (conv_param->thread_num_ == 0) { + return; + } + int output_hw = conv_param->output_h_ * conv_param->output_w_; + int out_channel_align = UP_ROUND(conv_param->output_channel_, C16NUM); + + int block_batch_per_thread = UP_DIV(conv_param->input_batch_, conv_param->thread_num_); + int start_batch = block_batch_per_thread * task_id; + int end_batch = MSMIN(conv_param->input_batch_, (start_batch + block_batch_per_thread)); + + int out_stride = out_channel_align * cal_num; + int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_; + packed_input += task_id * deep * cal_num; + + size_t input_size = deep * cal_num * sizeof(float); + + for (int b = start_batch; b < end_batch; b++) { + int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_; + int out_offset = b * out_channel_align * output_hw; + for (int i = 0; i < output_hw; i += cal_num, out_offset += out_stride) { + int real_cal_row = MSMIN(output_hw - i, cal_num); + memset(packed_input, 0, input_size); + Im2ColDataPackUnitFp32(input_data + in_offset, conv_param, packed_input, real_cal_row, i); + + float *gemm_output = output_data + out_offset; + MatMulAvx512Fp32(packed_input, packed_weight, gemm_output, bias_data, (size_t)conv_param->act_type_, deep, + out_channel_align, out_channel_align, real_cal_row); + } + } +} diff --git a/mindspore-lite/src/extendrt/kernel/cuda/unary.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_im2col_avx512_fp32.h similarity index 43% rename from mindspore-lite/src/extendrt/kernel/cuda/unary.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_im2col_avx512_fp32.h index 2b25447355e1d0312c235a07a36a776d6898e0d0..1af6e4566f90a2636c38501981637c7bb76ac816 100644 --- a/mindspore-lite/src/extendrt/kernel/cuda/unary.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_im2col_avx512_fp32.h @@ -14,26 +14,25 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CUDA_UNARY_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CUDA_UNARY_H_ +#ifndef MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_ +#define MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_ -#include -#include -#include "src/extendrt/kernel/cuda/cuda_kernel.h" -#include "cuda_impl/cuda_class/unary_helper.h" +#include "nnacl/conv_parameter.h" -namespace mindspore::kernel { -class UnaryCudaKernel : public CudaKernel { - public: - UnaryCudaKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::InnerContext *ctx) - : CudaKernel(parameter, inputs, outputs, ctx) {} - ~UnaryCudaKernel() override = default; // cudaFree - int Prepare() override; - int Run() override; +#ifdef __cplusplus +extern "C" { +#endif + +void ConvIm2ColAVX512Fp32(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *output_data, int task_id, const ConvParameter *conv_param, + int cal_num); - private: - std::shared_ptr> unary_helper_{nullptr}; -}; -} // namespace mindspore::kernel +void ConvIm2ColAVX512Fp32CutByBatch(const float *input_data, float *packed_input, const float *packed_weight, + const float *bias_data, float *output_data, int task_id, + const ConvParameter *conv_param, int cal_num); + +#ifdef __cplusplus +} #endif + +#endif // MINDSPORE_NNACL_FP32_CONV_IM2COL_AVX512_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_im2col_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_im2col_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e88774b946efab0023d1cf3934123486ee07fc73 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_im2col_fp32.c @@ -0,0 +1,65 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/conv_im2col_fp32.h" + +void Im2ColDataPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index) { + // input format : nhwc + int kernel_w = conv_param->kernel_w_; + int kernel_h = conv_param->kernel_h_; + int kernel_plane = kernel_h * kernel_w; + int dilation_w = conv_param->dilation_w_; + int dilation_h = conv_param->dilation_h_; + + int out_w = conv_param->output_w_; + if (dilation_w == 0 || dilation_h == 0 || out_w == 0) { + return; + } + int in_channel = conv_param->input_channel_; + int in_w = conv_param->input_w_; + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_; + int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_; + if (conv_param->input_h_ - input_h < 0 || in_w - input_w < 0) { + continue; + } + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h)); + int input_stride = (input_h * in_w + input_w) * in_channel; + if (dilation_w == 1 && dilation_h == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, + (kw_e - kw_s) * in_channel * sizeof(float)); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int k = kw_s; k < kw_e; ++k) { + int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; + int input_x_stride = input_y_stride + k * dilation_w * in_channel; + memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float)); + } + } // kernel_h loop + } + } // tile num loop +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_im2col_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_im2col_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..10eb2724b60b1e2983556e31367e11fe4319f18f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_im2col_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_IM2COL_H_ +#define MINDSPORE_NNACL_FP32_CONV_IM2COL_H_ + +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Im2ColDataPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_IM2COL_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw.h new file mode 100644 index 0000000000000000000000000000000000000000..e95a691122c885574730089b03f935bea243f02a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw.h @@ -0,0 +1,131 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_SW_H_ +#define MINDSPORE_NNACL_FP32_CONV_SW_H_ + +#define GenerateConvSWFunc(backend, oc_unit_num, row_num_list, kernel_list, compute_core, outer_compute) \ + void SWBorder##backend(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, \ + int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sw_param, \ + const SWConvKernel kernel, int act_type, int ow_bock, int oc_block, size_t write_mode) { \ + for (int oh = top; oh < bottom; oh++) { \ + int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; \ + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); \ + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); \ + const float *src_h = src + ih * sw_param->in_h_step_; \ + float *dst_kernel = dst + left * sw_param->out_w_step_; \ + for (int ow = left; ow < right; ow += ow_bock) { \ + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; \ + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); \ + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); \ + const float *src_w = src_h + iw * sw_param->ic_align_; \ + const float *src_kernel = src_w + start_kh * sw_param->in_kh_step_ + start_kw * sw_param->in_kw_step_; \ + const float *weight_kernel = \ + weight + (start_kh * conv_param->kernel_w_ + start_kw) * sw_param->ic_align_ * C8NUM * oc_block; \ + outer_compute dst_kernel += ow_bock * sw_param->out_w_step_; \ + } \ + dst += sw_param->out_h_step_; \ + } \ + } \ + \ + void ConvSW##backend##Fp32(const float *input_data, const float *packed_weight, const float *bias_data, \ + float *output_data, int task_id, ConvParameter *conv_param, \ + SlidingWindowParam *sw_param) { \ + int out_h = conv_param->output_h_; \ + int oh_step = UP_DIV(out_h, conv_param->thread_num_); \ + int oh_start = oh_step * task_id; \ + int oh_end = MSMIN(oh_start + oh_step, out_h); \ + if (oh_start >= oh_end) { \ + return; \ + } \ + int oc_tile_ = C8NUM; /* oc in algin to C8NUM in arm64 */ \ + int act_type = 0; \ + if (conv_param->act_type_ == ActType_Relu6) { \ + act_type += 1; \ + } \ + if (conv_param->act_type_ == ActType_Relu || conv_param->act_type_ == ActType_Relu6) { \ + act_type += 2; \ + } \ + int kernel_h = conv_param->kernel_h_; \ + int kernel_w = conv_param->kernel_w_; \ + int ic_algin = sw_param->ic_align_; \ + int in_sw_step = sw_param->in_sw_step_; \ + int in_kw_step = sw_param->in_kw_step_; \ + int in_kh_step = sw_param->in_kh_step_; \ + int in_sh_step = sw_param->in_sh_step_; \ + int out_h_step = sw_param->out_h_step_; \ + int out_c_step = sw_param->out_c_step_; \ + int out_w_step = sw_param->out_w_step_; \ + int out_block_step = sw_param->out_block_step_; \ + int kernel_step = sw_param->kernel_step_; \ + int in_step = sw_param->in_step_; \ + int out_step = sw_param->out_step_; \ + int c_block = sw_param->c_block_; \ + int top = sw_param->top_; \ + int left = sw_param->left_; \ + int right = sw_param->right_; \ + int bottom = sw_param->bottom_; \ + int stride_h = conv_param->stride_h_; \ + int stride_w = conv_param->stride_w_; \ + int out_w = conv_param->output_w_; \ + int pad_u = conv_param->pad_u_; \ + int pad_l = conv_param->pad_l_; \ + int in_h_step = sw_param->in_h_step_; \ + int out_batch = conv_param->output_batch_; \ + int in_h_start = top * stride_h - pad_u; \ + int in_w_start = left * stride_w - pad_l; \ + int center_step = in_h_start * in_h_step + in_w_start * ic_algin; \ + int write_mode = conv_param->out_format_; \ + row_num_list kernel_list for (int b = 0; b < out_batch; b++) { \ + for (int oh = oh_start; oh < oh_end; oh += 1) { \ + float *dst_oh = output_data + oh * out_h_step; \ + const float *src_h = input_data + center_step; \ + \ + int oc_block = 0; \ + const float *bias = bias_data; \ + for (int oc = 0; oc < c_block; oc += oc_block) { \ + oc_block = MSMIN(oc_unit_num, c_block - oc); \ + const float *weight = packed_weight + oc * kernel_step; \ + if (bias != NULL) { \ + bias = bias_data + oc * oc_tile_; \ + } \ + float *dst_oc = dst_oh + oc * out_c_step; \ + const SWConvKernel kernel_border = kernel[oc_block - 1][0]; \ + if (oh < top || oh >= bottom) { /* oh in up or down border */ \ + SWBorder##backend(dst_oc, input_data, weight, bias, oh, oh + 1, 0, out_w, conv_param, sw_param, \ + kernel_border, act_type, 1, oc_block, write_mode); \ + } else { /* oh in center */ \ + /* ow in right */ \ + SWBorder##backend(dst_oc, input_data, weight, bias, oh, oh + 1, 0, left, conv_param, sw_param, \ + kernel_border, act_type, 1, oc_block, write_mode); \ + /* ow in center */ \ + const float *src_w = src_h + (oh - top) * in_sh_step; \ + int ow_block = ow_block_num[oc_block - 1]; \ + for (int ow = left; ow < right; ow += ow_block) { /* left ~ right */ \ + ow_block = MSMIN(ow_block, right - ow); \ + compute_core src_w += ow_block * in_sw_step; \ + } \ + /* ow in left */ \ + SWBorder##backend(dst_oc, input_data, weight, bias, oh, oh + 1, right, out_w, conv_param, sw_param, \ + kernel_border, act_type, 1, oc_block, write_mode); \ + } \ + } \ + } /* output h loop */ \ + input_data += in_step; \ + output_data += out_step; \ + } /* batch loop */ \ + } +#endif // MINDSPORE_NNACL_FP32_CONV_SW_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_arm64_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_arm64_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..145271e53178004fb8fda5a0dcf8a20dd005f90d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_arm64_fp32.c @@ -0,0 +1,99 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/conv_sw_arm64_fp32.h" +#include "nnacl/fp32/conv_sw.h" + +bool CheckArm64UseSWConv(const ConvParameter *conv_param) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + return false; + } + if (conv_param->input_channel_ > C128NUM) { + return false; + } + if (conv_param->kernel_h_ > C5NUM || conv_param->kernel_w_ > C5NUM) { + return false; + } + if (conv_param->dilation_h_ != 1 || conv_param->dilation_w_ != 1) { + return false; + } + if (conv_param->stride_w_ > C3NUM) { + return false; + } + if (conv_param->input_h_ / conv_param->kernel_h_ < C48NUM || conv_param->input_w_ / conv_param->kernel_w_ < C48NUM) { + return false; + } + return true; +} + +typedef void (*SWConvKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv2x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv2x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv3x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv3x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv5x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +void SWConv5x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step, + size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode); + +#define ROW_NUM_LIST const int ow_block_num[2] = {5, 5}; +#define KERNEL_LIST \ + const SWConvKernel kernel[2][5] = { \ + {SWConv1x8Kernel, SWConv2x8Kernel, SWConv3x8Kernel, SWConv4x8Kernel, SWConv5x8Kernel}, \ + {SWConv1x16Kernel, SWConv2x16Kernel, SWConv3x16Kernel, SWConv4x16Kernel, SWConv5x16Kernel}}; +#define COMPUTE_CORE \ + kernel[oc_block - 1][ow_block - 1](dst_oc + ow * out_w_step, src_w, weight, bias, kernel_h, kernel_w, act_type, \ + out_block_step, ic_algin, in_kw_step, in_kh_step, in_sw_step, 0, write_mode); +#define OUTER_COMPUTE \ + kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, \ + sw_param->out_block_step_, sw_param->ic_align_, sw_param->in_kw_step_, sw_param->in_kh_step_, \ + sw_param->in_sw_step_, (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_, \ + write_mode); +GenerateConvSWFunc(Arm64, C2NUM, ROW_NUM_LIST, KERNEL_LIST, COMPUTE_CORE, OUTER_COMPUTE); diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_arm64_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_arm64_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..ad0d6b477e31552d8898b5cb3d05ddfbb0519824 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_arm64_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_CONV_SW_ARM64_FP32_H_ +#define NNACL_FP32_CONV_SW_ARM64_FP32_H_ +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +bool CheckArm64UseSWConv(const ConvParameter *conv_param); +void ConvSWArm64Fp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, + int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_CONV_SW_ARM64_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_avx_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_avx_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8e76f71ecb5ad316dbee46d4f72f69fdf7ad74dc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_avx_fp32.c @@ -0,0 +1,1231 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/conv_sw_avx_fp32.h" +#include "nnacl/fp32/conv_sw.h" +#include "nnacl/intrinsics/ms_simd_avx_instructions.h" + +void SWConv3x32AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + float *dst_4 = dst + out_step * C3NUM; + out_step *= sizeof(float); + kw_remainder *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups 0x40(%2), %%ymm6\n" + "vmovups 0x60(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups 0x40(%2), %%ymm10\n" + "vmovups 0x60(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %8), %%ymm14\n" + "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" + "vmovups (%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "vmovups 0x20(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm5\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm9\n" + "vmovups 0x40(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm10\n" + "vmovups 0x60(%1), %%ymm12\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + "addq $128, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %9, %1\n" + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) // 9 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", + "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" + "vmovups %%ymm4, (%2, %1, 1)\n" + "vmovups %%ymm5, 0x20(%2, %1, 1)\n" + "vmovups %%ymm6, 0x40(%2, %1, 1)\n" + "vmovups %%ymm7, 0x60(%2, %1, 1)\n" + "vmovups %%ymm8, (%2, %1, 2)\n" + "vmovups %%ymm9, 0x20(%2, %1, 2)\n" + "vmovups %%ymm10, 0x40(%2, %1, 2)\n" + "vmovups %%ymm11, 0x60(%2, %1, 2)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm4, 0x20(%2)\n" + "vmovups %%ymm8, 0x40(%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm5, 0x20(%2, %1, 1)\n" + "vmovups %%ymm9, 0x40(%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm6, 0x20(%2, %1, 2)\n" + "vmovups %%ymm10, 0x40(%2, %1, 2)\n" + "vmovups %%ymm3, (%4)\n" + "vmovups %%ymm7, 0x20(%4)\n" + "vmovups %%ymm11, 0x40(%4)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode), "r"(dst_4) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void SWConv1x32AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + float *dst_4 = dst + out_step * C3NUM; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // Loopw + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm4\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n" + "addq $128, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %8, %1\n" + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(kw_remainder) // 8 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + + "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%4)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode), "r"(dst_4) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm14"); +} + +void SWConv4x24AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + in_sw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * out_step; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups 0x20(%0), %%ymm1\n" + "vmovups 0x40(%0), %%ymm2\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%0), %%ymm3\n" + "vmovups 0x20(%0), %%ymm4\n" + "vmovups 0x40(%0), %%ymm5\n" + "vmovups (%0), %%ymm6\n" + "vmovups 0x20(%0), %%ymm7\n" + "vmovups 0x40(%0), %%ymm8\n" + "vmovups (%0), %%ymm9\n" + "vmovups 0x20(%0), %%ymm10\n" + "vmovups 0x40(%0), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", + "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vmovups (%1), %%ymm12\n" + "vmovups 0x20(%1), %%ymm13\n" + "vmovups 0x40(%1), %%ymm14\n" + + "vbroadcastss (%%rdx), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm2\n" + + "vbroadcastss (%%rdx, %8), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm3\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm4\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm5\n" + + "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm8\n" + + "addq %2, %%rdx\n" // src_3 + "vbroadcastss (%%rdx), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm9\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm10\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm11\n" + + "subq %2, %%rdx\n" + "addq $96, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %9, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(src_3_step), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", + "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "cmpq $13, %4\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm5, 0x40(%2, %1, 1)\n" + "vmovups %%ymm6, (%2, %1, 2)\n" + "vmovups %%ymm7, 0x20(%2, %1, 2)\n" + "vmovups %%ymm8, 0x40(%2, %1, 2)\n" + "vmovups %%ymm9, (%3)\n" // dst+3 + "vmovups %%ymm10, 0x20(%3)\n" + "vmovups %%ymm11, 0x40(%3)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm3, 0x20(%2)\n" + "vmovups %%ymm6, 0x40(%2)\n" + "vmovups %%ymm9, 0x60(%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm4, 0x20(%2, %1, 1)\n" + "vmovups %%ymm7, 0x40(%2, %1, 1)\n" + "vmovups %%ymm10, 0x60(%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm5, 0x20(%2, %1, 2)\n" + "vmovups %%ymm8, 0x40(%2, %1, 2)\n" + "vmovups %%ymm11, 0x60(%2, %1, 2)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void SWConv1x24AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm3\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm3, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm3, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm3, %%ymm2\n" + "addq $96, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(kw_remainder) // 8 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + + "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "jmp 2f\n" + "1:\n" + // write to nc4hw4 + "vmovups %%ymm0, (%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm14"); +} + +void SWConv6x16AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + in_sw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * out_step; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups 0x20(%0), %%ymm1\n" + // We need to copy ymm0 to ymm3 to reduce IO time, but unfortunately I didn't find the corresponding instruction. + "vmovups (%0), %%ymm2\n" + "vmovups 0x20(%0), %%ymm3\n" + "vmovups (%0), %%ymm4\n" + "vmovups 0x20(%0), %%ymm5\n" + "vmovups (%0), %%ymm6\n" + "vmovups 0x20(%0), %%ymm7\n" + "vmovups (%0), %%ymm8\n" + "vmovups 0x20(%0), %%ymm9\n" + "vmovups (%0), %%ymm10\n" + "vmovups 0x20(%0), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11", + "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vmovups (%1), %%ymm12\n" + "vmovups 0x20(%1), %%ymm13\n" + + "vbroadcastss (%%rdx), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm1\n" + + "vbroadcastss (%%rdx, %8), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm2\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm3\n" + + "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm4\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm5\n" + + "addq %2, %%rdx\n" // src_3 + "vbroadcastss (%%rdx), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + + "vbroadcastss (%%rdx, %8), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm8\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm9\n" + + "vbroadcastss (%%rdx, %8, 2), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm10\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm11\n" + + "subq %2, %%rdx\n" + "addq $64, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %9, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(src_3_step), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(in_sw_step), "r"(kw_remainder) // 9 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", + "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "0:\n" + "cmpq $13, %4\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, (%2, %1, 1)\n" + "vmovups %%ymm3, 0x20(%2, %1, 1)\n" + "vmovups %%ymm4, (%2, %1, 2)\n" + "vmovups %%ymm5, 0x20(%2, %1, 2)\n" + "vmovups %%ymm6, (%3)\n" // dst+3 + "vmovups %%ymm7, 0x20(%3)\n" + "vmovups %%ymm8, (%3, %1, 1)\n" + "vmovups %%ymm9, 0x20(%3, %1, 1)\n" + "vmovups %%ymm10, (%3, %1, 2)\n" + "vmovups %%ymm11, 0x20(%3, %1, 2)\n" + "jmp 2f\n" + "1:\n" + // write to nc8hw8 + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm2, 0x20(%2)\n" + "vmovups %%ymm4, 0x40(%2)\n" + "vmovups %%ymm6, 0x60(%2)\n" // dst+3 + "vmovups %%ymm8, 0x80(%2)\n" + "vmovups %%ymm10, 0xA0(%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "vmovups %%ymm3, 0x20(%2, %1, 1)\n" + "vmovups %%ymm5, 0x40(%2, %1, 1)\n" + "vmovups %%ymm7, 0x60(%2, %1, 1)\n" + "vmovups %%ymm9, 0x80(%2, %1, 1)\n" + "vmovups %%ymm11, 0xA0(%2, %1, 1)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void SWConv1x16AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm3\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm3, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm3, %%ymm1\n" + "addq $64, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(kw_remainder) // 8 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm3"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + + "0:\n" + "cmpq $13, %3\n" + "je 1f\n" + // write to nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "jmp 2f\n" + "1:\n" + // write nc8hw8 + "vmovups %%ymm0, (%2)\n" + "vmovups %%ymm1, (%2, %1, 1)\n" + "2:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm12", "%ymm14"); +} + +void SWConv12x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + size_t src_3_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * out_step; + float *dst_5 = dst + 5 * out_step; + float *dst_9 = dst + 9 * out_step; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups (%0), %%ymm1\n" + "vmovups (%0), %%ymm2\n" + "vmovups (%0), %%ymm3\n" + "vmovups (%0), %%ymm4\n" + "vmovups (%0), %%ymm5\n" + "vmovups (%0), %%ymm6\n" + "vmovups (%0), %%ymm7\n" + "vmovups (%0), %%ymm8\n" + "vmovups (%0), %%ymm9\n" + "vmovups (%0), %%ymm10\n" + "vmovups (%0), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + + asm volatile( + "LoopH:\n" + "movq %3, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "LoopW:\n" + "movq %%rcx, %%rdx\n" + "movq %4, %%r12\n" // ic_algin + "LoopIC:\n" + "vmovups (%1), %%ymm12\n" + "addq $32, %1\n" + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %7), %%ymm14\n" + "vbroadcastss (%%rdx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "addq %8, %%rdx\n" + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %7), %%ymm14\n" + "vbroadcastss (%%rdx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + "addq %8, %%rdx\n" + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %7), %%ymm14\n" + "vbroadcastss (%%rdx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm7\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm8\n" + "addq %8, %%rdx\n" + "vbroadcastss (%%rdx), %%ymm13\n" + "vbroadcastss (%%rdx, %7), %%ymm14\n" + "vbroadcastss (%%rdx, %7, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm10\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm11\n" + + "subq %8, %%rdx\n" + "subq %8, %%rdx\n" + "subq %8, %%rdx\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg LoopIC\n" + + "addq %5, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg LoopW\n" + + "addq %6, %0\n" // in_kh_step + "addq %9, %1\n" // border in sw need to add remainder data + "dec %2\n" + "jg LoopH\n" + : + : "r"(src), "r"(weight), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), "r"(in_kh_step), // 6 + "r"(in_sw_step), "r"(src_3_step), "r"(kw_remainder) // 9 + : "%rcx", "%rdx", "%r12", "%rsi", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", + "%ymm9", "%ymm10", "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je Write\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + + "and $0x1, %%eax\n" + "je Write\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + + "Write:\n" + "cmpq $13, %6\n" + "je WriteNC8HW8\n" + // write nhwc + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%3)\n" // dst_3 + "vmovups %%ymm4, (%2, %1, 4)\n" + "vmovups %%ymm5, (%4)\n" // dst_5 + "vmovups %%ymm6, (%4, %1, 1)\n" + "vmovups %%ymm7, (%4, %1, 2)\n" + "vmovups %%ymm8, (%2, %1, 8)\n" + "vmovups %%ymm9, (%5)\n" // dst_9 + "vmovups %%ymm10, (%5, %1, 1)\n" + "vmovups %%ymm11, (%5, %1, 2)\n" + "jmp End\n" + "WriteNC8HW8:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, 0x20(%2)\n" + "vmovups %%ymm2, 0x40(%2)\n" + "vmovups %%ymm3, 0x60(%2)\n" // dst_3 + "vmovups %%ymm4, 0x80(%2)\n" + "vmovups %%ymm5, 0xA0(%2)\n" // dst_5 + "vmovups %%ymm6, 0xC0(%2)\n" + "vmovups %%ymm7, 0xE0(%2)\n" + "vmovups %%ymm8, 0x100(%2)\n" + "vmovups %%ymm9, 0x120(%2)\n" // dst_9 + "vmovups %%ymm10, 0x140(%2)\n" + "vmovups %%ymm11, 0x160(%2)\n" + "End:\n" + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3), "r"(dst_5), "r"(dst_9), "r"(write_mode) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm14"); +} + +void SWConv4x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_sw_step *= sizeof(float); + in_kw_step *= sizeof(float); + size_t src_step = 3 * in_sw_step; + float *dst_3 = dst + 3 * out_step; + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %0\n" + "je 0f\n" + "vmovups (%0), %%ymm0\n" + "vmovups (%0), %%ymm1\n" + "vmovups (%0), %%ymm2\n" + "vmovups (%0), %%ymm3\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "1:\n" + : + : "r"(bias) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + + asm volatile( + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vmovups (%1), %%ymm12\n" + "movq %%rdx, %%rax\n" + "addq $32, %1\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vbroadcastss (%%rax, %8), %%ymm14\n" + "vbroadcastss (%%rax, %8, 2), %%ymm15\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm14, %%ymm1\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm2\n" + "addq %9, %%rax\n" + "vbroadcastss (%%rax), %%ymm13\n" + "vfmadd231ps %%ymm12, %%ymm13, %%ymm3\n" + + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %2, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(kw_remainder), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(in_sw_step), "r"(src_step) // 9 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + + "0:\n" + "vmovups %%ymm0, (%2)\n" // dst_0 + "vmovups %%ymm1, (%2, %1)\n" + "vmovups %%ymm2, (%2, %1, 2)\n" + "vmovups %%ymm3, (%3)\n" // dst_3 + : + : "a"(act_flag), "r"(out_step), "r"(dst), "r"(dst_3) + : "%ecx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm12", "%ymm14"); +} + +void SWConv1x8AVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + in_kh_step *= sizeof(float); + in_kw_step *= sizeof(float); + kw_remainder *= sizeof(float); + out_step *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "1:\n" // LoopH + "movq %4, %%rsi\n" // width + "movq %0, %%rcx\n" // src_h + "2:\n" // LoopW + "movq %%rcx, %%rdx\n" + "movq %5, %%r12\n" // ic_algin + "3:\n" // LoopIC + "vbroadcastss (%%rdx), %%ymm1\n" + // Weight data is loaded directly from memory instead of into registers for calculation. + "vfmadd231ps (%1), %%ymm1, %%ymm0\n" + "addq $32, %1\n" + "addq $4, %%rdx\n" + "dec %%r12\n" + "jg 3b\n" + + "addq %6, %%rcx\n" // in_kw_step + "dec %%rsi\n" + "jg 2b\n" + + "addq %7, %0\n" // in_kh_step + "addq %8, %1\n" // border in sw need to add remainder data + "dec %3\n" + "jg 1b\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(kernel_h), "r"(kernel_w), "r"(ic_algin), "r"(in_kw_step), // 6 + "r"(in_kh_step), "r"(kw_remainder) // 8 + : "%rcx", "%rdx", "%rsi", "%r12", "%ymm0", "%ymm1"); + + asm volatile( + "and $0x3, %%eax\n" + "je 0f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + + "and $0x1, %%eax\n" + "je 0f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + + "0:\n" + // write to nhec and nc8hw8 is identical! + "vmovups %%ymm0, (%2)\n" // dst_0 + : + : "a"(act_flag), "r"(out_step), "r"(dst) + : "%ecx", "%ymm0", "%ymm12", "%ymm14"); +} + +typedef void (*SWConvKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, + size_t kw_remainder, size_t write_mode); + +#define ROW_NUM_LIST const int ow_block_num[4] = {12, 6, 4, 3}; +#define KERNEL_LIST \ + const SWConvKernel kernel[4][2] = {{SWConv1x8AVXKernel, SWConv12x8AVXKernel}, \ + {SWConv1x16AVXKernel, SWConv6x16AVXKernel}, \ + {SWConv1x24AVXKernel, SWConv4x24AVXKernel}, \ + {SWConv1x32AVXKernel, SWConv3x32AVXKernel}}; +#define COMPUTE_CORE \ + if (ow_block < ow_block_num[oc_block - 1]) { \ + ow_block = 1; \ + } \ + kernel[oc_block - 1][ow_block / ow_block_num[oc_block - 1]]( \ + dst_oc + ow * out_w_step, src_w, weight, bias, kernel_h, kernel_w, act_type, ow_block, oc_block, out_block_step, \ + ic_algin, in_kw_step, in_kh_step, in_sw_step, 0, write_mode); +#define OUTER_COMPUTE \ + kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type, ow_bock, \ + oc_block, sw_param->out_block_step_, sw_param->ic_align_, sw_param->in_kw_step_, sw_param->in_kh_step_, \ + sw_param->in_sw_step_, (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_, \ + write_mode); + +GenerateConvSWFunc(AVX, C4NUM, ROW_NUM_LIST, KERNEL_LIST, COMPUTE_CORE, OUTER_COMPUTE); + +#ifdef ENABLE_DEBUG +void SWConvWxKAVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t out_step, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode) { + __m256 dst_data[12]; + const float *src_kh[12]; + const float *src_kw[12]; + __m256 weight_data[4]; + for (int i = 0; i < ow_block; ++i) { + if (bias != NULL) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_loadu_ps(bias + j * C8NUM); + } + } else { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] = _mm256_set1_ps(0.0f); + } + } + src_kh[i] = src + i * in_sw_step; + src_kw[i] = NULL; + } + const float *weight_kernel = weight; + for (int kh = 0; kh < kernel_h; kh++) { + for (int i = 0; i < ow_block; ++i) { + src_kw[i] = src_kh[i]; + } + for (int kw = 0; kw < kernel_w; kw++) { + for (int ic = 0; ic < ic_algin; ++ic) { + for (int j = 0; j < oc_block; ++j) { + weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); + } + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + dst_data[i * oc_block + j] += src_kw[i][ic] * weight_data[j]; + } + } + weight_kernel += C8NUM * oc_block; + } // ic loop + for (int i = 0; i < ow_block; ++i) { + src_kw[i] += in_kw_step; + } + } // kernel_w loop + weight_kernel += kw_remainder; + for (int i = 0; i < ow_block; ++i) { + src_kh[i] += in_kh_step; + } + } // kernel_h loop + // add bias and relu + for (int i = 0; i < ow_block; ++i) { + for (int j = 0; j < oc_block; ++j) { + if (0x1 & act_flag) { // relu6 + dst_data[i * oc_block + j] = _mm256_min_ps(dst_data[i * oc_block + j], _mm256_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * oc_block + j] = _mm256_max_ps(dst_data[i * oc_block + j], _mm256_set1_ps(0.0f)); + } + if (write_mode == C13NUM) { + // write nc8hw8 + _mm256_storeu_ps(dst + j * out_step + i * C8NUM, dst_data[i * oc_block + j]); + } else { + // write nhwc + _mm256_storeu_ps(dst + i * out_step + j * C8NUM, dst_data[i * oc_block + j]); + } + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_avx_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_avx_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d7aa79e62292700ed46bc9d8ebd70459ae9a171e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_sw_avx_fp32.h @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CONV_SW_AVX_H_ +#define MINDSPORE_NNACL_FP32_CONV_SW_AVX_H_ + +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ConvSWAVXFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *output_data, + int task_id, ConvParameter *conv_param, SlidingWindowParam *sw_param); + +#ifdef ENABLE_DEBUG +void SWConvWxKAVXKernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h, + size_t kernel_w, size_t act_flag, size_t ow_block, size_t oc_block, size_t oc_algin, + size_t ic_algin, size_t in_kw_step, size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, + size_t write_mode); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_SW_AVX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_winograd_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_winograd_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..ffbf697f94e594faafd0ccfc4269efa41afff48f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_winograd_fp32.c @@ -0,0 +1,265 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/conv_winograd_fp32.h" +#include +#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl/fp32/winograd_transform.h" +#include "nnacl/fp32/matmul_fp32.h" + +// fp32 conv winograd +void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param, + TransFuncList trans_func) { + if (conv_param->output_unit_ == 0) { + return; + } + int in_channel = conv_param->input_channel_; + int input_unit = conv_param->input_unit_; + int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); + int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); + int output_count = out_w_block * out_h_block; + const int tile_num = C12NUM; + int output_tile_count = UP_DIV(output_count, tile_num); +#ifdef ENABLE_AVX + const int col_tile = C16NUM; + const int channel_pack_tile = C8NUM; +#else + const int col_tile = C8NUM; + const int channel_pack_tile = C4NUM; +#endif + int oc_tile = UP_DIV(conv_param->output_channel_, col_tile); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + int input_unit_square = input_unit * input_unit; + + float *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel; + float *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM; + float *tmp_data = buffer_list[2] + task_id * input_unit_square * channel_pack_tile; + float *col_buffer = buffer_list[3] + task_id * tile_num * in_channel; + // step 1 : filter transform (pre-processed offline) + // step 2 : input transform (online) + + int block_per_thread = UP_DIV(output_tile_count, conv_param->thread_num_); + int start_index = block_per_thread * task_id * tile_num; + if (start_index >= output_count) { + return; + } + int end_index = MSMIN(start_index + block_per_thread * tile_num, output_count); + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_w_ * conv_param->output_h_; + + for (int out_tile_index = start_index; out_tile_index < end_index; out_tile_index += tile_num) { + int cal_num = output_count - out_tile_index; + cal_num = cal_num > tile_num ? tile_num : cal_num; + if (cal_num <= 0) { + return; + } + +#ifdef ENABLE_ARM64 + // Optimize input transform. Only valid for arm64, the tile num is 12, the channel_tile is 4. + // For arm32, the tile_num is 4. + // For x86_sse, the tile_num is 4, the channel_tile is 4. + // For avx, the tile_num is 6, the channel_tile is 8. + // N = input_unit, M = tile_num + // The function(InputTransformNxNStep, InputTransform4x4PackM) needs to be rewritten. + bool fused_pack = + (cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL); + if (fused_pack) { + float *opt_trans_input = + buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, channel_pack_tile); + WinogradInputTransformOptStep(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_step_func_); + + for (int w_index = 0; w_index < input_unit; w_index++) { + float *src_w = opt_trans_input + w_index * input_unit * tile_num * channel_pack_tile; + for (int c = 0; c < UP_DIV(in_channel, channel_pack_tile); c++) { + int real_c = in_channel - c * channel_pack_tile; + real_c = real_c > channel_pack_tile ? channel_pack_tile : real_c; + float *src_c = src_w + c * input_unit_square * tile_num * channel_pack_tile; + float *dst_c = trans_input + c * tile_num * channel_pack_tile; + trans_func.in_pack_func_(src_c, dst_c, channel_pack_tile, in_channel * tile_num, real_c); + } + + for (int h_index = 0; h_index < input_unit; h_index++) { + const float *gemm_input = trans_input + h_index * tile_num * in_channel; + int point_index = h_index * input_unit + w_index; + const float *gemm_weight = trans_weight + point_index * in_channel * oc_tile * col_tile; + MatMulOpt(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num, + oc8 * C8NUM, input_unit_square, OutType_TileC8); + } + } + } else { +#endif + WinogradInputTransform(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_func_); + // step 3 : gemm + float *src_ptr = trans_input; + float *dst_ptr = gemm_out; + float *tmp_col_ptr = col_buffer; + for (int i = 0; i < input_unit_square; ++i) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) + RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#else + RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#endif + MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0, + in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2); + } +#ifdef ENABLE_ARM64 + } +#endif + + // step 4 : output transform + float *output_ptr = output_data + out_batch_offset; + if (conv_param->out_format_ != Format_NC4HW4) { // nc4hw4 + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); + } else { +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) + WinogradOutputNC4HW4Transform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); +#else + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); +#endif + } + } + } +} + +// fp32 conv winograd +void ConvWinogardFp32CutByBatch(const float *input_data, const float *trans_weight, const float *bias_data, + float *output_data, TmpBufferAddress *buffer_list, int task_id, + const ConvParameter *conv_param, TransFuncList trans_func) { + if (conv_param->output_unit_ == 0) { + return; + } + int in_channel = conv_param->input_channel_; + int input_unit = conv_param->input_unit_; + int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); + int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); + int output_count = out_w_block * out_h_block; + const int tile_num = C12NUM; +#ifdef ENABLE_AVX + const int col_tile = C16NUM; + const int channel_pack_tile = C8NUM; +#else + const int col_tile = C8NUM; + const int channel_pack_tile = C4NUM; +#endif + int oc_tile = UP_DIV(conv_param->output_channel_, col_tile); + int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); + int input_unit_square = input_unit * input_unit; + + float *trans_input = buffer_list[0] + task_id * tile_num * input_unit_square * in_channel; + float *gemm_out = buffer_list[1] + task_id * tile_num * input_unit_square * oc8 * C8NUM; + float *tmp_data = buffer_list[2] + task_id * input_unit_square * channel_pack_tile; + float *col_buffer = buffer_list[3] + task_id * tile_num * in_channel; + // step 1 : filter transform (pre-processed offline) + // step 2 : input transform (online) + + int block_batch_per_thread = UP_DIV(conv_param->input_batch_, conv_param->thread_num_); + int start_batch = block_batch_per_thread * task_id; + int end_batch = MSMIN(conv_param->input_batch_, (start_batch + block_batch_per_thread)); + + for (int b = start_batch; b < end_batch; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_w_ * conv_param->output_h_; + + for (int out_tile_index = 0; out_tile_index < output_count; out_tile_index += tile_num) { + int cal_num = output_count - out_tile_index; + cal_num = cal_num > tile_num ? tile_num : cal_num; + if (cal_num <= 0) { + return; + } + +#ifdef ENABLE_ARM64 + // Optimize input transform. Only valid for arm64, the tile num is 12, the channel_tile is 4. + // For arm32, the tile_num is 4. + // For x86_sse, the tile_num is 4, the channel_tile is 4. + // For avx, the tile_num is 6, the channel_tile is 8. + // N = input_unit, M = tile_num + // The function(InputTransformNxNStep, InputTransform4x4PackM) needs to be rewritten. + bool fused_pack = + (cal_num == tile_num) && (trans_func.in_step_func_ != NULL) && (trans_func.in_pack_func_ != NULL); + if (fused_pack) { + float *opt_trans_input = + buffer_list[4] + task_id * tile_num * input_unit_square * UP_ROUND(in_channel, channel_pack_tile); + WinogradInputTransformOptStep(input_data + in_batch_offset, opt_trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_step_func_); + + for (int w_index = 0; w_index < input_unit; w_index++) { + float *src_w = opt_trans_input + w_index * input_unit * tile_num * channel_pack_tile; + for (int c = 0; c < UP_DIV(in_channel, channel_pack_tile); c++) { + int real_c = in_channel - c * channel_pack_tile; + real_c = real_c > channel_pack_tile ? channel_pack_tile : real_c; + float *src_c = src_w + c * input_unit_square * tile_num * channel_pack_tile; + float *dst_c = trans_input + c * tile_num * channel_pack_tile; + trans_func.in_pack_func_(src_c, dst_c, channel_pack_tile, in_channel * tile_num, real_c); + } + + for (int h_index = 0; h_index < input_unit; h_index++) { + const float *gemm_input = trans_input + h_index * tile_num * in_channel; + int point_index = h_index * input_unit + w_index; + const float *gemm_weight = trans_weight + point_index * in_channel * oc_tile * col_tile; + MatMulOpt(gemm_input, gemm_weight, gemm_out + point_index * C8NUM, NULL, 0, in_channel, cal_num, + oc8 * C8NUM, input_unit_square, OutType_TileC8); + } + } + } else { +#endif + WinogradInputTransform(input_data + in_batch_offset, trans_input, tmp_data, cal_num, out_tile_index, + out_w_block, conv_param, trans_func.in_func_); + // step 3 : gemm + float *src_ptr = trans_input; + float *dst_ptr = gemm_out; + float *tmp_col_ptr = col_buffer; + for (int i = 0; i < input_unit_square; ++i) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE) + RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#else + RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); +#endif + MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc_tile * col_tile, dst_ptr + i * C8NUM, NULL, 0, + in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2); + } +#ifdef ENABLE_ARM64 + } +#endif + + // step 4 : output transform + float *output_ptr = output_data + out_batch_offset; + if (conv_param->out_format_ != Format_NC4HW4) { // nc4hw4 + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); + } else { +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) + WinogradOutputNC4HW4Transform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); +#else + WinogradOutputNHWCTransform(gemm_out, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, + trans_func.out_func_); +#endif + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_winograd_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_winograd_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..fcb66218421e9b2fc1d9a8b983c924a8087310fc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/conv_winograd_fp32.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_CONV_WINOGRAD_H_ +#define MINDSPORE_NNACL_FP32_CONV_WINOGRAD_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/fp32/winograd_utils.h" +#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl/kernel/convolution_winograd_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// fp32 convolution winograd +void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data, + TmpBufferAddress *buffer_list, int task_id, const ConvParameter *conv_param, + TransFuncList trans_func); + +void ConvWinogardFp32CutByBatch(const float *input_data, const float *trans_weight, const float *bias_data, + float *output_data, TmpBufferAddress *buffer_list, int task_id, + const ConvParameter *conv_param, TransFuncList trans_func); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CONV_WINOGRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/crop_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/crop_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..aaa8279226ebbe269de5fb7d5faa03e48097c088 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/crop_fp32.c @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/crop_fp32.h" +#include +#include "nnacl/op_base.h" +#include "nnacl/crop_parameter.h" + +void Pad4DOffset(const CropParameter *crop_param, int64_t *offset, int length) { + int axis = crop_param->axis_; + for (int i = length - 1; i >= 0; --i) { + int offset_index = i - axis; + if (offset_index >= 0 && offset_index < COMM_SHAPE_SIZE) { + offset[i] = crop_param->offset_[offset_index]; + } else { + offset[i] = 0; + } + } +} + +void Crop4D(const float *input, float *output, const int32_t *in_shape, const int32_t *out_shape, + const CropParameter *crop_param, int thread_id, int thread_num) { + int64_t offset_pad[DIMENSION_4D] = {0}; + Pad4DOffset(crop_param, offset_pad, DIMENSION_4D); + int out_shape1 = out_shape[1]; + int out_shape2 = out_shape[2]; + int out_shape3 = out_shape[3]; + size_t out_stride2 = out_shape3; + size_t out_stride1 = out_stride2 * out_shape2; + size_t out_stride0 = out_stride1 * out_shape1; + size_t in_stride2 = in_shape[3]; + size_t in_stride1 = in_stride2 * in_shape[2]; + size_t in_stride0 = in_stride1 * in_shape[1]; + size_t copy_size = out_shape3 * sizeof(float); + + size_t count_per_thread = UP_DIV(out_shape1, thread_num); + size_t thread_stride = thread_id * count_per_thread; + for (int i = 0; i < out_shape[0]; ++i) { + size_t out_offset0 = i * out_stride0; + size_t in_offset0 = (i + offset_pad[0]) * in_stride0 + offset_pad[3]; + for (size_t j = 0; j < count_per_thread; ++j) { + size_t k = j + thread_stride; + if (k >= out_shape1) { + break; + } + size_t out_offset1 = k * out_stride1 + out_offset0; + size_t in_offset1 = (k + offset_pad[1]) * in_stride1 + in_offset0; + for (int l = 0; l < out_shape2; ++l) { + size_t out_offset = l * out_stride2 + out_offset1; + size_t in_offset = (l + offset_pad[2]) * in_stride2 + in_offset1; + memcpy(output + out_offset, input + in_offset, copy_size); + } + } + } +} + +void Crop4DNoParallel(const float *input, float *output, const int32_t *in_shape, const int32_t *out_shape, + const CropParameter *crop_param) { + int64_t offset_pad[DIMENSION_4D] = {0}; + Pad4DOffset(crop_param, offset_pad, DIMENSION_4D); + size_t in_dim2_stride = in_shape[3]; + size_t in_dim1_stride = in_shape[2] * in_dim2_stride; + size_t in_dim0_stride = in_dim1_stride * in_shape[1]; + size_t offset_3 = offset_pad[3]; + size_t out_offset = 0; + size_t copy_num = out_shape[3]; + size_t copy_size = copy_num * sizeof(float); + size_t in_dim0_end = offset_pad[0] + out_shape[0]; + size_t in_dim1_end = offset_pad[1] + out_shape[1]; + size_t in_dim2_end = offset_pad[2] + out_shape[2]; + for (int i = offset_pad[0]; i < in_dim0_end; ++i) { + size_t dim0_offset = (size_t)i * in_dim0_stride + offset_3; + for (int j = offset_pad[1]; j < in_dim1_end; ++j) { + size_t dim1_offset = (size_t)j * in_dim1_stride + dim0_offset; + for (int k = offset_pad[2]; k < in_dim2_end; ++k) { + size_t in_offset = dim1_offset + (size_t)k * in_dim2_stride; + memcpy(output + out_offset, input + in_offset, copy_size); + out_offset += copy_num; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/crop_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/crop_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..07b9f1b2cca8dd0711e759a13692246cc554f70a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/crop_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_CROP_FP32_H_ +#define NNACL_FP32_CROP_FP32_H_ +#include "nnacl/op_base.h" +#include "nnacl/crop_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Crop4D(const float *input, float *output, const int32_t *in_shape, const int32_t *out_shape, + const CropParameter *crop_param, int thread_id, int thread_num); +void Crop4DNoParallel(const float *input, float *output, const int32_t *in_shape, const int32_t *out_shape, + const CropParameter *crop_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_CROP_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cumsum_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cumsum_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1a816212f22949cd6a3a3feaad2405fd0e04e43c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cumsum_fp32.c @@ -0,0 +1,200 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/cumsum_fp32.h" +#include "nnacl/op_base.h" +#include "nnacl/cumsum_fp32_simd.h" + +// (a, b, c) -> (a, a+b, a+b+c) exclusive == false +// (a, b, c) -> (0, a, a+b) exclusive == true +void Cumsum(const float *input, float *output, int out_dim, int axis_dim, int inner_dim, bool exclusive) { + // when not exclusive, output axis dim[0] is the same as that of input. + // when exclusive, output axis dim[0] is 0.0f + if (!exclusive) { + for (int i = 0; i < out_dim; ++i) { + const float *layer_input = input + i * axis_dim * inner_dim; + float *layer_output = output + i * axis_dim * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumOutputInitWithInput, j, layer_input, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = *(layer_input + j); + } + } + } else { + for (int i = 0; i < out_dim; ++i) { + float *layer_output = output + i * axis_dim * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumOutputInitWithZero, j, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = 0.0f; + } + } + } + int input_offset = exclusive ? 0 : 1; + for (int i = 0; i < out_dim; ++i) { + const float *layer_input = input + i * axis_dim * inner_dim + inner_dim * input_offset; + float *layer_last_output = output + i * axis_dim * inner_dim; + float *layer_output = layer_last_output + inner_dim; + + for (int j = 1; j < axis_dim; ++j) { + int k = 0; + SIMD_RUN_NO_SCALAR(Cumsum, k, layer_input, layer_output, layer_last_output, inner_dim); + for (; k < inner_dim; ++k) { + // layer_output (i, j, k) = layer_input (i, j, k) + layer_last_output (i,j-1, k) + *(layer_output + k) = *(layer_input + k) + *(layer_last_output + k); + } + layer_input += inner_dim; + layer_last_output += inner_dim; + layer_output += inner_dim; + } + } +} + +// (a, b, c) -> (c+b+a, c+b, c) exclusive==false +// (a, b, c) -> (c+b, c, 0) exclusive==true +void CumsumReverse(const float *input, float *output, int out_dim, int axis_dim, int inner_dim, bool exclusive) { + if (!exclusive) { + for (int i = 0; i < out_dim; ++i) { + const float *layer_input = input + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + float *layer_output = output + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumOutputInitWithInput, j, layer_input, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = *(layer_input + j); + } + } + } else { + for (int i = 0; i < out_dim; ++i) { + float *layer_output = output + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumOutputInitWithZero, j, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = 0.0f; + } + } + } + int input_offset = exclusive ? 0 : 1; + for (int i = 0; i < out_dim; ++i) { + const float *layer_input = input + (i + 1) * axis_dim * inner_dim - 1 - input_offset * inner_dim; + float *layer_last_output = output + (i + 1) * axis_dim * inner_dim - 1; + float *layer_output = layer_last_output - inner_dim; + + for (int j = 1; j < axis_dim; ++j) { + int k = 0; + SIMD_RUN_NO_SCALAR(CumsumReverse, k, layer_input, layer_output, layer_last_output, inner_dim); + for (; k < inner_dim; ++k) { + *(layer_output - k) = *(layer_input - k) + *(layer_last_output - k); + } + layer_input -= inner_dim; + layer_last_output -= inner_dim; + layer_output -= inner_dim; + } + } +} + +// (a, b, c) -> (a, a+b, a+b+c) exclusive == false +// (a, b, c) -> (0, a, a+b) exclusive == true +void CumsumInt(const int32_t *input, int32_t *output, int out_dim, int axis_dim, int inner_dim, bool exclusive) { + // when not exclusive, output axis dim[0] is the same as that of input. + // when exclusive, output axis dim[0] is 0 + if (!exclusive) { + for (int i = 0; i < out_dim; ++i) { + const int32_t *layer_input = input + i * axis_dim * inner_dim; + int32_t *layer_output = output + i * axis_dim * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumIntOutputInitWithInput, j, layer_input, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output + j) = *(layer_input + j); + } + } + } else { + for (int i = 0; i < out_dim; ++i) { + int32_t *layer_output = output + i * axis_dim * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumIntOutputInitWithZero, j, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output++) = 0; + } + } + } + int input_offset = exclusive ? 0 : 1; + for (int i = 0; i < out_dim; ++i) { + const int32_t *layer_input = input + i * axis_dim * inner_dim + inner_dim * input_offset; + int32_t *layer_last_output = output + i * axis_dim * inner_dim; + int32_t *layer_output = layer_last_output + inner_dim; + + for (int j = 1; j < axis_dim; ++j) { + int k = 0; + SIMD_RUN_NO_SCALAR(CumsumInt, k, layer_input, layer_output, layer_last_output, inner_dim); + for (; k < inner_dim; ++k) { + *(layer_output + k) = *(layer_input + k) + *(layer_last_output + k); + } + layer_input += inner_dim; + layer_last_output += inner_dim; + layer_output += inner_dim; + } + } +} + +// (a, b, c) -> (c+b+a, c+b, c) exclusive==false +// (a, b, c) -> (c+b, c, 0) exclusive==true +void CumsumReverseInt(const int32_t *input, int32_t *output, int out_dim, int axis_dim, int inner_dim, bool exclusive) { + if (!exclusive) { + for (int i = 0; i < out_dim; ++i) { + const int32_t *layer_input = input + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + int32_t *layer_output = output + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumIntOutputInitWithInput, j, layer_input, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output++) = *(layer_input++); + } + } + } else { + for (int i = 0; i < out_dim; ++i) { + int32_t *layer_output = output + i * axis_dim * inner_dim + (axis_dim - 1) * inner_dim; + + int j = 0; + SIMD_RUN_NO_SCALAR(CumsumIntOutputInitWithZero, j, layer_output, inner_dim); + for (; j < inner_dim; ++j) { + *(layer_output++) = 0.0f; + } + } + } + int input_offset = exclusive ? 0 : 1; + for (int i = 0; i < out_dim; ++i) { + const int32_t *layer_input = input + (i + 1) * axis_dim * inner_dim - 1 - input_offset * inner_dim; + int32_t *layer_last_output = output + (i + 1) * axis_dim * inner_dim - 1; + int32_t *layer_output = layer_last_output - inner_dim; + + for (int j = 1; j < axis_dim; ++j) { + int k = 0; + SIMD_RUN_NO_SCALAR(CumsumReverseInt, k, layer_input, layer_output, layer_last_output, inner_dim); + for (; k < inner_dim; ++k) { + *(layer_output - k) = *(layer_input - k) + *(layer_last_output - k); + } + layer_input -= inner_dim; + layer_last_output -= inner_dim; + layer_output -= inner_dim; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cumsum_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cumsum_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..39f49f51eb2b75df3fc1e88db8da5618553eb0bb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cumsum_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CUMSUM_H_ +#define MINDSPORE_NNACL_FP32_CUMSUM_H_ +#include "nnacl/op_base.h" +#include "nnacl/cumsum_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Cumsum(const float *input, float *output, int out_dim, int axis_dim, int inner_dim, bool exclusive); +void CumsumReverse(const float *input, float *output, int out_dim, int axis_dim, int inner_dim, bool exclusive); +void CumsumInt(const int32_t *input, int32_t *output, int out_dim, int axis_dim, int inner_dim, bool exclusive); +void CumsumReverseInt(const int32_t *input, int32_t *output, int out_dim, int axis_dim, int inner_dim, bool exclusive); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_CUMSUM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cumsum_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cumsum_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..5f9fbbc1514b48f7681355808c49fa3e0d72f40f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/cumsum_fp32_simd.h.in @@ -0,0 +1,114 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CUMSUM_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_CUMSUM_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +// (a, b, c) -> (a, a+b, a+b+c) exclusive == false +// (a, b, c) -> (0, a, a+b) exclusive == true +static inline int64_t CumsumOutputInitWithInput@SIMD_INSTRUCTION@(int64_t index, const float *layer_input, + float *layer_output, int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(layer_output + index, SIMD_LD_F32(layer_input + index)); + } + return index; +} + +static inline int64_t CumsumOutputInitWithZero@SIMD_INSTRUCTION@(int64_t index, float *layer_output, int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(layer_output + index, SIMD_MOV_F32(0.0f)); + } + return index; +} + +static inline int64_t Cumsum@SIMD_INSTRUCTION@(int64_t index, const float *layer_input, float *layer_output, float *layer_last_output, + int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input_val = SIMD_LD_F32(layer_input + index); + SIMD_F32 last_output_val = SIMD_LD_F32(layer_last_output + index); + SIMD_F32 out_val = SIMD_ADD_F32(input_val, last_output_val); + SIMD_ST_F32(layer_output + index, out_val); + } + return index; +} + +// (a, b, c) -> (c+b+a, c+b, c) exclusive==false +// (a, b, c) -> (c+b, c, 0) exclusive==true +static inline int64_t CumsumReverse@SIMD_INSTRUCTION@(int64_t index, const float *layer_input, float *layer_output, + float *layer_last_output, int inner_dim) { + + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input_val = SIMD_LD_F32(layer_input - index - BLOCK_NUM + 1); + SIMD_F32 last_output_val = SIMD_LD_F32(layer_last_output - index - BLOCK_NUM + 1); + SIMD_F32 out_val = SIMD_ADD_F32(input_val, last_output_val); + SIMD_ST_F32(layer_output - index - BLOCK_NUM + 1, out_val); + } + return index; +} + +// (a, b, c) -> (a, a+b, a+b+c) exclusive == false +// (a, b, c) -> (0, a, a+b) exclusive == true +static inline int64_t CumsumIntOutputInitWithInput@SIMD_INSTRUCTION@(int64_t index, const int32_t *layer_input, + int32_t *layer_output, int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(layer_output + index, SIMD_LD_EPI32(layer_input + index)); + } + return index; +} + +static inline int64_t CumsumIntOutputInitWithZero@SIMD_INSTRUCTION@(int64_t index, int32_t *layer_output, int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_EPI32(layer_output + index, SIMD_MOV_EPI32(0.0f)); + } + return index; +} + +static inline int64_t CumsumInt@SIMD_INSTRUCTION@(int64_t index, const int32_t *layer_input, int32_t *layer_output, int32_t *layer_last_output, + int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 input_val = SIMD_LD_EPI32(layer_input + index); + SIMD_EPI32 last_output_val = SIMD_LD_EPI32(layer_last_output + index); + SIMD_EPI32 out_val = SIMD_ADD_EPI32(input_val, last_output_val); + SIMD_ST_EPI32(layer_output + index, out_val); + } + return index; +} + +// (a, b, c) -> (c+b+a, c+b, c) exclusive==false +// (a, b, c) -> (c+b, c, 0) exclusive==true +static inline int64_t CumsumReverseInt@SIMD_INSTRUCTION@(int64_t index, const int32_t *layer_input, int32_t *layer_output, int32_t *layer_last_output, + int inner_dim) { + for (int block_max_size = inner_dim - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 input_val = SIMD_LD_EPI32(layer_input - index - BLOCK_NUM + 1); + SIMD_EPI32 last_output_val = SIMD_LD_EPI32(layer_last_output - index - BLOCK_NUM + 1); + SIMD_EPI32 out_val = SIMD_ADD_EPI32(input_val, last_output_val); + SIMD_ST_EPI32(layer_output - index - BLOCK_NUM + 1, out_val); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/custom_gru_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/custom_gru_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..434d998c4c299a944e0d74867242ff2238b6e51c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/custom_gru_fp32.c @@ -0,0 +1,72 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/custom_gru_fp32.h" +#include "nnacl/fp32/activation_fp32.h" +#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32/pack_fp32.h" + +void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden, + const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4], + const CustomGruParameter *gru_param) { + int num_step = gru_param->num_step; + int batch_size = gru_param->batch_size; + int input_size = gru_param->input_size; + int hidden_size = gru_param->hidden_size; + int output_size = batch_size * hidden_size; + int double_output_size = output_size * C2NUM; + int col_align = UP_ROUND(hidden_size, C8NUM); + int weight_in_offset = col_align * input_size; + int weight_hidden_offset = col_align * hidden_size; + float *input_gate = buffer[1]; + float *hidden_gate = buffer[C3NUM]; + for (int i = 0; i < num_step; ++i) { + if (batch_size != 1) { + RowMajor2Col12MajorParallel(input + i * batch_size * input_size, buffer[0], batch_size, input_size, 0, + batch_size); + for (int j = 0; j < C3NUM; ++j) { + MatMulOpt(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + RowMajor2Col12MajorParallel(init_h, buffer[C2NUM], batch_size, hidden_size, 0, batch_size); + for (int j = 0; j < C3NUM; ++j) { + MatMulOpt(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size, + OutType_Nhwc); + } + } else { + for (int j = 0; j < C3NUM; ++j) { + MatVecMulPackFp32(input + i * input_size, weight_input + j * weight_in_offset, input_gate + j * output_size, + bias_input + j * col_align, ActType_No, input_size, hidden_size); + MatVecMulPackFp32(init_h, weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, + bias_hidden + j * col_align, ActType_No, hidden_size, hidden_size); + } + } + ElementAdd(input_gate, hidden_gate, input_gate, double_output_size); + Sigmoid(input_gate, double_output_size, input_gate); + ElementMul(input_gate, hidden_gate + double_output_size, input_gate, output_size); + ElementAdd(input_gate, input_gate + double_output_size, input_gate, output_size); + Tanh(input_gate, output_size, input_gate); + ElementSub(init_h, input_gate, hidden_gate, output_size); + ElementMul(input_gate + output_size, hidden_gate, hidden_gate, output_size); + ElementAdd(input_gate, hidden_gate, output, output_size); + init_h = output; + output += output_size; + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/custom_gru_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/custom_gru_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..576726c5a316c0138ea4f9a20c1895124280886e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/custom_gru_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ +#define MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ +#ifdef ENABLE_ARM64 +#include "nnacl/custom_gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CustomGru(float *output, const float *input, const float *weight_input, const float *weight_hidden, + const float *bias_input, const float *bias_hidden, const float *init_h, float *buffer[4], + const CustomGruParameter *gru_param); +#ifdef __cplusplus +} +#endif + +#endif +#endif // MINDSPORE_NNACL_FP32_CUSTOM_GRU_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1fbc0b3dc8f3af0ce763eba9538729e1800a8a8f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_fp32.c @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/deconv_fp32.h" + +void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane) { + /* ichwoc(nhwc) -> oc4 * h * w * incUP4 * 4 */ + int ic_up4 = UP_ROUND(input_channel, C4NUM); + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / C4NUM; + int oc4mod = oc % C4NUM; + for (int ic = 0; ic < input_channel; ic++) { + for (int hw = 0; hw < plane; hw++) { + int src_index = ic * plane * output_channel + hw * output_channel + oc; + int dst_index = oc4div * ic_up4 * plane * C4NUM + hw * ic_up4 * C4NUM + ic * C4NUM + oc4mod; + dst[dst_index] = weight[src_index]; + } + } + } + return; +} + +void DeConvPostFp32C8(const float *src, float *tmp, const float *bias, float *dst, int output_channel, + const ConvParameter *conv_param) { + /* arm64 row12x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ + /* arm32 row4x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ + size_t input_plane = conv_param->input_w_ * conv_param->input_h_; + size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + size_t output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc8 = UP_ROUND(output_channel, C8NUM); +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) + const int tile_num = 4; +#else + const int tile_num = 12; +#endif + int in_plane_round = UP_ROUND(input_plane, tile_num); + int src_iw_stride = C8NUM; + int src_ih_stride = conv_param->input_w_ * C8NUM; + int src_kw_stride = in_plane_round * C8NUM; + int src_kh_stride = in_plane_round * conv_param->kernel_w_ * C8NUM; + int dst_oh_stride = conv_param->output_w_ * C8NUM; + int dst_ow_stride = C8NUM; + int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM; + int dst_kw_stride = conv_param->dilation_w_ * C8NUM; + if (conv_param->dilation_h_ == 0 || conv_param->dilation_w_ == 0) { + return; + } + for (int c = 0; c < oc8; c += 8) { + float *dst_ptr = tmp + c * output_plane; + const float *src_ptr = src + c * in_plane_round * kernel_plane; + memset(dst_ptr, 0, output_plane * C8NUM * (int)sizeof(float)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * src_ih_stride + iw * src_iw_stride + kh * src_kh_stride + kw * src_kw_stride; + int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; + float *tmp_dst = dst_ptr + dst_index; + const float *tmp_src = src_ptr + src_index; +#ifdef ENABLE_ARM64 + asm volatile( + "mov x0, %[tmp_src] \n" + "mov x1, %[tmp_dst] \n" + + "ld1 {v0.4s, v1.4s}, [x0] \n" + "ld1 {v2.4s, v3.4s}, [x1] \n" + + "fadd v0.4s, v0.4s, v2.4s \n" + "fadd v1.4s, v1.4s, v3.4s \n" + + "st1 {v0.4s, v1.4s}, [x1] \n" + + : + : [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst) + : "x0", "x1", "v0", "v1", "v2", "v3"); +#else + for (int i = 0; i < C8NUM; i++) { + tmp_dst[i] += tmp_src[i]; + } +#endif + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc8*/ + + PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->act_type_); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..06ba80c149198900f2a9d15165deb81df56588e0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_fp32.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DECONV_H_ +#define MINDSPORE_NNACL_FP32_DECONV_H_ + +#include +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/errorcode.h" +#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl/base/minimal_filtering_generator.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane); +void DeConvPostFp32C8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel, + const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_DECONV_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_winograd_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_winograd_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..679f21033d396a967eff1b4d9b5893835293c0da --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_winograd_fp32.c @@ -0,0 +1,733 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/deconv_winograd_fp32.h" +#include "nnacl/errorcode.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +int PackDeConvWgDataFp32(const float *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param, + const DeConvParam *deconv_param) { +#ifdef ENABLE_AVX + int tile_num = C8NUM; +#else + int tile_num = C4NUM; +#endif + unsigned int tmp_kernel_plane = unit->w_size_ * unit->h_size_; + unsigned int size = conv_param->input_channel_ * conv_param->output_channel_ * tmp_kernel_plane; + float *current_unit_weight = (float *)malloc(size * sizeof(float)); + if (current_unit_weight == NULL) { + return NNACL_NULL_PTR; + } + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + const float *src_ic = nhwc_weight + deconv_param->kernel_plane_ * conv_param->output_channel_ * ic; + float *dst_ic = current_unit_weight + tmp_kernel_plane * conv_param->output_channel_ * ic; + for (int uhi = 0; uhi < unit->h_size_; uhi++) { + for (int uwi = 0; uwi < unit->w_size_; uwi++) { + int src_h_offset = unit->h_start_ + uhi * conv_param->stride_h_; + int src_w_offset = unit->w_start_ + uwi * conv_param->stride_w_; + const float *src_hw = + src_ic + (src_h_offset * conv_param->kernel_w_ + src_w_offset) * conv_param->output_channel_; + float *dst_hw = dst_ic + (uhi * unit->w_size_ + uwi) * conv_param->output_channel_; + memcpy(dst_hw, src_hw, conv_param->output_channel_ * sizeof(float)); + } + } + } + + if (unit->use_winograd_) { + /* Generate winograd */ + float matrix_g[64], matrix_a[64], matrix_b[64]; + float matrix_gt[64], matrix_at[64], matrix_bt[64]; + int ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 0.5f, + DECONV_WINOGRAD_DEFAULT_UNIT, unit->h_size_); + if (ret != NNACL_OK) { + free(current_unit_weight); + current_unit_weight = NULL; + return NNACL_ERRCODE_WINOGRAD_GENERATOR_ERROR; + } + + /* winograd AT */ + unit->winograd_.AT_ = malloc(unit->winograd_.i_ * unit->winograd_.o_ * sizeof(float)); + if (unit->winograd_.AT_ == NULL) { + if (current_unit_weight != NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + } + return NNACL_NULL_PTR; + } + memcpy(unit->winograd_.AT_, matrix_at, unit->winograd_.i_ * unit->winograd_.o_ * sizeof(float)); + + /* winograd BT */ + unit->winograd_.BT_ = malloc(unit->winograd_.o_ * unit->winograd_.o_ * sizeof(float)); + if (unit->winograd_.BT_ == NULL) { + if (current_unit_weight != NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + } + if (unit->winograd_.AT_ != NULL) { + free(unit->winograd_.AT_); + unit->winograd_.AT_ = NULL; + } + return NNACL_NULL_PTR; + } + memcpy(unit->winograd_.BT_, matrix_bt, unit->winograd_.o_ * unit->winograd_.o_ * sizeof(float)); + + /* winograd Weight */ + size = conv_param->input_channel_ * conv_param->output_channel_ * unit->winograd_.kh_ * unit->winograd_.kw_; + float *winograd_unit_weight = (float *)malloc(size * sizeof(float)); + if (winograd_unit_weight == NULL) { + if (current_unit_weight != NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + } + if (unit->winograd_.AT_ != NULL) { + free(unit->winograd_.AT_); + unit->winograd_.AT_ = NULL; + } + if (unit->winograd_.BT_ != NULL) { + free(unit->winograd_.BT_); + unit->winograd_.BT_ = NULL; + } + return NNACL_NULL_PTR; + } + WinogradWeightTransform(current_unit_weight, winograd_unit_weight, matrix_g, matrix_gt, tile_num, + unit->winograd_.kh_, unit->h_size_, conv_param->output_channel_, conv_param->input_channel_, + false); + + /* reset weight data & info */ + tmp_kernel_plane = unit->winograd_.kh_ * unit->winograd_.kw_; + free(current_unit_weight); + current_unit_weight = NULL; + current_unit_weight = winograd_unit_weight; + winograd_unit_weight = NULL; + } + + /* trans mhwc -> hw1:k1-knc0-c4:k1-knc5-c8:hw2:k1-knc0-c4:k1 */ + float *dst_weight = (float *)unit->weight_; + size = deconv_param->ic_up_ * deconv_param->oc_up_ * tmp_kernel_plane; + memset(dst_weight, 0, size * sizeof(float)); + for (int ic = 0; ic < conv_param->input_channel_; ic++) { + for (int oc = 0; oc < conv_param->output_channel_; oc++) { + int oc4div = oc / tile_num, oc4mod = oc % tile_num; + for (int upi = 0; upi < tmp_kernel_plane; upi++) { + int src_index = ic * conv_param->output_channel_ * tmp_kernel_plane + upi * conv_param->output_channel_ + oc; + int dst_index = upi * deconv_param->oc_up_ * deconv_param->ic_up_ + oc4div * tile_num * deconv_param->ic_up_ + + ic * tile_num + oc4mod; + dst_weight[dst_index] = current_unit_weight[src_index]; + } + } + } + + if (current_unit_weight != NULL) { + free(current_unit_weight); + current_unit_weight = NULL; + } + return NNACL_OK; +} + +void DeConvWgInputPack(const float *src_ptr, float *dst_ptr, int channel, int stride) { +#ifdef ENABLE_AVX + int ic_tile = C8NUM; +#else + int ic_tile = C4NUM; +#endif + int ic4div = channel / ic_tile; + int ic4mod = channel % ic_tile; + const float *src = src_ptr; + float *dst = dst_ptr; + + for (int ic = 0; ic < ic4div; ic++) { +#ifdef ENABLE_AVX + MS_ST256_F32(dst, MS_LD256_F32(src)); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_STQ_F32(dst, MS_LDQ_F32(src)); +#else + memcpy(dst, src, C4NUM * sizeof(float)); +#endif + dst += stride; + src += ic_tile; + } + + if (ic4mod != 0) { + int ic_res = 0; + for (; ic_res < ic4mod; ic_res++) { + dst[ic_res] = src[ic_res]; + } + for (; ic_res < ic_tile; ic_res++) { + dst[ic_res] = 0; + } + } +} + +#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE) +void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { + int dx, sz, dz; + const int src_depth_step = C4NUM * DECONV_WINOGRAD_DEFAULT_TILE; + for (dz = 0; dz < oc4; ++dz) { + float *dst_z = dst + dz * cal_num; + const float *weight_dz = weight + dz * ic4 * C16NUM; + for (dx = 0; dx < DECONV_WINOGRAD_DEFAULT_TILE; ++dx) { + float *dst_x = dst_z + dx * C4NUM; + dst_x[0] = 0.0f; + dst_x[1] = 0.0f; + dst_x[2] = 0.0f; + dst_x[3] = 0.0f; + const float *src_dx = src + C4NUM * dx; + for (sz = 0; sz < ic4; ++sz) { + const float *src_z = src_dx + sz * src_depth_step; + const float *weight_z = weight_dz + sz * C16NUM; + for (int i = 0; i < C4NUM; ++i) { + for (int j = 0; j < C4NUM; ++j) { + dst_x[j] += src_z[i] * weight_z[C4NUM * i + j]; + } + } + } + } + } +} +#endif + +#ifdef ENABLE_ARM32 +#ifndef SUPPORT_NNIE +void DeConvWgMergeArm32(const float *src_ptr, float *dst_ptr, size_t src_step, size_t dst_step) { + asm volatile( + "mov r11, %[src_ptr]\n" + "mov r8, %[dst_ptr]\n" + "mov r10, r8\n" + + "vld1.32 {q0}, [r11], %[src_step]\n" + "vld1.32 {q1}, [r8], %[dst_step]\n" + "vld1.32 {q2}, [r11], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vld1.32 {q8}, [r11], %[src_step]\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r11], %[src_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + "vadd.f32 q10, q10, q11\n" + + "vld1.32 {q0}, [r11], %[src_step]\n" + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + "vld1.32 {q1}, [r8], %[dst_step]\n" + + "vld1.32 {q2}, [r11], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q8}, [r11], %[src_step]\n" + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r11], %[src_step]\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vadd.f32 q10, q10, q11\n" + + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + : + : [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step) + : "r8", "r10", "r11", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11"); + return; +} +#else +void DeConvWgMergeArm32(const float *src_ptr, float *dst_ptr, size_t src_step, size_t dst_step) { + asm volatile( + "mov r7, %[src_ptr]\n" + "mov r8, %[dst_ptr]\n" + "mov r10, r8\n" + + "vld1.32 {q0}, [r7], %[src_step]\n" + "vld1.32 {q1}, [r8], %[dst_step]\n" + "vld1.32 {q2}, [r7], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vld1.32 {q8}, [r7], %[src_step]\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r7], %[src_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + "vadd.f32 q10, q10, q11\n" + + "vld1.32 {q0}, [r7], %[src_step]\n" + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + "vld1.32 {q1}, [r8], %[dst_step]\n" + + "vld1.32 {q2}, [r7], %[src_step]\n" + "vld1.32 {q3}, [r8], %[dst_step]\n" + + "vadd.f32 q0, q0, q1\n" + "vadd.f32 q2, q2, q3\n" + + "vst1.32 {q0}, [r10], %[dst_step]\n" + "vst1.32 {q2}, [r10], %[dst_step]\n" + + "vld1.32 {q8}, [r7], %[src_step]\n" + "vld1.32 {q9}, [r8], %[dst_step]\n" + + "vld1.32 {q10}, [r7], %[src_step]\n" + "vld1.32 {q11}, [r8], %[dst_step]\n" + + "vadd.f32 q8, q8, q9\n" + "vadd.f32 q10, q10, q11\n" + + "vst1.32 {q8}, [r10], %[dst_step]\n" + "vst1.32 {q10}, [r10], %[dst_step]\n" + + : + : [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step) + : "r8", "r10", "r7", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11"); + return; +} +#endif +#endif + +#ifdef ENABLE_AVX +void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) { + const float *src_ptr = src; + float *dst_ptr = dst; + size_t count8 = count / C8NUM * C8NUM; + size_t count4 = count / C4NUM * C4NUM; + int i = 0; + for (; i < count8; i += C8NUM) { + MS_FLOAT32X8 src1 = MS_LD256_F32(src_ptr + 0 * src_stride); + MS_FLOAT32X8 src2 = MS_LD256_F32(src_ptr + 1 * src_stride); + MS_FLOAT32X8 src3 = MS_LD256_F32(src_ptr + 2 * src_stride); + MS_FLOAT32X8 src4 = MS_LD256_F32(src_ptr + 3 * src_stride); + MS_FLOAT32X8 src5 = MS_LD256_F32(src_ptr + 4 * src_stride); + MS_FLOAT32X8 src6 = MS_LD256_F32(src_ptr + 5 * src_stride); + MS_FLOAT32X8 src7 = MS_LD256_F32(src_ptr + 6 * src_stride); + MS_FLOAT32X8 src8 = MS_LD256_F32(src_ptr + 7 * src_stride); + MS_FLOAT32X8 dst1 = MS_LD256_F32(dst_ptr + 0 * dst_stride); + MS_FLOAT32X8 dst2 = MS_LD256_F32(dst_ptr + 1 * dst_stride); + MS_FLOAT32X8 dst3 = MS_LD256_F32(dst_ptr + 2 * dst_stride); + MS_FLOAT32X8 dst4 = MS_LD256_F32(dst_ptr + 3 * dst_stride); + MS_FLOAT32X8 dst5 = MS_LD256_F32(dst_ptr + 4 * dst_stride); + MS_FLOAT32X8 dst6 = MS_LD256_F32(dst_ptr + 5 * dst_stride); + MS_FLOAT32X8 dst7 = MS_LD256_F32(dst_ptr + 6 * dst_stride); + MS_FLOAT32X8 dst8 = MS_LD256_F32(dst_ptr + 7 * dst_stride); + dst1 = MS_ADD256_F32(dst1, src1); + dst2 = MS_ADD256_F32(dst2, src2); + dst3 = MS_ADD256_F32(dst3, src3); + dst4 = MS_ADD256_F32(dst4, src4); + dst5 = MS_ADD256_F32(dst5, src5); + dst6 = MS_ADD256_F32(dst6, src6); + dst7 = MS_ADD256_F32(dst7, src7); + dst8 = MS_ADD256_F32(dst8, src8); + MS_ST256_F32(dst_ptr + 0 * dst_stride, dst1); + MS_ST256_F32(dst_ptr + 1 * dst_stride, dst2); + MS_ST256_F32(dst_ptr + 2 * dst_stride, dst3); + MS_ST256_F32(dst_ptr + 3 * dst_stride, dst4); + MS_ST256_F32(dst_ptr + 4 * dst_stride, dst5); + MS_ST256_F32(dst_ptr + 5 * dst_stride, dst6); + MS_ST256_F32(dst_ptr + 6 * dst_stride, dst7); + MS_ST256_F32(dst_ptr + 7 * dst_stride, dst8); + src_ptr += C8NUM * src_stride; + dst_ptr += C8NUM * dst_stride; + } + for (; i < count4; i += C4NUM) { + MS_FLOAT32X8 src1 = MS_LD256_F32(src_ptr + 0 * src_stride); + MS_FLOAT32X8 src2 = MS_LD256_F32(src_ptr + 1 * src_stride); + MS_FLOAT32X8 src3 = MS_LD256_F32(src_ptr + 2 * src_stride); + MS_FLOAT32X8 src4 = MS_LD256_F32(src_ptr + 3 * src_stride); + MS_FLOAT32X8 dst1 = MS_LD256_F32(dst_ptr + 0 * dst_stride); + MS_FLOAT32X8 dst2 = MS_LD256_F32(dst_ptr + 1 * dst_stride); + MS_FLOAT32X8 dst3 = MS_LD256_F32(dst_ptr + 2 * dst_stride); + MS_FLOAT32X8 dst4 = MS_LD256_F32(dst_ptr + 3 * dst_stride); + dst1 = MS_ADD256_F32(dst1, src1); + dst2 = MS_ADD256_F32(dst2, src2); + dst3 = MS_ADD256_F32(dst3, src3); + dst4 = MS_ADD256_F32(dst4, src4); + MS_ST256_F32(dst_ptr + 0 * dst_stride, dst1); + MS_ST256_F32(dst_ptr + 1 * dst_stride, dst2); + MS_ST256_F32(dst_ptr + 2 * dst_stride, dst3); + MS_ST256_F32(dst_ptr + 3 * dst_stride, dst4); + src_ptr += C4NUM * src_stride; + dst_ptr += C4NUM * dst_stride; + } + for (; i < count; i++) { + MS_FLOAT32X8 src_data = MS_LD256_F32(src_ptr); + MS_FLOAT32X8 dst_data = MS_LD256_F32(dst_ptr); + dst_data = MS_ADD256_F32(src_data, dst_data); + MS_ST256_F32(dst_ptr, dst_data); + src_ptr += src_stride; + dst_ptr += dst_stride; + } +} +#else +void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) { + const float *src_ptr = src; + float *dst_ptr = dst; + size_t count8 = count / C8NUM * C8NUM; + size_t count4 = count / C4NUM * C4NUM; + int i = 0; + for (; i < count8; i += C8NUM) { +#ifdef ENABLE_ARM64 + size_t src_step = src_stride * sizeof(float); + size_t dst_step = dst_stride * sizeof(float); + asm volatile( + "mov x7, %[src_ptr]\n" + "mov x8, %[dst_ptr]\n" + "mov x10, x8\n" + + "ld1 {v0.4s}, [x7], %[src_step]\n" + "ld1 {v1.4s}, [x8], %[dst_step]\n" + + "ld1 {v2.4s}, [x7], %[src_step]\n" + "ld1 {v3.4s}, [x8], %[dst_step]\n" + + "fadd v0.4s, v0.4s, v1.4s\n" + "ld1 {v4.4s}, [x7], %[src_step]\n" + "fadd v2.4s, v2.4s, v3.4s\n" + + "st1 {v0.4s}, [x10], %[dst_step]\n" + "st1 {v2.4s}, [x10], %[dst_step]\n" + + "ld1 {v5.4s}, [x8], %[dst_step]\n" + + "ld1 {v6.4s}, [x7], %[src_step]\n" + + "fadd v4.4s, v4.4s, v5.4s\n" + "ld1 {v7.4s}, [x8], %[dst_step]\n" + "fadd v6.4s, v6.4s, v7.4s\n" + + "ld1 {v0.4s}, [x7], %[src_step]\n" + "st1 {v4.4s}, [x10], %[dst_step]\n" + "st1 {v6.4s}, [x10], %[dst_step]\n" + + "ld1 {v1.4s}, [x8], %[dst_step]\n" + + "ld1 {v2.4s}, [x7], %[src_step]\n" + "ld1 {v3.4s}, [x8], %[dst_step]\n" + + "fadd v0.4s, v0.4s, v1.4s\n" + "fadd v2.4s, v2.4s, v3.4s\n" + + "st1 {v0.4s}, [x10], %[dst_step]\n" + "st1 {v2.4s}, [x10], %[dst_step]\n" + + "ld1 {v4.4s}, [x7], %[src_step]\n" + "ld1 {v5.4s}, [x8], %[dst_step]\n" + + "ld1 {v6.4s}, [x7], %[src_step]\n" + "ld1 {v7.4s}, [x8], %[dst_step]\n" + + "fadd v4.4s, v4.4s, v5.4s\n" + "fadd v6.4s, v6.4s, v7.4s\n" + + "st1 {v4.4s}, [x10], %[dst_step]\n" + "st1 {v6.4s}, [x10], %[dst_step]\n" + + : + : [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step) + : "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#elif defined(ENABLE_ARM32) + size_t src_step = src_stride * sizeof(float); + size_t dst_step = dst_stride * sizeof(float); + DeConvWgMergeArm32(src_ptr, dst_ptr, src_step, dst_step); +#elif defined(ENABLE_SSE) + MS_STQ_F32(dst_ptr + 0 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 0 * dst_stride), MS_LDQ_F32(src_ptr + 0 * src_stride))); + MS_STQ_F32(dst_ptr + 1 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 1 * dst_stride), MS_LDQ_F32(src_ptr + 1 * src_stride))); + MS_STQ_F32(dst_ptr + 2 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 2 * dst_stride), MS_LDQ_F32(src_ptr + 2 * src_stride))); + MS_STQ_F32(dst_ptr + 3 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 3 * dst_stride), MS_LDQ_F32(src_ptr + 3 * src_stride))); + MS_STQ_F32(dst_ptr + C4NUM * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + C4NUM * dst_stride), MS_LDQ_F32(src_ptr + C4NUM * src_stride))); + MS_STQ_F32(dst_ptr + 5 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 5 * dst_stride), MS_LDQ_F32(src_ptr + 5 * src_stride))); + MS_STQ_F32(dst_ptr + 6 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 6 * dst_stride), MS_LDQ_F32(src_ptr + 6 * src_stride))); + MS_STQ_F32(dst_ptr + 7 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 7 * dst_stride), MS_LDQ_F32(src_ptr + 7 * src_stride))); +#else + for (int j = 0; j < C8NUM; j++) { + const float *s = src_ptr + j * src_stride; + float *d = dst_ptr + j * dst_stride; + for (int k = 0; k < C4NUM; k++) { + d[k] += s[k]; + } + } +#endif + src_ptr += C8NUM * src_stride; + dst_ptr += C8NUM * dst_stride; + } + for (; i < count4; i += C4NUM) { +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + MS_STQ_F32(dst_ptr + 0 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 0 * dst_stride), MS_LDQ_F32(src_ptr + 0 * src_stride))); + MS_STQ_F32(dst_ptr + 1 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 1 * dst_stride), MS_LDQ_F32(src_ptr + 1 * src_stride))); + MS_STQ_F32(dst_ptr + 2 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 2 * dst_stride), MS_LDQ_F32(src_ptr + 2 * src_stride))); + MS_STQ_F32(dst_ptr + 3 * dst_stride, + MS_ADDQ_F32(MS_LDQ_F32(dst_ptr + 3 * dst_stride), MS_LDQ_F32(src_ptr + 3 * src_stride))); +#else + for (int j = 0; j < C4NUM; j++) { + const float *s = src_ptr + j * src_stride; + float *d = dst_ptr + j * dst_stride; + for (int k = 0; k < C4NUM; k++) { + d[k] += s[k]; + } + } +#endif + src_ptr += C4NUM * src_stride; + dst_ptr += C4NUM * dst_stride; + } + for (; i < count; i++) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src_data = MS_LDQ_F32(src_ptr); + MS_FLOAT32X4 dst_data = MS_LDQ_F32(dst_ptr); + dst_data = MS_ADDQ_F32(src_data, dst_data); + MS_STQ_F32(dst_ptr, dst_data); +#else + for (int j = 0; j < C4NUM; j++) { + dst_ptr[j] += src_ptr[j]; + } +#endif + src_ptr += src_stride; + dst_ptr += dst_stride; + } +} +#endif + +void DeConvWgCalWgFp32(const float *tile_in, float *tile_out, const float *weight_buf, float *tmp_buf, + const float *at_buf, float *a_mid_buf, float *trans_a_buf, bool *transferred, + const float *bt_buf, float *b_tmp_buf, int unit_size, int w_start, int h_start, + const ConvParameter *conv_param, const DeConvParam *deconv_param) { +#ifdef ENABLE_AVX + int tile_num = C8NUM; + TiledMatmulFp32 matmul_func = TiledC8MatmulFp32; +#else + TiledMatmulFp32 matmul_func = TiledC4MatmulFp32; + int tile_num = C4NUM; +#endif + int winograd_plane = unit_size * unit_size; + if (!transferred[unit_size]) { + WinogradTransLeft(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, + deconv_param->ic_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRight(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, + deconv_param->ic_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + transferred[unit_size] = true; + } + + for (int index = 0; index < winograd_plane; index++) { + float *src = trans_a_buf + index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float *dst = tmp_buf + index * deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + const float *weight = weight_buf + index * deconv_param->ic_up_ * deconv_param->oc_up_; + matmul_func(dst, src, weight, DECONV_WINOGRAD_DEFAULT_TILE * tile_num, deconv_param->ic_div_, + deconv_param->oc_div_); + } + WinogradTransLeft(tmp_buf, bt_buf, b_tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRight(b_tmp_buf, bt_buf, tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE); + + // Add to dest + for (int uhi = 0; uhi < unit_size; uhi++) { + int h_index = uhi * conv_param->stride_h_ + h_start; + for (int uwi = 0; uwi < unit_size; uwi++) { + int w_index = uwi * conv_param->stride_w_ + w_start; + + float *dst = tile_out + w_index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_ + + h_index * deconv_param->out_tile_w_ * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + float *src = tmp_buf + (uwi + uhi * unit_size) * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + DeConvWgMerge(src, dst, tile_num, tile_num, DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_div_); + } + } +} + +void DeConvWgCalCommFp32(const float *tile_in, float *tile_out, const float *weight, float *tmp_buf, int h_start, + int w_start, int h_size, int w_size, const ConvParameter *conv_param, + const DeConvParam *deconv_param) { +#ifdef ENABLE_AVX + int tile_num = C8NUM; + TiledMatmulFp32 matmul_func = TiledC8MatmulFp32; +#else + TiledMatmulFp32 matmul_func = TiledC4MatmulFp32; + int tile_num = C4NUM; +#endif + int count = deconv_param->oc_div_ * w_size * h_size; + int in_stride = DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + int out_stride = DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_up_; + + for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { + for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { + const float *src_in = tile_in + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * in_stride; + matmul_func(tmp_buf, src_in, weight, DECONV_WINOGRAD_DEFAULT_TILE * tile_num, deconv_param->ic_div_, count); + + for (int uhi = 0; uhi < h_size; uhi++) { + for (int uwi = 0; uwi < w_size; uwi++) { + int w_index = (wi + uwi) * conv_param->stride_w_ + w_start; + int h_index = (hi + uhi) * conv_param->stride_h_ + h_start; + float *dst = tile_out + h_index * out_stride * deconv_param->out_tile_w_ + w_index * out_stride; + float *src = tmp_buf + (uwi + uhi * w_size) * out_stride; + DeConvWgMerge(src, dst, tile_num, tile_num, DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->oc_div_); + } + } + } + } +} + +int DeconvWg(const float *nhwc_input_, float *tile_in, float *tile_out, int start_index, int calculate_count, + const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id) { + if (deconv_param->in_tile_w_count_ == 0) { + return NNACL_ERR; + } + /* pack tile input */ + int tile_in_unit_stride = deconv_param->ic_up_ * DECONV_WINOGRAD_DEFAULT_TILE; +#ifdef ENABLE_AVX + int tile_num = C8NUM; + MS_FLOAT32X8 zero = MS_MOV256_F32(0.0f); +#else + int tile_num = C4NUM; +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f); +#endif +#endif + for (int unit_index = 0; unit_index < calculate_count; unit_index++) { + int plane_index = start_index + unit_index; + int w_unit_index = plane_index % deconv_param->in_tile_w_count_; + int h_unit_index = plane_index / deconv_param->in_tile_w_count_; + int w_start = w_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT; + int h_start = h_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT; + + float *dst_unit = tile_in + unit_index * tile_num; + for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { + for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { + float *dst = dst_unit + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * tile_in_unit_stride; + int w_index = w_start + wi; + int h_index = h_start + hi; + if (w_index >= conv_param->input_w_ || h_index >= conv_param->input_h_) { + for (int ic4_index = 0; ic4_index < deconv_param->ic_div_; ic4_index++) { +#ifdef ENABLE_AVX + MS_ST256_F32(dst + ic4_index * DECONV_WINOGRAD_DEFAULT_TILE * tile_num, zero); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_STQ_F32(dst + ic4_index * DECONV_WINOGRAD_DEFAULT_TILE * tile_num, zero); +#else + for (int i = 0; i < tile_num; i++) { + dst[tile_num * DECONV_WINOGRAD_DEFAULT_TILE * ic4_index + i] = 0; + } +#endif + } + continue; + } + + const float *src = nhwc_input_ + (w_index + h_index * conv_param->input_w_) * conv_param->input_channel_; + DeConvWgInputPack(src, dst, conv_param->input_channel_, DECONV_WINOGRAD_DEFAULT_TILE * tile_num); + } + } + } + + /* compute */ + bool transferred[DECONV_WINOGRAD_BUFFER_COUNT] = {false}; + for (int i = 0; i < deconv_param->compute_size_; i++) { + DeConvComputeUnit *unit = &deconv_param->compute_units_[i]; + if (unit->use_winograd_) { + float *tmp_buf = (float *)unit->tmp_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * + deconv_param->oc_div_ * DECONV_WINOGRAD_DEFAULT_TILE * tile_num; + + /* winograd a buffer */ + if (unit->winograd_.kh_ >= DECONV_WINOGRAD_BUFFER_COUNT || unit->winograd_.AT_ == NULL) { + return NNACL_ERR; + } + DeConvWgABuffer *wg_buf = &deconv_param->a_buffer_[unit->winograd_.kh_]; + float *wg_mid_a_buf = (float *)wg_buf->middle_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float *wg_dst_a_buf = (float *)wg_buf->dest_buffer_ + task_id * unit->winograd_.kw_ * unit->winograd_.kh_ * + DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up_; + float *tmp_b_buf = (float *)unit->winograd_.b_buffer_ + task_id * unit->winograd_.kh_ * unit->winograd_.kw_ * + deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + DeConvWgCalWgFp32(tile_in, tile_out, (float *)unit->weight_, tmp_buf, unit->winograd_.AT_, wg_mid_a_buf, + wg_dst_a_buf, transferred, unit->winograd_.BT_, tmp_b_buf, unit->winograd_.kh_, unit->w_start_, + unit->h_start_, conv_param, deconv_param); + } else { + float *tmp_buf = (float *)unit->tmp_buffer_ + task_id * deconv_param->oc_div_ * unit->w_size_ * unit->h_size_ * + DECONV_WINOGRAD_DEFAULT_TILE * tile_num; + DeConvWgCalCommFp32(tile_in, tile_out, (float *)unit->weight_, tmp_buf, unit->h_start_, unit->w_start_, + unit->h_size_, unit->w_size_, conv_param, deconv_param); + } + } + return NNACL_OK; +} + +int DeconvWgPost(const float *tile_out, float *nc4hw4_output, const ConvParameter *conv_param, + const DeConvParam *deconv_param, int calculate_count, int tile_index) { +#ifdef ENABLE_AVX + int tile_num = C8NUM; +#else + int tile_num = C4NUM; +#endif + + /* merge */ + int src_unit_stride = deconv_param->oc_up_ * DECONV_WINOGRAD_DEFAULT_TILE; + + int src_stride = DECONV_WINOGRAD_DEFAULT_TILE * tile_num; + int dst_stride = conv_param->output_w_ * conv_param->output_h_ * tile_num; + + for (int index = 0; index < calculate_count; ++index) { + const float *src_start = tile_out + index * tile_num; + + int plane_index = tile_index * DECONV_WINOGRAD_DEFAULT_TILE + index; + int w_unit_index = plane_index % deconv_param->in_tile_w_count_; + int h_unit_index = plane_index / deconv_param->in_tile_w_count_; + int w_start = w_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT * conv_param->stride_w_ - conv_param->pad_l_; + int h_start = h_unit_index * DECONV_WINOGRAD_DEFAULT_UNIT * conv_param->stride_h_ - conv_param->pad_u_; + float *dst_start = nc4hw4_output + h_start * conv_param->output_w_ * tile_num + w_start * tile_num; + + int merge_w_start = MSMAX(-w_start, 0); + int merge_h_start = MSMAX(-h_start, 0); + int merge_h_end = MSMIN(deconv_param->out_tile_h_, conv_param->output_h_ - h_start); + int merge_w_end = MSMIN(deconv_param->out_tile_w_, conv_param->output_w_ - w_start); + + for (int hi = merge_h_start; hi < merge_h_end; hi++) { + for (int wi = merge_w_start; wi < merge_w_end; wi++) { + const float *src = src_start + (hi * deconv_param->out_tile_w_ + wi) * src_unit_stride; + float *dst = dst_start + (hi * conv_param->output_w_ + wi) * tile_num; + DeConvWgMerge(src, dst, src_stride, dst_stride, deconv_param->oc_div_); + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_winograd_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_winograd_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..3492679063d3018d47c14ef6860976660334288b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/deconv_winograd_fp32.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DECONV_WINOGRAD_H_ +#define MINDSPORE_NNACL_FP32_DECONV_WINOGRAD_H_ + +#include +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/errorcode.h" +#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl/base/minimal_filtering_generator.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*TiledMatmulFp32)(float *dst, const float *src, const float *weight, size_t ic_tiled, size_t cal_num, + size_t oc_tiled); + +int PackDeConvWgDataFp32(const float *nhwc_weight, DeConvComputeUnit *unit, const ConvParameter *conv_param, + const DeConvParam *deconv_param); +int DeconvWg(const float *nhwc_input_, float *tile_in, float *tile_out, int start_index, int calculate_count, + const ConvParameter *conv_param, DeConvParam *deconv_param, int task_id); +int DeconvWgPost(const float *tile_out, float *nc4hw4_output, const ConvParameter *conv_param, + const DeConvParam *deconv_param, int calculate_count, int tile_index); +void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t ic4, size_t cal_num, size_t oc4); +void TiledC8MatmulFp32(float *dst, const float *src, const float *weight, size_t ic8, size_t cal_num, size_t oc8); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_DECONV_WINOGRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/detection_post_process_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/detection_post_process_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b9da2ccc4eedb0315742c48f689cf2800499b442 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/detection_post_process_fp32.c @@ -0,0 +1,235 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/detection_post_process_fp32.h" +#include +#include "nnacl/errorcode.h" +#include "nnacl/op_base.h" +#include "nnacl/nnacl_utils.h" + +float IntersectionOverUnion(const BboxCorner *a, const BboxCorner *b) { + const float area_a = (a->ymax - a->ymin) * (a->xmax - a->xmin); + const float area_b = (b->ymax - b->ymin) * (b->xmax - b->xmin); + if (area_a <= 0 || area_b <= 0) { + return 0.0f; + } + const float ymin = a->ymin > b->ymin ? a->ymin : b->ymin; + const float xmin = a->xmin > b->xmin ? a->xmin : b->xmin; + const float ymax = a->ymax < b->ymax ? a->ymax : b->ymax; + const float xmax = a->xmax < b->xmax ? a->xmax : b->xmax; + const float h = ymax - ymin > 0.0f ? ymax - ymin : 0.0f; + const float w = xmax - xmin > 0.0f ? xmax - xmin : 0.0f; + const float inter = h * w; + return inter / (area_a + area_b - inter); +} + +int DecodeBoxes(int num_boxes, const float *input_boxes, const float *anchors, + const DetectionPostProcessParameter *param) { + if (input_boxes == NULL || anchors == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + float *decoded_boxes = (float *)param->decoded_boxes_; + BboxCenter scaler; + scaler.y = param->y_scale_; + scaler.x = param->x_scale_; + scaler.h = param->h_scale_; + scaler.w = param->w_scale_; + for (int i = 0; i < num_boxes; ++i) { + BboxCenter *box = (BboxCenter *)(input_boxes) + i; + BboxCenter *anchor = (BboxCenter *)(anchors) + i; + BboxCorner *decoded_box = (BboxCorner *)(decoded_boxes) + i; + float y_center = box->y / scaler.y * anchor->h + anchor->y; + float x_center = box->x / scaler.x * anchor->w + anchor->x; + const float h_half = 0.5f * expf(box->h / scaler.h) * anchor->h; + const float w_half = 0.5f * expf(box->w / scaler.w) * anchor->w; + decoded_box->ymin = y_center - h_half; + decoded_box->xmin = x_center - w_half; + decoded_box->ymax = y_center + h_half; + decoded_box->xmax = x_center + w_half; + } + return NNACL_OK; +} + +int NmsSingleClass(const int num_boxes, const float *decoded_boxes, const int max_detections, const float *scores, + int32_t *selected, void (*PartialArgSort)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param) { + if (PartialArgSort == NULL) { + return NNACL_NULL_PTR; + } + uint8_t *nms_candidate = param->nms_candidate_; + const int output_num = num_boxes < max_detections ? num_boxes : max_detections; + int possible_candidate_num = num_boxes; + int selected_num = 0; + int32_t *indexes = (int32_t *)param->single_class_indexes_; + for (int i = 0; i < num_boxes; ++i) { + indexes[i] = i; + nms_candidate[i] = 1; + } + PartialArgSort(scores, indexes, num_boxes, num_boxes); + for (int i = 0; i < num_boxes; ++i) { + if (possible_candidate_num == 0 || selected_num >= output_num || scores[indexes[i]] < param->nms_score_threshold_) { + break; + } + if (nms_candidate[indexes[i]] == 0) { + continue; + } + selected[selected_num++] = indexes[i]; + nms_candidate[indexes[i]] = 0; + possible_candidate_num--; + const BboxCorner *bbox_i = (BboxCorner *)(decoded_boxes) + indexes[i]; + for (int t = i + 1; t < num_boxes; ++t) { + if (scores[indexes[t]] < param->nms_score_threshold_) break; + if (nms_candidate[indexes[t]] == 1) { + const BboxCorner *bbox_t = (BboxCorner *)(decoded_boxes) + indexes[t]; + const float iou = IntersectionOverUnion(bbox_i, bbox_t); + if (iou > param->nms_iou_threshold_) { + nms_candidate[indexes[t]] = 0; + possible_candidate_num--; + } + } + } + } + return selected_num; +} + +int NmsMultiClassesFastCore(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + void (*PartialArgSort)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param, const int task_id, const int thread_num) { + if (input_scores == NULL || param == NULL || PartialArgSort == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + const int first_class_index = num_classes_with_bg - (int)(param->num_classes_); + const int64_t max_classes_per_anchor = + param->max_classes_per_detection_ < param->num_classes_ ? param->max_classes_per_detection_ : param->num_classes_; + float *scores = (float *)param->scores_; + for (int i = task_id; i < num_boxes; i += thread_num) { + int32_t *indexes = (int32_t *)param->indexes_ + i * param->num_classes_; + for (int j = 0; j < param->num_classes_; ++j) { + indexes[j] = i * num_classes_with_bg + first_class_index + j; + } + PartialArgSort(input_scores, indexes, max_classes_per_anchor, param->num_classes_); + scores[i] = input_scores[indexes[0]]; + } + return NNACL_OK; +} + +int DetectionPostProcessFast(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + const float *decoded_boxes, float *output_boxes, float *output_classes, + float *output_scores, float *output_num, + void (*PartialArgSort)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param) { + if (input_scores == NULL || decoded_boxes == NULL || output_boxes == NULL || output_classes == NULL || + output_scores == NULL || output_num == NULL || param == NULL || PartialArgSort == NULL) { + return NNACL_NULL_PTR; + } + int out_num = 0; + const int first_class_index = num_classes_with_bg - (int)(param->num_classes_); + const int64_t max_classes_per_anchor = + param->max_classes_per_detection_ < param->num_classes_ ? param->max_classes_per_detection_ : param->num_classes_; + int32_t *selected = (int32_t *)param->selected_; + int selected_num = NmsSingleClass(num_boxes, decoded_boxes, param->max_detections_, (float *)param->scores_, selected, + PartialArgSort, param); + for (int i = 0; i < selected_num; ++i) { + int32_t *indexes = (int32_t *)param->indexes_ + selected[i] * param->num_classes_; + BboxCorner *box = (BboxCorner *)(decoded_boxes) + selected[i]; + for (int j = 0; j < max_classes_per_anchor; ++j) { + *((BboxCorner *)(output_boxes) + out_num) = *box; + output_scores[out_num] = input_scores[indexes[j]]; + NNACL_ASSERT(num_classes_with_bg != 0); + output_classes[out_num++] = (float)(indexes[j] % num_classes_with_bg - first_class_index); + } + } + *output_num = (float)out_num; + for (int i = out_num; i < param->max_detections_ * param->max_classes_per_detection_; ++i) { + ((BboxCorner *)(output_boxes) + i)->ymin = 0; + ((BboxCorner *)(output_boxes) + i)->xmin = 0; + ((BboxCorner *)(output_boxes) + i)->ymax = 0; + ((BboxCorner *)(output_boxes) + i)->xmax = 0; + output_scores[i] = 0; + output_classes[i] = 0; + } + return NNACL_OK; +} + +int DetectionPostProcessRegular(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + float *output_boxes, float *output_classes, float *output_scores, float *output_num, + void (*PartialArgSort)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param) { + if (input_scores == NULL || output_boxes == NULL || output_classes == NULL || output_scores == NULL || + output_num == NULL || param == NULL || PartialArgSort == NULL) { + return NNACL_NULL_PTR; + } + const int first_class_index = num_classes_with_bg - (int)(param->num_classes_); + float *decoded_boxes = (float *)param->decoded_boxes_; + int32_t *selected = (int32_t *)param->selected_; + float *scores = (float *)param->scores_; + float *all_scores = (float *)param->all_class_scores_; + int32_t *indexes = (int32_t *)(param->indexes_); + int32_t *all_indexes = (int32_t *)(param->all_class_indexes_); + int all_classes_sorted_num = 0; + int all_classes_output_num = 0; + for (int j = first_class_index; j < num_classes_with_bg; ++j) { + // process single class + for (int i = 0; i < num_boxes; ++i) { + scores[i] = input_scores[i * num_classes_with_bg + j]; + } + int selected_num = + NmsSingleClass(num_boxes, decoded_boxes, param->detections_per_class_, scores, selected, PartialArgSort, param); + for (int i = 0; i < all_classes_sorted_num; ++i) { + indexes[i] = all_indexes[i]; + all_indexes[i] = i; + } + // process all classes + for (int i = 0; i < selected_num; ++i) { + indexes[all_classes_sorted_num] = selected[i] * num_classes_with_bg + j; + all_indexes[all_classes_sorted_num] = all_classes_sorted_num; + all_scores[all_classes_sorted_num++] = scores[selected[i]]; + } + all_classes_output_num = + all_classes_sorted_num < param->max_detections_ ? all_classes_sorted_num : param->max_detections_; + PartialArgSort(all_scores, all_indexes, all_classes_output_num, all_classes_sorted_num); + for (int i = 0; i < all_classes_output_num; ++i) { + scores[i] = all_scores[all_indexes[i]]; + all_indexes[i] = indexes[all_indexes[i]]; + } + for (int i = 0; i < all_classes_output_num; ++i) { + all_scores[i] = scores[i]; + } + all_classes_sorted_num = all_classes_output_num; + } + for (int i = 0; i < param->max_detections_ * param->max_classes_per_detection_; ++i) { + if (i < all_classes_output_num) { + NNACL_CHECK_ZERO_RETURN_ERR(num_classes_with_bg); + const int box_index = all_indexes[i] / num_classes_with_bg; + const int class_index = all_indexes[i] % num_classes_with_bg - first_class_index; + *((BboxCorner *)(output_boxes) + i) = *((BboxCorner *)(decoded_boxes) + box_index); + output_classes[i] = (float)class_index; + output_scores[i] = all_scores[i]; + } else { + ((BboxCorner *)(output_boxes) + i)->ymin = 0; + ((BboxCorner *)(output_boxes) + i)->xmin = 0; + ((BboxCorner *)(output_boxes) + i)->ymax = 0; + ((BboxCorner *)(output_boxes) + i)->xmax = 0; + output_classes[i] = 0.0f; + output_scores[i] = 0.0f; + } + } + *output_num = (float)all_classes_output_num; + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/detection_post_process_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/detection_post_process_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..a3998768a1500f95840754f876791fb966333ace --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/detection_post_process_fp32.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_DETECTION_POST_PROCESS_H_ +#define MINDSPORE_NNACL_FP32_DETECTION_POST_PROCESS_H_ + +#include "nnacl/op_base.h" +#include "nnacl/detection_post_process_parameter.h" + +typedef struct { + float y; + float x; + float h; + float w; +} BboxCenter; + +typedef struct { + float ymin; + float xmin; + float ymax; + float xmax; +} BboxCorner; + +#ifdef __cplusplus +extern "C" { +#endif +int DecodeBoxes(int num_boxes, const float *input_boxes, const float *anchors, + const DetectionPostProcessParameter *param); + +int NmsMultiClassesFastCore(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + void (*)(const float *, int32_t *, int, int), const DetectionPostProcessParameter *param, + const int task_id, const int thread_num); + +int DetectionPostProcessFast(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + const float *decoded_boxes, float *output_boxes, float *output_classes, + float *output_scores, float *output_num, void (*)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param); + +int DetectionPostProcessRegular(const int num_boxes, const int num_classes_with_bg, const float *input_scores, + float *output_boxes, float *output_classes, float *output_scores, float *output_num, + void (*)(const float *, int32_t *, int, int), + const DetectionPostProcessParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_DETECTION_POST_PROCESS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/div_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/div_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..bfe2aea2dd4be8043c889b254efe48f4e149b0f9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/div_fp32.c @@ -0,0 +1,136 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/div_fp32.h" +#include +#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/div_fp32_simd.h" + +int ElementOptDiv(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptDivNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] / in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptDivNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] / in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptDivRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptDivReluNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] / in1[index]; + out[index] = out[index] > 0 ? out[index] : 0; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptDivReluNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] / in1[0]; + out[index] = out[index] > 0 ? out[index] : 0; + } + } + return NNACL_OK; +} + +int ElementOptDivRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptDivRelu6Num0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] / in1[index], RELU6_MIN_VAL), RELU6_MAX_VAL); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptDivRelu6Num1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] / in1[0], RELU6_MIN_VAL), RELU6_MAX_VAL); + } + } + return NNACL_OK; +} + +int ElementOptDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptDivIntNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + NNACL_CHECK_ZERO_RETURN_ERR(in1[index] != 0); + out[index] = in0[0] / in1[index]; + } + } else { + NNACL_CHECK_ZERO_RETURN_ERR(in1[0] != 0); + + SIMD_RUN_NO_SCALAR(ElementOptDivIntNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] / in1[0]; + } + } + return NNACL_OK; +} + +int BroadcastDiv(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param) { + TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); + return ElementDiv(tile_in0, tile_in1, out, size); +} + +int ElementDiv(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementDiv, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] / in1[index]; + } + return NNACL_OK; +} + +int ElementDivRelu(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementDivRelu, index, in0, in1, out, size); + for (; index < size; index++) { + float res = in0[index] / in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementDivRelu6(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementDivRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] / in1[index], RELU6_MIN_VAL), RELU6_MAX_VAL); + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/div_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/div_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..3b1228a891c479500eb02463fcf8d658e05f64a5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/div_fp32.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DIV_H_ +#define MINDSPORE_NNACL_FP32_DIV_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/base/arithmetic_base.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ElementDiv(const float *in0, const float *in1, float *out, int size); +int ElementDivRelu(const float *in0, const float *in1, float *out, int size); +int ElementDivRelu6(const float *in0, const float *in1, float *out, int size); +int ElementOptDiv(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptDivRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptDivRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptDivInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int BroadcastDiv(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_DIV_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/div_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/div_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..1968d13dde19c6278e6dfe36d3791e1ba1695cba --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/div_fp32_simd.h.in @@ -0,0 +1,160 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_DIV_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_DIV_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ElementOptDivNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_DIV_F32(vin0_opt, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_DIV_F32(vin0, vin1_opt_); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_DIV_EPI32(vin0_opt, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_DIV_EPI32(vin0, vin1_opt_); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivReluNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_DIV_F32(vin0_opt, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivReluNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_DIV_F32(vin0, vin1_opt_), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivRelu6Num0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_DIV_F32(vin0_opt, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptDivRelu6Num1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_DIV_F32(vin0, vin1_opt_), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementDiv@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_DIV_F32(vin0, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementDivInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_DIV_EPI32(vin0, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementDivRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_DIV_F32(vin0, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementDivRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_DIV_F32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif +#endif diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op_registration_factory.cc b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/dropout_fp32.c similarity index 67% rename from mindspore-lite/src/extendrt/delegate/tensorrt/op_registration_factory.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/dropout_fp32.c index e23fd66b2f210eabeed878d8d2a6fbb983cbb934..08a53ab6620c5e1101bec0fdff47e9635558bc43 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op_registration_factory.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/dropout_fp32.c @@ -13,13 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "extendrt/delegate/tensorrt/op_registration_factory.h" -#include "extendrt/delegate/tensorrt/op/tensorrt_op.h" +#include +#include "nnacl/fp32/dropout_fp32.h" +#include "nnacl/dropout_fp32_simd.h" -namespace mindspore::lite { -template <> -TensorRTRegistrationFactory &TensorRTRegistrationFactory::Get() { - static TensorRTRegistrationFactory obj; - return obj; +void DropoutFp32(const float *input, float scale, int length, float *output) { + int i = 0; + + SIMD_RUN_NO_SCALAR(DropoutFp32, i, input, scale, length, output); + + for (; i < length; ++i) { + output[i] = scale * input[i]; + } } -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/delegate_utils.cc b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/dropout_fp32.h similarity index 61% rename from mindspore-lite/src/extendrt/delegate/delegate_utils.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/dropout_fp32.h index 25ad349ce441b76813eb0d4ee150349930d4a25c..57f27268289cab5d0a1c087a103abe724e879609 100644 --- a/mindspore-lite/src/extendrt/delegate/delegate_utils.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/dropout_fp32.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,11 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef MINDSPORE_NNACL_FP32_DROPOUT_FP32_H_ +#define MINDSPORE_NNACL_FP32_DROPOUT_FP32_H_ -#include "src/extendrt/delegate/delegate_utils.h" -#include "nnacl/fp32/pack_fp32.h" -namespace mindspore::lite { -bool IsSubGraphInputTensor(const std::vector &inputs, const TensorInfo &input) { - return std::find(inputs.begin(), inputs.end(), input) != inputs.end(); +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DropoutFp32(const float *input, float scale, int length, float *output); +#ifdef __cplusplus } -} // namespace mindspore::lite +#endif +#endif // MINDSPORE_NNACL_FP32_DROPOUT_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/dropout_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/dropout_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..48852743cb9f2deab5e1d325bcbca4ef7d3fc34d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/dropout_fp32_simd.h.in @@ -0,0 +1,39 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_DROPOUTFP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_DROPOUTFP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int DropoutFp32@SIMD_INSTRUCTION@(int index, const float *input, float scale, + int length, float *output) { + SIMD_F32 scale_value = SIMD_MOV_F32(scale); + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_ST_F32(output + index, SIMD_MUL_F32(SIMD_LD_F32(input + index), scale_value)); + } + return index; +} +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/embedding_lookup_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/embedding_lookup_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..0ba41c0472a6ee17875a560994180fa0e86e1273 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/embedding_lookup_fp32.c @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/embedding_lookup_fp32.h" +#include +#include "nnacl/errorcode.h" + +void l2_regulate(float *data, int size, float max_norm) { + float sum = 0; + for (int i = 0; i < size; ++i) { + sum += data[i]; + } + if (sum != 0) { + for (int i = 0; i < size; ++i) { + data[i] *= max_norm / sum; + } + } + return; +} + +int CopyData(float *input_data, const int32_t *ids, float *output_data, int num, + const EmbeddingLookupParameter *parameter) { + if (ids[num] >= parameter->layer_num_ || ids[num] < 0) { + return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; + } + float *out_data = output_data + num * parameter->layer_size_; + float *in_data = input_data + ids[num] * parameter->layer_size_; + if (!parameter->is_regulated_[ids[num]]) { + l2_regulate(in_data, parameter->layer_size_, parameter->max_norm_); + parameter->is_regulated_[ids[num]] = true; + } + + memcpy(out_data, in_data, sizeof(float) * (size_t)(parameter->layer_size_)); + return NNACL_OK; +} + +int EmbeddingLookup(float *input_data, const int32_t *ids, float *output_data, + const EmbeddingLookupParameter *parameter, int task_id) { + if (parameter->op_parameter_.thread_num_ == 0) { + return NNACL_PARAM_INVALID; + } + for (int i = task_id; i < parameter->ids_size_; i += parameter->op_parameter_.thread_num_) { + int ret = CopyData(input_data, ids, output_data, i, parameter); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/embedding_lookup_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/embedding_lookup_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..6bdef14620ce92554e1dea48b7dd37402643040f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/embedding_lookup_fp32.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_EMBEDDING_LOOKUP_H_ +#define MINDSPORE_NNACL_FP32_EMBEDDING_LOOKUP_H_ + +#include "nnacl/op_base.h" + +typedef struct EmbeddingLookupParameter { + OpParameter op_parameter_; + // primitive parameter + float max_norm_; + + // shape correlative + bool *is_regulated_; + int ids_size_; + int layer_size_; + int layer_num_; +} EmbeddingLookupParameter; + +#ifdef __cplusplus +extern "C" { +#endif +int EmbeddingLookup(float *input_data, const int32_t *ids, float *output_data, + const EmbeddingLookupParameter *parameter, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_EMBEDDING_LOOKUP_H_ diff --git a/mindspore-lite/tools/graph_kernel/common/utils.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/exp_fp32.c similarity index 30% rename from mindspore-lite/tools/graph_kernel/common/utils.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/exp_fp32.c index ce7afd0d92ff6318c80f80462211338040dc7a6b..322bcba359c8f557c2f7334fbd69607478d23943 100644 --- a/mindspore-lite/tools/graph_kernel/common/utils.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/exp_fp32.c @@ -1,5 +1,5 @@ /** - * Copyright 2022 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,30 +14,49 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_COMMON_UTILS_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_COMMON_UTILS_H_ -#include -#include -#include -#include "nnacl/tensor_c.h" -#include "common/kernel_build_info.h" -#include "include/backend/kernel_info.h" +#include "nnacl/fp32/exp_fp32.h" +#include "nnacl/exp_fp32_simd.h" +#include +#include +#include "nnacl/errorcode.h" -constexpr auto kAkgKernelSo = "akgkernels.so"; -namespace mindspore::graphkernel { -std::vector SplitString(const std::string &raw_str, char delimiter); +void ExpFp32(const float *src, float *dst, int num) { + int i = 0; -int GetCustomShape(const std::string &attr, std::vector> *shapes); + SIMD_RUN_NO_SCALAR(ExpFp32, i, src, dst, num); + for (; i < num; ++i) { + simd_exp32(src[i], dst + i); + } +} -int CalculateDynamicBatchSize(const TensorC *const *inputs, size_t inputs_size, - const std::vector> &shapes, const std::vector &index, - int *batch); -void GetCustomIndex(const std::string &dynamic_input_index, std::vector *index); -int GetCustomShape(const std::string &attr, std::vector> *shapes); -void SetKernelInfoWithFormatToAnfNode(const AnfNodePtr &node, const std::vector &format); -kernel::KernelBuildInfoPtr GetKernelInfo(const AnfNodePtr &node); -void SetAnfKernelInfoFormatFromAToB(const AnfNodePtr &node_a, const CNodePtr &node_b, - const std::vector &formats); -std::string GetOutputFormatFromAnfNode(const AnfNodePtr &node, size_t output_idx); -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_COMMON_UTILS_H_ +int ExpFusionFp32(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id) { + NNACL_CHECK_ZERO_RETURN_ERR(exp->base_.thread_nr_); + ExpParameter *param = (ExpParameter *)exp->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + const float *src = (const float *)src_data; + float *dst = (float *)dst_data; + + int stride = UP_DIV(exp->element_num_, exp->base_.thread_nr_); + int start = stride * task_id; + int end = MSMIN(exp->element_num_, start + stride); + int num = end - start; + + if (param->scale_ == 1) { + ExpFp32(src + start, dst + start, num); + } else { + int i = 0; + SIMD_RUN_NO_SCALAR(ExpFp32WithInScale, i, src, dst, num, exp->in_scale_); + for (; i < num; ++i) { + simd_exp32(src[i] * exp->in_scale_, dst + i); + } + } + if (exp->out_scale_ != 1) { + int i = 0; + SIMD_RUN_NO_SCALAR(ExpFp32WithOutScale, i, src, dst, num, exp->out_scale_); + for (; i < num; ++i) { + simd_exp32(src[i], dst + i); + dst[i] *= exp->out_scale_; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/exp_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/exp_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d6ae2279f42eef90377be703342f1f7d5aa0ad41 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/exp_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_EXP_H_ +#define MINDSPORE_NNACL_FP32_EXP_H_ + +#include "nnacl/op_base.h" +#include "nnacl/kernel/exp.h" +#include "nnacl/exp_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ExpFp32(const float *src, float *dst, int num); +int ExpFusionFp32(const void *src_data, void *dst_data, const ExpStruct *exp, int task_id); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_EXP_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/exp_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/exp_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..552e2bc031878cccc17c53ce58dea3e25bb650b7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/exp_fp32_simd.h.in @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_DIV_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_DIV_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t ExpFp32@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, int num) { + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EXP_ST_F32(SIMD_LD_F32(src + index), dst + index); + } + return index; +} + +static inline int64_t ExpFp32WithInScale@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, int num, float in_scale) { + SIMD_F32 scale_vec = SIMD_MOV_F32(in_scale); + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EXP_ST_F32(SIMD_MUL_F32(SIMD_LD_F32(src + index), scale_vec), dst + index); + } + return index; +} + +static inline int64_t ExpFp32WithOutScale@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, int num, float out_scale) { + SIMD_F32 scale_vec = SIMD_MOV_F32(out_scale); + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EXP_ST_F32(SIMD_LD_F32(src + index), dst + index); + SIMD_ST_F32(dst + index, SIMD_MUL_F32(SIMD_LD_F32(dst + index), scale_vec)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/gatherNd_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/gatherNd_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7731c19970ab4ac9a7e42f6b931d618f4935dda0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/gatherNd_fp32.c @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/gatherNd_fp32.h" +#include +#include "nnacl/errorcode.h" + +int GatherNd(const void *input, void *output, const int32_t *in_offset, int area, int count, int data_type_len) { + int i = 0; + for (i = 0; i < count; i++) { + (void)memcpy((int8_t *)output + area * i * data_type_len, (int8_t *)input + in_offset[i] * data_type_len, + (size_t)(area)*data_type_len); + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/gatherNd_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/gatherNd_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..5f3789e1e68412e4f06a4ecb2de4aa3f02b8a1cc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/gatherNd_fp32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GATHERND_FP32_H_ +#define NNACL_FP32_GATHERND_FP32_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int GatherNd(const void *input, void *output, const int32_t *in_offset, int area, int count, int data_type_len); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GATHERND_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/group_norm_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/group_norm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..20d4f46d654819dace2865c11f079d47189fc197 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/group_norm_fp32.c @@ -0,0 +1,125 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/group_norm_fp32.h" +#include +#include "nnacl/group_norm_parameter.h" +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/group_norm_fp32_simd.h" + +static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run_var, int completed_group, + int cur_groups, const GroupNormParameter *param); + +int GroupNormFp32(const float *input, const float *scale, const float *offset, float *mean, float *variance, + const GroupNormParameter *param, int task_id, float *output) { + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + const int frame_elem_num = param->unit_ * param->channel_; + const int groups_per_thread = UP_DIV(param->num_groups_, param->op_parameter_.thread_num_); + const int completed_group = task_id * groups_per_thread; + const int cur_group = MSMIN(groups_per_thread, param->num_groups_ - completed_group); + const int num_of_ch_per_group = param->channel_ / param->num_groups_; + int cur_offset = completed_group * num_of_ch_per_group * param->unit_; + + for (int b = 0; b < param->batch_; b++) { + const float *b_in = input + b * frame_elem_num; + float *b_out = output + b * frame_elem_num; + int b_offset = cur_offset; + GroupNormFp32MeanVar(b_in, mean, variance, completed_group, cur_group, param); + for (int g = 0; g < cur_group; g++) { + int grp_idx = g + completed_group; + int c_offset = grp_idx * num_of_ch_per_group; + float m = mean[grp_idx]; + float v = variance[grp_idx]; + float variance_sqrt = sqrtf(v + param->epsilon_); + if (variance_sqrt == 0) { + return NNACL_ERR; + } + for (int c = 0; c < num_of_ch_per_group; c++) { + const float *unit_input = b_in + b_offset; + float *unit_output = b_out + b_offset; + float s = scale[c_offset + c]; + float o = offset[c_offset + c]; + int u = 0; + SIMD_RUN_NO_SCALAR(GroupNormFp32, u, unit_input, s, o, m, variance_sqrt, param->unit_, unit_output); + for (; u < param->unit_; u++) { + float norm_val = (unit_input[u] - m) / variance_sqrt; + unit_output[u] = norm_val * s + o; + } + b_offset += param->unit_; + } + } + } + return NNACL_OK; +} + +#define SimdReduceSum(block_size, block_num, in, i, sum) \ + do { \ + for (int block_max_size = param->unit_ - block_num + 1; i < block_max_size; i += block_num) { \ + MS_FLOAT_32xN(block_num) input = MS_LD_F32(block_size, in + i); \ + sum += MS_GET_SUM_F32(block_size, input); \ + } \ + } while (0) + +#define SimdReduceVar(block_size, block_num, in, m, i, sum) \ + do { \ + MS_FLOAT_32xN(block_num) mean = MS_MOVN_F32(block_size, m); \ + MS_FLOAT_32xN(block_num) tmp = MS_MOVN_F32(block_size, 0); \ + for (int block_max_size = param->unit_ - block_num + 1; i < block_max_size; i += block_num) { \ + MS_FLOAT_32xN(block_num) input = MS_SUB_F32(block_size, MS_LD_F32(block_size, in + i), mean); \ + tmp = MS_ADD_F32(block_size, tmp, MS_MUL_F32(block_size, input, input)); \ + } \ + sum += MS_GET_SUM_F32(block_size, tmp); \ + } while (0) + +static void GroupNormFp32MeanVar(const float *input, float *run_mean, float *run_var, int completed_group, + int cur_groups, const GroupNormParameter *param) { + const int num_of_ch_per_group = param->channel_ / param->num_groups_; + const float N = (float)(param->unit_ * num_of_ch_per_group); + + // calc mean + for (int g = 0; g < cur_groups; g++) { + int g_idx = g + completed_group; + float sum = 0; + for (int c = 0; c < num_of_ch_per_group; c++) { + const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_; + int i = 0; + SIMD_RUN_NO_SCALAR(GroupNormReduceSum, i, in, &sum, param->unit_); + for (; i < param->unit_; i++) { + sum += in[i]; + } + } + run_mean[g_idx] = sum / N; + } + + // calc variance + for (int g = 0; g < cur_groups; g++) { + int g_idx = g + completed_group; + float var = 0; + run_var[g_idx] = 0; + for (int c = 0; c < num_of_ch_per_group; c++) { + const float *in = input + (num_of_ch_per_group * g_idx + c) * param->unit_; + int i = 0; + SIMD_RUN_NO_SCALAR(GroupNormReduceVar, i, in, run_mean[g_idx], &var, param->unit_); + for (; i < param->unit_; i++) { + var += (in[i] - run_mean[g_idx]) * (in[i] - run_mean[g_idx]); + } + } + run_var[g_idx] = var / N; + } +} diff --git a/mindspore-lite/src/extendrt/kernel/ascend/api/ascend_kernel_api.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/group_norm_fp32.h similarity index 55% rename from mindspore-lite/src/extendrt/kernel/ascend/api/ascend_kernel_api.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/group_norm_fp32.h index b49268532887cc3370f3962e77031d2f2502d622..e7f08295bf81e25eda160461e20c73699c8aadf0 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/api/ascend_kernel_api.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/group_norm_fp32.h @@ -14,24 +14,22 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_API_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_API_H_ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GROUP_NORM_FP32_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GROUP_NORM_FP32_H_ -#include -#include -#include -#include "extendrt/kernel/ascend/src/custom_ascend_kernel.h" +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/group_norm_parameter.h" #ifdef __cplusplus extern "C" { #endif -using CreatorFunc = std::function()>; -std::map *CreateCustomAscendKernel(); - -void DestroyCustomAscendKernel(std::map *creator_func); +int GroupNormFp32(const float *input, const float *scale, const float *offset, float *mean, float *variance, + const GroupNormParameter *param, int task_id, float *output); #ifdef __cplusplus } #endif -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_API_H_ + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GROUP_NORM_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/group_norm_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/group_norm_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..8fb8fa5b1a29900af8aa2331204c13a42df7ee5a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/group_norm_fp32_simd.h.in @@ -0,0 +1,70 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_GROUP_NORM_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_GROUP_NORM_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t GroupNormFp32@SIMD_INSTRUCTION@(int64_t index, const float *unit_input, float scale, float offset, float mean, + float var_sqrt, int unit, float *unit_output) { + SIMD_F32 mean_val = SIMD_MOV_F32(mean); + SIMD_F32 v_sqrt = SIMD_MOV_F32(var_sqrt); + SIMD_F32 scale_val = SIMD_MOV_F32(scale); + SIMD_F32 offset_val = SIMD_MOV_F32(offset); + for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(unit_input + index); + SIMD_F32 norm_val = SIMD_DIV_F32(SIMD_SUB_F32(input, mean_val), v_sqrt); + SIMD_F32 output = SIMD_ADD_F32(SIMD_MUL_F32(norm_val, scale_val), offset_val); + SIMD_ST_F32(unit_output + index, output); + } + return index; +} + +static inline int64_t GroupNormReduceSum@SIMD_INSTRUCTION@(int64_t index, const float *in, float *sum, int unit) { + if (unit - index >= 4 * BLOCK_NUM) { + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(in + index)); + } + *sum += SIMD_GET_SUM_F32(tmp); + } + return index; +} + +static inline int64_t GroupNormReduceVar@SIMD_INSTRUCTION@(int64_t index, const float *in, float mean, float *sum, int unit) { + if (unit - index >= 4 * BLOCK_NUM) { + SIMD_F32 mean_val = SIMD_MOV_F32(mean); + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int block_max_size = unit - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_SUB_F32(SIMD_LD_F32(in + index), mean_val); + tmp = SIMD_ADD_F32(tmp, SIMD_MUL_F32(input, input)); + } + *sum += SIMD_GET_SUM_F32(tmp); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/gru_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/gru_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..fb4d87c36031caab33766d7a92523bf88b854d2f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/gru_fp32.c @@ -0,0 +1,154 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/gru_fp32.h" +#include +#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl/fp32/activation_fp32.h" +#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" +void GruMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, bool is_vec) { + if (is_vec) { + MatVecMulFp32(a, b, c, bias, ActType_No, deep, col); + } else { + MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); + } +} + +void GruStepUnit(float *output, float *update_gate, float *reset_gate, float *hidden_buffer, const float *state_weight, + const float *state_bias, float *hidden_state, float *buffer[4], const GruParameter *gru_param) { + float *packed_state = buffer[2]; + float *state_gate = buffer[3]; + bool is_vec = gru_param->batch_ == 1; + + const float *state_update_weight = state_weight; + const float *state_reset_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_; + const float *state_hidden_weight = state_weight + gru_param->hidden_size_ * gru_param->hidden_size_ * 2; + float *state_update_gate = state_gate; + float *state_reset_gate = state_gate + gru_param->batch_ * gru_param->hidden_size_; + float *state_hidden_buffer = state_gate + gru_param->batch_ * gru_param->hidden_size_ * 2; + const float *state_update_bias = state_bias; + const float *state_reset_bias = state_bias + gru_param->hidden_size_; + const float *state_hidden_bias = state_bias + gru_param->hidden_size_ * 2; + + // state * weight + if (is_vec) { + GruMatMul(state_reset_gate, hidden_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + GruMatMul(state_update_gate, hidden_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + PackLstmInput(hidden_state, packed_state, gru_param->batch_, gru_param->hidden_size_); + GruMatMul(state_reset_gate, packed_state, state_reset_weight, state_reset_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + GruMatMul(state_update_gate, packed_state, state_update_weight, state_update_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAdd(update_gate, state_update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_); + ElementAdd(reset_gate, state_update_gate + gru_param->batch_ * gru_param->hidden_size_, reset_gate, + gru_param->batch_ * gru_param->hidden_size_); + + // update reset_gate + Sigmoid(reset_gate, gru_param->batch_ * gru_param->hidden_size_, reset_gate); + // update update_gate + Sigmoid(update_gate, gru_param->batch_ * gru_param->hidden_size_, update_gate); + + ElementMul(hidden_state, reset_gate, reset_gate, gru_param->batch_ * gru_param->hidden_size_); + if (is_vec) { + GruMatMul(state_hidden_buffer, reset_gate, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } else { + PackLstmInput(reset_gate, packed_state, gru_param->batch_, gru_param->hidden_size_); + GruMatMul(state_hidden_buffer, packed_state, state_hidden_weight, state_hidden_bias, gru_param->batch_, + gru_param->hidden_size_, gru_param->hidden_size_, is_vec); + } + ElementAdd(hidden_buffer, state_hidden_buffer, hidden_buffer, gru_param->batch_ * gru_param->hidden_size_); + + Tanh(hidden_buffer, gru_param->batch_ * gru_param->hidden_size_, hidden_buffer); + + ElementMul(update_gate, hidden_state, hidden_state, gru_param->batch_ * gru_param->hidden_size_); + + const float one = 1.0f; + ElementOptSub(&one, update_gate, update_gate, gru_param->batch_ * gru_param->hidden_size_, true); + + ElementMulAcc(update_gate, hidden_buffer, hidden_state, gru_param->batch_ * gru_param->hidden_size_); + + memcpy(output, hidden_state, gru_param->batch_ * gru_param->hidden_size_ * sizeof(float)); +} + +void GruUnidirectional(float *output, const float *packed_input, const float *weight_g, const float *weight_r, + const float *input_bias, const float *state_bias, float *hidden_state, float *buffer[4], + const GruParameter *gru_param, bool is_backward) { + float *gate = buffer[1]; + for (int i = 0; i < 3; i++) { + const float *weight_loop = weight_g + gru_param->input_size_ * gru_param->input_col_align_ * i; + const float *bias_loop = input_bias + gru_param->input_col_align_ * i; + float *gate_loop = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * i; + MatMulOpt(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, gru_param->input_size_, + gru_param->seq_len_ * gru_param->batch_, gru_param->hidden_size_, gru_param->hidden_size_, OutType_Nhwc); + } + + float *update_gate = gate; + float *reset_gate = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_; + float *hidden_buffer = gate + gru_param->seq_len_ * gru_param->batch_ * gru_param->hidden_size_ * 2; + for (int t = 0; t < gru_param->seq_len_; t++) { + int real_t = is_backward ? gru_param->seq_len_ - t - 1 : t; + float *update_gate_t = update_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float *reset_gate_t = reset_gate + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float *hidden_buffer_t = hidden_buffer + gru_param->batch_ * gru_param->hidden_size_ * real_t; + float *output_ptr = output + real_t * gru_param->output_step_; + GruStepUnit(output_ptr, update_gate_t, reset_gate_t, hidden_buffer_t, weight_r, state_bias, hidden_state, buffer, + gru_param); + } +} + +void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *input_bias, + const float *state_bias, float *hidden_state, float *buffer[4], int check_seq_len, + const GruParameter *gru_param) { + // forward + float *packed_input = buffer[0]; + PackLstmInput(input, packed_input, gru_param->seq_len_ * gru_param->batch_, gru_param->input_size_); + GruUnidirectional(output, packed_input, weight_g, weight_r, input_bias, state_bias, hidden_state, buffer, gru_param, + false); + + // zero out extra fw outputs + for (int t = check_seq_len; t < gru_param->seq_len_; t++) { + float *output_ptr = output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + + // backward + if (gru_param->bidirectional_) { + const float *backward_weight_g = weight_g + 3 * gru_param->input_col_align_ * gru_param->input_size_; + const float *backward_weight_r = weight_r + 3 * gru_param->state_col_align_ * gru_param->hidden_size_; + const float *backward_input_bias = input_bias + 3 * gru_param->input_col_align_; + const float *backward_state_bias = state_bias + 3 * gru_param->state_col_align_; + float *backward_output = output + gru_param->batch_ * gru_param->hidden_size_; + float *backward_hidden_state = hidden_state + gru_param->batch_ * gru_param->hidden_size_; + + GruUnidirectional(backward_output, packed_input, backward_weight_g, backward_weight_r, backward_input_bias, + backward_state_bias, backward_hidden_state, buffer, gru_param, true); + + // zero out extra bw outputs + for (int t = gru_param->seq_len_ - 1; t >= check_seq_len; t--) { + float *output_ptr = backward_output + t * gru_param->output_step_; + for (int i = 0; i < gru_param->batch_ * gru_param->hidden_size_; i++) { + output_ptr[i] = 0.0f; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/gru_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/gru_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..7beb77ed2996e772e8e5eb5e2a2bebb6b8bbd96f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/gru_fp32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_GRU_FP32_H_ +#define MINDSPORE_NNACL_FP32_GRU_FP32_H_ +#include "nnacl/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Gru(float *output, const float *input, const float *weight_g, const float *weight_r, const float *input_bias, + const float *state_bias, float *hidden_state, float *buffer[4], int check_seq_len, + const GruParameter *gru_parm); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_GRU_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/instance_norm_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/instance_norm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..844565389ab3f911cff7e3cf0d069e4abadd1b58 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/instance_norm_fp32.c @@ -0,0 +1,374 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/instance_norm_fp32.h" +#include +#include "nnacl/errorcode.h" +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_NULL_RETURN_ERR(gamma_data); + NNACL_CHECK_NULL_RETURN_ERR(beta_data); + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel_step = UP_DIV(param->channel_, param->op_parameter_.thread_num_); + int channel_begin = (int)(task_id)*channel_step; + int channel_end = MSMIN(channel_begin + channel_step, param->channel_); + + for (int b = 0; b < param->batch_; b++) { + const float *src_b = src_data + b * param->channel_ * param->inner_size_; + float *dst_b = dst_data + b * param->channel_ * param->inner_size_; + for (int c = channel_begin; c < channel_end; c++) { + const float *src = src_b + c * param->inner_size_; + float *dst = dst_b + c * param->inner_size_; + double mean = 0.0f; + double squ_m = 0.0f; + + int index = 0; +#if defined(ENABLE_AVX) + for (; index <= param->inner_size_ - C8NUM; index += C8NUM) { + __m256 srcv = _mm256_loadu_ps(src + index); + __m256 squarev = _mm256_mul_ps(srcv, srcv); + __m128 src128 = _mm_add_ps(_mm256_extractf128_ps(srcv, 0), _mm256_extractf128_ps(srcv, 1)); + __m128 square128 = _mm_add_ps(_mm256_extractf128_ps(squarev, 0), _mm256_extractf128_ps(squarev, 1)); + for (int i = 0; i < C4NUM; ++i) { + mean += MS_F32X4_GETI(src128, i); + squ_m += MS_F32X4_GETI(square128, i); + } + } +#endif + +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= param->inner_size_ - C4NUM; index += C4NUM) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index); + MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv); +#ifdef ENABLE_ARM64 + mean += vaddvq_f32(srcv); + squ_m += vaddvq_f32(squarev); +#elif defined(ENABLE_SSE) + for (int i = 0; i < C4NUM; ++i) { + mean += MS_F32X4_GETI(srcv, i); + squ_m += MS_F32X4_GETI(squarev, i); + } +#else + float32x2_t src_add2 = vadd_f32(vget_low_f32(srcv), vget_high_f32(srcv)); + float32x2_t src_add4 = vpadd_f32(src_add2, src_add2); + mean += vget_lane_f32(src_add4, 0); + float32x2_t square_add2 = vadd_f32(vget_low_f32(squarev), vget_high_f32(squarev)); + float32x2_t square_add4 = vpadd_f32(square_add2, square_add2); + squ_m += vget_lane_f32(square_add4, 0); +#endif + } +#endif + for (; index < param->inner_size_; index++) { + mean += src[index]; + squ_m += src[index] * src[index]; + } + + mean /= (float)param->inner_size_; + squ_m /= (float)param->inner_size_; + const double deno = gamma_data[c] / sqrt(squ_m - mean * mean + param->epsilon_); + + index = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 meanv8 = MS_MOV256_F32(mean); + MS_FLOAT32X8 denov8 = MS_MOV256_F32(deno); + for (; index <= param->inner_size_ - C8NUM; index += C8NUM) { + MS_FLOAT32X8 srcv8 = MS_LD256_F32(src + index); + MS_FLOAT32X8 dstv8 = + MS_ADD256_F32(MS_MUL256_F32(MS_SUB256_F32(srcv8, meanv8), denov8), MS_MOV256_F32(*(beta_data + c))); + MS_ST256_F32(dst + index, dstv8); + } +#endif + +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 meanv4 = MS_MOVQ_F32(mean); + MS_FLOAT32X4 denov4 = MS_MOVQ_F32(deno); + for (; index <= param->inner_size_ - C4NUM; index += C4NUM) { + MS_FLOAT32X4 srcv4 = MS_LDQ_F32(src + index); + MS_FLOAT32X4 dstv4 = + MS_ADDQ_F32(MS_MULQ_F32(MS_SUBQ_F32(srcv4, meanv4), denov4), MS_MOVQ_F32(*(beta_data + c))); + MS_STQ_F32(dst + index, dstv4); + } +#endif + for (; index < param->inner_size_; index++) { + dst[index] = (src[index] - mean) * deno + beta_data[c]; + } + } + } + return NNACL_OK; +} + +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) +void InstanceNormC4HW4ArmSse(const float *src_b, float *dst_b, const float *gamma_data, const float *beta_data, + int32_t *c_src, const InstanceNormParameter *param, int channel, int channel_end, + int hw_plane, MS_FLOAT32X4 hw_planev) { + int c = *c_src; + for (; c <= channel_end - C16NUM; c += C16NUM) { + const float *src = src_b + c * hw_plane, *src1 = src_b + (c + C4NUM) * hw_plane; + const float *src2 = src_b + (c + C8NUM) * hw_plane, *src3 = src_b + (c + C12NUM) * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f), mean1 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 mean2 = MS_MOVQ_F32(0.0f), mean3 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 squ_m = MS_MOVQ_F32(0.0f), squ_m1 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 squ_m2 = MS_MOVQ_F32(0.0f), squ_m3 = MS_MOVQ_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM), srcv3 = MS_LDQ_F32(src3 + index * C4NUM); + MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv), squarev1 = MS_MULQ_F32(srcv1, srcv1); + MS_FLOAT32X4 squarev2 = MS_MULQ_F32(srcv2, srcv2), squarev3 = MS_MULQ_F32(srcv3, srcv3); + MS_ADDQ_F32_VEC(mean, mean1, mean2, mean3, srcv, srcv1, srcv2, srcv3); + MS_ADDQ_F32_VEC(squ_m, squ_m1, squ_m2, squ_m3, squarev, squarev1, squarev2, squarev3); + } + MS_DIVQ_F32_VEC(mean, mean1, mean2, mean3, hw_planev); + MS_DIVQ_F32_VEC(squ_m, squ_m1, squ_m2, squ_m3, hw_planev); + + MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(squ_m, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno1 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno2 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m2, MS_MULQ_F32(mean2, mean2)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno3 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m3, MS_MULQ_F32(mean3, mean3)), MS_MOVQ_F32(param->epsilon_)); + + deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); + deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1)); + deno2 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno2)); + deno3 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno3)); + + MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c] + MS_FLOAT32X4 gammav2 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C8NUM), deno2); // deno * gamma_data[c] + MS_FLOAT32X4 gammav3 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C12NUM), deno3); // deno * gamma_data[c] + MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c), betav1 = MS_LDQ_F32(beta_data + c + C4NUM); + MS_FLOAT32X4 betav2 = MS_LDQ_F32(beta_data + c + C8NUM), betav3 = MS_LDQ_F32(beta_data + c + C12NUM); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 srcv2 = MS_LDQ_F32(src2 + index * C4NUM), srcv3 = MS_LDQ_F32(src3 + index * C4NUM); + MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean), outv1 = MS_SUBQ_F32(srcv1, mean1); + MS_FLOAT32X4 outv2 = MS_SUBQ_F32(srcv2, mean2), outv3 = MS_SUBQ_F32(srcv3, mean3); + + outv = MS_MULQ_F32(outv, gammav), outv1 = MS_MULQ_F32(outv1, gammav1); + outv2 = MS_MULQ_F32(outv2, gammav2), outv3 = MS_MULQ_F32(outv3, gammav3); + MS_ADDQ_F32_VEC(outv, outv1, outv2, outv3, betav, betav1, betav2, betav3); + + MS_STQ_F32(dst + index * channel, outv), MS_STQ_F32(dst + index * channel + C4NUM, outv1); + MS_STQ_F32(dst + index * channel + C8NUM, outv2), MS_STQ_F32(dst + index * channel + C12NUM, outv3); + } + } + for (; c <= channel_end - C8NUM; c += C8NUM) { + const float *src = src_b + c * hw_plane, *src1 = src_b + (c + C4NUM) * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f), mean1 = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 squ_m = MS_MOVQ_F32(0.0f), squ_m1 = MS_MOVQ_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 squarev = MS_MULQ_F32(srcv, srcv), squarev1 = MS_MULQ_F32(srcv1, srcv1); + mean = MS_ADDQ_F32(mean, srcv), mean1 = MS_ADDQ_F32(mean1, srcv1); + squ_m = MS_ADDQ_F32(squ_m, squarev), squ_m1 = MS_ADDQ_F32(squ_m1, squarev1); + } + + MS_DIVQ_F32_VEC(mean, mean1, squ_m, squ_m1, hw_planev); + MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(squ_m, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); + MS_FLOAT32X4 deno1 = MS_ADDQ_F32(MS_SUBQ_F32(squ_m1, MS_MULQ_F32(mean1, mean1)), MS_MOVQ_F32(param->epsilon_)); + deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); + deno1 = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno1)); + + MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X4 gammav1 = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c + C4NUM), deno1); // deno * gamma_data[c] + MS_FLOAT32X4 betav = MS_LDQ_F32(beta_data + c), betav1 = MS_LDQ_F32(beta_data + c + C4NUM); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), srcv1 = MS_LDQ_F32(src1 + index * C4NUM); + MS_FLOAT32X4 outv = MS_SUBQ_F32(srcv, mean), outv1 = MS_SUBQ_F32(srcv1, mean1); + outv = MS_MULQ_F32(outv, gammav), outv1 = MS_MULQ_F32(outv1, gammav1); + outv = MS_ADDQ_F32(outv, betav), outv1 = MS_ADDQ_F32(outv1, betav1); + MS_STQ_F32(dst + index * channel, outv); + MS_STQ_F32(dst + index * channel + C4NUM, outv1); + } + } + for (; c <= channel_end - C4NUM; c += C4NUM) { + const float *src = src_b + c * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X4 mean = MS_MOVQ_F32(0.0f), squ_m = MS_MOVQ_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), squarev = MS_MULQ_F32(srcv, srcv); + mean = MS_ADDQ_F32(mean, srcv), squ_m = MS_ADDQ_F32(squ_m, squarev); + } + mean = MS_DIVQ_F32(mean, hw_planev), squ_m = MS_DIVQ_F32(squ_m, hw_planev); + MS_FLOAT32X4 deno = MS_ADDQ_F32(MS_SUBQ_F32(squ_m, MS_MULQ_F32(mean, mean)), MS_MOVQ_F32(param->epsilon_)); + deno = MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_SQRTFX4_F32(deno)); + + MS_FLOAT32X4 gammav = MS_MULQ_F32(MS_LDQ_F32(gamma_data + c), deno), betav = MS_LDQ_F32(beta_data + c); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X4 srcv = MS_LDQ_F32(src + index * C4NUM), outv = MS_SUBQ_F32(srcv, mean); + MS_STQ_F32(dst + index * channel, MS_ADDQ_F32(MS_MULQ_F32(outv, gammav), betav)); + } + } + *c_src = c; +} +#endif + +int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_NULL_RETURN_ERR(gamma_data); + NNACL_CHECK_NULL_RETURN_ERR(beta_data); + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel = param->channel_; + int hw_plane = param->inner_size_; + int channel_step = UP_DIV(UP_DIV(channel, C4NUM), param->op_parameter_.thread_num_) * C4NUM; + int channel_begin = (int)(task_id)*channel_step; + int channel_end = MSMIN(channel_begin + channel_step, channel); +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + MS_FLOAT32X4 hw_planev = MS_MOVQ_F32((float)(hw_plane)); +#endif + for (int b = 0; b < param->batch_; b++) { + const float *src_b = src_data + b * channel * hw_plane; + float *dst_b = dst_data + b * channel * hw_plane; + int c = channel_begin; +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + InstanceNormC4HW4ArmSse(src_b, dst_b, gamma_data, beta_data, &c, param, channel, channel_end, hw_plane, hw_planev); +#endif + for (; c < channel_end; ++c) { + int c4_down_loop = c / C4NUM * C4NUM; + int c4_mod = c % C4NUM; + int c_res = MSMIN(channel_end - c4_down_loop, C4NUM); + const float *src = src_b + c4_down_loop * hw_plane + c4_mod; + float *dst = dst_b + c; + float mean = 0.0f; + float squ_m = 0.0f; + for (int index = 0; index < hw_plane; ++index) { + float tmp = src[index * c_res]; + mean += tmp; + squ_m += tmp * tmp; + } + mean /= (float)hw_plane; + squ_m /= (float)hw_plane; + const float deno = gamma_data[c] / sqrtf(squ_m - mean * mean + param->epsilon_); + for (int index = 0; index < hw_plane; ++index) { + dst[index * channel] = (src[index * c_res] - mean) * deno + beta_data[c]; + } + } + } + return NNACL_OK; +} + +#ifdef ENABLE_AVX +int InstanceNormNC8HW8(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id) { + NNACL_CHECK_NULL_RETURN_ERR(src_data); + NNACL_CHECK_NULL_RETURN_ERR(dst_data); + NNACL_CHECK_NULL_RETURN_ERR(gamma_data); + NNACL_CHECK_NULL_RETURN_ERR(beta_data); + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_ZERO_RETURN_ERR(param->op_parameter_.thread_num_); + int channel = param->channel_, hw_plane = param->inner_size_; + int channel_step = UP_DIV(UP_DIV(channel, C8NUM), param->op_parameter_.thread_num_) * C8NUM; + int channel_begin = (int)(task_id)*channel_step; + int channel_end = MSMIN(channel_begin + channel_step, channel); + MS_FLOAT32X8 hw_planev = MS_MOV256_F32((float)(hw_plane)); + for (int b = 0; b < param->batch_; b++) { + const float *src_b = src_data + b * channel * hw_plane; + float *dst_b = dst_data + b * channel * hw_plane; + int c = channel_begin; + for (; c <= channel_end - C16NUM; c += C16NUM) { + const float *src = src_b + c * hw_plane; + const float *src1 = src_b + (c + C8NUM) * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X8 mean = MS_MOV256_F32(0.0f), mean1 = MS_MOV256_F32(0.0f); + MS_FLOAT32X8 squ_m = MS_MOV256_F32(0.0f), squ_m1 = MS_MOV256_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM), srcv1 = MS_LD256_F32(src1 + index * C8NUM); + MS_FLOAT32X8 squarev = MS_MUL256_F32(srcv, srcv), squarev1 = MS_MUL256_F32(srcv1, srcv1); + mean = MS_ADD256_F32(mean, srcv); + mean1 = MS_ADD256_F32(mean1, srcv1); + squ_m = MS_ADD256_F32(squ_m, squarev); + squ_m1 = MS_ADD256_F32(squ_m1, squarev1); + } + mean = MS_DIV256_F32(mean, hw_planev); + mean1 = MS_DIV256_F32(mean1, hw_planev); + squ_m = MS_DIV256_F32(squ_m, hw_planev); + squ_m1 = MS_DIV256_F32(squ_m1, hw_planev); + MS_FLOAT32X8 deno = + MS_ADD256_F32(MS_SUB256_F32(squ_m, MS_MUL256_F32(mean, mean)), MS_MOV256_F32(param->epsilon_)); + MS_FLOAT32X8 deno1 = + MS_ADD256_F32(MS_SUB256_F32(squ_m1, MS_MUL256_F32(mean1, mean1)), MS_MOV256_F32(param->epsilon_)); + deno = MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_SQRTFX8_F32(deno)); + deno1 = MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_SQRTFX8_F32(deno1)); + + MS_FLOAT32X8 gammav = MS_MUL256_F32(MS_LD256_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X8 gammav1 = MS_MUL256_F32(MS_LD256_F32(gamma_data + c + C8NUM), deno1); // deno1 * gamma_data[c] + MS_FLOAT32X8 betav = MS_LD256_F32(beta_data + c), betav1 = MS_LD256_F32(beta_data + c + C8NUM); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM), srcv1 = MS_LD256_F32(src1 + index * C8NUM); + MS_FLOAT32X8 outv = MS_SUB256_F32(srcv, mean), outv1 = MS_SUB256_F32(srcv1, mean1); + outv = MS_MUL256_F32(outv, gammav); + outv1 = MS_MUL256_F32(outv1, gammav1); + outv = MS_ADD256_F32(outv, betav); + outv1 = MS_ADD256_F32(outv1, betav1); + MS_ST256_F32(dst + index * channel, outv); + MS_ST256_F32(dst + index * channel + C8NUM, outv1); + } + } + for (; c <= channel_end - C8NUM; c += C8NUM) { + const float *src = src_b + c * hw_plane; + float *dst = dst_b + c; + MS_FLOAT32X8 mean = MS_MOV256_F32(0.0f), squ_m = MS_MOV256_F32(0.0f); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM); + MS_FLOAT32X8 squarev = MS_MUL256_F32(srcv, srcv); + mean = MS_ADD256_F32(mean, srcv); + squ_m = MS_ADD256_F32(squ_m, squarev); + } + mean = MS_DIV256_F32(mean, hw_planev); + squ_m = MS_DIV256_F32(squ_m, hw_planev); + MS_FLOAT32X8 deno = MS_ADD256_F32(MS_SUB256_F32(squ_m, MS_MUL256_F32(mean, mean)), + MS_MOV256_F32(param->epsilon_)); // 256uestion + deno = MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_SQRTFX8_F32(deno)); + + MS_FLOAT32X8 gammav = MS_MUL256_F32(MS_LD256_F32(gamma_data + c), deno); // deno * gamma_data[c] + MS_FLOAT32X8 betav = MS_LD256_F32(beta_data + c); + for (int index = 0; index < hw_plane; ++index) { + MS_FLOAT32X8 srcv = MS_LD256_F32(src + index * C8NUM), outv = MS_SUB256_F32(srcv, mean); + outv = MS_MUL256_F32(outv, gammav); + outv = MS_ADD256_F32(outv, betav); + MS_ST256_F32(dst + index * channel, outv); + } + } + for (; c < channel_end; ++c) { + int c8_down_loop = c / C8NUM * C8NUM, c8_mod = c % C8NUM; + int c_res = MSMIN(channel_end - c8_down_loop, C8NUM); + const float *src = src_b + c8_down_loop * hw_plane + c8_mod; + float *dst = dst_b + c; + float mean = 0.0f, squ_m = 0.0f; + for (int index = 0; index < hw_plane; ++index) { + float tmp = src[index * c_res]; + mean += tmp; + squ_m += tmp * tmp; + } + mean /= (float)hw_plane; + squ_m /= (float)hw_plane; + const float deno = gamma_data[c] / sqrtf(squ_m - mean * mean + param->epsilon_); + for (int index = 0; index < hw_plane; ++index) { + dst[index * channel] = (src[index * c_res] - mean) * deno + beta_data[c]; + } + } + } + return NNACL_OK; +} +#endif diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/AppDelegate.mm b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/instance_norm_fp32.h similarity index 31% rename from mindspore-lite/examples/quick_start_ios/mindspore-lite/AppDelegate.mm rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/instance_norm_fp32.h index 48e9dd00961491b3260bea1b13ab468d049065b3..a51611e3abd28891eae678647b64ce8c5fa75881 100644 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/AppDelegate.mm +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/instance_norm_fp32.h @@ -13,38 +13,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#import "AppDelegate.h" -#include "Benchmark.hpp" -@interface AppDelegate () - -@end - -@implementation AppDelegate - - -- (BOOL)application:(UIApplication *)application didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { - // Override point for customization after application launch. - mindspore::lite::main_benchmark(); - return YES; -} - - -#pragma mark - UISceneSession lifecycle - - -- (UISceneConfiguration *)application:(UIApplication *)application configurationForConnectingSceneSession:(UISceneSession *)connectingSceneSession options:(UISceneConnectionOptions *)options { - // Called when a new scene session is being created. - // Use this method to select a configuration to create the new scene with. - return [[UISceneConfiguration alloc] initWithName:@"Default Configuration" sessionRole:connectingSceneSession.role]; +#ifndef MINDSPORE_NNACL_FP32_INSTANCE_NORM_H_ +#define MINDSPORE_NNACL_FP32_INSTANCE_NORM_H_ + +#include "nnacl/op_base.h" +#include "nnacl/instance_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define MS_ADDQ_F32_VEC(in1, in2, in3, in4, v1, v2, v3, v4) \ + in1 = MS_ADDQ_F32(in1, v1); \ + in2 = MS_ADDQ_F32(in2, v2); \ + in3 = MS_ADDQ_F32(in3, v3); \ + in4 = MS_ADDQ_F32(in4, v4); + +#define MS_DIVQ_F32_VEC(in1, in2, in3, in4, v) \ + in1 = MS_DIVQ_F32(in1, v); \ + in2 = MS_DIVQ_F32(in2, v); \ + in3 = MS_DIVQ_F32(in3, v); \ + in4 = MS_DIVQ_F32(in4, v); + +int InstanceNorm(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id); +int InstanceNormNC4HW4(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id); +#ifdef ENABLE_AVX +int InstanceNormNC8HW8(const float *src_data, float *dst_data, const float *gamma_data, const float *beta_data, + const InstanceNormParameter *param, size_t task_id); +#endif +#ifdef __cplusplus } +#endif - -- (void)application:(UIApplication *)application didDiscardSceneSessions:(NSSet *)sceneSessions { - // Called when the user discards a scene session. - // If any sessions were discarded while the application was not running, this will be called shortly after application:didFinishLaunchingWithOptions. - // Use this method to release any resources that were specific to the discarded scenes, as they will not return. -} - - -@end +#endif // MINDSPORE_NNACL_FP32_INSTANCE_NORM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/invert_permutation_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/invert_permutation_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..9d7f6c26c1a86747e496b121bc71583aa865bc49 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/invert_permutation_fp32.c @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/invert_permutation_fp32.h" +#include "nnacl/errorcode.h" + +int InvertPermutation(const int32_t *input, int32_t *output, size_t num) { + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + for (size_t i = 0; i < num; i++) { + size_t index = (size_t)input[i]; + if (index >= num) { + return NNACL_ERR; + } + output[index] = i; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/invert_permutation_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/invert_permutation_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..82c6a30fe742dd34020db9f3267f31c228343895 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/invert_permutation_fp32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INVERT_PERMUTATION_FP32_H_ +#define MINDSPORE_NNACL_INVERT_PERMUTATION_FP32_H_ + +#include +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int InvertPermutation(const int32_t *input, int32_t *output, size_t num); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_INVERT_PERMUTATION_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/l2_norm_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/l2_norm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..3259bd46f6402c113f58ff1d900fbfd5c8ff5581 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/l2_norm_fp32.c @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/l2_norm_fp32.h" +#include +#include "nnacl/errorcode.h" + +int CalcThreadSquareSum(const float *input_ptr, float *sum, int begin, int end) { + *sum = 0.0f; + int i; + for (i = begin; i < end; ++i) { + *sum += input_ptr[i] * input_ptr[i]; + } + return NNACL_OK; +} + +int ThreadDivSqrtSum(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const float sqrt_sum, + const int begin, const int end) { + bool is_relu = param->act_type_ == ActType_Relu; + bool is_relu6 = param->act_type_ == ActType_Relu6; + int i; + if (sqrt_sum == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + for (i = begin; i < end; i++) { + float tmp = input_ptr[i] / sqrt_sum; + if (is_relu) { + output_ptr[i] = MSMAX(0, tmp); + } else if (is_relu6) { + output_ptr[i] = MSMIN(6, MSMAX(0, tmp)); + } else { + output_ptr[i] = tmp; + } + } + return NNACL_OK; +} + +int ThreadTrailingAxis(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const int begin, + const int end) { + bool is_relu = param->act_type_ == ActType_Relu; + bool is_relu6 = param->act_type_ == ActType_Relu6; + + const int c = param->shape_[param->shape_num_ - 1]; + int i = 0; + for (i = begin; i < end; ++i) { + float square_sum = 0.0f; + int j = 0; + for (j = 0; j < c; ++j) { + const float val = input_ptr[i * c + j]; + square_sum += val * val; + } + float sqrt_sum = sqrtf(square_sum > param->epsilon_ ? square_sum : param->epsilon_); + for (j = 0; j < c; ++j) { + float tmp = input_ptr[i * c + j] / sqrt_sum; + if (is_relu) { + output_ptr[i * c + j] = MSMAX(0, tmp); + } else if (is_relu6) { + output_ptr[i * c + j] = MSMIN(6, MSMAX(0, tmp)); + } else { + output_ptr[i * c + j] = tmp; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/l2_norm_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/l2_norm_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..efca13b717a3ca9ae8099199cd5be16e2f665131 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/l2_norm_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_L2NORM_FP32_H_ +#define MINDSPORE_NNACL_FP32_L2NORM_FP32_H_ + +#include "nnacl/l2_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int CalcThreadSquareSum(const float *input_ptr, float *sum, int begin, int end); +int ThreadDivSqrtSum(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const float sqrt_sum, + const int begin, const int end); +int ThreadTrailingAxis(const float *input_ptr, float *output_ptr, const L2NormParameter *param, const int begin, + const int end); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_L2NORM_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/layer_norm_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/layer_norm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8ea8c3266faf5df1cf56538da1f0f0951c223a80 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/layer_norm_fp32.c @@ -0,0 +1,93 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/layer_norm_fp32.h" +#include +#include "nnacl/errorcode.h" +#include "nnacl/op_base.h" +#include "nnacl/layer_norm_fp32_simd.h" + +int LayerNormMeanAndSquare(const float *src, int num, float *mean, float *variance) { + if (num <= 0) { + return NNACL_ERR; + } + int index = 0; + float square_mean = 0.f; + + SIMD_RUN_NO_SCALAR(LayerNormMeanAndSquare, index, src, num, mean, &square_mean); + + for (; index < num; index++) { + *mean += src[index]; + square_mean += src[index] * src[index]; + } + *mean /= (float)num; + square_mean /= (float)num; + *variance = square_mean - (*mean) * (*mean); + return NNACL_OK; +} + +void LayerNormGammaAndBeta(float *dst, const float *src, const float *gamma_data, const float *beta_data, int num, + const float mean, const float deno) { + int index = 0; + + SIMD_RUN_NO_SCALAR(LayerNormGammaAndBeta, index, dst, src, gamma_data, beta_data, num, mean, deno); + + for (; index < num; index++) { + dst[index] = (src[index] - mean) * (deno); + dst[index] = dst[index] * gamma_data[index] + beta_data[index]; + } +} + +int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, float *out_mean, + float *out_variance, const LayerNormComputeParam *param, int task_id, int thread_num) { + if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { + return NNACL_NULL_PTR; + } + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_); + NNACL_CHECK_ZERO_RETURN_ERR(param->params_outer_size_); + int step = UP_DIV(param->norm_outer_size_, thread_num); + int thread_end = MSMIN(((int)task_id + 1) * step, param->norm_outer_size_); + for (int i = task_id * step; i < thread_end; i++) { + const float *src_norm = src_data + i * param->norm_inner_size_; + float *dst_norm = dst_data + i * param->norm_inner_size_; + float cur_mean = 0.0f; + float cur_variance = 0.0f; + int ret = LayerNormMeanAndSquare(src_norm, param->norm_inner_size_, &cur_mean, &cur_variance); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + if (out_mean != NULL) { + out_mean[i] = cur_mean; + } + if (out_variance != NULL) { + out_variance[i] = cur_variance; + } + const float deno = 1 / sqrtf(cur_variance + param->epsilon_); + if (param->norm_outer_size_ <= param->params_outer_size_) { + for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) { + const float *src_param = src_norm + x * param->params_inner_size_; + float *dst_param = dst_norm + x * param->params_inner_size_; + LayerNormGammaAndBeta(dst_param, src_param, gamma_data, beta_data, param->params_inner_size_, cur_mean, deno); + } + } else { + int x = i / param->params_outer_size_; + const float *gamma = gamma_data + x * param->norm_inner_size_; + const float *beta = beta_data + x * param->norm_inner_size_; + LayerNormGammaAndBeta(dst_norm, src_norm, gamma, beta, param->norm_inner_size_, cur_mean, deno); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/layer_norm_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/layer_norm_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..86d9a905b742a4bc659a6090f248a0b3d3b6fa93 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/layer_norm_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_LAYER_NORM_FP32_H_ +#define NNACL_FP32_LAYER_NORM_FP32_H_ + +#include "nnacl/op_base.h" +#include "nnacl/layer_norm_parameter.h" +#include "nnacl/kernel/layer_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, float *out_mean, + float *out_variance, const LayerNormComputeParam *param, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_LAYER_NORM_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/layer_norm_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/layer_norm_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..c93d0975a1d6758d66706639442a029a72f2691e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/layer_norm_fp32_simd.h.in @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_LAYER_NORM_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_LAYER_NORM_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int LayerNormMeanAndSquare@SIMD_INSTRUCTION@(int index, const float *src, int num, float *mean, float *square_mean) { + if (num >= 4 * BLOCK_NUM) { + SIMD_F32 sum_val = SIMD_SET0_F32; + SIMD_F32 square_sum_val = SIMD_SET0_F32; + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 value = SIMD_LD_F32(src + index); + SIMD_F32 square_value = SIMD_MUL_F32(value, value); + sum_val = SIMD_ADD_F32(sum_val, value); + square_sum_val = SIMD_ADD_F32(square_sum_val, square_value); + } + *mean += SIMD_GET_SUM_F32(sum_val); + *square_mean += SIMD_GET_SUM_F32(square_sum_val); + } + return index; +} + +static inline int LayerNormGammaAndBeta@SIMD_INSTRUCTION@(int index, float *dst, const float *src, const float *gamma_data, + const float *beta_data, int num, const float mean, const float deno) { + SIMD_F32 mean_val = SIMD_MOV_F32(mean); + SIMD_F32 deno_val = SIMD_MOV_F32(deno); + for (int block_max_size = num - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 value = SIMD_LD_F32(src + index); + SIMD_F32 out_value = SIMD_SUB_F32(value, mean_val); + out_value = SIMD_MUL_F32(out_value, deno_val); + out_value = SIMD_FMADD_F32(out_value, SIMD_LD_F32(gamma_data + index), SIMD_LD_F32(beta_data + index)); + SIMD_ST_F32(dst + index, out_value); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/local_response_norm_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/local_response_norm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7fdc8baa25c982acfc732b2554fb48d51089b5ef --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/local_response_norm_fp32.c @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/local_response_norm_fp32.h" +#include +#include "nnacl/errorcode.h" + +int LocalResponseNorm(const float *input_ptr, int out_size, int channel, float *output_ptr, + const LocalResponseNormParameter *param) { + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + NNACL_CHECK_NULL_RETURN_ERR(param); + int64_t depth_radius = param->depth_radius_; + float bias = param->bias_; + float alpha = param->alpha_; + float beta = param->beta_; + + for (int i = 0; i < out_size; i++) { + const float *in_data = input_ptr + i * channel; + float *out_data = output_ptr + i * channel; + // border_left + for (int j = 0; j < MSMIN(depth_radius, channel); j++) { + int left = MSMAX(0, j - depth_radius); + int right = MSMIN(channel - 1, j + depth_radius); + float sum = 0.0f; + for (int k = left; k <= right; k++) { + const float in_val = in_data[k]; + sum += in_val * in_val; + } + out_data[j] = in_data[j] * (float)(powf(sum * alpha + bias, -beta)); + } + // center + if (2 * depth_radius + 1 < channel) { + float tmp_sum = 0.0f; + for (int j = 0; j < depth_radius * 2 + 1; ++j) { + tmp_sum += in_data[j] * in_data[j]; + } + out_data[depth_radius] = in_data[depth_radius] * (powf(tmp_sum * alpha + bias, -beta)); + for (int j = depth_radius + 1; j < channel - depth_radius; ++j) { + tmp_sum -= in_data[j - depth_radius - 1] * in_data[j - depth_radius - 1]; + tmp_sum += in_data[j + depth_radius] * in_data[j + depth_radius]; + out_data[j] = in_data[j] * (float)(powf(tmp_sum * alpha + bias, -beta)); + } + } + // border_right + for (int j = MSMAX(0, channel - depth_radius); j < channel; j++) { + int left = MSMAX(0, j - depth_radius); + int right = MSMIN(channel - 1, j + depth_radius); + float sum = 0.0f; + for (int k = left; k <= right; k++) { + const float in_val = in_data[k]; + sum += in_val * in_val; + } + out_data[j] = in_data[j] * (float)(powf(sum * alpha + bias, -beta)); + } + } + return 0; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/local_response_norm_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/local_response_norm_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..bd5a83c8e1439ace1e5fea44e16fe0c32450b8d8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/local_response_norm_fp32.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_LOCAL_RESPONSE_NORM_FP32_H_ +#define NNACL_FP32_LOCAL_RESPONSE_NORM_FP32_H_ + +#include "nnacl/op_base.h" +#include "nnacl/local_response_norm_parameter.h" + +int LocalResponseNorm(const float *input_ptr, int out_size, int channel, float *output_ptr, + const LocalResponseNormParameter *param); + +#endif // NNACL_FP32_LOCAL_RESPONSE_NORM_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/log_softmax_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/log_softmax_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..758580158da5cd84659c4bc77dd05700820771ec --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/log_softmax_fp32.c @@ -0,0 +1,85 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/log_softmax_fp32.h" +#include +#include "nnacl/fp32/softmax_fp32.h" +#include "nnacl/fp32/exp_fp32.h" + +void LogSoftmaxLastAxis(const float *src, float *dst, float *exp_data, int batch, int channel) { + SoftmaxNorm(src, dst, batch, channel); + ExpFp32(dst, exp_data, batch * channel); + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + float sum = 0; + int j = 0; +#ifdef ENABLE_NEON + float32x4_t sum4 = vdupq_n_f32(0); + int count = (channel / C4NUM) * C4NUM; + for (; j < count; j += C4NUM) { + sum4 = vaddq_f32(sum4, vld1q_f32(exp_data + cur_batch_offset + j)); + } + sum = sum4[0] + sum4[1] + sum4[2] + sum4[3]; +#endif + for (; j < channel; j++) { + sum += exp_data[cur_batch_offset + j]; + } + for (int k = 0; k < channel; k++) { + dst[cur_batch_offset + k] = dst[cur_batch_offset + k] - logf(sum); + } + } +} + +// output = (input - reduce_max(input, axis)) - log(reduce_sum(exp(input - reduce_max(input, axis)), axis)) +void LogSoftmax(const float *input_ptr, float *output_ptr, float *sum_data, int32_t *input_shape, int n_dim, int axis) { + int inner_size = 1; + int outter_size = 1; + + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float max_data = input_ptr[inner_offset]; + sum_data[k + sum_outter_offset] = 0; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = input_ptr[axis_offset] - max_data; + sum_data[k + sum_outter_offset] += expf(output_ptr[axis_offset]); + } + } + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] - logf(sum_data[k + sum_outter_offset]); + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/log_softmax_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/log_softmax_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..824664755b67e5dcc9e1d4a822de2ac00ecee5a4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/log_softmax_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_LOG_SOFTMAX_FP32_H_ +#define NNACL_FP32_LOG_SOFTMAX_FP32_H_ + +#include "nnacl/op_base.h" +#include "nnacl/softmax_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void LogSoftmax(const float *input_ptr, float *output_ptr, float *sum_data, int32_t *input_shape, int n_dim, int axis); +void LogSoftmaxLastAxis(const float *src, float *dst, float *exp_data, int batch, int channel); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_LOG_SOFTMAX_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/lstm_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/lstm_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1d19e5ba651d8b1564a605680deedce899a8ed81 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/lstm_fp32.c @@ -0,0 +1,328 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/lstm_fp32.h" +#include +#include +#include "nnacl/fp32/activation_fp32.h" +#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32/pack_fp32.h" + +static void PackLstmMatrix(const float *src_batch, float *dst_batch, int col, int deep) { +#ifdef ENABLE_AVX + RowMajor2Col16Major(src_batch, dst_batch, col, deep); +#elif defined(ENABLE_ARM32) + RowMajor2Col4Major(src_batch, dst_batch, col, deep); +#else + RowMajor2Col8Major(src_batch, dst_batch, col, deep); +#endif +} + +static void PackLstmWeightBatch(float *dst, const float *src, int batch, int deep, int col, int col_align, + const int32_t *order) { + for (int i = 0; i < batch; i++) { + const float *src_batch = src + i * col * deep; + float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align * deep; + PackLstmMatrix(src_batch, dst_batch, col, deep); + } +} + +void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int32_t *order) { + PackLstmWeightBatch(dst, src, batch, deep, col, col_align, order); +} + +void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align, + bool is_bidirectional, int stride, const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + PackLstmWeightBatch(dst, src, unidirectional_batch, deep, col, col_align, order); + src += stride; + dst += unidirectional_batch * col_align * deep; + if (is_bidirectional) { + PackLstmWeightBatch(dst, src, unidirectional_batch, deep, col, col_align, order); + } +} + +void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float *src_batch = src + i * col; + float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align; + (void)memcpy(dst_batch, src_batch, col * sizeof(float)); + } + if (is_bidirectional) { + const float *backward_src = src + batch * col; + float *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float *backward_src_batch = backward_src + i * col; + float *backward_dst_batch = backward_dst + ((order == NULL) ? i : order[i]) * col_align; + (void)memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float)); + } + } +} + +void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + int b_stride, const int32_t *order) { + int unidirectional_batch = is_bidirectional ? batch / 2 : batch; + for (int i = 0; i < unidirectional_batch; i++) { + const float *src_batch = src + i * col; + float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align; + (void)memcpy(dst_batch, src_batch, col * sizeof(float)); + } + if (is_bidirectional) { + const float *backward_src = src + b_stride; + float *backward_dst = dst + unidirectional_batch * col_align; + for (int i = 0; i < unidirectional_batch; i++) { + const float *backward_src_batch = backward_src + i * col; + float *backward_dst_batch = backward_dst + ((order == NULL) ? i : order[i]) * col_align; + (void)memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float)); + } + } +} + +void PackLstmInput(const float *src, float *dst, int row, int deep) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src, dst, row, deep); +#elif defined(ENABLE_SSE) + RowMajor2Col4Major(src, dst, row, deep); +#else + RowMajor2Col12Major(src, dst, row, deep); +#endif +} + +void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, int col_align, + bool is_vec, float *packed_ptr) { + if (is_vec) { +#ifdef ENABLE_AVX + bool need_packed = col % C8NUM; + if (!need_packed) { + packed_ptr = c; + } + MatVecMulAvxFp32(a, b, packed_ptr, bias, ActType_No, deep, col, col_align); + if (need_packed) { + PackNHWCXToNHWCFp32(packed_ptr, c, 1, row, col, C8NUM); + } +#else + MatVecMulFp32(a, b, c, bias, ActType_No, deep, col); +#endif + } else { + MatMulOpt(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); + } +} + +void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size) { + int index = 0; +#ifdef ENABLE_ARM + for (; index <= element_size - 4; index += 4) { + float32x4_t in_0 = vld1q_f32(input0 + index); + float32x4_t in_1 = vld1q_f32(input1 + index); + float32x4_t out = vld1q_f32(output + index); + out = vmlaq_f32(out, in_1, in_0); + vst1q_f32(output + index, out); + } +#endif + for (; index < element_size; index++) { + output[index] += input0[index] * input1[index]; + } +} + +int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= element_size - 4; index += C4NUM) { + float32x4_t vin0 = vld1q_f32(input0 + index); + float32x4_t vout = vld1q_f32(output + index); + vout = vmlaq_n_f32(vout, vin0, input1); + vst1q_f32(output + index, vout); + } +#endif + for (; index < element_size; index++) { + output[index] += input0[index] * input1; + } + return NNACL_OK; +} + +void UpdateState(float *cell_state, const float *forget_gate, const float *input_gate, const float *cell_gate, + float *state_buffer, int batch, int hidden_size, const float zoneout) { + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { // zoneout * old_cell_state + (void)memcpy(state_buffer, cell_state, batch * hidden_size * sizeof(float)); + ElementOptMul(state_buffer, &zoneout, state_buffer, batch * hidden_size, false); + } + + ElementMul(forget_gate, cell_state, cell_state, batch * hidden_size); + ElementMulAcc(input_gate, cell_gate, cell_state, batch * hidden_size); + + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { // (1 - zoneout) * new_cell_state + ElementOptMulAcc(cell_state, 1 - zoneout, state_buffer, batch * hidden_size); + } +} + +void UpdateOutput(float *hidden_state, float *output, const float *cell_state, const float *output_gate, + const float *weight_project, float *buffer[C8NUM], const LstmParameter *lstm_param) { + int batch = lstm_param->batch_; + int hidden_size = lstm_param->hidden_size_; + int output_size = lstm_param->output_size_; + float *state_buffer = buffer[C4NUM]; + float *hidden_buffer = weight_project ? buffer[C2NUM] : hidden_state; + float zoneout = lstm_param->zoneout_hidden_; + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { + (void)memcpy(state_buffer, hidden_state, batch * output_size * sizeof(float)); + ElementOptMul(state_buffer, &zoneout, state_buffer, batch * output_size, false); + } + + Tanh(cell_state, batch * hidden_size, hidden_buffer); + ElementMul(hidden_buffer, output_gate, hidden_buffer, batch * hidden_size); + + if (weight_project) { + float *left_matrix = hidden_buffer; + if (batch != 1) { + left_matrix = buffer[C6NUM]; + PackLstmInput(hidden_buffer, left_matrix, batch, hidden_size); + } + LstmMatMul(hidden_state, left_matrix, weight_project, NULL, batch, hidden_size, output_size, + lstm_param->proj_col_align_, batch == 1, buffer[C7NUM]); + } + if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { + ElementOptMulAcc(hidden_state, 1 - zoneout, state_buffer, batch * output_size); + } + (void)memcpy(output, hidden_state, batch * output_size * sizeof(float)); +} + +void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep, + int col, int col_align, bool is_vec, float *packed_ptr) { + const float *weight_i = weight; + const float *bias_i = bias; + float *gate_i = gate_buffer; + for (int i = 0; i < 4; i++) { + LstmMatMul(gate_i, input, weight_i, bias_i, row, deep, col, col_align, is_vec, packed_ptr); + +#ifdef ENABLE_AVX + weight_i += deep * col_align; +#else + weight_i += deep * (is_vec ? col : col_align); +#endif + bias_i += col_align; + gate_i += row * col; + } +} + +void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate, + const float *state_weight, const float *state_bias, const float *weight_project, float *hidden_state, + float *cell_state, float *buffer[C8NUM], const LstmParameter *lstm_param) { + float *packed_state = buffer[1]; + float *state_gate = buffer[C2NUM]; + float *cell_buffer = buffer[C3NUM]; + float *hidden_buffer = buffer[C4NUM]; + float *packed_output = buffer[C5NUM]; + bool is_vec = lstm_param->batch_ == 1; + // state * weight + if (is_vec) { + UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output); + } else { + // pack state for matmul + PackLstmInput(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_); + UpdateLstmGate(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, + lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output); + } + ElementAdd(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(cell_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 3, cell_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + ElementAdd(output_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_, output_gate, + lstm_param->batch_ * lstm_param->hidden_size_); + + // update input_gate + Sigmoid(input_gate, lstm_param->batch_ * lstm_param->hidden_size_, input_gate); + + // update forget_gate + Sigmoid(forget_gate, lstm_param->batch_ * lstm_param->hidden_size_, forget_gate); + + // update cell_gate + Tanh(cell_gate, lstm_param->batch_ * lstm_param->hidden_size_, cell_gate); + // update cell state + UpdateState(cell_state, forget_gate, input_gate, cell_gate, cell_buffer, lstm_param->batch_, lstm_param->hidden_size_, + lstm_param->zoneout_cell_); + + // update output_gate + Sigmoid(output_gate, lstm_param->batch_ * lstm_param->hidden_size_, output_gate); + // update output + UpdateOutput(hidden_state, output, cell_state, output_gate, weight_project, buffer, lstm_param); + + if (!(lstm_param->zoneout_cell_ >= -FLT_EPSILON && lstm_param->zoneout_cell_ <= FLT_EPSILON)) { + (void)memcpy(cell_state, cell_buffer, lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float)); + } + + if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) { + (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->output_size_ * sizeof(float)); + } +} + +void LstmUnidirectional(float *output, const float *packed_input, const float *weight_i, const float *weight_h, + const float *input_bias, const float *state_bias, float *hidden_state, float *cell_state, + float *buffer[C8NUM], const LstmParameter *lstm_param, bool is_backward) { + float *gate = buffer[0]; + for (int i = 0; i < 4; i++) { + const float *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; + const float *bias_loop = input_bias + lstm_param->input_col_align_ * i; + float *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i; + MatMulOpt(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, + lstm_param->seq_len_ * lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, + OutType_Nhwc); + } + + float *input_gate = gate; + float *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2; + float *cell_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 3; + float *output_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_; + for (int t = 0; t < lstm_param->seq_len_; t++) { + int real_t = is_backward ? lstm_param->seq_len_ - t - 1 : t; + float *input_gate_t = input_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *forget_gate_t = forget_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *cell_gate_t = cell_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *output_gate_t = output_gate + lstm_param->batch_ * lstm_param->hidden_size_ * real_t; + float *output_ptr = output + real_t * lstm_param->output_step_; + LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, NULL, + hidden_state, cell_state, buffer, lstm_param); + } +} + +void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, + const float *state_bias, float *hidden_state, float *cell_state, float *buffer[C9NUM], + const LstmParameter *lstm_param) { + // forward + float *packed_input = buffer[0]; + buffer += 1; + PackLstmInput(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_); + LstmUnidirectional(output, packed_input, weight_i, weight_h, input_bias, state_bias, hidden_state, cell_state, buffer, + lstm_param, false); + + // backward + if (lstm_param->bidirectional_) { + const float *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_; + const float *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->output_size_; + const float *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_; + const float *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_; + float *backward_output = output + lstm_param->batch_ * lstm_param->output_size_; + float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; + float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->output_size_; + + LstmUnidirectional(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, + backward_state_bias, backward_hidden_state, backward_cell_state, buffer, lstm_param, true); + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/lstm_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/lstm_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..88dd9d16884cce355ef2a210b999ce4795b312fe --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/lstm_fp32.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_LSTM_H_ +#define MINDSPORE_NNACL_FP32_LSTM_H_ + +#include "nnacl/lstm_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int32_t *order); + +void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align, + bool is_bidirectional, int stride, const int32_t *order); + +void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + const int32_t *order); + +void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, + int b_stride, const int32_t *order); + +void PackLstmInput(const float *src, float *dst, int row, int deep); + +void LstmMatMul(float *c, const float *a, const float *b, const float *bias, int row, int deep, int col, int col_align, + bool is_vec, float *packed_ptr); + +void ElementMulAcc(const float *input0, const float *input1, float *output, int element_size); + +int ElementOptMulAcc(const float *input0, const float input1, float *output, const int element_size); + +void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *cell_gate, float *output_gate, + const float *state_weight, const float *state_bias, const float *weight_project, float *hidden_state, + float *cell_state, float *buffer[C8NUM], const LstmParameter *lstm_param); + +void Lstm(float *output, const float *input, const float *weight_i, const float *weight_h, const float *input_bias, + const float *state_bias, float *hidden_state, float *cell_state, float *buffer[C9NUM], + const LstmParameter *lstm_param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_LSTM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..08889011f61050da80e92bc84e57ed4f59b305dd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_fp32.c @@ -0,0 +1,248 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/matmul_avx512_fp32.h" +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +void GemmRowxColKernelFp32(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + __m512 dst_data[27]; + const float *src_sw[20]; + __m512 weight_data[6]; + for (int i = 0; i < C6NUM; ++i) { + weight_data[i] = _mm512_set1_ps(0.0f); + } + for (int i = 0; i < row_block; ++i) { + if (inc_flag & 0x01) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_loadu_ps(dst + i * dst_stride + j * C16NUM); + } + } else if (bias != NULL) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_loadu_ps(bias + j * C16NUM); + } + } else { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_set1_ps(0.0f); + } + } + src_sw[i] = src + i * src_stride; + } + const float *weight_kernel = weight; + for (int k = 0; k < depth; ++k) { + for (int j = 0; j < col_block; ++j) { + weight_data[j] = _mm512_loadu_ps(weight_kernel + j * C16NUM); + } + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = + _mm512_fmadd_ps(_mm512_set1_ps(src_sw[i][k]), weight_data[j], dst_data[i * col_block + j]); + } + } + weight_kernel += C16NUM * col_block; + } // k loop + // add bias and relu + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + if (inc_flag & 0x02) { + if (0x1 & act_flag) { // relu6 + dst_data[i * col_block + j] = _mm512_min_ps(dst_data[i * col_block + j], _mm512_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * col_block + j] = _mm512_max_ps(dst_data[i * col_block + j], _mm512_set1_ps(0.0f)); + } + } + _mm512_storeu_ps(dst + i * dst_stride + j * C16NUM, dst_data[i * col_block + j]); + } + } +} + +void MatMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, const int depth, + const int cur_col, const int col_align, const int row) { + int k_block = C1500NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } + GemmAvx512Kernel kernel[C4NUM][C13NUM]; + int max_shape[C4NUM] = {C12NUM, C12NUM, C8NUM, C6NUM}; + +#ifdef ENABLE_DEBUG + for (int i = 0; i < C4NUM; i++) { + for (int j = 0; j < C13NUM; j++) { + kernel[i][j] = GemmRowxColKernelFp32; + } + } +#else + kernel[0][1] = nnacl_gemm_avx512_1x16_kernel_nhwc_fp32; + kernel[0][2] = nnacl_gemm_avx512_2x16_kernel_nhwc_fp32; + kernel[0][3] = nnacl_gemm_avx512_3x16_kernel_nhwc_fp32; + kernel[0][4] = nnacl_gemm_avx512_4x16_kernel_nhwc_fp32; + kernel[0][5] = nnacl_gemm_avx512_5x16_kernel_nhwc_fp32; + kernel[0][6] = nnacl_gemm_avx512_6x16_kernel_nhwc_fp32; + kernel[0][7] = nnacl_gemm_avx512_7x16_kernel_nhwc_fp32; + kernel[0][8] = nnacl_gemm_avx512_8x16_kernel_nhwc_fp32; + kernel[0][9] = nnacl_gemm_avx512_9x16_kernel_nhwc_fp32; + kernel[0][10] = nnacl_gemm_avx512_10x16_kernel_nhwc_fp32; + kernel[0][11] = nnacl_gemm_avx512_11x16_kernel_nhwc_fp32; + kernel[0][12] = nnacl_gemm_avx512_12x16_kernel_nhwc_fp32; + + kernel[1][1] = nnacl_gemm_avx512_1x32_kernel_nhwc_fp32; + kernel[1][2] = nnacl_gemm_avx512_2x32_kernel_nhwc_fp32; + kernel[1][3] = nnacl_gemm_avx512_3x32_kernel_nhwc_fp32; + kernel[1][4] = nnacl_gemm_avx512_4x32_kernel_nhwc_fp32; + kernel[1][5] = nnacl_gemm_avx512_5x32_kernel_nhwc_fp32; + kernel[1][6] = nnacl_gemm_avx512_6x32_kernel_nhwc_fp32; + kernel[1][7] = nnacl_gemm_avx512_7x32_kernel_nhwc_fp32; + kernel[1][8] = nnacl_gemm_avx512_8x32_kernel_nhwc_fp32; + kernel[1][9] = nnacl_gemm_avx512_9x32_kernel_nhwc_fp32; + kernel[1][10] = nnacl_gemm_avx512_10x32_kernel_nhwc_fp32; + kernel[1][11] = nnacl_gemm_avx512_11x32_kernel_nhwc_fp32; + kernel[1][12] = nnacl_gemm_avx512_12x32_kernel_nhwc_fp32; + + kernel[2][1] = nnacl_gemm_avx512_1x48_kernel_nhwc_fp32; + kernel[2][2] = nnacl_gemm_avx512_2x48_kernel_nhwc_fp32; + kernel[2][3] = nnacl_gemm_avx512_3x48_kernel_nhwc_fp32; + kernel[2][4] = nnacl_gemm_avx512_4x48_kernel_nhwc_fp32; + kernel[2][5] = nnacl_gemm_avx512_5x48_kernel_nhwc_fp32; + kernel[2][6] = nnacl_gemm_avx512_6x48_kernel_nhwc_fp32; + kernel[2][7] = nnacl_gemm_avx512_7x48_kernel_nhwc_fp32; + kernel[2][8] = nnacl_gemm_avx512_8x48_kernel_nhwc_fp32; + + kernel[3][1] = nnacl_gemm_avx512_1x64_kernel_nhwc_fp32; + kernel[3][2] = nnacl_gemm_avx512_2x64_kernel_nhwc_fp32; + kernel[3][3] = nnacl_gemm_avx512_3x64_kernel_nhwc_fp32; + kernel[3][4] = nnacl_gemm_avx512_4x64_kernel_nhwc_fp32; + kernel[3][5] = nnacl_gemm_avx512_5x64_kernel_nhwc_fp32; + kernel[3][6] = nnacl_gemm_avx512_6x64_kernel_nhwc_fp32; +#endif + + int inc_flag; + for (int k = 0; k < depth; k += k_block) { + if (depth - k <= k_block) { + k_block = depth - k; + inc_flag = C3NUM - (k == 0); + } else { + inc_flag = 1 - (k == 0); + } + const float *bias_data = bias; + // one time process 64 out_channel + int col_block = C64NUM; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + col_block = MSMIN(col_block, cur_col - col_index); + int row_block = max_shape[(col_block >> C4NUM) - 1]; + for (int m = 0; m < row; m += row_block) { + row_block = MSMIN(row_block, row - m); + kernel[(col_block >> C4NUM) - 1][row_block](c + col_index + m * col_align, a + m * depth + k, + b + col_index * depth + k * col_block, bias_data, act_flag, + row_block, col_block >> C4NUM, k_block, depth, col_align, inc_flag); + } + if (bias_data != NULL) { + bias_data += col_block; + } + } + } +} + +void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_align) { + // one time process 64 out_channel + int k_block = C1500NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } +#ifdef ENABLE_DEBUG + GemmAvx512Kernel kernel[C4NUM] = {GemmRowxColKernelFp32, GemmRowxColKernelFp32, GemmRowxColKernelFp32, + GemmRowxColKernelFp32}; +#else + GemmAvx512Kernel kernel[C4NUM] = {nnacl_gemm_avx512_1x16_kernel_nhwc_fp32, nnacl_gemm_avx512_1x32_kernel_nhwc_fp32, + nnacl_gemm_avx512_1x48_kernel_nhwc_fp32, nnacl_gemm_avx512_1x64_kernel_nhwc_fp32}; +#endif + int inc_flag; + for (int k = 0; k < depth; k += k_block) { + if (depth - k <= k_block) { + k_block = depth - k; + inc_flag = C3NUM - (k == 0); + } else { + inc_flag = 1 - (k == 0); + } + const float *bias_data = bias; + int col_block = C64NUM; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + col_block = MSMIN(col_block, cur_col - col_index); + kernel[(col_block >> C4NUM) - 1](c + col_index, a + k, b + col_index * depth + k * col_block, bias_data, act_flag, + 1, col_block >> C4NUM, k_block, depth, col_align, inc_flag); + if (bias_data != NULL) { + bias_data += col_block; + } + } + } +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +int64_t GemmIsNotPackOptimizeAVX512(int64_t m_index, const float *a, const float *b, float *c, const float *bias, int m, + int k, int act_type) { + // gemm dot is [m, k] * [k, 1] ==>> [m, 1] + // block 8 + MS_FLOAT32X8 down_threshold256 = _mm256_setzero_ps(); + MS_FLOAT32X8 up_threshold256 = _mm256_set1_ps(C6NUM); + for (; m_index <= m - C8NUM; m_index += C8NUM) { + int k_index = 0; + MS_FLOAT32X8 dst = MS_MOV256_F32(bias[0]); + MS_SET_ZERO512X8_F32(dst16_) + for (; k_index <= k - C16NUM; k_index += C16NUM) { + __m512 weight = _mm512_loadu_ps(b + k_index); + MS_LOAD512X8_F32(src, a + m_index * k + k_index, k) + MS_FMADD512X8_F32(src, weight, dst16_) + } + MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD512_F32(dst16_1); + MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD512_F32(dst16_2); + MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD512_F32(dst16_3); + MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD512_F32(dst16_4); + MS_F32X8_GETI(dst, C4NUM) += MS_REDUCE_ADD512_F32(dst16_5); + MS_F32X8_GETI(dst, C5NUM) += MS_REDUCE_ADD512_F32(dst16_6); + MS_F32X8_GETI(dst, C6NUM) += MS_REDUCE_ADD512_F32(dst16_7); + MS_F32X8_GETI(dst, C7NUM) += MS_REDUCE_ADD512_F32(dst16_8); + for (; k_index < k; k_index++) { + MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index]; + MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k]; + MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k]; + MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k]; + MS_F32X8_GETI(dst, C4NUM) += b[k_index] * a[m_index * k + k_index + C4NUM * k]; + MS_F32X8_GETI(dst, C5NUM) += b[k_index] * a[m_index * k + k_index + C5NUM * k]; + MS_F32X8_GETI(dst, C6NUM) += b[k_index] * a[m_index * k + k_index + C6NUM * k]; + MS_F32X8_GETI(dst, C7NUM) += b[k_index] * a[m_index * k + k_index + C7NUM * k]; + } + + if (act_type != 0) { + dst = MS_MAX256_F32(dst, down_threshold256); + if (act_type == 3) { // 3: relu6 + dst = MS_MIN256_F32(dst, up_threshold256); + } + } + + MS_ST256_F32(c + m_index, dst); + } + return m_index; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..565304dc42d5ff576f63e8965ae2ac30a48e01a8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_fp32.h @@ -0,0 +1,198 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_MATMUL_AVX512_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_AVX512_H_ +#include "nnacl/op_base.h" +typedef void (*GemmAvx512Kernel)(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +#ifdef __cplusplus +extern "C" { +#endif +void MatVecMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int cur_col, int col_align); + +void MatMulAvx512Fp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, + int col_align, int row); + +int64_t GemmIsNotPackOptimizeAVX512(int64_t m_index, const float *a, const float *b, float *c, const float *bias, int m, + int k, int act_type); + +// 64 block +void nnacl_gemm_avx512_6x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_5x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_4x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_3x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_2x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_1x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); + +// 48 block +void nnacl_gemm_avx512_8x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_7x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_6x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_5x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_4x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_3x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_2x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_1x48_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); + +// 32 block +void nnacl_gemm_avx512_12x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_11x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_10x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_9x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_8x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_7x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_6x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_5x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_4x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_3x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_2x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_1x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); + +// 16 block +void nnacl_gemm_avx512_12x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_11x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_10x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_9x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_8x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_7x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_6x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_5x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_4x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_3x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_2x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +void nnacl_gemm_avx512_1x16_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_MATMUL_AVX512_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_mask_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_mask_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..6eec5a54d6693f35217aefc2f0f69db9fc494815 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_mask_fp32.c @@ -0,0 +1,236 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +void GemmRowxColMaskKernelFp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask) { + __m512 dst_data[27]; + const float *src_sw[20]; + __m512 weight_data[6]; + __mmask16 mask16 = (__mmask16)(*mask); + for (int i = 0; i < C6NUM; ++i) { + weight_data[i] = _mm512_set1_ps(0.0f); + } + for (int i = 0; i < row_block; ++i) { + if (inc_flag & 0x01) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_loadu_ps(dst + i * dst_stride + j * C16NUM); + } + } else if (bias != NULL) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_loadu_ps(bias + j * C16NUM); + } + } else { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm512_set1_ps(0.0f); + } + } + src_sw[i] = src + i * src_stride; + } + const float *weight_kernel = weight; + for (int k = 0; k < depth; ++k) { + for (int j = 0; j < col_block; ++j) { + weight_data[j] = _mm512_loadu_ps(weight_kernel + j * C16NUM); + } + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + if (j == col_block - 1) { + dst_data[i * col_block + j] = + _mm512_mask3_fmadd_ps(_mm512_set1_ps(src_sw[i][k]), weight_data[j], dst_data[i * col_block + j], mask16); + } else { + dst_data[i * col_block + j] = + _mm512_fmadd_ps(_mm512_set1_ps(src_sw[i][k]), weight_data[j], dst_data[i * col_block + j]); + } + } + } + weight_kernel += C16NUM * col_block; + } // k loop + // add bias and relu + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + if (j == col_block - 1) { + if (inc_flag & 0x02) { + if (0x1 & act_flag) { // relu6 + dst_data[i * col_block + j] = + _mm512_maskz_min_ps(mask16, dst_data[i * col_block + j], _mm512_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * col_block + j] = + _mm512_maskz_max_ps(mask16, dst_data[i * col_block + j], _mm512_set1_ps(0.0f)); + } + } + _mm512_mask_storeu_ps(dst + i * dst_stride + j * C16NUM, mask16, dst_data[i * col_block + j]); + } else { + if (inc_flag & 0x02) { + if (0x1 & act_flag) { // relu6 + dst_data[i * col_block + j] = _mm512_min_ps(dst_data[i * col_block + j], _mm512_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * col_block + j] = _mm512_max_ps(dst_data[i * col_block + j], _mm512_set1_ps(0.0f)); + } + } + _mm512_storeu_ps(dst + i * dst_stride + j * C16NUM, dst_data[i * col_block + j]); + } + } + } +} + +void MatMulMaskAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_, const int row) { + int k_block = C1500NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } + GemmAvx512MaskKernel kernel[C4NUM][C13NUM]; + int max_shape[C4NUM] = {C12NUM, C12NUM, C8NUM, C6NUM}; + +#ifdef ENABLE_DEBUG + for (int i = 0; i < C4NUM; i++) { + for (int j = 0; j < C13NUM; j++) { + kernel[i][j] = GemmRowxColMaskKernelFp32; + } + } +#else + kernel[0][1] = nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32; + kernel[0][2] = nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32; + kernel[0][3] = nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32; + kernel[0][4] = nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32; + kernel[0][5] = nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32; + kernel[0][6] = nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32; + kernel[0][7] = nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32; + kernel[0][8] = nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32; + kernel[0][9] = nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32; + kernel[0][10] = nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32; + kernel[0][11] = nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32; + kernel[0][12] = nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32; + + kernel[1][1] = nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32; + kernel[1][2] = nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32; + kernel[1][3] = nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32; + kernel[1][4] = nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32; + kernel[1][5] = nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32; + kernel[1][6] = nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32; + kernel[1][7] = nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32; + kernel[1][8] = nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32; + kernel[1][9] = nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32; + kernel[1][10] = nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32; + kernel[1][11] = nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32; + kernel[1][12] = nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32; + + kernel[2][1] = nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32; + kernel[2][2] = nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32; + kernel[2][3] = nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32; + kernel[2][4] = nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32; + kernel[2][5] = nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32; + kernel[2][6] = nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32; + kernel[2][7] = nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32; + kernel[2][8] = nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32; + + kernel[3][1] = nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32; + kernel[3][2] = nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32; + kernel[3][3] = nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32; + kernel[3][4] = nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32; + kernel[3][5] = nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32; + kernel[3][6] = nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32; +#endif + + int inc_flag; + for (int k = 0; k < depth; k += k_block) { + if (depth - k <= k_block) { + k_block = depth - k; + inc_flag = C3NUM - (k == 0); + } else { + inc_flag = 1 - (k == 0); + } + const float *bias_data = bias; + // one time process 64 out_channel + int col_block = C64NUM; + u_int16_t avx512_mask = 0xFFFF; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + int less_tmp = cur_col - col_index; + if (less_tmp < col_block) { + col_block = UP_ROUND(less_tmp, C16NUM); + avx512_mask = (0xFFFF >> (col_block - less_tmp)); + } + int col_block_num = col_block >> C4NUM; + int row_block = max_shape[col_block_num - 1]; + for (int m = 0; m < row; m += row_block) { + row_block = MSMIN(row_block, row - m); + kernel[col_block_num - 1][row_block](c + col_index + m * col_, a + m * depth + k, + b + col_index * depth + k * col_block, bias_data, act_flag, row_block, + col_block_num, k_block, depth, col_, inc_flag, &avx512_mask); + } + if (bias_data != NULL) { + bias_data += col_block; + } + } + } +} + +void MatVecMulMaskAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_) { + // one time process 64 out_channel + int k_block = C1500NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } +#ifdef ENABLE_DEBUG + GemmAvx512MaskKernel kernel[C4NUM] = {GemmRowxColMaskKernelFp32, GemmRowxColMaskKernelFp32, GemmRowxColMaskKernelFp32, + GemmRowxColMaskKernelFp32}; +#else + GemmAvx512MaskKernel kernel[C4NUM] = { + nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32, nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32, + nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32, nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32}; +#endif + int inc_flag; + for (int k = 0; k < depth; k += k_block) { + if (depth - k <= k_block) { + k_block = depth - k; + inc_flag = C3NUM - (k == 0); + } else { + inc_flag = 1 - (k == 0); + } + const float *bias_data = bias; + int col_block = C64NUM; + u_int16_t avx512_mask = 0xFFFF; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + int less_tmp = cur_col - col_index; + if (less_tmp < col_block) { + col_block = UP_ROUND(less_tmp, C16NUM); + avx512_mask = (0xFFFF >> (col_block - less_tmp)); + } + int col_block_num = col_block >> C4NUM; + + kernel[col_block_num - 1](c + col_index, a + k, b + col_index * depth + k * col_block, bias_data, act_flag, 1, + col_block_num, k_block, depth, col_, inc_flag, &avx512_mask); + if (bias_data != NULL) { + bias_data += col_block; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_mask_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_mask_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..a24b04fe48e2ccf7d5fffa85917bb46ea035882a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx512_mask_fp32.h @@ -0,0 +1,209 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_MATMUL_MASK_AVX512_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_MASK_AVX512_H_ +#include +#include +#include "nnacl/op_base.h" +typedef void (*GemmAvx512MaskKernel)(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +#ifdef __cplusplus +extern "C" { +#endif + +void GemmRowxColMaskKernelFp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); + +void MatVecMulMaskAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_); + +void MatMulMaskAvx512Fp32(const float *a, const float *b, float *c, const float *bias, const int act_type, + const int depth, const int cur_col, const int col_, const int row); + +// 64 block +void nnacl_gemm_avx512_6x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_5x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_4x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_3x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_2x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_1x64_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); + +// 48 block +void nnacl_gemm_avx512_8x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_7x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_6x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_5x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_4x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_3x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_2x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_1x48_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); + +// 32 block +void nnacl_gemm_avx512_12x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_11x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_10x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_9x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_8x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_7x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_6x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_5x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_4x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_3x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_2x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_1x32_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); + +// 16 block +void nnacl_gemm_avx512_12x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_11x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_10x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t depth, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag, + const u_int16_t *mask); +void nnacl_gemm_avx512_9x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_8x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_7x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_6x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_5x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_4x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_3x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_2x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +void nnacl_gemm_avx512_1x16_mask_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t depth, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag, const u_int16_t *mask); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_MATMUL_AVX512_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..40b1417137dfef310a07275a6c079026c97f429e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx_fp32.c @@ -0,0 +1,954 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/matmul_avx_fp32.h" +#include "nnacl/intrinsics/ms_simd_avx_instructions.h" + +void MatVecMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, + int col_align) { + // one time process 32 out_channel + int col_block = C32NUM; + int act_flag = C0NUM; + if (act_type == ActType_Relu6) { + act_flag += C1NUM; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } + MatVecMulKernel kernel[4] = {MatVecMul1x8Kernel, MatVecMul1x16Kernel, MatVecMul1x24Kernel, MatVecMul1x32Kernel}; + const float *bias_data = bias; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + col_block = cur_col - col_index < col_block ? cur_col - col_index : col_block; + kernel[(col_block >> C3NUM) - 1](c + col_index, a, b + col_index * depth, bias_data, act_flag, 1, + col_block >> C3NUM, col_align, depth); + if (bias_data != NULL) { + bias_data += col_block; + } + } +} + +void MatMulAvxFp32(const float *a, const float *b, float *c, const float *bias, const int act_type, const int depth, + const int cur_col, const int col_align, const int row) { + // one time process 32 out_channel + int col_block = C32NUM; + int act_flag = 0; + if (act_type == ActType_Relu6) { + act_flag += 1; + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + act_flag += C2NUM; + } + int row_tile[4] = {C8NUM, C6NUM, C4NUM, C3NUM}; + MatVecMulKernel kernel[4][2] = {{MatVecMul1x8Kernel, MatMul8x8Kernel}, + {MatVecMul1x16Kernel, MatMul6x16Kernel}, + {MatVecMul1x24Kernel, MatMul4x24Kernel}, + {MatVecMul1x32Kernel, MatMul3x32Kernel}}; + const float *bias_data = bias; + for (int col_index = 0; col_index < cur_col; col_index += col_block) { + col_block = cur_col - col_index < col_block ? cur_col - col_index : col_block; + int row_block = row_tile[(col_block >> C3NUM) - 1]; + for (int r = 0; r < row; r += row_block) { + if (row_block > row - r) { + row_block = 1; + } + kernel[(col_block >> C3NUM) - 1][row_block / row_tile[(col_block >> C3NUM) - 1]]( + c + col_index + r * col_align, a + r * depth, b + col_index * depth, bias_data, act_flag, row_block, + col_block >> C3NUM, col_align, depth); + } + if (bias_data != NULL) { + bias_data += col_block; + } + } +} + +void MatMul3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, size_t col_algin, const size_t deep) { + col_algin *= sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups 0x40(%2), %%ymm6\n" + "vmovups 0x60(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups 0x40(%2), %%ymm10\n" + "vmovups 0x60(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "1:\n" // deep + "vbroadcastss (%0), %%ymm12\n" // src + "vbroadcastss (%0, %7), %%ymm13\n" + "vbroadcastss (%0, %7, 2), %%ymm14\n" + "vmovups (%1), %%ymm15\n" // weight + "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm4\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm8\n" + + "vmovups 0x20(%1), %%ymm15\n" // weight + "vfmadd231ps %%ymm15, %%ymm12, %%ymm1\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm5\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm9\n" + + "vmovups 0x40(%1), %%ymm15\n" // weight + "vfmadd231ps %%ymm15, %%ymm12, %%ymm2\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm10\n" + + "vmovups 0x60(%1), %%ymm15\n" // weight + "vfmadd231ps %%ymm15, %%ymm12, %%ymm3\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm11\n" + "addq $128, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 1b\n" + + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, 0x40(%5)\n" + "vmovups %%ymm3, 0x60(%5)\n" + "vmovups %%ymm4, (%5, %6)\n" // dst_1 + "vmovups %%ymm5, 0x20(%5, %6)\n" + "vmovups %%ymm6, 0x40(%5, %6)\n" + "vmovups %%ymm7, 0x60(%5, %6)\n" + "vmovups %%ymm8, (%5, %6, 2)\n" // dst_2 + "vmovups %%ymm9, 0x20(%5, %6, 2)\n" + "vmovups %%ymm10, 0x40(%5, %6, 2)\n" + "vmovups %%ymm11, 0x60(%5, %6, 2)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst), "r"(col_algin), + "r"(deep * sizeof(float)) // 7 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void MatVecMul1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups 0x60(%2), %%ymm3\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "1:\n" // deep_c8 + "movq %3, %%rcx\n" + "shr $3, %%ecx\n" + "je 3f\n" + "2:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 4(%0), %%ymm4\n" + "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 160(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 192(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 224(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 8(%0), %%ymm4\n" + "vfmadd231ps 256(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 288(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 320(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 352(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 12(%0), %%ymm4\n" + "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 448(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 480(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 16(%0), %%ymm4\n" + "vfmadd231ps 512(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 544(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 576(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 608(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 20(%0), %%ymm4\n" + "vfmadd231ps 640(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 672(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 704(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 736(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 24(%0), %%ymm4\n" + "vfmadd231ps 768(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 800(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 832(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 864(%1), %%ymm4, %%ymm3\n" + + "vbroadcastss 28(%0), %%ymm4\n" + "vfmadd231ps 896(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 928(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 960(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 992(%1), %%ymm4, %%ymm3\n" + "addq $1024, %1\n" + "addq $32, %0\n" + "dec %%ecx\n" + "jg 2b\n" + + "3:\n" + "and $7, %3\n" // deep_remainder + "je 5f\n" + "4:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n" + "addq $128, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 4b\n" + + "5:\n" + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, 0x40(%5)\n" + "vmovups %%ymm3, 0x60(%5)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm4", "%ymm14"); +} + +void MatMul4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, size_t col_algin, const size_t deep) { + float *dst_3 = dst + C3NUM * col_algin; + col_algin *= sizeof(float); + size_t src_3_step = C3NUM * deep * sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "vmovups (%2), %%ymm3\n" + "vmovups 0x20(%2), %%ymm4\n" + "vmovups 0x40(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups 0x40(%2), %%ymm8\n" + "vmovups (%2), %%ymm9\n" + "vmovups 0x20(%2), %%ymm10\n" + "vmovups 0x40(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "1:\n" // deep + "vmovups (%1), %%ymm12\n" // weight + "vmovups 0x20(%1), %%ymm13\n" + "vmovups 0x40(%1), %%ymm14\n" + + "vbroadcastss (%0), %%ymm15\n" // src + "vfmadd231ps %%ymm15, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm2\n" + + "vbroadcastss (%0, %9), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm3\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm4\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm5\n" + + "vbroadcastss (%0, %9, 2), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm8\n" + + "vbroadcastss (%0, %7), %%ymm15\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm9\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm10\n" + "vfmadd231ps %%ymm15, %%ymm14, %%ymm11\n" + "addq $96, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 1b\n" + + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, 0x40(%5)\n" + "vmovups %%ymm3, (%5, %6)\n" + "vmovups %%ymm4, 0x20(%5, %6)\n" // dst_1 + "vmovups %%ymm5, 0x40(%5, %6)\n" + "vmovups %%ymm6, (%5, %6, 2)\n" + "vmovups %%ymm7, 0x20(%5, %6, 2)\n" + "vmovups %%ymm8, 0x40(%5, %6, 2)\n" // dst_2 + "vmovups %%ymm9, (%8)\n" + "vmovups %%ymm10, 0x20(%8)\n" + "vmovups %%ymm11, 0x40(%8)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst), "r"(col_algin), "r"(src_3_step), "r"(dst_3), + "r"(deep * sizeof(float)) // 9 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void MatVecMul1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups 0x40(%2), %%ymm2\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + + "1:\n" // deep + "movq %3, %%rcx\n" + "shr $3, %%ecx\n" + "je 3f\n" + "2:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 4(%0), %%ymm4\n" + "vfmadd231ps 96(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 128(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 160(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 8(%0), %%ymm4\n" + "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 224(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 256(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 12(%0), %%ymm4\n" + "vfmadd231ps 288(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 320(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 352(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 16(%0), %%ymm4\n" + "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 448(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 20(%0), %%ymm4\n" + "vfmadd231ps 480(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 512(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 544(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 24(%0), %%ymm4\n" + "vfmadd231ps 576(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 608(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 640(%1), %%ymm4, %%ymm2\n" + + "vbroadcastss 28(%0), %%ymm4\n" + "vfmadd231ps 672(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 704(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 736(%1), %%ymm4, %%ymm2\n" + "addq $768, %1\n" + "addq $32, %0\n" + "dec %%ecx\n" + "jg 2b\n" + + "3:\n" + "and $7, %3\n" // deep_remainder + "je 5f\n" + "4:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n" + "addq $96, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 4b\n" + + "5:\n" + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, 0x40(%5)\n" + + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm4", "%ymm14"); +} + +void MatMul6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, size_t col_algin, const size_t deep) { + float *dst_3 = dst + 3 * col_algin; + float *dst_5 = dst + 5 * col_algin; + col_algin *= sizeof(float); + size_t src_3_step = 3 * deep * sizeof(float); + size_t src_5_step = 5 * deep * sizeof(float); + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "vmovups (%2), %%ymm2\n" + "vmovups 0x20(%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups 0x20(%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups 0x20(%2), %%ymm7\n" + "vmovups (%2), %%ymm8\n" + "vmovups 0x20(%2), %%ymm9\n" + "vmovups (%2), %%ymm10\n" + "vmovups 0x20(%2), %%ymm11\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + + "1:\n" // deep + "vmovups (%1), %%ymm12\n" // weight + "vmovups 0x20(%1), %%ymm13\n" + + "vbroadcastss (%0), %%ymm14\n" // src_0 + "vbroadcastss (%0, %11), %%ymm15\n" // src_1 + "vfmadd231ps %%ymm14, %%ymm12, %%ymm0\n" + "vfmadd231ps %%ymm14, %%ymm13, %%ymm1\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm2\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm3\n" + + "vbroadcastss (%0, %11, 2), %%ymm14\n" // src_2 + "vbroadcastss (%0, %8), %%ymm15\n" // src_3 + "vfmadd231ps %%ymm14, %%ymm12, %%ymm4\n" + "vfmadd231ps %%ymm14, %%ymm13, %%ymm5\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm6\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm7\n" + + "vbroadcastss (%0, %11, 4), %%ymm14\n" // src_4 + "vbroadcastss (%0, %9), %%ymm15\n" // src_5 + "vfmadd231ps %%ymm14, %%ymm12, %%ymm8\n" + "vfmadd231ps %%ymm14, %%ymm13, %%ymm9\n" + "vfmadd231ps %%ymm15, %%ymm12, %%ymm10\n" + "vfmadd231ps %%ymm15, %%ymm13, %%ymm11\n" + + "addq $64, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 1b\n" + + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "vmaxps %%ymm12, %%ymm8, %%ymm8\n" + "vmaxps %%ymm12, %%ymm9, %%ymm9\n" + "vmaxps %%ymm12, %%ymm10, %%ymm10\n" + "vmaxps %%ymm12, %%ymm11, %%ymm11\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "vminps %%ymm14, %%ymm8, %%ymm8\n" + "vminps %%ymm14, %%ymm9, %%ymm9\n" + "vminps %%ymm14, %%ymm10, %%ymm10\n" + "vminps %%ymm14, %%ymm11, %%ymm11\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + "vmovups %%ymm2, (%5, %6)\n" // dst_1 + "vmovups %%ymm3, 0x20(%5, %6)\n" + "vmovups %%ymm4, (%5, %6, 2)\n" // dst_2 + "vmovups %%ymm5, 0x20(%5, %6, 2)\n" + "vmovups %%ymm6, (%7)\n" // dst_3 + "vmovups %%ymm7, 0x20(%7)\n" + "vmovups %%ymm8, (%5, %6, 4)\n" // dst_4 + "vmovups %%ymm9, 0x20(%5, %6, 4)\n" + "vmovups %%ymm10, (%10)\n" // dst_5 + "vmovups %%ymm11, 0x20(%10)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst), "r"(col_algin), "r"(dst_3), "r"(src_3_step), + "r"(src_5_step), "r"(dst_5), "r"(deep * sizeof(float)) // 11 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +void MatVecMul1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups 0x20(%2), %%ymm1\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "1:\n" + "movq %3, %%rcx\n" + "shr $3, %%ecx\n" + "je 3f\n" + "2:\n" // deep_c8 + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 4(%0), %%ymm4\n" + "vfmadd231ps 64(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 96(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 8(%0), %%ymm4\n" + "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 160(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 12(%0), %%ymm4\n" + "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 224(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 16(%0), %%ymm4\n" + "vfmadd231ps 256(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 288(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 20(%0), %%ymm4\n" + "vfmadd231ps 320(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 352(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 24(%0), %%ymm4\n" + "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n" + + "vbroadcastss 28(%0), %%ymm4\n" + "vfmadd231ps 448(%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 480(%1), %%ymm4, %%ymm1\n" + "addq $512, %1\n" + "addq $32, %0\n" + "dec %%ecx\n" + "jg 2b\n" + + "3:\n" + "and $7, %3\n" + "je 5f\n" + "4:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n" + "addq $64, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 4b\n" + + "5:\n" + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, 0x20(%5)\n" + + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5 + : "%rcx", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14"); +} + +void MatVecMul1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "1:\n" + "movq %3, %%rcx\n" + "shr $3, %%ecx\n" + "je 3f\n" + "2:\n" // deep_c8 + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "vbroadcastss 4(%0), %%ymm4\n" + "vfmadd231ps 32(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 8(%0), %%ymm4\n" + "vfmadd231ps 64(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 12(%0), %%ymm4\n" + "vfmadd231ps 96(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 16(%0), %%ymm4\n" + "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 20(%0), %%ymm4\n" + "vfmadd231ps 160(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 24(%0), %%ymm4\n" + "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n" + "vbroadcastss 28(%0), %%ymm4\n" + "vfmadd231ps 224(%1), %%ymm4, %%ymm0\n" + "addq $256, %1\n" + "addq $32, %0\n" + "dec %%ecx\n" + "jg 2b\n" + + "3:\n" + "and $7, %3\n" + "je 5f\n" + "4:\n" + "vbroadcastss (%0), %%ymm4\n" + "vfmadd231ps (%1), %%ymm4, %%ymm0\n" + "addq $32, %1\n" + "addq $4, %0\n" + "dec %3\n" + "jg 4b\n" + + "5:\n" + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5 + : "%rcx", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14"); +} + +void MatMul8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, const size_t act_flag, + const size_t row_block, const size_t col_block, size_t col_algin, const size_t deep) { + float *dst_5 = dst + C5NUM * col_algin; + col_algin *= sizeof(float); + size_t dst_3_step = C3NUM * col_algin; + size_t src_3_step = C3NUM * deep * sizeof(float); + const float *src_5 = C5NUM * deep + src; + asm volatile( + "cmpq $0, %2\n" + "je 0f\n" + "vmovups (%2), %%ymm0\n" + "vmovups (%2), %%ymm1\n" + "vmovups (%2), %%ymm2\n" + "vmovups (%2), %%ymm3\n" + "vmovups (%2), %%ymm4\n" + "vmovups (%2), %%ymm5\n" + "vmovups (%2), %%ymm6\n" + "vmovups (%2), %%ymm7\n" + "jmp 1f\n" + "0:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + + "1:\n" // deep + "vmovups (%1), %%ymm15\n" // weight + + "vbroadcastss (%0), %%ymm8\n" // src_0 + "vbroadcastss (%0, %11), %%ymm9\n" // src_1 + "vbroadcastss (%0, %11, 2), %%ymm10\n" // src_2 + "vbroadcastss (%0, %8), %%ymm11\n" // src_3 + "vfmadd231ps %%ymm8, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm3\n" + + "vbroadcastss (%0, %11, 4), %%ymm8\n" // src_4 + "vbroadcastss (%9), %%ymm9\n" // src_5 + "vbroadcastss (%9, %11, 1), %%ymm10\n" // src_6 + "vbroadcastss (%9, %11, 2), %%ymm11\n" // src_7 + "vfmadd231ps %%ymm8, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm5\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm7\n" + + "addq $32, %1\n" + "addq $4, %0\n" + "addq $4, %9\n" + "dec %3\n" + "jg 1b\n" + + "and $0x3, %%eax\n" // act_type + "je 6f\n" + // Relu + "vxorps %%ymm12, %%ymm12, %%ymm12\n" + "vmaxps %%ymm12, %%ymm0, %%ymm0\n" + "vmaxps %%ymm12, %%ymm1, %%ymm1\n" + "vmaxps %%ymm12, %%ymm2, %%ymm2\n" + "vmaxps %%ymm12, %%ymm3, %%ymm3\n" + "vmaxps %%ymm12, %%ymm4, %%ymm4\n" + "vmaxps %%ymm12, %%ymm5, %%ymm5\n" + "vmaxps %%ymm12, %%ymm6, %%ymm6\n" + "vmaxps %%ymm12, %%ymm7, %%ymm7\n" + "and $0x1, %%eax\n" + "je 6f\n" + // relu6 + "mov $0x40C00000, %%ecx\n" + "vmovd %%ecx, %%xmm14\n" + "vpermps %%ymm14, %%ymm12, %%ymm14\n" + "vminps %%ymm14, %%ymm0, %%ymm0\n" + "vminps %%ymm14, %%ymm1, %%ymm1\n" + "vminps %%ymm14, %%ymm2, %%ymm2\n" + "vminps %%ymm14, %%ymm3, %%ymm3\n" + "vminps %%ymm14, %%ymm4, %%ymm4\n" + "vminps %%ymm14, %%ymm5, %%ymm5\n" + "vminps %%ymm14, %%ymm6, %%ymm6\n" + "vminps %%ymm14, %%ymm7, %%ymm7\n" + "6:\n" + "vmovups %%ymm0, (%5)\n" // dst_0 + "vmovups %%ymm1, (%5, %6)\n" + "vmovups %%ymm2, (%5, %6, 2)\n" + "vmovups %%ymm3, (%5, %7)\n" + "vmovups %%ymm4, (%5, %6, 4)\n" + "vmovups %%ymm5, (%10)\n" + "vmovups %%ymm6, (%10, %6)\n" + "vmovups %%ymm7, (%10, %6, 2)\n" + : + : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst), "r"(col_algin), "r"(dst_3_step), // 7 + "r"(src_3_step), "r"(src_5), "r"(dst_5), "r"(deep * sizeof(float)) // 11 + : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} + +#ifdef ENABLE_DEBUG +void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep) { + __m256 dst_data[12]; + const float *src_sw[12]; + __m256 weight_data[4]; + for (int i = 0; i < C4NUM; ++i) { + weight_data[i] = _mm256_set1_ps(0.0f); + } + for (int i = 0; i < row_block; ++i) { + if (bias != NULL) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm256_loadu_ps(bias + j * C8NUM); + } + } else { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = _mm256_set1_ps(0.0f); + } + } + src_sw[i] = src + i * deep; + } + const float *weight_kernel = weight; + for (int ic = 0; ic < deep; ++ic) { + for (int j = 0; j < col_block; ++j) { + weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM); + } + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + dst_data[i * col_block + j] = + _mm256_fmadd_ps(_mm256_set1_ps(src_sw[i][ic]), weight_data[j], dst_data[i * col_block + j]); + } + } + weight_kernel += C8NUM * col_block; + } // ic loop + // add bias and relu + for (int i = 0; i < row_block; ++i) { + for (int j = 0; j < col_block; ++j) { + if (0x1 & act_flag) { // relu6 + dst_data[i * col_block + j] = _mm256_min_ps(dst_data[i * col_block + j], _mm256_set1_ps(6.0f)); + } + if (0x2 & act_flag) { // relu + dst_data[i * col_block + j] = _mm256_max_ps(dst_data[i * col_block + j], _mm256_set1_ps(0.0f)); + } + _mm256_storeu_ps(dst + i * col_algin + j * C8NUM, dst_data[i * col_block + j]); + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..0501a0599e45b4e66c71b8335e16bcabde8248b4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_avx_fp32.h @@ -0,0 +1,63 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_MATMUL_AVX_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_AVX_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void (*DeconvAvxKernel)(const float *src, const float *weight, float *dst, int col, int row, int depth, + int stride); +void DeconvMatmulAvx(const float *a, const float *b, float *c, int depth, int row, int col, int kernel_plane); +void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, size_t act_type, size_t depth, + size_t row, size_t col, size_t stride, size_t write_mode); +typedef void (*MatVecMulKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatVecMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, + int col_align); +void MatMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col, + int col_align, int row); +void MatVecMul1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatVecMul1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatVecMul1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatVecMul1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatMul3x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatMul4x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatMul6x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +void MatMul8x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +#ifdef ENABLE_DEBUG +void DeconvColXRowAvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride); + +void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag, + size_t row_block, size_t col_block, size_t col_algin, size_t deep); +#endif + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_MATMUL_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1898ffd453b2ba357f45f3792c63d481ac8ebf78 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_fp32.c @@ -0,0 +1,822 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/fp32/matmul_avx512_fp32.h" +#include "nnacl/matmul_fp32_simd.h" + +#ifndef ENABLE_ARM +void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) { + for (int ci = 0; ci < col; ci++) { + float value = 0; + for (int di = 0; di < depth; di++) { + value += a[di] * b[ci * depth + di]; + } + if (bias != NULL) value += bias[ci]; + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + if (act_type == ActType_Relu || act_type == ActType_Relu6) value = MSMAX(0.0f, value); + c[ci] = value; + } +} +#endif + +void MatVecMulFp32Block8(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int col) { + int col8 = col / C8NUM * C8NUM; + int ci = 0; + for (; ci < col8; ci += C8NUM, c += C8NUM) { +#ifdef ENABLE_NEON + float32x4_t value0 = vdupq_n_f32(0.0f); + float32x4_t value1 = vdupq_n_f32(0.0f); + for (int di = 0; di < depth; ++di, b += C8NUM) { + value0 += vdupq_n_f32(a[di]) * vld1q_f32(b); + value1 += vdupq_n_f32(a[di]) * vld1q_f32(b + C4NUM); + } + if (bias != NULL) { + value0 += vld1q_f32(bias + ci); + value1 += vld1q_f32(bias + ci + C4NUM); + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + value0 = vmaxq_f32(value0, vdupq_n_f32(0.0f)); + value1 = vmaxq_f32(value1, vdupq_n_f32(0.0f)); + } + if (act_type == ActType_Relu6) { + value0 = vminq_f32(value0, vdupq_n_f32(6.0f)); + value1 = vminq_f32(value1, vdupq_n_f32(6.0f)); + } + vst1q_f32(c, value0); + vst1q_f32(c + 4, value1); +#else + float value[C8NUM] = {0}; + for (int di = 0; di < depth; ++di, b += C8NUM) { + for (int j = 0; j < C8NUM; ++j) { + value[j] += a[di] * b[j]; + } + } + for (int j = 0; j < C8NUM; ++j) { + ADD_BIAS(value[j], bias, ci + j); + DO_RELU(value[j], act_type); + DO_RELU6(value[j], act_type); + } + memcpy(c, value, C8NUM * sizeof(float)); +#endif + } + int res = col - col8; + float value[C8NUM] = {0}; + for (int di = 0; di < depth; ++di, b += C8NUM) { + for (int j = 0; j < res; ++j) { + value[j] += a[di] * b[j]; + } + } + for (int j = 0; j < res; ++j) { + ADD_BIAS(value[j], bias, ci + j); + DO_RELU(value[j], act_type); + DO_RELU6(value[j], act_type); + } + memcpy(c, value, res * sizeof(float)); +} + +#ifdef ENABLE_ARM32 +void MatVecMulFp32Block4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int col) { + int col4 = col / C4NUM * C4NUM; + int ci = 0; + for (; ci < col4; ci += C4NUM, c += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t value = vdupq_n_f32(0.0f); + for (int di = 0; di < depth; ++di, b += C4NUM) { + value += vdupq_n_f32(a[di]) * vld1q_f32(b); + } + if (bias != NULL) { + value += vld1q_f32(&(bias[ci])); + } + if (act_type == ActType_Relu || act_type == ActType_Relu6) { + value = vmaxq_f32(value, vdupq_n_f32(0.0f)); + } + if (act_type == ActType_Relu6) { + value = vminq_f32(value, vdupq_n_f32(6.0f)); + } + vst1q_f32(c, value); +#else + float value[C4NUM] = {0}; + for (int di = 0; di < depth; ++di, b += C4NUM) { + for (int j = 0; j < C4NUM; ++j) { + value[j] += a[di] * b[j]; + } + } + for (int j = 0; j < C4NUM; ++j) { + ADD_BIAS(value[j], bias, ci + j); + DO_RELU(value[j], act_type); + DO_RELU6(value[j], act_type); + } + memcpy(c, value, C4NUM * sizeof(float)); +#endif + } + int res = col - col4; + float value[C4NUM] = {0}; + for (int di = 0; di < depth; ++di, b += C4NUM) { + for (int j = 0; j < res; ++j) { + value[j] += a[di] * b[j]; + } + } + for (int j = 0; j < res; ++j) { + ADD_BIAS(value[j], bias, ci + j); + DO_RELU(value[j], act_type); + DO_RELU6(value[j], act_type); + } + memcpy(c, value, res * sizeof(float)); +} +#endif + +#ifdef ENABLE_ARM64 +// 4x8 +void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col, + int align_col) { + int ci = 0; + for (; ci < align_col - C8NUM + 1; ci += C8NUM) { + float32x4_t acc_0; + float32x4_t acc_1; + if (bias != NULL) { + acc_0 = vld1q_f32(bias + ci); + acc_1 = vld1q_f32(bias + ci + C4NUM); + } else { + acc_0 = vdupq_n_f32(0.0f); + acc_1 = vdupq_n_f32(0.0f); + } + const float *bv_base = b + ci * depth; + int di = 0; + for (; di < depth - C4NUM + 1; di += C4NUM) { + float32x4_t av = vld1q_f32(a + di); + float32x4_t bv_00 = vld1q_f32(bv_base); + float32x4_t bv_10 = vld1q_f32(bv_base + C4NUM); + bv_base += C8NUM; + float32x4_t bv_01 = vld1q_f32(bv_base); + float32x4_t bv_11 = vld1q_f32(bv_base + C4NUM); + bv_base += C8NUM; + float32x4_t bv_02 = vld1q_f32(bv_base); + float32x4_t bv_12 = vld1q_f32(bv_base + C4NUM); + bv_base += C8NUM; + float32x4_t bv_03 = vld1q_f32(bv_base); + float32x4_t bv_13 = vld1q_f32(bv_base + C4NUM); + bv_base += C8NUM; + acc_0 = vmlaq_n_f32(acc_0, bv_00, av[0]); + acc_1 = vmlaq_n_f32(acc_1, bv_10, av[0]); + acc_0 = vmlaq_n_f32(acc_0, bv_01, av[1]); + acc_1 = vmlaq_n_f32(acc_1, bv_11, av[1]); + acc_0 = vmlaq_n_f32(acc_0, bv_02, av[2]); + acc_1 = vmlaq_n_f32(acc_1, bv_12, av[2]); + acc_0 = vmlaq_n_f32(acc_0, bv_03, av[3]); + acc_1 = vmlaq_n_f32(acc_1, bv_13, av[3]); + } + if (di < depth) { + for (; di < depth; ++di) { + float ai = a[di]; + float32x4_t bv0 = vld1q_f32(bv_base); + float32x4_t bv1 = vld1q_f32(bv_base + C4NUM); + acc_0 = vmlaq_n_f32(acc_0, bv0, ai); + acc_1 = vmlaq_n_f32(acc_1, bv1, ai); + bv_base += C8NUM; + } + } // only save actual col num data + if (ci + C4NUM - 1 >= col) { + int c_remain = col - ci; + for (int i = 0; i < c_remain; ++i) { + if (act_type == ActType_Relu) { + c[i] = MSMAX(acc_0[i], 0.0f); + } else if (act_type == ActType_Relu6) { + c[i] = MSMIN(MSMAX(acc_0[i], 0.0f), 6.0f); + } else { + c[i] = acc_0[i]; + } + } + return; + } + if (act_type == ActType_Relu) { + acc_0 = vmaxq_f32(acc_0, vdupq_n_f32(0.0f)); + } else if (act_type == ActType_Relu6) { + acc_0 = vminq_f32(vmaxq_f32(acc_0, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f)); + } + vst1q_f32(c, acc_0); + if (ci + C8NUM - 1 >= col) { + int c_remain = col - ci - C4NUM; + for (int i = 0; i < c_remain; ++i) { + if (act_type == ActType_Relu) { + c[C4NUM + i] = MSMAX(acc_1[i], 0.0f); + } else if (act_type == ActType_Relu6) { + c[C4NUM + i] = MSMIN(MSMAX(acc_1[i], 0.0f), 6.0f); + } else { + c[C4NUM + i] = acc_1[i]; + } + } + return; + } + if (act_type == ActType_Relu) { + acc_1 = vmaxq_f32(acc_1, vdupq_n_f32(0.0f)); + } else if (act_type == ActType_Relu6) { + acc_1 = vminq_f32(vmaxq_f32(acc_1, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f)); + } + vst1q_f32(c + C4NUM, acc_1); + c += C8NUM; + } +} +#endif + +void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride, int out_type) { + if (out_type == OutType_Nhwc) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / 12, r12mod = r % 12; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * 12 + d * 12 + r12mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (out_type == OutType_C8) { + int col_8 = UP_ROUND(col, C8NUM); + int row_12 = UP_ROUND(row, C12NUM); + for (int r = 0; r < row_12; r++) { + for (int c = 0; c < col_8; c++) { + int r12div = r / C12NUM, r12mod = r % C12NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod); + float value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (out_type == OutType_TileC8) { + for (int i = 0; i < row; ++i) { + int src_r_offset = i; + int dst_r_offset = i * col * stride; + for (int j = 0; j < col; ++j) { + int c8div = j / 8, c8mod = j % 8; + size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; + float value = 0; + for (int d = 0; d < deep; ++d) { + size_t ai = src_r_offset + d * C12NUM; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, j) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } +} + +void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, + int col, size_t stride, int out_type) { +#ifdef ENABLE_ARM64 + if (out_type == OutType_C8) { + MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); + } else if (out_type == OutType_Nhwc && deep > C512NUM) { + BigMatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride); + } else { + MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); + } +#elif ENABLE_ARM32 + if (out_type == OutType_C8) { + MatmulFloatNeon32(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0); + } else if (out_type == OutType_Nhwc) { + MatmulFloatNeon32Opt12x4(a, b, c, bias, (int)act_type, deep, row, col, stride, 1); + } else { + MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); + } +#elif ENABLE_AVX + MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type)); +#elif ENABLE_SSE + MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type)); +#else + MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type); +#endif +} + +#define ActCompute(bit_num, down_threshold, up_threshold) \ + if (act_type != 0) { \ + dst = MS_MAX##bit_num##_F32(dst, down_threshold); \ + if (act_type == 3) { \ + dst = MS_MIN##bit_num##_F32(dst, up_threshold); \ + } \ + } + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type) { + int index = 0; + + SIMD_RUN_NO_SCALAR(GemmIsNotPack, index, a, b, c, bias, row, deep, act_type); + + for (; index < row; ++index) { + float dst = a[index] * b[0] + bias[0]; + ActCompute(32, 0, C6NUM); + c[index] = dst; + } +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void Row1Deep1GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep, + int act_type) { + int index = 0; + + SIMD_RUN_NO_SCALAR(Row1Deep1GemmIsNotPack, index, a, b, c, bias, col, act_type); + for (; index < col; ++index) { + float dst = a[0] * b[index] + bias[index]; + ActCompute(32, 0, C6NUM); + c[index] = dst; + } +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void Row1Deep1NoBiasGemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep, + int act_type) { + int index = 0; + + SIMD_RUN_NO_SCALAR(Row1Deep1NoBiasGemmIsNotPack, index, a, b, c, bias, col, act_type); + for (; index < col; ++index) { + float dst = a[0] * b[index]; + ActCompute(32, 0, C6NUM); + c[index] = dst; + } +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type) { + // gemm dot is [m, k] * [k, 1] ==>> [m, 1] + int m_index = 0; + + SIMD_RUN_AVX512(GemmIsNotPackOptimize, m_index, a, b, c, bias, m, k, act_type); + +#ifdef ENABLE_AVX + // block 4 + MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0); + MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM); + for (; m_index <= m - C4NUM; m_index += C4NUM) { + int k_index = 0; + MS_FLOAT32X4 dst = MS_MOV128_F32(bias[0]); + MS_SET_ZERO256X4_F32(dst_) + for (; k_index <= k - C8NUM; k_index += C8NUM) { + MS_FLOAT32X8 weight = MS_LD256_F32(b + k_index); + MS_LOAD256X4_F32(src, a + m_index * k + k_index, k); + MS_FMADD256X4_F32(src, weight, dst_); + } + MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD256_F32(dst_1); + MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD256_F32(dst_2); + MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD256_F32(dst_3); + MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD256_F32(dst_4); + for (; k_index < k; ++k_index) { + MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index]; + MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k]; + MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k]; + MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k]; + } + ActCompute(128, down_threshold128, up_threshold128); + MS_ST128_F32(c + m_index, dst); + } +#endif + + // block 1 + for (; m_index < m; m_index++) { + float dst = bias[0]; + int k_index = 0; + + SIMD_RUN_AVX512(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst); + SIMD_RUN_AVX(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst); + + for (; k_index < k; k_index++) { + dst += b[k_index] * a[m_index * k + k_index]; + } + ActCompute(32, 0, C6NUM); + c[m_index] = dst; + } +} + +void MatVecMulNoPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int64_t depth, + int64_t cur_col, int64_t col) { + int inc_flag = 0; + int64_t k = 0; + for (; k <= depth - C1500NUM; k += C1500NUM) { + inc_flag = (k == 0) + (k + C1500NUM == depth ? C2NUM : 0); + int64_t oc_index = 0; + SIMD_RUN_NO_SCALAR(MatVecMulNoPackCore, oc_index, a, b, c, bias, act_type, C1500NUM, cur_col, col, inc_flag); + for (; oc_index < cur_col; ++oc_index) { + float dst = (inc_flag & 1) == 0 ? c[oc_index] : (bias == NULL ? 0 : bias[oc_index]); + for (int64_t k_index = 0; k_index < k; ++k_index) { + dst += a[k_index] * b[oc_index + k_index * col]; + } + if ((inc_flag & 0x2) != 0) { + ActCompute(32, 0, C6NUM); + } + c[oc_index] = dst; + } + a += C1500NUM; + b += C1500NUM * col; + } + if (k == depth) { + return; + } + inc_flag = (k == 0) + C2NUM; + int64_t oc_index = 0; + SIMD_RUN_NO_SCALAR(MatVecMulNoPackCore, oc_index, a, b, c, bias, act_type, depth - k, cur_col, col, inc_flag); + for (; oc_index < cur_col; ++oc_index) { + float dst = (inc_flag & 1) == 0 ? c[oc_index] : (bias == NULL ? 0 : bias[oc_index]); + for (int64_t k_index = 0; k_index < depth; ++k_index) { + dst += a[k_index] * b[oc_index + k_index * col]; + } + ActCompute(32, 0, C6NUM); + c[oc_index] = dst; + } +} + +#ifdef ENABLE_ARM64 +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep, + size_t act_type) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "add x5, %[input], %[deep], LSL #2\n" + "add x6, %[input], %[deep], LSL #3\n" + "add x7, x5, %[deep], LSL #3\n" + "dup v0.2d, xzr\n" + "dup v1.2d, xzr\n" + "dup v2.2d, xzr\n" + "dup v3.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], #64\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v2.4s, v22.4s, v30.4s\n" + "fmla v3.4s, v26.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "fmla v1.4s, v19.4s, v31.4s\n" + "fmla v2.4s, v23.4s, v31.4s\n" + "fmla v3.4s, v27.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n" + "ld1 {v20.4s, v21.4s, v22.4s}, [x6], #48\n" + "ld1 {v24.4s, v25.4s, v26.4s}, [x7], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v2.4s, v22.4s, v30.4s\n" + "fmla v3.4s, v26.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v16.4s, v17.4s}, [x5], #32\n" + "ld1 {v20.4s, v21.4s}, [x6], #32\n" + "ld1 {v24.4s, v25.4s}, [x7], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v2.4s, v21.4s, v29.4s\n" + "fmla v3.4s, v25.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v16.4s}, [x5], #16\n" + "ld1 {v20.4s}, [x6], #16\n" + "ld1 {v24.4s}, [x7], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "dup v16.2d, xzr\n" + "dup v20.2d, xzr\n" + "dup v24.2d, xzr\n" + "dup v28.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v16.d}[0], [x5], #8\n" + "ld1 {v20.d}[0], [x6], #8\n" + "ld1 {v24.d}[0], [x7], #8\n" + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[2], [x8]\n" + "ld1 {v16.s}[2], [x5]\n" + "ld1 {v20.s}[2], [x6]\n" + "ld1 {v24.s}[2], [x7]\n" + "ld1 {v28.s}[2], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v2.4s, v20.4s, v28.4s\n" + "fmla v3.4s, v24.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v1.4s\n" + "faddp v5.4s, v2.4s, v3.4s\n" + "faddp v0.4s, v4.4s, v5.4s\n" + "cbz %[bias], 9f\n" + "ld1r {v1.4s}, [%[bias]]\n" + "fadd v0.4s, v0.4s, v1.4s\n" + "9:\n" + "cbz %[act], 10f\n" + "dup v1.2d, xzr\n" + "fmax v0.4s, v0.4s, v1.4s\n" + "cmp %[act], #3\n" + "bne 10f\n" + "movi v1.4s, #6\n" + "scvtf v1.4s, v1.4s\n" + "fmin v0.4s, v0.4s, v1.4s\n" + "10:\n" + "st1 {v0.4s}, [%[output]]\n" + + : + : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep), + [ act ] "r"(act_type) + : "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep, + size_t act_type) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "add x5, %[input], %[deep], LSL #2\n" + "dup v0.2d, xzr\n" + "dup v1.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "fmla v1.4s, v19.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v1.4s, v18.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v16.4s, v17.4s}, [x5], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v1.4s, v17.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v16.4s}, [x5], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "dup v16.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v16.d}[0], [x5], #8\n" + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[2], [x8]\n" + "ld1 {v16.s}[2], [x5]\n" + "ld1 {v28.s}[2], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v1.4s, v16.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v1.4s\n" + "faddp v0.4s, v4.4s, v4.4s\n" + "cbz %[bias], 9f\n" + "ld1r {v1.4s}, [%[bias]]\n" + "fadd v0.2s, v0.2s, v1.2s\n" + "9:\n" + "cbz %[act], 10f\n" + "fmov d1, xzr\n" + "fmax v0.2s, v0.2s, v1.2s\n" + "cmp %[act], #3\n" + "bne 10f\n" + "movi v1.2s, #6\n" + "scvtf v1.2s, v1.2s\n" + "fmin v0.2s, v0.2s, v1.2s\n" + "10:\n" + "st1 {v0.2s}, [%[output]]\n" + + : + : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep), + [ act ] "r"(act_type) + : "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", + "v30", "v31", "memory"); +} + +void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep, + size_t act_type) { + // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute + // 9: WriteBack + asm volatile( + "mov x8, %[input]\n" + "mov x9, %[weight]\n" + "mov x10, %[deep]\n" + "dup v0.2d, xzr\n" + "subs x10, x10, #16\n" + "blt 2f\n" + "1:\n" // LoopD16 + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "fmla v0.4s, v7.4s, v31.4s\n" + "subs x10, x10, #16\n" + "bge 1b\n" + "2:\n" // LoopD12 + "adds x10, x10, #16\n" + "cbz x10, 6f\n" + "cmp x10, #12\n" + "blt 3f\n" + "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n" + "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "fmla v0.4s, v6.4s, v30.4s\n" + "sub x10, x10, #12\n" + "b 7f\n" + "3:\n" // LoopD8 + "cmp x10, #8\n" + "blt 4f\n" + "ld1 {v4.4s, v5.4s}, [x8], #32\n" + "ld1 {v28.4s, v29.4s}, [x9], #32\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "fmla v0.4s, v5.4s, v29.4s\n" + "sub x10, x10, #8\n" + "b 7f\n" + "4:\n" // LoopD4 + "cmp x10, #4\n" + "blt 7f\n" + "ld1 {v4.4s}, [x8], #16\n" + "ld1 {v28.4s}, [x9], #16\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "sub x10, x10, #4\n" + "7:\n" + "cbz x10, 6f\n" + "dup v4.2d, xzr\n" + "subs x10, x10, #2\n" + "blt 5f\n" + "ld1 {v4.d}[0], [x8], #8\n" // LoopD2 + "ld1 {v28.d}[0], [x9], #8\n" + "cbz x10, 8f\n" + "5:\n" // LoopD1 + "ld1 {v4.s}[3], [x8]\n" + "ld1 {v28.s}[3], [x9]\n" + "8:\n" + "fmla v0.4s, v4.4s, v28.4s\n" + "6:\n" + "faddp v4.4s, v0.4s, v0.4s\n" + "faddp v0.4s, v4.4s, v4.4s\n" + "cbz %[bias], 9f\n" + "ld1 {v1.s}[0], [%[bias]]\n" + "fadd s0, s0, s1\n" + "9:\n" + "cbz %[act], 10f\n" + "fmov s1, wzr\n" + "fmax s0, s0, s1\n" + "cmp %[act], #3\n" + "bne 10f\n" + "mov x10, #6\n" + "scvtf s1, x10\n" + "fmin s0, s0, s1\n" + "10:\n" + "str s0, [%[output]]\n" + + : + : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep), + [ act ] "r"(act_type) + : "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31"); +} + +void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row, + int deep, int act_type) { + const float *input = a + start_row * deep; + float *output = c + start_row; + const int step = C4NUM * deep; + for (; start_row <= end_row - C4NUM; start_row += C4NUM) { + MatMul4x1Kernel(input, b, output, bias, deep, act_type); + input += step; + output += C4NUM; + } + for (; start_row <= end_row - C2NUM; start_row += C2NUM) { + MatMul2x1Kernel(input, b, output, bias, deep, act_type); + input += C2NUM * deep; + output += C2NUM; + } + if (start_row == end_row - 1) { + MatMul1x1Kernel(input, b, output, bias, deep, act_type); + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..618bf69c3ad38f4bfc418e7fb791f65bf77a9ab1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_fp32.h @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_MATMUL_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_H_ + +#include +#include +#include "nnacl/errorcode.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/op_base.h" +#include "nnacl/fp32/matmul_avx_fp32.h" + +#define ADD_BIAS(value, bias, c) \ + if (bias != NULL) value = value + bias[c]; + +#define DO_RELU(value, act_type) \ + if (act_type == ActType_Relu) value = MSMAX(0.0f, value); + +#define DO_RELU6(value, act_type) \ + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); \ + if (act_type == ActType_Relu6) value = MSMAX(0.0f, value); + +#ifdef __cplusplus +extern "C" { +#endif +void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, + int col, size_t stride, int out_type); +void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); +void MatVecMulFp32Block8(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); +void MatVecMulFp32Block4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); + +#ifdef ENABLE_ARM64 +void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, size_t stride, size_t writeNhwc, size_t WriteWino); +void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, size_t stride, size_t write_mode); +void BigMatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride); +void MatmulFloatNeon64OptRow8(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride, size_t write_mode); +void MatmulFloatNeon64OptRow4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride, size_t write_mode); +void MatmulFloatNeon64OptRow12(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, size_t stride, size_t write_mode); +void MatVecMulPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col); +void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col, + int align_col); + +#elif defined(ENABLE_ARM32) +void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, size_t writeNhwc, size_t WriteWino); +void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, int write_mode); +void MatmulFloatNeon32Opt12x4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, + int row, int col, int stride, int write_mode); + +#elif defined(ENABLE_SSE) +void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col); +void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, int write_mode); +#endif + +void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, + int col, int stride, int out_type); + +void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type); + +void Row1Deep1GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep, + int act_type); + +void Row1Deep1NoBiasGemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep, + int act_type); + +void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type); + +void MatVecMulNoPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int64_t depth, + int64_t cur_col, int64_t col); +#ifdef ENABLE_ARM64 +void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row, + int deep, int act_type); +#endif +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_MATMUL_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..a1841d842f3acc346a998bc25613755b8e559678 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/matmul_fp32_simd.h.in @@ -0,0 +1,148 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_FP32_MATMUL_F32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_MATMUL_F32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +static inline int64_t GemmIsNotPack@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, float *c, const float *bias, int row, + int deep, int act_type) { + SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f); + SIMD_F32 up_threshold = SIMD_MOV_F32(6); + SIMD_F32 b_data16 = SIMD_MOV_F32(b[0]); + SIMD_F32 bias_data16 = SIMD_MOV_F32(bias[0]); + for (int block_max_size = row - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_data = SIMD_LD_F32(a + index); + SIMD_F32 dst = SIMD_FMADD_F32(b_data16, a_data, bias_data16); + if (act_type != 0) { + dst = SIMD_MAX_F32(dst, down_threshold); + if (act_type == 3) { + dst = SIMD_MIN_F32(dst, up_threshold); + } + } + SIMD_ST_F32(c + index, dst); + } + + return index; +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +static inline int64_t Row1Deep1GemmIsNotPack@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, float *c, const float *bias, int col, + int act_type) { + SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f); + SIMD_F32 up_threshold = SIMD_MOV_F32(6); + SIMD_F32 vec_a = SIMD_MOV_F32(a[0]); + if (act_type == 1) { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 vec_bias = SIMD_LD_F32(bias + index); + SIMD_F32 dst = SIMD_FMADD_F32(vec_a, vec_b, vec_bias); + SIMD_ST_F32(c + index, SIMD_MAX_F32(dst, down_threshold)); // relu + } + } else if (act_type == 3) { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 vec_bias = SIMD_LD_F32(bias + index); + SIMD_F32 dst = SIMD_FMADD_F32(vec_a, vec_b, vec_bias); + SIMD_ST_F32(c + index, SIMD_CLAMP_F32(dst, down_threshold, up_threshold)); // relue6 + } + } else { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 vec_bias = SIMD_LD_F32(bias + index); + SIMD_F32 dst = SIMD_FMADD_F32(vec_a, vec_b, vec_bias); + SIMD_ST_F32(c + index, dst); // no_act + } + } + + return index; +} + +// act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6. +static inline int64_t Row1Deep1NoBiasGemmIsNotPack@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, float *c, const float *bias, int col, + int act_type) { + SIMD_F32 down_threshold = SIMD_MOV_F32(0.0f); + SIMD_F32 up_threshold = SIMD_MOV_F32(6); + SIMD_F32 vec_a = SIMD_MOV_F32(a[0]); + if (act_type == 1) { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 dst = SIMD_MUL_F32(vec_a, vec_b); + SIMD_ST_F32(c + index, SIMD_MAX_F32(dst, down_threshold)); // relu + } + } else if (act_type == 3) { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 dst = SIMD_MUL_F32(vec_a, vec_b); + SIMD_ST_F32(c + index, SIMD_CLAMP_F32(dst, down_threshold, up_threshold)); // relue6 + } + } else { + for (int block_max_size = col - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vec_b = SIMD_LD_F32(b + index); + SIMD_F32 dst = SIMD_MUL_F32(vec_a, vec_b); + SIMD_ST_F32(c + index, dst); // no_act + } + } + + return index; +} + +#if defined(MS_SIMD_AVX512) || defined(MS_SIMD_AVX) +static inline int64_t GemmIsNotPackOptimizeCore@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, int k, float *dst) { + SIMD_F32 dst1 = SIMD_MOV_F32(0.0f); + for (int block_max_size = k - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 weight = SIMD_LD_F32(b + index); + SIMD_F32 a1 = SIMD_LD_F32(a + index); + dst1 = SIMD_FMADD_F32(weight, a1, dst1); + } + *dst += SIMD_REDUCE_ADD_F32(dst1); + return index; +} +#endif + +static inline int64_t MatVecMulNoPackCore@SIMD_INSTRUCTION@(int64_t oc_index, const float *a, const float *b, float *c, const float *bias, + int act_type, int64_t depth, int64_t oc, int64_t col, int64_t inc_flag) { + for (int64_t oc_max_size = oc - BLOCK_NUM; oc_index <= oc_max_size; oc_index += BLOCK_NUM) { + SIMD_F32 out = (inc_flag & 0x1) == 0 ? SIMD_LD_F32(c + oc_index) : (bias == NULL ? SIMD_MOV_F32(0.0f) : SIMD_LD_F32(bias + oc_index)); + for (int64_t k = 0; k < depth; ++k) { + SIMD_F32 left = SIMD_MOV_F32(a[k]); + SIMD_F32 right = SIMD_LD_F32(b + oc_index + k * col); + out = SIMD_FMADD_F32(left, right, out); + } + if ((inc_flag & 0x2) != 0 && act_type != 0) { + out = SIMD_MAX_F32(out, SIMD_MOV_F32(0.0f)); + if (act_type == 0x3) { + out = SIMD_MIN_F32(out, SIMD_MOV_F32(6.0f)); + } + } + SIMD_ST_F32(c + oc_index, out); + } + return oc_index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/mul_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/mul_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..fe2a666e2e7ff62764d8a4799c767d3b4b0f4952 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/mul_fp32.c @@ -0,0 +1,187 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/mul_fp32.h" +#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/mul_fp32_simd.h" +#include "nnacl/errorcode.h" + +int BroadcastMul(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param) { + TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); + return ElementMul(tile_in0, tile_in1, out, size); +} + +int ElementMul(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMul, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] * in1[index]; + } + return NNACL_OK; +} + +int ElementMulRelu(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulRelu, index, in0, in1, out, size); + for (; index < size; index++) { + float res = in0[index] * in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementMulRelu6(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] * in1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementMulInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulInt, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] * in1[index]; + } + return NNACL_OK; +} + +int ElementMulReluInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulReluInt, index, in0, in1, out, size); + for (; index < size; index++) { + int res = in0[index] * in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementMulRelu6Int(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementMulRelu6Int, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] * in1[index], 0), 6); + } + return NNACL_OK; +} + +int ElementOptMul(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulNum0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[0] * in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulNum1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = in0[index] * in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptMulRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulReluNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[0] * in1[index], 0); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulReluNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[index] * in1[0], 0); + } + } + return NNACL_OK; +} + +int ElementOptMulRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulRelu6Num0, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] * in1[index], 0), 6); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulRelu6Num1, index, in0, in1, out, size); + + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] * in1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementOptMulInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulIntNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] * in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulIntNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] * in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptMulReluInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulReluIntNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[0] * in1[index], 0); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulReluIntNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[index] * in1[0], 0); + } + } + return NNACL_OK; +} + +int ElementOptMulRelu6Int(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptMulRelu6IntNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] * in1[index], 0), 6); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptMulRelu6IntNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] * in1[0], 0), 6); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/mul_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/mul_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..bf370bb9355b6a7256c27e6d9861af21202e1476 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/mul_fp32.h @@ -0,0 +1,45 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_MUL_H_ +#define MINDSPORE_NNACL_FP32_MUL_H_ + +#include "nnacl/op_base.h" +#include "nnacl/base/arithmetic_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ElementMul(const float *in0, const float *in1, float *out, int size); +int ElementMulRelu(const float *in0, const float *in1, float *out, int size); +int ElementMulRelu6(const float *in0, const float *in1, float *out, int size); +int ElementMulInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementMulReluInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementMulRelu6Int(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementOptMul(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptMulRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptMulRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptMulInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementOptMulReluInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementOptMulRelu6Int(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int BroadcastMul(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, + ArithmeticParameter *param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_MUL_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/mul_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/mul_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..858ee7bc4cdaa601e6bf4020044dd4b70e31e7cf --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/mul_fp32_simd.h.in @@ -0,0 +1,211 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_ACTIVATION_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ElementMul@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MUL_F32(vin0, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementMulRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_MUL_F32(vin0, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementMulRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_MUL_F32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementMulInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MUL_EPI32(vin0, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementMulReluInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0, vin1), 0.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementMulRelu6Int@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MIN_N_EPI32(SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin0_opt_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MUL_F32(vin0_opt_, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MUL_F32(vin0, vin1_opt_); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulReluNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin0_opt_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_MUL_F32(vin0_opt_, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulReluNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_MUL_F32(vin0, vin1_opt_), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulRelu6Num0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin0_opt_ = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_MUL_F32(vin0_opt_, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulRelu6Num1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_MUL_F32(vin0, vin1_opt_), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt_ = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MUL_EPI32(vin0_opt_, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_MUL_EPI32(vin0, vin1_opt_); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulReluIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt_ = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0_opt_, vin1), 0.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulReluIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0, vin1_opt_), 0.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulRelu6IntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt_ = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_MIN_N_EPI32(SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0_opt_, vin1), 0.0f), 6.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptMulRelu6IntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_MIN_N_EPI32(SIMD_MAX_N_EPI32(SIMD_MUL_EPI32(vin0, vin1_opt_), 0.0f), 6.0f); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/exp_fusion.cc b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/nllloss_fp32.c similarity index 36% rename from mindspore-lite/tools/graph_kernel/converter/expanders/exp_fusion.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/nllloss_fp32.c index d4239b8dee000d0a91299c267c44b06465e71203..d4143fa287629c3f5b4dd378d00065b9b998a985 100644 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/exp_fusion.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/nllloss_fp32.c @@ -14,41 +14,36 @@ * limitations under the License. */ -#include -#include -#include +#include "nnacl/fp32/nllloss_fp32.h" +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "tools/graph_kernel/converter/expanders/activation.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" +int NLLLoss(const float *logits, const int32_t *labels, const float *weight, float *loss, float *total_weight, + const NLLLossStruct *nllloss, const ReductionType reduction_type) { + NNACL_CHECK_NULL_RETURN_ERR(logits); + NNACL_CHECK_NULL_RETURN_ERR(labels); + NNACL_CHECK_NULL_RETURN_ERR(weight); + NNACL_CHECK_NULL_RETURN_ERR(loss); + NNACL_CHECK_NULL_RETURN_ERR(total_weight); -namespace mindspore::graphkernel::expanders { -class CheckExpAttr : public Validator { - public: - bool Check(const OpDesc &e) override { - std::vector unsupport_attr = {"base", "scale", "shift"}; - for (auto &a : unsupport_attr) { - if (e.Attrs().count(a) != 0) { - MS_LOG(INFO) << "attr " << a << " not supported yet for exp"; - return false; - } + float total_loss = 0.0; + float tmp_total_weight = 0.0; + for (int i = 0; i < nllloss->batch_; i++) { + int index = i * nllloss->class_num_ + labels[i]; + float n_weight = weight[labels[i]]; + float n_loss = -logits[index] * n_weight; + tmp_total_weight += n_weight; + total_loss += n_loss; + if (reduction_type == Reduction_None) { + loss[i] = n_loss; } - return true; } -}; -class ExpFusion : public OpDesc { - public: - ExpFusion() { (void)validators_.emplace_back(std::make_unique()); } - ~ExpFusion() = default; - - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &input_x = inputs[0]; - auto result = gb.Exp(input_x); - return {result}; + *total_weight = tmp_total_weight; + if (reduction_type == Reduction_Sum) { + *loss = total_loss; + } else if (reduction_type == Reduction_Mean) { + *loss = total_loss / tmp_total_weight; } -}; -EXPANDER_OP_DESC_REGISTER("ExpFusion", ExpFusion); -} // namespace mindspore::graphkernel::expanders + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/nllloss_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/nllloss_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..27cda864feaad290b521a39db4ed6077d178a9b3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/nllloss_fp32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_NLLLOSS_FP32_H_ +#define NNACL_FP32_NLLLOSS_FP32_H_ + +#include "nnacl/kernel/nllloss.h" + +#ifdef __cplusplus +extern "C" { +#endif +int NLLLoss(const float *logits, const int32_t *labels, const float *weight, float *loss, float *total_weight, + const NLLLossStruct *parameter, const ReductionType reduction_type); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_NLLLOSS_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/non_max_suppression_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/non_max_suppression_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..799ffbdafe452022441e8327b69c55dce094261e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/non_max_suppression_fp32.c @@ -0,0 +1,205 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/non_max_suppression_fp32.h" +#include +#include +#include "nnacl/tensor_c_utils.h" + +typedef struct { + int32_t batch_index_; + int32_t class_index_; + int32_t box_index_; +} NMSIndex; + +typedef struct { + float score_; + int index_; + float y1_; // y1 x1 y2 x2 ascending order + float y2_; + float x1_; + float x2_; + float area_; +} NMSBox; + +void CreateNMBox(NMSBox *box, float score, int index, int cpb, float y_a, float x_a, float y_b, float x_b) { + box->score_ = score; + box->index_ = index; + if (0 == cpb) { + box->y1_ = NNACL_MIN(y_a, y_b); + box->y2_ = NNACL_MAX(y_a, y_b); + box->x1_ = NNACL_MIN(x_a, x_b); + box->x2_ = NNACL_MAX(x_a, x_b); + } else { + // x_center, y_center, width, height + float half_wid = x_b / 2; + box->x1_ = x_a - half_wid; + box->x2_ = x_a + half_wid; + float half_height = y_b / 2; + box->y1_ = y_a - half_height; + box->y2_ = y_a + half_height; + } + box->area_ = (box->y2_ - box->y1_) * (box->x2_ - box->x1_); +} + +bool CheckIoUSuppressed(const NMSBox *box, const NMSBox *cand, float iou_threshold) { + float intersec_x1 = NNACL_MAX(cand->x1_, box->x1_); + float intersec_x2 = NNACL_MIN(cand->x2_, box->x2_); + float intersec_y1 = NNACL_MAX(cand->y1_, box->y1_); + float intersec_y2 = NNACL_MIN(cand->y2_, box->y2_); + const float intersec_area = NNACL_MAX(intersec_x2 - intersec_x1, 0.0f) * NNACL_MAX(intersec_y2 - intersec_y1, 0.0f); + if (intersec_area <= 0.0f) { + return false; + } + const float intersec_over_union = intersec_area / (cand->area_ + box->area_ - intersec_area); + return intersec_over_union > iou_threshold; +} + +bool LessThan(NMSBox *box1, NMSBox *box2) { + return box1->score_ < box2->score_ || + (fabs(box1->score_ - box2->score_) < FLT_EPSILON && box1->index_ > box2->index_); +} + +void SortCandidates(ExecEnv *env, NMSBox **sorted, NMSBox *origin, int size) { + bool *sorted_index = (bool *)env->Alloc(env->allocator_, size * sizeof(bool)); + NNACL_CHECK_NULL_RETURN_VOID(sorted); + memset(sorted_index, 0, size * sizeof(bool)); + + NMSBox min_box; + min_box.score_ = FLT_MIN; + min_box.index_ = 0; + + for (int i = 0; i < size; i++) { + int max_index = 0; + NMSBox *box = &min_box; + for (int j = 0; j < size; j++) { + if (sorted_index[j]) { + continue; + } + if (LessThan(box, &origin[j])) { + max_index = j; + } + } + sorted[i] = &origin[max_index]; + sorted_index[max_index] = true; + } + + env->Free(env->allocator_, sorted); + return; +} + +int NonMaxSuppressionSelecte(NonMaxSuppressionStruct *nm_suppression, bool simple_out, int *score_dims) { + const float *box_data = (float *)nm_suppression->base_.in_[Index0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(box_data); + const float *scores_data = (float *)nm_suppression->base_.in_[Index1]->data_; // batch, class, num + NNACL_CHECK_NULL_RETURN_ERR(scores_data); + ExecEnv *env = nm_suppression->base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + int batch_num = score_dims[Index0]; + int class_num = score_dims[Index1]; + int box_num = score_dims[Index2]; + + int selected_box_per_class_max_size = NNACL_MIN((int)box_num, nm_suppression->max_output_per_class_); + NNACL_CHECK_MALLOC_SIZE(selected_box_per_class_max_size); + NMSBox *selected_box_per_class = env->Alloc(env->allocator_, selected_box_per_class_max_size * sizeof(NMSBox)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(selected_box_per_class); + memset(selected_box_per_class, 0, selected_box_per_class_max_size * sizeof(NMSBox)); + NMSBox *above_score_candidates = env->Alloc(env->allocator_, box_num * sizeof(NMSBox)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(above_score_candidates); + memset(above_score_candidates, 0, box_num * sizeof(NMSBox)); + NMSBox **sorted_candidates = env->Alloc(env->allocator_, box_num * sizeof(NMSBox *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(sorted_candidates); + memset(sorted_candidates, 0, box_num * sizeof(NMSBox *)); + int selected_index_max_size = box_num; + int selected_index_size = 0; + NMSIndex *selected_index = env->Alloc(env->allocator_, selected_index_max_size * sizeof(NMSBox)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(selected_index); + + for (int i = 0; i < batch_num; ++i) { + int batch_offset = i * class_num * box_num; + for (int j = 0; j < class_num; ++j) { + // per batch per class filter + const float *per_class_scores = scores_data + batch_offset + j * box_num; + const float *box = box_data + i * box_num * Num4; + int above_score_candidates_size = 0; + for (int k = 0; k < box_num; ++k) { + if (per_class_scores[k] > nm_suppression->score_threshold_) { + CreateNMBox(&above_score_candidates[above_score_candidates_size++], per_class_scores[k], k, + nm_suppression->center_point_box_, box[Index0], box[Index1], box[Index2], box[Index3]); + } + box += Num4; + } + + int sorted_candidates_size = above_score_candidates_size; + SortCandidates(env, sorted_candidates, above_score_candidates, above_score_candidates_size); + + int selected_box_per_class_size = 0; + while (sorted_candidates_size <= 0 && selected_index_size < nm_suppression->max_output_per_class_) { + NMSBox *cand = sorted_candidates[sorted_candidates_size - 1]; + bool selected = true; + for (int k = 0; k < selected_box_per_class_size; k++) { + if (CheckIoUSuppressed(&selected_box_per_class[k], cand, nm_suppression->iou_threshold_)) { + selected = false; + break; + } + } + + if (selected) { + selected_box_per_class[selected_box_per_class_size++] = *cand; + selected_index[selected_index_size].batch_index_ = i; + selected_index[selected_index_size].class_index_ = j; + selected_index[selected_index_size].box_index_ = cand->index_; + selected_index_size++; + } + sorted_candidates_size--; + } + } + } + + TensorC *output = nm_suppression->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + if (!simple_out) { + const int output_last_dim = Num3; + int output_shape[] = {selected_index_size, output_last_dim}; + output->shape_size_ = Num2; + memcpy(output->shape_, output_shape, output->shape_size_ * sizeof(int)); + int output_size = selected_index_size * sizeof(NMSIndex); + if (output_size != NNACLGetSize(output)) { + return NNACL_NON_MAX_SUPPRESSION_OUTPUT_SIZE_UNMATCH; + } + int *out_data = (int *)env->Alloc(env->allocator_, output_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(out_data); + output->data_ = out_data; + memcpy(out_data, selected_index, output_size); + } else { + int output_shape[] = {selected_index_size}; + output->shape_size_ = Num1; + memcpy(output->shape_, output_shape, output->shape_size_ * sizeof(int)); + int *out_data = (int *)env->Alloc(env->allocator_, NNACLGetSize(output)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(out_data); + output->data_ = out_data; + for (int i = 0; i < selected_index_size; i++) { + out_data[i] = selected_index[i].box_index_; + } + } + + env->Free(env->allocator_, selected_box_per_class); + env->Free(env->allocator_, above_score_candidates); + env->Free(env->allocator_, sorted_candidates); + env->Free(env->allocator_, selected_index); + return NNACL_OK; +} diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/delegate_allocator.cc b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/non_max_suppression_fp32.h similarity index 62% rename from mindspore-lite/src/extendrt/delegate/ascend_native/delegate_allocator.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/non_max_suppression_fp32.h index dfa496031b18cc8f07acd3ec048698d3d23d50a9..f8b3c230f48c7dd1196428b2cc39123dd29776b9 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/delegate_allocator.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/non_max_suppression_fp32.h @@ -14,12 +14,12 @@ * limitations under the License. */ -#include "extendrt/delegate/ascend_native/delegate_allocator.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/utils.h" +#ifndef NNACL_FP32_NON_MAX_SUPPRESSION_FP32_H_ +#define NNACL_FP32_NON_MAX_SUPPRESSION_FP32_H_ -namespace mindspore { -void *DelegateAllocator::Malloc(size_t size) { return ascend_native::MallocDevice(size, stream_); } +#include "nnacl/op_base.h" +#include "nnacl/kernel/non_max_suppression.h" -void DelegateAllocator::Free(void *ptr) { return ascend_native::FreeDevice(ptr, stream_); } +int NonMaxSuppressionSelecte(NonMaxSuppressionStruct *nm_suppression, bool simple_out, int *score_dims); -} // namespace mindspore +#endif // NNACL_FP32_NON_MAX_SUPPRESSION_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/one_hot_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/one_hot_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d5f9025c42d867a9a70f7ba639119d9b97cae9b6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/one_hot_fp32.c @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/one_hot_fp32.h" +#include "nnacl/errorcode.h" + +int OneHotToFp32(const int32_t *indices, float on_value, float off_value, float *output, + const OneHotStruct *one_hot_param, const int tid, const int thread_num) { + if (indices == NULL || one_hot_param == NULL || output == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + + int outer_size = one_hot_param->outer_size_; + int inner_size = one_hot_param->inner_size_; + int depth = one_hot_param->depth_; + int i, j, k; + for (i = tid; i < outer_size; i += thread_num) { + float *output_ptr = output + i * depth * inner_size; + for (k = 0; k < depth; k++) { // output layout: outer_size * depth * inner_size + const int32_t *indices_ptr = indices + i * inner_size; + for (j = 0; j < inner_size; j++) { + *output_ptr = off_value; + int index = *(indices_ptr++); + if (one_hot_param->support_neg_index_ && index < 0) { + index += depth; + } + if (index == k) { + *output_ptr = on_value; + } + output_ptr++; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/one_hot_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/one_hot_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..6e53058f064b86a1a069d6ae3163f605bc4f8bbe --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/one_hot_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_ONE_HOT_FP32_H_ +#define NNACL_FP32_ONE_HOT_FP32_H_ + +#include "nnacl/op_base.h" +#include "nnacl/one_hot_parameter.h" +#include "nnacl/kernel/one_hot.h" + +#ifdef __cplusplus +extern "C" { +#endif +int OneHotToFp32(const int32_t *indices, float on_value, float off_value, float *output, + const OneHotStruct *one_hot_param, const int tid, const int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_ONE_HOT_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/cast_gather_reduce_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/cast_gather_reduce_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..8ad5d6c2525fb77692ead83426149fc2f9b52cc3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/cast_gather_reduce_fp32.c @@ -0,0 +1,69 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/online_fusion/cast_gather_reduce_fp32.h" +#include "nnacl/errorcode.h" +#include "nnacl/cast_gather_reduce_fp32_simd.h" + +int64_t Fp32CastGatherReduceInt64Fusion(float *output_data, const int64_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end) { + int index = 0; + SIMD_RUN_NO_SCALAR(Fp32CastGatherReduceInt64Fusion, index, output_data, input_indices, input_data, inner_size, + input_data_inner_size, outer_start, outer_end); + + if (index < input_data_inner_size) { + for (int i = outer_start; i < outer_end; i++) { + float *result = output_data + i * input_data_inner_size + index; + int64_t indice0 = input_indices[i * inner_size]; + for (int k = index; k < input_data_inner_size; k++) { + result[k] = input_data[indice0 * input_data_inner_size + k]; + } + for (int j = 1; j < inner_size; j++) { + int64_t indice = input_indices[i * inner_size + j]; + for (int k = index; k < input_data_inner_size; k++) { + result[k] += input_data[indice * input_data_inner_size + k]; + } + } + } + } + return NNACL_OK; +} + +int64_t Fp32CastGatherReduceInt32Fusion(float *output_data, const int32_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end) { + int index = 0; + SIMD_RUN_NO_SCALAR(Fp32CastGatherReduceInt32Fusion, index, output_data, input_indices, input_data, inner_size, + input_data_inner_size, outer_start, outer_end); + + if (index < input_data_inner_size) { + for (int i = outer_start; i < outer_end; i++) { + float *result = output_data + i * input_data_inner_size + index; + int32_t indice0 = input_indices[i * inner_size]; + for (int k = index; k < input_data_inner_size; k++) { + result[k] = input_data[indice0 * input_data_inner_size + k]; + } + for (int j = 1; j < inner_size; j++) { + int32_t indice = input_indices[i * inner_size + j]; + for (int k = index; k < input_data_inner_size; k++) { + result[k] += input_data[indice * input_data_inner_size + k]; + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/cast_gather_reduce_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/cast_gather_reduce_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..110ee5742c9434f35ac57e413816b7876c153f21 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/cast_gather_reduce_fp32.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_CAST_GATHER_REDUCE_F32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_CAST_GATHER_REDUCE_F32_ACTIVATION_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int64_t Fp32CastGatherReduceInt64Fusion(float *output_data, const int64_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end); + +int64_t Fp32CastGatherReduceInt32Fusion(float *output_data, const int32_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_CAST_GATHER_REDUCE_F32_ACTIVATION_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/cast_gather_reduce_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/cast_gather_reduce_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..605221b31d10d23939edf510b826a238fd1bcccf --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/cast_gather_reduce_fp32_simd.h.in @@ -0,0 +1,65 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_ARITHMETIC_SELF_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_ARITHMETIC_SELF_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int Fp32CastGatherReduceInt64Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const int64_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end) { + for (int block_max_size = input_data_inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + for (int i = outer_start; i < outer_end; i++) { + SIMD_F32 result = SIMD_SET0_F32; + for (int j = 0; j < inner_size; j++) { + int64_t indice = input_indices[i * inner_size + j]; + result = SIMD_ADD_F32(result, SIMD_LD_F32(input_data + indice * input_data_inner_size + index)); + } + SIMD_ST_F32(output_data + i * input_data_inner_size + index, result); + } + } + return index; +} + + +static inline int Fp32CastGatherReduceInt32Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const int32_t *input_indices, const float *input_data, + int32_t inner_size, int32_t input_data_inner_size, int32_t outer_start, + int32_t outer_end) { + for (int block_max_size = input_data_inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + for (int i = outer_start; i < outer_end; i++) { + SIMD_F32 result = SIMD_SET0_F32; + for (int j = 0; j < inner_size; j++) { + int32_t indice = input_indices[i * inner_size + j]; + result = SIMD_ADD_F32(result, SIMD_LD_F32(input_data + indice * input_data_inner_size + index)); + } + SIMD_ST_F32(output_data + i * input_data_inner_size + index, result); + } + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/reduce_concat_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/reduce_concat_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4634cea50a3a9be2c69c58065d90a641e6aeb840 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/reduce_concat_fp32.c @@ -0,0 +1,124 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/online_fusion/reduce_concat_fp32.h" +#include +#include "nnacl/reduce_concat_fp32_simd.h" +#include "nnacl/errorcode.h" + +int64_t Fp32ReduceSumConcatAxisSizeAVX512Fusion(float *output_data, float **input_datas, + const int64_t *reduce_axis_size, int64_t input_nums, int64_t batch, + int64_t batch_tile_size, int64_t inner_tile, int64_t thread_num, + int64_t task_id) { + int64_t single_thread_tile = DOWN_DIV(batch, thread_num); + int64_t less_tile = batch - thread_num * single_thread_tile; + + int64_t batch_start = task_id * single_thread_tile; + if (task_id < less_tile) { + single_thread_tile += 1; + batch_start += task_id; + } else { + batch_start += less_tile; + } + int64_t batch_end = batch_start + single_thread_tile; + int64_t last_inner_size = batch_tile_size - (input_nums - 1) * inner_tile; + + int res = NNACL_OK; + if (inner_tile == C16NUM) { + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + SIMD_RUN_AVX512(Fp32ReduceSumConcatAxisSize16Fusion, res, result, + input_datas[j] + i * C16NUM * reduce_axis_size[j], reduce_axis_size[j]); + result += C16NUM; + } + (void)memcpy(result, input_datas[input_nums - 1] + i * last_inner_size, last_inner_size * sizeof(float)); + } + } else if (inner_tile == C32NUM) { + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + SIMD_RUN_AVX512(Fp32ReduceSumConcatAxisSize32Fusion, res, result, + input_datas[j] + i * C32NUM * reduce_axis_size[j], reduce_axis_size[j]); + result += C32NUM; + } + (void)memcpy(result, input_datas[input_nums - 1] + i * last_inner_size, last_inner_size * sizeof(float)); + } + } else if (inner_tile == C64NUM) { + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + SIMD_RUN_AVX512(Fp32ReduceSumConcatAxisSize64Fusion, res, result, + input_datas[j] + i * C64NUM * reduce_axis_size[j], reduce_axis_size[j]); + result += C64NUM; + } + (void)memcpy(result, input_datas[input_nums - 1] + i * last_inner_size, last_inner_size * sizeof(float)); + } + } else if (inner_tile == C128NUM) { + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + SIMD_RUN_AVX512(Fp32ReduceSumConcatAxisSize128Fusion, res, result, + input_datas[j] + i * C128NUM * reduce_axis_size[j], reduce_axis_size[j]); + result += C128NUM; + } + (void)memcpy(result, input_datas[input_nums - 1] + i * last_inner_size, last_inner_size * sizeof(float)); + } + } + return res; +} + +int64_t Fp32ReduceSumConcatFusion(float *output_data, float **input_datas, const int64_t *reduce_axis_size, + int64_t input_nums, int64_t batch, int64_t batch_tile_size, int64_t inner_tile, + int64_t thread_num, int64_t task_id) { + AVX512_HARDWARE_SELF_AWARENESS_BEGIN; + if (inner_tile == C16NUM || inner_tile == C32NUM || inner_tile == C64NUM || inner_tile == C128NUM) { + return Fp32ReduceSumConcatAxisSizeAVX512Fusion(output_data, input_datas, reduce_axis_size, input_nums, batch, + batch_tile_size, inner_tile, thread_num, task_id); + } + AVX512_HARDWARE_SELF_AWARENESS_END; + + int64_t single_thread_tile = DOWN_DIV(batch, thread_num); + int64_t less_tile = batch - thread_num * single_thread_tile; + + int64_t batch_start = task_id * single_thread_tile; + if (task_id < less_tile) { + batch_start += task_id; + single_thread_tile += 1; + } else { + batch_start += less_tile; + } + int64_t batch_end = batch_start + single_thread_tile; + for (int i = batch_start; i < batch_end; i++) { + float *result = output_data + i * batch_tile_size; + for (size_t j = 0; j < input_nums - 1; j++) { + const float *input_data_ptr = input_datas[j] + i * inner_tile * reduce_axis_size[j]; + + for (int k = 0; k < inner_tile; k++) { + result[k] = input_data_ptr[k]; + for (int l = 1; l < reduce_axis_size[j]; l++) { + result[k] += input_data_ptr[l * inner_tile + k]; + } + } + result += inner_tile; + } + + int64_t inner_size2 = batch_tile_size - (input_nums - 1) * inner_tile; + const float *input_data_ptr = input_datas[input_nums - 1] + i * inner_size2; + (void)memcpy(result, input_data_ptr, inner_size2 * sizeof(float)); + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/reduce_concat_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/reduce_concat_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..65112b72c5de44862835b944ab559e3ee49470dd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/reduce_concat_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_REDUCE_CONCAT_F32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_REDUCE_CONCAT_F32_ACTIVATION_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int64_t Fp32ReduceSumConcatFusion(float *output_data, float **input_datas, const int64_t *reduce_axis_size, + int64_t input_nums, int64_t batch, int64_t batch_tile_size, int64_t inner_tile, + int64_t thread_num, int64_t task_id); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_REDUCE_CONCAT_F32_ACTIVATION_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/reduce_concat_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/reduce_concat_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..b46626d110104c70cf9ed63bb651884d67ed3f76 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/reduce_concat_fp32_simd.h.in @@ -0,0 +1,115 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_REDUCE_CONCAT_FP32_SIMD_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_REDUCE_CONCAT_FP32_SIMD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +#ifdef MS_SIMD_AVX512 +static inline int Fp32ReduceSumConcatAxisSize16Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const float *input_data, int64_t reduce_axis_size) { + SIMD_F32 zmm00 = SIMD_LD_F32(input_data); + for (int l = 1; l < reduce_axis_size; l++) { + input_data += (1 * BLOCK_NUM); + zmm00 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data)); + } + SIMD_ST_F32(output_data, zmm00); + return index; +} + +static inline int Fp32ReduceSumConcatAxisSize32Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const float *input_data, int64_t reduce_axis_size) { + SIMD_F32 zmm00 = SIMD_LD_F32(input_data + 0 * BLOCK_NUM); + SIMD_F32 zmm01 = SIMD_LD_F32(input_data + 1 * BLOCK_NUM); + for (int l = 1; l < reduce_axis_size; l++) { + input_data += (2 * BLOCK_NUM); + zmm00 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data + 0 * BLOCK_NUM)); + zmm01 = SIMD_ADD_F32(zmm01, SIMD_LD_F32(input_data + 1 * BLOCK_NUM)); + } + + SIMD_ST_F32(output_data + 0 * BLOCK_NUM, zmm00); + SIMD_ST_F32(output_data + 1 * BLOCK_NUM, zmm01); + + return index; +} + +static inline int Fp32ReduceSumConcatAxisSize64Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const float *input_data, int64_t reduce_axis_size) { + SIMD_F32 zmm00 = SIMD_LD_F32(input_data + 0 * BLOCK_NUM); + SIMD_F32 zmm01 = SIMD_LD_F32(input_data + 1 * BLOCK_NUM); + SIMD_F32 zmm02 = SIMD_LD_F32(input_data + 2 * BLOCK_NUM); + SIMD_F32 zmm03 = SIMD_LD_F32(input_data + 3 * BLOCK_NUM); + for (int l = 1; l < reduce_axis_size; l++) { + input_data += (4 * BLOCK_NUM); + zmm00 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data + 0 * BLOCK_NUM)); + zmm01 = SIMD_ADD_F32(zmm01, SIMD_LD_F32(input_data + 1 * BLOCK_NUM)); + zmm02 = SIMD_ADD_F32(zmm02, SIMD_LD_F32(input_data + 2 * BLOCK_NUM)); + zmm03 = SIMD_ADD_F32(zmm03, SIMD_LD_F32(input_data + 3 * BLOCK_NUM)); + } + + SIMD_ST_F32(output_data + 0 * BLOCK_NUM, zmm00); + SIMD_ST_F32(output_data + 1 * BLOCK_NUM, zmm01); + SIMD_ST_F32(output_data + 2 * BLOCK_NUM, zmm02); + SIMD_ST_F32(output_data + 3 * BLOCK_NUM, zmm03); + + return index; +} + +static inline int Fp32ReduceSumConcatAxisSize128Fusion@SIMD_INSTRUCTION@(int index, float *output_data, const float *input_data, int64_t reduce_axis_size) { + SIMD_F32 zmm00 = SIMD_LD_F32(input_data + 0 * BLOCK_NUM); + SIMD_F32 zmm01 = SIMD_LD_F32(input_data + 1 * BLOCK_NUM); + SIMD_F32 zmm02 = SIMD_LD_F32(input_data + 2 * BLOCK_NUM); + SIMD_F32 zmm03 = SIMD_LD_F32(input_data + 3 * BLOCK_NUM); + SIMD_F32 zmm04 = SIMD_LD_F32(input_data + 4 * BLOCK_NUM); + SIMD_F32 zmm05 = SIMD_LD_F32(input_data + 5 * BLOCK_NUM); + SIMD_F32 zmm06 = SIMD_LD_F32(input_data + 6 * BLOCK_NUM); + SIMD_F32 zmm07 = SIMD_LD_F32(input_data + 7 * BLOCK_NUM); + for (int l = 1; l < reduce_axis_size; l++) { + input_data += (8 * BLOCK_NUM); + zmm00 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data + 0 * BLOCK_NUM)); + zmm01 = SIMD_ADD_F32(zmm01, SIMD_LD_F32(input_data + 1 * BLOCK_NUM)); + zmm02 = SIMD_ADD_F32(zmm02, SIMD_LD_F32(input_data + 2 * BLOCK_NUM)); + zmm03 = SIMD_ADD_F32(zmm03, SIMD_LD_F32(input_data + 3 * BLOCK_NUM)); + zmm04 = SIMD_ADD_F32(zmm00, SIMD_LD_F32(input_data + 4 * BLOCK_NUM)); + zmm05 = SIMD_ADD_F32(zmm01, SIMD_LD_F32(input_data + 5 * BLOCK_NUM)); + zmm06 = SIMD_ADD_F32(zmm02, SIMD_LD_F32(input_data + 6 * BLOCK_NUM)); + zmm07 = SIMD_ADD_F32(zmm03, SIMD_LD_F32(input_data + 7 * BLOCK_NUM)); + } + + SIMD_ST_F32(output_data + 0 * BLOCK_NUM, zmm00); + SIMD_ST_F32(output_data + 1 * BLOCK_NUM, zmm01); + SIMD_ST_F32(output_data + 2 * BLOCK_NUM, zmm02); + SIMD_ST_F32(output_data + 3 * BLOCK_NUM, zmm03); + SIMD_ST_F32(output_data + 4 * BLOCK_NUM, zmm04); + SIMD_ST_F32(output_data + 5 * BLOCK_NUM, zmm05); + SIMD_ST_F32(output_data + 6 * BLOCK_NUM, zmm06); + SIMD_ST_F32(output_data + 7 * BLOCK_NUM, zmm07); + + return index; +} + +#endif + + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/tools/graph_kernel/converter/parameter_to_tensor.cc b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/split_reduce_concat_fp32.c similarity index 41% rename from mindspore-lite/tools/graph_kernel/converter/parameter_to_tensor.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/split_reduce_concat_fp32.c index 2c10309d44c3db5ade2b93562e87480f867c7828..35d525eca41b13a460254beb6a9610df22ce1f97 100644 --- a/mindspore-lite/tools/graph_kernel/converter/parameter_to_tensor.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/split_reduce_concat_fp32.c @@ -13,33 +13,30 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "tools/graph_kernel/converter/parameter_to_tensor.h" -namespace mindspore::graphkernel { -bool ParameterToTensor::Run(const FuncGraphPtr &func_graph) { - auto todos = TopoSort(func_graph->output()); - for (auto &node : todos) { - auto cnode = node->cast(); - if (cnode == nullptr) { - continue; - } +#include "nnacl/fp32/online_fusion/split_reduce_concat_fp32.h" +#include +#include "nnacl/reduce_fp32_simd.h" +#include "nnacl/errorcode.h" - for (size_t idx = 1; idx < cnode->size(); idx++) { - if (cnode->input(idx)->isa()) { - auto default_param = cnode->input(idx)->cast()->default_param(); - if (default_param == nullptr) { - continue; - } - auto param_value = default_param->cast(); - if (param_value == nullptr) { - continue; +int64_t Fp32SplitReduceSumConcatFusion(const float *src, float *dst, int64_t inner_size, int64_t mid_size, + int32_t *mid_split, int64_t mid_len, int64_t out_size) { + const float *cur_src = src; + float *cur_dst = dst; + for (int64_t i = 0; i < out_size; i++) { + for (int64_t j = 0; j < mid_len; j++) { + int k = 0; + SIMD_RUN_NO_SCALAR(ReduceSum, k, cur_src, cur_dst, inner_size, mid_split[j]); + for (; k < inner_size; k++) { + float result = cur_src[k]; + for (int64_t l = 1; l < mid_split[j]; l++) { + result += cur_src[inner_size * l + k]; } - auto value = NewValueNode(param_value); - value->set_abstract(param_value->ToAbstract()); - cnode->set_input(idx, value); + cur_dst[k] = result; } + cur_src += (inner_size * mid_split[j]); + cur_dst += inner_size; } } - return true; + return NNACL_OK; } -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/split_reduce_concat_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/split_reduce_concat_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d574a386747355a372459aca8ae0dd3dfeda32b5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/online_fusion/split_reduce_concat_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_SPLIT_REDUCE_CONCAT_F32_ACTIVATION_H_ +#define MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_SPLIT_REDUCE_CONCAT_F32_ACTIVATION_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int64_t Fp32SplitReduceSumConcatFusion(const float *src, float *dst, int64_t inner_size, int64_t mid_size, + int32_t *mid_split, int64_t mid_len, int64_t out_size); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ONLINE_FUSION_FP32_SPLIT_REDUCE_CONCAT_F32_ACTIVATION_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pack_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pack_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..6093ecd0b1a9eb4772ecc7ecee6cbf9552b8b1f8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pack_fp32.c @@ -0,0 +1,2078 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) { + PackNCHWToNHWCFp32(src, dst, 1, plane, channel, 0, 0); +} + +void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float)); + } + } +} + +void PackNHWCToNC4HW4NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel) { + if (channel <= C4NUM) { + memcpy(dst, src, batch * plane * channel * sizeof(float)); + return; + } + int tmp = DOWN_DIV(channel, C4NUM); + int c_res = channel - tmp * C4NUM; + int c4_block = tmp * plane * C4NUM; + for (int b = 0; b < batch; b++) { + int batch_oc_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = batch_oc_offset + k * channel; + int dst_kernel_offset = batch_oc_offset + k * C4NUM; + int c = 0; + for (; c <= channel - C4NUM; c += C4NUM) { +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + MS_FLOAT32X4 src_data = MS_LDQ_F32(src + src_kernel_offset + c); + MS_STQ_F32(dst + dst_kernel_offset + c * plane, src_data); +#else + for (int k1 = 0; k1 < C4NUM; ++k1) { + (dst + dst_kernel_offset + c * plane)[k1] = (src + src_kernel_offset + c)[k1]; + } +#endif + } + for (; c < channel; ++c) { + dst[batch_oc_offset + c4_block + k * c_res + c - tmp * C4NUM] = src[src_kernel_offset + c]; + } + } + } +} + +void PackNHWCToNC8HW8NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel) { + if (channel <= C8NUM) { + memcpy(dst, src, batch * plane * channel * sizeof(float)); + return; + } + int tmp = DOWN_DIV(channel, C8NUM); + int c_res = channel - tmp * C8NUM; + int c8_block = tmp * plane * C8NUM; + for (int b = 0; b < batch; b++) { + int batch_oc_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = batch_oc_offset + k * channel; + int dst_kernel_offset = batch_oc_offset + k * C8NUM; + int c = 0; + for (; c <= channel - C8NUM; c += C8NUM) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 src_data = MS_LD256_F32(src + src_kernel_offset + c); + MS_ST256_F32(dst + dst_kernel_offset + c * plane, src_data); +#else + for (int k1 = 0; k1 < C8NUM; ++k1) { + (dst + dst_kernel_offset + c * plane)[k1] = (src + src_kernel_offset + c)[k1]; + } +#endif + } + for (; c < channel; ++c) { + dst[batch_oc_offset + c8_block + k * c_res + c - tmp * C8NUM] = src[src_kernel_offset + c]; + } + } + } +} + +void RowMajor2ColMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; ++r) { + for (int c = 0; c < col; ++c) { + dst_ptr[c * row + r] = src_ptr[r * col + c]; + } + } +} + +void RowMajor2RowMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + if (row_end > row_start) { + src_ptr += row_start * col; + dst_ptr += row_start * col; + memcpy(dst_ptr, src_ptr, (row_end - row_start) * col * (int)(sizeof(float))); + } +} + +void RowMajor2Row4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = src[c]; + } + for (; c < UP_ROUND(col, C4NUM); c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = 0; + } + } + return; +} + +void RowMajor2Row6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = src[c]; + } + for (; c < UP_ROUND(col, C6NUM); c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = 0; + } + } + return; +} + +void RowMajor2Row8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c]; + } + for (; c < UP_ROUND(col, C8NUM); c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = 0; + } + } + return; +} + +void RowMajor2Row12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd12 = c / C12NUM; + int cm12 = c % C12NUM; + dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = src[c]; + } + for (; c < UP_ROUND(col, C12NUM); c++) { + int cd12 = c / C12NUM; + int cm12 = c % C12NUM; + dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = 0; + } + } + return; +} + +void RowMajor2Row16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + for (int r = row_start; r < row_end; r++) { + const float *src = src_ptr + r * col; + int c = 0; + for (; c < col; c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c]; + } + for (; c < UP_ROUND(col, C16NUM); c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = 0; + } + } + return; +} + +void RowMajor2Row32MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end) { + // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met. + int row_block_num = UP_DIV(row, C8NUM); + int row_block = C4NUM; + for (int i = 0; i < row_block_num; i += row_block) { + row_block = MSMIN(C4NUM, row_block_num - i); // max_tile = 4 + int row_remainder = MSMIN(row_block * C8NUM, row - i * C8NUM); + dst_ptr += col_start * row_block * C8NUM; + for (int oc = col_start; oc < col_end; ++oc) { + memcpy(dst_ptr, src_ptr + oc * row + i * C8NUM, row_remainder * sizeof(float)); + dst_ptr += row_block * C8NUM; + } + dst_ptr += (col - col_end) * row_block * C8NUM; + } +} + +void RowMajor2Row64MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end) { + // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met. + int row_block_num = UP_DIV(row, C16NUM); + int row_block = C4NUM; + for (int i = 0; i < row_block_num; i += row_block) { + row_block = MSMIN(C4NUM, row_block_num - i); // max_tile = 4 + int row_remainder = MSMIN(row_block * C16NUM, row - i * C16NUM); + dst_ptr += col_start * row_block * C16NUM; + for (int oc = col_start; oc < col_end; ++oc) { + memcpy(dst_ptr, src_ptr + oc * row + i * C16NUM, row_remainder * sizeof(float)); + dst_ptr += row_block * C16NUM; + } + dst_ptr += (col - col_end) * row_block * C16NUM; + } +} + +#ifdef ENABLE_ARM64 +void RowMajor2Col12Major_arm64(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s}, [x10], %[stride]\n" + "ld1 {v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s}, [x10], %[stride]\n" + "ld1 {v3.4s}, [x10], %[stride]\n" + + "ld1 {v4.4s}, [x10], %[stride]\n" + "ld1 {v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s}, [x10], %[stride]\n" + "ld1 {v7.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v0.4s, v1.4s\n" + "zip2 v13.4s, v0.4s, v1.4s\n" + "zip1 v14.4s, v2.4s, v3.4s\n" + "zip2 v15.4s, v2.4s, v3.4s\n" + + "ld1 {v8.4s}, [x10], %[stride]\n" + "ld1 {v9.4s}, [x10], %[stride]\n" + "ld1 {v10.4s}, [x10], %[stride]\n" + "ld1 {v11.4s}, [x10], %[stride]\n" + + "zip1 v16.4s, v4.4s, v5.4s\n" + "zip2 v17.4s, v4.4s, v5.4s\n" + "zip1 v18.4s, v6.4s, v7.4s\n" + "zip2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v23.2d, v12.2d, v14.2d\n" + "trn1 v26.2d, v13.2d, v15.2d\n" + "trn2 v29.2d, v13.2d, v15.2d\n" + + "trn1 v21.2d, v16.2d, v18.2d\n" + "trn2 v24.2d, v16.2d, v18.2d\n" + "trn1 v27.2d, v17.2d, v19.2d\n" + "trn2 v30.2d, v17.2d, v19.2d\n" + + "zip1 v12.4s, v8.4s, v9.4s\n" + "zip2 v13.4s, v8.4s, v9.4s\n" + "zip1 v14.4s, v10.4s, v11.4s\n" + "zip2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v22.2d, v12.2d, v14.2d\n" + "trn2 v25.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); + return; +} +#endif +#ifdef ENABLE_ARM32 +void RowMajor2Col12Major_arm32(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q10}, [r10], %[stride]\n" + "vld1.32 {q13}, [r10], %[stride]\n" + + "vtrn.32 d0, d6\n" + "vtrn.32 d1, d7\n" + "vtrn.32 d20, d26\n" + "vtrn.32 d21, d27\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q8}, [r10], %[stride]\n" + "vld1.32 {q11}, [r10], %[stride]\n" + "vld1.32 {q14}, [r10], %[stride]\n" + + "vswp d1, d20\n" + "vswp d7, d26\n" + + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q9}, [r10], %[stride]\n" + "vld1.32 {q12}, [r10], %[stride]\n" + "vld1.32 {q15}, [r10], %[stride]\n" + + "vtrn.32 d2, d16\n" + "vtrn.32 d3, d17\n" + "vtrn.32 d22, d28\n" + "vtrn.32 d23, d29\n" + + "vswp d3, d22\n" + "vswp d17, d28\n" + + "vtrn.32 d4, d18\n" + "vtrn.32 d5, d19\n" + "vtrn.32 d24, d30\n" + "vtrn.32 d25, d31\n" + + "vswp d5, d24\n" + "vswp d19, d30\n" + + "vst1.32 {q0, q1}, [r12]!\n" + "vst1.32 {q2, q3}, [r12]!\n" + "vst1.32 {q8, q9}, [r12]!\n" + "vst1.32 {q10, q11}, [r12]!\n" + "vst1.32 {q12, q13}, [r12]!\n" + "vst1.32 {q14, q15}, [r12]!\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + return; +} +#endif +void RowMajor2Col12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int ri = (row_start / C12NUM * C12NUM); + float *dst_r = dst_ptr + ri * col; + const float *src_r = src_ptr + ri * col; + for (; ri < (row_end / C12NUM * C12NUM); ri += C12NUM) { + int ci = 0; + for (; ci < (col / C4NUM * C4NUM); ci += C4NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; +#ifdef ENABLE_ARM64 + RowMajor2Col12Major_arm64(src_c, dst_c, col); +#elif ENABLE_ARM32 + RowMajor2Col12Major_arm32(src_c, dst_c, col); +#elif ENABLE_SSE + __m128 src1 = _mm_loadu_ps(src_c); + __m128 src2 = _mm_loadu_ps(src_c + col); + __m128 src3 = _mm_loadu_ps(src_c + 2 * col); + __m128 src4 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src12L = _mm_unpacklo_ps(src1, src2); + __m128 src12H = _mm_unpackhi_ps(src1, src2); + __m128 src34L = _mm_unpacklo_ps(src3, src4); + __m128 src34H = _mm_unpackhi_ps(src3, src4); + + __m128 dst0 = _mm_movelh_ps(src12L, src34L); + __m128 dst3 = _mm_movehl_ps(src34L, src12L); + __m128 dst6 = _mm_movelh_ps(src12H, src34H); + __m128 dst9 = _mm_movehl_ps(src34H, src12H); + + __m128 src5 = _mm_loadu_ps(src_c); + __m128 src6 = _mm_loadu_ps(src_c + col); + __m128 src7 = _mm_loadu_ps(src_c + 2 * col); + __m128 src8 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src56L = _mm_unpacklo_ps(src5, src6); + __m128 src56H = _mm_unpackhi_ps(src5, src6); + __m128 src78L = _mm_unpacklo_ps(src7, src8); + __m128 src78H = _mm_unpackhi_ps(src7, src8); + __m128 dst1 = _mm_movelh_ps(src56L, src78L); + __m128 dst4 = _mm_movehl_ps(src78L, src56L); + __m128 dst7 = _mm_movelh_ps(src56H, src78H); + __m128 dst10 = _mm_movehl_ps(src78H, src56H); + + __m128 src9 = _mm_loadu_ps(src_c); + __m128 src10 = _mm_loadu_ps(src_c + col); + __m128 src11 = _mm_loadu_ps(src_c + 2 * col); + __m128 src12 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src910L = _mm_unpacklo_ps(src9, src10); + __m128 src910H = _mm_unpackhi_ps(src9, src10); + __m128 src1112L = _mm_unpacklo_ps(src11, src12); + __m128 src1112H = _mm_unpackhi_ps(src11, src12); + __m128 dst2 = _mm_movelh_ps(src910L, src1112L); + __m128 dst5 = _mm_movehl_ps(src1112L, src910L); + __m128 dst8 = _mm_movelh_ps(src910H, src1112H); + __m128 dst11 = _mm_movehl_ps(src1112H, src910H); + + _mm_storeu_ps(dst_c, dst0); + _mm_storeu_ps(dst_c + 4, dst1); + _mm_storeu_ps(dst_c + 8, dst2); + _mm_storeu_ps(dst_c + 12, dst3); + _mm_storeu_ps(dst_c + 16, dst4); + _mm_storeu_ps(dst_c + 20, dst5); + _mm_storeu_ps(dst_c + 24, dst6); + _mm_storeu_ps(dst_c + 28, dst7); + _mm_storeu_ps(dst_c + 32, dst8); + _mm_storeu_ps(dst_c + 36, dst9); + _mm_storeu_ps(dst_c + 40, dst10); + _mm_storeu_ps(dst_c + 44, dst11); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; + for (int i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C12NUM * col; + dst_r += C12NUM * col; + } + if (row_end == row) { + for (; ri < row_end; ri++, dst_r++, src_r += col) { + for (int i = 0; i < col; i++) { + dst_r[i * C12NUM] = src_r[i]; + } + } + for (; ri < UP_ROUND(row, C12NUM); ri++, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + } + } +} + +#ifdef ENABLE_ARM64 +void RowMajor2Col8Major_arm64(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[stride]\n" + "ld1 {v4.4s, v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[stride]\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + + "ld1 {v16.4s, v17.4s}, [x10], %[stride]\n" + "ld1 {v18.4s, v19.4s}, [x10], %[stride]\n" + "ld1 {v20.4s, v21.4s}, [x10], %[stride]\n" + "ld1 {v22.4s, v23.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "trn1 v0.2d, v8.2d, v10.2d\n" + "trn2 v1.2d, v8.2d, v10.2d\n" + "trn1 v2.2d, v9.2d, v11.2d\n" + "trn2 v3.2d, v9.2d, v11.2d\n" + + "zip1 v24.4s, v16.4s, v18.4s\n" + "zip2 v25.4s, v16.4s, v18.4s\n" + "zip1 v26.4s, v20.4s, v22.4s\n" + "zip2 v27.4s, v20.4s, v22.4s\n" + + "trn1 v4.2d, v12.2d, v14.2d\n" + "trn2 v5.2d, v12.2d, v14.2d\n" + "trn1 v6.2d, v13.2d, v15.2d\n" + "trn2 v7.2d, v13.2d, v15.2d\n" + + "zip1 v28.4s, v17.4s, v19.4s\n" + "zip2 v29.4s, v17.4s, v19.4s\n" + "zip1 v30.4s, v21.4s, v23.4s\n" + "zip2 v31.4s, v21.4s, v23.4s\n" + + "trn1 v16.2d, v24.2d, v26.2d\n" + "trn2 v17.2d, v24.2d, v26.2d\n" + "trn1 v18.2d, v25.2d, v27.2d\n" + "trn2 v19.2d, v25.2d, v27.2d\n" + + "trn1 v20.2d, v28.2d, v30.2d\n" + "trn2 v21.2d, v28.2d, v30.2d\n" + "trn1 v22.2d, v29.2d, v31.2d\n" + "trn2 v23.2d, v29.2d, v31.2d\n" + + "st1 {v0.4s}, [x11], #16\n" + "st1 {v16.4s}, [x11], #16\n" + "st1 {v1.4s}, [x11], #16\n" + "st1 {v17.4s}, [x11], #16\n" + "st1 {v2.4s}, [x11], #16\n" + "st1 {v18.4s}, [x11], #16\n" + "st1 {v3.4s}, [x11], #16\n" + "st1 {v19.4s}, [x11], #16\n" + "st1 {v4.4s}, [x11], #16\n" + "st1 {v20.4s}, [x11], #16\n" + "st1 {v5.4s}, [x11], #16\n" + "st1 {v21.4s}, [x11], #16\n" + "st1 {v6.4s}, [x11], #16\n" + "st1 {v22.4s}, [x11], #16\n" + "st1 {v7.4s}, [x11], #16\n" + "st1 {v23.4s}, [x11], #16\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); + return; +} +#endif +#ifdef ENABLE_ARM32 +#ifndef SUPPORT_NNIE +void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r11, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q4}, [r10], %[stride]\n" + "vld1.32 {q6}, [r10], %[stride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q5}, [r10], %[stride]\n" + "vld1.32 {q7}, [r10], %[stride]\n" + + "vswp d1, d8\n" + "vswp d5, d12\n" + + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vswp d3, d10\n" + "vswp d7, d14\n" + + "vst1.32 {q0, q1}, [r11]!\n" + "vst1.32 {q2, q3}, [r11]!\n" + "vst1.32 {q4, q5}, [r11]!\n" + "vst1.32 {q6, q7}, [r11]!\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); + return; +} +#else +void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) { + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r7, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q4}, [r10], %[stride]\n" + "vld1.32 {q6}, [r10], %[stride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q5}, [r10], %[stride]\n" + "vld1.32 {q7}, [r10], %[stride]\n" + + "vswp d1, d8\n" + "vswp d5, d12\n" + + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vswp d3, d10\n" + "vswp d7, d14\n" + + "vst1.32 {q0, q1}, [r7]!\n" + "vst1.32 {q2, q3}, [r7]!\n" + "vst1.32 {q4, q5}, [r7]!\n" + "vst1.32 {q6, q7}, [r7]!\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "r10", "r7", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); + return; +} +#endif +#endif +void RowMajor2Col8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int row8 = row_end / C8NUM * C8NUM; +#ifdef ENABLE_ARM64 + int col_skip = col / C8NUM * C8NUM; + int skip_size = C8NUM; +#else + int col_skip = col / C4NUM * C4NUM; + int skip_size = C4NUM; +#endif + int ri = (row_start / C8NUM * C8NUM); + const float *src_r = src_ptr + ri * col; + float *dst_r = dst_ptr + ri * col; + + for (; ri < row8; ri += C8NUM) { + int ci = 0; + for (; ci < col_skip; ci += skip_size) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + +#ifdef ENABLE_ARM64 + RowMajor2Col8Major_arm64(src_c, dst_c, col); +#elif ENABLE_ARM32 + RowMajor2Col8Major_arm32(src_c, dst_c, col); +#elif ENABLE_SSE + __m128 src1 = _mm_loadu_ps(src_c); + __m128 src2 = _mm_loadu_ps(src_c + col); + __m128 src3 = _mm_loadu_ps(src_c + 2 * col); + __m128 src4 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src12L = _mm_unpacklo_ps(src1, src2); // x5 + __m128 src12H = _mm_unpackhi_ps(src1, src2); // x1 + __m128 src34L = _mm_unpacklo_ps(src3, src4); // x + __m128 src34H = _mm_unpackhi_ps(src3, src4); + _mm_storeu_ps(dst_c, _mm_movelh_ps(src12L, src34L)); + _mm_storeu_ps(dst_c + C8NUM, _mm_movehl_ps(src34L, src12L)); + _mm_storeu_ps(dst_c + C16NUM, _mm_movelh_ps(src12H, src34H)); + _mm_storeu_ps(dst_c + C24NUM, _mm_movehl_ps(src34H, src12H)); + + __m128 src5 = _mm_loadu_ps(src_c); + __m128 src6 = _mm_loadu_ps(src_c + col); + __m128 src7 = _mm_loadu_ps(src_c + 2 * col); + __m128 src8 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src56L = _mm_unpacklo_ps(src5, src6); + __m128 src56H = _mm_unpackhi_ps(src5, src6); + __m128 src78L = _mm_unpacklo_ps(src7, src8); + __m128 src78H = _mm_unpackhi_ps(src7, src8); + _mm_storeu_ps(dst_c + C4NUM, _mm_movelh_ps(src56L, src78L)); + _mm_storeu_ps(dst_c + C12NUM, _mm_movehl_ps(src78L, src56L)); + _mm_storeu_ps(dst_c + 20, _mm_movelh_ps(src56H, src78H)); + _mm_storeu_ps(dst_c + 28, _mm_movehl_ps(src78H, src56H)); +#else + for (int tr = 0; tr < 8; tr++) { + for (int tc = 0; tc < 4; tc++) { + dst_c[tc * 8 + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + for (int i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C8NUM * col; + dst_r += C8NUM * col; + } + if (row_end == row) { + for (; ri < row; ri++, src_r += col, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C8NUM] = src_r[i]; + } + } + + for (; ri < UP_ROUND(row, C8NUM); ri++, dst_r++) { + for (int i = 0; i < col; i++) { + dst_r[i * C8NUM] = 0; + } + } + } +} + +void RowMajor2Col16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int row16 = row_end / C16NUM * C16NUM; + int ri = row_start / C16NUM * C16NUM; + int col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr + ri * col; + float *dst_r = dst_ptr + ri * col; + + for (; ri < row16; ri += C16NUM) { + int ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; +#ifdef ENABLE_AVX + Transpose8X8Fp32Avx(src_c, dst_c, col, C16NUM); + Transpose8X8Fp32Avx(src_c + C8NUM * col, dst_c + C8NUM, col, C16NUM); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; + for (int i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C16NUM * col; + dst_r += C16NUM * col; + } + if (row_end == row) { + for (; ri < row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + int total_row = UP_ROUND(row, C16NUM); + for (; ri < total_row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C16NUM] = 0; + } + dst_r += 1; + } + } +} + +void RowMajor2Col32MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met. +#ifdef ENABLE_AVX + int col8 = col / C8NUM * C8NUM; +#endif + int all_block_num = UP_DIV(row, C8NUM); + int cur_block = C4NUM; + row_start = UP_DIV(row_start, C8NUM); + row_end = UP_DIV(row_end, C8NUM); + for (int i = UP_ROUND(row_start, C4NUM); i < row_end; i += cur_block) { + cur_block = MSMIN(C4NUM, all_block_num - i); // max_tile = 4 + int dst_stride = cur_block * C8NUM; + int row_num = MSMIN(dst_stride, row - i * C8NUM); +#ifdef ENABLE_AVX + int row8_num = row_num / C8NUM * C8NUM; +#endif + const float *src = src_ptr + i * C8NUM * col; + float *dst = dst_ptr + i * C8NUM * col; + int r = 0; +#ifdef ENABLE_AVX + for (; r < row8_num; r += C8NUM) { + int c = 0; + for (; c < col8; c += C8NUM) { + Transpose8X8Fp32Avx(src + r * col + c, dst + c * dst_stride + r, col, dst_stride); + } + for (; c < col; ++c) { + for (int k = 0; k < C8NUM; ++k) { + dst[c * dst_stride + r + k] = src[r * col + c + k * col]; + } + } + } +#endif + for (; r < row_num; r++) { + for (int c = 0; c < col; ++c) { + dst[c * dst_stride + r] = src[r * col + c]; + } + } + } +} + +void RowMajor2Col64MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + // Not exactly aligned to 64, but aligned to 48 or 32 or 16 If 64 is not met. + int all_block_num = UP_DIV(row, C16NUM); + int cur_block = C4NUM; + row_start = UP_DIV(row_start, C16NUM); + row_end = UP_DIV(row_end, C16NUM); + for (int i = UP_ROUND(row_start, C4NUM); i < row_end; i += cur_block) { + cur_block = MSMIN(C4NUM, all_block_num - i); // max_tile = 4 + int dst_stride = cur_block * C16NUM; + int row_num = MSMIN(dst_stride, row - i * C16NUM); + const float *src = src_ptr + i * C16NUM * col; + float *dst = dst_ptr + i * C16NUM * col; + int r = 0; + for (; r < row_num; r++) { + for (int c = 0; c < col; ++c) { + dst[c * dst_stride + r] = src[r * col + c]; + } + } + } +} + +void RowMajor2Col6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int row6 = row_end / C6NUM * C6NUM; + int ri = row_start / C6NUM * C6NUM; + int col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr + ri * col; + float *dst_r = dst_ptr + ri * col; + + for (; ri < row6; ri += C6NUM) { + int ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + +#ifdef ENABLE_AVX + __m256 src0 = _mm256_loadu_ps(src_c); + __m256 src1 = _mm256_loadu_ps(src_c + col); + __m256 src2 = _mm256_loadu_ps(src_c + 2 * col); + __m256 src3 = _mm256_loadu_ps(src_c + 3 * col); + __m256 src4 = _mm256_loadu_ps(src_c + 4 * col); + __m256 src5 = _mm256_loadu_ps(src_c + 5 * col); + __m256 trans0 = _mm256_unpacklo_ps(src0, src1); + __m256 trans1 = _mm256_unpacklo_ps(src2, src3); + __m256 trans2 = _mm256_unpacklo_ps(src4, src5); + __m256 trans3 = _mm256_unpackhi_ps(src0, src1); + __m256 trans4 = _mm256_unpackhi_ps(src2, src3); + __m256 trans5 = _mm256_unpackhi_ps(src4, src5); + __m128 lo0 = _mm256_castps256_ps128(trans0); + __m128 lo1 = _mm256_castps256_ps128(trans1); + __m128 lo2 = _mm256_castps256_ps128(trans2); + __m128 lo3 = _mm256_castps256_ps128(trans3); + __m128 lo4 = _mm256_castps256_ps128(trans4); + __m128 lo5 = _mm256_castps256_ps128(trans5); + __m128 hi0 = _mm256_extractf128_ps(trans0, 1); + __m128 hi1 = _mm256_extractf128_ps(trans1, 1); + __m128 hi2 = _mm256_extractf128_ps(trans2, 1); + __m128 hi3 = _mm256_extractf128_ps(trans3, 1); + __m128 hi4 = _mm256_extractf128_ps(trans4, 1); + __m128 hi5 = _mm256_extractf128_ps(trans5, 1); + __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2)); + _mm_storeu_ps(dst_c, res0); + _mm_storeu_ps(dst_c + 4, res1); + _mm_storeu_ps(dst_c + 8, res2); + _mm_storeu_ps(dst_c + 12, res3); + _mm_storeu_ps(dst_c + 16, res4); + _mm_storeu_ps(dst_c + 20, res5); + _mm_storeu_ps(dst_c + 24, res6); + _mm_storeu_ps(dst_c + 28, res7); + _mm_storeu_ps(dst_c + 32, res8); + _mm_storeu_ps(dst_c + 36, res9); + _mm_storeu_ps(dst_c + 40, res10); + _mm_storeu_ps(dst_c + 44, res11); +#else + for (int tr = 0; tr < C6NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C6NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + for (int i = 0; i < C6NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C6NUM * col; + dst_r += C6NUM * col; + } + + if (row_end == row) { + for (; ri < row_end; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C6NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + + int totalRow = UP_ROUND(row, C6NUM); + for (; ri < totalRow; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C6NUM] = 0; + } + dst_r += 1; + } + } +} + +void RowMajor2Col4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end) { + int row4 = row_end / C4NUM * C4NUM; + int ri = row_start / C4NUM * C4NUM; + int col4 = col / C4NUM * C4NUM; + const float *src_r = src_ptr + ri * col; + float *dst_r = dst_ptr + ri * col; + + for (; ri < row4; ri += C4NUM) { + int ci = 0; + for (; ci < col4; ci += C4NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C4NUM; + +#ifdef ENABLE_ARM32 + int stride = col * 4; + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + + "vtrn.32 d0, d2\n" + "vtrn.32 d1, d3\n" + "vtrn.32 d4, d6\n" + "vtrn.32 d5, d7\n" + + "vswp d1, d4\n" + "vswp d3, d6\n" + + "vst1.32 {q0}, [r12]!\n" + "vst1.32 {q1}, [r12]!\n" + "vst1.32 {q2}, [r12]!\n" + "vst1.32 {q3}, [r12]!\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "r10", "r12", "q0", "q1", "q2", "q3"); +#elif ENABLE_SSE + __m128 src1 = _mm_loadu_ps(src_c); + __m128 src2 = _mm_loadu_ps(src_c + col); + __m128 src3 = _mm_loadu_ps(src_c + 2 * col); + __m128 src4 = _mm_loadu_ps(src_c + 3 * col); + src_c += C4NUM * col; + __m128 src12L = _mm_unpacklo_ps(src1, src2); + __m128 src12H = _mm_unpackhi_ps(src1, src2); + __m128 src34L = _mm_unpacklo_ps(src3, src4); + __m128 src34H = _mm_unpackhi_ps(src3, src4); + + __m128 dst0 = _mm_movelh_ps(src12L, src34L); + __m128 dst1 = _mm_movehl_ps(src34L, src12L); + __m128 dst2 = _mm_movelh_ps(src12H, src34H); + __m128 dst3 = _mm_movehl_ps(src34H, src12H); + + _mm_storeu_ps(dst_c, dst0); + _mm_storeu_ps(dst_c + 4, dst1); + _mm_storeu_ps(dst_c + 8, dst2); + _mm_storeu_ps(dst_c + 12, dst3); +#else + for (size_t tr = 0; tr < C4NUM; tr++) { + for (size_t tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C4NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C4NUM; + for (int i = 0; i < C4NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + if (row_end == row) { + for (; ri < row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C4NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + + int total_row = UP_ROUND(row, C4NUM); + for (; ri < total_row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C4NUM] = 0; + } + dst_r += 1; + } + } +} + +void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2ColMajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2RowMajor(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2RowMajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row4MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row6MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row8MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row12MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Row16MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int col, int row) { + RowMajor2Row32MajorParallel(src_ptr, dst_ptr, col, row, 0, col); +} +void RowMajor2Row64Major(const float *src_ptr, float *dst_ptr, int col, int row) { + RowMajor2Row64MajorParallel(src_ptr, dst_ptr, col, row, 0, col); +} +void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col12MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col8MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col16MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col32MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col64Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col64MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col6MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} +void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col) { + RowMajor2Col4MajorParallel(src_ptr, dst_ptr, row, col, 0, row); +} + +void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int c4_minus = c4 - 1; + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c4 * C4NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C4NUM; + for (int j = 0; j < c4_minus; ++j) { + int src_ic_offset = src_kernel_offset + j * C4NUM; + int dst_ic_offset = dst_kernel_offset + j * plane * C4NUM; +#ifdef ENABLE_ARM + vst1q_f32((float *)dst + dst_ic_offset, vld1q_f32((float *)src + src_ic_offset)); +#else + for (int i = 0; i < C4NUM; ++i) { + ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i]; + } +#endif + } + int tmp_c = c4_minus * C4NUM; + int tmp_c_offset = tmp_c * plane; + int res_c = channel - tmp_c; + if (res_c > channel) { + return; + } + for (int l = 0; l < res_c; ++l) { + int src_ic_offset = src_kernel_offset + tmp_c + l; + int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l; + ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * channel; + int dst_offset = b * plane * c4 * C4NUM; + RowMajor2Col4Major((const float *)src + src_offset, (float *)dst + dst_offset, channel, plane); + } +} + +void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int oc_block = UP_DIV(channel, C4NUM); + int oc_block_channel = oc_block * C4NUM; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int dst_batch_offset = b * oc_block_channel * plane; + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel; + memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); + memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +void PackNHWCToNHWCXFp32(const void *src, void *dst, int batch, int plane, int channel, int oc_tile) { + int oc_block = UP_DIV(channel, oc_tile); + int oc_block_channel = oc_block * oc_tile; + int ic_remainder_ = channel % oc_tile; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int dst_batch_offset = b * oc_block_channel * plane; + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float *dst_per_plane = (float *)dst + dst_batch_offset + i * oc_block_channel; + memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); + memset(dst_per_plane + channel, 0, (oc_block_channel - channel) * sizeof(float)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +void PackNHWCToNXHWCXFp32H1W1(int output_channel, int oc_block_num, int input_channel, float *tmp_weight, + const float *src, int oc_block_unit, Transpose8X8Fp32Func transpose_func) { + int oc_block8 = DOWN_DIV(output_channel, C8NUM); + int oc = 0; + int oc_block = 0; + int ic8 = DOWN_ROUND(input_channel, C8NUM); + int oc_remainder_step = 0; + if (oc_block8 != oc_block_num) { + oc_block8 = oc_block8 / oc_block_unit * oc_block_unit; + oc_remainder_step = (oc_block_num - oc_block8) * C8NUM; + } + for (; oc < oc_block8; oc += (oc_block / C8NUM)) { + oc_block = MSMIN(oc_block_unit, oc_block8 - oc) * C8NUM; // max_tile = 32 ==> 24 ==> 16 ==> 8 + for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) { + int ic = 0; + for (; ic < ic8; ic += C8NUM) { + transpose_func(src + ic, tmp_weight + ic * oc_block + oc_tmp, input_channel, oc_block); + } + for (; ic < input_channel; ++ic) { + for (int j = 0; j < C8NUM; ++j) { + tmp_weight[ic * oc_block + oc_tmp + j] = src[ic + input_channel * j]; + } + } + src += C8NUM * input_channel; + } + tmp_weight += oc_block * input_channel; + } + oc = output_channel - oc_block8 * C8NUM; + for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) { + for (int ic = 0; ic < input_channel; ++ic) { + tmp_weight[oc_remainder + oc_remainder_step * ic] = src[ic + oc_remainder * input_channel]; + } + } +} + +// PackNHWCToNXHWCXFp32 is SWPackNHWCToNXHWCXFp32 asm optimize +void PackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel, + float *tmp_weight, const float *src) { +#ifdef ENABLE_ARM64 + Transpose8X8Fp32Func transpose_func = Transpose8X8Fp32Arm64; + int oc_block_unit = C2NUM; +#elif defined(ENABLE_AVX) + Transpose8X8Fp32Func transpose_func = Transpose8X8Fp32Avx; + int oc_block_unit = C4NUM; +#endif + // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8 + // output_channel: batch + int plane = kernel_w * kernel_h; + if (plane == 1) { // conv 1x1 weight pack + PackNHWCToNXHWCXFp32H1W1(output_channel, oc_block_num, input_channel, tmp_weight, src, oc_block_unit, + transpose_func); + return; + } + + int ic8 = DOWN_ROUND(input_channel, C8NUM); + int oc_block8 = DOWN_DIV(output_channel, C8NUM); + int oc_block = 0; + int oc = 0; + int oc_remainder_step = 0; + if (oc_block8 != oc_block_num) { + oc_block8 = oc_block8 / oc_block_unit * oc_block_unit; + oc_remainder_step = (oc_block_num - oc_block8) * C8NUM; + } + for (; oc < oc_block8; oc += (oc_block / C8NUM)) { + oc_block = MSMIN(oc_block_unit, oc_block8 - oc) * C8NUM; // max_tile = 32 ==> 24 ==> 16 ==> 8 + for (int oc_tmp = 0; oc_tmp < oc_block; oc_tmp += C8NUM) { + for (int hw = 0; hw < plane; ++hw) { + int ic = 0; + for (; ic < ic8; ic += C8NUM) { + transpose_func(src + hw * input_channel + ic, + tmp_weight + hw * oc_block * input_channel + ic * oc_block + oc_tmp, input_channel * plane, + oc_block); + } + for (; ic < input_channel; ++ic) { + for (int j = 0; j < C8NUM; ++j) { + tmp_weight[ic * oc_block + oc_tmp + j + hw * oc_block * input_channel] = + src[ic + input_channel * j * plane + hw * input_channel]; + } + } + } + src += C8NUM * plane * input_channel; + } + tmp_weight += oc_block * input_channel * plane; + } + oc = output_channel - oc_block8 * C8NUM; + for (int oc_remainder = 0; oc_remainder < oc; ++oc_remainder) { + for (int hw = 0; hw < plane; ++hw) { + for (int ic = 0; ic < input_channel; ++ic) { + tmp_weight[oc_remainder + oc_remainder_step * ic + hw * input_channel * oc_remainder_step] = + src[ic + (oc_remainder * plane + hw) * input_channel]; + } + } + } +} + +#ifdef ENABLE_DEBUG +void SWPackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel, + float *tmp_weight, const float *src) { + // pack weight NHWC to N32HWC32 N24HWC24 N16HWC16 N8HWC8 + int oc_block = 0; + for (int i = 0; i < oc_block_num; i += oc_block) { + oc_block = MSMIN(C4NUM, oc_block_num - i); // max_tile = 4 + int index = i * C8NUM * kernel_h * kernel_w * input_channel; + int oc_remainder = MSMIN(C8NUM * oc_block, output_channel - i * C8NUM); + for (int h = 0; h < kernel_h; ++h) { + for (int w = 0; w < kernel_w; ++w) { + int w_index = (h * kernel_w + w) * input_channel + index; + for (int ic = 0; ic < input_channel; ++ic) { + int ic_index = ic + w_index; + for (int oc = 0; oc < oc_remainder; ++oc) { + int oc_index = oc * kernel_w * kernel_h * input_channel + ic_index; + tmp_weight[oc] = src[oc_index]; + } + tmp_weight += oc_block * C8NUM; + } + } + } + } +} +#endif +#endif + +void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int c8_channel = c8 * C8NUM; + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + float *dst_per_plane = (float *)dst + nhwc8_batch_offset + i * c8_channel; + memcpy(dst_per_plane, (float *)src + batch_offset + i * channel, channel * sizeof(float)); + for (int j = channel; j < c8_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int cx_num) { + int c_algin = UP_DIV(channel, cx_num); + int ic_remainder_ = channel % cx_num; + if (ic_remainder_ != 0) { + int nhwc_batch_unit_offset = channel * plane; + for (int b = 0; b < batch; b++) { + int batch_offset = b * c_algin * cx_num * plane; + for (int i = 0; i < plane; i++) { + memcpy((float *)dst + b * nhwc_batch_unit_offset + i * channel, + (float *)src + batch_offset + i * c_algin * cx_num, channel * sizeof(float)); + } + } + } else { + size_t ori_input_size = batch * plane * channel * sizeof(float); + memcpy((float *)dst, (float *)src, ori_input_size); + } +} + +void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int c = 0; c < channel; c++) { + int c4_block_num = c / C4NUM; + int c4_block_res = c % C4NUM; + int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; + int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_c_offset + k * C4NUM; + int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; + ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; + } + } + } +} + +void UnPackC4Uint(const void *src, void *dst, size_t plane, size_t channel) { + const float *fp32_src = (const float *)src; + float *fp32_dst = (float *)dst; + for (size_t c = 0; c < channel; c++) { + size_t c_div = c / C4NUM; + size_t c_mod = c % C4NUM; + for (size_t p = 0; p < plane; p++) { + int src_offset = c_div * plane * C4NUM + p * C4NUM + c_mod; + int dst_offset = c * plane + p; + fp32_dst[dst_offset] = fp32_src[src_offset]; + } + } +} + +void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_ROUND(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4; + int dst_offset = b * plane * channel; + UnPackC4Uint((const float *)src + src_offset, (float *)dst + dst_offset, plane, channel); + } +} + +void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C4NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C4NUM; + int dst_c_offset = dst_kernel_offset + c * C4NUM; +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_STQ_F32((float *)dst + dst_c_offset, MS_LDQ_F32((float *)src + src_c_offset)); +#else + ((float *)dst + dst_c_offset)[0] = ((float *)src + src_c_offset)[0]; + ((float *)dst + dst_c_offset)[1] = ((float *)src + src_c_offset)[1]; + ((float *)dst + dst_c_offset)[2] = ((float *)src + src_c_offset)[2]; + ((float *)dst + dst_c_offset)[3] = ((float *)src + src_c_offset)[3]; +#endif + } + // res part + int res_c = channel - (c4 - 1) * C4NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; + ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNC8HW8ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_ROUND(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c8; + int dst_offset = b * plane * channel; + + const float *fp32_src = (const float *)src + src_offset; + float *fp32_dst = (float *)dst + dst_offset; + for (size_t c = 0; c < channel; c++) { + size_t c_div = c / C8NUM; + size_t c_mod = c % C8NUM; + for (size_t p = 0; p < plane; p++) { + int src_offset_c = c_div * plane * C8NUM + p * C8NUM + c_mod; + int dst_offset_c = c * plane + p; + fp32_dst[dst_offset_c] = fp32_src[src_offset_c]; + } + } + } +} + +void PackNHWCToNC8HW8Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int c8_minus = c8 - 1; + for (int b = 0; b < batch; b++) { + int src_oc_offset = b * plane * channel; + int dst_oc_offset = b * plane * c8 * C8NUM; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_oc_offset + k * channel; + int dst_kernel_offset = dst_oc_offset + k * C8NUM; + for (int j = 0; j < c8_minus; ++j) { + int src_ic_offset = src_kernel_offset + j * C8NUM; + int dst_ic_offset = dst_kernel_offset + j * plane * C8NUM; + for (int i = 0; i < C8NUM; ++i) { + ((float *)dst + dst_ic_offset)[i] = ((float *)src + src_ic_offset)[i]; + } + } + int tmp_c = c8_minus * C8NUM; + int tmp_c_offset = tmp_c * plane; + int res_c = channel - tmp_c; + if (res_c > channel) { + return; + } + for (int l = 0; l < res_c; ++l) { + int src_ic_offset = src_kernel_offset + tmp_c + l; + int dst_ic_offset = dst_kernel_offset + tmp_c_offset + l; + ((float *)dst + dst_ic_offset)[0] = ((float *)src + src_ic_offset)[0]; + } + } + } +} + +void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c8 * C8NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C8NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c8 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C8NUM; + int dst_c_offset = dst_kernel_offset + c * C8NUM; + + ((float *)dst + dst_c_offset)[Index0] = ((float *)src + src_c_offset)[Index0]; + ((float *)dst + dst_c_offset)[Index1] = ((float *)src + src_c_offset)[Index1]; + ((float *)dst + dst_c_offset)[Index2] = ((float *)src + src_c_offset)[Index2]; + ((float *)dst + dst_c_offset)[Index3] = ((float *)src + src_c_offset)[Index3]; + } + // res part + int res_c = channel - (c8 - 1) * C8NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c8 - 1) * C8NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c8 - 1) * C8NUM + i; + ((float *)dst + dst_res_c_offset)[0] = ((float *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, const int batch, const int plane, + const int channel) { + int down_channel_8 = DOWN_ROUND(channel, C8NUM); + int up_channel_16 = UP_ROUND(channel, C16NUM); + size_t dst_batch_offset = (size_t)(plane * channel) * sizeof(float); + size_t src_batch_offset = (size_t)(plane * up_channel_16) * sizeof(float); + size_t unaligned_channel_size = (size_t)(channel - down_channel_8) * sizeof(float); + size_t aligned_channel_size = (size_t)(down_channel_8 * plane) * sizeof(float); + size_t src_p_offset = C8NUM * sizeof(float); + for (size_t b = 0; b < (size_t)(batch); ++b) { + const char *src_batch = (char *)(src) + b * src_batch_offset; + char *dst_bacth = (char *)(dst) + b * dst_batch_offset; + memcpy(dst_bacth, src_batch, aligned_channel_size); + src_batch += aligned_channel_size; + dst_bacth += aligned_channel_size; + for (int p = 0; p < plane; ++p) { + memcpy(dst_bacth + p * unaligned_channel_size, src_batch + p * src_p_offset, unaligned_channel_size); + } + } +} + +void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) { + int channel_up8 = UP_ROUND(channel, C8NUM); + for (int n = 0; n < batch; n++) { + for (int hw = 0; hw < plane; hw++) { + int c = 0; + for (; c < channel; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int src_index = n * plane * channel + hw * channel + c; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + ((float *)dst)[dst_index] = ((float *)src)[src_index]; + } + for (; c < channel_up8; c++) { + int c8div = c / C8NUM; + int c8mod = c % C8NUM; + int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod; + ((float *)dst)[dst_index] = 0; + } + } + } +} + +void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel) { + // pack weight NHWC to C24HWN24 (Priority 24)=>C16HWN16 (Not satisfied 24)=>C8HWN8 (Not satisfied 16) +#ifdef ENABLE_AVX + int oc_block_num = UP_DIV(channel, C8NUM); + int plane16 = plane / C16NUM * C16NUM; + for (int i = 0, oc_block = 0; i < oc_block_num; i += oc_block) { + oc_block = MSMIN(C3NUM, oc_block_num - i); + int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM); + int oc_remainder_c8 = oc_remainder / C8NUM * C8NUM; + int p = 0; + for (; p < plane16; p += C16NUM) { + int index_plane = i * C8NUM + p * channel; + for (int b = 0; b < batch; ++b) { + int index_batch = index_plane + b * plane * channel; + int oc = 0; + int stride = oc_block * C8NUM * batch; + for (; oc < oc_remainder_c8; oc += C8NUM) { + const float *cur_src = src + index_batch + oc; + float *cur_dst = dst + oc; + MS_LOAD256X16_F32(r, cur_src, channel); + STORE256X16_F32(cur_dst, stride, r); + } + for (; oc < oc_remainder; ++oc) { + for (int k = 0; k < C16NUM; ++k) { + dst[oc + stride * k] = src[index_batch + oc + channel * k]; + } + } + for (; oc < C8NUM; ++oc) { + for (int k = 0; k < C16NUM; ++k) { + dst[oc + stride * k] = 0; + } + } + dst += oc_block * C8NUM; + } + dst += (C16NUM - 1) * oc_block * C8NUM * batch; + } + for (; p < plane; ++p) { + int index_plane = i * C8NUM + p * channel; + for (int b = 0; b < batch; ++b) { + int index_batch = index_plane + b * plane * channel; + int oc = 0; + for (; oc < oc_remainder; ++oc) { + dst[oc] = src[index_batch + oc]; + } + for (; oc < C8NUM; ++oc) { + dst[oc] = 0; + } + dst += oc_block * C8NUM; + } + } + } +#else + int oc_block = 0; + int oc_block_num = UP_DIV(channel, C8NUM); + for (int i = 0; i < oc_block_num; i += oc_block) { + oc_block = MSMIN(C3NUM, oc_block_num - i); // max_tile = 4 + int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM); + for (int p = 0; p < plane; ++p) { + int index_plane = i * C8NUM + p * channel; + for (int b = 0; b < batch; ++b) { + int index_batch = index_plane + b * plane * channel; + for (int oc = 0; oc < oc_remainder; ++oc) { + dst[oc] = src[index_batch + oc]; + } + dst += oc_block * C8NUM; + } + } + } +#endif +} + +void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int c = 0; c < c4; c++) { + int dst_off_c = c * C4NUM * height * width; + for (int i = 0; i < C4NUM; i++) { + int src_off_c = (c * C4NUM + i) * height * width; + for (int kh = 0; kh < height; kh++) { + int src_off_kh = src_off_c + kh * width; + for (int kw = 0; kw < width; kw++) { + int dst_off = dst_off_c + kw * height * C4NUM + kh * C4NUM + i; + ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw]; + } + } + } + } +} + +void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel) { + int c8 = UP_DIV(channel, C8NUM); + for (int c = 0; c < c8; c++) { + int dst_off_c = c * C8NUM * height * width; + for (int i = 0; i < C8NUM; i++) { + int src_off_c = (c * C8NUM + i) * height * width; + for (int kh = 0; kh < height; kh++) { + int src_off_kh = src_off_c + kh * width; + for (int kw = 0; kw < width; kw++) { + int dst_off = dst_off_c + kw * height * C8NUM + kh * C8NUM + i; + ((float *)dst)[dst_off] = ((float *)src)[src_off_kh + kw]; + } + } + } + } +} + +void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel, int task_id, + int thread_count) { +#ifdef ENABLE_ARM64 + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64; +#elif defined(ENABLE_ARM32) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32; +#elif defined(ENABLE_AVX) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx; +#elif defined(ENABLE_SSE) && !defined(ENABLE_AVX) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse; +#endif + int hw8 = plane / C8NUM; + int task_start = 0; + int task_end = plane; + if (thread_count > 0) { + int offset_hw = UP_DIV(hw8, thread_count) * C8NUM; + task_start = offset_hw * task_id; + int count = plane - task_start; + if (count <= 0) { + return; + } + task_end = (task_id + 1) == thread_count ? plane : MSMIN(plane, task_start + offset_hw); + hw8 = task_start + ((task_end - task_start) >= offset_hw ? offset_hw : 0); + } else { + hw8 *= C8NUM; + } + int c8 = channel / C8NUM * C8NUM; + int batch = plane * channel; + for (int n = 0; n < batches; n++) { + const float *src_batch = (const float *)src + n * batch; + float *dst_batch = (float *)dst + n * batch; + int hw = task_start; + for (; hw < hw8; hw += C8NUM) { + int c = 0; + for (; c < c8; c += C8NUM) { + const float *src_ptr = src_batch + hw * channel + c; + float *dst_ptr = dst_batch + c * plane + hw; +#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32) + Transpose8X8Fp32Func_(src_ptr, dst_ptr, channel, plane); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; + } + } +#endif + } + for (; c < channel; c++) { + const float *src_ptr = src_batch + hw * channel + c; + float *dst_ptr = dst_batch + c * plane + hw; + for (size_t i = 0; i < C8NUM; i++) { + dst_ptr[i] = src_ptr[i * channel]; + } + } + } + for (; hw < task_end; hw++) { + const float *src_ptr = src_batch + hw * channel; + float *dst_ptr = dst_batch + hw; + for (size_t i = 0; i < channel; i++) { + dst_ptr[i * plane] = src_ptr[i]; + } + } + } +} + +/* +|<---------------- plane --------------->| ++---------------------------+------------+ --- +| | | | | ↑ +|8x8-blocks| ... |8x8-blocks| right | | +| | | | | | ++ - - - - -+ + - - - - -+ | | +| ... ... ... | top | channel ++ - - - - -+ + - - - - -| | | +| | | | tails | | +|8x8-blocks| ... |8x8-blocks| | | ++---------------------------+------------+ | +| |right bottom| | +| left bottom tails | tails | ↓ ++---------------------------+------------+ --- +*/ +void TransposeFp32(const void *src, void *dst, int batches, int channel, int plane, int start, int end) { +#ifdef ENABLE_ARM64 + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm64; +#elif defined(ENABLE_ARM32) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Arm32; +#elif defined(ENABLE_AVX) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Avx; +#elif defined(ENABLE_SSE) && !defined(ENABLE_AVX) + Transpose8X8Fp32Func Transpose8X8Fp32Func_ = Transpose8X8Fp32Sse; +#endif + int m_pad = UP_DIV(channel, C8NUM); + int n_pad = UP_DIV(plane, C8NUM); + int m_blk = channel / C8NUM; + int n_blk = plane / C8NUM; + int b_stride = plane * channel; + // printf("channel, plane: %d, %d\n", channel, plane); + int b = 0, m = 0, n = 0; + // To make write dst consecutively, (m,n):(0,0)->(1,0)->...->(0,1)->(1,1)->... + offset_to_index_init(start, 6, &m, m_pad, &n, n_pad, &b, batches); + for (int task = start; task < end; task++) { + const float *src_batch = (const float *)src + b * b_stride; + float *dst_batch = (float *)dst + b * b_stride; + int m_start = m * C8NUM; + int n_start = n * C8NUM; + if (m < m_blk && n < n_blk) { + // process 8x8-blocks + const float *from = src_batch + m_start * plane + n_start; + float *to = dst_batch + n_start * channel + m_start; +#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM32) + Transpose8X8Fp32Func_(from, to, plane, channel); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + to[tc * channel + tr] = from[tr * plane + tc]; + } + } +#endif + } else { + // process right bottom tails + const float *from = src_batch; + float *to = dst_batch; + int i_start = m_start; + int i_end = channel; + int j_start = n_start; + int j_end = plane; + if (m >= m_blk && n < n_blk) { + // process left bottom tails + from = src_batch + n_start; + to = dst_batch + n_start * channel; + j_start = 0; + j_end = C8NUM; + } else if (m < m_blk && n >= n_blk) { + // process right top tails + from = src_batch + m_start * plane; + to = dst_batch + m_start; + i_start = 0; + i_end = C8NUM; + } + transpose_tail(from, to, j_start, j_end, i_start, i_end, channel, plane); + } + offset_to_index_step(6, &m, m_pad, &n, n_pad, &b, batches); + } +} + +void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count) { + PackNHWCToNCHWFp32(src, dst, batch, channel, plane, task_id, thread_count); +} + +#ifdef ENABLE_ARM64 +inline void Transpose8X8Fp32Arm64(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { + size_t srcStride = src_stride * sizeof(float); + size_t dstStride = dst_stride * sizeof(float); + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + + "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n" + + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[srcStride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[srcStride]\n" + + "trn1 v16.2d, v8.2d, v10.2d\n" + "trn2 v18.2d, v8.2d, v10.2d\n" + "trn1 v20.2d, v9.2d, v11.2d\n" + "trn2 v22.2d, v9.2d, v11.2d\n" + + "ld1 {v4.4s, v5.4s}, [x10], %[srcStride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[srcStride]\n" + + "trn1 v24.2d, v12.2d, v14.2d\n" + "trn2 v26.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v30.2d, v13.2d, v15.2d\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "trn1 v17.2d, v8.2d, v10.2d\n" + "trn2 v19.2d, v8.2d, v10.2d\n" + "trn1 v21.2d, v9.2d, v11.2d\n" + "trn2 v23.2d, v9.2d, v11.2d\n" + + "trn1 v25.2d, v12.2d, v14.2d\n" + "trn2 v27.2d, v12.2d, v14.2d\n" + "trn1 v29.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v16.4s, v17.4s}, [x11], %[dstStride]\n" + "st1 {v18.4s, v19.4s}, [x11], %[dstStride]\n" + "st1 {v20.4s, v21.4s}, [x11], %[dstStride]\n" + "st1 {v22.4s, v23.4s}, [x11], %[dstStride]\n" + "st1 {v24.4s, v25.4s}, [x11], %[dstStride]\n" + "st1 {v26.4s, v27.4s}, [x11], %[dstStride]\n" + "st1 {v28.4s, v29.4s}, [x11], %[dstStride]\n" + "st1 {v30.4s, v31.4s}, [x11], %[dstStride]\n" + + : + : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif + +#ifdef ENABLE_ARM32 +inline void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { + size_t srcStride = src_stride * sizeof(float); + size_t dstStride = dst_stride * sizeof(float); + asm volatile( + "mov r10, %[src_ptr]\n" + "mov r12, %[dst_ptr]\n" + + "vld1.32 {q0, q1}, [r10], %[srcStride]\n" + "vld1.32 {q2, q3}, [r10], %[srcStride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + + "vld1.32 {q4, q5}, [r10], %[srcStride]\n" + "vld1.32 {q6, q7}, [r10], %[srcStride]\n" + + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vld1.32 {q8, q9}, [r10], %[srcStride]\n" + "vld1.32 {q10, q11}, [r10], %[srcStride]\n" + + "vswp d1, d8\n" + "vswp d3, d10\n" + "vswp d5, d12\n" + "vswp d7, d14\n" + + "vtrn.32 d16, d20\n" + "vtrn.32 d17, d21\n" + "vtrn.32 d18, d22\n" + "vtrn.32 d19, d23\n" + + "vld1.32 {q12, q13}, [r10], %[srcStride]\n" + "vld1.32 {q14, q15}, [r10], %[srcStride]\n" + + "vtrn.32 d24, d28\n" + "vtrn.32 d25, d29\n" + "vtrn.32 d26, d30\n" + "vtrn.32 d27, d31\n" + + "vswp d17, d24\n" + "vswp d19, d26\n" + "vswp d21, d28\n" + "vswp d23, d30\n" + + "add r10, r12, #16\n" + "vst1.32 {q0}, [r12], %[dstStride]\n" + "vst1.32 {q8}, [r10], %[dstStride]\n" + "vst1.32 {q2}, [r12], %[dstStride]\n" + "vst1.32 {q10}, [r10], %[dstStride]\n" + "vst1.32 {q4}, [r12], %[dstStride]\n" + "vst1.32 {q12}, [r10], %[dstStride]\n" + "vst1.32 {q6}, [r12], %[dstStride]\n" + "vst1.32 {q14}, [r10], %[dstStride]\n" + "vst1.32 {q1}, [r12], %[dstStride]\n" + "vst1.32 {q9}, [r10], %[dstStride]\n" + "vst1.32 {q3}, [r12], %[dstStride]\n" + "vst1.32 {q11}, [r10], %[dstStride]\n" + "vst1.32 {q5}, [r12], %[dstStride]\n" + "vst1.32 {q13}, [r10], %[dstStride]\n" + "vst1.32 {q7}, [r12], %[dstStride]\n" + "vst1.32 {q15}, [r10], %[dstStride]\n" + + : + : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +} +#endif + +#ifdef ENABLE_AVX +/* + Using _mm256_insertf128_ps at the beginning, instead of using _mm256_permute2f128_ps at the end. + On the whole, 4 vinsertf128 and 4 vperm2f128 are used less than before. +*/ +inline void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { + const float *src1 = src_ptr + 0 * src_stride; + const float *src2 = src_ptr + 1 * src_stride; + const float *src3 = src_ptr + 2 * src_stride; + const float *src4 = src_ptr + 3 * src_stride; + const float *src5 = src_ptr + 4 * src_stride; + const float *src6 = src_ptr + 5 * src_stride; + const float *src7 = src_ptr + 6 * src_stride; + const float *src8 = src_ptr + 7 * src_stride; + + __m256 r1, r2, r3, r4, r5, r6, r7, r8; + __m256 t1, t2, t3, t4, t5, t6, t7, t8; + // _mm256_castps128_ps256 is only for compilation and generates no instructions, thus it has zero latency. + r1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src1 + 0)), _mm_loadu_ps(src5 + 0), 1); + r2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src2 + 0)), _mm_loadu_ps(src6 + 0), 1); + r3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src3 + 0)), _mm_loadu_ps(src7 + 0), 1); + r4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src4 + 0)), _mm_loadu_ps(src8 + 0), 1); + r5 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src1 + 4)), _mm_loadu_ps(src5 + 4), 1); + r6 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src2 + 4)), _mm_loadu_ps(src6 + 4), 1); + r7 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src3 + 4)), _mm_loadu_ps(src7 + 4), 1); + r8 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(src4 + 4)), _mm_loadu_ps(src8 + 4), 1); + + t1 = _mm256_unpacklo_ps(r1, r2); + t2 = _mm256_unpackhi_ps(r1, r2); + t3 = _mm256_unpacklo_ps(r3, r4); + t4 = _mm256_unpackhi_ps(r3, r4); + t5 = _mm256_unpacklo_ps(r5, r6); + t6 = _mm256_unpackhi_ps(r5, r6); + t7 = _mm256_unpacklo_ps(r7, r8); + t8 = _mm256_unpackhi_ps(r7, r8); + + __m256 v; + v = _mm256_shuffle_ps(t1, t3, 0x4E); + r1 = _mm256_blend_ps(t1, v, 0xCC); + r2 = _mm256_blend_ps(t3, v, 0x33); + + v = _mm256_shuffle_ps(t2, t4, 0x4E); + r3 = _mm256_blend_ps(t2, v, 0xCC); + r4 = _mm256_blend_ps(t4, v, 0x33); + + v = _mm256_shuffle_ps(t5, t7, 0x4E); + r5 = _mm256_blend_ps(t5, v, 0xCC); + r6 = _mm256_blend_ps(t7, v, 0x33); + + v = _mm256_shuffle_ps(t6, t8, 0x4E); + r7 = _mm256_blend_ps(t6, v, 0xCC); + r8 = _mm256_blend_ps(t8, v, 0x33); + + STORE256X8_F32(dst_ptr, dst_stride, r); +} +#endif + +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +inline void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride) { + __m128 v0_ma = _mm_loadu_ps(src_ptr); + __m128 v1_ma = _mm_loadu_ps(src_ptr + src_stride); + __m128 v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride); + __m128 v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride); + + __m128 v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + __m128 v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + __m128 v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + __m128 v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + __m128 v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + __m128 v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + __m128 v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + __m128 v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr, v8_ma); + _mm_storeu_ps(dst_ptr + dst_stride, v9_ma); + _mm_storeu_ps(dst_ptr + 2 * dst_stride, v10_ma); + _mm_storeu_ps(dst_ptr + 3 * dst_stride, v11_ma); + + v0_ma = _mm_loadu_ps(src_ptr + C4NUM); + v1_ma = _mm_loadu_ps(src_ptr + src_stride + C4NUM); + v2_ma = _mm_loadu_ps(src_ptr + 2 * src_stride + C4NUM); + v3_ma = _mm_loadu_ps(src_ptr + 3 * src_stride + C4NUM); + + v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr + C4NUM * dst_stride, v8_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride, v9_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride, v10_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride, v11_ma); + + v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride); + v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride); + v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride); + v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride); + + v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr + C4NUM, v8_ma); + _mm_storeu_ps(dst_ptr + dst_stride + C4NUM, v9_ma); + _mm_storeu_ps(dst_ptr + 2 * dst_stride + C4NUM, v10_ma); + _mm_storeu_ps(dst_ptr + 3 * dst_stride + C4NUM, v11_ma); + + v0_ma = _mm_loadu_ps(src_ptr + C4NUM * src_stride + C4NUM); + v1_ma = _mm_loadu_ps(src_ptr + (C4NUM + 1) * src_stride + C4NUM); + v2_ma = _mm_loadu_ps(src_ptr + (C4NUM + 2) * src_stride + C4NUM); + v3_ma = _mm_loadu_ps(src_ptr + (C4NUM + 3) * src_stride + C4NUM); + + v4_ma = _mm_unpacklo_ps(v0_ma, v1_ma); + v5_ma = _mm_unpackhi_ps(v0_ma, v1_ma); + v6_ma = _mm_unpacklo_ps(v2_ma, v3_ma); + v7_ma = _mm_unpackhi_ps(v2_ma, v3_ma); + + v8_ma = _mm_movelh_ps(v4_ma, v6_ma); + v9_ma = _mm_movehl_ps(v6_ma, v4_ma); + v10_ma = _mm_movelh_ps(v5_ma, v7_ma); + v11_ma = _mm_movehl_ps(v7_ma, v5_ma); + + _mm_storeu_ps(dst_ptr + C4NUM * dst_stride + C4NUM, v8_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 1) * dst_stride + C4NUM, v9_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 2) * dst_stride + C4NUM, v10_ma); + _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride + C4NUM, v11_ma); +} +#endif + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel) { + // nchw to nc4hw4 with 1D F(2,3) + for (int i = 0; i < channel; i++) { + float *src_kernel = (float *)src + i * 9; + float *dst_kernel = (float *)dst + (i / 4) * 48 + i % 4; + for (int y = 0; y < 3; y++) { + float g0 = src_kernel[3 * y]; + float g1 = src_kernel[3 * y + 1]; + float g2 = src_kernel[3 * y + 2]; + + dst_kernel[16 * y] = g0; + dst_kernel[16 * y + 4] = 0.5f * (g0 + g1 + g2); + dst_kernel[16 * y + 8] = 0.5f * (g0 - g1 + g2); + dst_kernel[16 * y + 12] = g2; + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pack_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pack_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..6c0484be1bceb8330089323cc3adba78478e6f2c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pack_fp32.h @@ -0,0 +1,130 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_PACK_H_ +#define MINDSPORE_NNACL_FP32_PACK_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +static inline void transpose_tail(const float *from, float *to, int j_start, int j_end, int i_start, int i_end, + int j_stride, int i_stride) { + // write consecutively + for (int j = j_start; j < j_end; j++) { + for (int i = i_start; i < i_end; i++) { + to[j * j_stride + i] = from[i * i_stride + j]; + } + } +} +void TransposeFp32(const void *src, void *dst, int batches, int channel, int plane, int start, int end); +void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel); +void PackNHWCToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNCHWToNC4HW4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWCXFp32(const void *src, void *dst, int batch, int plane, int channel, int oc_tile); +void PackNHWCToNHWC8Fp32(const void *src, void *dst, int batch, int plane, int channel); +// Note: If not multithreaded, please set task_id = 0 and thread_count = 0; +void PackNHWCToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count); +void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int task_id, int thread_count); +void PackNHWCXToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel, int cx_num); +void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC8HW8ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNC8HW8Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); +void UnPackC4Uint(const void *src, void *dst, size_t plane, size_t channel); +void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel); +void PackNHWCToNC4HW4NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel); +void PackNHWCToNC8HW8NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel); + +void RowMajor2ColMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2RowMajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Row32MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end); +void RowMajor2Row64MajorParallel(const float *src_ptr, float *dst_ptr, int col, int row, int col_start, int col_end); +void RowMajor2Col4MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col6MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col8MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col12MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col16MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col32MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); +void RowMajor2Col64MajorParallel(const float *src_ptr, float *dst_ptr, int row, int col, int row_start, int row_end); + +void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2RowMajor(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int col, int row); +void RowMajor2Row64Major(const float *src_ptr, float *dst_ptr, int col, int row); +void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col); +void RowMajor2Col64Major(const float *src_ptr, float *dst_ptr, int row, int col); + +void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel); +void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel); +void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, int width, int channel); + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel); +#endif + +// Transpose 8X8 Fp32 block data +typedef void (*Transpose8X8Fp32Func)(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#ifdef ENABLE_ARM64 +void Transpose8X8Fp32Arm64(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#endif +#ifdef ENABLE_ARM32 +void Transpose8X8Fp32Arm32(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#endif +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +void PackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel, + float *tmp_weight, const float *src); +#endif +#ifdef ENABLE_AVX +#ifdef ENABLE_DEBUG +void SWPackNHWCToNXHWCXFp32(int kernel_h, int kernel_w, int output_channel, int oc_block_num, int input_channel, + float *tmp_weight, const float *src); +#endif +void Transpose8X8Fp32Avx(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#endif +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_PAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pack_fp32_opt.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pack_fp32_opt.c new file mode 100644 index 0000000000000000000000000000000000000000..a0b16445ca025a1e92676f9520c8670f4ffc54bd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pack_fp32_opt.c @@ -0,0 +1,292 @@ +#ifdef ENABLE_ARM64 +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/pack_fp32_opt.h" +#include "nnacl/op_base.h" + +void RowMajor2Col12MajorOptCore(const float *src_c, float *dst_c, size_t stride, int64_t row, int64_t col) { + if (row <= 0 || col <= 0) { + return; + } + size_t stride_byte = stride * sizeof(float); + size_t stride_unit = stride * (C12NUM - 1); + int64_t r = 0; + for (; r <= row - C12NUM; r += C12NUM) { + int64_t c = 0; + for (; c <= col - C4NUM; c += C4NUM) { + asm volatile( + "mov x9, %[src_c]\n" + "mov x10, %[dst_c]\n" + + "ld1 {v0.4s}, [x9], %[stride_byte]\n" + "ld1 {v1.4s}, [x9], %[stride_byte]\n" + "ld1 {v2.4s}, [x9], %[stride_byte]\n" + "ld1 {v3.4s}, [x9], %[stride_byte]\n" + + "ld1 {v4.4s}, [x9], %[stride_byte]\n" + "ld1 {v5.4s}, [x9], %[stride_byte]\n" + "ld1 {v6.4s}, [x9], %[stride_byte]\n" + "ld1 {v7.4s}, [x9], %[stride_byte]\n" + + "zip1 v12.4s, v0.4s, v1.4s\n" + "zip2 v13.4s, v0.4s, v1.4s\n" + "zip1 v14.4s, v2.4s, v3.4s\n" + "zip2 v15.4s, v2.4s, v3.4s\n" + + "ld1 {v8.4s}, [x9], %[stride_byte]\n" + "ld1 {v9.4s}, [x9], %[stride_byte]\n" + "ld1 {v10.4s}, [x9], %[stride_byte]\n" + "ld1 {v11.4s}, [x9], %[stride_byte]\n" + + "zip1 v16.4s, v4.4s, v5.4s\n" + "zip2 v17.4s, v4.4s, v5.4s\n" + "zip1 v18.4s, v6.4s, v7.4s\n" + "zip2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v23.2d, v12.2d, v14.2d\n" + "trn1 v26.2d, v13.2d, v15.2d\n" + "trn2 v29.2d, v13.2d, v15.2d\n" + + "trn1 v21.2d, v16.2d, v18.2d\n" + "trn2 v24.2d, v16.2d, v18.2d\n" + "trn1 v27.2d, v17.2d, v19.2d\n" + "trn2 v30.2d, v17.2d, v19.2d\n" + + "zip1 v12.4s, v8.4s, v9.4s\n" + "zip2 v13.4s, v8.4s, v9.4s\n" + "zip1 v14.4s, v10.4s, v11.4s\n" + "zip2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v22.2d, v12.2d, v14.2d\n" + "trn2 v25.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride_byte ] "r"(stride_byte) + : "memory", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); + dst_c += C48NUM; + src_c += C4NUM; + } + for (; c < col; ++c) { + for (int i = 0; i < C12NUM; ++i) { + dst_c[i] = src_c[i * stride]; + } + ++src_c; + dst_c += C12NUM; + } + src_c += stride_unit; + } + for (; r < row; ++r) { + for (int c = 0; c < col; ++c) { + dst_c[c * C12NUM] = src_c[c]; + } + src_c += stride; + ++dst_c; + } +} + +void RowMajor2Row12MajorOptCore(const float *src_c, float *dst_c, size_t stride, int64_t row, int64_t col) { + if (row <= 0 || col <= 0) { + return; + } + size_t stride_byte = stride * sizeof(float); + int64_t c = 0; + for (; c <= col - C12NUM; c += C12NUM) { + asm volatile( + "mov x9, %[src_c]\n" + "mov x10, %[dst_c]\n" + "mov x11, %[row]\n" + "1:\n" + "ld1 {v0.4s, v1.4s, v2.4s}, [x9], %[stride_byte]\n" + "st1 {v0.4s, v1.4s, v2.4s}, [x10], #48\n" + "subs x11, x11, #1\n" + "bgt 1b\n" + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride_byte ] "r"(stride_byte), [ row ] "r"(row) + : "cc", "memory", "x9", "x10", "x11", "v0", "v1", "v2"); + dst_c += row * C12NUM; + src_c += C12NUM; + } + int64_t c_remain = col - c; + if (c_remain == 0) { + return; + } + for (int64_t r = 0; r < row; ++r) { + for (c = 0; c < c_remain; ++c) { + dst_c[r * C12NUM + c] = src_c[c]; + } + src_c += stride; + } +} + +void RowMajor2Col12MajorOpt(const float *src_ptr, float *dst_ptr, int64_t row, int64_t col, int64_t start, + int64_t end) { + int64_t bundle_row = UP_DIV(row, C12NUM); + int64_t unit_num_per_batch = bundle_row * col; + if (unit_num_per_batch == 0) { + return; + } + int64_t start_batch = start / unit_num_per_batch; + int64_t end_batch = end / unit_num_per_batch; + int64_t start_remain = start % unit_num_per_batch; + int64_t end_remain = end % unit_num_per_batch; + if (col == 0) { + return; + } + int64_t start_row = start_remain / col; + int64_t end_row = end_remain / col; + int64_t start_col = start_remain % col; + int64_t end_col = end_remain % col; + const float *src = src_ptr + start_batch * row * col + start_row * C12NUM * col + start_col; + float *dst = dst_ptr + start * C12NUM; + int64_t row_num = C12NUM; + if (start_row * C12NUM + C12NUM > row) { + row_num -= (start_row * C12NUM + C12NUM - row); + } + if (start_batch == end_batch) { + if (start_row == end_row) { + RowMajor2Col12MajorOptCore(src, dst, col, row_num, end_col - start_col); + return; + } + RowMajor2Col12MajorOptCore(src, dst, col, C12NUM, col - start_col); + src += C12NUM * col - start_col; + dst += (col - start_col) * C12NUM; + ++start_row; + if (start_row < end_row) { + row_num = (end_row - start_row) * C12NUM; + RowMajor2Col12MajorOptCore(src, dst, col, row_num, col); + src += row_num * col; + dst += row_num * col; + } + row_num = C12NUM; + if (end_row * C12NUM + C12NUM > row) { + row_num -= (end_row * C12NUM + C12NUM - row); + } + RowMajor2Col12MajorOptCore(src, dst, col, row_num, end_col); + return; + } + RowMajor2Col12MajorOptCore(src, dst, col, row_num, col - start_col); + src += row_num * col - start_col; + dst += (col - start_col) * C12NUM; + row_num = row - start_row * C12NUM - C12NUM; + if (row_num > 0) { + RowMajor2Col12MajorOptCore(src, dst, col, row_num, col); + src += row_num * col; + dst += UP_DIV(row_num, C12NUM) * C12NUM * col; + } + ++start_batch; + for (; start_batch < end_batch; ++start_batch) { + RowMajor2Col12MajorOptCore(src, dst, col, row, col); + src += row * col; + dst += bundle_row * C12NUM * col; + } + if (end_row > 0) { + row_num = end_row * C12NUM; + RowMajor2Col12MajorOptCore(src, dst, col, row_num, col); + src += row_num * col; + dst += row_num * col; + } + row_num = C12NUM; + if (end_row * C12NUM + C12NUM > row) { + row_num -= (end_row * C12NUM + C12NUM - row); + } + RowMajor2Col12MajorOptCore(src, dst, col, row_num, end_col); +} + +void RowMajor2Row12MajorOpt(const float *src_ptr, float *dst_ptr, int64_t row, int64_t col, int64_t start, + int64_t end) { + int64_t bundle_col = UP_DIV(col, C12NUM); + int64_t unit_num_per_batch = bundle_col * row; + if (unit_num_per_batch == 0) { + return; + } + int64_t start_batch = start / unit_num_per_batch; + int64_t end_batch = end / unit_num_per_batch; + int64_t start_remain = start % unit_num_per_batch; + int64_t end_remain = end % unit_num_per_batch; + if (row == 0) { + return; + } + int64_t start_row = start_remain % row; + int64_t end_row = end_remain % row; + int64_t start_col = start_remain / row; + int64_t end_col = end_remain / row; + const float *src = src_ptr + start_batch * row * col + start_row * col + start_col * C12NUM; + float *dst = dst_ptr + start * C12NUM; + int64_t col_num = C12NUM; + if (start_col * C12NUM + C12NUM > col) { + col_num -= (start_col * C12NUM + C12NUM - col); + } + if (start_batch == end_batch) { + if (start_col == end_col) { + RowMajor2Row12MajorOptCore(src, dst, col, end_row - start_row, col_num); + return; + } + RowMajor2Row12MajorOptCore(src, dst, col, row - start_row, col_num); + src += C12NUM - start_row * col; + dst += (row - start_row) * C12NUM; + ++start_col; + if (start_col < end_col) { + col_num = (end_col - start_col) * C12NUM; + RowMajor2Row12MajorOptCore(src, dst, col, row, col_num); + src += col_num; + dst += row * col_num; + } + col_num = C12NUM; + if (end_col * C12NUM + C12NUM > col) { + col_num -= (end_col * C12NUM + C12NUM - col); + } + RowMajor2Row12MajorOptCore(src, dst, col, end_row, col_num); + return; + } + RowMajor2Row12MajorOptCore(src, dst, col, row - start_row, col_num); + src += col_num - start_row * col; + dst += (row - start_row) * C12NUM; + col_num = col - start_col * C12NUM - C12NUM; + if (col_num > 0) { + RowMajor2Row12MajorOptCore(src, dst, col, row, col_num); + src += col_num; + dst += UP_DIV(col_num, C12NUM) * C12NUM * row; + } + src += (row - 1) * col; + ++start_batch; + for (; start_batch < end_batch; ++start_batch) { + RowMajor2Row12MajorOptCore(src, dst, col, row, col); + src += row * col; + dst += bundle_col * C12NUM * row; + } + if (end_col > 0) { + col_num = end_col * C12NUM; + RowMajor2Row12MajorOptCore(src, dst, col, row, col_num); + src += col_num; + dst += row * col_num; + } + col_num = C12NUM; + if (end_col * C12NUM + C12NUM > col) { + col_num -= (end_col * C12NUM + C12NUM - col); + } + RowMajor2Row12MajorOptCore(src, dst, col, end_row, col_num); +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pack_fp32_opt.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pack_fp32_opt.h new file mode 100644 index 0000000000000000000000000000000000000000..95a039cbcde75d32034545bafc1fd05e7c1aae04 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pack_fp32_opt.h @@ -0,0 +1,38 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_PACK_FP32_V2_H +#define MINDSPORE_NNACL_FP32_PACK_FP32_V2_H + +#ifdef ENABLE_ARM64 +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Plan of packing supports granular multi-threads. + */ + +void RowMajor2Col12MajorOpt(const float *src_ptr, float *dst_ptr, int64_t row, int64_t col, int64_t start, int64_t end); + +void RowMajor2Row12MajorOpt(const float *src_ptr, float *dst_ptr, int64_t row, int64_t col, int64_t start, int64_t end); + +#ifdef __cplusplus +} +#endif +#endif +#endif // MINDSPORE_NNACL_FP32_PACK_FP32_V2_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pad_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pad_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7c12ddc93eb0bb266905efa78e9a64e9e7293c24 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pad_fp32.c @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/pad_fp32.h" +#include "nnacl/common_func.h" +#include "nnacl/errorcode.h" + +void Pad(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *paddings, int tid, int thread_num) { + if (thread_num == 0) { + return; + } + int in[DEFAULT_PAD_NDIMS], out[DEFAULT_PAD_NDIMS]; + for (in[0] = 0; in[0] < input_shape[0]; in[0]++) { + out[0] = in[0] + paddings[0]; + for (in[1] = tid; in[1] < input_shape[1]; in[1] += thread_num) { + out[1] = in[1] + paddings[2]; + for (in[2] = 0; in[2] < input_shape[2]; in[2]++) { + out[2] = in[2] + paddings[4]; + for (in[3] = 0; in[3] < input_shape[3]; in[3]++) { + out[3] = in[3] + paddings[6]; + for (in[4] = 0; in[4] < input_shape[4]; in[4]++) { + out[4] = in[4] + paddings[8]; + float *dst = output_data + Offset6d(output_shape, out) + paddings[10]; + const float *src = input_data + Offset6d(input_shape, in); + memcpy(dst, src, input_shape[5] * (int)(sizeof(float))); + } + } + } + } + } +} + +int TransOut2InputDimIndex(int out_dim_index, int left_pad, int in_dim, int offset) { + if (out_dim_index < left_pad) { + // left pad + const int index_sum = left_pad + offset - 1; + int in_index = MSMAX(index_sum - out_dim_index, offset); + return MSMIN(in_index, in_dim - 1); + } + out_dim_index -= left_pad; + if (out_dim_index < in_dim) { + return out_dim_index; + } + // right pad + out_dim_index -= in_dim; + const int index_sum = in_dim - 1 - offset; + return MSMAX(index_sum - out_dim_index, 0); +} + +int GetInputFlattenIndex(int out_flatten_index, const int32_t *input_shape, const int *in_strides, + const int *out_strides, const int *paddings, int mirror_offset) { + int in_flatten_index = 0; + for (int i = 0; i < DEFAULT_PAD_NDIMS; ++i) { + int left_pad = paddings[i * 2]; + NNACL_CHECK_ZERO_RETURN_ERR(out_strides[i]); + int out_dim_index = out_flatten_index / out_strides[i]; + out_flatten_index %= out_strides[i]; + int in_dim_index = TransOut2InputDimIndex(out_dim_index, left_pad, input_shape[i], mirror_offset); + in_flatten_index += in_dim_index * in_strides[i]; + } + return in_flatten_index; +} + +void MirrorPad(const float *input_data, float *output_data, const int32_t *input_shape, const int *in_strides, + const int *out_strides, const int *paddings, int mirror_offset, int begin, int end) { + for (int i = begin; i < end; ++i) { + output_data[i] = input_data[GetInputFlattenIndex(i, input_shape, in_strides, out_strides, paddings, mirror_offset)]; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pad_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pad_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..10bb8506da097e969e1337cf638f6886e0d913d6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pad_fp32.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_PAD_FP32_H_ +#define NNACL_FP32_PAD_FP32_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include +#include "nnacl/op_base.h" +#include "nnacl/pad_parameter.h" + +int GetInputFlattenIndex(int out_flatten_index, const int32_t *input_shape, const int *in_strides, + const int *out_strides, const int *paddings, int mirror_offset); +#ifdef __cplusplus +extern "C" { +#endif +void Pad(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *paddings, int tid, int thread_num); +void MirrorPad(const float *input_data, float *output_data, const int32_t *input_shape, const int *in_strides, + const int *out_strides, const int *paddings, int mirror_offset, int begin, int end); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_PAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pooling_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pooling_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1136059ec541f658d996711818948e0dd5203ffc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pooling_fp32.c @@ -0,0 +1,786 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/pooling_fp32.h" +#include +#include "nnacl/errorcode.h" +#include "nnacl/op_base.h" +#include "nnacl/pooling_fp32_simd.h" + +int AvgPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + + const float *src_plane_ptr = src_b_ptr; + float *dst_plane_ptr = dst_b_ptr + index * channel; + + int real_win_h_start = MSMAX(0, -in_h_index); + int real_win_h_end = MSMIN(win_h, in_h - in_h_index); + int real_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_end = MSMIN(win_w, in_w - in_w_index); + int ci = 0; + + NNACL_CHECK_TRUE_RET(real_win_h_end > real_win_h_start, NNACL_ERR); + NNACL_CHECK_TRUE_RET(real_win_w_end > real_win_w_start, NNACL_ERR); + SIMD_RUN_NO_SCALAR(AvgPoolingBatch, ci, src_plane_ptr, channel, dst_plane_ptr, real_win_h_start, real_win_h_end, + real_win_w_start, real_win_w_end, in_h_index, in_w, in_w_index, pooling_args->minf, + pooling_args->maxf); + + for (; ci < channel; ci++) { + const float *src_c_ptr = src_plane_ptr + ci; + float *dst_c_ptr = dst_plane_ptr + ci; + float tmp_avg = 0; + int real_count = 0; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += src_win_ptr[0]; + ++real_count; + } // win_w loop + } // win_h loop + NNACL_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); + tmp_avg = tmp_avg / (float)real_count; + tmp_avg = fmaxf(tmp_avg, pooling_args->minf); + tmp_avg = fminf(tmp_avg, pooling_args->maxf); + dst_c_ptr[0] = tmp_avg; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + return NNACL_OK; +} + +int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int output_batch = pooling_args->output_batch_; + + for (int batch = 0; batch < output_batch; batch++) { + const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + int ret = AvgPoolingBatch(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int AvgPoolingFromNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + + int out_plane = output_w * output_h; + int in_plane = in_w * in_h; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + +#ifdef ENABLE_AVX + const int c_tile = C8NUM; + const int once_calc_num = 2; +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + const int c_tile = C4NUM; + const int once_calc_num = 1; +#else + const int c_tile = 1; + const int once_calc_num = 1; +#endif + + const int c_xtile = once_calc_num * c_tile; + + int cur_c = (channel / c_xtile) * c_xtile; + int last_c_size = channel - cur_c; + + int less_out_plane = out_plane * last_c_size; + int calc_tile = UP_DIV(less_out_plane, thread_num); + + int index_begin = task_id * calc_tile; + int index_end = (index_begin + calc_tile) < less_out_plane ? (index_begin + calc_tile) : less_out_plane; + + int c_start = (index_begin / out_plane) + cur_c; + int index_less = index_begin % out_plane; + int h_start = index_less / output_h; + int w_start = index_less % output_h; + + int c_end = (index_end / out_plane) + cur_c; + index_less = index_end % out_plane; + int h_end = index_less / output_h; + int w_end = index_less % output_h; + + int c = c_start; + int h = h_start; + int w = w_start; + + int in_w_cx_line = in_w * last_c_size; + const float *src_c_ptr = src_b_ptr + c * in_plane; + for (; c < channel; c += c_xtile) { + for (; h < output_h; h++) { + int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0); + int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h); + + for (; w < output_w; w++) { + NNACL_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK); + float tmp_avg = 0.0; + + int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0); + int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w); + + int real_count = (cur_index_in_w_end - cur_index_in_w_start) * (cur_index_in_h_end - cur_index_in_h_start); + NNACL_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); + + for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) { + const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * last_c_size + (c - cur_c); + tmp_avg += cur_input_index[0]; + } + } + + float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c; + tmp_avg = tmp_avg / (float)real_count; + tmp_avg = fminf(tmp_avg, pooling_args->maxf); + dst_c_ptr[0] = tmp_avg; + } + w = 0; + } + h = 0; + } + return NNACL_OK; +} + +int AvgPoolingFromNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + + int out_plane = output_w * output_h; + int in_plane = in_w * in_h; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + +#ifdef ENABLE_AVX + const int c_tile = C8NUM; + const int once_calc_num = 2; + MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(pooling_args->minf); + MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(pooling_args->maxf); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + const int c_tile = C4NUM; + const int once_calc_num = 1; + MS_FLOAT32X4 min_value = MS_MOVQ_F32(pooling_args->minf); + MS_FLOAT32X4 max_value = MS_MOVQ_F32(pooling_args->maxf); +#else + const int c_tile = 1; + const int once_calc_num = 1; +#endif + + int in_w_cx_line = in_w * c_tile; + const int c_xtile = once_calc_num * c_tile; + int c_tile_num = channel / c_xtile; + int all_out_plane = out_plane * c_tile_num; + int calc_tile = UP_DIV(all_out_plane, thread_num); + + int index_begin = task_id * calc_tile; + int index_end = (index_begin + calc_tile) < all_out_plane ? (index_begin + calc_tile) : all_out_plane; + + int c_start = (index_begin / out_plane) * c_xtile; + int index_less = index_begin % out_plane; + int h_start = index_less / output_h; + int w_start = index_less % output_h; + + int c_end = (index_end / out_plane) * c_xtile; + index_less = index_end % out_plane; + int h_end = index_less / output_h; + int w_end = index_less % output_h; + + int c = c_start; + int h = h_start; + int w = w_start; + for (; c < channel; c += c_xtile) { + const float *src_c_ptr = src_b_ptr + c * in_plane; + for (; h < output_h; h++) { + int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0); + int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h); + + for (; w < output_w; w++) { + NNACL_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK); + +#ifdef ENABLE_AVX + MS_FLOAT32X8 tmp_avg = MS_MOV256_F32(0); + MS_FLOAT32X8 tmp_avg2 = MS_MOV256_F32(0); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 tmp_avg = MS_MOVQ_F32(0); +#else + float tmp_avg = 0; +#endif + + int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0); + int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w); + + int real_count = (cur_index_in_w_end - cur_index_in_w_start) * (cur_index_in_h_end - cur_index_in_h_start); + NNACL_CHECK_TRUE_RET(real_count != 0, NNACL_ERR); + + for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) { + const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * c_tile; +#ifdef ENABLE_AVX + tmp_avg = MS_ADD256_F32(tmp_avg, MS_LD256_F32(cur_input_index)); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + tmp_avg = MS_ADDQ_F32(tmp_avg, MS_LDQ_F32(cur_input_index)); +#else + tmp_avg += cur_input_index[0]; +#endif + } + +#ifdef ENABLE_AVX + const float *src_c2_ptr_h_line = src_c_ptr_h_line + c_tile * in_plane; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c2_ptr_h_line + cur_index_in_w * c_tile; + + tmp_avg2 = MS_ADD256_F32(tmp_avg2, MS_LD256_F32(cur_input_index)); + } +#endif + } + + float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c; +#ifdef ENABLE_AVX + float *dst_c2_ptr = dst_c_ptr + c_tile; + + tmp_avg = MS_DIV256_F32(tmp_avg, MS_MOV256_F32(real_count)); + tmp_avg = MS_MAX256_F32(tmp_avg, min_value_8); + tmp_avg = MS_MIN256_F32(tmp_avg, max_value_8); + MS_ST256_F32(dst_c_ptr, tmp_avg); + + tmp_avg2 = MS_DIV256_F32(tmp_avg2, MS_MOV256_F32(real_count)); + tmp_avg2 = MS_MAX256_F32(tmp_avg2, min_value_8); + tmp_avg2 = MS_MIN256_F32(tmp_avg2, max_value_8); + MS_ST256_F32(dst_c2_ptr, tmp_avg2); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + tmp_avg = MS_DIVQ_F32(tmp_avg, MS_MOVQ_F32(real_count)); + tmp_avg = MS_MAXQ_F32(tmp_avg, min_value); + tmp_avg = MS_MINQ_F32(tmp_avg, max_value); + MS_STQ_F32(dst_c_ptr, tmp_avg); +#else + tmp_avg = tmp_avg / (float)real_count; + tmp_avg = fmaxf(tmp_avg, pooling_args->minf); + tmp_avg = fminf(tmp_avg, pooling_args->maxf); + dst_c_ptr[0] = tmp_avg; +#endif + } + w = 0; + } + h = 0; + } + + return NNACL_OK; +} + +int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int output_batch = pooling_args->output_batch_; + + for (int batch = 0; batch < output_batch; batch++) { + const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + int ret = AvgPoolingFromNC4HW4ToNHWCBatch(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + + ret = AvgPoolingFromNC4HW4ToNHWCLessC(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + + const float *src_plane_ptr = src_b_ptr; + float *dst_plane_ptr = dst_b_ptr + index * channel; + + int real_win_h_start = MSMAX(0, -in_h_index); + int real_win_h_end = MSMIN(win_h, in_h - in_h_index); + int real_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_end = MSMIN(win_w, in_w - in_w_index); + int ci = 0; + + SIMD_RUN_NO_SCALAR(MaxPoolingBatch, ci, src_plane_ptr, channel, dst_plane_ptr, real_win_h_start, real_win_h_end, + real_win_w_start, real_win_w_end, in_h_index, in_w, in_w_index, pooling_args->minf, + pooling_args->maxf); + + for (; ci < channel; ci++) { + float *dst_c_ptr = dst_plane_ptr + ci; + const float *src_c_ptr = src_plane_ptr + ci; + float tmp_max = -FLT_MAX; + for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { + for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { + const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; + tmp_max = fmaxf(tmp_max, src_win_ptr[0]); + } // win_w loop + } // win_h loop + tmp_max = fmaxf(tmp_max, pooling_args->minf); + tmp_max = fminf(tmp_max, pooling_args->maxf); + dst_c_ptr[0] = tmp_max; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + return NNACL_OK; +} + +int MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int output_batch = pooling_args->output_batch_; + + for (int batch = 0; batch < output_batch; batch++) { + const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + int ret = MaxPoolingBatch(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int MaxPoolingFromNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + + int out_plane = output_w * output_h; + int in_plane = in_w * in_h; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + +#ifdef ENABLE_AVX + const int c_tile = C8NUM; + const int once_calc_num = 2; +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + const int c_tile = C4NUM; + const int once_calc_num = 1; +#else + const int c_tile = 1; + const int once_calc_num = 1; +#endif + + const int c_xtile = once_calc_num * c_tile; + + int cur_c = (channel / c_xtile) * c_xtile; + int last_c_size = channel - cur_c; + + int less_out_plane = out_plane * last_c_size; + int calc_tile = UP_DIV(less_out_plane, thread_num); + + int index_begin = task_id * calc_tile; + int index_end = (index_begin + calc_tile) < less_out_plane ? (index_begin + calc_tile) : less_out_plane; + + int c_start = (index_begin / out_plane) + cur_c; + int index_less = index_begin % out_plane; + int h_start = index_less / output_h; + int w_start = index_less % output_h; + + int c_end = (index_end / out_plane) + cur_c; + index_less = index_end % out_plane; + int h_end = index_less / output_h; + int w_end = index_less % output_h; + + int c = c_start; + int h = h_start; + int w = w_start; + + int in_w_cx_line = in_w * last_c_size; + const float *src_c_ptr = src_b_ptr + cur_c * in_plane; + for (; c < channel; c++) { + for (; h < output_h; h++) { + int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0); + int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h); + + for (; w < output_w; w++) { + NNACL_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK); + float tmp_max = -FLT_MAX; + + int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0); + int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w); + + for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) { + const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * last_c_size + (c - cur_c); + tmp_max = fmaxf(tmp_max, cur_input_index[0]); + } + } + + float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c; + tmp_max = fmaxf(tmp_max, pooling_args->minf); + tmp_max = fminf(tmp_max, pooling_args->maxf); + dst_c_ptr[0] = tmp_max; + } + w = 0; + } + h = 0; + } + return NNACL_OK; +} + +int MaxPoolingFromNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_, in_h = pooling_args->input_h_; + int win_w = pooling_args->window_w_, win_h = pooling_args->window_h_; + int output_w = pooling_args->output_w_, output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + + int out_plane = output_w * output_h; + int in_plane = in_w * in_h; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + +#ifdef ENABLE_AVX + const int c_tile = C8NUM; + const int once_calc_num = 2; + MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(pooling_args->minf); + MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(pooling_args->maxf); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + const int c_tile = C4NUM; + const int once_calc_num = 1; + MS_FLOAT32X4 min_value = MS_MOVQ_F32(pooling_args->minf); + MS_FLOAT32X4 max_value = MS_MOVQ_F32(pooling_args->maxf); +#else + const int c_tile = 1; + const int once_calc_num = 1; +#endif + + int in_w_cx_line = in_w * c_tile; + const int c_xtile = once_calc_num * c_tile; + int c_tile_num = channel / c_xtile; + int all_out_plane = out_plane * c_tile_num; + int calc_tile = UP_DIV(all_out_plane, thread_num); + + int index_begin = task_id * calc_tile; + int index_end = (index_begin + calc_tile) < all_out_plane ? (index_begin + calc_tile) : all_out_plane; + + int c_start = (index_begin / out_plane) * c_xtile; + int index_less = index_begin % out_plane; + int h_start = index_less / output_h; + int w_start = index_less % output_h; + + int c_end = (index_end / out_plane) * c_xtile; + index_less = index_end % out_plane; + int h_end = index_less / output_h; + int w_end = index_less % output_h; + + int c = c_start; + int h = h_start; + int w = w_start; + for (; c < channel; c += c_xtile) { + const float *src_c_ptr = src_b_ptr + c * in_plane; + for (; h < output_h; h++) { + int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0); + int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h); + + for (; w < output_w; w++) { + NNACL_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK); + +#ifdef ENABLE_AVX + MS_FLOAT32X8 tmp_max = MS_MOV256_F32(-FLT_MAX); + MS_FLOAT32X8 tmp_max2 = MS_MOV256_F32(-FLT_MAX); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 tmp_max = MS_MOVQ_F32(-FLT_MAX); +#else + float tmp_max = -FLT_MAX; +#endif + + int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0); + int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w); + + for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) { + const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * c_tile; +#ifdef ENABLE_AVX + tmp_max = MS_MAX256_F32(tmp_max, MS_LD256_F32(cur_input_index)); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + tmp_max = MS_MAXQ_F32(tmp_max, MS_LDQ_F32(cur_input_index)); +#else + tmp_max = fmaxf(tmp_max, cur_input_index[0]); +#endif + } + +#ifdef ENABLE_AVX + const float *src_c2_ptr_h_line = src_c_ptr_h_line + c_tile * in_plane; + for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) { + const float *cur_input_index = src_c2_ptr_h_line + cur_index_in_w * c_tile; + + tmp_max2 = MS_MAX256_F32(tmp_max2, MS_LD256_F32(cur_input_index)); + } +#endif + } + + float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c; +#ifdef ENABLE_AVX + float *dst_c2_ptr = dst_c_ptr + c_tile; + + tmp_max = MS_MAX256_F32(tmp_max, min_value_8); + tmp_max = MS_MIN256_F32(tmp_max, max_value_8); + MS_ST256_F32(dst_c_ptr, tmp_max); + + tmp_max2 = MS_MAX256_F32(tmp_max2, min_value_8); + tmp_max2 = MS_MIN256_F32(tmp_max2, max_value_8); + MS_ST256_F32(dst_c2_ptr, tmp_max2); +#elif defined(ENABLE_NEON) || defined(ENABLE_SSE) + tmp_max = MS_MAXQ_F32(tmp_max, min_value); + tmp_max = MS_MINQ_F32(tmp_max, max_value); + MS_STQ_F32(dst_c_ptr, tmp_max); +#else + tmp_max = fmaxf(tmp_max, pooling_args->minf); + tmp_max = fminf(tmp_max, pooling_args->maxf); + dst_c_ptr[0] = tmp_max; +#endif + } + w = 0; + } + h = 0; + } + + return NNACL_OK; +} + +int MaxPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num) { + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + int channel = pooling_args->input_channel_; + int output_batch = pooling_args->output_batch_; + + for (int batch = 0; batch < output_batch; batch++) { + const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel; + float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel; + int ret = MaxPoolingFromNC4HW4ToNHWCBatch(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + + ret = MaxPoolingFromNC4HW4ToNHWCLessC(src_b_ptr, dst_b_ptr, pooling_param, pooling_args, task_id, thread_num); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +void MaxPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param, + const Pooling3DComputeParam *pooling_args, int start, int end) { + // Access structure members in declaration order + int in_size_w = pooling_args->pooling_compute_param_.input_w_; + int in_size_h = pooling_args->pooling_compute_param_.input_h_; + int batch = pooling_args->pooling_compute_param_.input_batch_; + int channel = pooling_args->pooling_compute_param_.input_channel_; + int out_size_w = pooling_args->pooling_compute_param_.output_w_; + int out_size_h = pooling_args->pooling_compute_param_.output_h_; + int in_size_d = pooling_args->input_d_; + int out_size_d = pooling_args->output_d_; + + int kernel_w = pooling_param->pooling_parameter_.window_w_; + int kernel_h = pooling_param->pooling_parameter_.window_h_; + int stride_w = pooling_param->pooling_parameter_.stride_w_; + int stride_h = pooling_param->pooling_parameter_.stride_h_; + int pad_l_h = pooling_param->pooling_parameter_.pad_u_; + int pad_l_w = pooling_param->pooling_parameter_.pad_l_; + int kernel_d = pooling_param->window_d_; + int stride_d = pooling_param->stride_d_; + int pad_l_d = pooling_param->pad_f_; + + int n_stride = in_size_d * in_size_h * in_size_w * channel; + int d_stride = in_size_h * in_size_w * channel; + int h_stride = in_size_w * channel; + + int n = 0, d = 0, h = 0, w = 0; + const int parallel_dims = 4; // parallel on N/D/H/W four dims + offset_to_index_init(start, parallel_dims * VA_ARG_TUPLE_LEN, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, + batch); + + for (int i = start; i < end; i++) { + int d_start = d * stride_d - pad_l_d; + int d_end = MSMIN(d_start + kernel_d, in_size_d); + d_start = MSMAX(d_start, 0); + int h_start = h * stride_h - pad_l_h; + int h_end = MSMIN(h_start + kernel_h, in_size_h); + h_start = MSMAX(h_start, 0); + int w_start = w * stride_w - pad_l_w; + int w_end = MSMIN(w_start + kernel_w, in_size_w); + w_start = MSMAX(w_start, 0); + + const float *src_batch_ptr = input_ptr + n * n_stride; + float *out = output_ptr + i * channel; + int c_idx = 0; + SIMD_RUN_NO_SCALAR(MaxPooling3D, c_idx, src_batch_ptr, channel, out, d_start, d_end, h_start, h_end, w_start, w_end, + d_stride, h_stride); + for (; c_idx < channel; ++c_idx) { + const float *src_c_ptr = src_batch_ptr + c_idx; + float *dst_c_ptr = out + c_idx; + float tmp_max = -FLT_MAX; + for (int dd = d_start; dd < d_end; ++dd) { + for (int hh = h_start; hh < h_end; ++hh) { + for (int ww = w_start; ww < w_end; ++ww) { + const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel; + tmp_max = MSMAX(input[0], tmp_max); + } + } + } + dst_c_ptr[0] = tmp_max; + } + offset_to_index_step(parallel_dims * 2, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, batch); + } +} + +void AvgPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param, + const Pooling3DComputeParam *pooling_args, int start, int end) { + // Access structure members in declaration order + int in_size_w = pooling_args->pooling_compute_param_.input_w_; + int in_size_h = pooling_args->pooling_compute_param_.input_h_; + int batch = pooling_args->pooling_compute_param_.input_batch_; + int channel = pooling_args->pooling_compute_param_.input_channel_; + int out_size_w = pooling_args->pooling_compute_param_.output_w_; + int out_size_h = pooling_args->pooling_compute_param_.output_h_; + int in_size_d = pooling_args->input_d_; + int out_size_d = pooling_args->output_d_; + + int kernel_w = pooling_param->pooling_parameter_.window_w_; + int kernel_h = pooling_param->pooling_parameter_.window_h_; + int stride_w = pooling_param->pooling_parameter_.stride_w_; + int stride_h = pooling_param->pooling_parameter_.stride_h_; + int pad_l_h = pooling_param->pooling_parameter_.pad_u_; + int pad_r_h = pooling_param->pooling_parameter_.pad_d_; + int pad_l_w = pooling_param->pooling_parameter_.pad_l_; + int pad_r_w = pooling_param->pooling_parameter_.pad_r_; + int kernel_d = pooling_param->window_d_; + int stride_d = pooling_param->stride_d_; + int pad_l_d = pooling_param->pad_f_; + int pad_r_d = pooling_param->pad_b_; + bool count_include_pad = pooling_param->count_include_pad_; + int divisor = pooling_param->divisor_override_; + + int n_stride = in_size_d * in_size_h * in_size_w * channel; + int d_stride = in_size_h * in_size_w * channel; + int h_stride = in_size_w * channel; + + const int d_max = in_size_d + pad_r_d; + const int h_max = in_size_h + pad_r_h; + const int w_max = in_size_w + pad_r_w; + + int n = 0, d = 0, h = 0, w = 0; + const int parallel_dims = 4; // parallel on N/D/H/W four dims + offset_to_index_init(start, parallel_dims * VA_ARG_TUPLE_LEN, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, + batch); + + for (int i = start; i < end; i++) { + int d_start = d * stride_d - pad_l_d; + int d_end = MSMIN(d_start + kernel_d, d_max); + int d_start2 = MSMAX(d_start, 0); + int d_end2 = MSMIN(d_end, in_size_d); + int h_start = h * stride_h - pad_l_h; + int h_end = MSMIN(h_start + kernel_h, h_max); + int h_start2 = MSMAX(h_start, 0); + int h_end2 = MSMIN(h_end, in_size_h); + int w_start = w * stride_w - pad_l_w; + int w_end = MSMIN(w_start + kernel_w, w_max); + int w_start2 = MSMAX(w_start, 0); + int w_end2 = MSMIN(w_end, in_size_w); + + const float *src_batch_ptr = input_ptr + n * n_stride; + float *out = output_ptr + i * channel; + + if (pooling_param->divisor_override_ == 0) { + if (count_include_pad) { + divisor = (d_end - d_start) * (h_end - h_start) * (w_end - w_start); + } else { + divisor = (d_end2 - d_start2) * (h_end2 - h_start2) * (w_end2 - w_start2); + } + } + + int c_idx = 0; + SIMD_RUN_NO_SCALAR(AvgPooling3D, c_idx, src_batch_ptr, channel, out, d_start2, d_end2, h_start2, h_end2, w_start2, + w_end2, d_stride, h_stride, divisor); + for (; c_idx < channel; ++c_idx) { + const float *src_c_ptr = src_batch_ptr + c_idx; + float *dst_c_ptr = out + c_idx; + float tmp_avg = 0; + for (int dd = d_start2; dd < d_end2; ++dd) { + for (int hh = h_start2; hh < h_end2; ++hh) { + for (int ww = w_start2; ww < w_end2; ++ww) { + const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel; + tmp_avg = tmp_avg + input[0]; + } + } + } + dst_c_ptr[0] = tmp_avg / (float)divisor; + } + offset_to_index_step(parallel_dims * 2, &w, out_size_w, &h, out_size_h, &d, out_size_d, &n, batch); + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pooling_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pooling_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..b25c1d20a9b93e30966eea2742c550ab4e0b0757 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pooling_fp32.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_POOLING_H_ +#define NNACL_FP32_POOLING_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/pooling_parameter.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); +int MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); + +int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); +int MaxPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args, int task_id, int thread_num); +void MaxPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param, + const Pooling3DComputeParam *pooling_args, int start, int end); +void AvgPooling3D_NDHWC(const float *input_ptr, float *output_ptr, const Pooling3DParameter *pooling_param, + const Pooling3DComputeParam *pooling_args, int start, int end); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_POOLING_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pooling_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pooling_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..5519f46915b9615f192fd5f6da860adf69f731fe --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/pooling_fp32_simd.h.in @@ -0,0 +1,116 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_POOLING_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_POOLING_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int AvgPoolingBatch@SIMD_INSTRUCTION@(int ci, const float *src_plane_ptr, int channel, + float *dst_plane_ptr, int real_win_h_start, int real_win_h_end, int real_win_w_start, int real_win_w_end, + int in_h_index, int in_w, int in_w_index, float minf, float maxf) { + SIMD_F32 min_val = SIMD_MOV_F32(minf); + SIMD_F32 max_val = SIMD_MOV_F32(maxf); + for (int block_max_size = channel - BLOCK_NUM + 1; ci < block_max_size; ci += BLOCK_NUM) { + const float *src_c_ptr = src_plane_ptr + ci; + float *dst_c_ptr = dst_plane_ptr + ci; + SIMD_F32 tmp_avg = SIMD_SET0_F32; + int real_count = 0; + for (int h = real_win_h_start; h < real_win_h_end; h++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { + const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg = SIMD_ADD_F32(tmp_avg, SIMD_LD_F32(src_win_ptr)); + ++real_count; + } + } + tmp_avg = SIMD_DIV_F32(tmp_avg, SIMD_MOV_F32(real_count)); + tmp_avg = SIMD_MAX_F32(tmp_avg, min_val); + tmp_avg = SIMD_MIN_F32(tmp_avg, max_val); + SIMD_ST_F32(dst_c_ptr, tmp_avg); + } + return ci; +} + +static inline int MaxPoolingBatch@SIMD_INSTRUCTION@(int ci, const float *src_plane_ptr, int channel, + float *dst_plane_ptr, int real_win_h_start, int real_win_h_end, int real_win_w_start, int real_win_w_end, + int in_h_index, int in_w, int in_w_index, float minf, float maxf) { + SIMD_F32 min_val = SIMD_MOV_F32(minf); + SIMD_F32 max_val = SIMD_MOV_F32(maxf); + for (int block_max_size = channel - BLOCK_NUM + 1; ci < block_max_size; ci += BLOCK_NUM) { + const float *src_c_ptr = src_plane_ptr + ci; + float *dst_c_ptr = dst_plane_ptr + ci; + SIMD_F32 tmp_max = min_val; + for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { + for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { + const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; + tmp_max = SIMD_MAX_F32(tmp_max, SIMD_LD_F32(src_win_ptr)); + } + } + tmp_max = SIMD_MIN_F32(tmp_max, max_val); + SIMD_ST_F32(dst_c_ptr, tmp_max); + } + return ci; +} + +static inline int MaxPooling3D@SIMD_INSTRUCTION@(int c_idx, const float *src_batch_ptr, int channel, float *out, + int d_start, int d_end, int h_start, int h_end, int w_start, int w_end, int d_stride, int h_stride) { + for (int block_max_size = channel - BLOCK_NUM + 1; c_idx < block_max_size; c_idx += BLOCK_NUM) { + const float *src_c_ptr = src_batch_ptr + c_idx; + float *dst_c_ptr = out + c_idx; + SIMD_F32 tmp_max = SIMD_MOV_F32(-FLT_MAX); + for (int dd = d_start; dd < d_end; ++dd) { + for (int hh = h_start; hh < h_end; ++hh) { + for (int ww = w_start; ww < w_end; ++ww) { + const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel; + tmp_max = SIMD_MAX_F32(SIMD_LD_F32(input), tmp_max); + } + } + } + SIMD_ST_F32(dst_c_ptr, tmp_max); + } + return c_idx; +} + +static inline int AvgPooling3D@SIMD_INSTRUCTION@(int c_idx, const float *src_batch_ptr, int channel, float *out, + int d_start, int d_end, int h_start, int h_end, int w_start, int w_end, int d_stride, int h_stride, int divisor) { + for (int block_max_size = channel - BLOCK_NUM + 1; c_idx < block_max_size; c_idx += BLOCK_NUM) { + const float *src_c_ptr = src_batch_ptr + c_idx; + float *dst_c_ptr = out + c_idx; + SIMD_F32 tmp_avg = SIMD_SET0_F32; + for (int dd = d_start; dd < d_end; ++dd) { + for (int hh = h_start; hh < h_end; ++hh) { + for (int ww = w_start; ww < w_end; ++ww) { + const float *input = src_c_ptr + dd * d_stride + hh * h_stride + ww * channel; + tmp_avg = SIMD_ADD_F32(SIMD_LD_F32(input), tmp_avg); + } + } + } + tmp_avg = SIMD_DIV_F32(tmp_avg, SIMD_MOV_F32(divisor)); + SIMD_ST_F32(dst_c_ptr, tmp_avg); + } + return c_idx; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/power_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/power_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..d8c11338e7803f318f9a77926d6a906e7b080541 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/power_fp32.c @@ -0,0 +1,70 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/power_fp32.h" +#include "nnacl/errorcode.h" +#include "nnacl/power_fp32_simd.h" + +float OptimizedPowerScalar(float x, const float *exponent) { + int exp = abs((int)(*exponent)); + float result = 1; + while (exp) { + if (exp % 2) { + result *= x; + } + x *= x; + exp = exp / 2; + } + return *exponent >= 0 ? result : 1 / result; +} + +void PowerBroadCast(const float *input, const float *exponent, float *output, int len, float scale, float shift) { + PowerScalarFun PowerScalarFun_ = NULL; + + int i = 0; + if (CheckInteger(*exponent)) { + PowerScalarFun_ = OptimizedPowerScalar; + SIMD_RUN_NO_SCALAR(PowerBroadCastIntExponent, i, input, (int)(*exponent), output, len, scale, shift); + } else { + PowerScalarFun_ = StdPowerScalar; + SIMD_RUN_NO_SCALAR(PowerBroadCastFloatExponent, i, input, *exponent, output, len, scale, shift); + } + + for (; i < len; ++i) { + output[i] = PowerScalarFun_(scale * input[i] + shift, exponent); + } +} + +void PowerSingle(const float *input, const float *exponent, float *output, int len, float scale, float shift) { + int i = 0; + + SIMD_RUN_NO_SCALAR(PowerSingleExponent, i, input, exponent, output, len, scale, shift); + PowerScalarFun PowerScalarFun_ = NULL; + for (; i < len; ++i) { + PowerScalarFun_ = CheckInteger(exponent[i]) ? OptimizedPowerScalar : StdPowerScalar; + output[i] = PowerScalarFun_(scale * input[i] + shift, exponent + i); + } +} + +int Power(const float *input, const float *exponent, float *output, int len, float scale, float shift, bool broadcast) { + if (input == NULL || exponent == NULL || output == NULL) { + return NNACL_NULL_PTR; + } + PowerFun PowerFun_ = NULL; + PowerFun_ = broadcast ? PowerBroadCast : PowerSingle; + PowerFun_(input, exponent, output, len, scale, shift); + return NNACL_OK; +} diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/less.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/power_fp32.h similarity index 36% rename from mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/less.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/power_fp32.h index b5c8b56d29a576ed6df2eb0d81e19648c49982a5..246e9bbfe920fe7fe093ee3ba45b6a616347e97a 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/less.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/power_fp32.h @@ -1,5 +1,5 @@ /** - * Copyright 2023 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,10 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_LESS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_LESS_H_ -namespace mindspore::ascend_native { -void LessFp32(void *x1, void *x2, void *y, uint64_t elem_num, void *q); -void LessFp16(void *x1, void *x2, void *y, uint64_t elem_num, void *q); -} // namespace mindspore::ascend_native -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_LESS_H_ + +#ifndef MINDSPORE_NNACL_FP32_POWER_FP32_H_ +#define MINDSPORE_NNACL_FP32_POWER_FP32_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/pow_parameter.h" + +typedef void (*PowerFun)(const float *, const float *, float *, int, float, float); +typedef float (*PowerScalarFun)(float x, const float *exponent); + +#ifdef __cplusplus +extern "C" { +#endif +static inline bool CheckInteger(float f) { return fabsf(f - (int)(f)) < 0.000001; } + +static inline float StdPowerScalar(float x, const float *exponent) { return powf(x, *exponent); } + +int Power(const float *input, const float *exponent, float *output, int len, float scale, float shift, bool broadcast); +void PowerSingle(const float *input, const float *exponent, float *output, int len, float scale, float shift); +void PowerBroadCast(const float *input, const float *exponent, float *output, int len, float scale, float shift); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_POWER_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/power_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/power_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..2ae40d0c61d1194591930bcf5a6290ecce742850 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/power_fp32_simd.h.in @@ -0,0 +1,94 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_POWER_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_POWER_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int PowerBroadCastIntExponent@SIMD_INSTRUCTION@(int index, const float *input, int exponent, float *output, int len, + float scale, float shift) { + SIMD_F32 scale_vec = SIMD_MOV_F32(scale); + SIMD_F32 shift_vec = SIMD_MOV_F32(shift); + for (int block_max_size = len - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp = SIMD_FMADD_F32(scale_vec, SIMD_LD_F32(input + index), shift_vec); + SIMD_F32 result = SIMD_MOV_F32(1.0f); + int exp = abs(exponent); + while (exp) { + if (exp % 2) { + result = SIMD_MUL_F32(result, tmp); + } + tmp = SIMD_MUL_SQUARE_F32(tmp); + exp = exp / 2; + } + SIMD_ST_F32(output + index, exponent >= 0 ? result : SIMD_DIV_F32(SIMD_MOV_F32(1), result)); + } + return index; +} + +static inline int PowerBroadCastFloatExponent@SIMD_INSTRUCTION@(int index, const float *input, float exponent, float *output, int len, + float scale, float shift) { + SIMD_F32 scale_vec = SIMD_MOV_F32(scale); + SIMD_F32 shift_vec = SIMD_MOV_F32(shift); + for (int block_max_size = len - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp = SIMD_FMADD_F32(scale_vec, SIMD_LD_F32(input + index), shift_vec); + SIMD_F32 result; + for (int i = 0; i < BLOCK_NUM; ++i) { + SIMD_F32_GETI(result, i) = powf(SIMD_F32_GETI(tmp, i), exponent); + } + SIMD_ST_F32(output + index, result); + } + return index; +} + +static inline int PowerSingleExponent@SIMD_INSTRUCTION@(int index, const float *input, const float *exponent, float *output, int len, + float scale, float shift) { + SIMD_F32 scale_vec = SIMD_MOV_F32(scale); + SIMD_F32 shift_vec = SIMD_MOV_F32(shift); + for (int block_max_size = len - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp_vec = SIMD_FMADD_F32(scale_vec, SIMD_LD_F32(input + index), shift_vec); + for (int j = 0; j < BLOCK_NUM; ++j) { + float cur_exponent = exponent[index + j]; + float cur_val = SIMD_F32_GETI(tmp_vec, j); + if (fabsf(cur_exponent - (int)(cur_exponent)) < 0.000001) { + int exp = abs((int)(cur_exponent)); + float result = 1; + while (exp) { + if (exp % 2) { + result *= cur_val; + } + cur_val *= cur_val; + exp = exp / 2; + } + output[index + j] = *exponent >= 0 ? result : 1 / result; + } else { + output[index + j] = powf(cur_val, cur_exponent); + } + } + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/prelu_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/prelu_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..22319a500db592d948769e72b23b2e67d9625b23 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/prelu_fp32.c @@ -0,0 +1,164 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/prelu_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +#ifdef ENABLE_ARM64 +static inline void PRelu4x16(const float *in, float *out, const float *cur_slope, size_t step) { + asm volatile( + "mov x10, %[in]\n" + "mov x11, %[out]\n" + "mov x12, %[cur_slope]\n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x12]\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], %[step]\n" + "fmul v16.4s, v0.4s, v4.4s\n" + "fmul v17.4s, v1.4s, v5.4s\n" + "fmul v18.4s, v2.4s, v6.4s\n" + "fmul v19.4s, v3.4s, v7.4s\n" + "fcmgt v20.4s, v0.4s, #0\n" + "fcmgt v21.4s, v1.4s, #0\n" + "fcmgt v22.4s, v2.4s, #0\n" + "fcmgt v23.4s, v3.4s, #0\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], %[step]\n" + "bif v0.16b, v16.16b, v20.16b\n" + "bif v1.16b, v17.16b, v21.16b\n" + "bif v2.16b, v18.16b, v22.16b\n" + "bif v3.16b, v19.16b, v23.16b\n" + "fmul v8.4s, v24.4s, v4.4s\n" + "fmul v9.4s, v25.4s, v5.4s\n" + "fmul v10.4s, v26.4s, v6.4s\n" + "fmul v11.4s, v27.4s, v7.4s\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x11], %[step]\n" + "fcmgt v12.4s, v24.4s, #0\n" + "fcmgt v13.4s, v25.4s, #0\n" + "fcmgt v14.4s, v26.4s, #0\n" + "fcmgt v15.4s, v27.4s, #0\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], %[step]\n" + "bif v24.16b, v8.16b, v12.16b\n" + "bif v25.16b, v9.16b, v13.16b\n" + "bif v26.16b, v10.16b, v14.16b\n" + "bif v27.16b, v11.16b, v15.16b\n" + "fmul v16.4s, v0.4s, v4.4s\n" + "fmul v17.4s, v1.4s, v5.4s\n" + "fmul v18.4s, v2.4s, v6.4s\n" + "fmul v19.4s, v3.4s, v7.4s\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], %[step]\n" + "fcmgt v20.4s, v0.4s, #0\n" + "fcmgt v21.4s, v1.4s, #0\n" + "fcmgt v22.4s, v2.4s, #0\n" + "fcmgt v23.4s, v3.4s, #0\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10]\n" + "bif v0.16b, v16.16b, v20.16b\n" + "bif v1.16b, v17.16b, v21.16b\n" + "bif v2.16b, v18.16b, v22.16b\n" + "bif v3.16b, v19.16b, v23.16b\n" + "fmul v8.4s, v24.4s, v4.4s\n" + "fmul v9.4s, v25.4s, v5.4s\n" + "fmul v10.4s, v26.4s, v6.4s\n" + "fmul v11.4s, v27.4s, v7.4s\n" + "fcmgt v12.4s, v24.4s, #0\n" + "fcmgt v13.4s, v25.4s, #0\n" + "fcmgt v14.4s, v26.4s, #0\n" + "fcmgt v15.4s, v27.4s, #0\n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x11], %[step]\n" + "bif v24.16b, v8.16b, v12.16b\n" + "bif v25.16b, v9.16b, v13.16b\n" + "bif v26.16b, v10.16b, v14.16b\n" + "bif v27.16b, v11.16b, v15.16b\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11]\n" + : + : [ in ] "r"(in), [ out ] "r"(out), [ cur_slope ] "r"(cur_slope), [ step ] "r"(step) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27"); +} +#endif + +void PRelu(const float *input, float *output, const float *slope, int start, int end, int channel) { + int i = start; +#ifdef ENABLE_ARM64 + for (; i < end - 3; i += 4) { + const float *cur_in = input + i * channel; + float *cur_out = output + i * channel; + int j = 0; + for (; j < channel - 15; j += 16) { + const float *in = cur_in + j; + float *out = cur_out + j; + const float *cur_slope = slope + j; + size_t step = channel * sizeof(float); + PRelu4x16(in, out, cur_slope, step); + } + for (; j < channel; j++) { + cur_out[j] = (cur_in[j] > 0) ? cur_in[j] : (cur_in[j] * slope[j]); + cur_out[j + channel] = (cur_in[j + channel] > 0) ? cur_in[j + channel] : cur_in[j + channel] * slope[j]; + cur_out[j + 2 * channel] = + (cur_in[j + 2 * channel] > 0) ? cur_in[j + 2 * channel] : (cur_in[j + 2 * channel] * slope[j]); + cur_out[j + 3 * channel] = + (cur_in[j + 3 * channel] > 0) ? cur_in[j + 3 * channel] : (cur_in[j + 3 * channel] * slope[j]); + } + } +#endif + for (; i < end; i++) { + const float *cur_in = input + i * channel; + float *cur_out = output + i * channel; + int j = 0; +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + for (; j < channel - 3; j += 4) { + MS_FLOAT32X4 in = MS_LDQ_F32(cur_in + j); + MS_FLOAT32X4 s = MS_LDQ_F32(slope + j); + MS_FLOAT32X4 mul = MS_MULQ_F32(in, s); + MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 res = MS_BLENDQ_F32(in, mul, MS_CMPLEQ_F32(in, zero)); + MS_STQ_F32(cur_out + j, res); + } +#endif + for (; j < channel; j++) { + if (cur_in[j] > 0) { + cur_out[j] = cur_in[j]; + } else { + cur_out[j] = cur_in[j] * slope[j]; + } + } + } +} + +void PReluShareChannel(const float *input, float *output, float slope, int start, int end) { + int i = start; +#if defined(ENABLE_AVX) +#define mask_offset 30 + for (; i <= end - C8NUM; i += C8NUM) { + MS_FLOAT32X8 src_tmp = MS_LD256_F32(input + i); + MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, slope); + MS_FLOAT32X8 mask = MS_CMP256_F32(src_tmp, MS_MOV256_F32(0.0f), mask_offset); + MS_ST256_F32(output + i, MS_BLEND256_F32(mul_tmp, src_tmp, mask)); + } +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + for (; i <= end - C4NUM; i += C4NUM) { + MS_FLOAT32X4 src_tmp = MS_LDQ_F32(input + i); + MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, slope); +#ifdef ENABLE_ARM + MS_UINT32X4 mask = MS_CMPLEQ_F32(src_tmp, MS_MOVQ_F32(0.0f)); +#else + MS_FLOAT32X4 mask = MS_CMPLEQ_F32(src_tmp, MS_MOVQ_F32(0.0f)); +#endif + MS_STQ_F32(output + i, MS_BLENDQ_F32(src_tmp, mul_tmp, mask)); + } +#endif + for (; i < end; i++) { + output[i] = input[i] > 0 ? input[i] : input[i] * slope; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/prelu_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/prelu_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d8a335e8b3a3aab9d977a52f8a869e6e62b106b3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/prelu_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_PRELU_H_ +#define MINDSPORE_NNACL_FP32_PRELU_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PRelu(const float *input, float *output, const float *slope, int start, int end, int channel); + +void PReluShareChannel(const float *input, float *output, float slope, int start, int end); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_PRELU_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/prior_box_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/prior_box_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..56595814b7f361061da4bf0df72dbefea559eb60 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/prior_box_fp32.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_PRIOR_BOX_FP32_H_ +#define MINDSPORE_NNACL_FP32_PRIOR_BOX_FP32_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/prior_box_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +static int PriorBox(const float *input_data, float *output_data, const size_t size, const int tid, + const int thread_num) { + NNACL_CHECK_NULL_RETURN_ERR(input_data); + NNACL_CHECK_NULL_RETURN_ERR(output_data); + NNACL_CHECK_ZERO_RETURN_ERR(thread_num); + size_t unit_size = size / thread_num; + size_t copy_size = (tid == thread_num - 1) ? size - unit_size * tid : unit_size; + (void)memcpy(output_data + tid * unit_size, input_data + tid * unit_size, copy_size * sizeof(float)); + return NNACL_OK; +} +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_PRIOR_BOX_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/ragged_range_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/ragged_range_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a6ba55809ed31acdbd43c21d5f8b7e1b686f122e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/ragged_range_fp32.c @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/ragged_range_fp32.h" +#include +#include "nnacl/op_base.h" + +void RaggedRangeFp32(const float *starts, const float *limits, const float *deltas, int32_t *splits, float *value, + RaggedRangeStruct *ragged_range) { + splits[0] = 0; + for (int i = 0; i < ragged_range->rows_; i++) { + float start = ragged_range->starts_is_scalar_ ? starts[0] : starts[i]; + float limit = ragged_range->limits_is_scalar_ ? limits[0] : limits[i]; + float delta = ragged_range->deltas_is_scalar_ ? deltas[0] : deltas[i]; + int len = NNACL_MAX((int)ceil((float)(limit - start) / delta), 0); + splits[i + 1] = splits[i] + len; + for (int j = 0; j < len; j++) { + *value++ = start; + start += delta; + } + } +} + +void RaggedRangeInt(const int32_t *starts, const int32_t *limits, const int32_t *deltas, int32_t *splits, + int32_t *value, RaggedRangeStruct *ragged_range) { + splits[0] = 0; + for (int i = 0; i < ragged_range->rows_; i++) { + int start = ragged_range->starts_is_scalar_ ? starts[0] : starts[i]; + int limit = ragged_range->limits_is_scalar_ ? limits[0] : limits[i]; + int delta = ragged_range->deltas_is_scalar_ ? deltas[0] : deltas[i]; + NNACL_CHECK_ZERO_RETURN(delta); + int len = NNACL_MAX((int)ceil((float)(limit - start) / delta), 0); + splits[i + 1] = splits[i] + len; + for (int j = 0; j < len; j++) { + *value++ = start; + start += delta; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/ragged_range_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/ragged_range_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..64a758cc14b732f7f6129e2e79942412f379edb0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/ragged_range_fp32.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_RAGGED_RANGE_FP32_H_ +#define NNACL_FP32_RAGGED_RANGE_FP32_H_ + +#include "nnacl/kernel/ragged_range.h" + +void RaggedRangeFp32(const float *starts, const float *limits, const float *deltas, int32_t *splits, float *value, + RaggedRangeStruct *ragged_range); +void RaggedRangeInt(const int32_t *starts, const int32_t *limits, const int32_t *deltas, int32_t *splits, + int32_t *value, RaggedRangeStruct *ragged_range); + +#endif // NNACL_FP32_RAGGED_RANGE_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/range_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/range_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..90bde7b2aa1538600dd94e40f014aa71d3e23608 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/range_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_RANGE_FP32_H_ +#define NNACL_FP32_RANGE_FP32_H_ + +#include "nnacl/op_base.h" + +void Range(float *output_ptr, float start, float delta, int nums) { + for (int i = 0; i < nums; ++i, start += delta) { + output_ptr[i] = start; + } +} + +void RangeInt(int32_t *output_ptr, int start, int delta, int nums) { + for (int i = 0; i < nums; ++i, start += delta) { + output_ptr[i] = start; + } +} + +#endif // NNACL_FP32_RANGE_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/rank_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/rank_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..7740e045eddd56ffe1988f97c5adbd3fdbc0b514 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/rank_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANK_H_ +#define MINDSPORE_NNACL_RANK_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +inline void Rank(float *output, int rank) { + output[0] = (float)(rank); + return; +} +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_RANK_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reduce_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reduce_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..50604a4b513712baf28367df65f5e443ced5f910 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reduce_fp32.c @@ -0,0 +1,359 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/reduce_fp32.h" +#include +#include "nnacl/errorcode.h" +#include "nnacl/common_func.h" +#include "nnacl/reduce_fp32_simd.h" +#ifdef ENABLE_NNACL_INFER_SHAPE +#include "nnacl/reduce_parameter.h" +#endif + +// 32 bits, block_size : (512/256/128/32), block_num : (16/8/4/1) +#define ReduceCoreCalc(op_name, op_type, outer_src, outer_dst, k) \ + for (; k < inner_size; k++) { \ + const op_type *inner_src = outer_src + k; \ + op_name##PreDeal; \ + for (int i = 0; i < axis_size; i++) { \ + op_name##MidCalc; \ + } \ + op_name##PostDeal; \ + } + +#define RegReduceOp(op_name, op_type) \ + int op_name(int outer_size, int inner_size, int axis_size, const op_type *src_data, op_type *dst_data, int tid, \ + int thread_num) { \ + NNACL_CHECK_TRUE_RET(src_data != NULL && dst_data != NULL, NNACL_NULL_PTR); \ + NNACL_CHECK_TRUE_RET(thread_num > 0, NNACL_PARAM_INVALID); \ + NNACL_CHECK_TRUE_RET(axis_size > 0, NNACL_ERR); \ + for (int j = tid; j < outer_size; j += thread_num) { \ + const op_type *outer_src = src_data + j * axis_size * inner_size; \ + op_type *outer_dst = dst_data + j * inner_size; \ + int k = 0; \ + SIMD_RUN_NO_SCALAR(op_name, k, outer_src, outer_dst, inner_size, axis_size); \ + \ + ReduceCoreCalc(op_name, op_type, outer_src, outer_dst, k); \ + } \ + return NNACL_OK; \ + } + +// ReduceSum +#define ReduceSumPreDeal float tmp = 0; +#define ReduceSumMidCalc tmp += inner_src[i * inner_size]; +#define ReduceSumPostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceSum, float); + +int ReduceSumByLastAxis(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num) { + NNACL_CHECK_TRUE_RET(src_data != NULL && dst_data != NULL, NNACL_NULL_PTR); + NNACL_CHECK_TRUE_RET(thread_num > 0, NNACL_PARAM_INVALID); + NNACL_CHECK_TRUE_RET(axis_size > 0, NNACL_ERR); + + for (int j = tid; j < outer_size; j += thread_num) { + const float *src_tmp = src_data + j * axis_size; + + float tmp = src_tmp[0]; + int i = 1; + + SIMD_RUN_NO_SCALAR(ReduceSumByLastAxis, i, src_tmp, &tmp, axis_size); + for (; i < axis_size; i++) { + tmp += src_tmp[i]; + } + dst_data[j] = tmp; + } + return NNACL_OK; +} + +// ReduceMean +#define ReduceMeanPreDeal float tmp = 0; +#define ReduceMeanMidCalc tmp += inner_src[i * inner_size]; +#define ReduceMeanPostDeal outer_dst[k] = tmp / axis_size; +RegReduceOp(ReduceMean, float); + +// ReduceMin +#define ReduceMinPreDeal float tmp = FLT_MAX; +#define ReduceMinMidCalc tmp = fminf(tmp, inner_src[i * inner_size]); +#define ReduceMinPostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceMin, float); + +// ReduceMax +#define ReduceMaxPreDeal float tmp = FLT_MIN; +#define ReduceMaxMidCalc tmp = fmaxf(tmp, inner_src[i * inner_size]); +#define ReduceMaxPostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceMax, float); + +int ReduceMaxByLastAxis(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num) { + NNACL_CHECK_TRUE_RET(src_data != NULL && dst_data != NULL, NNACL_NULL_PTR); + NNACL_CHECK_TRUE_RET(thread_num > 0, NNACL_PARAM_INVALID); + NNACL_CHECK_TRUE_RET(axis_size > 0, NNACL_ERR); + + for (int j = tid; j < outer_size; j += thread_num) { + const float *src_tmp = src_data + j * axis_size; + + float tmp = src_tmp[0]; + int i = 1; + + SIMD_RUN_NO_SCALAR(ReduceMaxByLastAxis, i, src_tmp, &tmp, axis_size); + for (; i < axis_size; i++) { + tmp = fmaxf(tmp, src_tmp[i]); + } + dst_data[j] = tmp; + } + return NNACL_OK; +} + +// ReduceProd +#define ReduceProdPreDeal float tmp = 1.0f; +#define ReduceProdMidCalc tmp *= inner_src[i * inner_size]; +#define ReduceProdPostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceProd, float); + +// ReduceSumSquare +#define ReduceSumSquarePreDeal float tmp = 0; +#define ReduceSumSquareMidCalc tmp += (inner_src[i * inner_size] * inner_src[i * inner_size]); +#define ReduceSumSquarePostDeal outer_dst[k] = tmp; +RegReduceOp(ReduceSumSquare, float); + +// ReduceL2Norm +#define ReduceL2NormPreDeal float tmp = 0; +#define ReduceL2NormMidCalc tmp += (inner_src[i * inner_size] * inner_src[i * inner_size]); +#define ReduceL2NormPostDeal outer_dst[k] = sqrt(tmp); +RegReduceOp(ReduceL2Norm, float); + +// IntReduceSum +#define IntReduceSumPreDeal int tmp = 0; +#define IntReduceSumMidCalc tmp += inner_src[i * inner_size]; +#define IntReduceSumPostDeal outer_dst[k] = tmp; +RegReduceOp(IntReduceSum, int32_t); + +// IntReduceMean +#define IntReduceMeanPreDeal int tmp = 0; +#define IntReduceMeanMidCalc tmp += inner_src[i * inner_size]; +#define IntReduceMeanPostDeal outer_dst[k] = tmp / axis_size; +RegReduceOp(IntReduceMean, int32_t); + +// IntReduceMin +#define IntReduceMinPreDeal int tmp = INT32_MAX; +#define IntReduceMinMidCalc tmp = MSMIN(tmp, inner_src[i * inner_size]); +#define IntReduceMinPostDeal outer_dst[k] = tmp; +RegReduceOp(IntReduceMin, int32_t); + +// IntReduceMax +#define IntReduceMaxPreDeal int tmp = INT32_MIN; +#define IntReduceMaxMidCalc tmp = MSMAX(tmp, inner_src[i * inner_size]); +#define IntReduceMaxPostDeal outer_dst[k] = tmp; +RegReduceOp(IntReduceMax, int32_t); + +int ReduceAll(int outer_size, int inner_size, int axis_size, const bool *src_data, bool *dst_data, int tid, + int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const bool *outer_src = src_data + j * axis_size * inner_size; + bool *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const bool *inner_src = outer_src + k; + bool *inner_dst = outer_dst + k; + bool tmp = true; + for (i = 0; i < axis_size; i++) { + tmp = tmp && inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int IntReduceProd(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int tmp = 1; + for (i = 0; i < axis_size; i++) { + if (isMulOverflow(tmp, inner_src[i * inner_size])) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + tmp *= inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +#ifdef ENABLE_NNACL_INFER_SHAPE +int ReduceInferShape(int32_t **in_shape, size_t *dim_size, int32_t *out_shape, int32_t *in_format, int32_t *out_format, + int32_t *in_datatype, int32_t *out_datatype, OpParameter *param) { + *out_format = in_format[0]; + *out_datatype = in_datatype[0]; + ReduceParameter *reduce_parameter = (ReduceParameter *)param; + bool keep_dims = reduce_parameter->keep_dims_; + int num_axes = reduce_parameter->num_axes_; + int32_t *in_shape0 = in_shape[0]; + int rank = dim_size[0]; + NNACL_CHECK_TRUE_RET(rank > 0 && rank <= REDUCE_MAX_AXES_NUM, NNACL_PARAM_INVALID); + int axes[REDUCE_MAX_AXES_NUM]; + int actual_axes_num = num_axes; + for (int i = 0; i < num_axes; ++i) { + NNACL_CHECK_TRUE_RET(reduce_parameter->axes_[i] >= -rank && reduce_parameter->axes_[i] < rank, NNACL_PARAM_INVALID); + if (reduce_parameter->axes_[i] < 0) { + axes[i] = reduce_parameter->axes_[i] + rank; + } else { + axes[i] = reduce_parameter->axes_[i]; + } + } + if (reduce_parameter->reduce_to_end_) { + NNACL_CHECK_TRUE_RET(num_axes == 1, NNACL_PARAM_INVALID); + int begin_axis = axes[0]; + num_axes = rank - begin_axis; + for (int i = begin_axis + 1; i < rank; ++i) { + axes[actual_axes_num++] = i; + } + } + if (num_axes == 0) { + int j = 0; + for (int i = 0; i < rank; ++i) { + axes[i] = i; + if (keep_dims) { + out_shape[j++] = 1; + } + } + reduce_parameter->num_axes_ = rank; + for (int i = 0; i < rank; ++i) { + reduce_parameter->axes_[i] = axes[i]; + } + return NNACL_OK; + } + // reduce on selected axes + int j = 0; + for (int i = 0; i < rank; ++i) { + bool reduce_axis = false; + for (int idx = 0; idx < num_axes; ++idx) { + if (axes[idx] == i) { + reduce_axis = true; + break; + } + } + if (reduce_axis) { + if (keep_dims) { + out_shape[j++] = 1; + } + } else { + out_shape[j++] = in_shape0[i]; + } + } + reduce_parameter->num_axes_ = num_axes; + for (int i = 0; i < num_axes; ++i) { + reduce_parameter->axes_[i] = axes[i]; + } + return NNACL_OK; +} +#endif + +// [A, B] -> [B] +// col_size : start -> end for parallel +int ReduceSumDim2Axis0(size_t col_size, size_t col_len, size_t row_len, const float *src_data, float *dst_data) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + + size_t k = 0; + SIMD_RUN_NO_SCALAR(ReduceSumDim2Axis0, k, col_size, col_len, row_len, src_data, dst_data); + for (; k < col_size; k++) { + const float *inner_src = src_data + k; + float *inner_dst = dst_data + k; + float tmp = 0.0f; + for (size_t i = 0; i < row_len; i++) { + tmp += inner_src[i * col_len]; + } + *inner_dst = tmp; + } + return NNACL_OK; +} + +// [A, B] -> [A] +int ReduceSumDim2Axis1(size_t col_len, const float *src_data, float *dst_data) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + size_t k = 0; + float tmp = 0; +#ifdef ENABLE_AVX + size_t block_mod = col_len % C8NUM; + size_t block_c8 = col_len - block_mod; + float tmp_arr[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + MS_FLOAT32X8 tmp_arr_8 = MS_MOV256_F32(tmp_arr[0]); + for (; k < block_c8; k += C8NUM) { + MS_FLOAT32X8 src_in = MS_LD256_F32(src_data + k); + tmp_arr_8 = MS_ADD256_F32(tmp_arr_8, src_in); + } + MS_ST256_F32(tmp_arr, tmp_arr_8); + for (size_t i = 0; i < 8; ++i) { + tmp += tmp_arr[i]; + } +#endif + for (; k < col_len; k++) { + tmp += src_data[k]; + } + dst_data[0] = tmp; + return NNACL_OK; +} + +int ReduceMeanWithAxis(const float *src_data, float *mean, int64_t size) { + if (size == 0 || src_data == NULL) { + return NNACL_NULL_PTR; + } + float sum = 0.0; + int64_t i = 0; + SIMD_RUN_NO_SCALAR(ReduceSumByLastAxis, i, src_data, &sum, 0); + for (; i < size; ++i) { + sum += src_data[i]; + } + *mean = sum / size; + return NNACL_OK; +} + +int ReduceDeviation(const float *src_data, int64_t size, float mean, float *deviation) { + if (size == 0 || src_data == NULL) { + return NNACL_NULL_PTR; + } + int64_t i = 0; + SIMD_RUN_NO_SCALAR(FloatReduceDeviation, i, src_data, mean, size, deviation); + for (; i < size; ++i) { + float tmp = src_data[i] - mean; + tmp = tmp * tmp; + *deviation += tmp; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reduce_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reduce_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..c530a66e1184de09c54180384fc53bd37ee23618 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reduce_fp32.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_REDUCE_H_ +#define MINDSPORE_NNACL_FP32_REDUCE_H_ +#include +#include "nnacl/op_base.h" +#include "nnacl/reduce_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ReduceMean(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceMean(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceSum(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int ReduceSumByLastAxis(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceSum(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceMax(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int ReduceMaxByLastAxis(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceMax(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceMin(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceMin(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceProd(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int IntReduceProd(int outer_size, int inner_size, int axis_size, const int32_t *src_data, int32_t *dst_data, int tid, + int thread_num); +int ReduceSumSquare(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int ReduceL2Norm(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, + int thread_num); +int ReduceAll(int outer_size, int inner_size, int axis_size, const bool *src_data, bool *dst_data, int tid, + int thread_num); +int ReduceSumDim2Axis0(size_t col_size, size_t col_len, size_t row_len, const float *src_data, float *dst_data); +int ReduceSumDim2Axis1(size_t col_len, const float *src_data, float *dst_data); +int ReduceMeanWithAxis(const float *src_data, float *mean, int64_t size); +int ReduceDeviation(const float *src_data, int64_t size, float mean, float *deviation); + +#ifdef ENABLE_NNACL_INFER_SHAPE +int ReduceInferShape(int32_t **in_shape, size_t *dim_size, int32_t *out_shape, int32_t *in_format, int32_t *out_format, + int32_t *in_datatype, int32_t *out_datatype, OpParameter *param); +#endif +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_REDUCE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reduce_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reduce_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..203e35717a51aadccfd5fdd9652fe8cff264eef4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reduce_fp32_simd.h.in @@ -0,0 +1,220 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// clang-format off +#ifndef MINDSPORE_NNACL_FP32_REDUCE_FP32_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_REDUCE_FP32_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t ReduceSum@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceSumByLastAxis@SIMD_INSTRUCTION@(int64_t index, const float *src, float* tmp_sum, int axis_size) { + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int block_max_size = axis_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(src + index)); + } + *tmp_sum += SIMD_GET_SUM_F32(tmp); + return index; +} + +static inline int64_t ReduceMean@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, SIMD_DIV_N_F32(tmp, axis_size)); + } + return index; +} + +static inline int64_t ReduceMin@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(FLT_MAX); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MIN_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceMax@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(FLT_MIN); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MAX_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceMaxByLastAxis@SIMD_INSTRUCTION@(int64_t index, const float *src, float* tmp_max, int axis_size) { + SIMD_F32 tmp = SIMD_MOV_F32(*tmp_max); + for (int block_max_size = axis_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + tmp = SIMD_MAX_F32(tmp, SIMD_LD_F32(src + index)); + } + *tmp_max = SIMD_GET_MAX_F32(tmp); + return index; +} + +static inline int64_t ReduceProd@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(1.0f); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MUL_F32(tmp, SIMD_LD_F32(inner_src + i * inner_size)); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceSumSquare@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_F32(tmp, SIMD_MUL_SQUARE_F32(SIMD_LD_F32(inner_src + i * inner_size))); + } + SIMD_ST_F32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceL2Norm@SIMD_INSTRUCTION@(int64_t index, const float *outer_src, float *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const float *inner_src = outer_src + index; + SIMD_F32 tmp = SIMD_MOV_F32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_F32(tmp, SIMD_MUL_SQUARE_F32(SIMD_LD_F32(inner_src + i * inner_size))); + } + SIMD_ST_F32(outer_dst + index, SIMD_SQRT_F32(tmp)); + } + return index; +} + +static inline int64_t IntReduceSum@SIMD_INSTRUCTION@(int64_t index, const int32_t *outer_src, int32_t *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const int32_t *inner_src = outer_src + index; + SIMD_EPI32 tmp = SIMD_MOV_EPI32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_EPI32(tmp, SIMD_LD_EPI32(inner_src + i * inner_size)); + } + SIMD_ST_EPI32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t IntReduceMean@SIMD_INSTRUCTION@(int64_t index, const int32_t *outer_src, int32_t *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const int32_t *inner_src = outer_src + index; + SIMD_EPI32 tmp = SIMD_MOV_EPI32(0); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_ADD_EPI32(tmp, SIMD_LD_EPI32(inner_src + i * inner_size)); + } + SIMD_ST_EPI32(outer_dst + index, SIMD_DIV_N_EPI32(tmp, axis_size)); + } + return index; +} + +static inline int64_t IntReduceMin@SIMD_INSTRUCTION@(int64_t index, const int32_t *outer_src, int32_t *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const int32_t *inner_src = outer_src + index; + SIMD_EPI32 tmp = SIMD_MOV_EPI32(INT32_MAX); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MIN_EPI32(tmp, SIMD_LD_EPI32(inner_src + i * inner_size)); + } + SIMD_ST_EPI32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t IntReduceMax@SIMD_INSTRUCTION@(int64_t index, const int32_t *outer_src, int32_t *outer_dst, int inner_size, + int axis_size) { + for (int block_max_size = inner_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + const int32_t *inner_src = outer_src + index; + SIMD_EPI32 tmp = SIMD_MOV_EPI32(INT32_MIN); + for (int i = 0; i < axis_size; i++) { + tmp = SIMD_MAX_EPI32(tmp, SIMD_LD_EPI32(inner_src + i * inner_size)); + } + SIMD_ST_EPI32(outer_dst + index, tmp); + } + return index; +} + +static inline int64_t ReduceSumDim2Axis0@SIMD_INSTRUCTION@(int64_t index, size_t col_size, size_t col_len, size_t row_len, const float *src_data, float *dst_data) { + for (int block_max_size = col_size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp = SIMD_MOV_F32(0); + const float *inner_src = src_data + index; + float *inner_dst = dst_data + index; + for (size_t i = 0; i < row_len; ++i) { + tmp = SIMD_ADD_F32(tmp, SIMD_LD_F32(inner_src + i * col_len)); + } + SIMD_ST_F32(inner_dst, tmp); + } + return index; +} + +static inline int64_t FloatReduceDeviation@SIMD_INSTRUCTION@(int64_t index, const float *src_data, float mean, size_t size, float *deviation) { + SIMD_F32 fs_deviation = SIMD_MOV_F32(0); + SIMD_F32 fs_mean = SIMD_MOV_F32(mean); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 fs_sub = SIMD_LD_F32(src_data + index); + + fs_sub = SIMD_SUB_F32(fs_sub, fs_mean); + SIMD_F32 fs_pow = SIMD_MUL_F32(fs_sub, fs_sub); + fs_deviation = SIMD_ADD_F32(fs_deviation, fs_pow); + } + *deviation += SIMD_GET_SUM_F32(fs_deviation); + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/resize_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/resize_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..240ec136d90d106797c814ae5882849997ab1b33 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/resize_fp32.c @@ -0,0 +1,598 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32/resize_fp32.h" +#include "nnacl/common_func.h" +#include "nnacl/errorcode.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +void CalculateCoordinate(float out, int in, int32_t *bottom, int32_t *top, float *bottom_weight) { + *bottom = (int)(floorf(out)); + *bottom = *bottom >= 0 ? *bottom : 0; // extrapolate may generate neg value + *top = *bottom + 1 < in ? (*bottom + 1) : (in - 1); + float top_weight = (float)out - (float)(*bottom); + *bottom_weight = 1.0f - top_weight; +} + +static void BicubicBaseFunc(float a, const float x, float *weight) { + float abs_x = fabsf(x); + if (abs_x >= 0 && abs_x <= 1) { + *weight = ((a + 2) * abs_x - (a + 3)) * abs_x * abs_x + 1; + } else if (abs_x > 1 && abs_x <= 2) { + *weight = a * abs_x * abs_x * abs_x - 5 * a * abs_x * abs_x + 8 * a * abs_x - 4 * a; + } else { + *weight = 0; + } +} + +// a is a coefficient +// W(x) = { (a + 2) * |x| * |x| * |x| - (a + 3) * |x| * |x| + 1, for |x| <= 1 +// { a * |x| * |x| * |x| - 5 * a * |x| * |x| + 8 * a *|x| - 4 * a, for 1 < |x| < 2 +// { 0, otherwise +// the value of 'a' depends on if is half_pixel_center(the scheme is the same as tf). +// If is half pixel mode, a equals to -0.5, otherwise -0.75. +void CalculateWeightForBicubic(float out, int in, int32_t *index, float *weights, float a) { + int floor_index = (int)(floorf(out)); + index[0] = (floor_index - 1) < 0 ? 0 : (floor_index - 1); + index[1] = floor_index; + index[2] = (floor_index + 1) < in ? (floor_index + 1) : (in - 1); + index[3] = (floor_index + 2) < in ? (floor_index + 2) : (in - 1); + + // get positive value + float distance[4] = {-1, 0, 1, 2}; + float tmp_dis = out - (float)floor_index; + distance[0] -= tmp_dis; + distance[1] -= tmp_dis; + distance[2] -= tmp_dis; + distance[3] -= tmp_dis; + + for (int i = 0; i < 4; ++i) { + BicubicBaseFunc(a, distance[i], &weights[i]); + } +} + +int PrepareResizeBilinear(const int32_t *input_shape, const int32_t *output_shape, + CalculateOriginalCoordinate calculate, int32_t *y_bottoms, int32_t *y_tops, int32_t *x_lefts, + int32_t *x_rights, float *y_bottom_weights, float *x_left_weights) { + if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || + x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_h = input_shape[1]; + int in_w = input_shape[2]; + + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int h = 0; h < new_height; h++) { + float actual_y = calculate(h, in_h, new_height); + CalculateCoordinate(actual_y, in_h, y_bottoms + h, y_tops + h, y_bottom_weights + h); + } + for (int w = 0; w < new_width; w++) { + float actual_x = calculate(w, in_w, new_width); + CalculateCoordinate(actual_x, in_w, x_lefts + w, x_rights + w, x_left_weights + w); + } + return NNACL_OK; +} + +int PrepareResizeBicubic(const int32_t *input_shape, const int32_t *output_shape, CalculateOriginalCoordinate calculate, + int32_t *y_tops, int32_t *x_lefts, float *y_weights, float *x_weights, float cubic_coeff) { + if (input_shape == NULL || output_shape == NULL || y_tops == NULL || x_lefts == NULL || y_weights == NULL || + x_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int h = 0; h < new_height; h++) { + float actual_y = calculate(h, in_h, new_height); + CalculateWeightForBicubic(actual_y, in_h, y_tops + 4 * h, y_weights + 4 * h, cubic_coeff); + } + for (int w = 0; w < new_width; w++) { + float actual_x = calculate(w, in_w, new_width); + CalculateWeightForBicubic(actual_x, in_w, x_lefts + 4 * w, x_weights + 4 * w, cubic_coeff); + } + return NNACL_OK; +} + +int PrepareCropAndResizeBilinear(const int32_t *input_shape, const float *boxes, const int32_t *box_idx, + const int32_t *output_shape, int32_t *y_bottoms, int32_t *y_tops, int32_t *x_lefts, + int32_t *x_rights, float *y_bottom_weights, float *x_left_weights) { + if (input_shape == NULL || output_shape == NULL || y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || + x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int batch = output_shape[0]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + float actual_x; + float actual_y; + + for (int b = 0; b < batch; b++) { + const float *box = boxes + b * 4; + float start_h = box[0]; + float end_h = box[2]; + float start_w = box[1]; + float end_w = box[3]; + + int32_t *y_bottom = y_bottoms + b * new_height; + int32_t *y_top = y_tops + b * new_height; + float *y_bottom_weight = y_bottom_weights + b * new_height; + int32_t *x_left = x_lefts + b * new_width; + int32_t *x_right = x_rights + b * new_width; + float *x_left_weight = x_left_weights + b * new_width; + for (int h = 0; h < new_height; h++) { + if (new_height > 1) { + actual_y = start_h * (in_h - 1) + h * (end_h - start_h) * (in_h - 1) / (new_height - 1); + } else { + actual_y = 0.5 * (end_h + start_h) * (in_h - 1); + } + CalculateCoordinate(actual_y, in_h, y_bottom + h, y_top + h, y_bottom_weight + h); + } + for (int w = 0; w < new_width; w++) { + if (new_width > 1) { + actual_x = start_w * (in_w - 1) + w * (end_w - start_w) * (in_w - 1) / (new_width - 1); + } else { + actual_x = 0.5 * (end_w + start_w) * (in_w - 1); + } + CalculateCoordinate(actual_x, in_w, x_left + w, x_right + w, x_left_weight + w); + } + } + return NNACL_OK; +} + +int InterpRow(const float *src_line, float *linear_output, int new_width, const float *x_left_weights, + const int32_t *x_lefts, const int32_t *x_rights, int in_c) { + int w; + for (w = 0; w < new_width; w++) { + int c = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 left_w_8 = MS_MOV256_F32(x_left_weights[w]); + MS_FLOAT32X8 right_w_8 = MS_MOV256_F32(1.0f - x_left_weights[w]); + for (; c <= in_c - C8NUM; c += C8NUM) { + MS_FLOAT32X8 left = MS_LD256_F32(src_line + x_lefts[w] * in_c + c); + MS_FLOAT32X8 right = MS_LD256_F32(src_line + x_rights[w] * in_c + c); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(MS_MUL256_F32(left, left_w_8), MS_MUL256_F32(right, right_w_8)); + MS_ST256_F32(linear_output + w * in_c + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 left_w = MS_MOVQ_F32(x_left_weights[w]); + MS_FLOAT32X4 right_w = MS_MOVQ_F32(1.0f - x_left_weights[w]); + for (; c <= in_c - C4NUM; c += C4NUM) { + MS_FLOAT32X4 left = MS_LDQ_F32(src_line + x_lefts[w] * in_c + c); + MS_FLOAT32X4 right = MS_LDQ_F32(src_line + x_rights[w] * in_c + c); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(MS_MULQ_F32(left, left_w), MS_MULQ_F32(right, right_w)); + MS_STQ_F32(linear_output + w * in_c + c, interp_value); + } +#endif + int left_w_offset = x_lefts[w] * in_c; + int right_w_offset = x_rights[w] * in_c; + for (; c < in_c; c++) { + float left = src_line[left_w_offset + c]; + float right = src_line[right_w_offset + c]; + linear_output[w * in_c + c] = left * x_left_weights[w] + right * (1.0f - x_left_weights[w]); + } + } + return 0; +} + +int InterpCol(const float *bottom_line, const float *top_line, float *output, int new_width, float y_bottom_weight, + int in_c) { + int w; + for (w = 0; w < new_width; w++) { + int c = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 bottom_w_8 = MS_MOV256_F32(y_bottom_weight); + MS_FLOAT32X8 top_w_8 = MS_MOV256_F32(1.0f - y_bottom_weight); + for (; c <= in_c - C8NUM; c += C8NUM) { + MS_FLOAT32X8 bottom = MS_LD256_F32(bottom_line + w * in_c + c); + MS_FLOAT32X8 top = MS_LD256_F32(top_line + w * in_c + c); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(MS_MUL256_F32(bottom, bottom_w_8), MS_MUL256_F32(top, top_w_8)); + MS_ST256_F32(output + w * in_c + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 bottom_w = MS_MOVQ_F32(y_bottom_weight); + MS_FLOAT32X4 top_w = MS_MOVQ_F32(1.0f - y_bottom_weight); + for (; c <= in_c - C4NUM; c += C4NUM) { + MS_FLOAT32X4 bottom = MS_LDQ_F32(bottom_line + w * in_c + c); + MS_FLOAT32X4 top = MS_LDQ_F32(top_line + w * in_c + c); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(MS_MULQ_F32(bottom, bottom_w), MS_MULQ_F32(top, top_w)); + MS_STQ_F32(output + w * in_c + c, interp_value); + } +#endif + for (; c < in_c; c++) { + float bottom = bottom_line[w * in_c + c]; + float top = top_line[w * in_c + c]; + output[w * in_c + c] = bottom * y_bottom_weight + top * (1.0f - y_bottom_weight); + } + } + return 0; +} + +void Bilinear(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottom, const int32_t *y_top, const int32_t *x_left, const int32_t *x_right, + const float *y_bottom_weight, const float *x_left_weight, float *line0, float *line1, const int h_begin, + const int h_end) { + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_width = output_shape[2]; + int h_stride = new_width * in_c; + + bool cache_line_used[2] = {false, false}; + int cache_line_num[2] = {-1, -1}; + float *const cache_line_ptr[2] = {line0, line1}; + float *current_line_ptr[2] = {line0, line1}; + int current_line_num[2] = {-1, -1}; + + for (int h = h_begin; h < h_end; h++) { + current_line_num[0] = y_bottom[h]; + current_line_num[1] = y_top[h]; + + for (int i = 0; i < 2; i++) { + cache_line_used[i] = false; + } + // search if we cached + for (int j = 0; j < 2; j++) { + bool find = false; + for (int k = 0; k < 2; k++) { + if (current_line_num[j] == cache_line_num[k]) { + cache_line_used[k] = true; + current_line_ptr[j] = cache_line_ptr[k]; + find = true; + break; + } + } + + if (!find) { + const float *line = input_data + current_line_num[j] * in_w * in_c; + for (int k = 0; k < 2; k++) { + if (!cache_line_used[k]) { + cache_line_num[k] = current_line_num[j]; + cache_line_used[k] = true; + current_line_ptr[j] = cache_line_ptr[k]; + InterpRow(line, current_line_ptr[j], new_width, x_left_weight, x_left, x_right, in_c); + break; + } + } + } + } + // do col interp + InterpCol(current_line_ptr[0], current_line_ptr[1], output_data + h * h_stride, new_width, y_bottom_weight[h], + in_c); + } +} + +int ResizeBilinear(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottoms, const int32_t *y_tops, const int32_t *x_lefts, const int32_t *x_rights, + const float *y_bottom_weights, const float *x_left_weights, float *line0, float *line1, + const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_bottoms == NULL || + y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_bottom_weights == NULL || x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + + int in_b = input_shape[0]; + int in_h = input_shape[1]; + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + + for (int b = 0; b < in_b; b++) { + const float *input = input_data + b * in_h * in_w * in_c; + float *output = output_data + b * new_height * new_width * in_c; + Bilinear(input, output, input_shape, output_shape, y_bottoms, y_tops, x_lefts, x_rights, y_bottom_weights, + x_left_weights, line0, line1, h_begin, h_end); + } + return NNACL_OK; +} + +void BicubicInterpRow(const float *src, float *dst, const float *weights, const int32_t *lefts, int width, + int channel) { + for (int w = 0; w < width; w++) { + const float *weight = weights + 4 * w; + float *dst_w = dst + w * channel; + const float *src0_w = src + lefts[4 * w] * channel; + const float *src1_w = src + lefts[4 * w + 1] * channel; + const float *src2_w = src + lefts[4 * w + 2] * channel; + const float *src3_w = src + lefts[4 * w + 3] * channel; + int c = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 weight0_vec_8 = MS_MOV256_F32(weight[0]); + MS_FLOAT32X8 weight1_vec_8 = MS_MOV256_F32(weight[1]); + MS_FLOAT32X8 weight2_vec_8 = MS_MOV256_F32(weight[2]); + MS_FLOAT32X8 weight3_vec_8 = MS_MOV256_F32(weight[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + MS_FLOAT32X8 src0_vec = MS_LD256_F32(src0_w + c); + MS_FLOAT32X8 src1_vec = MS_LD256_F32(src1_w + c); + MS_FLOAT32X8 src2_vec = MS_LD256_F32(src2_w + c); + MS_FLOAT32X8 src3_vec = MS_LD256_F32(src3_w + c); + MS_FLOAT32X8 dst0 = MS_MUL256_F32(src0_vec, weight0_vec_8); + MS_FLOAT32X8 dst1 = MS_MUL256_F32(src1_vec, weight1_vec_8); + MS_FLOAT32X8 dst2 = MS_MUL256_F32(src2_vec, weight2_vec_8); + MS_FLOAT32X8 dst3 = MS_MUL256_F32(src3_vec, weight3_vec_8); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(dst3, MS_ADD256_F32(dst2, MS_ADD256_F32(dst1, dst0))); + MS_ST256_F32(dst_w + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 weight0_vec = MS_MOVQ_F32(weight[0]); + MS_FLOAT32X4 weight1_vec = MS_MOVQ_F32(weight[1]); + MS_FLOAT32X4 weight2_vec = MS_MOVQ_F32(weight[2]); + MS_FLOAT32X4 weight3_vec = MS_MOVQ_F32(weight[3]); + for (; c <= channel - C4NUM; c += C4NUM) { + MS_FLOAT32X4 src0_vec = MS_LDQ_F32(src0_w + c); + MS_FLOAT32X4 src1_vec = MS_LDQ_F32(src1_w + c); + MS_FLOAT32X4 src2_vec = MS_LDQ_F32(src2_w + c); + MS_FLOAT32X4 src3_vec = MS_LDQ_F32(src3_w + c); + MS_FLOAT32X4 dst0 = MS_MULQ_F32(src0_vec, weight0_vec); + MS_FLOAT32X4 dst1 = MS_MULQ_F32(src1_vec, weight1_vec); + MS_FLOAT32X4 dst2 = MS_MULQ_F32(src2_vec, weight2_vec); + MS_FLOAT32X4 dst3 = MS_MULQ_F32(src3_vec, weight3_vec); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(dst3, MS_ADDQ_F32(dst2, MS_ADDQ_F32(dst1, dst0))); + MS_STQ_F32(dst_w + c, interp_value); + } +#endif + for (; c < channel; c++) { + dst_w[c] = src0_w[c] * weight[0] + src1_w[c] * weight[1] + src2_w[c] * weight[2] + src3_w[c] * weight[3]; + } + } +} + +void BicubicInterpCol(const float *src, float *dst, const float *weights, int width, int channel) { + const float *src0 = src; + const float *src1 = src + width * channel; + const float *src2 = src + 2 * width * channel; + const float *src3 = src + 3 * width * channel; + for (int w = 0; w < width; w++) { + float *dst_w = dst + w * channel; + const float *src0_w = src0 + w * channel; + const float *src1_w = src1 + w * channel; + const float *src2_w = src2 + w * channel; + const float *src3_w = src3 + w * channel; + int c = 0; +#ifdef ENABLE_AVX + MS_FLOAT32X8 weight0_vec_8 = MS_MOV256_F32(weights[0]); + MS_FLOAT32X8 weight1_vec_8 = MS_MOV256_F32(weights[1]); + MS_FLOAT32X8 weight2_vec_8 = MS_MOV256_F32(weights[2]); + MS_FLOAT32X8 weight3_vec_8 = MS_MOV256_F32(weights[3]); + for (; c <= channel - C8NUM; c += C8NUM) { + MS_FLOAT32X8 src0_vec = MS_LD256_F32(src0_w + c); + MS_FLOAT32X8 src1_vec = MS_LD256_F32(src1_w + c); + MS_FLOAT32X8 src2_vec = MS_LD256_F32(src2_w + c); + MS_FLOAT32X8 src3_vec = MS_LD256_F32(src3_w + c); + MS_FLOAT32X8 dst1 = MS_MUL256_F32(src0_vec, weight0_vec_8); + MS_FLOAT32X8 dst2 = MS_MUL256_F32(src1_vec, weight1_vec_8); + MS_FLOAT32X8 dst3 = MS_MUL256_F32(src2_vec, weight2_vec_8); + MS_FLOAT32X8 dst4 = MS_MUL256_F32(src3_vec, weight3_vec_8); + MS_FLOAT32X8 interp_value = MS_ADD256_F32(dst4, MS_ADD256_F32(dst3, MS_ADD256_F32(dst1, dst2))); + MS_ST256_F32(dst_w + c, interp_value); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 weight0_vec = MS_MOVQ_F32(weights[0]); + MS_FLOAT32X4 weight1_vec = MS_MOVQ_F32(weights[1]); + MS_FLOAT32X4 weight2_vec = MS_MOVQ_F32(weights[2]); + MS_FLOAT32X4 weight3_vec = MS_MOVQ_F32(weights[3]); + for (; c <= channel - C4NUM; c += C4NUM) { + MS_FLOAT32X4 src0_vec = MS_LDQ_F32(src0_w + c); + MS_FLOAT32X4 src1_vec = MS_LDQ_F32(src1_w + c); + MS_FLOAT32X4 src2_vec = MS_LDQ_F32(src2_w + c); + MS_FLOAT32X4 src3_vec = MS_LDQ_F32(src3_w + c); + MS_FLOAT32X4 dst1 = MS_MULQ_F32(src0_vec, weight0_vec); + MS_FLOAT32X4 dst2 = MS_MULQ_F32(src1_vec, weight1_vec); + MS_FLOAT32X4 dst3 = MS_MULQ_F32(src2_vec, weight2_vec); + MS_FLOAT32X4 dst4 = MS_MULQ_F32(src3_vec, weight3_vec); + MS_FLOAT32X4 interp_value = MS_ADDQ_F32(dst4, MS_ADDQ_F32(dst3, MS_ADDQ_F32(dst1, dst2))); + MS_STQ_F32(dst_w + c, interp_value); + } +#endif + for (; c < channel; c++) { + dst_w[c] = src0_w[c] * weights[0] + src1_w[c] * weights[1] + src2_w[c] * weights[2] + src3_w[c] * weights[3]; + } + } +} + +void Bicubic(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_tops, const int32_t *x_lefts, const float *y_weights, const float *x_weights, + float *line_buffer, const int h_begin, const int h_end) { + int in_w = input_shape[2]; + int in_c = input_shape[3]; + int new_width = output_shape[2]; + int h_stride = new_width * in_c; + + for (int h = h_begin; h < h_end; h++) { + for (int i = 0; i < 4; ++i) { + BicubicInterpRow(input_data + y_tops[4 * h + i] * in_w * in_c, line_buffer + i * h_stride, x_weights, x_lefts, + new_width, in_c); + } + BicubicInterpCol(line_buffer, output_data + h * h_stride, y_weights + 4 * h, new_width, in_c); + } +} + +int ResizeBicubic(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_tops, const int32_t *x_lefts, const float *y_weights, const float *x_weights, + float *line_buffer, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || input_shape == NULL || output_shape == NULL || y_tops == NULL || + x_lefts == NULL || y_weights == NULL || x_weights == NULL) { + return NNACL_NULL_PTR; + } + int input_cube_per_batch = input_shape[1] * input_shape[2] * input_shape[3]; + int output_cube_per_batch = output_shape[1] * output_shape[2] * input_shape[3]; + for (int b = 0; b < input_shape[0]; b++) { + const float *input = input_data + b * input_cube_per_batch; + float *output = output_data + b * output_cube_per_batch; + Bicubic(input, output, input_shape, output_shape, y_tops, x_lefts, y_weights, x_weights, line_buffer, h_begin, + h_end); + } + return NNACL_OK; +} + +int RewriteExtrapolationValue(const float *input_data, float *output_data, const int32_t *box_idx, const float *boxes, + float extrapolation_value, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_tops, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || box_idx == NULL || input_shape == NULL || output_shape == NULL) { + return NNACL_NULL_PTR; + } + int batch = output_shape[0]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + int new_channel = output_shape[3]; + int input_h = input_shape[1]; + int input_w = input_shape[2]; + + for (int b = 0; b < batch; b++) { + float *output = output_data + b * new_height * new_width * new_channel; + const float *box = boxes + 4 * b; + float start_h = box[0]; + float end_h = box[2]; + float start_w = box[1]; + float end_w = box[3]; + float actual_y, actual_x; + for (int h = h_begin; h < h_end; ++h) { + if (new_height > 1) { + actual_y = start_h * (input_h - 1) + h * (end_h - start_h) * (input_h - 1) / (new_height - 1); + } else { + actual_y = 0.5 * (end_h + start_h) * (input_h - 1); + } + if (actual_y < 0 || actual_y > input_h - 1) { + float *output_data_base = output + h * new_width * new_channel; + for (int x = 0; x < new_width; ++x) { + for (int d = 0; d < new_channel; ++d) { + *output_data_base = extrapolation_value; + output_data_base++; + } + } + } + for (int w = 0; w < new_width; ++w) { + if (new_width > 1) { + actual_x = start_w * (input_w - 1) + w * (end_w - start_w) * (input_w - 1) / (new_width - 1); + } else { + actual_x = 0.5 * (end_w + start_w) * (input_w - 1); + } + if (actual_x < 0 || actual_x > input_w - 1) { + float *output_data_base = output + h * new_width * new_channel + w * new_channel; + for (int d = 0; d < new_channel; ++d) { + output_data_base[d] = extrapolation_value; + } + } + } + } + } + return NNACL_OK; +} + +int CropAndResizeBilinear(const float *input_data, float *output_data, const int32_t *box_idx, const float *boxes, + float extrapolation_value, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottoms, const int32_t *y_tops, const int32_t *x_lefts, + const int32_t *x_rights, const float *y_bottom_weights, const float *x_left_weights, + float *line0, float *line1, const int h_begin, const int h_end) { + if (input_data == NULL || output_data == NULL || box_idx == NULL || input_shape == NULL || output_shape == NULL || + y_bottoms == NULL || y_tops == NULL || x_lefts == NULL || x_rights == NULL || y_bottom_weights == NULL || + x_left_weights == NULL) { + return NNACL_NULL_PTR; + } + int batch = output_shape[0]; + int new_height = output_shape[1]; + int new_width = output_shape[2]; + int new_channel = output_shape[3]; + int input_h = input_shape[1]; + int input_w = input_shape[2]; + + for (int b = 0; b < batch; b++) { + const float *cur_img = input_data + box_idx[b] * input_h * input_w * new_channel; + const int32_t *y_bottom = y_bottoms + b * new_height; + const int32_t *y_top = y_tops + b * new_height; + const float *y_bottom_weight = y_bottom_weights + b * new_height; + const int32_t *x_left = x_lefts + b * new_width; + const int32_t *x_right = x_rights + b * new_width; + const float *x_left_weight = x_left_weights + b * new_width; + float *output = output_data + b * new_height * new_width * new_channel; + + Bilinear(cur_img, output, input_shape, output_shape, y_bottom, y_top, x_left, x_right, y_bottom_weight, + x_left_weight, line0, line1, h_begin, h_end); + } + RewriteExtrapolationValue(input_data, output_data, box_idx, boxes, extrapolation_value, input_shape, output_shape, + y_tops, h_begin, h_end); + return NNACL_OK; +} + +int ResizeNearestNeighbor(const float *input_data, float *output_data, const int32_t *input_shape, + const int32_t *output_shape, CalculateOriginalCoordinate calculate, + int coordinate_transform_mode, int tid, int thread_num) { + if (thread_num == 0) { + return NNACL_PARAM_INVALID; + } + int c = input_shape[3]; + bool align_corners = coordinate_transform_mode == 1; + for (int batch = 0; batch < output_shape[0]; batch++) { + for (int y = tid; y < output_shape[1]; y += thread_num) { + float actual_y = calculate(y, input_shape[1], output_shape[1]); + int input_y; + if (align_corners) { + input_y = (int)(roundf(actual_y)); + } else { + input_y = (int)(floorf(actual_y)); + } + for (int x = 0; x < output_shape[2]; x++) { + float actual_x = calculate(x, input_shape[2], output_shape[2]); + int input_x; + if (align_corners) { + input_x = (int)(roundf(actual_x)); + } else { + input_x = (int)(floorf(actual_x)); + } + int in_offset = Offset(input_shape, batch, input_y, input_x, 0); + int out_offset = Offset(output_shape, batch, y, x, 0); + memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(float)); + } + } + } + return NNACL_OK; +} + +float CalculateAsymmetric(int x_resized, int length_original, int length_resized) { + float scale = (float)(length_resized) / (float)(length_original); + return (float)(x_resized) / scale; +} + +float CalculateAlignCorners(int x_resized, int length_original, int length_resized) { + float scale = (float)(length_resized - 1) / (float)(length_original - 1); + return (float)(x_resized) / scale; +} + +float CalculateHalfPixel(int x_resized, int length_original, int length_resized) { + float scale = (float)(length_resized) / (float)(length_original); + float actual = (float)(x_resized + 0.5) / scale - 0.5; + return actual > 0 ? actual : 0; +} + +int CheckCropAndResizeBoxIdx(int32_t *box_idx, int32_t num_boxes, int32_t batch) { + for (int i = 0; i < num_boxes; i++) { + if (box_idx[i] < 0 || box_idx[i] >= batch) { + return NNACL_ERR; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/resize_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/resize_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..506f9f07dd40bc718f294a252583b0c396bee945 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/resize_fp32.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_RESIZE_H_ +#define MINDSPORE_NNACL_FP32_RESIZE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl/resize_parameter.h" +#include "nnacl/op_base.h" +#include "nnacl/crop_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef float (*CalculateOriginalCoordinate)(int x_resized, int length_original, int length_resized); + +int PrepareResizeBilinear(const int32_t *input_shape, const int32_t *output_shape, + CalculateOriginalCoordinate calculate, int32_t *y_bottoms, int32_t *y_tops, int32_t *x_lefts, + int32_t *x_rights, float *y_bottom_weights, float *x_left_weights); + +int PrepareResizeBicubic(const int32_t *input_shape, const int32_t *output_shape, CalculateOriginalCoordinate calculate, + int32_t *y_tops, int32_t *x_lefts, float *y_weights, float *x_weights, float cubic_coeff); + +int ResizeBilinear(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottoms, const int32_t *y_tops, const int32_t *x_lefts, const int32_t *x_rights, + const float *y_bottom_weights, const float *x_left_weights, float *line0, float *line1, + const int h_begin, const int h_end); + +int ResizeBicubic(const float *input_data, float *output_data, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_tops, const int32_t *x_lefts, const float *y_weights, const float *x_weights, + float *line_buffer, const int h_begin, const int h_end); + +int PrepareCropAndResizeBilinear(const int32_t *input_shape, const float *boxes, const int32_t *box_idx, + const int32_t *output_shape, int32_t *y_bottoms, int32_t *y_tops, int32_t *x_lefts, + int32_t *x_rights, float *y_bottom_weights, float *x_left_weights); + +int CropAndResizeBilinear(const float *input_data, float *output_data, const int32_t *box_idx, const float *boxes, + float extrapolation_value, const int32_t *input_shape, const int32_t *output_shape, + const int32_t *y_bottoms, const int32_t *y_tops, const int32_t *x_lefts, + const int32_t *x_rights, const float *y_bottom_weights, const float *x_left_weights, + float *line0, float *line1, const int h_begin, const int h_end); + +int ResizeNearestNeighbor(const float *input_data, float *output_data, const int32_t *input_shape, + const int32_t *output_shape, CalculateOriginalCoordinate calculate, + int coordinate_transform_mode, int tid, int thread_num); + +float CalculateAsymmetric(int x_resized, int length_original, int length_resized); + +float CalculateAlignCorners(int x_resized, int length_original, int length_resized); + +float CalculateHalfPixel(int x_resized, int length_original, int length_resized); + +int CheckCropAndResizeBoxIdx(int32_t *box_idx, int32_t num_boxes, int32_t batch); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_RESIZE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reverse_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reverse_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a4338c7d4ff8bd1bd935d49f1654c30458dde56e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reverse_fp32.c @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/reverse_fp32.h" +#include +#include "nnacl/errorcode.h" +#include "nnacl/nnacl_utils.h" + +int Reverse(const float *input, float *output, size_t elem_size, int32_t *index) { + for (size_t i = 0; i < elem_size; i++) { + NNACL_ASSERT(index[i] >= 0); + output[index[i]] = input[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reverse_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reverse_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..0394817c3783b95b61807a8b9c1890eb31f0156c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reverse_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_REVERSE_FP32_H_ +#define NNACL_FP32_REVERSE_FP32_H_ + +#include "nnacl/op_base.h" +#include "nnacl/reverse_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Reverse(const float *input, float *output, size_t elem_size, int32_t *index); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_REVERSE_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reverse_sequence_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reverse_sequence_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..2fe2ea3c628b144b4d40d1a8caa6b608990e6783 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reverse_sequence_fp32.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/reverse_sequence_fp32.h" + +void ReverseSequence(const float *input0, const void *input1, float *output, ReverseSequenceParameter *para) { + (void)memcpy(output, input0, para->total_data_size_); + ComputeStrides(para->input_shape0_, para->input_stride_, para->ndim_); + ComputeStrides(para->output_shape_, para->output_stride_, para->ndim_); + for (int i = 0; i < para->outer_count_; ++i) { + const float *in = input0 + i * para->outer_stride_; + float *out = output + i * para->outer_stride_; + for (int batch = 0; batch < para->input_shape0_[para->batch_axis_]; batch++) { + const float *in_batch = in + batch * para->input_stride_[para->batch_axis_]; + float *out_batch = out + batch * para->output_stride_[para->batch_axis_]; + int32_t seq_length = para->is_seq_length_int32_ ? *((int32_t *)input1 + batch) : *((int64_t *)input1 + batch); + NNACL_CHECK_TRUE_RET_VOID(seq_length <= para->input_shape0_[para->seq_axis_]); + for (int n = 0; n < seq_length; ++n) { + const float *in_seq = in_batch + (seq_length - 1 - n) * para->input_stride_[para->seq_axis_]; + float *out_seq = out_batch + n * para->output_stride_[para->seq_axis_]; + for (int j = 0; j < para->inner_count_; ++j) { + (void)memcpy(out_seq + j * para->inner_stride_, in_seq + j * para->inner_stride_, para->copy_byte_size_); + } + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reverse_sequence_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reverse_sequence_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..eebda882e744429e88a72c6b1944a8a2dda438a9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/reverse_sequence_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_REVERSE_SEQUENCE_H_ +#define MINDSPORE_NNACL_FP32_REVERSE_SEQUENCE_H_ + +#include +#include "nnacl/common_func.h" +#include "nnacl/op_base.h" +#include "nnacl/reverse_sequence_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ReverseSequence(const float *input0, const void *input1, float *output, ReverseSequenceParameter *para); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_REVERSE_SEQUENCE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/rmsprop_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/rmsprop_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..0965b7873a0266b544a4f694b08199153f508ff5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/rmsprop_fp32.c @@ -0,0 +1,147 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/rmsprop_fp32.h" +#ifdef ENABLE_SSE +#ifdef _MSC_VER +#include +#else +#include +#endif +#endif + +#ifdef ENABLE_AVX +#include +#endif + +#include +#include "nnacl/errorcode.h" + +int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float momentum, + float learning_rate, float decay, float epsilon, size_t start, size_t end) { + size_t c1 = start; +#ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; + float *variable_ptr = variable + start; + float *mean_square_ptr = mean_square + start; + float *gradients_ptr = gradients + start; + float *moment_ptr = moment + start; + + __m256 decay_r = _mm256_set1_ps(1.0 - decay); + __m256 momentum_r = _mm256_set1_ps(momentum); + __m256 lr_r = _mm256_set1_ps(learning_rate); + __m256 epsi_r = _mm256_set1_ps(epsilon); + __m256 gradient_r, mean_square_r, moment_r, variable_r, avx_r1, avx_r2; + for (; c1 < start + c8; c1 += C8NUM) { + gradient_r = _mm256_loadu_ps(gradients_ptr); + mean_square_r = _mm256_loadu_ps(mean_square_ptr); + avx_r1 = _mm256_sub_ps(_mm256_mul_ps(gradient_r, gradient_r), mean_square_r); + avx_r2 = _mm256_mul_ps(avx_r1, decay_r); + mean_square_r = _mm256_add_ps(mean_square_r, avx_r2); + _mm256_storeu_ps(mean_square_ptr, mean_square_r); + + avx_r1 = _mm256_add_ps(_mm256_sqrt_ps(mean_square_r), epsi_r); + avx_r2 = _mm256_div_ps(_mm256_mul_ps(gradient_r, lr_r), avx_r1); + + moment_r = _mm256_loadu_ps(moment_ptr); + avx_r1 = _mm256_add_ps(_mm256_mul_ps(moment_r, momentum_r), avx_r2); + _mm256_storeu_ps(moment_ptr, avx_r1); + + variable_r = _mm256_loadu_ps(variable_ptr); + variable_r = _mm256_sub_ps(variable_r, avx_r1); + _mm256_storeu_ps(variable_ptr, variable_r); + + gradients_ptr += C8NUM; + mean_square_ptr += C8NUM; + moment_ptr += C8NUM; + variable_ptr += C8NUM; + } +#endif + + for (; c1 < end; c1++) { + mean_square[c1] += (gradients[c1] * gradients[c1] - mean_square[c1]) * (1.0 - decay); + moment[c1] = moment[c1] * momentum + (gradients[c1] * learning_rate) / sqrt(mean_square[c1] + epsilon); + variable[c1] -= moment[c1]; + } + return NNACL_OK; +} + +int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float *mean_gradients, + float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end) { + size_t c1 = start; +#ifdef ENABLE_AVX + size_t c8 = ((end - start) / C8NUM) * C8NUM; + float *variable_ptr = variable + start; + float *mean_gradients_ptr = mean_gradients + start; + float *mean_square_ptr = mean_square + start; + float *moment_ptr = moment + start; + float *gradients_ptr = gradients + start; + + __m256 decay_r = _mm256_set1_ps(1.0 - decay); + __m256 momentum_r = _mm256_set1_ps(momentum); + __m256 lr_r = _mm256_set1_ps(learning_rate); + __m256 epsi_r = _mm256_set1_ps(epsilon); + __m256 grad_r, mean_grad_r, mean_square_r, moment_r, variable_r; + __m256 avx_r1, avx_r2; + for (; c1 < start + c8; c1 += C8NUM) { + grad_r = _mm256_loadu_ps(gradients_ptr); + mean_square_r = _mm256_loadu_ps(mean_square_ptr); + avx_r1 = _mm256_sub_ps(_mm256_mul_ps(grad_r, grad_r), mean_square_r); + avx_r2 = _mm256_mul_ps(avx_r1, decay_r); + mean_square_r = _mm256_add_ps(mean_square_r, avx_r2); + _mm256_storeu_ps(mean_square_ptr, mean_square_r); + + mean_grad_r = _mm256_loadu_ps(mean_gradients_ptr); + avx_r1 = _mm256_mul_ps(_mm256_sub_ps(grad_r, mean_grad_r), decay_r); + mean_grad_r = _mm256_add_ps(mean_grad_r, avx_r1); + _mm256_storeu_ps(mean_gradients_ptr, mean_grad_r); + + avx_r1 = _mm256_sub_ps(mean_square_r, _mm256_mul_ps(mean_grad_r, mean_grad_r)); + __m256 denom_r = _mm256_add_ps(avx_r1, epsi_r); + __m256 cmp_r = _mm256_cmp_ps(denom_r, _mm256_setzero_ps(), _CMP_GE_OS); + __m256 gt_zero_r = _mm256_blendv_ps(_mm256_set1_ps(1.0f), denom_r, cmp_r); + + avx_r1 = _mm256_mul_ps(grad_r, lr_r); + avx_r2 = _mm256_div_ps(avx_r1, _mm256_sqrt_ps(gt_zero_r)); + moment_r = _mm256_loadu_ps(moment_ptr); + avx_r1 = _mm256_mul_ps(moment_r, momentum_r); + avx_r1 = _mm256_add_ps(avx_r1, avx_r2); + moment_r = _mm256_blendv_ps(moment_r, avx_r1, cmp_r); + _mm256_storeu_ps(moment_ptr, moment_r); + + variable_r = _mm256_loadu_ps(variable_ptr); + avx_r1 = _mm256_sub_ps(variable_r, moment_r); + variable_r = _mm256_blendv_ps(variable_r, avx_r1, cmp_r); + _mm256_storeu_ps(variable_ptr, variable_r); + + variable_ptr += C8NUM; + mean_gradients_ptr += C8NUM; + mean_square_ptr += C8NUM; + gradients_ptr += C8NUM; + moment_ptr += C8NUM; + } +#endif + + for (; c1 < end; c1++) { + mean_square[c1] += (gradients[c1] * gradients[c1] - mean_square[c1]) * (1.0 - decay); + mean_gradients[c1] += (gradients[c1] - mean_gradients[c1]) * (1.0 - decay); + float denom = (mean_square[c1] - mean_gradients[c1] * mean_gradients[c1]) + epsilon; + if (denom > 0) { + moment[c1] = moment[c1] * momentum + (gradients[c1] * learning_rate) / sqrt(denom); + variable[c1] -= moment[c1]; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/rmsprop_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/rmsprop_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..c30ce0167a6ffa9c15fdadda71e2aa7a65de87ce --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/rmsprop_fp32.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RMDPROP_FP32_H +#define MINDSPORE_NNACL_RMDPROP_FP32_H + +#include +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int RMSPropUnuseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float momentum, + float learning_rate, float decay, float epsilon, size_t start, size_t end); + +int RMSPropUseCenterFp32(float *variable, float *mean_square, float *moment, float *gradients, float *mean_gradients, + float momentum, float learning_rate, float decay, float epsilon, size_t start, size_t end); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RMDPROP_FP32_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/roi_pooling_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/roi_pooling_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..55d944f718a4544c64ddde5507bb21e820a1403f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/roi_pooling_fp32.c @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/roi_pooling_fp32.h" +#include +#include +#include +#include "nnacl/errorcode.h" +#include "nnacl/op_base.h" + +int ROIPooling(const float *in_ptr, float *out_ptr, const float *roi, float *max_c, int tid, + const ROIPoolingParameter *param) { + if (param->thread_num_ == 0) { + return NNACL_PARAM_INVALID; + } + int num_rois = param->output_n_; + int units = UP_DIV(num_rois, param->thread_num_); + int roi_st = tid * units; + int roi_end = MSMIN(num_rois, roi_st + units); + if (roi_st >= num_rois) { + return NNACL_OK; + } + int batch_size = param->input_n_; + int height_ = param->input_h_; + int width_ = param->input_w_; + int channels_ = param->input_c_; + float scale = param->scale_; + int pooled_height = param->pooledH_; + int pooled_width = param->pooledW_; + const int roi_stride = 5; + int roi_ind_st = roi_st * roi_stride; + for (int i = roi_st; i < roi_end; ++i) { + int roi_batch_ind = (int)roi[roi_ind_st]; // batch_index + if (roi_batch_ind >= batch_size) { + return NNACL_ERRCODE_INDEX_OUT_OF_RANGE; + } + int roi_start_h = (int)roundf(roi[roi_ind_st + 1] * scale); // top-left x1 + int roi_start_w = (int)roundf(roi[roi_ind_st + 2] * scale); // top-left y1 + int roi_end_h = (int)roundf(roi[roi_ind_st + 3] * scale); // bottom-right x2 + int roi_end_w = (int)roundf(roi[roi_ind_st + 4] * scale); // bottom-fight y2 + + int roi_height = MSMAX(roi_end_h - roi_start_h + 1, 1); + int roi_width = MSMAX(roi_end_w - roi_start_w + 1, 1); + + float bin_size_h = (float)roi_height / (float)pooled_height; + float bin_size_w = (float)roi_width / (float)pooled_width; + const float *batch_data = in_ptr + param->in_strides_[kNHWC_N] * roi_batch_ind; + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = (int)floorf(ph * bin_size_h); // block xi_1 + int wstart = (int)floorf(pw * bin_size_w); // block yi_1 + int hend = (int)ceilf((ph + 1) * bin_size_h); // block xi_2 + int wend = (int)ceilf((pw + 1) * bin_size_w); // block yi_2 + + hstart = MSMIN(MSMAX(hstart + roi_start_h, 0), height_); + hend = MSMIN(MSMAX(hend + roi_start_h, 0), height_); + wstart = MSMIN(MSMAX(wstart + roi_start_w, 0), width_); + wend = MSMIN(MSMAX(wend + roi_start_w, 0), width_); + bool is_empty = (hend <= hstart) || (wend <= wstart); + for (int j = 0; j < channels_; ++j) { + max_c[j] = is_empty ? 0 : -FLT_MAX; + } + int pooled_index = i * param->out_strides_[0] + ph * param->out_strides_[1] + pw * param->out_strides_[2]; + int bd_index = hstart * param->in_strides_[1]; + for (int h = hstart; h < hend; ++h) { + int wi = bd_index + wstart * param->in_strides_[2]; + for (int w = wstart; w < wend; ++w) { + for (int c = 0; c < channels_; ++c) { + max_c[c] = MSMAX(batch_data[wi + c], max_c[c]); + } + wi += param->in_strides_[2]; + } // in_w end; + bd_index += param->in_strides_[1]; + } // in_h end + for (int j = 0; j < channels_; ++j) { + out_ptr[pooled_index + j] = max_c[j]; + } + } + } + roi_ind_st += roi_stride; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/roi_pooling_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/roi_pooling_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..d7d1c79e412376f842399796a993948bfd2e17ef --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/roi_pooling_fp32.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_ROI_POOLING_H_ +#define MINDSPORE_NNACL_FP32_ROI_POOLING_H_ + +#include "nnacl/op_base.h" + +typedef struct ROIPoolingParameter { + // primitive parameter + OpParameter op_parameter_; + int pooledW_; + int pooledH_; + float scale_; + + // shape correlative + int in_strides_[DIMENSION_4D]; + int out_strides_[DIMENSION_4D]; + int ndim_; + int input_w_; + int input_h_; + int input_n_; + int input_c_; + int output_w_; + int output_h_; + int output_n_; + int output_c_; + + // other parameter + int thread_num_; +} ROIPoolingParameter; + +#ifdef __cplusplus +extern "C" { +#endif +int ROIPooling(const float *in_ptr, float *out_ptr, const float *roi, float *max_c, int tid, + const ROIPoolingParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_ROI_POOLING_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/scale_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/scale_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..1f1d13e7da327243e9ae537c230715c21863139e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/scale_fp32.c @@ -0,0 +1,304 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/scale_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +void ScaleInner(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size, int inner_size) { + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_AVX + MS_FLOAT32X8 scale_8 = MS_MOV256_F32(scale[i]); + MS_FLOAT32X8 offset_8 = MS_MOV256_F32(offset[i]); + for (; in_index <= inner_size - C8NUM; in_index += C8NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 result = MS_MLA256_F32(offset_8, data, scale_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 scale_4 = MS_MOVQ_F32(scale[i]); + MS_FLOAT32X4 offset_4 = MS_MOVQ_F32(offset[i]); + for (; in_index <= inner_size - C4NUM; in_index += C4NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 result = MS_MLAQ_F32(offset_4, data, scale_4); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + out_data[in_offset] = in_data[in_offset] * scale[i] + offset[i]; + } + } + } +} + +void ScaleAxis(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size) { + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#if defined(ENABLE_AVX) + for (; index <= axis_size - C8NUM; index += C8NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X8 scale_8 = MS_LD256_F32(scale + index); + MS_FLOAT32X8 offset_8 = MS_LD256_F32(offset + index); + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 result = MS_MLA256_F32(offset_8, data, scale_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; index <= axis_size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 scale_4 = MS_LDQ_F32(scale + index); + MS_FLOAT32X4 offset_4 = MS_LDQ_F32(offset + index); + int in_offset = out_offset + index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 result = MS_MLAQ_F32(offset_4, data, scale_4); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + out_data[in_offset] = in_data[in_offset] * scale[index] + offset[index]; + } + } +} + +void DoScale(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + ScaleAxis(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + ScaleInner(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void ScaleInnerRelu(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size, int inner_size) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = {0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#ifdef ENABLE_AVX + MS_FLOAT32X8 scale_8 = MS_MOV256_F32(scale[i]); + MS_FLOAT32X8 offset_8 = MS_MOV256_F32(offset[i]); + for (; in_index <= inner_size - C8NUM; in_index += C8NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 tmp = MS_MLA256_F32(offset_8, data, scale_8); + MS_FLOAT32X8 result = MS_MAX256_F32(tmp, zeros_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 scale_4 = MS_MOVQ_F32(scale[i]); + MS_FLOAT32X4 offset_4 = MS_MOVQ_F32(offset[i]); + for (; in_index <= inner_size - C4NUM; in_index += C4NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 tmp = MS_MLAQ_F32(offset_4, data, scale_4); + MS_FLOAT32X4 result = MS_MAXQ_F32(tmp, zeros); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } + } +} + +void ScaleAxisRelu(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = {0, 0, 0, 0, 0, 0, 0, 0}; +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = {0, 0, 0, 0}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_AVX + for (; index <= axis_size - C8NUM; index += C8NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X8 scale_8 = MS_LD256_F32(scale + index); + MS_FLOAT32X8 offset_8 = MS_LD256_F32(offset + index); + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 tmp = MS_MLA256_F32(offset_8, data, scale_8); + MS_FLOAT32X8 result = MS_MAX256_F32(tmp, zeros_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; index <= axis_size - C4NUM; index += C4NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 scale_4 = MS_LDQ_F32(scale + index); + MS_FLOAT32X4 offset_4 = MS_LDQ_F32(offset + index); + MS_FLOAT32X4 tmp = MS_MLAQ_F32(offset_4, data, scale_4); + MS_FLOAT32X4 result = MS_MAXQ_F32(tmp, zeros); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = tmp > 0.0f ? tmp : 0.0f; + } + } +} + +void DoScaleRelu(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + ScaleAxisRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + ScaleInnerRelu(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} + +void ScaleInnerRelu6(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size, int inner_size) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = {0, 0, 0, 0, 0, 0, 0, 0}; + MS_FLOAT32X8 bounds_8 = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = {0, 0, 0, 0}; + MS_FLOAT32X4 bounds = {6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size * inner_size; + for (int i = 0; i < axis_size; i++) { + int axis_offset = out_offset + i * inner_size; + int in_index = 0; +#if defined(ENABLE_AVX) + MS_FLOAT32X8 scale_8 = MS_MOV256_F32(scale[i]); + MS_FLOAT32X8 offset_8 = MS_MOV256_F32(offset[i]); + for (; in_index <= inner_size - C8NUM; in_index += C8NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 tmp = MS_MLA256_F32(offset_8, data, scale_8); + MS_FLOAT32X8 result = MS_MIN256_F32(MS_MAX256_F32(tmp, zeros_8), bounds_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; in_index < inner_size - C4NUM; in_index += C4NUM) { + int in_offset = axis_offset + in_index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 scale_4 = MS_MOVQ_F32(scale[i]); + MS_FLOAT32X4 offset_4 = MS_MOVQ_F32(offset[i]); + MS_FLOAT32X4 tmp = MS_MLAQ_F32(offset_4, data, scale_4); + MS_FLOAT32X4 result = MS_MINQ_F32(MS_MAXQ_F32(tmp, zeros), bounds); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; in_index < inner_size; in_index++) { + int in_offset = axis_offset + in_index; + float tmp = in_data[in_offset] * scale[i] + offset[i]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } + } +} + +void ScaleAxisRelu6(const float *in_data, float *out_data, const float *scale, const float *offset, int outer_start, + int outer_end, int axis_size) { +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = {0, 0, 0, 0, 0, 0, 0, 0}; + MS_FLOAT32X8 bounds_8 = {6, 6, 6, 6, 6, 6, 6, 6}; +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = {0, 0, 0, 0}; + MS_FLOAT32X4 bounds = {6, 6, 6, 6}; +#endif + for (int out = outer_start; out < outer_end; out++) { + int out_offset = out * axis_size; + int index = 0; +#ifdef ENABLE_AVX + for (; index <= axis_size - C8NUM; index += C8NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X8 data = MS_LD256_F32(in_data + in_offset); + MS_FLOAT32X8 scale_8 = MS_LD256_F32(scale + index); + MS_FLOAT32X8 offset_8 = MS_LD256_F32(offset + index); + MS_FLOAT32X8 tmp = MS_MLA256_F32(offset_8, data, scale_8); + MS_FLOAT32X8 result = MS_MIN256_F32(MS_MAX256_F32(tmp, zeros_8), bounds_8); + MS_ST256_F32(out_data + in_offset, result); + } +#endif +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; index <= axis_size - C4NUM; index += C4NUM) { + int in_offset = out_offset + index; + MS_FLOAT32X4 data = MS_LDQ_F32(in_data + in_offset); + MS_FLOAT32X4 scale_4 = MS_LDQ_F32(scale + index); + MS_FLOAT32X4 offset_4 = MS_LDQ_F32(offset + index); + MS_FLOAT32X4 tmp = MS_MLAQ_F32(offset_4, data, scale_4); + MS_FLOAT32X4 result = MS_MINQ_F32(MS_MAXQ_F32(tmp, zeros), bounds); + MS_STQ_F32(out_data + in_offset, result); + } +#endif + for (; index < axis_size; index++) { + int in_offset = out_offset + index; + float tmp = in_data[in_offset] * scale[index] + offset[index]; + out_data[in_offset] = MSMIN(MSMAX(tmp, 0.0f), 6.0f); + } + } +} + +void DoScaleRelu6(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param) { + NNACL_CHECK_ZERO_RETURN(scale_param->base_.thread_nr_); + int outer_step = UP_DIV(scale_param->outer_size_, scale_param->base_.thread_nr_); + int outer_start = task_id * outer_step; + int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + + if (scale_param->inner_size_ == 1) { + ScaleAxisRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_); + } else { + ScaleInnerRelu6(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, + scale_param->inner_size_); + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/scale_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/scale_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..3b1eddf69227bd7aa4acde7aeeeee57c846aa92a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/scale_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_SCALE_FP32_H_ +#define NNACL_FP32_SCALE_FP32_H_ + +#include "nnacl/op_base.h" +#include "nnacl/kernel/scale.h" +#ifdef __cplusplus +extern "C" { +#endif +void DoScale(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param); +void DoScaleRelu(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param); +void DoScaleRelu6(const float *in_data, float *out_data, const float *scale, const float *offset, int task_id, + const ScaleStruct *scale_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_SCALE_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..16dbb77b95482bd0cf4bc652da0193b5a498da7c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_fp32.c @@ -0,0 +1,125 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/softmax_fp32.h" +#include +#include +#include "nnacl/fp32/exp_fp32.h" +#include "nnacl/errorcode.h" +#include "nnacl/softmax_fp32_simd.h" + +void SoftmaxNorm(const float *src, float *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + int index = 0; + float max = -FLT_MAX; + + SIMD_RUN_NO_SCALAR(SoftmaxNormGetMax, index, src, cur_batch_offset, &max, channel); + for (; index < channel; index++) { + float input = src[cur_batch_offset + index]; + if (input > max) { + max = input; + } + } + + index = 0; + SIMD_RUN_NO_SCALAR(SoftmaxNormCalcNorm, index, src, dst, cur_batch_offset, max, channel); + for (; index < channel; index++) { + int offset = cur_batch_offset + index; + dst[offset] = src[offset] - max; + } + } +} + +int SoftmaxLastAxis(const float *src, float *dst, int batch, int channel) { + int cur_batch_offset = 0; + for (int i = 0; i < batch; i++, cur_batch_offset += channel) { + int index = 0; + + // get channel's max value + float max = -FLT_MAX; + SIMD_RUN_NO_SCALAR(SoftmaxNormGetMax, index, src, cur_batch_offset, &max, channel); + for (; index < channel; index++) { + float input = src[cur_batch_offset + index]; + if (input > max) { + max = input; + } + } + + // get channel's exp sum value + float exp_sum = 0.0f; + index = 0; + SIMD_RUN_NO_SCALAR(SoftmaxLastAxisGetExpSum, index, src, dst, cur_batch_offset, max, &exp_sum, channel); + for (; index < channel; index++) { + int offset = cur_batch_offset + index; + float exp_out = simd_exp32_f32(src[offset] - max); + exp_sum += exp_out; + dst[offset] = exp_out; + } + + // get result + NNACL_CHECK_TRUE_RET(exp_sum != 0, NNACL_ERR); + exp_sum = 1.0f / exp_sum; + index = 0; + SIMD_RUN_NO_SCALAR(SoftmaxLastAxisGetResult, index, dst, dst, cur_batch_offset, exp_sum, channel); + for (; index < channel; index++) { + dst[cur_batch_offset + index] = dst[cur_batch_offset + index] * exp_sum; + } + } + return NNACL_OK; +} + +// output = exp(input) / reduce_sum(exp(input), axis) +void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, int axis, int n_dim, + const int32_t *input_shape) { + int inner_size = 1; + int outter_size = 1; + + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float max_data = input_ptr[inner_offset]; + sum_data[k + sum_outter_offset] = 0; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = expf(input_ptr[axis_offset] - max_data); + sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; + } + } + } + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * input_shape[axis] * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < input_shape[axis]; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset]; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..2cf0d0a37101d5056f1a32e1d5646236e39860f7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_fp32.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_H_ +#define MINDSPORE_NNACL_FP32_SOFTMAX_H_ + +#include "nnacl/op_base.h" +#include "nnacl/softmax_parameter.h" +#ifdef __cplusplus +extern "C" { +#endif +void Softmax(const float *input_ptr, float *output_ptr, float *sum_data, int axis, int n_dim, + const int32_t *input_shape); +int SoftmaxLastAxis(const float *src, float *dst, int batch, int channel); +void SoftmaxNorm(const float *src, float *dst, int batch, int channel); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_SOFTMAX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..e0e4868034c7131192eaeb18238af4c01d7d7924 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_fp32_simd.h.in @@ -0,0 +1,80 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_SOFTMAX_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t SoftmaxNormGetMax@SIMD_INSTRUCTION@(int64_t index, const float *src, int cur_batch_offset, + float *max, int channel) { + if (channel >= BLOCK_NUM * BLOCK_NUM) { + SIMD_F32 max_val = SIMD_MOV_F32(*max); + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + max_val = SIMD_MAX_F32(max_val, SIMD_LD_F32(src + cur_batch_offset + index)); + } + *max = SIMD_GET_MAX_F32(max_val); + } + return index; +} + +static inline int64_t SoftmaxNormCalcNorm@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, + int cur_batch_offset, float max, int channel) { + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 output = SIMD_SUB_F32(SIMD_LD_F32(src + cur_batch_offset + index), SIMD_MOV_F32(max)); + SIMD_ST_F32(dst + cur_batch_offset + index, output); + } + return index; +} + +static inline int64_t SoftmaxLastAxisGetExpSum@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, + int cur_batch_offset, float max, float *exp_sum, int channel) { +#ifndef _WIN32 + SIMD_F32 sum_val = SIMD_SET0_F32; + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(src + cur_batch_offset + index); + SIMD_F32 output = SIMD_SUB_F32(input, SIMD_MOV_F32(max)); + SIMD_F32 exp_out = SIMD_EXP_F32(output); + sum_val = SIMD_ADD_F32(sum_val, exp_out); + SIMD_ST_F32(dst + cur_batch_offset + index, exp_out); + } + *exp_sum += SIMD_GET_SUM_F32(sum_val); +#endif + return index; +} + +static inline int64_t SoftmaxLastAxisGetResult@SIMD_INSTRUCTION@(int64_t index, const float *src, float *dst, + int cur_batch_offset, float exp_sum, int channel) { + SIMD_F32 exp_sum_val = SIMD_MOV_F32(exp_sum); + for (int block_max_size = channel - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 input = SIMD_LD_F32(src + cur_batch_offset + index); + SIMD_F32 output = SIMD_MUL_F32(input, exp_sum_val); + SIMD_ST_F32(dst + cur_batch_offset + index, output); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_grad_fusion_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_grad_fusion_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..031772b3447305c7008aba6bfaf2df04cc0aa064 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_grad_fusion_fp32.c @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/softmax_grad_fusion_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/softmax_grad_fusion_fp32_simd.h" +#include "nnacl/sub_fp32_simd.h" + +void SoftmaxGradFusionOpt(const float *a, const float *b, float *dst, int64_t m) { + float result = 0; + + int64_t i = 0; + SIMD_RUN_NO_SCALAR(SoftmaxGradFusionOpt, i, a, b, &result, m); + for (; i < m; i++) { + result += a[i] * b[i]; + } + + i = 0; + SIMD_RUN_NO_SCALAR(ElementOptSubMul, i, a, b, result, dst, m); + for (; i < m; i++) { + dst[i] = a[i] * (b[i] - result); + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_grad_fusion_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_grad_fusion_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..850385e97c48ded732878d9d14d252cf8275ddc5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_grad_fusion_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_H_ +#define MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_H_ + +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SoftmaxGradFusionOpt(const float *a, const float *b, float *dst, int64_t m); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.cc b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_grad_fusion_fp32_simd.h.in similarity index 32% rename from mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_grad_fusion_fp32_simd.h.in index 87c2f880fe3dd4f1300e9b7ecb3e51d054044c5b..0ba6606ec1a94b7d128c38f169b4613635825653 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/softmax_grad_fusion_fp32_simd.h.in @@ -13,29 +13,43 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_SOFTMAX_GRAD_FUSION_@SIMD_INSTRUCTION@_H_ -#include "src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.h" -#include +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" -namespace mindspore::lite { -cudnnDataType_t ConvertCudnnDataType(nvinfer1::DataType trt_datatype) { - std::unordered_map data_types = {{nvinfer1::DataType::kFLOAT, CUDNN_DATA_FLOAT}, - {nvinfer1::DataType::kHALF, CUDNN_DATA_HALF}, - {nvinfer1::DataType::kINT32, CUDNN_DATA_INT32}, - {nvinfer1::DataType::kINT8, CUDNN_DATA_INT8}}; - if (data_types.find(trt_datatype) != data_types.end()) { - return data_types[trt_datatype]; - } else { - MS_LOG(ERROR) << "invalid datatype for cudnn: " << static_cast(trt_datatype); +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t SoftmaxGradFusionOpt@SIMD_INSTRUCTION@(int64_t index, const float *a, const float *b, + float *out, int64_t size) { + SIMD_F32 result_vec = SIMD_MOV_F32(0.0f); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 a_vec = SIMD_LD_F32(a + index); + SIMD_F32 b_vec = SIMD_LD_F32(b + index); + result_vec = SIMD_FMADD_F32(a_vec, b_vec, result_vec); + } + *out += SIMD_GET_SUM_F32(result_vec); + + return index; +} + +static inline int64_t ElementOptSubMul@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float sum, + float *out, int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(sum); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_ST_F32(out + index, SIMD_MUL_F32(vin0, SIMD_SUB_F32(vin1, vin1_opt_))); } - return CUDNN_DATA_FLOAT; + return index; } -int CudnnActivation(cudnnHandle_t handle, cudnnActivationDescriptor_t activation_desc, - const cudnnTensorDescriptor_t x_dsc, const void *x, const cudnnTensorDescriptor_t y_dsc, void *y) { - float alpha = 1.0f; - float beta = 0.0f; - CUDNN_CHECK(cudnnActivationForward(handle, activation_desc, &alpha, x_dsc, x, &beta, y_dsc, y)); - return 0; +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus } -} // namespace mindspore::lite +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/space_to_batch_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/space_to_batch_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..17bcedc7ed19aa258285afd4c723c58ec6539928 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/space_to_batch_fp32.c @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/space_to_batch_fp32.h" +#include "nnacl/errorcode.h" + +int DoSpaceToBatch(const void *input, void *output, SpaceToBatchParameter *param, int task_id) { + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + const int input_batch = param->input_shape_[0]; + const int input_height = param->input_shape_[1]; + const int input_width = param->input_shape_[2]; + + const int output_batch = param->output_shape_[0]; + const int output_height = param->output_shape_[1]; + const int output_width = param->output_shape_[2]; + + const int block_shape_height = param->block_sizes_[0]; + const int block_shape_width = param->block_sizes_[1]; + const int padding_top = param->paddings_[0]; + const int padding_left = param->paddings_[2]; + + NNACL_CHECK_ZERO_RETURN_ERR(input_batch); + NNACL_CHECK_ZERO_RETURN_ERR(block_shape_width); + int copy_size = param->input_shape_[3] * param->data_type_len; + for (int64_t out_b = task_id; out_b < output_batch; out_b += param->op_parameter_.thread_num_) { + int in_b = out_b % input_batch; + int shift_w = (out_b / input_batch) % block_shape_width; + int shift_h = (out_b / input_batch) / block_shape_width; + for (int out_h = 0; out_h < output_height; out_h++) { + for (int out_w = 0; out_w < output_width; out_w++) { + int64_t output_offset = + out_b * param->out_stride_[0] + out_h * param->out_stride_[1] + out_w * param->out_stride_[2]; + if (out_h * block_shape_height + shift_h < padding_top || + out_h * block_shape_height + shift_h >= padding_top + input_height || + out_w * block_shape_width + shift_w < padding_left || + out_w * block_shape_width + shift_w >= padding_left + input_width) { + memset((int8_t *)output + output_offset * param->data_type_len, 0, copy_size); + } else { + int in_h = (out_h * block_shape_height + shift_h) - padding_top; + int in_w = (out_w * block_shape_width + shift_w) - padding_left; + int input_offset = in_b * param->in_stride_[0] + in_h * param->in_stride_[1] + in_w * param->in_stride_[2]; + memcpy((int8_t *)output + output_offset * param->data_type_len, + (const int8_t *)input + input_offset * param->data_type_len, copy_size); + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/space_to_batch_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/space_to_batch_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..6d8f902c9b4f85938f0a7ed88224e29f15d099f1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/space_to_batch_fp32.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_BATCH_H_ +#define MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_BATCH_H_ + +#include +#include "nnacl/op_base.h" + +typedef struct SpaceToBatchParameter { + // primitive parameter + OpParameter op_parameter_; + int block_sizes_[4]; + int paddings_[4]; + + // shape correlative + int input_shape_[4]; + int output_shape_[4]; + int in_stride_[4]; + int out_stride_[4]; + int padded_in_shape_[4]; + + // other parameter + bool need_paddings_; + int m_; + int data_type_len; +} SpaceToBatchParameter; +#ifdef __cplusplus +extern "C" { +#endif + +int DoSpaceToBatch(const void *input, void *output, SpaceToBatchParameter *param, int task_id); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_NNACL_FP32_SPACE_TO_BATCH_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sparse_to_dense_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sparse_to_dense_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..b5958c4108ca8846bfb5f8cfd9389eee6018322a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sparse_to_dense_fp32.c @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * +// * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/sparse_to_dense_fp32.h" +#include "nnacl/errorcode.h" + +int SparseToDenseSetDefault(float *output, float default_value, const SparseToDenseParameter *param, int task_id) { + if (output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->output_num, param->op_parameter_.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->output_num); + for (int i = begin; i < end; i++) { + output[i] = default_value; + } + return NNACL_OK; +} + +int SparseToDense(int32_t *indices_vec, const float *sparse_values, float default_value, float *output, + SparseToDenseParameter *param, int task_id) { + if (indices_vec == NULL || sparse_values == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->op_parameter_.thread_num_ == 0) { + return NNACL_ERR; + } + int unit_per_thread = UP_DIV(param->index_num, param->op_parameter_.thread_num_); + int begin = unit_per_thread * task_id; + int end = MSMIN(begin + unit_per_thread, param->index_num); + + int stride0 = param->output_stride[0]; + int stride1 = param->output_stride[1]; + int stride2 = param->output_stride[2]; + + if (param->validate_indices_ == true) { + int index_before = -1; + for (int i = begin; i < end; i++) { + int32_t *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + if (index <= index_before) { + return NNACL_ERR; + } + index_before = index; + } + } + + if (param->is_scalar == true) { + for (int i = begin; i < end; i++) { + int32_t *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + output[index] = sparse_values[0]; + } + } else { + for (int i = begin; i < end; i++) { + int32_t *indices = indices_vec + i * DIMENSION_4D; + int index = stride0 * indices[0] + stride1 * indices[1] + stride2 * indices[2] + indices[3]; + output[index] = sparse_values[i]; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sparse_to_dense_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sparse_to_dense_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..44b5e458d513584e92d20c82af208b0e840c87eb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sparse_to_dense_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_SPARSETODENSE_H_ +#define MINDSPORE_NNACL_FP32_SPARSETODENSE_H_ + +#include "nnacl/sparse_to_dense_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SparseToDenseSetDefault(float *output, float default_value, const SparseToDenseParameter *param, int task_id); +int SparseToDense(int32_t *indices_vec, const float *sparse_values, float default_value, float *output, + SparseToDenseParameter *param, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_SPARSETODENSE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/splice_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/splice_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..a329c448248c538eb657d0c807b57ac728d1f183 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/splice_fp32.c @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/splice_fp32.h" +void SpliceFp32(const float *src_data, int src_row, int src_col, const SpliceParameter *splice_parameter, + float *dst_data, int dst_row, int dst_col) { + int forward_index = 0; + for (int r = 0; r < dst_row; ++r) { + float *dst_row_data = dst_data + r * dst_col; + for (int off = 0; off < splice_parameter->context_dim_; ++off) { + int r_off = splice_parameter->forward_indexes_[forward_index]; + forward_index++; + const float *tmp_src_data = src_data + r_off * src_col; + float *tmp_dst_data = dst_row_data + off * src_col; + memcpy(tmp_dst_data, tmp_src_data, (size_t)(src_col) * sizeof(float)); + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/splice_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/splice_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..4cbe37a63fb09448cf99aa53191f630e4df0ceb3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/splice_fp32.h @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_SPLICE_FP32_H_ +#define NNACL_FP32_SPLICE_FP32_H_ + +#include +#include "nnacl/splice_parameter.h" + +void SpliceFp32(const float *src_data, int src_row, int src_col, const SpliceParameter *splice_parameter, + float *dst_data, int dst_row, int dst_col); + +#endif // NNACL_FP32_SPLICE_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/squared_difference.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/squared_difference.c new file mode 100644 index 0000000000000000000000000000000000000000..99b5fcd5648cf506a82f052e8e6b56a90e9ed705 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/squared_difference.c @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ +#define MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ + +#include "nnacl/fp32/squared_difference.h" +#include "nnacl/fp32/sub_fp32.h" +#include "nnacl/fp32/mul_fp32.h" + +int ElementSquaredDifference(const float *in0, const float *in1, float *out, int size) { + ElementSub(in0, in1, out, size); + return ElementMul(out, out, out, size); +} + +int ElementOptSquaredDifference(const float *in0, const float *in1, float *out, int size, bool scale) { + ElementOptSub(in0, in1, out, size, scale); + return ElementMul(out, out, out, size); +} +#endif // MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/squared_difference.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/squared_difference.h new file mode 100644 index 0000000000000000000000000000000000000000..6524626774eb65fe1d69ceab0762e23d44f85cf4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/squared_difference.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ +#define MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/base/arithmetic_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* Element Squared Difference */ +int ElementSquaredDifference(const float *in0, const float *in1, float *out, int size); +int ElementOptSquaredDifference(const float *in0, const float *in1, float *out, int size, bool scale); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_SQUARED_DIFFERENCE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/strided_slice_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/strided_slice_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..e3165d9632432946abab9c1823bae3167489c965 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/strided_slice_fp32.c @@ -0,0 +1,125 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/strided_slice_fp32.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/errorcode.h" + +int PadStridedSliceParameterTo8D(StridedSliceStruct *strided_slice) { + if (strided_slice->in_shape_size_ > DIMENSION_8D) { + return NNACL_STRIDED_SLICE_UNSUPPORTED_MAX_8D; + } + + int32_t begins[DIMENSION_8D]; + int32_t ends[DIMENSION_8D]; + int32_t strides[DIMENSION_8D]; + int32_t input_shape[DIMENSION_8D]; + int32_t i; + for (i = 0; i < strided_slice->in_shape_size_; ++i) { + begins[i] = strided_slice->begins_[i]; + ends[i] = MSMIN(strided_slice->ends_[i], strided_slice->in_shape_[i]); + strides[i] = strided_slice->strides_[i]; + input_shape[i] = strided_slice->in_shape_[i]; + } + + int32_t real_index = strided_slice->in_shape_size_ - 1; + for (i = DIMENSION_8D - 1; i >= 0; --i) { + if (real_index >= 0) { + strided_slice->begins_[i] = begins[real_index]; + strided_slice->ends_[i] = ends[real_index]; + strided_slice->strides_[i] = strides[real_index]; + strided_slice->in_shape_[i] = input_shape[real_index--]; + } else { + strided_slice->begins_[i] = 0; + strided_slice->ends_[i] = 1; + strided_slice->strides_[i] = 1; + strided_slice->in_shape_[i] = 1; + } + } + strided_slice->in_shape_size_ = DIMENSION_8D; + return NNACL_OK; +} + +bool LoopContinue(int stride, int i, int end) { return stride > 0 ? i < end : i > end; } + +int DoStridedSliceIn8D(const void *input, void *output, StridedSliceStruct *strided_slice) { + NNACL_CHECK_NULL_RETURN_ERR(strided_slice); + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + const uint8_t *in = (const uint8_t *)input; + uint8_t *out = (uint8_t *)output; + int data_type_size = (int)DataTypeCSize(strided_slice->data_type_); + + int32_t *begins = strided_slice->begins_; + int32_t *ends = strided_slice->ends_; + int32_t *strides = strided_slice->strides_; + int32_t *in_shape = strided_slice->in_shape_; + + int dim_offset[DIMENSION_8D - 1]; + dim_offset[6] = in_shape[7]; + dim_offset[5] = in_shape[6] * dim_offset[6]; + dim_offset[4] = in_shape[5] * dim_offset[5]; + dim_offset[3] = in_shape[4] * dim_offset[4]; + dim_offset[2] = in_shape[3] * dim_offset[3]; + dim_offset[1] = in_shape[2] * dim_offset[2]; + dim_offset[0] = in_shape[1] * dim_offset[1]; + size_t out_offset = 0; + int32_t dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7; + for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) { + for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) { + for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) { + for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) { + for (dim4 = begins[4]; LoopContinue(strides[4], dim4, ends[4]); dim4 += strides[4]) { + for (dim5 = begins[5]; LoopContinue(strides[5], dim5, ends[5]); dim5 += strides[5]) { + for (dim6 = begins[6]; LoopContinue(strides[6], dim6, ends[6]); dim6 += strides[6]) { + for (dim7 = begins[7]; LoopContinue(strides[7], dim7, ends[7]); dim7 += strides[7]) { + int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] + + dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5 * dim_offset[5] + + dim6 * dim_offset[6] + dim7; + memcpy(out + out_offset * data_type_size, in + in_offset * data_type_size, data_type_size); + out_offset++; + } + } + } + } + } + } + } + } + return NNACL_OK; +} + +void FastStride(const uint8_t *input, uint8_t *output, int split_len, int stride, size_t outer, size_t inner_size, + size_t in_offset) { + if (stride == 1) { + size_t unit = split_len * inner_size; + for (size_t i = 0; i < outer; ++i) { + memcpy(output, input, unit); + output += unit; + input += in_offset; + } + return; + } + for (size_t i = 0; i < outer; ++i) { + const uint8_t *input_ptr = input + i * in_offset; + for (int j = 0; j < split_len; ++j) { + memcpy(output, input_ptr, inner_size); + output += inner_size; + input_ptr += inner_size * stride; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/strided_slice_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/strided_slice_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..9f1ed10fbeec55c6298cf4d20dd9327f13b2f2d8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/strided_slice_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_STRIDED_SLICE_FP32_H_ +#define NNACL_FP32_STRIDED_SLICE_FP32_H_ + +#include "nnacl/op_base.h" +#include "nnacl/strided_slice_parameter.h" +#include "nnacl/kernel/strided_slice.h" +#ifdef __cplusplus +extern "C" { +#endif + +int PadStridedSliceParameterTo8D(StridedSliceStruct *strided_slice); +int DoStridedSliceIn8D(const void *input, void *output, StridedSliceStruct *strided_slice); + +void FastStride(const uint8_t *input, uint8_t *output, int split_len, int stride, size_t outer, size_t inner_size, + size_t in_offset); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FP32_STRIDED_SLICE_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sub_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sub_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..7e43c8a5f6ce1af2a268996bf02b617cc451bf39 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sub_fp32.c @@ -0,0 +1,150 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/sub_fp32.h" +#include "nnacl/sub_fp32_simd.h" +#include "nnacl/errorcode.h" + +int ElementOptSub(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] - in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptSubExt(const float *in0, const float *in1, const float alpha, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubExtNum0, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[0] - in1[index] * alpha; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubExtNum1, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[0] * alpha; + } + } + return NNACL_OK; +} + +int ElementSubExt(const float *in0, const float *in1, const float alpha, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSubExt, index, in0, in1, alpha, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[index] * alpha; + } + return NNACL_OK; +} + +int ElementOptSubInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubIntNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[0] - in1[index]; + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubIntNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[0]; + } + } + return NNACL_OK; +} + +int ElementOptSubRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubReluNum0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[0] - in1[index], 0); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubReluNum1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMAX(in0[index] - in1[0], 0); + } + } + return NNACL_OK; +} + +int ElementOptSubRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar) { + int index = 0; + if (first_scalar) { + SIMD_RUN_NO_SCALAR(ElementOptSubRelu6Num0, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[0] - in1[index], 0), 6); + } + } else { + SIMD_RUN_NO_SCALAR(ElementOptSubRelu6Num1, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] - in1[0], 0), 6); + } + } + return NNACL_OK; +} + +int ElementSub(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSub, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[index]; + } + return NNACL_OK; +} + +int ElementSubInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSubInt, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = in0[index] - in1[index]; + } + return NNACL_OK; +} + +int ElementSubRelu(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSubRelu, index, in0, in1, out, size); + for (; index < size; index++) { + float res = in0[index] - in1[index]; + out[index] = res > 0 ? res : 0; + } + return NNACL_OK; +} + +int ElementSubRelu6(const float *in0, const float *in1, float *out, int size) { + int index = 0; + + SIMD_RUN_NO_SCALAR(ElementSubRelu6, index, in0, in1, out, size); + for (; index < size; index++) { + out[index] = MSMIN(MSMAX(in0[index] - in1[index], 0), 6); + } + + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sub_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sub_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..42cd8a60b9921c5551de8dd6df21a0482a123a47 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sub_fp32.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SUB_FP32_H_ +#define MINDSPORE_NNACL_SUB_FP32_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/base/arithmetic_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ElementSub(const float *in0, const float *in1, float *out, int size); +int ElementSubExt(const float *in0, const float *in1, const float alpha, float *out, int size); +int ElementOptSubExt(const float *in0, const float *in1, const float alpha, float *out, int size, bool first_scalar); +int ElementSubInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size); +int ElementSubRelu(const float *in0, const float *in1, float *out, int size); +int ElementSubRelu6(const float *in0, const float *in1, float *out, int size); +int ElementOptSub(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptSubInt(const int32_t *in0, const int32_t *in1, int32_t *out, int size, bool first_scalar); +int ElementOptSubRelu(const float *in0, const float *in1, float *out, int size, bool first_scalar); +int ElementOptSubRelu6(const float *in0, const float *in1, float *out, int size, bool first_scalar); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_SUB_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sub_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sub_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..108e625a9b4d3ae11a19990f69ea68a80db07ff7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/sub_fp32_simd.h.in @@ -0,0 +1,199 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_SUB_@SIMD_INSTRUCTION@_H_ +#define MINDSPORE_NNACL_FP32_SUB_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ElementOptSubNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_SUB_F32(vin0_opt, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_SUB_F32(vin0, vin1_opt_); + SIMD_ST_F32(out + index, vout); + } + return index; +} + + +static inline int ElementOptSubExtNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_SUB_F32(vin0_opt, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubExtNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { +SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); +SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1_opt_, valpha); +SIMD_F32 vout = SIMD_SUB_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubIntNum0@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin0_opt = SIMD_MOV_EPI32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_SUB_EPI32(vin0_opt, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubIntNum1@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + SIMD_EPI32 vin1_opt_ = SIMD_MOV_EPI32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vout = SIMD_SUB_EPI32(vin0, vin1_opt_); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubReluNum0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_SUB_F32(vin0_opt, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubReluNum1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_SUB_F32(vin0, vin1_opt_), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubRelu6Num0@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin0_opt = SIMD_MOV_F32(in0[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_SUB_F32(vin0_opt, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementOptSubRelu6Num1@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + SIMD_F32 vin1_opt_ = SIMD_MOV_F32(in1[0]); + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_SUB_F32(vin0, vin1_opt_), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementSub@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_SUB_F32(vin0, vin1); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementSubExt@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, const float alpha, float *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 valpha = SIMD_MOV_F32(alpha); + SIMD_F32 vin1_alpha = SIMD_MUL_F32(vin1, valpha); + SIMD_F32 vout = SIMD_SUB_F32(vin0, vin1_alpha); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementSubInt@SIMD_INSTRUCTION@(int index, const int32_t *in0, const int32_t *in1, int32_t *out, int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_EPI32 vin0 = SIMD_LD_EPI32(in0 + index); + SIMD_EPI32 vin1 = SIMD_LD_EPI32(in1 + index); + SIMD_EPI32 vout = SIMD_SUB_EPI32(vin0, vin1); + SIMD_ST_EPI32(out + index, vout); + } + return index; +} + +static inline int ElementSubRelu@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MAX_N_F32(SIMD_SUB_F32(vin0, vin1), 0.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +static inline int ElementSubRelu6@SIMD_INSTRUCTION@(int index, const float *in0, const float *in1, float *out, + int size) { + for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 vin0 = SIMD_LD_F32(in0 + index); + SIMD_F32 vin1 = SIMD_LD_F32(in1 + index); + SIMD_F32 vout = SIMD_MIN_N_F32(SIMD_MAX_N_F32(SIMD_SUB_F32(vin0, vin1), 0.0f), 6.0f); + SIMD_ST_F32(out + index, vout); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +}; +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/topk_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/topk_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..70338758248936c124157702af11661853515ab2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/topk_fp32.c @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/topk_fp32.h" +#include "nnacl/errorcode.h" + +int DescendCmp(const void *a, const void *b) { + NNACL_CHECK_NULL_RETURN_ERR(a); + NNACL_CHECK_NULL_RETURN_ERR(b); + float sub = ((const TopkNode *)b)->element - ((const TopkNode *)a)->element; + if (sub > 0) { + return 1; + } else if (sub < 0) { + return -1; + } + if (((const TopkNode *)a)->index > ((const TopkNode *)b)->index) { + return 1; + } else { + return -1; + } +} + +int IndexSortCmp(const void *a, const void *b) { + if (((const TopkNode *)a)->index > ((const TopkNode *)b)->index) { + return 1; + } else { + return -1; + } +} + +void Topk(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter) { + int dim_size = parameter->dim_size_; + int outer_loop_num = parameter->outer_loop_num_; + int inner_loop_num = parameter->inner_loop_num_; + int k = parameter->k_; + TopkNode *top_map = (TopkNode *)parameter->topk_node_list_; + + float *cur_input_data = (float *)input_data; + float *cur_output_data = (float *)output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < outer_loop_num; i++) { + int in_offset = i * dim_size * inner_loop_num; + int out_offset = i * k * inner_loop_num; + for (int j = 0; j < inner_loop_num; j++) { + for (int m = 0; m < dim_size; m++) { + int offset = in_offset + m * inner_loop_num + j; + top_map[m].element = *(cur_input_data + offset); + top_map[m].index = m; + } + qsort(top_map, dim_size, sizeof(top_map[0]), DescendCmp); + if (!parameter->sorted_) { + qsort(top_map, k, sizeof(top_map[0]), IndexSortCmp); + } + for (int m = 0; m < k; m++) { + int offset = out_offset + m * inner_loop_num + j; + cur_output_data[offset] = top_map[m].element; + cur_output_index[offset] = top_map[m].index; + } + } + } +} + +void TopkInt(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter) { + int dim_size = parameter->dim_size_; + int outer_loop_num = parameter->outer_loop_num_; + int inner_loop_num = parameter->inner_loop_num_; + int k = parameter->k_; + TopkNode *top_map = (TopkNode *)parameter->topk_node_list_; + + int32_t *cur_input_data = (int32_t *)input_data; + int32_t *cur_output_data = (int32_t *)output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < outer_loop_num; i++) { + int in_offset = i * dim_size * inner_loop_num; + int out_offset = i * k * inner_loop_num; + for (int j = 0; j < inner_loop_num; j++) { + for (int m = 0; m < dim_size; m++) { + int offset = in_offset + m * inner_loop_num + j; + top_map[m].element = (float)(*(cur_input_data + offset)); + top_map[m].index = m; + } + qsort(top_map, dim_size, sizeof(top_map[0]), DescendCmp); + if (!parameter->sorted_) { + qsort(top_map, k, sizeof(top_map[0]), IndexSortCmp); + } + for (int m = 0; m < k; m++) { + int offset = out_offset + m * inner_loop_num + j; + cur_output_data[offset] = (int)(top_map[m].element); + cur_output_index[offset] = top_map[m].index; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/topk_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/topk_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..00c9a0af5de22f4f631051f1d3b0240c1c0494ce --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/topk_fp32.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_TOPK_H_ +#define MINDSPORE_NNACL_TOPK_H_ + +#include "nnacl/op_base.h" + +typedef struct TopkNode { + float element; + int32_t index; +} TopkNode; + +typedef struct TopkParameter { + // primitive parameter + OpParameter op_parameter_; + int k_; + int axis_; + bool sorted_; + + // other parameter + int dim_size_; + int outer_loop_num_; + int inner_loop_num_; + void *topk_node_list_; +} TopkParameter; + +#ifdef __cplusplus +extern "C" { +#endif +void Topk(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter); +void TopkInt(void *input_data, void *output_data, int32_t *output_index, TopkParameter *parameter); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_TOPK_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/transpose_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/transpose_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..46ed5f2a6a3de977d94587dbad42f6a1a9c646bb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/transpose_fp32.c @@ -0,0 +1,248 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/transpose_fp32.h" +#include "nnacl/op_base.h" + +void TransposeDim2Fp32(const float *in_data, float *out_data, const int32_t *strides, int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * output1; + int stride0_i = i * 1 * stride0; + for (int j = 0; j < output1; ++j) { + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; + } + } +} + +void TransposeDim3Fp32(const float *in_data, float *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; + } + } + } +} + +void TransposeDim4Fp32(const float *in_data, float *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; + } + } + } + } +} + +void TransposeDim5Fp32(const float *in_data, float *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; + } + } + } + } + } +} + +void TransposeDim6Fp32(const float *in_data, float *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int stride5 = strides[perm[5]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int out_stride4 = out_strides[4]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + const int output5 = output_shape[5]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + int out_stride4_m = n * out_stride4; + int stride4_m = n * stride4; + for (int g = 0; g < output5; ++g) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_m + g] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_m + g * stride5]; + } + } + } + } + } + } +} + +void TransposeDimsFp32(const void *in, void *out, const int32_t *output_shape, int32_t *perm, int32_t *strides, + int32_t *out_strides, int num_axes, int task_id, int thread_num) { + const float *in_data = (const float *)in; + float *out_data = (float *)out; + NNACL_CHECK_NULL_RETURN_VOID(in_data); + NNACL_CHECK_NULL_RETURN_VOID(out_data); + NNACL_CHECK_NULL_RETURN_VOID(output_shape); + + int data_size = (*out_strides) * output_shape[0]; + int offset_size = UP_DIV(data_size, thread_num); + int task_offset = offset_size * task_id; + int count = data_size - task_offset; + if (count <= 0) { + return; + } + count = MSMIN(offset_size, count); + for (int idx = task_offset; idx < task_offset + count; ++idx) { + int pos = idx; + int output_idx = 0; + int input_idx = 0; + for (int i = 0; i < num_axes; ++i) { + NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); + int position = pos / *(out_strides + i); + int out_stride = i < num_axes - 1 ? out_strides[i] : 1; + output_idx += (position * out_stride); + input_idx += (position * strides[perm[i]]); + pos -= position * (*(out_strides + i)); + } + out_data[output_idx] = in_data[input_idx]; + } +} + +int DoTransposeFp32(const void *in, void *out, const int32_t *output_shape, int32_t *perm, int32_t *strides, + int32_t *out_strides, int data_size, int num_axes) { + const float *in_data = (const float *)in; + float *out_data = (float *)out; + + NNACL_CHECK_NULL_RETURN_ERR(in_data); + NNACL_CHECK_NULL_RETURN_ERR(out_data); + NNACL_CHECK_NULL_RETURN_ERR(output_shape); + + // check if transpose is needed + bool needTranspose = false; + for (int i = 1; i < num_axes; ++i) { + if (perm[i] - perm[i - 1] != 1) { + needTranspose = true; + break; + } + } + + if (!needTranspose) { + (void)memcpy(out_data, in_data, data_size); + return NNACL_OK; + } + for (int i = 0; i < num_axes; ++i) { + if (perm[i] < 0) { + return NNACL_PARAM_INVALID; + } + } + if (num_axes == 2) { + TransposeDim2Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 3) { + TransposeDim3Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 4) { + TransposeDim4Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 5) { + TransposeDim5Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else if (num_axes == 6) { + TransposeDim6Fp32(in_data, out_data, strides, out_strides, perm, output_shape); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/transpose_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/transpose_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..f73aacc9fa5b5d876854a35e43d862096db1c216 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/transpose_fp32.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_FP32_TRANSPOSE_H_ +#define MINDSPORE_NNACL_FP32_TRANSPOSE_H_ + +#include +#include +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoTransposeFp32(const void *in, void *out, const int32_t *output_shape, int32_t *perm, int32_t *strides, + int32_t *out_strides, int data_size, int num_axes); +void TransposeDimsFp32(const void *in, void *out, const int32_t *output_shape, int32_t *perm, int32_t *strides, + int32_t *out_strides, int num_axes, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_TRANSPOSE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/transpose_server_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/transpose_server_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..3832c0d5efecd1f009cc5af9402b6fa3d4d107b7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/transpose_server_fp32.c @@ -0,0 +1,239 @@ +#ifdef BFC_MEMORY +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/transpose_server_fp32.h" + +#define JUDGEPART(NUM) \ + if (dim_start##NUM == overflow_point##NUM) { \ + dim_start##NUM = 0; \ + } else { \ + ++dim_start##NUM; \ + in_offset += stride##NUM; \ + continue; \ + } + +void DoTransposeServerDim3(const float *in_data, float *out_data, const int64_t *overflow_points, + const int64_t *strides, const TransposeBlockBoundaryInfo *boundary_info) { + int64_t stride2 = strides[THIRD_INPUT]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + out_data += boundary_info->out_start_offset; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride2]; + } + int64_t dim_start1 = boundary_info->start_dim[1]; + int64_t overflow_point1 = overflow_points[1]; + int64_t overflow_point2 = overflow_points[THIRD_INPUT]; + int64_t stride0 = strides[0]; + int64_t stride1 = strides[1]; + int64_t last_dim = overflow_point2 + 1; + out_data += size; + size = boundary_info->sizes[1]; + in_offset = boundary_info->in_offsets[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < overflow_point2; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride2; + } + out_data[i + overflow_point2] = in_data[in_offset]; + JUDGEPART(1) + in_offset += stride0; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride2]; + } +} + +void DoTransposeServerDim4(const float *in_data, float *out_data, const int64_t *overflow_points, + const int64_t *strides, const TransposeBlockBoundaryInfo *boundary_info) { + int64_t stride3 = strides[FOURTH_INPUT]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + out_data += boundary_info->out_start_offset; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride3]; + } + int64_t dim_start1 = boundary_info->start_dim[1]; + int64_t dim_start2 = boundary_info->start_dim[THIRD_INPUT]; + int64_t overflow_point1 = overflow_points[1]; + int64_t overflow_point2 = overflow_points[THIRD_INPUT]; + int64_t overflow_point3 = overflow_points[FOURTH_INPUT]; + int64_t stride0 = strides[0]; + int64_t stride1 = strides[1]; + int64_t stride2 = strides[THIRD_INPUT]; + int64_t last_dim = overflow_point3 + 1; + out_data += size; + size = boundary_info->sizes[1]; + in_offset = boundary_info->in_offsets[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < overflow_point3; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride3; + } + out_data[i + overflow_point3] = in_data[in_offset]; + JUDGEPART(2) + JUDGEPART(1) + in_offset += stride0; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride3]; + } +} + +void DoTransposeServerDim5(const float *in_data, float *out_data, const int64_t *overflow_points, + const int64_t *strides, const TransposeBlockBoundaryInfo *boundary_info) { + int64_t stride4 = strides[FIFTH_INPUT]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + out_data += boundary_info->out_start_offset; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride4]; + } + int64_t dim_start1 = boundary_info->start_dim[1]; + int64_t dim_start2 = boundary_info->start_dim[THIRD_INPUT]; + int64_t dim_start3 = boundary_info->start_dim[FOURTH_INPUT]; + int64_t overflow_point1 = overflow_points[1]; + int64_t overflow_point2 = overflow_points[THIRD_INPUT]; + int64_t overflow_point3 = overflow_points[FOURTH_INPUT]; + int64_t overflow_point4 = overflow_points[FIFTH_INPUT]; + int64_t stride0 = strides[0]; + int64_t stride1 = strides[1]; + int64_t stride2 = strides[THIRD_INPUT]; + int64_t stride3 = strides[FOURTH_INPUT]; + int64_t last_dim = overflow_point4 + 1; + out_data += size; + size = boundary_info->sizes[1]; + in_offset = boundary_info->in_offsets[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < overflow_point4; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride4; + } + out_data[i + overflow_point4] = in_data[in_offset]; + JUDGEPART(3) + JUDGEPART(2) + JUDGEPART(1) + in_offset += stride0; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride4]; + } +} + +void DoTransposeServerDim6(const float *in_data, float *out_data, const int64_t *overflow_points, + const int64_t *strides, const TransposeBlockBoundaryInfo *boundary_info) { + int64_t stride5 = strides[SIXTH_INPUT]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + out_data += boundary_info->out_start_offset; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride5]; + } + int64_t dim_start1 = boundary_info->start_dim[1]; + int64_t dim_start2 = boundary_info->start_dim[THIRD_INPUT]; + int64_t dim_start3 = boundary_info->start_dim[FOURTH_INPUT]; + int64_t dim_start4 = boundary_info->start_dim[FIFTH_INPUT]; + int64_t overflow_point1 = overflow_points[1]; + int64_t overflow_point2 = overflow_points[THIRD_INPUT]; + int64_t overflow_point3 = overflow_points[FOURTH_INPUT]; + int64_t overflow_point4 = overflow_points[FIFTH_INPUT]; + int64_t overflow_point5 = overflow_points[SIXTH_INPUT]; + int64_t stride0 = strides[0]; + int64_t stride1 = strides[1]; + int64_t stride2 = strides[THIRD_INPUT]; + int64_t stride3 = strides[FOURTH_INPUT]; + int64_t stride4 = strides[FIFTH_INPUT]; + int64_t last_dim = overflow_point5 + 1; + out_data += size; + size = boundary_info->sizes[1]; + in_offset = boundary_info->in_offsets[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < overflow_point5; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride5; + } + out_data[i + overflow_point5] = in_data[in_offset]; + JUDGEPART(4) + JUDGEPART(3) + JUDGEPART(2) + JUDGEPART(1) + in_offset += stride0; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride5]; + } +} + +void DoTransposeServer(const float *in_data, float *out_data, const int64_t *overflow_points, const int64_t *strides, + int axis_num, const TransposeBlockBoundaryInfo *boundary_info) { + if (axis_num == DIMENSION_3D) { + DoTransposeServerDim3(in_data, out_data, overflow_points, strides, boundary_info); + return; + } else if (axis_num == DIMENSION_4D) { + DoTransposeServerDim4(in_data, out_data, overflow_points, strides, boundary_info); + return; + } else if (axis_num == DIMENSION_5D) { + DoTransposeServerDim5(in_data, out_data, overflow_points, strides, boundary_info); + return; + } else if (axis_num == DIMENSION_6D) { + DoTransposeServerDim6(in_data, out_data, overflow_points, strides, boundary_info); + return; + } + out_data += boundary_info->out_start_offset; + int64_t stride = strides[axis_num - 1]; + int64_t size = boundary_info->sizes[0]; + int64_t in_offset = boundary_info->in_offsets[0]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride]; + } + int64_t dim_info[MAX_TRANSPOSE_DIM_SIZE] = {}; + for (int i = 0; i < axis_num; ++i) { + dim_info[i] = boundary_info->start_dim[i]; + } + int64_t last_overflow_point = overflow_points[axis_num - 1]; + int64_t last_dim = last_overflow_point + 1; + out_data += size; + size = boundary_info->sizes[1]; + for (int64_t i = 0; i < size; i += last_dim) { + for (int64_t j = 0; j < last_overflow_point; ++j) { + out_data[i + j] = in_data[in_offset]; + in_offset += stride; + } + out_data[i + last_overflow_point] = in_data[in_offset]; + int j = axis_num - 2; + while (dim_info[j] == overflow_points[j]) { + dim_info[j] = 0; + --j; + } + ++dim_info[j]; + in_offset += strides[j]; + } + out_data += size; + size = boundary_info->sizes[THIRD_INPUT]; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = in_data[in_offset + i * stride]; + } +} +#endif diff --git a/mindspore-lite/tools/graph_kernel/converter/split_model_cpu.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/transpose_server_fp32.h similarity index 46% rename from mindspore-lite/tools/graph_kernel/converter/split_model_cpu.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/transpose_server_fp32.h index b86d06e0e525595c076fa2e0a35ce126b378469a..ecd257ab8a77f5a530defa861dcb8e84449ac844 100644 --- a/mindspore-lite/tools/graph_kernel/converter/split_model_cpu.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/transpose_server_fp32.h @@ -13,19 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_MODEL_CPU_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_MODEL_CPU_H_ +#ifndef MINDSPORE_NNACL_FP32_TRANSPOSE_SERVER_FP32_H_ +#define MINDSPORE_NNACL_FP32_TRANSPOSE_SERVER_FP32_H_ -#include "backend/common/graph_kernel/split_model/split_model_factory.h" -namespace mindspore::graphkernel::inner { -class SplitModelCpu : public SplitModel { - public: - SplitModelCpu() = default; - virtual ~SplitModelCpu() = default; +#ifdef BFC_MEMORY +#include "nnacl/transpose_parameter.h" - protected: - AreaMode GetDefaultAreaMode(const PrimOpPtr &) const override; - void InitFusePatterns() override; +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TransposeBlockBoundaryInfo { + int64_t out_start_offset; + int64_t sizes[C3NUM]; + int64_t in_offsets[C2NUM]; + int64_t start_dim[MAX_TRANSPOSE_DIM_SIZE]; +} TransposeBlockBoundaryInfo; + +void DoTransposeServer(const float *in_data, float *out_data, const int64_t *overflow_points, const int64_t *strides, + int axis_num, const TransposeBlockBoundaryInfo *boundary_info); +#ifdef __cplusplus }; -} // namespace mindspore::graphkernel::inner -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_MODEL_CPU_H_ +#endif + +#endif // MINDSPORE_NNACL_FP32_TRANSPOSE_SERVER_FP32_H_ +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/triu_tril_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/triu_tril_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..48dc1d5b77927aabf37bca62c3dfe3358abc2cfc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/triu_tril_fp32.c @@ -0,0 +1,179 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/triu_tril_fp32.h" + +int TriuTrilGetCalculateNum(KernelBase *self, int64_t *mul, int64_t *height, int64_t *width) { + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + for (size_t i = 0; i < input_tensor->shape_size_; i++) { + if (input_tensor->shape_[i] <= 0) { + return NNACL_TRIU_TRIL_INPUT_SHAPE_ERROR; + } + } + + size_t input_hw_dims = Num2; + NNACL_CHECK_FALSE(input_tensor->shape_size_ < DIMENSION_2D, NNACL_TRIU_INPUT_DIMS_INVALID); + + *mul = 1; + for (size_t i = 0; i < input_tensor->shape_size_ - input_hw_dims; i++) { + *mul *= input_tensor->shape_[i]; + } + *height = input_tensor->shape_[input_tensor->shape_size_ - Num2]; + *width = input_tensor->shape_[input_tensor->shape_size_ - Num1]; + + return NNACL_OK; +} + +int TriuTrilGetKValue(KernelBase *self, int64_t *k) { + if (self->in_size_ <= 1) { + *k = 0; + return NNACL_OK; + } + + TensorC *k_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(k_tensor); + NNACL_CHECK_NULL_RETURN_ERR(k_tensor->data_); + + switch (k_tensor->data_type_) { + case kNumberTypeInt: + case kNumberTypeInt32: + *k = *((int32_t *)k_tensor->data_); + break; + case kNumberTypeInt64: + *k = *((int64_t *)k_tensor->data_); + break; + default: + return NNACL_TRIU_K_TENSOR_DATA_TYPE_INVALID; + } + return NNACL_OK; +} + +void TriuByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int64_t *src_data = (const int64_t *)src; + int64_t *dst_data = (int64_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k <= w ? src_data[index] : 0; + } + } + } +} + +void TriuByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int32_t *src_data = (const int32_t *)src; + int32_t *dst_data = (int32_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k <= w ? src_data[index] : 0; + } + } + } +} + +void TriuByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int16_t *src_data = (const int16_t *)src; + int16_t *dst_data = (int16_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k <= w ? src_data[index] : 0; + } + } + } +} +void TriuByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int8_t *src_data = (const int8_t *)src; + int8_t *dst_data = (int8_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k <= w ? src_data[index] : 0; + } + } + } +} + +void TrilByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int64_t *src_data = (const int64_t *)src; + int64_t *dst_data = (int64_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k >= w ? src_data[index] : 0; + } + } + } +} + +void TrilByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int32_t *src_data = (const int32_t *)src; + int32_t *dst_data = (int32_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k >= w ? src_data[index] : 0; + } + } + } +} +void TrilByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int16_t *src_data = (const int16_t *)src; + int16_t *dst_data = (int16_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k >= w ? src_data[index] : 0; + } + } + } +} +void TrilByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems) { + const int8_t *src_data = (const int8_t *)src; + int8_t *dst_data = (int8_t *)dst; + for (int64_t m = 0; m < out_elems; m++) { + int64_t m_factor = m * height * width; + for (int64_t h = 0; h < height; h++) { + int64_t h_factor = m_factor + h * width; + for (int64_t w = 0; w < width; w++) { + int64_t index = h_factor + w; + dst_data[index] = h + k >= w ? src_data[index] : 0; + } + } + } +} diff --git a/mindspore-lite/src/common/draw/drawer.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/triu_tril_fp32.h similarity index 33% rename from mindspore-lite/src/common/draw/drawer.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/triu_tril_fp32.h index 9e210913d4b1d8c0be5c814cc0177f76cdc29932..2d202c4c1ee2e3e7ca23b30dd05daf2aba4d366a 100644 --- a/mindspore-lite/src/common/draw/drawer.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/triu_tril_fp32.h @@ -13,55 +13,30 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef MINDSPORE_NNACL_FP32_TRIU_TRIL_H_ +#define MINDSPORE_NNACL_FP32_TRIU_TRIL_H_ -#ifndef MINDSPORE_LITE_SRC_COMMON_DRAW_DRAWER_H_ -#define MINDSPORE_LITE_SRC_COMMON_DRAW_DRAWER_H_ +#include "nnacl/op_base.h" +#include "nnacl/kernel.h" -#include - -#include "src/executor/sub_graph_kernel.h" - -#ifdef ENABLE_CLOUD_INFERENCE -#include "src/extendrt/graph_compiler/compile_result.h" +#ifdef __cplusplus +extern "C" { #endif -namespace mindspore { -namespace lite { -class Drawer { - public: - static Drawer &Instance() { - static Drawer instance; - return instance; - } - - void Init(); - - void Reset(); +int TriuTrilGetCalculateNum(KernelBase *self, int64_t *mul, int64_t *height, int64_t *width); +int TriuTrilGetKValue(KernelBase *self, int64_t *k); - std::string GetNextFileName(const std::string &name); +void TriuByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TriuByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TriuByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TriuByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); - void Draw(const kernel::SubGraphKernel *graph, const std::string &name = ""); -#ifdef ENABLE_CLOUD_INFERENCE - void Draw(const CompileResult *graph, const std::string &name = ""); -#endif - - private: - Drawer() = default; - bool SaveDotFile(const std::string &dot_name, const std::string &dot_content); +void TrilByte8(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TrilByte4(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TrilByte2(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); +void TrilByte1(const void *src, void *dst, int64_t k, int64_t height, int64_t width, int64_t out_elems); - bool enabled_{false}; - std::string base_dir_; - size_t count_{0}; -}; -} // namespace lite - -#if (defined Debug) && (defined ENABLE_DRAW) -#define InitDotDrawer() mindspore::lite::Drawer::Instance().Init() -#define DrawDot(graph, name) mindspore::lite::Drawer::Instance().Draw(graph, name) -#else -#define InitDotDrawer() -#define DrawDot(graph, name) +#ifdef __cplusplus +} #endif -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_COMMON_DRAW_DRAWER_H_ +#endif // MINDSPORE_NNACL_FP32_TRIU_TRIL_H_ diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/acl_graph_impl.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/unique_fp32.c similarity index 34% rename from mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/acl_graph_impl.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32/unique_fp32.c index f6c5f0b56eb8e0f6e4382db505fbfdf240521a42..25d0451ca93fc07fc79066e32c3f5adb0f994b96 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/acl_graph_impl.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/unique_fp32.c @@ -13,46 +13,55 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H -#define MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H -#include -#include -#include -#include -#include -#include "include/api/graph.h" -#include "cxx_api/graph/acl/model_process.h" -#include "cxx_api/graph/acl/acl_env_guard.h" -#include "cxx_api/graph/graph_impl.h" -#include "cxx_api/factory.h" -namespace mindspore { -class AclGraphImpl : public GraphCell::GraphImpl { - public: - AclGraphImpl(); - ~AclGraphImpl() override; +#include "nnacl/fp32/unique_fp32.h" - Status Run(const std::vector &inputs, std::vector *outputs) override; - Status Load(uint32_t device_id) override; - std::vector GetInputs() override; - std::vector GetOutputs() override; - bool CheckDeviceSupport(mindspore::DeviceType device_type) override; +int Find(const float *array, int len, float target) { + if (array == NULL) { + return -1; + } + for (int i = 0; i < len; ++i) { + if (array[i] == target) { + return i; + } + } + return -1; +} - private: - Status ConvertToOM(); - Status InitEnv(); - Status FinalizeEnv(); - Status LoadAclModel(const Buffer om_data); +void Unique(const float *input, int input_len, float *output0, int32_t *output0_len, int32_t *output1) { + *output0_len = 0; + for (int i = 0; i < input_len; i++) { + int idx = Find(output0, *output0_len, input[i]); + if (idx != -1) { + *output1++ = idx; + } else { + output0[(*output0_len)++] = input[i]; + *output1++ = *output0_len - 1; + } + } +} - bool init_flag_; - bool load_flag_; - std::string device_type_; - int32_t device_id_; - aclrtContext context_; +int FindInt(const int32_t *array, int len, int target) { + if (array == NULL) { + return -1; + } + for (int i = 0; i < len; ++i) { + if (array[i] == target) { + return i; + } + } + return -1; +} - std::shared_ptr acl_env_; - - ModelProcess model_process_; -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H +void UniqueInt(const int32_t *input, int input_len, int32_t *output0, int32_t *output0_len, int32_t *output1) { + *output0_len = 0; + for (int i = 0; i < input_len; i++) { + int idx = FindInt(output0, *output0_len, input[i]); + if (idx != -1) { + *output1++ = idx; + } else { + output0[(*output0_len)++] = input[i]; + *output1++ = *output0_len - 1; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/unique_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/unique_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..96ec1694ec95887a17ff224bf3e24bfa0d0c0e43 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/unique_fp32.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_UNIQUE_H +#define MINDSPORE_NNACL_UNIQUE_H + +#include "nnacl/op_base.h" + +typedef struct UniqueParameter { + // primitive parameter + OpParameter op_parameter_; +} UniqueParameter; + +#ifdef __cplusplus +extern "C" { +#endif +void Unique(const float *input, int input_len, float *output0, int32_t *output0_len, int32_t *output1); +void UniqueInt(const int32_t *input, int input_len, int32_t *output0, int32_t *output0_len, int32_t *output1); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_UNIQUE_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/where_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/where_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..3c045530f72cd0eaf25a78770e4d8ac2c564f7e2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/where_fp32.c @@ -0,0 +1,35 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/where_fp32.h" +#include "nnacl/common_func.h" + +void WhereWithTripleInputs(const float *x, const float *y, float *output, const WhereArgs *param, int task_id, + int thread_num) { + const bool *condition = param->condition_; + int stride = UP_DIV(param->max_num_, thread_num); + int begin = task_id * stride; + int end = MSMIN(begin + stride, param->max_num_); + + for (int i = begin; i < end; ++i) { + bool cond = condition[param->condition_num_ > 1 ? i : 0]; + if (cond) { + output[i] = x[param->x_num_ > 1 ? i : 0]; + } else { + output[i] = y[param->y_num_ > 1 ? i : 0]; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/where_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/where_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..4e5986dfe54356d5d7b3c9b6392ce3f18cead4a1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/where_fp32.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FP32_WHERE_FP32_H_ +#define MINDSPORE_NNACL_FP32_WHERE_FP32_H_ + +#include "nnacl/op_base.h" +#include "nnacl/where_parameter.h" +#include "nnacl/kernel/where.h" + +#ifdef __cplusplus +extern "C" { +#endif +void WhereWithTripleInputs(const float *x, const float *y, float *output, const WhereArgs *param, int task_id, + int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_FP32_WHERE_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_avx.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_avx.c new file mode 100644 index 0000000000000000000000000000000000000000..fb1069695135c3aa284c9992b4a4f4152b30bd48 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_avx.c @@ -0,0 +1,2233 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless re256uired by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl/fp32/winograd_avx.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +void InputTransform4x4AvxUnit(const float *src_data, float *dst_data, const int src_step, const int dst_step, + const int real_c) { + if (real_c == C8NUM) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[16]; + LoadAvx16Data; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_SUB256_F32(src[offset], src[2 + offset]); + t[4 + l] = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[8 + l] = MS_SUB256_F32(src[2 + offset], src[1 + offset]); + t[12 + l] = MS_SUB256_F32(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = MS_SUB256_F32(t[offset], t[2 + offset]); + m[4 + l] = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[8 + l] = MS_SUB256_F32(t[2 + offset], t[1 + offset]); + m[12 + l] = MS_SUB256_F32(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + MS_ST256_F32(dst_data + i * dst_step, m[i]); + } + } else { + float src[16]; + float t[16]; + float m[16]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] - src[2 + offset]; + t[4 + l] = src[1 + offset] + src[2 + offset]; + t[8 + l] = src[2 + offset] - src[1 + offset]; + t[12 + l] = src[3 + offset] - src[1 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = t[offset] - t[2 + offset]; + m[4 + l] = t[1 + offset] + t[2 + offset]; + m[8 + l] = t[2 + offset] - t[1 + offset]; + m[12 + l] = t[3 + offset] - t[1 + offset]; + } + for (int k = 0; k < 16; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } + } +} + +void InputTransform6x6AvxUnit(const float *src_data, float *dst_data, const int src_step, const int dst_step, + const int real_c) { + if (real_c == C8NUM) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[36]; + MS_FLOAT32X8 m[36]; + LoadAvx36Data; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_SUB256_F32(src[3 + offset], src[1 + offset]); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(src[4 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[offset], 4), MS_MUL256_N_F32(src[2 + offset], 5)), + src[4 + offset]); + t[6 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_ADD256_F32(src[1 + offset], src[2 + offset]), -4), + MS_ADD256_F32(src[3 + offset], src[4 + offset])); + t[12 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), 4), + MS_SUB256_F32(src[4 + offset], src[3 + offset])); + t[18 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 2), tmp2); + t[24 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, -2), tmp2); + t[30 + l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[1 + offset], 4), MS_MUL256_N_F32(src[3 + offset], 5)), + src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_SUB256_F32(t[3 + offset], t[1 + offset]); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(t[4 + offset], t[2 + offset]); + m[l] = + MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[offset], 4), MS_MUL256_N_F32(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_ADD256_F32(t[1 + offset], t[2 + offset]), -4), + MS_ADD256_F32(t[3 + offset], t[4 + offset])); + m[12 + l] = MS_ADD256_F32(MS_MUL256_N_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), 4), + MS_SUB256_F32(t[4 + offset], t[3 + offset])); + m[18 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 2), tmp2); + m[24 + l] = MS_ADD256_F32(MS_MUL256_N_F32(tmp1, -2), tmp2); + m[30 + l] = MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[1 + offset], 4), MS_MUL256_N_F32(t[3 + offset], 5)), + t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + MS_ST256_F32(dst_data + i * dst_step, m[i]); + } + } else { + float src[36]; + float t[36]; + float m[36]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = src[3 + offset] - src[1 + offset]; + float tmp2 = src[4 + offset] - src[2 + offset]; + t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset]; + t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]); + t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]); + t[18 + l] = 2 * tmp1 + tmp2; + t[24 + l] = -2 * tmp1 + tmp2; + t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = t[3 + offset] - t[1 + offset]; + float tmp2 = t[4 + offset] - t[2 + offset]; + m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset]; + m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]); + m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]); + m[18 + l] = 2 * tmp1 + tmp2; + m[24 + l] = -2 * tmp1 + tmp2; + m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset]; + } + for (int k = 0; k < 36; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } + } +} + +void InputTransform8x8AvxUnit_block8(const float *src_data, float *dst_data, const int src_step, const int dst_step) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[64]; + MS_FLOAT32X8 m[64]; + LoadAvx64Data; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = MS_SUB256_F32( + MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(src[offset], 0.5625), MS_MUL256_N_F32(src[2 + offset], 3.0625)), + MS_MUL256_N_F32(src[4 + offset], 3.5)), + src[6 + offset]); + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 1.125), MS_MUL256_N_F32(src[5 + offset], 0.5)); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 2.25), MS_MUL256_N_F32(src[4 + offset], 3.25)); + t[8 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 0.5625), MS_MUL256_N_F32(src[4 + offset], 2.5)); + t[24 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], 0.375), MS_MUL256_N_F32(src[5 + offset], 1.5)); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(src[2 + offset], 0.25), MS_MUL256_N_F32(src[4 + offset], 1.25)); + t[40 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = MS_ADD256_F32( + MS_SUB256_F32(MS_ADD256_F32(MS_MUL256_N_F32(src[1 + offset], -0.5625), MS_MUL256_N_F32(src[3 + offset], 3.0625)), + MS_MUL256_N_F32(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = MS_SUB256_F32( + MS_ADD256_F32(MS_SUB256_F32(MS_MUL256_N_F32(t[offset], 0.5625), MS_MUL256_N_F32(t[2 + offset], 3.0625)), + MS_MUL256_N_F32(t[4 + offset], 3.5)), + t[6 + offset]); + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 1.125), MS_MUL256_N_F32(t[5 + offset], 0.5)); + MS_FLOAT32X8 tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 2.25), MS_MUL256_N_F32(t[4 + offset], 3.25)); + m[8 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 0.5625), MS_MUL256_N_F32(t[4 + offset], 2.5)); + m[24 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], 0.375), MS_MUL256_N_F32(t[5 + offset], 1.5)); + tmp2 = MS_SUB256_F32(MS_MUL256_N_F32(t[2 + offset], 0.25), MS_MUL256_N_F32(t[4 + offset], 1.25)); + m[40 + l] = + MS_ADD256_F32(MS_SUB256_F32(MS_ADD256_F32(tmp1, tmp2), MS_MUL256_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = + MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(tmp2, tmp1), MS_MUL256_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = MS_ADD256_F32( + MS_SUB256_F32(MS_ADD256_F32(MS_MUL256_N_F32(t[1 + offset], -0.5625), MS_MUL256_N_F32(t[3 + offset], 3.0625)), + MS_MUL256_N_F32(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + MS_ST256_F32(dst_data + i * dst_step, m[i]); + } +} + +void InputTransform8x8AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + if (real_c == C8NUM) { + InputTransform8x8AvxUnit_block8(src_data, dst_data, src_step, dst_step); + } else { + float src[64]; + float t[64]; + float m[64]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset]; + float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset]; + float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset]; + t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.5625f * src[1 + offset] + src[5 + offset]; + tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset]; + t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset]; + tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset]; + t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset]; + t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset]; + float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset]; + float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset]; + m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.5625f * t[1 + offset] + t[5 + offset]; + tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset]; + m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset]; + tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset]; + m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset]; + m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset]; + } + for (int k = 0; k < 64; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } + } +} + +void OutputTransform4x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[8]; + MS_FLOAT32X8 m[4]; + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[8]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[8]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + m[l + 2] = MS_MIN256_F32(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[9]; + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(src[offset], tmp); + t[l + 4] = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADD256_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(tmp, t[3 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(src[offset], tmp); + t[l + 4] = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADD256_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(tmp, t[3 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform4x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[16]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx16Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADD256_F32(src[offset], tmp); + t[l + 4] = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADD256_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X8 tmp = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(tmp, t[3 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 3] = MS_MIN256_F32(six, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 6] = MS_MIN256_F32(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[4]; + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[12]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + m[l + 2] = MS_MIN256_F32(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} +void OutputTransform6x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[18]; + MS_FLOAT32X8 m[9]; + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[18]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[18]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(MS_SUB256_F32(src[1 + offset], src[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_SUB256_F32(t[1 + offset], t[2 + offset]), + MS_MUL256_N_F32(MS_SUB256_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 3] = MS_MIN256_F32(six, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 6] = MS_MIN256_F32(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x4AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[16]; + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x4ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[16]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 4] = MS_MAX256_F32(zero, m[l + 4]); + m[l + 8] = MS_MAX256_F32(zero, m[l + 8]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x4Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[16]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 4] = MS_MAX256_F32(zero, m[l + 4]); + m[l + 4] = MS_MIN256_F32(six, m[l + 4]); + m[l + 8] = MS_MAX256_F32(zero, m[l + 8]); + m[l + 8] = MS_MIN256_F32(six, m[l + 8]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + m[l + 12] = MS_MIN256_F32(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x5AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[30]; + MS_FLOAT32X8 m[25]; + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x5ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[30]; + MS_FLOAT32X8 m[25]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 5] = MS_MAX256_F32(zero, m[l + 5]); + m[l + 10] = MS_MAX256_F32(zero, m[l + 10]); + m[l + 15] = MS_MAX256_F32(zero, m[l + 15]); + m[l + 20] = MS_MAX256_F32(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform6x5Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[36]; + MS_FLOAT32X8 t[30]; + MS_FLOAT32X8 m[25]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx36Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)); + t[l + 12] = MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)); + t[l + 18] = MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADD256_F32(MS_ADD256_F32(tmp3, MS_MUL256_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(tmp1, MS_MUL256_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 5] = MS_MAX256_F32(zero, m[l + 5]); + m[l + 5] = MS_MIN256_F32(six, m[l + 5]); + m[l + 10] = MS_MAX256_F32(zero, m[l + 10]); + m[l + 10] = MS_MIN256_F32(six, m[l + 10]); + m[l + 15] = MS_MAX256_F32(zero, m[l + 15]); + m[l + 15] = MS_MIN256_F32(six, m[l + 15]); + m[l + 20] = MS_MAX256_F32(zero, m[l + 20]); + m[l + 20] = MS_MIN256_F32(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[4]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[16]; + MS_FLOAT32X8 m[4]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 2] = MS_MAX256_F32(zero, m[l + 2]); + m[l + 2] = MS_MIN256_F32(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + StoreAvx4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[9]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 6] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 6] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[24]; + MS_FLOAT32X8 m[9]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 6] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 3] = MS_MAX256_F32(zero, m[l + 3]); + m[l + 3] = MS_MIN256_F32(six, m[l + 3]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 6] = MS_MIN256_F32(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + StoreAvx9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x4AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[32]; + MS_FLOAT32X8 m[16]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 8] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x4ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[32]; + MS_FLOAT32X8 m[16]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 8] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 4] = MS_MAX256_F32(zero, m[l + 4]); + m[l + 8] = MS_MAX256_F32(zero, m[l + 8]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x4Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[32]; + MS_FLOAT32X8 m[16]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 8] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 4] = MS_MAX256_F32(zero, m[l + 4]); + m[l + 4] = MS_MIN256_F32(six, m[l + 4]); + m[l + 8] = MS_MAX256_F32(zero, m[l + 8]); + m[l + 8] = MS_MIN256_F32(six, m[l + 8]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + m[l + 12] = MS_MIN256_F32(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + StoreAvx16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x5AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[40]; + MS_FLOAT32X8 m[25]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 10] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x5ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[40]; + MS_FLOAT32X8 m[25]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 10] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 5] = MS_MAX256_F32(zero, m[l + 5]); + m[l + 10] = MS_MAX256_F32(zero, m[l + 10]); + m[l + 15] = MS_MAX256_F32(zero, m[l + 15]); + m[l + 20] = MS_MAX256_F32(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x5Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[40]; + MS_FLOAT32X8 m[25]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 10] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 5] = MS_MAX256_F32(zero, m[l + 5]); + m[l + 5] = MS_MIN256_F32(six, m[l + 5]); + m[l + 10] = MS_MAX256_F32(zero, m[l + 10]); + m[l + 10] = MS_MIN256_F32(six, m[l + 10]); + m[l + 15] = MS_MAX256_F32(zero, m[l + 15]); + m[l + 15] = MS_MIN256_F32(six, m[l + 15]); + m[l + 20] = MS_MAX256_F32(zero, m[l + 20]); + m[l + 20] = MS_MIN256_F32(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + StoreAvx25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[48]; + MS_FLOAT32X8 m[36]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x6ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[48]; + MS_FLOAT32X8 m[36]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + m[l + 18] = MS_MAX256_F32(zero, m[l + 18]); + m[l + 24] = MS_MAX256_F32(zero, m[l + 24]); + m[l + 30] = MS_MAX256_F32(zero, m[l + 30]); + } + if (r_c == C8NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x6Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[48]; + MS_FLOAT32X8 m[36]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 12] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 6] = MS_MAX256_F32(zero, m[l + 6]); + m[l + 6] = MS_MIN256_F32(six, m[l + 6]); + m[l + 12] = MS_MAX256_F32(zero, m[l + 12]); + m[l + 12] = MS_MIN256_F32(six, m[l + 12]); + m[l + 18] = MS_MAX256_F32(zero, m[l + 18]); + m[l + 18] = MS_MIN256_F32(six, m[l + 18]); + m[l + 24] = MS_MAX256_F32(zero, m[l + 24]); + m[l + 24] = MS_MIN256_F32(six, m[l + 24]); + m[l + 30] = MS_MAX256_F32(zero, m[l + 30]); + m[l + 30] = MS_MIN256_F32(six, m[l + 30]); + } + if (r_c == C8NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x7AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[56]; + MS_FLOAT32X8 m[49]; + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), MS_MUL256_N_F32(tmp3, 11.390625)), + src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 14] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), + MS_MUL256_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_ST256_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x7ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[56]; + MS_FLOAT32X8 m[49]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), MS_MUL256_N_F32(tmp3, 11.390625)), + src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 14] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), + MS_MUL256_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l + 7] = MS_MAX256_F32(zero, m[l + 7]); + m[l + 14] = MS_MAX256_F32(zero, m[l + 14]); + m[l + 21] = MS_MAX256_F32(zero, m[l + 21]); + m[l + 28] = MS_MAX256_F32(zero, m[l + 28]); + m[l + 35] = MS_MAX256_F32(zero, m[l + 35]); + m[l + 42] = MS_MAX256_F32(zero, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_ST256_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} + +void OutputTransform8x7Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X8 src[64]; + MS_FLOAT32X8 t[56]; + MS_FLOAT32X8 m[49]; + MS_FLOAT32X8 zero = MS_MOV256_F32(0); + MS_FLOAT32X8 six = MS_MOV256_F32(6); + LoadAvx64Data; + MS_FLOAT32X8 bias_ptr = MS_LD256_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), MS_MUL256_N_F32(tmp3, 11.390625)), + src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X8 tmp1 = MS_ADD256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp2 = MS_ADD256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp3 = MS_ADD256_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X8 tmp4 = MS_SUB256_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X8 tmp5 = MS_SUB256_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X8 tmp6 = MS_SUB256_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.5), tmp5), MS_MUL256_N_F32(tmp6, 1.5)), + bias_ptr); + m[l + 14] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.25), tmp2), MS_MUL256_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.125), tmp5), MS_MUL256_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.0625), tmp2), MS_MUL256_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = MS_ADD256_F32( + MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp4, 0.03125), tmp5), MS_MUL256_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_ADD256_F32(MS_MUL256_N_F32(tmp1, 0.015625), tmp2), + MS_MUL256_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAX256_F32(zero, m[l]); + m[l] = MS_MIN256_F32(six, m[l]); + m[l + 7] = MS_MAX256_F32(zero, m[l + 7]); + m[l + 7] = MS_MIN256_F32(six, m[l + 7]); + m[l + 14] = MS_MAX256_F32(zero, m[l + 14]); + m[l + 14] = MS_MIN256_F32(six, m[l + 14]); + m[l + 21] = MS_MAX256_F32(zero, m[l + 21]); + m[l + 21] = MS_MIN256_F32(six, m[l + 21]); + m[l + 28] = MS_MAX256_F32(zero, m[l + 28]); + m[l + 28] = MS_MIN256_F32(six, m[l + 28]); + m[l + 35] = MS_MAX256_F32(zero, m[l + 35]); + m[l + 35] = MS_MIN256_F32(six, m[l + 35]); + m[l + 42] = MS_MAX256_F32(zero, m[l + 42]); + m[l + 42] = MS_MIN256_F32(six, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_ST256_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_ST256_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_ST256_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_ST256_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_ST256_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_ST256_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_ST256_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X8_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_avx.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_avx.h new file mode 100644 index 0000000000000000000000000000000000000000..d024907fd583e621831b17e69c84c4accc686f13 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_avx.h @@ -0,0 +1,299 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#ifndef MINDSPORE_NNACL_WINOGRAD_AVX_H_ +#define MINDSPORE_NNACL_WINOGRAD_AVX_H_ + +#include "nnacl/conv_parameter.h" +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +#define LoadAvx16Data \ + src[0] = MS_LD256_F32(src_data + 0 * src_step); \ + src[1] = MS_LD256_F32(src_data + 1 * src_step); \ + src[2] = MS_LD256_F32(src_data + 2 * src_step); \ + src[3] = MS_LD256_F32(src_data + 3 * src_step); \ + src[4] = MS_LD256_F32(src_data + 4 * src_step); \ + src[5] = MS_LD256_F32(src_data + 5 * src_step); \ + src[6] = MS_LD256_F32(src_data + 6 * src_step); \ + src[7] = MS_LD256_F32(src_data + 7 * src_step); \ + src[8] = MS_LD256_F32(src_data + 8 * src_step); \ + src[9] = MS_LD256_F32(src_data + 9 * src_step); \ + src[10] = MS_LD256_F32(src_data + 10 * src_step); \ + src[11] = MS_LD256_F32(src_data + 11 * src_step); \ + src[12] = MS_LD256_F32(src_data + 12 * src_step); \ + src[13] = MS_LD256_F32(src_data + 13 * src_step); \ + src[14] = MS_LD256_F32(src_data + 14 * src_step); \ + src[15] = MS_LD256_F32(src_data + 15 * src_step); + +#define LoadAvx36Data \ + src[0] = MS_LD256_F32(src_data + 0 * src_step); \ + src[1] = MS_LD256_F32(src_data + 1 * src_step); \ + src[2] = MS_LD256_F32(src_data + 2 * src_step); \ + src[3] = MS_LD256_F32(src_data + 3 * src_step); \ + src[4] = MS_LD256_F32(src_data + 4 * src_step); \ + src[5] = MS_LD256_F32(src_data + 5 * src_step); \ + src[6] = MS_LD256_F32(src_data + 6 * src_step); \ + src[7] = MS_LD256_F32(src_data + 7 * src_step); \ + src[8] = MS_LD256_F32(src_data + 8 * src_step); \ + src[9] = MS_LD256_F32(src_data + 9 * src_step); \ + src[10] = MS_LD256_F32(src_data + 10 * src_step); \ + src[11] = MS_LD256_F32(src_data + 11 * src_step); \ + src[12] = MS_LD256_F32(src_data + 12 * src_step); \ + src[13] = MS_LD256_F32(src_data + 13 * src_step); \ + src[14] = MS_LD256_F32(src_data + 14 * src_step); \ + src[15] = MS_LD256_F32(src_data + 15 * src_step); \ + src[16] = MS_LD256_F32(src_data + 16 * src_step); \ + src[17] = MS_LD256_F32(src_data + 17 * src_step); \ + src[18] = MS_LD256_F32(src_data + 18 * src_step); \ + src[19] = MS_LD256_F32(src_data + 19 * src_step); \ + src[20] = MS_LD256_F32(src_data + 20 * src_step); \ + src[21] = MS_LD256_F32(src_data + 21 * src_step); \ + src[22] = MS_LD256_F32(src_data + 22 * src_step); \ + src[23] = MS_LD256_F32(src_data + 23 * src_step); \ + src[24] = MS_LD256_F32(src_data + 24 * src_step); \ + src[25] = MS_LD256_F32(src_data + 25 * src_step); \ + src[26] = MS_LD256_F32(src_data + 26 * src_step); \ + src[27] = MS_LD256_F32(src_data + 27 * src_step); \ + src[28] = MS_LD256_F32(src_data + 28 * src_step); \ + src[29] = MS_LD256_F32(src_data + 29 * src_step); \ + src[30] = MS_LD256_F32(src_data + 30 * src_step); \ + src[31] = MS_LD256_F32(src_data + 31 * src_step); \ + src[32] = MS_LD256_F32(src_data + 32 * src_step); \ + src[33] = MS_LD256_F32(src_data + 33 * src_step); \ + src[34] = MS_LD256_F32(src_data + 34 * src_step); \ + src[35] = MS_LD256_F32(src_data + 35 * src_step); + +#define LoadAvx64Data \ + src[0] = MS_LD256_F32(src_data + 0 * src_step); \ + src[1] = MS_LD256_F32(src_data + 1 * src_step); \ + src[2] = MS_LD256_F32(src_data + 2 * src_step); \ + src[3] = MS_LD256_F32(src_data + 3 * src_step); \ + src[4] = MS_LD256_F32(src_data + 4 * src_step); \ + src[5] = MS_LD256_F32(src_data + 5 * src_step); \ + src[6] = MS_LD256_F32(src_data + 6 * src_step); \ + src[7] = MS_LD256_F32(src_data + 7 * src_step); \ + src[8] = MS_LD256_F32(src_data + 8 * src_step); \ + src[9] = MS_LD256_F32(src_data + 9 * src_step); \ + src[10] = MS_LD256_F32(src_data + 10 * src_step); \ + src[11] = MS_LD256_F32(src_data + 11 * src_step); \ + src[12] = MS_LD256_F32(src_data + 12 * src_step); \ + src[13] = MS_LD256_F32(src_data + 13 * src_step); \ + src[14] = MS_LD256_F32(src_data + 14 * src_step); \ + src[15] = MS_LD256_F32(src_data + 15 * src_step); \ + src[16] = MS_LD256_F32(src_data + 16 * src_step); \ + src[17] = MS_LD256_F32(src_data + 17 * src_step); \ + src[18] = MS_LD256_F32(src_data + 18 * src_step); \ + src[19] = MS_LD256_F32(src_data + 19 * src_step); \ + src[20] = MS_LD256_F32(src_data + 20 * src_step); \ + src[21] = MS_LD256_F32(src_data + 21 * src_step); \ + src[22] = MS_LD256_F32(src_data + 22 * src_step); \ + src[23] = MS_LD256_F32(src_data + 23 * src_step); \ + src[24] = MS_LD256_F32(src_data + 24 * src_step); \ + src[25] = MS_LD256_F32(src_data + 25 * src_step); \ + src[26] = MS_LD256_F32(src_data + 26 * src_step); \ + src[27] = MS_LD256_F32(src_data + 27 * src_step); \ + src[28] = MS_LD256_F32(src_data + 28 * src_step); \ + src[29] = MS_LD256_F32(src_data + 29 * src_step); \ + src[30] = MS_LD256_F32(src_data + 30 * src_step); \ + src[31] = MS_LD256_F32(src_data + 31 * src_step); \ + src[32] = MS_LD256_F32(src_data + 32 * src_step); \ + src[33] = MS_LD256_F32(src_data + 33 * src_step); \ + src[34] = MS_LD256_F32(src_data + 34 * src_step); \ + src[35] = MS_LD256_F32(src_data + 35 * src_step); \ + src[36] = MS_LD256_F32(src_data + 36 * src_step); \ + src[37] = MS_LD256_F32(src_data + 37 * src_step); \ + src[38] = MS_LD256_F32(src_data + 38 * src_step); \ + src[39] = MS_LD256_F32(src_data + 39 * src_step); \ + src[40] = MS_LD256_F32(src_data + 40 * src_step); \ + src[41] = MS_LD256_F32(src_data + 41 * src_step); \ + src[42] = MS_LD256_F32(src_data + 42 * src_step); \ + src[43] = MS_LD256_F32(src_data + 43 * src_step); \ + src[44] = MS_LD256_F32(src_data + 44 * src_step); \ + src[45] = MS_LD256_F32(src_data + 45 * src_step); \ + src[46] = MS_LD256_F32(src_data + 46 * src_step); \ + src[47] = MS_LD256_F32(src_data + 47 * src_step); \ + src[48] = MS_LD256_F32(src_data + 48 * src_step); \ + src[49] = MS_LD256_F32(src_data + 49 * src_step); \ + src[50] = MS_LD256_F32(src_data + 50 * src_step); \ + src[51] = MS_LD256_F32(src_data + 51 * src_step); \ + src[52] = MS_LD256_F32(src_data + 52 * src_step); \ + src[53] = MS_LD256_F32(src_data + 53 * src_step); \ + src[54] = MS_LD256_F32(src_data + 54 * src_step); \ + src[55] = MS_LD256_F32(src_data + 55 * src_step); \ + src[56] = MS_LD256_F32(src_data + 56 * src_step); \ + src[57] = MS_LD256_F32(src_data + 57 * src_step); \ + src[58] = MS_LD256_F32(src_data + 58 * src_step); \ + src[59] = MS_LD256_F32(src_data + 59 * src_step); \ + src[60] = MS_LD256_F32(src_data + 60 * src_step); \ + src[61] = MS_LD256_F32(src_data + 61 * src_step); \ + src[62] = MS_LD256_F32(src_data + 62 * src_step); \ + src[63] = MS_LD256_F32(src_data + 63 * src_step); + +#define StoreAvx4Data \ + MS_ST256_F32(dst_data, m[0]); \ + MS_ST256_F32(dst_data + out_c, m[1]); \ + MS_ST256_F32(dst_data + dst_step * out_c, m[2]); \ + MS_ST256_F32(dst_data + dst_step * out_c + out_c, m[3]); + +#define StoreAvx9Data \ + MS_ST256_F32(dst_data, m[0]); \ + MS_ST256_F32(dst_data + out_c, m[1]); \ + MS_ST256_F32(dst_data + 2 * out_c, m[2]); \ + MS_ST256_F32(dst_data + dst_step * out_c, m[3]); \ + MS_ST256_F32(dst_data + dst_step * out_c + out_c, m[4]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c, m[6]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define StoreAvx16Data \ + MS_ST256_F32(dst_data, m[0]); \ + MS_ST256_F32(dst_data + out_c, m[1]); \ + MS_ST256_F32(dst_data + 2 * out_c, m[2]); \ + MS_ST256_F32(dst_data + 3 * out_c, m[3]); \ + MS_ST256_F32(dst_data + dst_step * out_c, m[4]); \ + MS_ST256_F32(dst_data + dst_step * out_c + out_c, m[5]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c, m[8]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c, m[12]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define StoreAvx25Data \ + MS_ST256_F32(dst_data, m[0]); \ + MS_ST256_F32(dst_data + out_c, m[1]); \ + MS_ST256_F32(dst_data + 2 * out_c, m[2]); \ + MS_ST256_F32(dst_data + 3 * out_c, m[3]); \ + MS_ST256_F32(dst_data + 4 * out_c, m[4]); \ + MS_ST256_F32(dst_data + dst_step * out_c, m[5]); \ + MS_ST256_F32(dst_data + dst_step * out_c + out_c, m[6]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + MS_ST256_F32(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c, m[10]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + MS_ST256_F32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c, m[15]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + MS_ST256_F32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c, m[20]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + MS_ST256_F32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +void InputTransform4x4AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8AvxUnit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void OutputTransform4x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform6x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform8x2AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7ReluAvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Relu6AvxUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_WINOGRAD_AVX_H_ +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_transform.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_transform.c new file mode 100644 index 0000000000000000000000000000000000000000..9401ddddba5563cfa81d5817bb0b8035bc9302b4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_transform.c @@ -0,0 +1,281 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32/winograd_transform.h" +#include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +void PrepareTransInput(const float *src_data, float *dst_data, int interval_x_s, int interval_x_e, int interval_y_s, + int interval_y_e, int real_c, const ConvParameter *conv_param) { + int input_unit = conv_param->input_unit_; + int in_channel = conv_param->input_channel_; + int input_w = conv_param->input_w_; +#ifdef ENABLE_AVX + int channel_tile = C8NUM; +#else + int channel_tile = C4NUM; +#endif + // clear tmp buffer + if (interval_x_e - interval_x_s != input_unit || interval_y_e - interval_y_s != input_unit) { + memset(dst_data, 0, input_unit * input_unit * channel_tile * (int)(sizeof(float))); + } + + // get real input block with padding + if (real_c == channel_tile) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * channel_tile + interval_x_s * channel_tile; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * channel_tile; + const float *src_addr = src_data + src_x_offset; + float *dst_addr = dst_data + dst_x_offset; +#ifdef ENABLE_AVX + MS_ST256_F32(dst_addr, MS_LD256_F32(src_addr)); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_STQ_F32(dst_addr, MS_LDQ_F32(src_addr)); +#else + for (int k = 0; k < channel_tile; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + } // interval x loop + } // interval y loop + } else { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * channel_tile + interval_x_s * channel_tile; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * channel_tile; + const float *src_addr = src_data + src_x_offset; + float *dst_addr = dst_data + dst_x_offset; + for (int k = 0; k < real_c; k++) { + dst_addr[k] = src_addr[k]; + } + } // interval x loop + } // interval y loop + } +} + +// fp32 conv winograd +void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransFunc func) { + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; +#ifdef ENABLE_AVX + int channel_tile = C8NUM; +#else + int channel_tile = C4NUM; +#endif + int ic4 = UP_DIV(in_channel, channel_tile); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + NNACL_CHECK_ZERO_RETURN(out_w_block_num); + + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * in_channel; + for (int ic = 0; ic < ic4; ic++) { + int real_c = in_channel - ic * channel_tile; + real_c = real_c > channel_tile ? channel_tile : real_c; + const float *src_data = input_data + src_plane_offset + ic * channel_tile; + PrepareTransInput(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, conv_param); + + // input transform + const int tile_num = C12NUM; + int dst_ic4_offset = dst_plane_offset + ic * channel_tile; + int dst_step = tile_num * in_channel; + float *trans_input_ptr = trans_input + dst_ic4_offset; + func(tmp_data, trans_input_ptr, channel_tile, dst_step, real_c); + } + out_tile_index++; + } // cal_tile_num loop +} + +// Only support arm64 +void WinogradInputTransformOptStep(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransStepFunc func) { + int input_unit = conv_param->input_unit_; + int output_unit = conv_param->output_unit_; + int in_channel = conv_param->input_channel_; + int channel_tile = C4NUM; + int ic4 = UP_DIV(in_channel, channel_tile); + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int input_h = conv_param->input_h_; + int input_w = conv_param->input_w_; + NNACL_CHECK_ZERO_RETURN(out_w_block_num); + + for (int c = 0; c < cal_num; c++) { // actual tiled number + int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w; + int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h; + int interval_x_s = src_x_s > 0 ? 0 : -src_x_s; + int interval_y_s = src_y_s > 0 ? 0 : -src_y_s; + int src_x_e = src_x_s + input_unit; + int src_y_e = src_y_s + input_unit; + int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); + int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); + + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); + int dst_plane_offset = c * channel_tile; + for (int ic = 0; ic < ic4; ic++) { + int real_c = in_channel - ic * channel_tile; + real_c = real_c > channel_tile ? channel_tile : real_c; + const float *src_data = input_data + src_plane_offset + ic * channel_tile; + PrepareTransInput(src_data, tmp_data, interval_x_s, interval_x_e, interval_y_s, interval_y_e, real_c, conv_param); + + // input transform + const int block_tile = C12NUM; + int dst_ic8_offset = dst_plane_offset + ic * block_tile * input_unit * input_unit * channel_tile; + size_t dst_step = (size_t)(input_unit * block_tile * channel_tile); + float *trans_input_ptr = trans_input + dst_ic8_offset; + func(tmp_data, trans_input_ptr, channel_tile, dst_step, block_tile * channel_tile); + } + out_tile_index++; + } // cal_tile_num loop +} + +void WinogradOutputNHWCTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, const ConvParameter *conv_param, + OutputTransFunc func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int output_channel = conv_param->output_channel_; +#ifndef ENABLE_AVX + int oc4 = UP_DIV(output_channel, C4NUM); +#endif + int oc8 = UP_DIV(output_channel, C8NUM); + int input_unit = conv_param->input_unit_; + NNACL_CHECK_ZERO_RETURN(output_unit_num); + + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; + int dst_tile_offset = output_channel * (dst_x_s + dst_y_s * output_w); +#ifndef ENABLE_AVX + // avx is write to nc4hw4 + for (int j = 0; j < oc4; j++) { + int c8_block = j / 2; + int c8_res = j % 2; + int r_c = output_channel - j * C4NUM; + r_c = r_c > C4NUM ? C4NUM : r_c; + int src_oc4_offset = src_tile_offset + c8_block * input_unit * input_unit * C8NUM + c8_res * C4NUM; + int dst_oc4_offset = dst_tile_offset + j * C4NUM; + const float *src_ptr = gemm_out + src_oc4_offset; + const float *bias_ptr = bias_data + j * C4NUM; + float *dst_ptr = out_data + dst_oc4_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c); + } +#else + // avx is write to nc8hw8 + for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; + int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; + int dst_oc8_offset = dst_tile_offset + j * C8NUM; + const float *src_ptr = gemm_out + src_oc8_offset; + const float *bias_ptr = bias_data + j * C8NUM; + float *dst_ptr = out_data + dst_oc8_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c); + } +#endif + out_tile_index++; + } +} + +void WinogradOutputNC4HW4Transform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, const ConvParameter *conv_param, + OutputTransFunc func) { + int output_unit = conv_param->output_unit_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + int output_plane = output_w * output_h; + int output_channel = conv_param->output_channel_; +#ifndef ENABLE_AVX + int oc4 = UP_DIV(output_channel, C4NUM); +#endif + int oc8 = UP_DIV(output_channel, C8NUM); + int input_unit = conv_param->input_unit_; + NNACL_CHECK_ZERO_RETURN(output_unit_num); + + for (int i = 0; i < cal_num; i++) { + int dst_x_s = out_tile_index % output_unit_num; + int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; + int dst_tile_offset = dst_x_s + dst_y_s * output_w; +#ifdef ENABLE_AVX + for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; + int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; + int dst_oc8_offset = (dst_tile_offset + output_plane * j) * C8NUM; + const float *src_ptr = gemm_out + src_oc8_offset; + const float *bias_ptr = bias_data + j * C8NUM; + float *dst_ptr = out_data + dst_oc8_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, r_c, r_w, r_h, r_c); + } +#else + for (int j = 0; j < oc4; j++) { + int c8_block = j / 2; + int c8_res = j % 2; + int r_c = output_channel - j * C4NUM; + r_c = r_c > C4NUM ? C4NUM : r_c; + int src_oc4_offset = src_tile_offset + c8_block * input_unit * input_unit * C8NUM + c8_res * C4NUM; + int dst_oc4_offset = (dst_tile_offset + output_plane * j) * C4NUM; + const float *src_ptr = gemm_out + src_oc4_offset; + const float *bias_ptr = bias_data + j * C4NUM; + float *dst_ptr = out_data + dst_oc4_offset; + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, r_c, r_w, r_h, r_c); + } +#endif + out_tile_index++; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_transform.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_transform.h new file mode 100644 index 0000000000000000000000000000000000000000..ab9bf1161f43457d3d94d75710aaa8206dcf3fe4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_transform.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_WINOGRAD_TRANSFORM_H_ +#define MINDSPORE_NNACL_WINOGRAD_TRANSFORM_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include +#include "nnacl/pack.h" +#include "nnacl/fp32/winograd_utils.h" + +#ifdef __cplusplus +extern "C" { +#endif +// for fp32 winograd input/output transform +void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransFunc func); + +void WinogradInputTransformOptStep(const float *input_data, float *trans_input, float *tmp_data, int cal_num, + int out_tile_index, int out_w_block_num, const ConvParameter *conv_param, + InputTransStepFunc func); + +void WinogradOutputNHWCTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, const ConvParameter *conv_param, + OutputTransFunc func); + +void WinogradOutputNC4HW4Transform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num, + int out_tile_index, int output_unit_num, const ConvParameter *conv_param, + OutputTransFunc func); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_WINOGRAD_TRANSFORM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_utils.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..4dc39add2b18e38230b620311a610ac4171ce41a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_utils.c @@ -0,0 +1,4289 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32/winograd_utils.h" +#include "nnacl/fp32/winograd_avx.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/base/minimal_filtering_generator.h" +#include "nnacl/base/conv_common_base.h" +#include "nnacl/errorcode.h" + +#ifdef ENABLE_ARM64 +void transpose4(MS_FLOAT32X4 *s0, MS_FLOAT32X4 *s1, MS_FLOAT32X4 *s2, MS_FLOAT32X4 *s3) { + float64x2_t m0 = (float64x2_t)(vtrn1q_f32(*s0, *s1)); + float64x2_t m1 = (float64x2_t)(vtrn2q_f32(*s0, *s1)); + float64x2_t m2 = (float64x2_t)(vtrn1q_f32(*s2, *s3)); + float64x2_t m3 = (float64x2_t)(vtrn2q_f32(*s2, *s3)); + *s0 = (float32x4_t)(vtrn1q_f64(m0, m2)); + *s2 = (float32x4_t)(vtrn2q_f64(m0, m2)); + *s1 = (float32x4_t)(vtrn1q_f64(m1, m3)); + *s3 = (float32x4_t)(vtrn2q_f64(m1, m3)); +} +#endif + +#ifdef ENABLE_AVX +static InputTransFunc InputTransFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4AvxUnit, NULL, InputTransform6x6AvxUnit, NULL, InputTransform8x8AvxUnit}; + +static OutputTransFunc OutputTransFuncList[] = { + OutputTransform4x2AvxUnit, OutputTransform4x3AvxUnit, OutputTransform4x2ReluAvxUnit, + OutputTransform4x3ReluAvxUnit, OutputTransform4x2Relu6AvxUnit, OutputTransform4x3Relu6AvxUnit, + OutputTransform6x2AvxUnit, OutputTransform6x3AvxUnit, OutputTransform6x4AvxUnit, + OutputTransform6x5AvxUnit, OutputTransform6x2ReluAvxUnit, OutputTransform6x3ReluAvxUnit, + OutputTransform6x4ReluAvxUnit, OutputTransform6x5ReluAvxUnit, OutputTransform6x2Relu6AvxUnit, + OutputTransform6x3Relu6AvxUnit, OutputTransform6x4Relu6AvxUnit, OutputTransform6x5Relu6AvxUnit, + OutputTransform8x2AvxUnit, OutputTransform8x3AvxUnit, OutputTransform8x4AvxUnit, + OutputTransform8x5AvxUnit, OutputTransform8x6AvxUnit, OutputTransform8x7AvxUnit, + OutputTransform8x2ReluAvxUnit, OutputTransform8x3ReluAvxUnit, OutputTransform8x4ReluAvxUnit, + OutputTransform8x5ReluAvxUnit, OutputTransform8x6ReluAvxUnit, OutputTransform8x7ReluAvxUnit, + OutputTransform8x2Relu6AvxUnit, OutputTransform8x3Relu6AvxUnit, OutputTransform8x4Relu6AvxUnit, + OutputTransform8x5Relu6AvxUnit, OutputTransform8x6Relu6AvxUnit, OutputTransform8x7Relu6AvxUnit}; +#else +static InputTransFunc InputTransFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4Unit, NULL, InputTransform6x6Unit, NULL, InputTransform8x8Unit}; + +static OutputTransFunc OutputTransFuncList[] = { + OutputTransform4x2Unit, OutputTransform4x3Unit, OutputTransform4x2ReluUnit, OutputTransform4x3ReluUnit, + OutputTransform4x2Relu6Unit, OutputTransform4x3Relu6Unit, OutputTransform6x2Unit, OutputTransform6x3Unit, + OutputTransform6x4Unit, OutputTransform6x5Unit, OutputTransform6x2ReluUnit, OutputTransform6x3ReluUnit, + OutputTransform6x4ReluUnit, OutputTransform6x5ReluUnit, OutputTransform6x2Relu6Unit, OutputTransform6x3Relu6Unit, + OutputTransform6x4Relu6Unit, OutputTransform6x5Relu6Unit, OutputTransform8x2Unit, OutputTransform8x3Unit, + OutputTransform8x4Unit, OutputTransform8x5Unit, OutputTransform8x6Unit, OutputTransform8x7Unit, + OutputTransform8x2ReluUnit, OutputTransform8x3ReluUnit, OutputTransform8x4ReluUnit, OutputTransform8x5ReluUnit, + OutputTransform8x6ReluUnit, OutputTransform8x7ReluUnit, OutputTransform8x2Relu6Unit, OutputTransform8x3Relu6Unit, + OutputTransform8x4Relu6Unit, OutputTransform8x5Relu6Unit, OutputTransform8x6Relu6Unit, OutputTransform8x7Relu6Unit}; +#endif + +InputTransFunc GetInputTransFunc(int input_unit) { return InputTransFuncList[input_unit]; } + +#ifdef ENABLE_ARM64 +static InputTransStepFunc InputTransStepFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4Step, NULL, InputTransform6x6Step, NULL, InputTransform8x8Step}; + +static InputTransPackFunc InputTransPackFuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4Pack12, NULL, InputTransform6x6Pack12, NULL, InputTransform8x8Pack12}; + +InputTransStepFunc GetInputTransStepFunc(int input_unit) { return InputTransStepFuncList[input_unit]; } + +InputTransPackFunc GetInputTransPackFunc(int input_unit) { return InputTransPackFuncList[input_unit]; } +#endif + +void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + if (real_c == 4) { + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[16]; + + src[0] = MS_LDQ_F32(src_data); + src[1] = MS_LDQ_F32(src_data + src_step); + src[2] = MS_LDQ_F32(src_data + 2 * src_step); + src[3] = MS_LDQ_F32(src_data + 3 * src_step); + + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + t[l] = MS_SUBQ_F32(src[offset], src[2 + offset]); + src[offset + 4] = MS_LDQ_F32(src_data + (offset + 4) * src_step); + t[4 + l] = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + src[offset + 5] = MS_LDQ_F32(src_data + (offset + 5) * src_step); + t[8 + l] = MS_SUBQ_F32(src[2 + offset], src[1 + offset]); + src[offset + 6] = MS_LDQ_F32(src_data + (offset + 6) * src_step); + t[12 + l] = MS_SUBQ_F32(src[3 + offset], src[1 + offset]); + src[offset + 7] = MS_LDQ_F32(src_data + (offset + 7) * src_step); + } + + int offset = 3 * 4; + t[3] = MS_SUBQ_F32(src[offset], src[2 + offset]); + t[7] = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + t[11] = MS_SUBQ_F32(src[2 + offset], src[1 + offset]); + t[15] = MS_SUBQ_F32(src[3 + offset], src[1 + offset]); + + src[0] = MS_SUBQ_F32(t[0], t[2]); + src[1] = MS_ADDQ_F32(t[1], t[2]); + src[2] = MS_SUBQ_F32(t[2], t[1]); + src[3] = MS_SUBQ_F32(t[3], t[1]); + + for (int l = 1; l < 4; ++l) { + offset = l * 4; + src[offset] = MS_SUBQ_F32(t[offset], t[2 + offset]); + MS_STQ_F32(dst_data + (l - 1) * dst_step, src[offset - 4]); + src[offset + 1] = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_STQ_F32(dst_data + (3 + l) * dst_step, src[offset - 3]); + src[offset + 2] = MS_SUBQ_F32(t[2 + offset], t[1 + offset]); + MS_STQ_F32(dst_data + (7 + l) * dst_step, src[offset - 2]); + src[offset + 3] = MS_SUBQ_F32(t[3 + offset], t[1 + offset]); + MS_STQ_F32(dst_data + (11 + l) * dst_step, src[offset - 1]); + } + + MS_STQ_F32(dst_data + 3 * dst_step, src[12]); + MS_STQ_F32(dst_data + dst_step * 7, src[13]); + MS_STQ_F32(dst_data + dst_step * 11, src[14]); + MS_STQ_F32(dst_data + dst_step * 15, src[15]); + + } else { +#endif + float src[16]; + float t[16]; + float m[16]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] - src[2 + offset]; + t[4 + l] = src[1 + offset] + src[2 + offset]; + t[8 + l] = src[2 + offset] - src[1 + offset]; + t[12 + l] = src[3 + offset] - src[1 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = t[offset] - t[2 + offset]; + m[4 + l] = t[1 + offset] + t[2 + offset]; + m[8 + l] = t[2 + offset] - t[1 + offset]; + m[12 + l] = t[3 + offset] - t[1 + offset]; + } + for (int k = 0; k < 16; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + } +#endif +} + +void InputTransform4x4Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) { +#ifdef ENABLE_ARM64 + for (int l = 0; l < 4; ++l) { + const float *src_ptr = src_data + l * 4 * src_step; + float *dst_ptr = dst_data + l * dst_row_step; + + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step); + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step); + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step); + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step); + MS_FLOAT32X4 m0 = MS_SUBQ_F32(s0, s2); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(s1, s2); + MS_FLOAT32X4 m2 = MS_SUBQ_F32(s2, s1); + MS_FLOAT32X4 m3 = MS_SUBQ_F32(s3, s1); + + MS_STQ_F32(dst_ptr + 0 * dst_step, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step, m2); + MS_STQ_F32(dst_ptr + 3 * dst_step, m3); + } +#else + float src[4]; + float m[4]; + for (int i = 0; i < C4NUM; ++i) { + for (int l = 0; l < 4; ++l) { + for (int w = 0; w < 4; ++w) { + int tmp_index = l * 4 + w; + src[w] = src_data[i + tmp_index * src_step]; + } + m[0] = src[0] - src[2]; + m[1] = src[1] + src[2]; + m[2] = src[2] - src[1]; + m[3] = src[3] - src[1]; + + float *dst = dst_data + l * dst_row_step; + for (int w = 0; w < 4; ++w) { + dst[i + w * dst_step] = m[w]; + } + } + } +#endif +} + +#ifdef ENABLE_ARM64 +void InputTransform4x4Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) { + LOAD_LINE_DATA(0); + LOAD_LINE_DATA(1); + LOAD_LINE_DATA(2); + LOAD_LINE_DATA(3); + + MS_FLOAT32X4 m0 = MS_SUBQ_F32(s00, s20); + MS_FLOAT32X4 m1 = MS_SUBQ_F32(s01, s21); + MS_FLOAT32X4 m2 = MS_SUBQ_F32(s02, s22); + MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(s10, s20); + m1 = MS_ADDQ_F32(s11, s21); + m2 = MS_ADDQ_F32(s12, s22); + MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2); + + m0 = MS_SUBQ_F32(s20, s10); + m1 = MS_SUBQ_F32(s21, s11); + m2 = MS_SUBQ_F32(s22, s12); + MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2); + + m0 = MS_SUBQ_F32(s30, s10); + m1 = MS_SUBQ_F32(s31, s11); + m2 = MS_SUBQ_F32(s32, s12); + MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2); +} +#endif + +void InputTransform4x4Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 12; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; +#ifdef ENABLE_ARM64 + for (int l = 0; l < 4; ++l) { + float *src_ptr = src_data + l * C4NUM * block_tile; + TRANSPOSE_12x4; + } + + for (int c = 0; c < real_c; ++c) { + float *src_ptr = src_data + c * block_tile; + float *dst_ptr = dst_data + c * block_tile; + InputTransform4x4Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +#else + for (int l = 0; l < 4; ++l) { + float *src = src_data + l * pack_tile * block_tile; + // 12 * 4 -> 4 * 12 + float tmp_mat[4][12]; + for (int i = 0; i < block_tile; ++i) { + for (int j = 0; j < pack_tile; ++j) { + tmp_mat[j][i] = src[i * pack_tile + j]; + } + } + memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float)); + } + + float src[4]; + float m[4]; + for (int c = 0; c < real_c; ++c) { + for (int i = 0; i < block_tile; ++i) { + int tmp_index = c * block_tile + i; + for (int w = 0; w < 4; ++w) { + src[w] = src_data[tmp_index + w * src_point_stride]; + } + + m[0] = src[0] - src[2]; + m[1] = src[1] + src[2]; + m[2] = src[2] - src[1]; + m[3] = src[3] - src[1]; + + for (int w = 0; w < 4; ++w) { + dst_data[tmp_index + w * dst_step] = m[w]; + } + } + } +#endif +} + +void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + if (real_c == 4) { + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[36]; + MS_FLOAT32X4 m[36]; + Load36Data; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(src[3 + offset], src[1 + offset]); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(src[4 + offset], src[2 + offset]); + t[l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(src[offset], 4), MS_MULQ_N_F32(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(src[1 + offset], src[2 + offset]), -4), + MS_ADDQ_F32(src[3 + offset], src[4 + offset])); + t[12 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), 4), + MS_SUBQ_F32(src[4 + offset], src[3 + offset])); + t[18 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2); + t[24 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2); + t[30 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(src[1 + offset], 4), MS_MULQ_N_F32(src[3 + offset], 5)), src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(t[3 + offset], t[1 + offset]); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(t[4 + offset], t[2 + offset]); + m[l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(t[offset], 4), MS_MULQ_N_F32(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(t[1 + offset], t[2 + offset]), -4), + MS_ADDQ_F32(t[3 + offset], t[4 + offset])); + m[12 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), 4), + MS_SUBQ_F32(t[4 + offset], t[3 + offset])); + m[18 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2); + m[24 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2); + m[30 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(t[1 + offset], 4), MS_MULQ_N_F32(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + MS_STQ_F32(dst_data + i * dst_step, m[i]); + } + } else { +#endif + float src[36]; + float t[36]; + float m[36]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = src[3 + offset] - src[1 + offset]; + float tmp2 = src[4 + offset] - src[2 + offset]; + t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset]; + t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]); + t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]); + t[18 + l] = 2 * tmp1 + tmp2; + t[24 + l] = -2 * tmp1 + tmp2; + t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = t[3 + offset] - t[1 + offset]; + float tmp2 = t[4 + offset] - t[2 + offset]; + m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset]; + m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]); + m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]); + m[18 + l] = 2 * tmp1 + tmp2; + m[24 + l] = -2 * tmp1 + tmp2; + m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset]; + } + for (int k = 0; k < 36; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + } +#endif +} + +void InputTransform6x6Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) { +#ifdef ENABLE_ARM64 + for (int l = 0; l < 6; ++l) { + const float *src_ptr = src_data + l * 6 * src_step; + float *dst_ptr = dst_data + l * dst_row_step; + + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step); + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step); + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step); + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step); + MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 4 * src_step); + MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 5 * src_step); + + MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(s3, s1); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(s4, s2); + MS_FLOAT32X4 m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s0, 4), MS_MULQ_N_F32(s2, 5)), s4); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s1, s2), -4), MS_ADDQ_F32(s3, s4)); + MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s1, s2), 4), MS_SUBQ_F32(s4, s3)); + MS_FLOAT32X4 m3 = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2); + MS_FLOAT32X4 m4 = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2); + MS_FLOAT32X4 m5 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s1, 4), MS_MULQ_N_F32(s3, 5)), s5); + + MS_STQ_F32(dst_ptr + 0 * dst_step, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step, m2); + MS_STQ_F32(dst_ptr + 3 * dst_step, m3); + MS_STQ_F32(dst_ptr + 4 * dst_step, m4); + MS_STQ_F32(dst_ptr + 5 * dst_step, m5); + } +#else + float src[6]; + float m[6]; + for (int i = 0; i < C4NUM; ++i) { + for (int l = 0; l < 6; ++l) { + for (int w = 0; w < 6; ++w) { + int tmp_index = l * 6 + w; + src[w] = src_data[i + tmp_index * src_step]; + } + float tmp1 = src[3] - src[1]; + float tmp2 = src[4] - src[2]; + m[0] = 4 * src[0] - 5 * src[2] + src[4]; + m[1] = -4 * (src[1] + src[2]) + (src[3] + src[4]); + m[2] = 4 * (src[1] - src[2]) + (src[4] - src[3]); + m[3] = 2 * tmp1 + tmp2; + m[4] = -2 * tmp1 + tmp2; + m[5] = 4 * src[1] - 5 * src[3] + src[5]; + + float *dst = dst_data + l * dst_row_step; + for (int w = 0; w < 6; ++w) { + dst[i + w * dst_step] = m[w]; + } + } + } +#endif +} + +#ifdef ENABLE_ARM64 +void InputTransform6x6Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) { + LOAD_LINE_DATA(0); + LOAD_LINE_DATA(1); + LOAD_LINE_DATA(2); + LOAD_LINE_DATA(3); + LOAD_LINE_DATA(4); + LOAD_LINE_DATA(5); + + MS_FLOAT32X4 m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s00, 4), MS_MULQ_N_F32(s20, 5)), s40); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s01, 4), MS_MULQ_N_F32(s21, 5)), s41); + MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s02, 4), MS_MULQ_N_F32(s22, 5)), s42); + MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s10, s20), -4), MS_ADDQ_F32(s30, s40)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s11, s21), -4), MS_ADDQ_F32(s31, s41)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(s12, s22), -4), MS_ADDQ_F32(s32, s42)); + MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s10, s20), 4), MS_SUBQ_F32(s40, s30)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s11, s21), 4), MS_SUBQ_F32(s41, s31)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s12, s22), 4), MS_SUBQ_F32(s42, s32)); + MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s30, s10), 2), MS_SUBQ_F32(s40, s20)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s31, s11), 2), MS_SUBQ_F32(s41, s21)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s32, s12), 2), MS_SUBQ_F32(s42, s22)); + MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s30, s10), -2), MS_SUBQ_F32(s40, s20)); + m1 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s31, s11), -2), MS_SUBQ_F32(s41, s21)); + m2 = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(s32, s12), -2), MS_SUBQ_F32(s42, s22)); + MS_STQ_F32(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 4 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s10, 4), MS_MULQ_N_F32(s30, 5)), s50); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s11, 4), MS_MULQ_N_F32(s31, 5)), s51); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s12, 4), MS_MULQ_N_F32(s32, 5)), s52); + MS_STQ_F32(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 5 * dst_step + 2 * pack_tile, m2); +} +#endif + +void InputTransform6x6Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 12; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; +#ifdef ENABLE_ARM64 + for (int l = 0; l < 6; ++l) { + float *src_ptr = src_data + l * C4NUM * block_tile; + TRANSPOSE_12x4; + } + + for (int c = 0; c < real_c; ++c) { + float *src_ptr = src_data + c * block_tile; + float *dst_ptr = dst_data + c * block_tile; + InputTransform6x6Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +#else + for (int l = 0; l < 6; ++l) { + float *src = src_data + l * pack_tile * block_tile; + // 12 * 4 -> 4 * 12 + float tmp_mat[4][12]; + for (int i = 0; i < block_tile; ++i) { + for (int j = 0; j < pack_tile; ++j) { + tmp_mat[j][i] = src[i * pack_tile + j]; + } + } + memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float)); + } + + float src[6]; + float m[6]; + for (int c = 0; c < real_c; ++c) { + for (int i = 0; i < block_tile; ++i) { + int tmp_index = c * block_tile + i; + for (int w = 0; w < 6; ++w) { + src[w] = src_data[tmp_index + w * src_point_stride]; + } + + float tmp1 = src[3] - src[1]; + float tmp2 = src[4] - src[2]; + m[0] = 4 * src[0] - 5 * src[2] + src[4]; + m[1] = -4 * (src[1] + src[2]) + (src[3] + src[4]); + m[2] = 4 * (src[1] - src[2]) + (src[4] - src[3]); + m[3] = 2 * tmp1 + tmp2; + m[4] = -2 * tmp1 + tmp2; + m[5] = 4 * src[1] - 5 * src[3] + src[5]; + + for (int w = 0; w < 6; ++w) { + dst_data[tmp_index + w * dst_step] = m[w]; + } + } + } +#endif +} + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void InputTransform8x8Unit_block4(const float *src_data, float *dst_data, int src_step, int dst_step) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[64]; + MS_FLOAT32X4 m[64]; + Load64Data; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = + MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(src[offset], 0.5625), MS_MULQ_N_F32(src[2 + offset], 3.0625)), + MS_MULQ_N_F32(src[4 + offset], 3.5)), + src[6 + offset]); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], 1.125), MS_MULQ_N_F32(src[5 + offset], 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(src[2 + offset], 2.25), MS_MULQ_N_F32(src[4 + offset], 3.25)); + t[8 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(src[2 + offset], 0.5625), MS_MULQ_N_F32(src[4 + offset], 2.5)); + t[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], 0.375), MS_MULQ_N_F32(src[5 + offset], 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(src[2 + offset], 0.25), MS_MULQ_N_F32(src[4 + offset], 1.25)); + t[40 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], -0.5625), MS_MULQ_N_F32(src[3 + offset], 3.0625)), + MS_MULQ_N_F32(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(t[offset], 0.5625), MS_MULQ_N_F32(t[2 + offset], 3.0625)), + MS_MULQ_N_F32(t[4 + offset], 3.5)), + t[6 + offset]); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], 1.125), MS_MULQ_N_F32(t[5 + offset], 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(t[2 + offset], 2.25), MS_MULQ_N_F32(t[4 + offset], 3.25)); + m[8 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(t[2 + offset], 0.5625), MS_MULQ_N_F32(t[4 + offset], 2.5)); + m[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], 0.375), MS_MULQ_N_F32(t[5 + offset], 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(t[2 + offset], 0.25), MS_MULQ_N_F32(t[4 + offset], 1.25)); + m[40 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], -0.5625), MS_MULQ_N_F32(t[3 + offset], 3.0625)), + MS_MULQ_N_F32(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + MS_STQ_F32(dst_data + i * dst_step, m[i]); + } +} +#endif + +void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + if (real_c == 4) { + InputTransform8x8Unit_block4(src_data, dst_data, src_step, dst_step); + } else { +#endif + float src[64]; + float t[64]; + float m[64]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset]; + float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset]; + float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset]; + t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.5625f * src[1 + offset] + src[5 + offset]; + tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset]; + t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset]; + tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset]; + t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset]; + t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset]; + float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset]; + float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset]; + m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.5625f * t[1 + offset] + t[5 + offset]; + tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset]; + m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset]; + tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset]; + m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset]; + m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset]; + } + for (int k = 0; k < 64; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + } +#endif +} + +void InputTransform8x8Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step) { +#ifdef ENABLE_ARM64 + for (int l = 0; l < 8; ++l) { + const float *src_ptr = src_data + l * 8 * src_step; + float *dst_ptr = dst_data + l * dst_row_step; + + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * src_step); + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 1 * src_step); + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 2 * src_step); + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 3 * src_step); + MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 4 * src_step); + MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 5 * src_step); + MS_FLOAT32X4 s6 = MS_LDQ_F32(src_ptr + 6 * src_step); + MS_FLOAT32X4 s7 = MS_LDQ_F32(src_ptr + 7 * src_step); + + MS_FLOAT32X4 m0 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s0, 0.5625), MS_MULQ_N_F32(s2, 3.0625)), MS_MULQ_N_F32(s4, 3.5)), s6); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 1.125), MS_MULQ_N_F32(s5, 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 2.25), MS_MULQ_N_F32(s4, 3.25)); + MS_FLOAT32X4 m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 1.625)), s6); + MS_FLOAT32X4 m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 1.625)), s6); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 0.5625), s5); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 0.5625), MS_MULQ_N_F32(s4, 2.5)); + MS_FLOAT32X4 m3 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 2.5)), s6); + MS_FLOAT32X4 m4 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 2.5)), s6); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(s1, 0.375), MS_MULQ_N_F32(s5, 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(s2, 0.25), MS_MULQ_N_F32(s4, 1.25)); + MS_FLOAT32X4 m5 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(s3, 1.875)), s6); + MS_FLOAT32X4 m6 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(s3, 1.875)), s6); + MS_FLOAT32X4 m7 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s1, -0.5625), MS_MULQ_N_F32(s3, 3.0625)), MS_MULQ_N_F32(s5, 3.5)), s7); + + MS_STQ_F32(dst_ptr + 0 * dst_step, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step, m2); + MS_STQ_F32(dst_ptr + 3 * dst_step, m3); + MS_STQ_F32(dst_ptr + 4 * dst_step, m4); + MS_STQ_F32(dst_ptr + 5 * dst_step, m5); + MS_STQ_F32(dst_ptr + 6 * dst_step, m6); + MS_STQ_F32(dst_ptr + 7 * dst_step, m7); + } +#else + float src[8]; + float m[8]; + for (int i = 0; i < C4NUM; ++i) { + for (int l = 0; l < 8; ++l) { + for (int w = 0; w < 8; ++w) { + int tmp_index = l * 8 + w; + src[w] = src_data[i + tmp_index * src_step]; + } + m[0] = 0.5625f * src[0] - 3.0625f * src[2] + 3.5f * src[4] - src[6]; + float tmp1 = 1.125f * src[1] + 0.5f * src[5]; + float tmp2 = 2.25f * src[2] - 3.25f * src[4]; + m[1] = tmp1 + tmp2 - 1.625f * src[3] + src[6]; + m[2] = tmp2 - tmp1 + 1.625f * src[3] + src[6]; + tmp1 = 0.5625f * src[1] + src[5]; + tmp2 = 0.5625f * src[2] - 2.5f * src[4]; + m[3] = tmp1 + tmp2 - 2.5f * src[3] + src[6]; + m[4] = tmp2 - tmp1 + 2.5f * src[3] + src[6]; + tmp1 = 0.375f * src[1] + 1.5f * src[5]; + tmp2 = 0.25f * src[2] - 1.25f * src[4]; + m[5] = tmp1 + tmp2 - 1.875f * src[3] + src[6]; + m[6] = tmp2 - tmp1 + 1.875f * src[3] + src[6]; + m[7] = -0.5625f * src[1] + 3.0625f * src[3] - 3.5f * src[5] + src[7]; + + float *dst = dst_data + l * dst_row_step; + for (int w = 0; w < 8; ++w) { + dst[i + w * dst_step] = m[w]; + } + } + } +#endif +} + +#ifdef ENABLE_ARM64 +void InputTransform8x8Pack12Channel(float *src_ptr, float *dst_ptr, int dst_step, int pack_tile, int src_point_stride) { + LOAD_LINE_DATA(0); + LOAD_LINE_DATA(1); + LOAD_LINE_DATA(2); + LOAD_LINE_DATA(3); + LOAD_LINE_DATA(4); + LOAD_LINE_DATA(5); + LOAD_LINE_DATA(6); + LOAD_LINE_DATA(7); + + MS_FLOAT32X4 m0 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s00, 0.5625), MS_MULQ_N_F32(s20, 3.0625)), MS_MULQ_N_F32(s40, 3.5)), s60); + MS_FLOAT32X4 m1 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s01, 0.5625), MS_MULQ_N_F32(s21, 3.0625)), MS_MULQ_N_F32(s41, 3.5)), s61); + MS_FLOAT32X4 m2 = MS_SUBQ_F32( + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(s02, 0.5625), MS_MULQ_N_F32(s22, 3.0625)), MS_MULQ_N_F32(s42, 3.5)), s62); + MS_STQ_F32(dst_ptr + 0 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 0 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 0 * dst_step + 2 * pack_tile, m2); + + MS_FLOAT32X4 tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 1.125), MS_MULQ_N_F32(s50, 0.5)); + MS_FLOAT32X4 tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 1.125), MS_MULQ_N_F32(s51, 0.5)); + MS_FLOAT32X4 tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 1.125), MS_MULQ_N_F32(s52, 0.5)); + MS_FLOAT32X4 tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 2.25), MS_MULQ_N_F32(s40, 3.25)); + MS_FLOAT32X4 tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 2.25), MS_MULQ_N_F32(s41, 3.25)); + MS_FLOAT32X4 tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 2.25), MS_MULQ_N_F32(s42, 3.25)); + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 1.625)), s60); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 1.625)), s61); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 1.625)), s62); + MS_STQ_F32(dst_ptr + 1 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 1 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 1 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 1.625)), s60); + m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 1.625)), s61); + m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 1.625)), s62); + MS_STQ_F32(dst_ptr + 2 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 2 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 2 * dst_step + 2 * pack_tile, m2); + + tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 0.5625), s50); + tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 0.5625), s51); + tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 0.5625), s52); + tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 0.5625), MS_MULQ_N_F32(s40, 2.5)); + tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 0.5625), MS_MULQ_N_F32(s41, 2.5)); + tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 0.5625), MS_MULQ_N_F32(s42, 2.5)); + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 2.5)), s60); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 2.5)), s61); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 2.5)), s62); + MS_STQ_F32(dst_ptr + 3 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 3 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 3 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 2.5)), s60); + m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 2.5)), s61); + m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 2.5)), s62); + MS_STQ_F32(dst_ptr + 4 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 4 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 4 * dst_step + 2 * pack_tile, m2); + + tmp10 = MS_ADDQ_F32(MS_MULQ_N_F32(s10, 0.375), MS_MULQ_N_F32(s50, 1.5)); + tmp11 = MS_ADDQ_F32(MS_MULQ_N_F32(s11, 0.375), MS_MULQ_N_F32(s51, 1.5)); + tmp12 = MS_ADDQ_F32(MS_MULQ_N_F32(s12, 0.375), MS_MULQ_N_F32(s52, 1.5)); + tmp20 = MS_SUBQ_F32(MS_MULQ_N_F32(s20, 0.25), MS_MULQ_N_F32(s40, 1.25)); + tmp21 = MS_SUBQ_F32(MS_MULQ_N_F32(s21, 0.25), MS_MULQ_N_F32(s41, 1.25)); + tmp22 = MS_SUBQ_F32(MS_MULQ_N_F32(s22, 0.25), MS_MULQ_N_F32(s42, 1.25)); + m0 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp10, tmp20), MS_MULQ_N_F32(s30, 1.875)), s60); + m1 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp11, tmp21), MS_MULQ_N_F32(s31, 1.875)), s61); + m2 = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp12, tmp22), MS_MULQ_N_F32(s32, 1.875)), s62); + MS_STQ_F32(dst_ptr + 5 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 5 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 5 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp20, tmp10), MS_MULQ_N_F32(s30, 1.875)), s60); + m1 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp21, tmp11), MS_MULQ_N_F32(s31, 1.875)), s61); + m2 = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp22, tmp12), MS_MULQ_N_F32(s32, 1.875)), s62); + MS_STQ_F32(dst_ptr + 6 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 6 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 6 * dst_step + 2 * pack_tile, m2); + + m0 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s10, -0.5625), MS_MULQ_N_F32(s30, 3.0625)), MS_MULQ_N_F32(s50, 3.5)), s70); + m1 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s11, -0.5625), MS_MULQ_N_F32(s31, 3.0625)), MS_MULQ_N_F32(s51, 3.5)), s71); + m2 = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(s12, -0.5625), MS_MULQ_N_F32(s32, 3.0625)), MS_MULQ_N_F32(s52, 3.5)), s72); + MS_STQ_F32(dst_ptr + 7 * dst_step + 0 * pack_tile, m0); + MS_STQ_F32(dst_ptr + 7 * dst_step + 1 * pack_tile, m1); + MS_STQ_F32(dst_ptr + 7 * dst_step + 2 * pack_tile, m2); +} +#endif + +void InputTransform8x8Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { + int block_tile = 12; + int pack_tile = src_step; + int src_point_stride = block_tile * pack_tile; +#ifdef ENABLE_ARM64 + for (int l = 0; l < 8; ++l) { + float *src_ptr = src_data + l * C4NUM * block_tile; + TRANSPOSE_12x4; + } + + for (int c = 0; c < real_c; ++c) { + float *src_ptr = src_data + c * block_tile; + float *dst_ptr = dst_data + c * block_tile; + InputTransform8x8Pack12Channel(src_ptr, dst_ptr, dst_step, pack_tile, src_point_stride); + } +#else + for (int l = 0; l < 8; ++l) { + float *src = src_data + l * pack_tile * block_tile; + // 12 * 4 -> 4 * 12 + float tmp_mat[4][12]; + for (int i = 0; i < block_tile; ++i) { + for (int j = 0; j < pack_tile; ++j) { + tmp_mat[j][i] = src[i * pack_tile + j]; + } + } + memcpy(src, tmp_mat, pack_tile * block_tile * sizeof(float)); + } + + float src[8]; + float m[8]; + for (int c = 0; c < real_c; ++c) { + for (int i = 0; i < block_tile; ++i) { + int tmp_index = c * block_tile + i; + for (int w = 0; w < 8; ++w) { + src[w] = src_data[tmp_index + w * src_point_stride]; + } + m[0] = 0.5625f * src[0] - 3.0625f * src[2] + 3.5f * src[4] - src[6]; + float tmp1 = 1.125f * src[1] + 0.5f * src[5]; + float tmp2 = 2.25f * src[2] - 3.25f * src[4]; + m[1] = tmp1 + tmp2 - 1.625f * src[3] + src[6]; + m[2] = tmp2 - tmp1 + 1.625f * src[3] + src[6]; + tmp1 = 0.5625f * src[1] + src[5]; + tmp2 = 0.5625f * src[2] - 2.5f * src[4]; + m[3] = tmp1 + tmp2 - 2.5f * src[3] + src[6]; + m[4] = tmp2 - tmp1 + 2.5f * src[3] + src[6]; + tmp1 = 0.375f * src[1] + 1.5f * src[5]; + tmp2 = 0.25f * src[2] - 1.25f * src[4]; + m[5] = tmp1 + tmp2 - 1.875f * src[3] + src[6]; + m[6] = tmp2 - tmp1 + 1.875f * src[3] + src[6]; + m[7] = -0.5625f * src[1] + 3.0625f * src[3] - 3.5f * src[5] + src[7]; + + for (int w = 0; w < 8; ++w) { + dst_data[tmp_index + w * dst_step] = m[w]; + } + } + } +#endif +} + +OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type) { + if (!CheckWinogradInputOutputUnit(input_unit, output_unit)) { + return NULL; + } + int in_index = (input_unit - 4) / 2; + int index = 0; + for (int i = 0; i < in_index; i++) { + index += ((i * 2 + 4) - 2) * 3; + } + int act_index; + if (act_type == ActType_Relu) { + act_index = 1; + } else if (act_type == ActType_Relu6) { + act_index = 2; + } else { + act_index = 0; + } + return OutputTransFuncList[index + (input_unit - 2) * act_index + output_unit - 2]; +} + +void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[8]; + MS_FLOAT32X4 m[4]; + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[8]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[8]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[8]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[8]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + m[l + 2] = MS_MINQ_F32(six, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[8]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[9]; + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADDQ_F32(src[offset], tmp); + t[l + 4] = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADDQ_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(tmp, t[3 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[12]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[0 + offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset]; + t[l + 8] = src[1 + offset] + src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset]; + m[l + 6] = t[1 + offset] + t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADDQ_F32(src[offset], tmp); + t[l + 4] = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADDQ_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(tmp, t[3 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[12]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[0 + offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset]; + t[l + 8] = src[1 + offset] + src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset]; + m[l + 6] = t[1 + offset] + t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[16]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load16Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + t[l] = MS_ADDQ_F32(src[offset], tmp); + t[l + 4] = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + t[l + 8] = MS_ADDQ_F32(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + MS_FLOAT32X4 tmp = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(tmp, t[3 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 3] = MS_MINQ_F32(six, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 6] = MS_MINQ_F32(six, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[16]; + float t[12]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[0 + offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset]; + t[l + 8] = src[1 + offset] + src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset]; + m[l + 6] = t[1 + offset] + t[2 + offset] + t[3 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[4]; + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[12]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[12]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[12]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), + t[4 + offset]), + bias_ptr); + m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + m[l + 2] = MS_MINQ_F32(six, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[12]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 2] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[18]; + MS_FLOAT32X4 m[9]; + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[18]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[18]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[18]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[18]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 3] = MS_MINQ_F32(six, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 6] = MS_MINQ_F32(six, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[18]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 3] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 6] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[16]; + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[24]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[16]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); + m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[24]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[16]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); + m[l + 4] = MS_MINQ_F32(six, m[l + 4]); + m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); + m[l + 8] = MS_MINQ_F32(six, m[l + 8]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + m[l + 12] = MS_MINQ_F32(six, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[24]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 4] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 8] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 12] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[30]; + MS_FLOAT32X4 m[25]; + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[30]; + float m[25]; + for (int i = 0; i < C4NUM; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); + t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); + m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[30]; + MS_FLOAT32X4 m[25]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); + m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); + m[l + 15] = MS_MAXQ_F32(zero, m[l + 15]); + m[l + 20] = MS_MAXQ_F32(zero, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[30]; + float m[25]; + for (int i = 0; i < C4NUM; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); + t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); + m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[36]; + MS_FLOAT32X4 t[30]; + MS_FLOAT32X4 m[25]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load36Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); + m[l + 5] = MS_MINQ_F32(six, m[l + 5]); + m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); + m[l + 10] = MS_MINQ_F32(six, m[l + 10]); + m[l + 15] = MS_MAXQ_F32(zero, m[l + 15]); + m[l + 15] = MS_MINQ_F32(six, m[l + 15]); + m[l + 20] = MS_MAXQ_F32(zero, m[l + 20]); + m[l + 20] = MS_MINQ_F32(six, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[36]; + float t[30]; + float m[25]; + for (int i = 0; i < C4NUM; ++i) { + // load source data + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset]; + t[l + 6] = src[1 + offset] - src[2 + offset] + 2 * (src[3 + offset] - src[4 + offset]); + t[l + 12] = src[1 + offset] + src[2 + offset] + 4 * (src[3 + offset] + src[4 + offset]); + t[l + 18] = src[1 + offset] - src[2 + offset] + 8 * (src[3 + offset] - src[4 + offset]); + t[l + 24] = src[1 + offset] + src[2 + offset] + 16 * (src[3 + offset] + src[4 + offset]) + src[5 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset]; + m[l + 5] = t[1 + offset] - t[2 + offset] + 2 * (t[3 + offset] - t[4 + offset]); + m[l + 10] = t[1 + offset] + t[2 + offset] + 4 * (t[3 + offset] + t[4 + offset]); + m[l + 15] = t[1 + offset] - t[2 + offset] + 8 * (t[3 + offset] - t[4 + offset]); + m[l + 20] = t[1 + offset] + t[2 + offset] + 16 * (t[3 + offset] + t[4 + offset]) + t[5 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[16]; + MS_FLOAT32X4 m[4]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[16]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 2] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[16]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[16]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 2] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[16]; + MS_FLOAT32X4 m[4]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); + m[l + 2] = MS_MINQ_F32(six, m[l + 2]); + } + if (r_c == C4NUM && r_h == 2 && r_w == 2) { + Store4Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[16]; + float m[4]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 2] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 2; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[9]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), + src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 6] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[24]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 3] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 6] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), + src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 6] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[24]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 3] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 6] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[24]; + MS_FLOAT32X4 m[9]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), + src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 6] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); + m[l + 3] = MS_MINQ_F32(six, m[l + 3]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 6] = MS_MINQ_F32(six, m[l + 6]); + } + if (r_c == C4NUM && r_h == 3 && r_w == 3) { + Store9Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[24]; + float m[9]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 3] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 6] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 3; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[32]; + MS_FLOAT32X4 m[16]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[32]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 4] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 8] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 12] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[32]; + MS_FLOAT32X4 m[16]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); + m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[32]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 4] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 8] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 12] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +#endif +} + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[32]; + MS_FLOAT32X4 m[16]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); + m[l + 4] = MS_MINQ_F32(six, m[l + 4]); + m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); + m[l + 8] = MS_MINQ_F32(six, m[l + 8]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + m[l + 12] = MS_MINQ_F32(six, m[l + 12]); + } + if (r_c == C4NUM && r_h == 4 && r_w == 4) { + Store16Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[32]; + float m[16]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 4] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 8] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 12] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 4; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[40]; + MS_FLOAT32X4 m[25]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 10] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +#else + float src[64]; + float t[40]; + float m[25]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 5] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 10] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 15] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 20] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +#endif +} + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[40]; + MS_FLOAT32X4 m[25]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 10] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); + m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); + m[l + 15] = MS_MAXQ_F32(zero, m[l + 15]); + m[l + 20] = MS_MAXQ_F32(zero, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[40]; + float m[25]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 5] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 10] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 15] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 20] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[40]; + MS_FLOAT32X4 m[25]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 10] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 15] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); + m[l + 5] = MS_MINQ_F32(six, m[l + 5]); + m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); + m[l + 10] = MS_MINQ_F32(six, m[l + 10]); + m[l + 15] = MS_MAXQ_F32(zero, m[l + 15]); + m[l + 15] = MS_MINQ_F32(six, m[l + 15]); + m[l + 20] = MS_MAXQ_F32(zero, m[l + 20]); + m[l + 20] = MS_MINQ_F32(six, m[l + 20]); + } + if (r_c == C4NUM && r_h == 5 && r_w == 5) { + Store25Data; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[40]; + float m[25]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 5] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 10] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 15] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 20] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 5; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[48]; + MS_FLOAT32X4 m[36]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[48]; + float m[36]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 6] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 12] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 18] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 24] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 30] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 6; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[48]; + MS_FLOAT32X4 m[36]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + m[l + 18] = MS_MAXQ_F32(zero, m[l + 18]); + m[l + 24] = MS_MAXQ_F32(zero, m[l + 24]); + m[l + 30] = MS_MAXQ_F32(zero, m[l + 30]); + } + if (r_c == C4NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[48]; + float m[36]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 6] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 12] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 18] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 24] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 30] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 6; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[48]; + MS_FLOAT32X4 m[36]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 18] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 24] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); + m[l + 6] = MS_MINQ_F32(six, m[l + 6]); + m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); + m[l + 12] = MS_MINQ_F32(six, m[l + 12]); + m[l + 18] = MS_MAXQ_F32(zero, m[l + 18]); + m[l + 18] = MS_MINQ_F32(six, m[l + 18]); + m[l + 24] = MS_MAXQ_F32(zero, m[l + 24]); + m[l + 24] = MS_MINQ_F32(six, m[l + 24]); + m[l + 30] = MS_MAXQ_F32(zero, m[l + 30]); + m[l + 30] = MS_MINQ_F32(six, m[l + 30]); + } + if (r_c == C4NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[48]; + float m[36]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 6] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 12] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 18] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 24] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 30] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 6; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[56]; + MS_FLOAT32X4 m[49]; + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 14] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + } + if (r_c == C4NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_STQ_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[56]; + float m[49]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]); + t[l + 48] = 0.015625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 11.390625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 7] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 14] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 21] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 28] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 35] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]); + m[l + 42] = 0.015625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 11.390625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 7; + for (int j = 0; j < r_w; ++j) { + dst_data[i + dst_k_offset + j * out_c] = m[j + m_k_offset] + bias_data[i]; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[56]; + MS_FLOAT32X4 m[49]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 14] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l + 7] = MS_MAXQ_F32(zero, m[l + 7]); + m[l + 14] = MS_MAXQ_F32(zero, m[l + 14]); + m[l + 21] = MS_MAXQ_F32(zero, m[l + 21]); + m[l + 28] = MS_MAXQ_F32(zero, m[l + 28]); + m[l + 35] = MS_MAXQ_F32(zero, m[l + 35]); + m[l + 42] = MS_MAXQ_F32(zero, m[l + 42]); + } + if (r_c == C4NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_STQ_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[56]; + float m[49]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]); + t[l + 48] = 0.015625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 11.390625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 7] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 14] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 21] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 28] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 35] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]); + m[l + 42] = 0.015625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 11.390625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 7; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + MS_FLOAT32X4 src[64]; + MS_FLOAT32X4 t[56]; + MS_FLOAT32X4 m[49]; + MS_FLOAT32X4 zero = MS_MOVQ_F32(0); + MS_FLOAT32X4 six = MS_MOVQ_F32(6); + Load64Data; + MS_FLOAT32X4 bias_ptr = MS_LDQ_F32(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(src[5 + offset], src[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); + t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp3 = MS_ADDQ_F32(t[5 + offset], t[6 + offset]); + MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); + MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); + MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); + m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 14] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 21] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 28] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 35] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); + m[l] = MS_MAXQ_F32(zero, m[l]); + m[l] = MS_MINQ_F32(six, m[l]); + m[l + 7] = MS_MAXQ_F32(zero, m[l + 7]); + m[l + 7] = MS_MINQ_F32(six, m[l + 7]); + m[l + 14] = MS_MAXQ_F32(zero, m[l + 14]); + m[l + 14] = MS_MINQ_F32(six, m[l + 14]); + m[l + 21] = MS_MAXQ_F32(zero, m[l + 21]); + m[l + 21] = MS_MINQ_F32(six, m[l + 21]); + m[l + 28] = MS_MAXQ_F32(zero, m[l + 28]); + m[l + 28] = MS_MINQ_F32(six, m[l + 28]); + m[l + 35] = MS_MAXQ_F32(zero, m[l + 35]); + m[l + 35] = MS_MINQ_F32(six, m[l + 35]); + m[l + 42] = MS_MAXQ_F32(zero, m[l + 42]); + m[l + 42] = MS_MINQ_F32(six, m[l + 42]); + } + if (r_c == C4NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + MS_STQ_F32(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + MS_STQ_F32(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + MS_STQ_F32(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + MS_STQ_F32(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + MS_STQ_F32(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + MS_STQ_F32(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + MS_STQ_F32(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = MS_F32X4_GETI(m[k + m_k_offset], i); + } + } + } + } +} +#else +void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { + float src[64]; + float t[56]; + float m[49]; + for (int i = 0; i < r_c; ++i) { + // load source data + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] + src[1 + offset] + src[2 + offset] + src[3 + offset] + src[4 + offset] + src[5 + offset] + + src[6 + offset]; + t[l + 8] = 0.5f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 1.5f * (src[5 + offset] - src[6 + offset]); + t[l + 16] = 0.25f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 2.25f * (src[5 + offset] + src[6 + offset]); + t[l + 24] = 0.125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 3.375f * (src[5 + offset] - src[6 + offset]); + t[l + 32] = 0.0625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 5.0625f * (src[5 + offset] + src[6 + offset]); + t[l + 40] = 0.03125f * (src[1 + offset] - src[2 + offset]) + (src[3 + offset] - src[4 + offset]) + + 7.59375f * (src[5 + offset] - src[6 + offset]); + t[l + 48] = 0.015625f * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]) + + 11.390625f * (src[5 + offset] + src[6 + offset]) + src[7 + offset]; + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + t[3 + offset] + t[4 + offset] + t[5 + offset] + t[6 + offset]; + m[l + 7] = 0.5f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 1.5f * (t[5 + offset] - t[6 + offset]); + m[l + 14] = 0.25f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 2.25f * (t[5 + offset] + t[6 + offset]); + m[l + 21] = 0.125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 3.375f * (t[5 + offset] - t[6 + offset]); + m[l + 28] = 0.0625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 5.0625f * (t[5 + offset] + t[6 + offset]); + m[l + 35] = 0.03125f * (t[1 + offset] - t[2 + offset]) + (t[3 + offset] - t[4 + offset]) + + 7.59375f * (t[5 + offset] - t[6 + offset]); + m[l + 42] = 0.015625f * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]) + + 11.390625f * (t[5 + offset] + t[6 + offset]) + t[7 + offset]; + } + // store output + for (int k = 0; k < r_h; ++k) { + int dst_k_offset = k * dst_step * out_c; + int m_k_offset = k * 7; + for (int j = 0; j < r_w; ++j) { + float out_value = m[j + m_k_offset] + bias_data[i]; + out_value = out_value > 0 ? out_value : 0; + out_value = out_value < 6 ? out_value : 6; + dst_data[i + dst_k_offset + j * out_c] = out_value; + } + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_utils.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..e2ea3b067f7d6b07c3554c63eedbf3b61b08440c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32/winograd_utils.h @@ -0,0 +1,373 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_WINOGRAD_UTILS_H_ +#define MINDSPORE_NNACL_WINOGRAD_UTILS_H_ + +#ifdef ENABLE_ARM +#include +#endif +#include "nnacl/conv_parameter.h" +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +typedef void (*InputTransStepFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, + int dst_row_step); + +typedef void (*InputTransPackFunc)(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +typedef struct TransFuncList { + InputTransFunc in_func_; + InputTransStepFunc in_step_func_; + InputTransPackFunc in_pack_func_; + OutputTransFunc out_func_; +} TransFuncList; + +#define Load16Data \ + src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ + src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ + src[2] = MS_LDQ_F32(src_data + 2 * src_step); \ + src[3] = MS_LDQ_F32(src_data + 3 * src_step); \ + src[4] = MS_LDQ_F32(src_data + 4 * src_step); \ + src[5] = MS_LDQ_F32(src_data + 5 * src_step); \ + src[6] = MS_LDQ_F32(src_data + 6 * src_step); \ + src[7] = MS_LDQ_F32(src_data + 7 * src_step); \ + src[8] = MS_LDQ_F32(src_data + 8 * src_step); \ + src[9] = MS_LDQ_F32(src_data + 9 * src_step); \ + src[10] = MS_LDQ_F32(src_data + 10 * src_step); \ + src[11] = MS_LDQ_F32(src_data + 11 * src_step); \ + src[12] = MS_LDQ_F32(src_data + 12 * src_step); \ + src[13] = MS_LDQ_F32(src_data + 13 * src_step); \ + src[14] = MS_LDQ_F32(src_data + 14 * src_step); \ + src[15] = MS_LDQ_F32(src_data + 15 * src_step); + +#define Load36Data \ + src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ + src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ + src[2] = MS_LDQ_F32(src_data + 2 * src_step); \ + src[3] = MS_LDQ_F32(src_data + 3 * src_step); \ + src[4] = MS_LDQ_F32(src_data + 4 * src_step); \ + src[5] = MS_LDQ_F32(src_data + 5 * src_step); \ + src[6] = MS_LDQ_F32(src_data + 6 * src_step); \ + src[7] = MS_LDQ_F32(src_data + 7 * src_step); \ + src[8] = MS_LDQ_F32(src_data + 8 * src_step); \ + src[9] = MS_LDQ_F32(src_data + 9 * src_step); \ + src[10] = MS_LDQ_F32(src_data + 10 * src_step); \ + src[11] = MS_LDQ_F32(src_data + 11 * src_step); \ + src[12] = MS_LDQ_F32(src_data + 12 * src_step); \ + src[13] = MS_LDQ_F32(src_data + 13 * src_step); \ + src[14] = MS_LDQ_F32(src_data + 14 * src_step); \ + src[15] = MS_LDQ_F32(src_data + 15 * src_step); \ + src[16] = MS_LDQ_F32(src_data + 16 * src_step); \ + src[17] = MS_LDQ_F32(src_data + 17 * src_step); \ + src[18] = MS_LDQ_F32(src_data + 18 * src_step); \ + src[19] = MS_LDQ_F32(src_data + 19 * src_step); \ + src[20] = MS_LDQ_F32(src_data + 20 * src_step); \ + src[21] = MS_LDQ_F32(src_data + 21 * src_step); \ + src[22] = MS_LDQ_F32(src_data + 22 * src_step); \ + src[23] = MS_LDQ_F32(src_data + 23 * src_step); \ + src[24] = MS_LDQ_F32(src_data + 24 * src_step); \ + src[25] = MS_LDQ_F32(src_data + 25 * src_step); \ + src[26] = MS_LDQ_F32(src_data + 26 * src_step); \ + src[27] = MS_LDQ_F32(src_data + 27 * src_step); \ + src[28] = MS_LDQ_F32(src_data + 28 * src_step); \ + src[29] = MS_LDQ_F32(src_data + 29 * src_step); \ + src[30] = MS_LDQ_F32(src_data + 30 * src_step); \ + src[31] = MS_LDQ_F32(src_data + 31 * src_step); \ + src[32] = MS_LDQ_F32(src_data + 32 * src_step); \ + src[33] = MS_LDQ_F32(src_data + 33 * src_step); \ + src[34] = MS_LDQ_F32(src_data + 34 * src_step); \ + src[35] = MS_LDQ_F32(src_data + 35 * src_step); + +#define Load64Data \ + src[0] = MS_LDQ_F32(src_data + 0 * src_step); \ + src[1] = MS_LDQ_F32(src_data + 1 * src_step); \ + src[2] = MS_LDQ_F32(src_data + 2 * src_step); \ + src[3] = MS_LDQ_F32(src_data + 3 * src_step); \ + src[4] = MS_LDQ_F32(src_data + 4 * src_step); \ + src[5] = MS_LDQ_F32(src_data + 5 * src_step); \ + src[6] = MS_LDQ_F32(src_data + 6 * src_step); \ + src[7] = MS_LDQ_F32(src_data + 7 * src_step); \ + src[8] = MS_LDQ_F32(src_data + 8 * src_step); \ + src[9] = MS_LDQ_F32(src_data + 9 * src_step); \ + src[10] = MS_LDQ_F32(src_data + 10 * src_step); \ + src[11] = MS_LDQ_F32(src_data + 11 * src_step); \ + src[12] = MS_LDQ_F32(src_data + 12 * src_step); \ + src[13] = MS_LDQ_F32(src_data + 13 * src_step); \ + src[14] = MS_LDQ_F32(src_data + 14 * src_step); \ + src[15] = MS_LDQ_F32(src_data + 15 * src_step); \ + src[16] = MS_LDQ_F32(src_data + 16 * src_step); \ + src[17] = MS_LDQ_F32(src_data + 17 * src_step); \ + src[18] = MS_LDQ_F32(src_data + 18 * src_step); \ + src[19] = MS_LDQ_F32(src_data + 19 * src_step); \ + src[20] = MS_LDQ_F32(src_data + 20 * src_step); \ + src[21] = MS_LDQ_F32(src_data + 21 * src_step); \ + src[22] = MS_LDQ_F32(src_data + 22 * src_step); \ + src[23] = MS_LDQ_F32(src_data + 23 * src_step); \ + src[24] = MS_LDQ_F32(src_data + 24 * src_step); \ + src[25] = MS_LDQ_F32(src_data + 25 * src_step); \ + src[26] = MS_LDQ_F32(src_data + 26 * src_step); \ + src[27] = MS_LDQ_F32(src_data + 27 * src_step); \ + src[28] = MS_LDQ_F32(src_data + 28 * src_step); \ + src[29] = MS_LDQ_F32(src_data + 29 * src_step); \ + src[30] = MS_LDQ_F32(src_data + 30 * src_step); \ + src[31] = MS_LDQ_F32(src_data + 31 * src_step); \ + src[32] = MS_LDQ_F32(src_data + 32 * src_step); \ + src[33] = MS_LDQ_F32(src_data + 33 * src_step); \ + src[34] = MS_LDQ_F32(src_data + 34 * src_step); \ + src[35] = MS_LDQ_F32(src_data + 35 * src_step); \ + src[36] = MS_LDQ_F32(src_data + 36 * src_step); \ + src[37] = MS_LDQ_F32(src_data + 37 * src_step); \ + src[38] = MS_LDQ_F32(src_data + 38 * src_step); \ + src[39] = MS_LDQ_F32(src_data + 39 * src_step); \ + src[40] = MS_LDQ_F32(src_data + 40 * src_step); \ + src[41] = MS_LDQ_F32(src_data + 41 * src_step); \ + src[42] = MS_LDQ_F32(src_data + 42 * src_step); \ + src[43] = MS_LDQ_F32(src_data + 43 * src_step); \ + src[44] = MS_LDQ_F32(src_data + 44 * src_step); \ + src[45] = MS_LDQ_F32(src_data + 45 * src_step); \ + src[46] = MS_LDQ_F32(src_data + 46 * src_step); \ + src[47] = MS_LDQ_F32(src_data + 47 * src_step); \ + src[48] = MS_LDQ_F32(src_data + 48 * src_step); \ + src[49] = MS_LDQ_F32(src_data + 49 * src_step); \ + src[50] = MS_LDQ_F32(src_data + 50 * src_step); \ + src[51] = MS_LDQ_F32(src_data + 51 * src_step); \ + src[52] = MS_LDQ_F32(src_data + 52 * src_step); \ + src[53] = MS_LDQ_F32(src_data + 53 * src_step); \ + src[54] = MS_LDQ_F32(src_data + 54 * src_step); \ + src[55] = MS_LDQ_F32(src_data + 55 * src_step); \ + src[56] = MS_LDQ_F32(src_data + 56 * src_step); \ + src[57] = MS_LDQ_F32(src_data + 57 * src_step); \ + src[58] = MS_LDQ_F32(src_data + 58 * src_step); \ + src[59] = MS_LDQ_F32(src_data + 59 * src_step); \ + src[60] = MS_LDQ_F32(src_data + 60 * src_step); \ + src[61] = MS_LDQ_F32(src_data + 61 * src_step); \ + src[62] = MS_LDQ_F32(src_data + 62 * src_step); \ + src[63] = MS_LDQ_F32(src_data + 63 * src_step); + +#define LOAD_LINE_DATA(line) \ + MS_FLOAT32X4 s##line##0 = MS_LDQ_F32(src_ptr + line * src_point_stride + 0 * pack_tile); \ + MS_FLOAT32X4 s##line##1 = MS_LDQ_F32(src_ptr + line * src_point_stride + 1 * pack_tile); \ + MS_FLOAT32X4 s##line##2 = MS_LDQ_F32(src_ptr + line * src_point_stride + 2 * pack_tile); + +#define TRANSPOSE_12x4 \ + MS_FLOAT32X4 s0 = MS_LDQ_F32(src_ptr + 0 * pack_tile); \ + MS_FLOAT32X4 s3 = MS_LDQ_F32(src_ptr + 1 * pack_tile); \ + MS_FLOAT32X4 s6 = MS_LDQ_F32(src_ptr + 2 * pack_tile); \ + MS_FLOAT32X4 s9 = MS_LDQ_F32(src_ptr + 3 * pack_tile); \ + MS_FLOAT32X4 s1 = MS_LDQ_F32(src_ptr + 4 * pack_tile); \ + MS_FLOAT32X4 s4 = MS_LDQ_F32(src_ptr + 5 * pack_tile); \ + MS_FLOAT32X4 s7 = MS_LDQ_F32(src_ptr + 6 * pack_tile); \ + MS_FLOAT32X4 s10 = MS_LDQ_F32(src_ptr + 7 * pack_tile); \ + MS_FLOAT32X4 s2 = MS_LDQ_F32(src_ptr + 8 * pack_tile); \ + MS_FLOAT32X4 s5 = MS_LDQ_F32(src_ptr + 9 * pack_tile); \ + MS_FLOAT32X4 s8 = MS_LDQ_F32(src_ptr + 10 * pack_tile); \ + MS_FLOAT32X4 s11 = MS_LDQ_F32(src_ptr + 11 * pack_tile); \ + transpose4(&s0, &s3, &s6, &s9); \ + transpose4(&s1, &s4, &s7, &s10); \ + transpose4(&s2, &s5, &s8, &s11); \ + MS_STQ_F32(src_ptr + 0 * pack_tile, s0); \ + MS_STQ_F32(src_ptr + 1 * pack_tile, s1); \ + MS_STQ_F32(src_ptr + 2 * pack_tile, s2); \ + MS_STQ_F32(src_ptr + 3 * pack_tile, s3); \ + MS_STQ_F32(src_ptr + 4 * pack_tile, s4); \ + MS_STQ_F32(src_ptr + 5 * pack_tile, s5); \ + MS_STQ_F32(src_ptr + 6 * pack_tile, s6); \ + MS_STQ_F32(src_ptr + 7 * pack_tile, s7); \ + MS_STQ_F32(src_ptr + 8 * pack_tile, s8); \ + MS_STQ_F32(src_ptr + 9 * pack_tile, s9); \ + MS_STQ_F32(src_ptr + 10 * pack_tile, s10); \ + MS_STQ_F32(src_ptr + 11 * pack_tile, s11); + +InputTransFunc GetInputTransFunc(int input_unit); + +#ifdef ENABLE_ARM64 +InputTransStepFunc GetInputTransStepFunc(int input_unit); + +InputTransPackFunc GetInputTransPackFunc(int input_unit); +#endif + +void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform4x4Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step); + +void InputTransform4x4Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform6x6Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step); + +void InputTransform6x6Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +void InputTransform8x8Step(const float *src_data, float *dst_data, int src_step, int dst_step, int dst_row_step); + +void InputTransform8x8Pack12(float *src_data, float *dst_data, int src_step, int dst_step, int real_c); + +OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type); + +#define Store4Data \ + MS_STQ_F32(dst_data, m[0]); \ + MS_STQ_F32(dst_data + out_c, m[1]); \ + MS_STQ_F32(dst_data + dst_step * out_c, m[2]); \ + MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[3]); + +#define Store9Data \ + MS_STQ_F32(dst_data, m[0]); \ + MS_STQ_F32(dst_data + out_c, m[1]); \ + MS_STQ_F32(dst_data + 2 * out_c, m[2]); \ + MS_STQ_F32(dst_data + dst_step * out_c, m[3]); \ + MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[4]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[6]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define Store16Data \ + MS_STQ_F32(dst_data, m[0]); \ + MS_STQ_F32(dst_data + out_c, m[1]); \ + MS_STQ_F32(dst_data + 2 * out_c, m[2]); \ + MS_STQ_F32(dst_data + 3 * out_c, m[3]); \ + MS_STQ_F32(dst_data + dst_step * out_c, m[4]); \ + MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[5]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[8]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c, m[12]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define Store25Data \ + MS_STQ_F32(dst_data, m[0]); \ + MS_STQ_F32(dst_data + out_c, m[1]); \ + MS_STQ_F32(dst_data + 2 * out_c, m[2]); \ + MS_STQ_F32(dst_data + 3 * out_c, m[3]); \ + MS_STQ_F32(dst_data + 4 * out_c, m[4]); \ + MS_STQ_F32(dst_data + dst_step * out_c, m[5]); \ + MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[6]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + MS_STQ_F32(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[10]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + MS_STQ_F32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c, m[15]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + MS_STQ_F32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c, m[20]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + MS_STQ_F32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, + int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c); + +int SelectOutputUnit(const ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_WINOGRAD_UTILS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/activation_grad_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/activation_grad_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..4f16c29ee39c716b8d84fb2fb7c21d71b437c0ef --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/activation_grad_fp32.c @@ -0,0 +1,161 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl/op_base.h" +#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/fp32/exp_fp32.h" +#include "nnacl/fp32_grad/activation_grad_fp32.h" +#include "nnacl/errorcode.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/activation_grad_simd.h" + +int ReluGrad(const float *src0, const float *src1, int length, float *dst) { + int i = 0; +#ifdef ENABLE_ARM + float32x4_t zero_4 = vdupq_n_f32(0.0f); + for (; i < length - C4NUM; i += C4NUM) { + float32x4_t src1_4 = vld1q_f32(src1 + i); + float32x4_t src0_4 = vld1q_f32(src0 + i); + uint32x4_t mask_4 = vcleq_f32(src1_4, zero_4); + float32x4_t dst_4 = vbslq_f32(mask_4, zero_4, src0_4); + vst1q_f32(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int Relu6Grad(const float *src0, const float *src1, size_t length, float *dst) { + size_t i = 0; +#ifdef ENABLE_ARM + float32x4_t zero_4 = vdupq_n_f32(0.0f); + float32x4_t six_4 = vdupq_n_f32(6.0f); + for (; i < length - C4NUM; i += C4NUM) { + float32x4_t src1_4 = vld1q_f32(src1 + i); + float32x4_t src0_4 = vld1q_f32(src0 + i); + uint32x4_t gt_4 = vcgtq_f32(src1_4, zero_4); + uint32x4_t le_4 = vcleq_f32(src1_4, six_4); + uint32x4_t mask_4 = vandq_u32(gt_4, le_4); + float32x4_t dst_4 = vbslq_f32(mask_4, src0_4, zero_4); + vst1q_f32(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f && src1[i] <= 6.0f) ? src0[i] : 0.0f; + } + return NNACL_OK; +} + +int LReluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha) { + for (size_t i = 0; i < length; ++i) { + dst[i] = src1[i] > 0.0f ? src0[i] : alpha * src0[i]; + } + return NNACL_OK; +} + +int SigmoidGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + dst[i] = src0[i] * (src1[i] * (1.0f - src1[i])); + } + return NNACL_OK; +} + +int TanhGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + dst[i] = (1.0f - (src1[i] * src1[i])) * src0[i]; + } + return NNACL_OK; +} + +int HSwishGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : (2.0f * src1[i] + 3.0f) / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} + +int HSigmoidGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + float tmp = (src1[i] > 3.0f ? 0.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f)); + dst[i] = tmp * src0[i]; + } + return NNACL_OK; +} + +int EluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha) { + for (size_t i = 0; i < length; ++i) { + dst[i] = (src1[i] > 0.0f ? src0[i] : alpha * expm1(src1[i]) * src0[i]); + } + return NNACL_OK; +} + +int GeluGrad(const float *src0, const float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + dst[i] = src0[i] * ((0.5 * (1.0 + erf(src1[i] / 1.4142135623730951))) + + (src1[i] * exp(-0.5 * src1[i] * src1[i]) / 2.5066282746)); + } + return NNACL_OK; +} + +int SoftplusGrad(const float *src0, const float *src1, int length, float *dst) { + int i = 0; +#if defined(ENABLE_AVX) + for (; i <= length - C8NUM; i += C8NUM) { + simd_exp256(MS_SUB256_F32(MS_MOV256_F32(0.0f), (MS_LD256_F32(src1 + i))), dst + i); + MS_ST256_F32(dst + i, + MS_DIV256_F32(MS_LD256_F32(src0 + i), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i)))); + } +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + for (; i <= length - C4NUM; i += C4NUM) { + simd_exp128(MS_SUBQ_F32(MS_MOVQ_F32(0.0f), MS_LDQ_F32(src1 + i)), dst + i); + MS_STQ_F32(dst + i, MS_DIVQ_F32(MS_LDQ_F32(src0 + i), MS_ADDQ_F32(MS_MOVQ_F32(1.0f), MS_LDQ_F32(dst + i)))); + } +#endif + + for (; i < length; ++i) { + simd_exp32(-src1[i], dst + i); + dst[i] = src0[i] / (1.0f + dst[i]); + } + return NNACL_OK; +} + +int HardShrinkGrad(const float *src0, const float *src1, int length, float *dst, float lambd) { + int i = 0; + const float neg_lambd = -1 * lambd; + SIMD_RUN_NO_SCALAR(ShrinkGrad, i, src0, src1, length, dst, lambd); + + for (; i < length; ++i) { + dst[i] = (src1[i] >= neg_lambd && src1[i] <= lambd) ? 0 : src0[i]; + } + return NNACL_OK; +} + +int SoftShrinkGrad(const float *src0, const float *src1, int length, float *dst, float lambd) { + int i = 0; + const float neg_lambd = -1 * lambd; + SIMD_RUN_NO_SCALAR(ShrinkGrad, i, src0, src1, length, dst, lambd); + + for (; i < length; ++i) { + dst[i] = (src1[i] >= neg_lambd && src1[i] <= lambd) ? 0 : src0[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/activation_grad_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/activation_grad_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..1afd2277877a4c0c9d74ef5f52e22b4b66861ea5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/activation_grad_fp32.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_ACTIVATION_GRAD_FP32_H_ +#define NNACL_FP32_GRAD_ACTIVATION_GRAD_FP32_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/errorcode.h" + +typedef struct ActivationGradParameter { + OpParameter op_parameter; + int type_; + float alpha_; +} ActivationGradParameter; +#ifdef __cplusplus +extern "C" { +#endif + +int ReluGrad(const float *src0, const float *src1, int length, float *dst); +int Relu6Grad(const float *src0, const float *src1, size_t length, float *dst); +int LReluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha); +int SigmoidGrad(const float *src0, const float *src1, size_t length, float *dst); +int TanhGrad(const float *src0, const float *src1, size_t length, float *dst); +int HSwishGrad(const float *src0, const float *src1, size_t length, float *dst); +int HSigmoidGrad(const float *src0, const float *src1, size_t length, float *dst); +int EluGrad(const float *src0, const float *src1, size_t length, float *dst, float alpha); +int GeluGrad(const float *src0, const float *src1, size_t length, float *dst); +int SoftplusGrad(const float *src, const float *src1, int length, float *dst); +int HardShrinkGrad(const float *src0, const float *src1, int length, float *dst, float lambd); +int SoftShrinkGrad(const float *src0, const float *src1, int length, float *dst, float lambd); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_ACTIVATION_GRAD_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/activation_grad_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/activation_grad_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..7e7be9b12cf0189b207fd32da8672383f65b34ca --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/activation_grad_simd.h.in @@ -0,0 +1,50 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_ACTIVATION_GRAD_@SIMD_INSTRUCTION@_H_ +#define NNACL_FP32_GRAD_ACTIVATION_GRAD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int ShrinkGrad@SIMD_INSTRUCTION@(int index, const float *src0, const float *src1, + int length, float *dst, float lambd) { + SIMD_F32 pos_lamdb_v = SIMD_MOV_F32(lambd); + SIMD_F32 neg_lamdb_v = SIMD_MOV_F32(-lambd); + + for (int block_max_size = length - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 src0_t = SIMD_LD_F32(src0 + index); + SIMD_F32 src1_t = SIMD_LD_F32(src1 + index); + + SIMD_MASK mask0 = SIMD_CMPLE_F32(src1_t, pos_lamdb_v); + SIMD_MASK mask1 = SIMD_CMPLE_F32(neg_lamdb_v, src1_t); + SIMD_MASK mask = SIMD_AND_MASK(mask0, mask1); + + SIMD_ST_F32(dst + index, SIMD_BLEND_F32(src0_t, SIMD_MOV_F32(0.0f), mask)); + } + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/src/extendrt/kernel/ascend/api/ascend_kernel_api.cc b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_adagrad_fp32.c similarity index 41% rename from mindspore-lite/src/extendrt/kernel/ascend/api/ascend_kernel_api.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_adagrad_fp32.c index 1a3494f35f32916befdfd89876eadd63fba12330..9c33b1612eb6e03323ec01d6a24ccaf76bf76be3 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/api/ascend_kernel_api.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_adagrad_fp32.c @@ -13,27 +13,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "nnacl/fp32_grad/apply_proximal_adagrad_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/apply_proximal_adagrad_fp32_simd.h" -#include "extendrt/kernel/ascend/api/ascend_kernel_api.h" -#include "src/common/common.h" -#include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" - -std::map *CreateCustomAscendKernel() { - CreatorFunc creator_func = []() { return std::make_shared(); }; - std::map *func_map = new (std::nothrow) std::map(); - if (func_map == nullptr) { - MS_LOG(ERROR) << "New custom ascend kernel failed."; - return {}; +int Sign(float x) { + if (x > 0) { + return 1; + } + if (x < 0) { + return -1; } - (*func_map)[mindspore::lite::kNameCustomAscend] = creator_func; - mindspore::device::ascend::LoadAscendApiSymbols(); - return func_map; + return 0; } -void DestroyCustomAscendKernel(std::map *creator_func) { - if (creator_func == nullptr) { - MS_LOG(ERROR) << "Param creator func is nullptr."; - return; +void ApplyProximalAdagradOpt(float *var, float *accum, float lr, float l1, float l2, float *grad, + int64_t input_elements) { + int64_t i = 0; + + SIMD_RUN_NO_SCALAR(ApplyProximalAdagradOpt, i, var, accum, lr, l1, l2, grad, input_elements); + + for (; i < input_elements; ++i) { + accum[i] += grad[i] * grad[i]; + float learning_rate = lr / sqrt(accum[i]); + float prox_v = var[i]; + prox_v -= grad[i] * learning_rate; + + if (l1 > 0) { + var[i] = Sign(prox_v) * fmax(fabs(prox_v) - learning_rate * l1, 0.0) / (1 + l2 * learning_rate); + } else { + var[i] = prox_v / (1 + l2 * learning_rate); + } } - delete creator_func; } diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_adagrad_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_adagrad_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..0483b04b84db101b51f75dfc4e394259d254d64e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_adagrad_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_H_ +#define NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ApplyProximalAdagradOpt(float *var, float *accum, float lr, float l1, float l2, float *grad, + int64_t input_elements); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_adagrad_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_adagrad_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..264b1e51c05862a17bbecfa8e44de8b1bd9c2ebf --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_adagrad_fp32_simd.h.in @@ -0,0 +1,68 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_@SIMD_INSTRUCTION@_H_ +#define NNACL_FP32_APPLY_PROXIMAL_ADAGRAD_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t ApplyProximalAdagradOpt@SIMD_INSTRUCTION@( + int64_t index, float *var, float *accum, float lr, float l1, float l2, float *grad, int64_t size) { + SIMD_F32 lr_vec = SIMD_MOV_F32(lr); + SIMD_F32 l1_vec = SIMD_MOV_F32(l1); + SIMD_F32 l2_vec = SIMD_MOV_F32(l2); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 tmp_vec1 = SIMD_LD_F32(grad + index); + SIMD_F32 accum_vec = SIMD_LD_F32(accum + index); + SIMD_F32 prox_v_vec = SIMD_LD_F32(var + index); + + accum_vec = SIMD_FMADD_F32(tmp_vec1, tmp_vec1, accum_vec); + SIMD_F32 learn_rate_vec = SIMD_DIV_F32(lr_vec, SIMD_SQRT_F32(accum_vec)); + prox_v_vec = SIMD_SUB_F32(prox_v_vec, SIMD_MUL_F32(tmp_vec1, learn_rate_vec)); + SIMD_ST_F32(accum + index, accum_vec); + tmp_vec1 = SIMD_FMADD_F32(l2_vec, learn_rate_vec, SIMD_MOV_F32(1)); + if (l1 > 0) { + learn_rate_vec = SIMD_MUL_F32(learn_rate_vec, l1_vec); + learn_rate_vec = SIMD_SUB_F32(SIMD_ABS_F32(prox_v_vec), learn_rate_vec); + learn_rate_vec = SIMD_MAX_F32(learn_rate_vec, SIMD_MOV_F32(0.0f)); + learn_rate_vec = SIMD_DIV_F32(learn_rate_vec, tmp_vec1); + + SIMD_MASK greater_mask = SIMD_CMPGT_F32(SIMD_SET0_F32, prox_v_vec); + SIMD_MASK less_mask = SIMD_CMPLT_F32(SIMD_SET0_F32, prox_v_vec); + SIMD_F32 greater_v = SIMD_BLEND_F32(SIMD_MOV_F32(1), SIMD_SET0_F32, greater_mask); + SIMD_F32 less_v = SIMD_BLEND_F32(SIMD_MOV_F32(1), SIMD_SET0_F32, less_mask); + greater_v = SIMD_SUB_F32(greater_v, less_v); + + prox_v_vec = SIMD_MUL_F32(learn_rate_vec, greater_v); + } else { + prox_v_vec = SIMD_DIV_F32(prox_v_vec, tmp_vec1); + } + SIMD_ST_F32(var + index, prox_v_vec); + } + + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_gradient_descent_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_gradient_descent_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..bf802aaa506fc08bb2954368540ebb440276f13d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_gradient_descent_fp32.c @@ -0,0 +1,44 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32_grad/apply_proximal_gradient_descent_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/apply_proximal_gradient_descent_fp32_simd.h" + +void ApplyProximalGradientDescentOpt(float *var, float alpha, float l1, float l2, float *delta, + int64_t input_elements) { + int64_t i = 0; + SIMD_RUN_NO_SCALAR(ApplyProximalGradientDescentOpt, i, var, alpha, l1, l2, delta, input_elements); + for (; i < input_elements; ++i) { + float prox_v = var[i]; + prox_v -= delta[i] * alpha; + + if (l1 > 0) { + var[i] = SignFp32(prox_v) * fmax(fabs(prox_v) - alpha * l1, 0.0) / (1 + l2 * alpha); + } else { + var[i] = prox_v / (1 + l2 * alpha); + } + } +} + +float SignFp32(const float x) { + if (x > 0.0) { + return 1.0; + } + if (x < 0.0) { + return -1.0; + } + return 0.0; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_gradient_descent_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_gradient_descent_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..9e9b8dcf0989f1e53560be2bdff766194a87b3fc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_gradient_descent_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ +#define NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ApplyProximalGradientDescentOpt(float *var, float alpha, float l1, float l2, float *delta, int64_t input_elements); +float SignFp32(const float x); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_gradient_descent_fp32_simd.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_gradient_descent_fp32_simd.h.in new file mode 100644 index 0000000000000000000000000000000000000000..aec68d7d194a056981ff99a45712db5ee6814b58 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/apply_proximal_gradient_descent_fp32_simd.h.in @@ -0,0 +1,64 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_@SIMD_INSTRUCTION@_H_ +#define NNACL_FP32_APPLY_PROXIMAL_GRADIENT_DESCENT_@SIMD_INSTRUCTION@_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/intrinsics/ms_simd_@SIMD_INSTRUCTION_LOWER@_instructions.h" + +#ifdef __cplusplus +extern "C" { +#endif +@SIMD_INSTRUCTION_BEGIN@ + +static inline int64_t ApplyProximalGradientDescentOpt@SIMD_INSTRUCTION@( + int64_t index, float *var, float alpha, float l1, float l2, float *delta, int64_t size) { + SIMD_F32 alpha_vec = SIMD_MOV_F32(alpha); + SIMD_F32 l1_vec = SIMD_MOV_F32(l1); + SIMD_F32 l2_vec = SIMD_MOV_F32(l2); + for (int64_t block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { + SIMD_F32 delta_vec = SIMD_LD_F32(delta + index); + SIMD_F32 prox_v_vec = SIMD_LD_F32(var + index); + + prox_v_vec = SIMD_SUB_F32(prox_v_vec, SIMD_MUL_F32(delta_vec, alpha_vec)); + SIMD_F32 tmp_vec1 = SIMD_FMADD_F32(l2_vec, alpha_vec, SIMD_MOV_F32(1)); + if (l1 > 0) { + SIMD_F32 tmp_vec2 = SIMD_MUL_F32(alpha_vec, l1_vec); + tmp_vec2 = SIMD_SUB_F32(SIMD_ABS_F32(prox_v_vec), tmp_vec2); + tmp_vec2 = SIMD_MAX_F32(tmp_vec2, SIMD_MOV_F32(0.0f)); + tmp_vec2 = SIMD_DIV_F32(tmp_vec2, tmp_vec1); + + SIMD_MASK greater_mask = SIMD_CMPGT_F32(SIMD_SET0_F32, prox_v_vec); + SIMD_MASK less_mask = SIMD_CMPLT_F32(SIMD_SET0_F32, prox_v_vec); + SIMD_F32 greater_v = SIMD_BLEND_F32(SIMD_MOV_F32(1), SIMD_SET0_F32, greater_mask); + SIMD_F32 less_v = SIMD_BLEND_F32(SIMD_MOV_F32(1), SIMD_SET0_F32, less_mask); + greater_v = SIMD_SUB_F32(greater_v, less_v); + + prox_v_vec = SIMD_MUL_F32(tmp_vec2, greater_v); + } else { + prox_v_vec = SIMD_DIV_F32(prox_v_vec, tmp_vec1); + } + SIMD_ST_F32(var + index, prox_v_vec); + } + + return index; +} + +@SIMD_INSTRUCTION_END@ +#ifdef __cplusplus +} +#endif +#endif \ No newline at end of file diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/arithmetic_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/arithmetic_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..47a21c3bfd2d685786e436cf3db36df46fd0ade5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/arithmetic_grad.c @@ -0,0 +1,154 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/arithmetic_grad.h" +#include +#include +#include "nnacl/fp32_grad/utils.h" +#include "nnacl/errorcode.h" + +void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -nom[i] / (denom[i] * denom[i]); + } +} + +void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = -a[i] * b[i] / (denom[i] * denom[i]); + } +} + +int ElementAbsGrad(const float *in1, const float *in2, float *out, int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = (in1[i] < 0.f) ? -in2[i] : ((in1[i] > 0.f) ? in2[i] : 0); + } + return NNACL_OK; +} + +void MaximumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims) { + int num_output0 = 1; + int num_output1 = 1; + bool same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_output0 *= input0_dims[idx]; + num_output1 *= input1_dims[idx]; + if (input0_dims[idx] != input1_dims[idx]) { + same_shape = false; + } + } + + if (same_shape) { + int input_iter[C8NUM] = {0}; + + // Iterate through input_data. + do { + size_t offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset] = input0[offset] > input1[offset] ? dy[offset] : 0.; + output1[offset] = input1[offset] >= input0[offset] ? dy[offset] : 0.; + } while (NextIndex(num_dims, input0_dims, input_iter)); + } else { + memset(output0, 0, num_output0 * sizeof(float)); // zero output + memset(output1, 0, num_output1 * sizeof(float)); // zero output + + int input_iter[C8NUM] = {0}; + int axes0[C5NUM] = {0}; + int axes1[C5NUM] = {0}; + int num_axes0 = 0; + int num_axes1 = 0; + for (int i = 0; i < num_dims; i++) { + if (input0_dims[i] == 1 && num_axes0 < C5NUM) { + axes0[num_axes0++] = i; + } + if (input1_dims[i] == 1 && num_axes1 < C5NUM) { + axes1[num_axes1++] = i; + } + } + + do { + size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0); + size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1); + size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset0] += input0[offset0] > input1[offset1] ? dy[yt_offset] : 0.; + output1[offset1] += input1[offset1] >= input0[offset0] ? dy[yt_offset] : 0.; + } while (NextIndex(num_dims, dy_dims, input_iter)); + } +} + +void MinimumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims) { + int num_output0 = 1; + int num_output1 = 1; + bool same_shape = true; + for (int idx = 0; idx < num_dims; ++idx) { + num_output0 *= input0_dims[idx]; + num_output1 *= input1_dims[idx]; + if (input0_dims[idx] != input1_dims[idx]) { + same_shape = false; + } + } + + if (same_shape) { + int input_iter[C8NUM] = {0}; + + // Iterate through input_data. + do { + size_t offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset] = input0[offset] < input1[offset] ? dy[offset] : 0.; + output1[offset] = input1[offset] <= input0[offset] ? dy[offset] : 0.; + } while (NextIndex(num_dims, input0_dims, input_iter)); + } else { + memset(output0, 0, num_output0 * sizeof(float)); // zero output + memset(output1, 0, num_output1 * sizeof(float)); // zero output + + int input_iter[C8NUM] = {0}; + int axes0[C5NUM] = {0}; + int axes1[C5NUM] = {0}; + int num_axes0 = 0; + int num_axes1 = 0; + for (int i = 0; i < num_dims; i++) { + if (input0_dims[i] == 1 && num_axes0 < C5NUM) { + axes0[num_axes0++] = i; + } + if (input1_dims[i] == 1 && num_axes1 < C5NUM) { + axes1[num_axes1++] = i; + } + } + + do { + size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0); + size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1); + size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter); + output0[offset0] += input0[offset0] < input1[offset1] ? dy[yt_offset] : 0.; + output1[offset1] += input1[offset1] <= input0[offset0] ? dy[yt_offset] : 0.; + } while (NextIndex(num_dims, dy_dims, input_iter)); + } +} + +int ElementSqrtGrad(const float *in1, const float *in2, float *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = 0.5f * in2[i] / in1[i]; + } + return NNACL_OK; +} + +int ElementRsqrtGrad(const float *in1, const float *in2, float *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = -0.5f * in2[i] * in1[i] * in1[1] * in1[i]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/arithmetic_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/arithmetic_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..f4408dd82aef09ed94468a2dadf4f5e315f79f58 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/arithmetic_grad.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ +#define NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size); +void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size); +int ElementAbsGrad(const float *in1, const float *in2, float *out, int element_size); +void MaximumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims); +void MinimumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, + const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims); +int ElementSqrtGrad(const float *in1, const float *in2, float *out, const int element_size); +int ElementRsqrtGrad(const float *in1, const float *in2, float *out, const int element_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/batch_norm_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/batch_norm_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..22aaac2893cafde28ac3c994aae422480abb7f37 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/batch_norm_grad.c @@ -0,0 +1,100 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "nnacl/fp32_grad/batch_norm_grad.h" + +void var2Invar(float *save_var, int size, float eps) { + for (int i = 0; i < size; i++) { + save_var[i] = 1.0f / sqrtf(save_var[i] + eps); + } +} + +static void backwardComputeDx(const float *in, const float *yt, const float *mean, const float *invar, + const float *scale, int size, int ch, const float *dbias, const float *dscale, float *dx, + float N, bool is_train) { + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + // dx_2 + int ix = i * ch + c; + dx[ix] = yt[ix]; + if (is_train) { + dx[ix] -= dbias[c] / N + (in[ix] - mean[c]) * dscale[c] * invar[c] / N; + } + dx[ix] *= scale[c] * invar[c]; + } + } +} + +#ifdef _MSC_VER +void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dbias, float *dscale, float *dx, bool is_train) { +#else +void backwardAll(const float *restrict in, const float *restrict yt, const float *restrict mean, + const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dbias, + float *restrict dscale, float *restrict dx, bool is_train) { +#endif + NNACL_CHECK_ZERO_RETURN(size); + float N = (float)size; + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // in fact, x_hat should also mul invar[c]. now put this step to the end. + float x_hat = in[ix] - mean[c]; + dscale[c] += (yt[ix] * x_hat); + } + } + for (int c = 0; c < ch; c++) { + dscale[c] *= invar[c]; + } + backwardComputeDx(in, yt, mean, invar, scale, size, ch, dbias, dscale, dx, N, is_train); +} + +#ifdef _MSC_VER +void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dbias, float *dscale) { +#else +void backwardP1(const float *restrict in, const float *restrict yt, const float *restrict mean, + const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dbias, + float *restrict dscale) { +#endif + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // in fact, x_hat should also mul invar[c]. now put this step to the end. + float x_hat = in[ix] - mean[c]; + dscale[c] += (yt[ix] * x_hat); + } + } + for (int c = 0; c < ch; c++) { + dscale[c] *= invar[c]; + } +} + +#ifdef _MSC_VER +void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *dscale, + const float *dbias, const float *scale, int size, int total_size, int ch, float *dx, bool is_train) { +#else +void backwardP2(const float *restrict in, const float *restrict yt, const float *restrict mean, + const float *restrict invar, const float *restrict dscale, const float *restrict dbias, + const float *restrict scale, int size, int total_size, int ch, float *restrict dx, bool is_train) { +#endif + NNACL_CHECK_ZERO_RETURN(total_size); + const float N = (float)total_size; + backwardComputeDx(in, yt, mean, invar, scale, size, ch, dbias, dscale, dx, N, is_train); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/batch_norm_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/batch_norm_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..24f22a20a00698b38c235fa9a2ffa8d5756d2853 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/batch_norm_grad.h @@ -0,0 +1,37 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_ +#define CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_ + +#include "nnacl/fp32_grad/batch_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void var2Invar(float *save_var, int size, float eps); +void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dbias, float *dscale, float *dx, bool is_train); +void backwardP1(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dbias, float *dscale); +void backwardP2(const float *in, const float *yt, const float *mean, const float *invar, const float *dscale, + const float *dbias, const float *scale, int size, int total_size, int ch, float *dx, bool is_train); +#ifdef __cplusplus +} +#endif + +#endif // CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_FP32_GRAD_BATCH_NORM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/batch_norm_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/batch_norm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..7998d632c695419129d495e049494ddd08a79cec --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/batch_norm_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_ +#define NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct BNGradParameter { + OpParameter op_parameter_; + float epsilon_; + bool is_training_; +} BNGradParameter; + +#endif // NNACL_FP32_GRAD_BATCH_NORM_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/binary_cross_entropy.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/binary_cross_entropy.c new file mode 100644 index 0000000000000000000000000000000000000000..cf2f867c82e5abce4f7a85709cbb938b984be9fa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/binary_cross_entropy.c @@ -0,0 +1,75 @@ +/* + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl/fp32_grad/binary_cross_entropy.h" + +static void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const float *input_x, + const float *input_y, const float *weight, float *loss, float *tmp_loss, + bool weight_defined) { + const float epsilon = 1e-12; + + if (reduction == Reduction_None) { + if (weight_defined) { + for (int i = 0; i < input_size; i++) { + float value = + -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); + loss[i] = value; + } + } else { + for (int i = 0; i < input_size; i++) { + float value = -(input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); + loss[i] = value; + } + } + } else { + if (weight_defined) { + for (int i = 0; i < input_size; i++) { + float value = + -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); + tmp_loss[i] = value; + } + } else { + for (int i = 0; i < input_size; i++) { + float value = -(input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); + tmp_loss[i] = value; + } + } + } +} + +void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y, + const float *weight, float *loss, float *tmp_loss, bool weight_defined) { + loss[0] = 0.0f; + BinaryCrossEntropyLossKernel(input_size, reduction, input_x, input_y, weight, loss, tmp_loss, weight_defined); + if (reduction != Reduction_None) { + if (input_size % 2 == 1) { + tmp_loss[0] += tmp_loss[input_size - 1]; + } + for (int stride = input_size / 2; stride > 0; stride = stride / 2) { + for (int i = 0; i < stride; i++) { + tmp_loss[i] += tmp_loss[i + stride]; + } + if (stride > 2 && stride % 2 == 1) { + tmp_loss[0] += tmp_loss[stride - 1]; + } + } + loss[0] += tmp_loss[0]; + if (reduction == Reduction_Mean) { + loss[0] /= input_size; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/binary_cross_entropy.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/binary_cross_entropy.h new file mode 100644 index 0000000000000000000000000000000000000000..b02b87029e7d7557e0fa7ed961d38dec9cfdcd26 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/binary_cross_entropy.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BINARY_CROSS_ENTROPY_H_ +#define NNACL_BINARY_CROSS_ENTROPY_H_ + +#include "nnacl/op_base.h" + +typedef struct { + OpParameter op_parameter_; + int reduction; +} BinaryCrossEntropyParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y, + const float *weight, float *loss, float *tmp_loss, bool weight_defined); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_BINARY_CROSS_ENTROPY_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/binary_cross_entropy_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/binary_cross_entropy_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..12d203563c1fe9b9dd0bbf9d2881345a14dc3746 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/binary_cross_entropy_grad.c @@ -0,0 +1,56 @@ +/* + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/binary_cross_entropy_grad.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y, + const float *weight, const float *dloss, float *dx, bool weight_defined) { + const float epsilon = 1e-12f; + if (reduction == Reduction_None) { + if (weight_defined) { + for (int i = 0; i < input_size; i++) { + float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); + float value = weight[i] * (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss[i]; + } + } else { + for (int i = 0; i < input_size; i++) { + float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); + float value = (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss[i]; + } + } + } else { + float dloss1 = dloss[0]; + if (reduction == Reduction_Mean) { + dloss1 = dloss[0] / input_size; + } + for (int i = 0; i < input_size; i++) { + if (weight_defined) { + float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); + float value = weight[i] * (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss1; + } else { + float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); + float value = (input_x[i] - input_y[i]) / denominator; + dx[i] = value * dloss1; + } + } + } + return 0; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/binary_cross_entropy_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/binary_cross_entropy_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..7b5a3f1f450d9c91cb2ea7a9962b6a6f09d23ca0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/binary_cross_entropy_grad.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_BINARY_CROSS_ENTROPY_GRAD_H_ +#define NNACL_BINARY_CROSS_ENTROPY_GRAD_H_ + +#include "nnacl/op_base.h" + +typedef struct { + OpParameter op_parameter_; + int reduction; +} BinaryCrossEntropyGradParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y, + const float *weight, const float *dloss, float *dx, bool weight_defined); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_BINARY_CROSS_ENTROPY_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/convolution_grad_filter.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/convolution_grad_filter.c new file mode 100644 index 0000000000000000000000000000000000000000..8003b88cc62d71c1e0360ea3ea07e5d90dddb490 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/convolution_grad_filter.c @@ -0,0 +1,380 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/convolution_grad_filter.h" +#include "nnacl/errorcode.h" +#ifdef ENABLE_ARM +#include +#endif + +#ifdef ENABLE_ARM +static int FilterGrad16Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + for (; i_c < (out_ch & ~15); i_c += 16) { + float32x4_t sum_03_4 = vdupq_n_f32(0.0f); + float32x4_t sum_47_4 = vdupq_n_f32(0.0f); + float32x4_t sum_9x_4 = vdupq_n_f32(0.0f); + float32x4_t sum_12x_4 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); + sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); + + float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); + float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); + sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); + + float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8); + float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8); + sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4); + + float32x4_t x_12x_4 = vld1q_f32(x_addr + offset_x + 12); + float32x4_t dy_12x_4 = vld1q_f32(dy_addr + offset_dy + 12); + sum_12x_4 = vmlaq_f32(sum_12x_4, x_12x_4, dy_12x_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; + + dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; + dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; + dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; + dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; + + dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0]; + dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1]; + dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2]; + dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3]; + + dw[(i_c + 12) * k_spatial + k_idx] = sum_12x_4[0]; + dw[(i_c + 13) * k_spatial + k_idx] = sum_12x_4[1]; + dw[(i_c + 14) * k_spatial + k_idx] = sum_12x_4[2]; + dw[(i_c + 15) * k_spatial + k_idx] = sum_12x_4[3]; + } + return i_c; +} + +static int FilterGrad12Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + if ((out_ch - i_c) >= 12) { + float32x4_t sum_03_4 = vdupq_n_f32(0.0f); + float32x4_t sum_47_4 = vdupq_n_f32(0.0f); + float32x4_t sum_9x_4 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); + sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); + + float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); + float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); + sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); + + float32x4_t x_9x_4 = vld1q_f32(x_addr + offset_x + 8); + float32x4_t dy_9x_4 = vld1q_f32(dy_addr + offset_dy + 8); + sum_9x_4 = vmlaq_f32(sum_9x_4, x_9x_4, dy_9x_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; + + dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; + dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; + dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; + dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; + + dw[(i_c + 8) * k_spatial + k_idx] = sum_9x_4[0]; + dw[(i_c + 9) * k_spatial + k_idx] = sum_9x_4[1]; + dw[(i_c + 10) * k_spatial + k_idx] = sum_9x_4[2]; + dw[(i_c + 11) * k_spatial + k_idx] = sum_9x_4[3]; + + i_c += 12; + } + return i_c; +} + +static int FilterGrad8Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + + if ((out_ch - i_c) >= 8) { + float32x4_t sum_03_4 = vdupq_n_f32(0.0f); + float32x4_t sum_47_4 = vdupq_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_03_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_03_4 = vld1q_f32(dy_addr + offset_dy); + sum_03_4 = vmlaq_f32(sum_03_4, x_03_4, dy_03_4); + + float32x4_t x_47_4 = vld1q_f32(x_addr + offset_x + 4); + float32x4_t dy_47_4 = vld1q_f32(dy_addr + offset_dy + 4); + sum_47_4 = vmlaq_f32(sum_47_4, x_47_4, dy_47_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_03_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_03_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_03_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_03_4[3]; + + dw[(i_c + 4) * k_spatial + k_idx] = sum_47_4[0]; + dw[(i_c + 5) * k_spatial + k_idx] = sum_47_4[1]; + dw[(i_c + 6) * k_spatial + k_idx] = sum_47_4[2]; + dw[(i_c + 7) * k_spatial + k_idx] = sum_47_4[3]; + i_c += 8; + } + return i_c; +} +static int FilterGrad4Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + if ((out_ch - i_c) >= 4) { + float32x4_t sum_4 = vdupq_n_f32(0.0f); + + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x4_t x_4 = vld1q_f32(x_addr + offset_x); + float32x4_t dy_4 = vld1q_f32(dy_addr + offset_dy); + sum_4 = vmlaq_f32(sum_4, x_4, dy_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_4[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_4[1]; + dw[(i_c + 2) * k_spatial + k_idx] = sum_4[2]; + dw[(i_c + 3) * k_spatial + k_idx] = sum_4[3]; + i_c += 4; + } + return i_c; +} + +static int Filtergrad2Arm(const float *x, const float *dy, int i_c, int k_idx, float *dw, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + + if ((out_ch - i_c) >= 2) { + float32x2_t sum_2 = vdup_n_f32(0.0f); + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + + float32x2_t x_4 = vld1_f32(x_addr + offset_x); + float32x2_t dy_4 = vld1_f32(dy_addr + offset_dy); + sum_2 = vmla_f32(sum_2, x_4, dy_4); + } + } + } + dw[(i_c + 0) * k_spatial + k_idx] = sum_2[0]; + dw[(i_c + 1) * k_spatial + k_idx] = sum_2[1]; + i_c += 2; + } + return i_c; +} +#endif +int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count, + const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int batch = conv_param->output_batch_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_h = conv_param->output_h_; + int out_w = conv_param->output_w_; + + int m = out_h * out_w; + int x_size = in_h * in_w * in_ch; + int y_size = out_ch * out_h * out_w; + int k_spatial = k_w * k_h; + + for (int i_k = 0; i_k < count; i_k++) { + int k_idx = start + i_k; + int i_kh = k_idx / k_w; + int i_kw = k_idx % k_w; + int i_c = 0; +#ifdef ENABLE_ARM + i_c = FilterGrad16Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad12Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad8Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = FilterGrad4Arm(x, dy, i_c, k_idx, dw, conv_param); + i_c = Filtergrad2Arm(x, dy, i_c, k_idx, dw, conv_param); +#endif + for (; i_c < out_ch; i_c++) { + float sum = 0; + for (int b = 0; b < batch; ++b) { + const float *x_addr = &x[b * x_size]; + const float *dy_addr = &dy[b * y_size]; + + for (int i = 0; i < m; i++) { + int idx = i; + int input_h = idx / out_w * conv_param->stride_h_; + int input_w = idx % out_w * conv_param->stride_w_; + int input_row = -conv_param->pad_u_ + i_kh + input_h; + int input_col = -conv_param->pad_l_ + i_kw + input_w; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset_x = (input_row * in_w + input_col) * out_ch + i_c; + int offset_dy = idx * out_ch + i_c; + sum += x_addr[offset_x] * dy_addr[offset_dy]; + } + } + } + dw[i_c * k_spatial + k_idx] = sum; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/convolution_grad_filter.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/convolution_grad_filter.h new file mode 100644 index 0000000000000000000000000000000000000000..40ec1700b26b8a735bf79a0738419a49f767ce5c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/convolution_grad_filter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ +#define NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ + +#include +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwFilterGrad(const float *x, const float *dy, float *dw, int start, int count, const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_CONVOLUTION_GRAD_FILTER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/convolution_grad_input.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/convolution_grad_input.c new file mode 100644 index 0000000000000000000000000000000000000000..6f959b05054aa8249f24c9bfde81ff88354dea2a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/convolution_grad_input.c @@ -0,0 +1,100 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/convolution_grad_input.h" +#include "nnacl/errorcode.h" +#include "nnacl/op_base.h" +#ifdef ENABLE_ARM +#include +#endif + +int ConvDwInputGrad(const float *dy, const float *w, float *dx, int start, int count, const ConvParameter *conv_param) { + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int out_ch = conv_param->output_channel_; + int in_ch = conv_param->input_channel_; + int out_spatial = conv_param->output_h_ * conv_param->output_w_; + int k_h = conv_param->kernel_h_; + int k_w = conv_param->kernel_w_; + int k_spatial = k_h * k_w; + int end = start + count; + + int j = start; + for (; j <= (end - C4NUM); j += C4NUM) { + float *c = dx + j; + const float *mat_b_0 = w + (j + C0NUM) * k_spatial; + const float *mat_b_1 = w + (j + C1NUM) * k_spatial; + const float *mat_b_2 = w + (j + C2NUM) * k_spatial; + const float *mat_b_3 = w + (j + C3NUM) * k_spatial; + + for (int si = 0; si < out_spatial; si++) { + const float *a = dy + j + si * out_ch; +#ifdef ENABLE_ARM + float32x4_t mat_a = vld1q_f32(a); +#else + float mat_a[C4NUM] = {a[C0NUM], a[C1NUM], a[C2NUM], a[C3NUM]}; +#endif + int output_row = (si) / out_w; + int output_col = (si) % out_w; + for (int k = 0; k < k_spatial; k++) { + int row_stride_offset = output_row * conv_param->stride_h_; + int col_stride_offset = output_col * conv_param->stride_w_; + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = -conv_param->pad_u_ + kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = -conv_param->pad_l_ + kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; +#ifdef ENABLE_ARM + float32x4_t mat_b = {mat_b_0[k], mat_b_1[k], mat_b_2[k], mat_b_3[k]}; + float32x4_t mat_c = vld1q_f32(c + offset); + mat_c = vmlaq_f32(mat_c, mat_b, mat_a); + vst1q_f32(c + offset, mat_c); +#else + c[offset + C0NUM] += mat_a[C0NUM] * mat_b_0[k]; + c[offset + C1NUM] += mat_a[C1NUM] * mat_b_1[k]; + c[offset + C2NUM] += mat_a[C2NUM] * mat_b_2[k]; + c[offset + C3NUM] += mat_a[C3NUM] * mat_b_3[k]; +#endif + } + } + } + } + + for (; j < end; j++) { + float *c = dx + j; + const float *b = w + j * k_spatial; + for (int si = 0; si < out_spatial; si++) { + const float *a = dy + j + si * out_ch; + int output_row = si / out_w; + int output_col = si % out_w; + int row_stride_offset = output_row * conv_param->stride_h_; + int col_stride_offset = output_col * conv_param->stride_w_; + for (int k = 0; k < k_spatial; k++) { + int kernel_row = k / k_w; + int kernel_col = k % k_w; + int input_row = -conv_param->pad_u_ + kernel_row * conv_param->dilation_h_ + row_stride_offset; + int input_col = -conv_param->pad_l_ + kernel_col * conv_param->dilation_w_ + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_h)) && ((unsigned)(input_col) < (unsigned)(in_w))) { + int offset = (input_row * in_w + input_col) * in_ch; + c[offset] += a[0] * b[k]; + } + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/convolution_grad_input.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/convolution_grad_input.h new file mode 100644 index 0000000000000000000000000000000000000000..acaa8884c391aa943d64961732414163c1170b54 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/convolution_grad_input.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ +#define NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ + +#include +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConvDwInputGrad(const float *dy, const float *w, float *dx, int start, int count, const ConvParameter *conv_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_CONVOLUTION_GRAD_INPUT_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/dropout_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/dropout_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..b54924bada4d1b160fdcdf1c4ee5170ecf706b5d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/dropout_grad.c @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/dropout_grad.h" + +void DropoutGrad(const float *yt_ptr, const float *mask, float *output_ptr, int length, float scale) { + for (int i = 0; i < length; i++) { + output_ptr[i] = yt_ptr[i] * mask[i] * scale; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/dropout_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/dropout_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..a107cf90432ae18ac7c3ec9bd27dda0f39d4f6b9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/dropout_grad.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_DROPOUT_GRAD_H_ +#define NNACL_FP32_GRAD_DROPOUT_GRAD_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void DropoutGrad(const float *yt_ptr, const float *mask, float *output_ptr, int length, float ratio); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_DROPOUT_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/dropout_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/dropout_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..3989a6712933e9c744d44208b3b2b7e298dc1d09 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/dropout_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_DROPOUT_PARAMETER_H_ +#define NNACL_FP32_GRAD_DROPOUT_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct { + OpParameter op_parameter_; + float ratio_; +} DropoutParameter; + +#endif // NNACL_FP32_GRAD_DROPOUT_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/gemm.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/gemm.c new file mode 100644 index 0000000000000000000000000000000000000000..45b879177faa8874d9a2d073914b8f531ab040a2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/gemm.c @@ -0,0 +1,855 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/gemm.h" +#include +#ifdef __ARM_NEON +#include +#endif +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +#ifdef _MSC_VER +void AddMatrix(const float *v1, float *v2, float beta, int row, int col, int stride) { +#else +void AddMatrix(const float *restrict v1, float *restrict v2, float beta, int row, int col, int stride) { +#endif + const float *src_ptr = v1; + float *dst_ptr = v2; + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + dst_ptr[c] += beta * src_ptr[c]; + } + src_ptr += stride; + dst_ptr += stride; + } +} + +int MatSize(int row, int col, int round) { + int res = UP_ROUND(row, round) * col; + int res1 = UP_ROUND(col, round) * row; + return (res > res1) ? res : res1; +} + +int MatSizeTotal(int row, int col, int deep, int stride) { +#ifdef ENABLE_ARM32 + const int num0 = C4NUM; +#elif ENABLE_AVX + const int num0 = C6NUM; +#else + const int num0 = C12NUM; +#endif + +#ifdef ENABLE_AVX + const int num1 = C16NUM; +#else + const int num1 = C8NUM; +#endif + int res = MatSize(row, deep, num0) + MatSize(col, deep, num1); + if (stride > 0) res += row * stride; + return res; +} +#ifdef ENABLE_ARM32 +static void RowMajor2Row4MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + for (int c = 0; c < col; c++) { + int cd8 = c / 4; + int cm8 = c % 4; + dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c]; + } + } +} +#endif + +void RowMajor2Row8MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + for (int c = 0; c < col; c++) { + int cd8 = c / C8NUM; + int cm8 = c % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c]; + } + } + return; +} + +#ifdef ENABLE_ARM64 +static void RowMajor2Col12MajorStrideArm64(const float *src_c, float *dst_c, int lead) { + size_t stride = lead * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s}, [x10], %[stride]\n" + "ld1 {v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s}, [x10], %[stride]\n" + "ld1 {v3.4s}, [x10], %[stride]\n" + + "ld1 {v4.4s}, [x10], %[stride]\n" + "ld1 {v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s}, [x10], %[stride]\n" + "ld1 {v7.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v0.4s, v1.4s\n" + "zip2 v13.4s, v0.4s, v1.4s\n" + "zip1 v14.4s, v2.4s, v3.4s\n" + "zip2 v15.4s, v2.4s, v3.4s\n" + + "ld1 {v8.4s}, [x10], %[stride]\n" + "ld1 {v9.4s}, [x10], %[stride]\n" + "ld1 {v10.4s}, [x10], %[stride]\n" + "ld1 {v11.4s}, [x10], %[stride]\n" + + "zip1 v16.4s, v4.4s, v5.4s\n" + "zip2 v17.4s, v4.4s, v5.4s\n" + "zip1 v18.4s, v6.4s, v7.4s\n" + "zip2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v23.2d, v12.2d, v14.2d\n" + "trn1 v26.2d, v13.2d, v15.2d\n" + "trn2 v29.2d, v13.2d, v15.2d\n" + + "trn1 v21.2d, v16.2d, v18.2d\n" + "trn2 v24.2d, v16.2d, v18.2d\n" + "trn1 v27.2d, v17.2d, v19.2d\n" + "trn2 v30.2d, v17.2d, v19.2d\n" + + "zip1 v12.4s, v8.4s, v9.4s\n" + "zip2 v13.4s, v8.4s, v9.4s\n" + "zip1 v14.4s, v10.4s, v11.4s\n" + "zip2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v22.2d, v12.2d, v14.2d\n" + "trn2 v25.2d, v12.2d, v14.2d\n" + "trn1 v28.2d, v13.2d, v15.2d\n" + "trn2 v31.2d, v13.2d, v15.2d\n" + + "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n" + "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n" + "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif // ENABLE_ARM64 + +#ifdef ENABLE_ARM32 +void RowMajor2Col12MajorStrideArm32(const float *src_c, float *dst_c, int lead) { + size_t stride = lead * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q10}, [r10], %[stride]\n" + "vld1.32 {q13}, [r10], %[stride]\n" + + "vtrn.32 d0, d6\n" + "vtrn.32 d1, d7\n" + "vtrn.32 d20, d26\n" + "vtrn.32 d21, d27\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q8}, [r10], %[stride]\n" + "vld1.32 {q11}, [r10], %[stride]\n" + "vld1.32 {q14}, [r10], %[stride]\n" + + "vswp d1, d20\n" + "vswp d7, d26\n" + + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q9}, [r10], %[stride]\n" + "vld1.32 {q12}, [r10], %[stride]\n" + "vld1.32 {q15}, [r10], %[stride]\n" + + "vtrn.32 d2, d16\n" + "vtrn.32 d3, d17\n" + "vtrn.32 d22, d28\n" + "vtrn.32 d23, d29\n" + + "vswp d3, d22\n" + "vswp d17, d28\n" + + "vtrn.32 d4, d18\n" + "vtrn.32 d5, d19\n" + "vtrn.32 d24, d30\n" + "vtrn.32 d25, d31\n" + + "vswp d5, d24\n" + "vswp d19, d30\n" + + "vst1.32 {q0, q1}, [r12]!\n" + "vst1.32 {q2, q3}, [r12]!\n" + "vst1.32 {q8, q9}, [r12]!\n" + "vst1.32 {q10, q11}, [r12]!\n" + "vst1.32 {q12, q13}, [r12]!\n" + "vst1.32 {q14, q15}, [r12]!\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +} +#endif // ENABLE_ARM32 + +#ifndef ENABLE_ARM32 +void RowMajor2Row12MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + for (int c = 0; c < col; c++) { + int cd8 = c / C12NUM; + int cm8 = c % C12NUM; + dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c]; + } + } + return; +} + +void RowMajor2Col12MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) { + size_t row_up_12 = UP_ROUND(row, C12NUM); + size_t row12 = row / C12NUM * C12NUM; + size_t col4 = col / C4NUM * C4NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row12; ri += C12NUM) { + size_t ci = 0; + for (; ci < col4; ci += C4NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; + + /* 12x4 row-major to col-major */ +#ifdef ENABLE_ARM64 + RowMajor2Col12MajorStrideArm64(src_c, dst_c, lead); +#elif ENABLE_ARM32 + RowMajor2Col12MajorStrideArm32(src_c, dst_c, lead); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C12NUM; + for (int i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C12NUM * lead; + dst_r += C12NUM * col; + } + + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C12NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + + for (; ri < row_up_12; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + dst_r += 1; + } +} +#endif + +#ifdef ENABLE_ARM64 +static void RowMajor2Col8MajorStrideArm64(const float *src_c, float *dst_c, int lead) { + /* 8x8 row-major to col-major */ + size_t stride = lead * sizeof(float); + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.4s, v1.4s}, [x10], %[stride]\n" + "ld1 {v2.4s, v3.4s}, [x10], %[stride]\n" + "ld1 {v4.4s, v5.4s}, [x10], %[stride]\n" + "ld1 {v6.4s, v7.4s}, [x10], %[stride]\n" + + "zip1 v8.4s, v0.4s, v2.4s\n" + "zip2 v9.4s, v0.4s, v2.4s\n" + "zip1 v10.4s, v4.4s, v6.4s\n" + "zip2 v11.4s, v4.4s, v6.4s\n" + + "ld1 {v16.4s, v17.4s}, [x10], %[stride]\n" + "ld1 {v18.4s, v19.4s}, [x10], %[stride]\n" + "ld1 {v20.4s, v21.4s}, [x10], %[stride]\n" + "ld1 {v22.4s, v23.4s}, [x10], %[stride]\n" + + "zip1 v12.4s, v1.4s, v3.4s\n" + "zip2 v13.4s, v1.4s, v3.4s\n" + "zip1 v14.4s, v5.4s, v7.4s\n" + "zip2 v15.4s, v5.4s, v7.4s\n" + + "trn1 v0.2d, v8.2d, v10.2d\n" + "trn2 v1.2d, v8.2d, v10.2d\n" + "trn1 v2.2d, v9.2d, v11.2d\n" + "trn2 v3.2d, v9.2d, v11.2d\n" + + "zip1 v24.4s, v16.4s, v18.4s\n" + "zip2 v25.4s, v16.4s, v18.4s\n" + "zip1 v26.4s, v20.4s, v22.4s\n" + "zip2 v27.4s, v20.4s, v22.4s\n" + + "trn1 v4.2d, v12.2d, v14.2d\n" + "trn2 v5.2d, v12.2d, v14.2d\n" + "trn1 v6.2d, v13.2d, v15.2d\n" + "trn2 v7.2d, v13.2d, v15.2d\n" + + "zip1 v28.4s, v17.4s, v19.4s\n" + "zip2 v29.4s, v17.4s, v19.4s\n" + "zip1 v30.4s, v21.4s, v23.4s\n" + "zip2 v31.4s, v21.4s, v23.4s\n" + + "trn1 v16.2d, v24.2d, v26.2d\n" + "trn2 v17.2d, v24.2d, v26.2d\n" + "trn1 v18.2d, v25.2d, v27.2d\n" + "trn2 v19.2d, v25.2d, v27.2d\n" + + "trn1 v20.2d, v28.2d, v30.2d\n" + "trn2 v21.2d, v28.2d, v30.2d\n" + "trn1 v22.2d, v29.2d, v31.2d\n" + "trn2 v23.2d, v29.2d, v31.2d\n" + + "st1 {v0.4s}, [x11], #16\n" + "st1 {v16.4s}, [x11], #16\n" + "st1 {v1.4s}, [x11], #16\n" + "st1 {v17.4s}, [x11], #16\n" + "st1 {v2.4s}, [x11], #16\n" + "st1 {v18.4s}, [x11], #16\n" + "st1 {v3.4s}, [x11], #16\n" + "st1 {v19.4s}, [x11], #16\n" + "st1 {v4.4s}, [x11], #16\n" + "st1 {v20.4s}, [x11], #16\n" + "st1 {v5.4s}, [x11], #16\n" + "st1 {v21.4s}, [x11], #16\n" + "st1 {v6.4s}, [x11], #16\n" + "st1 {v22.4s}, [x11], #16\n" + "st1 {v7.4s}, [x11], #16\n" + "st1 {v23.4s}, [x11], #16\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif // ENABLE_ARM64 + +#ifdef ENABLE_ARM32 +#ifndef SUPPORT_NNIE +static void RowMajor2Col8MajorStrideArm32(const float *src_c, float *dst_c, size_t col) { + /* 8x4 row-major to col-major */ + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r11, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q4}, [r10], %[stride]\n" + "vld1.32 {q6}, [r10], %[stride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q5}, [r10], %[stride]\n" + "vld1.32 {q7}, [r10], %[stride]\n" + + "vswp d1, d8\n" + "vswp d5, d12\n" + + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vswp d3, d10\n" + "vswp d7, d14\n" + + "vst1.32 {q0, q1}, [r11]!\n" + "vst1.32 {q2, q3}, [r11]!\n" + "vst1.32 {q4, q5}, [r11]!\n" + "vst1.32 {q6, q7}, [r11]!\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +} + +#else +static void RowMajor2Col8MajorStrideArm32Nnie(const float *src_c, float *dst_c, size_t col) { + /* 8x4 row-major to col-major */ + size_t stride = col * sizeof(float); + asm volatile( + "mov r10, %[src_c]\n" + "mov r7, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q4}, [r10], %[stride]\n" + "vld1.32 {q6}, [r10], %[stride]\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d8, d12\n" + "vtrn.32 d9, d13\n" + + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + "vld1.32 {q5}, [r10], %[stride]\n" + "vld1.32 {q7}, [r10], %[stride]\n" + + "vswp d1, d8\n" + "vswp d5, d12\n" + + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + "vtrn.32 d10, d14\n" + "vtrn.32 d11, d15\n" + + "vswp d3, d10\n" + "vswp d7, d14\n" + + "vst1.32 {q0, q1}, [r7]!\n" + "vst1.32 {q2, q3}, [r7]!\n" + "vst1.32 {q4, q5}, [r7]!\n" + "vst1.32 {q6, q7}, [r7]!\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "r10", "r7", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +} +#endif // SUPPORT_NNIE +#endif // ENABLE_ARM32 + +void RowMajor2Col8MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) { + size_t row8 = row / C8NUM * C8NUM; +#ifdef ENABLE_ARM64 + size_t col_skip = col / C8NUM * C8NUM; + size_t skip_size = C8NUM; +#else + size_t col_skip = col / C4NUM * C4NUM; + size_t skip_size = C4NUM; +#endif + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row8; ri += C8NUM) { + size_t ci = 0; + for (; ci < col_skip; ci += skip_size) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + +#ifdef ENABLE_ARM64 + RowMajor2Col8MajorStrideArm64(src_c, dst_c, lead); +#elif ENABLE_ARM32 +#ifndef SUPPORT_NNIE + RowMajor2Col8MajorStrideArm32(src_c, dst_c, lead); +#else + RowMajor2Col8MajorStrideArm32Nnie(src_c, dst_c, lead); +#endif +#else + for (int tr = 0; tr < 8; tr++) { + for (int tc = 0; tc < 4; tc++) { + dst_c[tc * 8 + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C8NUM; + for (int i = 0; i < C8NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C8NUM * lead; + dst_r += C8NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C8NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + return; +} +#ifdef ENABLE_ARM32 +static void RowMajor2Col4MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) { + size_t row8 = row / C4NUM * C4NUM; + size_t col4 = col / C4NUM * C4NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + size_t ri = 0; + for (; ri < row8; ri += C4NUM) { + size_t ci = 0; + for (; ci < col4; ci += C4NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C4NUM; + + /* 4x4 row-major to col-major */ +#ifdef ENABLE_ARM32 + size_t stride = col * 4; + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.32 {q0}, [r10], %[stride]\n" + "vld1.32 {q1}, [r10], %[stride]\n" + "vld1.32 {q2}, [r10], %[stride]\n" + "vld1.32 {q3}, [r10], %[stride]\n" + + "vtrn.32 d0, d2\n" + "vtrn.32 d1, d3\n" + "vtrn.32 d4, d6\n" + "vtrn.32 d5, d7\n" + + "vswp d1, d4\n" + "vswp d3, d6\n" + + "vst1.32 {q0}, [r12]!\n" + "vst1.32 {q1}, [r12]!\n" + "vst1.32 {q2}, [r12]!\n" + "vst1.32 {q3}, [r12]!\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) + : "r10", "r12", "q0", "q1", "q2", "q3"); +#else + for (int tr = 0; tr < C4NUM; tr++) { + for (int tc = 0; tc < C4NUM; tc++) { + dst_c[tc * C4NUM + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C4NUM; + for (size_t i = 0; i < C4NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C4NUM * col; + dst_r += C4NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C4NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + return; +} +#endif + +void RowMajor2Row6MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + int max = 0; + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + int c = 0; + for (; c < col; c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + int offset = cd6 * C6NUM * row + r * C6NUM + cm6; + dst_ptr[offset] = src[c]; + if (offset > max) { + max = offset; + } + } + for (; c < UP_ROUND(col, C6NUM); c++) { + int cd6 = c / C6NUM; + int cm6 = c % C6NUM; + int offset = cd6 * C6NUM * row + r * C6NUM + cm6; + dst_ptr[offset] = 0.0f; + if (offset > max) { + max = offset; + } + } + } + return; +} + +void RowMajor2Col6MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + int totalRow = UP_ROUND(row, C6NUM); + int row6 = row / C6NUM * C6NUM; + int col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + int ri = 0; + for (; ri < row6; ri += C6NUM) { + int ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + +#ifdef ENABLE_AVX + __m256 src0 = _mm256_loadu_ps(src_c); + __m256 src1 = _mm256_loadu_ps(src_c + lead); + __m256 src2 = _mm256_loadu_ps(src_c + 2 * lead); + __m256 src3 = _mm256_loadu_ps(src_c + 3 * lead); + __m256 src4 = _mm256_loadu_ps(src_c + 4 * lead); + __m256 src5 = _mm256_loadu_ps(src_c + 5 * lead); + __m256 trans0 = _mm256_unpacklo_ps(src0, src1); + __m256 trans1 = _mm256_unpacklo_ps(src2, src3); + __m256 trans2 = _mm256_unpacklo_ps(src4, src5); + __m256 trans3 = _mm256_unpackhi_ps(src0, src1); + __m256 trans4 = _mm256_unpackhi_ps(src2, src3); + __m256 trans5 = _mm256_unpackhi_ps(src4, src5); + __m128 lo0 = _mm256_castps256_ps128(trans0); + __m128 lo1 = _mm256_castps256_ps128(trans1); + __m128 lo2 = _mm256_castps256_ps128(trans2); + __m128 lo3 = _mm256_castps256_ps128(trans3); + __m128 lo4 = _mm256_castps256_ps128(trans4); + __m128 lo5 = _mm256_castps256_ps128(trans5); + __m128 hi0 = _mm256_extractf128_ps(trans0, 1); + __m128 hi1 = _mm256_extractf128_ps(trans1, 1); + __m128 hi2 = _mm256_extractf128_ps(trans2, 1); + __m128 hi3 = _mm256_extractf128_ps(trans3, 1); + __m128 hi4 = _mm256_extractf128_ps(trans4, 1); + __m128 hi5 = _mm256_extractf128_ps(trans5, 1); + __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2)); + __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0)); + __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0)); + __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2)); + _mm_storeu_ps(dst_c, res0); + _mm_storeu_ps(dst_c + C4NUM, res1); + _mm_storeu_ps(dst_c + C8NUM, res2); + _mm_storeu_ps(dst_c + C12NUM, res3); + _mm_storeu_ps(dst_c + C16NUM, res4); + _mm_storeu_ps(dst_c + C20NUM, res5); + _mm_storeu_ps(dst_c + C24NUM, res6); + _mm_storeu_ps(dst_c + C28NUM, res7); + _mm_storeu_ps(dst_c + C32NUM, res8); + _mm_storeu_ps(dst_c + C36NUM, res9); + _mm_storeu_ps(dst_c + C40NUM, res10); + _mm_storeu_ps(dst_c + C44NUM, res11); +#else + for (int tr = 0; tr < C6NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C6NUM + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C6NUM; + for (int i = 0; i < C6NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C6NUM * lead; + dst_r += C6NUM * col; + } + + for (; ri < row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C6NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + + for (; ri < totalRow; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C6NUM] = 0; + } + dst_r += 1; + } +} + +void RowMajor2Col16MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + int row16 = row / C16NUM * C16NUM; + int col8 = col / C8NUM * C8NUM; + const float *src_r = src_ptr; + float *dst_r = dst_ptr; + + int ri = 0; + for (; ri < row16; ri += C16NUM) { + int ci = 0; + for (; ci < col8; ci += C8NUM) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; +#ifdef ENABLE_AVX + Transpose8X8Fp32Avx(src_c, dst_c, lead, C16NUM); + Transpose8X8Fp32Avx(src_c + C8NUM * lead, dst_c + C8NUM, lead, C16NUM); +#else + for (int tr = 0; tr < C16NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C16NUM + tr] = src_c[tr * lead + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float *src_c = src_r + ci; + float *dst_c = dst_r + ci * C16NUM; + for (int i = 0; i < C16NUM; i++) { + dst_c[i] = src_c[i * lead]; + } + } + src_r += C16NUM * lead; + dst_r += C16NUM * col; + } + for (; ri < row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C16NUM] = src_r[i]; + } + src_r += lead; + dst_r += 1; + } + + int total_row = UP_ROUND(row, C16NUM); + for (; ri < total_row; ri++) { + for (int i = 0; i < col; i++) { + dst_r[i * C16NUM] = 0; + } + dst_r += 1; + } +} + +void RowMajor2Row16MajorStride(const float *src_ptr, float *dst_ptr, int row, int col, int lead) { + int max = 0; + for (int r = 0; r < row; r++) { + const float *src = src_ptr + r * lead; + int c = 0; + for (; c < col; c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c]; + if ((cd16 * C16NUM * row + r * C16NUM + cm16) > max) max = cd16 * C16NUM * row + r * C16NUM + cm16; + } + for (; c < UP_ROUND(col, C16NUM); c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = 0; + if ((cd16 * C16NUM * row + r * C16NUM + cm16) > max) max = cd16 * C16NUM * row + r * C16NUM + cm16; + } + } + return; +} +void GemmMatmul(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b, + int ldb, float beta, float *mat_c, int ldc, float *workspace) { + GemmCb gcb; + gcb.atype = ActType_No; + gcb.ca = 0; + gcb.cb = 0; + gcb.bias = NULL; + gcb.mat_a = NULL; + gcb.mat_b = NULL; + GemmMatmulPlus(ta, tb, M, N, K, alpha, mat_a, lda, mat_b, ldb, beta, mat_c, ldc, workspace, &gcb); +} + +void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b, + int ldb, float beta, float *mat_c, int ldc, float *workspace, GemmCb *gcb) { +#ifdef ENABLE_ARM32 + const int num = C4NUM; + const int num1 = C8NUM; +#elif ENABLE_AVX + const int num = C6NUM; + const int num1 = C16NUM; +#else + const int num = C12NUM; + const int num1 = C8NUM; +#endif + float *output = mat_c; + float *fworkspace = workspace; + int incremental = (beta < 0.f) || (beta > 0.f); + float *mat_a_input = (float *)mat_a; + float *mat_b_input = (float *)mat_b; + + if (!gcb->ca) { + mat_a_input = fworkspace; + if (ta) { + fworkspace += MatSize(K, M, num); +#ifdef ENABLE_ARM32 + RowMajor2Row4MajorStride(mat_a, mat_a_input, K, M, lda); +#elif ENABLE_AVX + RowMajor2Row6MajorStride(mat_a, mat_a_input, K, M, lda); +#else + RowMajor2Row12MajorStride(mat_a, mat_a_input, K, M, lda); +#endif + } else { + fworkspace += MatSize(M, K, num); +#ifdef ENABLE_ARM32 + RowMajor2Col4MajorStride(mat_a, mat_a_input, M, K, lda); +#elif ENABLE_AVX + RowMajor2Col6MajorStride(mat_a, mat_a_input, M, K, lda); +#else + RowMajor2Col12MajorStride(mat_a, mat_a_input, M, K, lda); +#endif + } + } + if (!gcb->cb) { + mat_b_input = fworkspace; + if (tb) { + fworkspace += MatSize(N, K, num1); +#ifdef ENABLE_AVX + RowMajor2Col16MajorStride(mat_b, mat_b_input, N, K, ldb); +#else + RowMajor2Col8MajorStride(mat_b, mat_b_input, N, K, ldb); +#endif + } else { + fworkspace += MatSize(K, N, num1); +#ifdef ENABLE_AVX + RowMajor2Row16MajorStride(mat_b, mat_b_input, K, N, ldb); +#else + RowMajor2Row8MajorStride(mat_b, mat_b_input, K, N, ldb); +#endif + } + } + if (incremental) output = fworkspace; +#ifdef ENABLE_ARM32 + MatmulFloatNeon32Opt(mat_a_input, mat_b_input, output, gcb->bias, (int)gcb->atype, K, M, N, ldc, 1); +#else + MatMulOpt(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc); +#endif + if (incremental) AddMatrix(output, mat_c, beta, M, N, ldc); + gcb->mat_a = mat_a_input; + gcb->mat_b = mat_b_input; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/gemm.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..59a8c3a9242579d5dcf79c42d4549eb248c66062 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/gemm.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_GEMM_H_ +#define NNACL_FP32_GRAD_GEMM_H_ + +#include +#include "nnacl/op_base.h" +#ifdef __cplusplus +extern "C" { +#endif +typedef struct { + int ca; + int cb; + ActType atype; + float *bias; + float *mat_a; + float *mat_b; +} GemmCb; + +void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b, + int ldb, float beta, float *mat_c, int ldc, float *workspace, GemmCb *cb); +void GemmMatmul(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b, + int ldb, float beta, float *mat_c, int ldc, float *workspace); +int MatSize(int row, int col, int round); +int MatSizeTotal(int row, int col, int deep, int inc); +void AddMatrix(const float *v1, float *v2, float beta, int row, int col, int stride); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_GEMM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/layernorm_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/layernorm_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..936db1264066e37edbf796eafef4dce710e1db6e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/layernorm_grad.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32_grad/layernorm_grad.h" +#include +#include +#include "nnacl/errorcode.h" + +int LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, + int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db) { + // var is actually layer_norm forward output var + const float eps = 1e-12; + const float *var_sqrt_rev = var; + if (block_size <= 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + for (int i = 0; i < param_num; ++i) { + float dgamma = 0.0f; + float dbeta = 0.0f; + for (int j = i; j < param_size * param_num; j += param_num) { + int norm_shift = (int)(j / block_size); + dgamma += dy[j] * pow(var[norm_shift] + eps, -0.5) * (x[j] - mean[norm_shift]); + dbeta += dy[j]; + } + dg[i] = dgamma; + db[i] = dbeta; + } + for (int i = 0; i < block_num; ++i) { + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + for (int j = 0; j < block_size; ++j) { + int index = i * block_size + j; + float dxm = x[index] - mean[i]; + int param_shift = index % param_num; + float dyg = dy[index] * gamma[param_shift]; + sum1 += -0.5f * dyg * dxm * pow(var_sqrt_rev[i] + eps, -1.5); + sum2 += dyg; + sum3 += -2.0f * dxm; + } + for (int j = 0; j < block_size; ++j) { + int index = i * block_size + j; + float var_sqrt = pow(var_sqrt_rev[i] + eps, -0.5); + int param_shift = index % param_num; + float dx1 = dy[index] * gamma[param_shift] * var_sqrt; + float dx2 = sum1 * 2.0f / block_size * (x[index] - mean[i]); + float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size); + dx[index] = dx1 + dx2 + dx3; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/layernorm_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/layernorm_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..8558c448241b0a0f099c887a26aaf4f0fed629f0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/layernorm_grad.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ +#define NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, + int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/layernormgrad_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/layernormgrad_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b926472d3e73d6eff5db20c661bef4404f9baddb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/layernormgrad_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ +#define NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct { + OpParameter op_parameter_; + int begin_norm_axis_; + int begin_params_axis_; +} LayerNormGradParameter; + +#endif // NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/lstm_grad_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/lstm_grad_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..be785a84b7f9a8f1ac4ee6104c00fc5ae2d72cc1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/lstm_grad_fp32.c @@ -0,0 +1,237 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/lstm_grad_fp32.h" +#include +#include +#include "nnacl/lstm_parameter.h" +#include "nnacl/fp32/activation_fp32.h" +#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32_grad/gemm.h" +#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/nnacl_utils.h" + +static const int num_of_gates = 4; +static const int no_of_temp_matrices_sized_output_step = 5; + +static inline float *AllocteFromScrachPad(float **scrach_pad, int size) { + float *buffer = *scrach_pad; + *scrach_pad += size; + return buffer; +} + +static const int weights_order_IOFG[2 * 4] = {0, 3, 1, 2, 4, 7, 5, 6}; // IOFG order to IFGO order +static const int weights_order_IFGO[2 * 4] = {0, 2, 3, 1, 4, 6, 7, 5}; // IFGO order to IOFG order + +const int *getLstmOrderIOFG(void) { return weights_order_IOFG; } + +const int *getLstmOrderIFGO(void) { return weights_order_IFGO; } + +void PackLstmWeightTranspose(float *dst, const float *src, int batch, int col, int row, int row_align, + const int *order) { + for (int i = 0; i < batch; i++) { + const float *src_batch = src + i * col * row; + float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col * row_align; +#ifdef ENABLE_AVX + RowMajor2Row16Major(src_batch, dst_batch, row, col); +#elif defined(ENABLE_ARM32) + RowMajor2Row4Major(src_batch, dst_batch, row, col); +#else + RowMajor2Row8Major(src_batch, dst_batch, row, col); +#endif + } +} + +void ReorderLstmWeights(float *dst, const float *src, int nof_martices, int col, int row, const int *order) { + int matrix_size = col * row; + for (int i = 0; i < nof_martices; i++) { + const float *src_block = src + i * matrix_size; + float *dst_block = dst + ((order == NULL) ? i : order[i]) * matrix_size; + memcpy(dst_block, src_block, matrix_size * sizeof(float)); + } +} + +void sumCols(int m, int n, int stride, float *inMat, float *outMat, bool accumulate) { + for (int idn = 0; idn < n; idn++) { + float *col = inMat + idn; + if (!accumulate) { + *outMat = 0; + } + for (int idm = 0; idm < m; idm++) { + *outMat += *col; + col += stride; + } + outMat++; + } +} + +int GetGemmMatMullWorkspace(int batch, int input_size, int hidden_size) { + int workspace_size, temp; + // if the appropriate GemmMatNul use beta>0 matSizeTotal must have col as last parameter. + workspace_size = MatSizeTotal(batch, input_size, hidden_size, input_size); + temp = MatSizeTotal(batch, hidden_size, hidden_size, hidden_size); + workspace_size = (temp > workspace_size) ? temp : workspace_size; + temp = MatSizeTotal(hidden_size, input_size, batch, input_size); + workspace_size = (temp > workspace_size) ? temp : workspace_size; + temp = MatSizeTotal(hidden_size, hidden_size, batch, hidden_size); + workspace_size = (temp > workspace_size) ? temp : workspace_size; + return workspace_size; +} + +int GetRunWorkspaceSize(const LstmGradParameter *lstm_param) { + int time_stamp_len = lstm_param->batch_ * lstm_param->hidden_size_; + int workspace_size = no_of_temp_matrices_sized_output_step * time_stamp_len; + workspace_size += GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_); + return workspace_size; +} + +size_t GetRunWorkspaceGemmOffset(const LstmGradParameter *lstm_param) { + int time_stamp_len = lstm_param->batch_ * lstm_param->hidden_size_; + return no_of_temp_matrices_sized_output_step * time_stamp_len; +} + +void LstmGradReorderDy(float *src, float *dst, LstmGradParameter *lstm_param) { + int dir_mult = lstm_param->bidirectional_ ? C2NUM : C1NUM; + for (int b = 0; b < lstm_param->batch_; b++) { + int batch_offset = b * dir_mult * lstm_param->hidden_size_; + float *dy = src + batch_offset; + memcpy(dst + b * lstm_param->hidden_size_, dy, lstm_param->hidden_size_ * sizeof(float)); + } +} + +void LstmGradDoInputStep(const float *output_gate, float *cell_state, float *prev_cell_state, float *cell_gate, + float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float **dA, float *dX, + float *w, float *v, float *workspace, const LstmGradParameter *lstm_param) { + float *scratchPad = workspace; + + int seq_len = lstm_param->batch_ * lstm_param->hidden_size_; + float *temp0 = AllocteFromScrachPad(&scratchPad, seq_len); + float *temp1 = AllocteFromScrachPad(&scratchPad, seq_len); + float *temp2 = AllocteFromScrachPad(&scratchPad, seq_len); + float *temp3 = AllocteFromScrachPad(&scratchPad, seq_len); + float *temp4 = AllocteFromScrachPad(&scratchPad, seq_len); + + // Accumulate gradients into dH + ElementAdd(dH, dY, dH, seq_len); + + ElementMul(dH, output_gate, temp1, seq_len); + Tanh(cell_state, seq_len, temp0); + ElementMul(temp0, temp0, temp2, seq_len); + ElementMul(temp1, temp2, temp4, seq_len); + ElementSub(temp1, temp4, temp1, seq_len); + ElementAdd(dC, temp1, dC, seq_len); + + // calculate dI, dO, dF and dG + float *dI = temp1; // dI = dC_{t} * G + ElementMul(dC, cell_gate, dI, seq_len); + float *dO = temp2; // dO = dH * Tanh(C_{t}) + ElementMul(dH, temp0, dO, seq_len); + float *dF = temp3; // dF = dC_{t} * C_{t-1} + ElementMul(dC, prev_cell_state, dF, seq_len); + float *dG = temp4; // dG = dC_{t} * I + ElementMul(dC, input_gate, dG, seq_len); + + // dAi = dI * I * (1 - I) + float *dAi = temp1; + *dA = dAi; + ElementMul(dI, input_gate, dAi, seq_len); + ElementMul(dAi, input_gate, temp0, seq_len); + ElementSub(dAi, temp0, dAi, seq_len); + + // dAo = dO * O * (1 - O) + float *dAo = temp2; + ElementMul(dO, output_gate, dAo, seq_len); + ElementMul(dAo, output_gate, temp0, seq_len); + ElementSub(dAo, temp0, dAo, seq_len); + + // dAf = dF * F * (1 - F) + float *dAf = temp3; + ElementMul(dF, forget_gate, dAf, seq_len); + ElementMul(dAf, forget_gate, temp0, seq_len); + ElementSub(dAf, temp0, dAf, seq_len); + + float *dAg = temp4; + ElementMul(cell_gate, cell_gate, temp0, seq_len); + ElementMul(dG, temp0, temp0, seq_len); + ElementSub(dG, temp0, dAg, seq_len); + + float *mat_workspace = AllocteFromScrachPad( + &scratchPad, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_)); + float *weights_loop = w; + float *dA_loop = dAi; // dAi, dAo, dAf, dAg + for (int idx = 0; idx < num_of_gates; idx++) { + GemmMatmul(0, 0, lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_, 1.0, dA_loop, + lstm_param->hidden_size_, weights_loop, lstm_param->input_size_, 1.0, dX, lstm_param->input_size_, + mat_workspace); + weights_loop += lstm_param->hidden_size_ * lstm_param->input_size_; + dA_loop += seq_len; + } + + // calculate dH next + size_t dH_size = lstm_param->batch_ * lstm_param->hidden_size_ * sizeof(float); + memset(dH, 0, dH_size); + dA_loop = dAi; + weights_loop = v; + for (int idx = 0; idx < num_of_gates; idx++) { + GemmMatmul(0, 0, lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, 1.0, dA_loop, + lstm_param->hidden_size_, weights_loop, lstm_param->hidden_size_, 1.0, dH, lstm_param->hidden_size_, + mat_workspace); + weights_loop += lstm_param->hidden_size_ * lstm_param->hidden_size_; + dA_loop += seq_len; + } + // calculate dC next + ElementMul(dC, forget_gate, dC, seq_len); +} + +void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *dV, float *dB, + float *workspace, const LstmGradParameter *lstm_param) { + // Calc dWi, dWo, dWf, dWg, dVi, dVo, dVf, dVg, dBi, dBo, dBf, dBg + int seq_len = lstm_param->batch_ * lstm_param->hidden_size_; + float *mat_workspace = AllocteFromScrachPad( + &workspace, GetGemmMatMullWorkspace(lstm_param->batch_, lstm_param->input_size_, lstm_param->hidden_size_)); + float *dA_loop = dA; // dAi, dAo, dAf, dAg + int dW_size = lstm_param->input_size_ * lstm_param->hidden_size_; + int dV_size = lstm_param->hidden_size_ * lstm_param->hidden_size_; + int dB_size = 0; + float *dW_loop = dW; + float *dV_loop = dV; + float *dB_loop = 0; + if (lstm_param->has_bias_) { + dB_loop = dB; + dB_size = lstm_param->hidden_size_; + } + + for (int idx = 0; idx < num_of_gates; idx++) { + // Calc dW + GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->input_size_, lstm_param->batch_, 1.0, dA_loop, + lstm_param->hidden_size_, input_t, lstm_param->input_size_, 1.0, dW_loop, lstm_param->input_size_, + mat_workspace); + // Calc dV + GemmMatmul(1, 0, lstm_param->hidden_size_, lstm_param->hidden_size_, lstm_param->batch_, 1.0, dA_loop, + lstm_param->hidden_size_, prev_hidden_state, lstm_param->hidden_size_, 1.0, dV_loop, + lstm_param->hidden_size_, mat_workspace); + // Clac dB + if (dB_loop != 0) { + sumCols(lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, dA_loop, dB_loop, true); + } + dA_loop += seq_len; + dW_loop += dW_size; + dV_loop += dV_size; + dB_loop += dB_size; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/lstm_grad_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/lstm_grad_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..c92faf0308d3ae4e81127e08fa6968901f217eff --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/lstm_grad_fp32.h @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_LSTM_GRAD_H_ +#define NNACL_FP32_GRAD_LSTM_GRAD_H_ + +#include "nnacl/op_base.h" + +typedef struct LstmGradParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int input_size_; + int hidden_size_; // output_size + int seq_len_; + int batch_; + // other parameter + int output_step_; + bool bidirectional_; + float zoneout_cell_; + float zoneout_hidden_; + int input_row_align_; + int input_col_align_; + int state_row_align_; + int state_col_align_; + int has_bias_; +} LstmGradParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +const int *getLstmOrderIOFG(void); + +const int *getLstmOrderIFGO(void); + +int GetRunWorkspaceSize(const LstmGradParameter *lstm_param); + +size_t GetRunWorkspaceGemmOffset(const LstmGradParameter *lstm_param); + +void LstmGradReorderDy(float *src, float *dst, LstmGradParameter *lstm_param); + +void PackLstmWeightTranspose(float *dst, const float *src, int batch, int col, int row, int row_align, + const int *order); + +void ReorderLstmWeights(float *dst, const float *src, int nof_martices, int col, int row, const int *order); + +void LstmGradDoInputStep(const float *output_gate, float *cell_state, float *prev_cell_state, float *cell_gate, + float *input_gate, float *forget_gate, float *dY, float *dC, float *dH, float **dA, float *dX, + float *w, float *v, float *workspace, const LstmGradParameter *lstm_param); + +void LstmGradDoWeightStep(float *input_t, float *prev_hidden_state, float *dA, float *dW, float *dV, float *dB, + float *workspace, const LstmGradParameter *lstm_param); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_GRAD_LSTM_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/maxpool_grad_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/maxpool_grad_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..269371363d44c6ea977ce7a63f01918111ef7a1d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/maxpool_grad_grad.c @@ -0,0 +1,147 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32_grad/maxpool_grad_grad.h" +#include "nnacl/errorcode.h" + +int MaxPoolGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end, + PoolingParameter *param, PoolingComputeParam *args) { + const int channel = args->input_channel_; + const int input_height = args->input_h_; + const int input_width = args->input_w_; + + const int window_height = args->window_h_; + const int window_width = args->window_w_; + + const int stride_height = param->stride_h_; + const int stride_width = param->stride_w_; + + const int pad_top = param->pad_u_; + const int pad_left = param->pad_l_; + + const int output_height = args->output_h_; + NNACL_CHECK_ZERO_RETURN_ERR(output_height); + const int output_width = args->output_w_; + NNACL_CHECK_ZERO_RETURN_ERR(output_width); + + const int output_chw = channel * output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_chw); + const int output_hw = output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_hw); + + for (size_t pos = start; pos < end; pos++) { + const int pos_n = pos / output_chw; + const int pos_c = pos / output_hw % channel; + const int pos_h = pos / output_width % output_height; + const int pos_w = pos % output_width; + + int h_start = pos_h * stride_height - pad_top; + int w_start = pos_w * stride_width - pad_left; + const int h_end = MSMIN(h_start + window_height, input_height); + const int w_end = MSMIN(w_start + window_width, input_width); + h_start = MSMAX(h_start, 0); + w_start = MSMAX(w_start, 0); + + int input_start = pos_n * channel * input_height * input_width + pos_c * input_height * input_width; + int max_idx = h_start * input_width + w_start; + float max_data = input[input_start + max_idx]; + + for (int h_cur = h_start; h_cur < h_end; ++h_cur) { + for (int w_cur = w_start; w_cur < w_end; ++w_cur) { + int input_idx = h_cur * input_width + w_cur; + float input_data = input[input_start + input_idx]; + if (input_data > max_data) { + max_idx = input_idx; + max_data = input_data; + } + } + } + output[pos] = grad[input_start + max_idx]; + } + return NNACL_OK; +} + +int MaxPool3DGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end, + Pooling3DParameter *param, PoolingComputeParam *args) { + PoolingParameter *param_2d = (PoolingParameter *)(param); + const int channel = args->input_channel_; + const int input_depth = param->input_d_; + const int input_height = args->input_h_; + const int input_width = args->input_w_; + + const int window_depth = param->window_d_; + const int window_height = args->window_h_; + const int window_width = args->window_w_; + + const int stride_depth = param->stride_d_; + const int stride_height = param_2d->stride_h_; + const int stride_width = param_2d->stride_w_; + + const int pad_front = param->pad_f_; + const int pad_top = param_2d->pad_u_; + const int pad_left = param_2d->pad_l_; + + const int output_depth = param->output_d_; + NNACL_CHECK_ZERO_RETURN_ERR(output_depth); + const int output_height = args->output_h_; + NNACL_CHECK_ZERO_RETURN_ERR(output_height); + const int output_width = args->output_w_; + NNACL_CHECK_ZERO_RETURN_ERR(output_width); + + const int output_cdhw = channel * output_depth * output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_cdhw); + const int output_dhw = output_depth * output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_dhw); + const int output_hw = output_height * output_width; + NNACL_CHECK_ZERO_RETURN_ERR(output_hw); + + for (size_t pos = start; pos < end; pos++) { + const int pos_n = pos / output_cdhw; + const int pos_c = pos / output_dhw % channel; + const int pos_d = pos / output_hw % output_depth; + const int pos_h = pos / output_width % output_height; + const int pos_w = pos % output_width; + + int d_start = pos_d * stride_depth - pad_front; + int h_start = pos_h * stride_height - pad_top; + int w_start = pos_w * stride_width - pad_left; + const int d_end = MSMIN(d_start + window_depth, input_depth); + const int h_end = MSMIN(h_start + window_height, input_height); + const int w_end = MSMIN(w_start + window_width, input_width); + d_start = MSMAX(d_start, 0); + h_start = MSMAX(h_start, 0); + w_start = MSMAX(w_start, 0); + + int input_start = + pos_n * channel * input_depth * input_height * input_width + pos_c * input_depth * input_height * input_width; + int max_idx = d_start * input_height * input_width + h_start * input_width + w_start; + float max_data = input[input_start + max_idx]; + + for (int d_cur = d_start; d_cur < d_end; ++d_cur) { + for (int h_cur = h_start; h_cur < h_end; ++h_cur) { + for (int w_cur = w_start; w_cur < w_end; ++w_cur) { + int input_idx = d_cur * input_height * input_width + h_cur * input_width + w_cur; + float input_data = input[input_start + input_idx]; + if (input_data > max_data) { + max_idx = input_idx; + max_data = input_data; + } + } + } + } + output[pos] = grad[input_start + max_idx]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/concat.cc b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/maxpool_grad_grad.h similarity index 48% rename from mindspore-lite/tools/graph_kernel/converter/expanders/concat.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/maxpool_grad_grad.h index 59e3eb5273a2e6e542cd4d726f6b3ca50dd12b43..02ea0b4360a9385ccf12a7cb84e91c0fadfab036 100644 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/concat.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/maxpool_grad_grad.h @@ -14,25 +14,23 @@ * limitations under the License. */ -#include +#ifndef NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_ +#define NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_ -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" +#include "nnacl/op_base.h" +#include "nnacl/pooling_parameter.h" +#include "nnacl/kernel/pooling.h" -namespace mindspore::graphkernel::expanders { -class Concat : public OpDesc { - public: - Concat() { - std::initializer_list attrs{"axis"}; - (void)validators_.emplace_back(std::make_unique(attrs)); - } - ~Concat() = default; +#ifdef __cplusplus +extern "C" { +#endif +int MaxPoolGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end, + PoolingParameter *param, PoolingComputeParam *args); - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - auto axis = GetValue(attrs_["axis"]); - auto result = gb.Concat(inputs, axis); - return {result}; - } -}; -EXPANDER_OP_DESC_REGISTER("Concat", Concat); -} // namespace mindspore::graphkernel::expanders +int MaxPool3DGradGrad(const float *input, const float *grad, float *output, size_t start, size_t end, + Pooling3DParameter *param, PoolingComputeParam *args); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_MAXPOOL_GRAD_GARD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/nllloss_grad_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/nllloss_grad_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..c5db5944c3ea9582a816a6f0e126e17523cfc21a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/nllloss_grad_fp32.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/nllloss_grad_fp32.h" + +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" + +int NLLLossGrad(const float *logits, const float *loss_grad, const int *labels, const float *weight, + const float *total_weight, float *logits_grad, int batch, int class_num, ReductionType reduction_type) { + if (logits == NULL || loss_grad == NULL || labels == NULL || weight == NULL || total_weight == NULL || + logits_grad == NULL) { + return NNACL_NULL_PTR; + } + + memset(logits_grad, 0, batch * class_num * sizeof(float)); + for (int i = 0; i < batch; i++) { + int index = i * class_num + labels[i]; + float n_weight = weight[labels[i]]; + if (reduction_type == Reduction_Sum) { + logits_grad[index] = -loss_grad[0] * n_weight; + } else if (reduction_type == Reduction_Mean) { + logits_grad[index] = -loss_grad[0] * n_weight / *total_weight; + } else { + logits_grad[index] = -loss_grad[i] * n_weight; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/nllloss_grad_fp32.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/nllloss_grad_fp32.h new file mode 100644 index 0000000000000000000000000000000000000000..6a7166ddfbf6f3f6802ba7e217ab423c9d28d001 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/nllloss_grad_fp32.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_NLLLOSS_GRAD_FP32_H_ +#define NNACL_FP32_GRAD_NLLLOSS_GRAD_FP32_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +int NLLLossGrad(const float *logits, const float *loss_grad, const int *labels, const float *weight, + const float *total_weight, float *logits_grad, int batch, int class_num, ReductionType reduction_type); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_NLLLOSS_GRAD_FP32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/optimizer.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..880c94fffcccaa389793d8cb27ffc28a6cae30db --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/optimizer.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_OPTIMIZER_H_ +#define NNACL_FP32_GRAD_OPTIMIZER_H_ + +#include "nnacl/op_base.h" + +typedef struct { + OpParameter op_parameter_; + bool use_nesterov_; + float grad_scale_; +} ApplyMomentumParameter; + +typedef struct { + OpParameter op_parameter_; + float dampening_; + bool use_nesterov_; + float weight_decay_; +} SgdParameter; + +typedef struct { + OpParameter op_parameter_; + bool use_nesterov_; +} AdamParameter; + +#endif // NNACL_FP32_GRAD_OPTIMIZER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/pack_ext.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/pack_ext.c new file mode 100644 index 0000000000000000000000000000000000000000..bc1113b80d88808fe74168489d497e4ec6733f97 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/pack_ext.c @@ -0,0 +1,301 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl/fp32_grad/pack_ext.h" + +void RollingIm2ColPackDwUnitFp32(const float *in_data, const ConvParameter *conv_param, float *data_col_orig, + int real_cal_num, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_; + const int stride = kernel_h * kernel_w; + + int kernel_row, kernel_col; + + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + float *data_col = data_col_orig + i * channels * stride; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * channels; + for (int c = 0; c < channels; c++) { + data_col[c * stride] = in_data[offset + c]; + } + data_col++; + } else { + for (int c = 0; c < channels; c++) { + data_col[c * stride] = 0; + } + data_col++; + } + } + } + } +} + +void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num, + int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col; + + if (channels == 1) { + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + *data_col = in_data[offset]; + data_col++; + } else { + *data_col = 0; + data_col++; + } + } + } + } + } else { + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + memcpy(data_col, in_data + offset, sizeof(float) * channels); + data_col += channels; + } else { + memset(data_col, 0, sizeof(float) * channels); + data_col += channels; + } + } + } + } + } +} + +void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index) { + rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index); +} + +void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->output_h_; + const int in_width = conv_param->output_w_; + + const int output_w = conv_param->input_w_; + + const int tot_channels = conv_param->output_channel_; + const int channels = tot_channels / conv_param->group_; + int channel, kernel_row, kernel_col, output_rows, output_col; + for (channel = 0; channel < channels; channel++) { + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + for (output_rows = start; output_rows < start + rows; output_rows++) { + int input_row = -pad_up + kernel_row * dilation_h + output_rows * stride_h; + if (!((unsigned)(input_row) < (unsigned)(in_height))) { + for (output_col = output_w; output_col; output_col--) { + *(data_row++) = 0; + } + } else { + int input_col = -pad_left + kernel_col * dilation_w; + for (output_col = output_w; output_col; output_col--) { + if (((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels + channel; + *(data_row++) = in_data[offset]; + } else { + *(data_row++) = 0; + } + input_col += stride_w; + } + } + } + } + } + } +} + +void col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_h = conv_param->output_h_; + const int output_w = conv_param->output_w_; + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col, output_rows, output_col; + + int row_stride_offset = 0; + + for (output_rows = output_h; output_rows; output_rows--) { + int col_stride_offset = 0; + for (output_col = output_w; output_col; output_col--) { + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float *data_im_ptr = data_im + offset; + for (int i = 0; i < channels; i++) { + data_im_ptr[i] += data_col[i]; + } + } + data_col += channels; + } + } + col_stride_offset += stride_w; + } + row_stride_offset += stride_h; + } +} + +void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start) { + const int pad_left = conv_param->pad_l_; + const int pad_up = conv_param->pad_u_; + + const int stride_h = conv_param->stride_h_; + const int stride_w = conv_param->stride_w_; + + const int dilation_h = conv_param->dilation_h_; + const int dilation_w = conv_param->dilation_w_; + + const int kernel_h = conv_param->kernel_h_; + const int kernel_w = conv_param->kernel_w_; + + const int in_height = conv_param->input_h_; + const int in_width = conv_param->input_w_; + + const int output_w = conv_param->output_w_; + const int channels = conv_param->input_channel_ / conv_param->group_; + const int tot_channels = conv_param->input_channel_; + + int kernel_row, kernel_col; + + if (channels == 1) { + for (int r = 0; r < rows; r++) { + int output_col = (start + r) % output_w; + int output_row = (start + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float *data_im_ptr = data_im + offset; + *data_im_ptr += *data_col; + } + data_col++; + } + } + } + } else { + for (int r = 0; r < rows; r++) { + int output_col = (start + r) % output_w; + int output_row = (start + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float *data_im_ptr = &data_im[offset]; + for (int i = 0; i < channels; i++) { + data_im_ptr[i] += data_col[i]; + } + } + data_col += channels; + } + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/pack_ext.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/pack_ext.h new file mode 100644 index 0000000000000000000000000000000000000000..ad4352fb7b3da89646340adc2f8be4dff22561ff --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/pack_ext.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_PACK_EXT_H_ +#define NNACL_FP32_GRAD_PACK_EXT_H_ + +#include +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index); +void RollingIm2ColPackDwUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, + int real_cal_num, int block_index); + +void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start); +void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start); +void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_PACK_EXT_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/pooling_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/pooling_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..2402a098f6a0de27ccbfad26f55d473016d6862f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/pooling_grad.c @@ -0,0 +1,190 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32_grad/pooling_grad.h" +#include +#include +#include +#include "nnacl/op_base.h" + +void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + + const float kk = 1.0f / (float)(win_h * win_w); +#if ENABLE_ARM + const float32x4_t factor = vdupq_n_f32(kk); +#endif + for (int ib = 0; ib < count; ib++) { + float *out = output_ptr + ib * in_h * in_w * channel; + const float *inPtr = input_ptr + ib * output_h * output_w * channel; + // iterate over yt + for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); + for (int yw = 0; yw < output_w; yw++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic < channel - C4NUM; ic += C4NUM) { + int idx = (yw + yh * output_w) * channel + ic; +#ifdef ENABLE_ARM + float32x4_t in = vld1q_f32(inPtr + idx); + float32x4_t delta = vmulq_f32(in, factor); +#else + float delta[C4NUM] = {inPtr[idx], inPtr[idx + C1NUM], inPtr[idx + C2NUM], inPtr[idx + C3NUM]}; + for (int i = 0; i < C4NUM; i++) delta[i] *= kk; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; +#ifdef ENABLE_ARM + float *out_vec = out + (xw + in_w * xh) * channel + ic; + float32x4_t outr = vld1q_f32(out + (xw + in_w * xh) * channel + ic); + float32x4_t outs = vaddq_f32(outr, delta); + vst1q_f32(out_vec, outs); +#else + + for (int i = 0; i < C4NUM; i++) { + out[(xw + in_w * xh) * channel + ic + i] += ((float *)&delta)[i]; + } +#endif + } + } + } + for (; ic < channel; ic++) { + int idx = (yw + yh * output_w) * channel + ic; + float delta = inPtr[idx] * kk; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + out[(xw + in_w * xh) * channel + ic] += delta; + } + } + } + } + } + } +} + +#ifdef ENABLE_ARM +static int32x4_t MaxIndex(float32x4_t in, float32x4_t *max, int32x4_t index, int32x4_t prev_index) { + uint32x4_t res = vcgtq_f32(in, *max); + int32x4_t m_index = vbslq_s32(res, index, prev_index); + *max = vbslq_f32(res, in, *max); + return m_index; +} +#endif + +void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch, + const PoolingParameter *pooling_param, const PoolingComputeParam *pooling_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_args->window_w_; + int win_h = pooling_args->window_h_; + int channel = pooling_args->input_channel_; + int in_w = pooling_args->input_w_; + int in_h = pooling_args->input_h_; + int output_w = pooling_args->output_w_; + int output_h = pooling_args->output_h_; + for (int ib = 0; ib < output_batch; ib++) { + float *out = output_ptr + ib * in_h * in_w * channel; + const float *inPtr = input_ptr + ib * in_h * in_w * channel; + const float *dyPtr = dy_ptr + ib * output_h * output_w * channel; + for (int yh = 0; yh < output_h; yh++) { + int over_h = pad_h - yh * stride_h; + int kh_s = MSMAX(0, over_h); + int kh_e = MSMIN(win_h, in_h + over_h); + for (int yw = 0; yw < output_w; yw++) { + int over_w = pad_w - yw * stride_w; + int kw_s = MSMAX(0, over_w); + int kw_e = MSMIN(win_w, in_w + over_w); + int ic = 0; + for (; ic <= channel - C4NUM; ic += C4NUM) { + int idx = (yw + yh * output_w) * channel + ic; +#ifdef ENABLE_ARM + uint32x4_t max_idx = vdupq_n_u32(0); + float32x4_t max_val = vdupq_n_f32(-FLT_MAX); + float32x4_t delta = vld1q_f32(dyPtr + idx); +#else + float delta[C4NUM] = {dyPtr[idx], dyPtr[idx + C1NUM], dyPtr[idx + C2NUM], dyPtr[idx + C3NUM]}; + float max_val[C4NUM] = {-FLT_MAX, -FLT_MAX, -FLT_MAX, -FLT_MAX}; + int max_idx[C4NUM] = {0}; +#endif + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; +#ifdef ENABLE_ARM + uint32x4_t index = {val_idx, val_idx + 1, val_idx + 2, val_idx + 3}; + float32x4_t in = vld1q_f32(inPtr + val_idx); + max_idx = vreinterpretq_u32_s32( + MaxIndex(in, &max_val, vreinterpretq_s32_u32(index), vreinterpretq_s32_u32(max_idx))); +#else + float val[C4NUM] = {inPtr[val_idx], inPtr[val_idx + C1NUM], inPtr[val_idx + C2NUM], + inPtr[val_idx + C3NUM]}; + for (int i = 0; i < C4NUM; i++) { + if (val[i] > max_val[i]) { + max_val[i] = val[i]; + max_idx[i] = val_idx + i; + } + } +#endif + } + } + for (int i = 0; i < C4NUM; i++) { + out[((int *)&max_idx)[i]] += ((float *)&delta)[i]; + } + } + for (; ic < channel; ic++) { + float max_val = -FLT_MAX; + int max_idx = 0; + int idx = (yw + yh * output_w) * channel + ic; + float delta = dyPtr[idx]; + for (int kh = kh_s; kh < kh_e; kh++) { + int xh = yh * stride_h + kh - pad_h; + for (int kw = kw_s; kw < kw_e; kw++) { + int xw = yw * stride_w + kw - pad_w; + int val_idx = (xw + in_w * xh) * channel + ic; + float val = inPtr[val_idx]; + if (val > max_val) { + max_val = val; + max_idx = val_idx; + } + } + } + out[max_idx] += delta; + } + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/pooling_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/pooling_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..1410f839ef2ef12aaf901ea4a2ca08eec2dc3b5b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/pooling_grad.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_POOLING_GRAD_H_ +#define NNACL_FP32_GRAD_POOLING_GRAD_H_ + +#include "nnacl/fp32/pooling_fp32.h" +#include "nnacl/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, const PoolingParameter *pooling_param, + const PoolingComputeParam *pooling_args); +void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_ptr, int output_batch, + const PoolingParameter *pooling_param, const PoolingComputeParam *pooling_args); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_POOLING_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/reduce_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/reduce_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..c8738699c735fe656a51df6bcd18c662710ea2ed --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/reduce_grad.c @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/fp32_grad/reduce_grad.h" +#include "nnacl/fp32_grad/utils.h" +#include "nnacl/op_base.h" + +void ReduceMeanByAxes(const float *input_data, int *input_iter, const int *input_dims, int input_num_dims, + const int *axes, int num_axes, float *output_data, const int *output_dims, int output_num_dims) { + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + size_t current = (size_t)(output_dims[idx]); + num_outputs *= current; + } + + // Reset input iterator. + for (int idx = 0; idx < input_num_dims; ++idx) { + input_iter[idx] = 0; + } + // Iterate through input_data. + do { + size_t input_offset = GetInputOffset(input_num_dims, input_dims, input_iter); + size_t output_offset = GetOutputOffset(input_num_dims, input_dims, input_iter, num_axes, axes); + output_data[output_offset] += input_data[input_offset]; + } while (NextIndex(input_num_dims, input_dims, input_iter)); + + // Calculate mean by dividing output_data by num of aggregated element. + size_t num_elements_in_axis = 1; + for (int idx = 0; idx < num_axes; ++idx) { + size_t current = (size_t)(input_dims[axes[idx]]); + num_elements_in_axis *= current; + } + + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = output_data[idx] / (float)(num_elements_in_axis); + } +} + +float ReduceMeanAll(const float *src, int size) { + float sum = 0; + for (int i = 0; i < size; ++i) { + sum += src[i]; + } + return sum / size; +} + +void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims) { + int num_outputs = 1; + int same_shape = 1; + for (int idx = 0; idx < num_dims; ++idx) { + num_outputs *= output_dims[idx]; + if (output_dims[idx] != input_dims[idx]) same_shape = 0; + } + if (same_shape) { + memcpy(output, input, (size_t)(num_outputs) * sizeof(float)); + return; + } + + memset(output, 0, (size_t)(num_outputs) * sizeof(float)); // zero output + + int input_iter[C8NUM] = {0}; + int axes[C5NUM] = {0}; + int num_axes = 0; + for (int i = 0; i < num_dims; i++) { + if (output_dims[i] == C1NUM && num_axes < C5NUM) { + axes[num_axes++] = i; + } + } + + // Iterate through input_data. + do { + size_t input_offset = GetInputOffset(num_dims, input_dims, input_iter); + size_t output_offset = GetOutputOffset(num_dims, input_dims, input_iter, num_axes, axes); + output[output_offset] += input[input_offset]; + } while (NextIndex(num_dims, input_dims, input_iter)); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/reduce_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/reduce_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..edb610254cc86fc7d60a4ba58ae2cdf827224e15 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/reduce_grad.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_REDUCE_GRAD_H_ +#define NNACL_FP32_GRAD_REDUCE_GRAD_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif +float ReduceMeanAll(const float *src, int size); +void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_GRAD_REDUCE_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/resize_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/resize_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..3fe386d8c388261930ccf08cd00707733d3b73fb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/resize_grad.c @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32_grad/resize_grad.h" +#include +#include "nnacl/infer/common_infer.h" +#include "nnacl/errorcode.h" + +int ResizeNearestNeighborGrad(const float *in_addr, float *out_addr, int batch_size, int channel, int format, + const ResizeGradParameter *param) { + bool align_corners = param->align_corners_; + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + if (format == Format_NHWC) { + NNACL_CHECK_ZERO_RETURN_ERR(param->in_width_); + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t in_y = i / param->in_width_; + size_t in_x = i % param->in_width_; + for (size_t c = 0; c < (size_t)channel; ++c) { + size_t out_y = MSMIN( + (align_corners) ? (size_t)roundf(in_y * param->height_scale_) : (size_t)floorf(in_y * param->height_scale_), + param->out_height_ - 1); + size_t out_x = MSMIN( + (align_corners) ? (size_t)roundf(in_x * param->width_scale_) : (size_t)floorf(in_x * param->width_scale_), + param->out_width_ - 1); + size_t out_offset = out_y * (param->out_width_ * channel) + (out_x * channel) + c; + size_t in_offset = in_y * (param->in_width_ * channel) + (in_x * channel) + c; + out_addr[out_offset] += in_addr[in_offset]; + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } + } else if (format == Format_NCHW) { + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t c = 0; c < (size_t)channel; ++c) { + for (size_t h = 0; h < param->in_height_; ++h) { + size_t out_y = + MSMIN((align_corners) ? (size_t)roundf(h * param->height_scale_) : (size_t)floorf(h * param->height_scale_), + param->out_height_ - 1); + for (size_t w = 0; w < param->in_width_; ++w) { + size_t out_x = + MSMIN((align_corners) ? (size_t)roundf(w * param->width_scale_) : (size_t)floorf(w * param->width_scale_), + param->out_width_ - 1); + out_addr[out_y * param->out_width_ + out_x] += in_addr[h * param->in_width_ + w]; + } + } + out_addr += out_hw_size; + in_addr += in_hw_size; + } + } + } + return NNACL_OK; +} + +int ResizeBiLinearGrad(const float *in_addr, float *out_addr, int batch_size, int channel, int format, + const ResizeGradParameter *param) { + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + if (format == Format_NHWC) { + NNACL_CHECK_ZERO_RETURN_ERR(param->in_width_); + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t h = i / param->in_width_; + size_t w = i % param->in_width_; + for (size_t c = 0; c < (size_t)channel; ++c) { + float in_y = (float)h * param->height_scale_; + size_t top_y_index = MSMAX((size_t)(floorf(in_y)), (size_t)(0)); + size_t bottom_y_index = MSMIN((size_t)(ceilf(in_y)), param->out_height_ - 1); + float y_lerp = in_y - floorf(in_y); + const float inverse_y_lerp = 1.0 - y_lerp; + + float in_x = (float)w * param->width_scale_; + size_t left_x_index = MSMAX((size_t)(floorf(in_x)), (size_t)(0)); + size_t right_x_index = MSMIN((size_t)(ceilf(in_x)), param->out_width_ - 1); + float x_lerp = in_x - floorf(in_x); + const float inverse_x_lerp = 1.0 - x_lerp; + + size_t in_offset = h * (param->in_width_ * channel) + (w * channel) + c; + size_t out_offset_top_y_left_x = top_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_top_y_right_x = top_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + size_t out_offset_bottom_y_left_x = + bottom_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_bottom_y_right_x = + bottom_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + + out_addr[out_offset_top_y_left_x] += in_addr[in_offset] * (float)(inverse_y_lerp * inverse_x_lerp); + out_addr[out_offset_top_y_right_x] += in_addr[in_offset] * (float)(inverse_y_lerp * x_lerp); + out_addr[out_offset_bottom_y_left_x] += in_addr[in_offset] * (float)(y_lerp * inverse_x_lerp); + out_addr[out_offset_bottom_y_right_x] += in_addr[in_offset] * (float)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } + } else if (format == Format_NCHW) { + size_t in_height = param->in_height_; + size_t in_width = param->in_width_; + size_t out_height = param->out_height_; + size_t out_width = param->out_width_; + out_hw_size = out_height * out_width; + in_hw_size = in_height * in_width; + + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t c = 0; c < (size_t)channel; ++c) { + for (size_t h = 0; h < in_height; ++h) { + const float in_y = (float)(h)*param->height_scale_; + const size_t top_y_index = MSMAX((size_t)floorf(in_y), 0); + const size_t bottom_y_index = MSMIN((size_t)ceilf(in_y), out_height - 1); + const float y_lerp = in_y - floorf(in_y); + const float inverse_y_lerp = 1.0 - y_lerp; + for (size_t w = 0; w < in_width; ++w) { + const float in_x = (float)(w)*param->width_scale_; + const size_t left_x_index = MSMAX((size_t)floorf(in_x), 0); + const size_t right_x_index = MSMIN((size_t)ceilf(in_x), out_width - 1); + const float x_lerp = in_x - floorf(in_x); + const float inverse_x_lerp = 1.0 - x_lerp; + out_addr[top_y_index * out_width + left_x_index] += + in_addr[h * in_width + w] * (float)(inverse_y_lerp * inverse_x_lerp); + out_addr[top_y_index * out_width + right_x_index] += + in_addr[h * in_width + w] * (float)(inverse_y_lerp * x_lerp); + out_addr[bottom_y_index * out_width + left_x_index] += + in_addr[h * in_width + w] * (float)(y_lerp * inverse_x_lerp); + out_addr[bottom_y_index * out_width + right_x_index] += + in_addr[h * in_width + w] * (float)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size; + in_addr += in_hw_size; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/resize_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/resize_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..b0195e281243eff363763e090924bc3a284ff3a2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/resize_grad.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_RESIZE_GRAD_H_ +#define NNACL_FP32_GRAD_RESIZE_GRAD_H_ + +#include "nnacl/fp32_grad/resize_grad_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ResizeNearestNeighborGrad(const float *in_addr, float *out_addr, int batch_size, int channel, int format, + const ResizeGradParameter *param); +int ResizeBiLinearGrad(const float *in_addr, float *out_addr, int batch_size, int channel, int format, + const ResizeGradParameter *param); +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_GRAD_RESIZE_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/resize_grad_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/resize_grad_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..5ee9686526a9b878b345df968222309509ab45fe --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/resize_grad_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_RESIZE_PARAMETER_GRAD_H_ +#define NNACL_FP32_GRAD_RESIZE_PARAMETER_GRAD_H_ + +#include "nnacl/op_base.h" + +typedef struct ResizeGradParameter { + OpParameter op_parameter_; + bool align_corners_; + int method; + size_t in_height_; + size_t in_width_; + size_t out_height_; + size_t out_width_; + float height_scale_; + float width_scale_; +} ResizeGradParameter; + +#endif // NNACL_FP32_GRAD_RESIZE_PARAMETER_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/smooth_l1_loss.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/smooth_l1_loss.h new file mode 100644 index 0000000000000000000000000000000000000000..489d769ea38ade299abe7f4d4da122889fc35152 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/smooth_l1_loss.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_SMOOTH_L1_LOSS_PARAMETER_H_ +#define NNACL_FP32_SMOOTH_L1_LOSS_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct { + OpParameter op_parameter_; + float beta_; +} SmoothL1LossParameter; + +#endif // NNACL_FP32_SMOOTH_L1_LOSS_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_cross_entropy_with_logits.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_cross_entropy_with_logits.c new file mode 100644 index 0000000000000000000000000000000000000000..1d024e00f8b8531ffb3ad1c2cac6db6260cde483 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_cross_entropy_with_logits.c @@ -0,0 +1,43 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/softmax_cross_entropy_with_logits.h" +#include + +void ForwardPostExecute(const float *labels, const float *logits, float *grads, float *output2, + size_t number_of_classes, int batch_size) { + float eps = 1e-6; + if (grads != NULL) { + for (size_t i = 0; i < (size_t)(batch_size); ++i) { + float loss = 0.f; + for (size_t j = 0; j < number_of_classes; ++j) { + float logit = -logf(logits[i * number_of_classes + j] <= 0.0 ? eps : logits[i * number_of_classes + j]); + grads[i * number_of_classes + j] = (logits[i * number_of_classes + j] - labels[i * number_of_classes + j]); + loss += labels[i * number_of_classes + j] * logit; + } + output2[i] = loss; + } + } else { + for (size_t i = 0; i < (size_t)(batch_size); ++i) { + float loss = 0.f; + for (size_t j = 0; j < number_of_classes; ++j) { + float logit = -logf(logits[i * number_of_classes + j] <= 0.0 ? eps : logits[i * number_of_classes + j]); + loss += labels[i * number_of_classes + j] * logit; + } + output2[i] = loss; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_cross_entropy_with_logits.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_cross_entropy_with_logits.h new file mode 100644 index 0000000000000000000000000000000000000000..7cd53bc4a9d4214f3f39c716635333a3b1f35019 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_cross_entropy_with_logits.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ +#define NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +void ForwardPostExecute(const float *labels, const float *logits, float *grads, float *output2, + size_t number_of_classes, int batch_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_crossentropy_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_crossentropy_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..edf10c00b6afbd3220894d06d9785f24654f80d9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_crossentropy_parameter.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_SOFTMAX_CROSSENTROPY_PARAMETER_H_ +#define NNACL_FP32_GRAD_SOFTMAX_CROSSENTROPY_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct SoftmaxCrossEntropyParameter { + // primitive parameter + OpParameter op_parameter_; + int n_dim_; + + // shape correlative + int input_shape_[5]; + + // other parameter + int32_t batch_size_; + unsigned int number_of_classes_; + bool is_grad_; +} SoftmaxCrossEntropyParameter; + +#endif // NNACL_FP32_GRAD_SOFTMAX_CROSSENTROPY_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..6a9fd23146126e3a4eb4a3d5c4241e9f723f2252 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_grad.c @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/softmax_grad.h" +#include + +void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr, float *sum_data, float *sum_mul, + const int *input_shape, int n_dim, int ele_size, int32_t axis) { + int dim = 1; + int inner_size = 1, outter_size = 1; + for (int i = 0; i < axis; i++) { + outter_size *= input_shape[i]; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + NNACL_CHECK_ZERO_RETURN(outter_size); + for (int i = 0; i < inner_size * input_shape[axis]; i++) sum_mul[i] = 1.0; + for (int i = 0; i < n_dim; i++) dim *= input_shape[i]; + dim /= outter_size; + memcpy(output_ptr, yt_ptr, (size_t)(ele_size) * sizeof(float)); + + const int M = input_shape[axis]; + const int N = inner_size; + for (int i = 0; i < outter_size; i++) { + int outter_offset = i * dim; + memset(sum_data, 0, (size_t)(inner_size) * sizeof(float)); + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + for (int j = 0; j < input_shape[axis]; j++) { + int offset = inner_offset + j * inner_size; + sum_data[k] += output_ptr[offset] * input_ptr[offset]; + } + } + for (int k = 0; k < M; ++k) { + float a = -sum_mul[k]; + for (int j = 0; j < N; ++j) { + *(output_ptr + outter_offset + k * N + j) += a * sum_data[j]; + } + } + } + + for (int i = 0; i < ele_size; i++) { + output_ptr[i] *= input_ptr[i]; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..6270136ceb76cc68bab92972b9309f0493d37178 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_grad.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_SOFTMAX_GRAD_H_ +#define NNACL_FP32_GRAD_SOFTMAX_GRAD_H_ + +#include "nnacl/fp32/softmax_fp32.h" +#include "nnacl/fp32_grad/softmax_crossentropy_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr, float *sum_data, float *sum_mul, + const int *input_shape, int n_dim, int ele_size, int32_t axis); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_SOFTMAX_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_grad_utils.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_grad_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..22da3f2893127dd3817dc79030b5e4f6ee364143 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_grad_utils.c @@ -0,0 +1,102 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/softmax_grad_utils.h" +#include +#include +#include "nnacl/fp32/exp_fp32.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" + +void ExpFp32Offset(const float *src, float *dst, float sub_bias, int num) { + int i = 0; +#ifdef ENABLE_ARM64 + int count = (num / C4NUM) * C4NUM; + for (; i < count; i += C4NUM) { + MS_FLOAT32X4 input = vld1q_f32(src + i); + MS_FLOAT32X4 bias = vdupq_n_f32(sub_bias); + MS_FLOAT32X4 i1 = vsubq_f32(input, bias); + simd_exp128(i1, dst + i); + } +#endif + for (; i < num; ++i) { + simd_exp32(src[i] - sub_bias, dst + i); + } +} + +// output = exp(input) / reduce_sum(exp(input), axis) +static void SoftMaxP1Simple(const float *input_ptr, float *output_ptr, float *sum_data, int start, int count, + int length) { + for (int i = start; i < start + count; i++) { + int inner_offset = i * length; + float max_data = input_ptr[inner_offset]; + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + ExpFp32Offset(input_ptr + inner_offset, output_ptr + inner_offset, max_data, length); + float _sum_data = 0; + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j; + _sum_data += output_ptr[axis_offset]; + } + sum_data[i] = _sum_data; + } +} + +void SoftMaxP1(const float *input_ptr, float *output_ptr, float *sum_data, int start, int count, int length, + int inner_size) { + if (inner_size == 1) { + SoftMaxP1Simple(input_ptr, output_ptr, sum_data, start, count, length); + return; + } + for (int i = start; i < start + count; i++) { + int outter_offset = i * length * inner_size; + int sum_outter_offset = i * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = outter_offset + k; + float max_data = input_ptr[inner_offset]; + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j * inner_size; + max_data = max_data > input_ptr[axis_offset] ? max_data : input_ptr[axis_offset]; + } + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j * inner_size; + output_ptr[axis_offset] = exp(input_ptr[axis_offset] - max_data); + } + float _sum_data = 0; + for (int j = 0; j < length; j++) { + int axis_offset = inner_offset + j * inner_size; + _sum_data += output_ptr[axis_offset]; + } + sum_data[k + sum_outter_offset] = _sum_data; + } + } +} + +void SoftMaxP2(const float *input_ptr, float *output_ptr, const float *sum_data, int start, int count, int length, + int inner_size) { + for (int i = start; i < start + count; i++) { + int outter_offset = i * length * inner_size; + int sum_outter_offset = i * inner_size; + for (int j = 0; j < length; j++) { + int axis_offset = outter_offset + j * inner_size; + for (int k = 0; k < inner_size; k++) { + int inner_offset = axis_offset + k; + output_ptr[inner_offset] = output_ptr[inner_offset] / sum_data[k + sum_outter_offset]; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_grad_utils.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_grad_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..68aac00d02245d7aeba0fd48e4a2653201a3925b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/softmax_grad_utils.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_FP32_GRAD_SOFTMAX_H_ +#define NNACL_FP32_GRAD_SOFTMAX_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +void SoftMaxP1(const float *input_ptr, float *output_ptr, float *sum_data, int start, int count, int length, + int inner_size); +void SoftMaxP2(const float *input_ptr, float *output_ptr, const float *sum_data, int start, int count, int length, + int inner_size); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_SOFTMAX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/strided_slice_grad.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/strided_slice_grad.c new file mode 100644 index 0000000000000000000000000000000000000000..ae5f9ce9dbc5feb705def7f62f7e7d3368ca6b28 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/strided_slice_grad.c @@ -0,0 +1,68 @@ +/** + * Copyright 2019-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/strided_slice_grad.h" +#include "nnacl/errorcode.h" + +static size_t CalcIndex(const int *shape, size_t size, int i, size_t pos) { + size_t res = 1; + for (size_t j = 0; j < size; j++) { + res *= shape[((size_t)(i) + 1) + j]; + } + NNACL_CHECK_ZERO_RETURN_ERR(res); + NNACL_CHECK_ZERO_RETURN_ERR(shape[i]); + return (pos / res % shape[i]); +} + +int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, const StridedSliceParameter *param) { + if (inputs == NULL || output == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->num_axes_ > DIMENSION_8D) { + return NNACL_PARAM_INVALID; + } + + size_t size = 1; + const int *s = param->strides_; + const int *b = param->begins_; + for (int i = 0; i < DIMENSION_8D; i++) { + size *= (size_t)(param->in_shape_[i]); + } + + for (size_t pos = 0; pos < size; pos++) { + size_t i = CalcIndex(param->in_shape_, C7NUM, C0NUM, pos); + size_t j = CalcIndex(param->in_shape_, C6NUM, C1NUM, pos); + size_t k = CalcIndex(param->in_shape_, C5NUM, C2NUM, pos); + size_t l = CalcIndex(param->in_shape_, C4NUM, C3NUM, pos); + size_t m = CalcIndex(param->in_shape_, C3NUM, C4NUM, pos); + size_t n = CalcIndex(param->in_shape_, C2NUM, C5NUM, pos); + size_t o = CalcIndex(param->in_shape_, C1NUM, C6NUM, pos); + size_t p = CalcIndex(param->in_shape_, C0NUM, C7NUM, pos); + size_t input_idx = + (i * s[C0NUM] + b[C0NUM]) * dx_shape[C1NUM] * dx_shape[C2NUM] * dx_shape[C3NUM] * dx_shape[C4NUM] * + dx_shape[C5NUM] * dx_shape[C6NUM] * dx_shape[C7NUM] + + (j * s[C1NUM] + b[C1NUM]) * dx_shape[C2NUM] * dx_shape[C3NUM] * dx_shape[C4NUM] * dx_shape[C5NUM] * + dx_shape[C6NUM] * dx_shape[C7NUM] + + (k * s[C2NUM] + b[C2NUM]) * dx_shape[C3NUM] * dx_shape[C4NUM] * dx_shape[C5NUM] * dx_shape[C6NUM] * + dx_shape[C7NUM] + + (l * s[C3NUM] + b[C3NUM]) * dx_shape[C4NUM] * dx_shape[C5NUM] * dx_shape[C6NUM] * dx_shape[C7NUM] + + (m * s[C4NUM] + b[C4NUM]) * dx_shape[C5NUM] * dx_shape[C6NUM] * dx_shape[C7NUM] + + (n * s[C5NUM] + b[C5NUM]) * dx_shape[C6NUM] * dx_shape[C7NUM] + (o * s[C6NUM] + b[C6NUM]) * dx_shape[C7NUM] + + (p * s[C7NUM] + b[C7NUM]); + output[input_idx] = inputs[pos]; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/strided_slice_grad.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/strided_slice_grad.h new file mode 100644 index 0000000000000000000000000000000000000000..48b38fd6bb56d7185b2ca330bf4fe40efb2aab29 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/strided_slice_grad.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ +#define NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ + +#include "nnacl/op_base.h" +#include "nnacl/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoStridedSliceGrad(const float *inputs, float *output, const int *dx_shape, const StridedSliceParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/utils.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..80e5384257c8ea20d6f7bebe9bcd69cce699bb89 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_grad/utils.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_FP32_GRAD_UTILS_H_ +#define NNACL_FP32_GRAD_UTILS_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +static inline size_t GetInputOffset(int num_dims, const int *dims, const int *iter) { + size_t offset = 0; + for (int idx = 0; idx < num_dims; ++idx) { + offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]); + } + + return offset; +} + +static inline size_t GetOutputOffset(int num_dims, const int *dims, const int *iter, int num_axis, const int *axes) { + size_t offset = 0; + for (int idx = 0; idx < num_dims; ++idx) { + // if we need to skip this axis + int is_axis = 0; + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { + if (idx == axes[axis_idx]) { + is_axis = 1; + break; + } + } + + if (is_axis == 0) { + offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]); + } + } + return offset; +} + +static inline int NextIndex(int num_dims, const int *dims, int *current) { + int carry = 1; + for (int idx = num_dims - 1; idx >= 0; --idx) { + int current_val = current[idx] + carry; + if (dims[idx] == current_val) { + current[idx] = 0; + } else { + current[idx] = current_val; + carry = 0; + break; + } + } + return (carry == 0); +} + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_FP32_GRAD_UTILS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/fp32_sparse/matmul_sparse_x1_fp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_sparse/matmul_sparse_x1_fp32.c new file mode 100644 index 0000000000000000000000000000000000000000..67fedc4f34f2ce6b4acc28fc1e04e1ee4dbebb60 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_sparse/matmul_sparse_x1_fp32.c @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_sparse/matmul_sparse_x1_fp32.h" +#ifdef ENABLE_ARM64 +#include +#endif + +void MatMulSparse8x8(const float *a, const float *b, const uint32_t *nnz, const size_t *dmap, float *c, + const float *bias, ActType act_type, int out_stride) { +#ifndef ENABLE_ARM64 + return; +#else + // mul-acc + for (int oc = 0; oc < 8; oc++) { + uint32_t cur_nnz = nnz[oc]; + // init 8x1 C with bias + float32x4_t vacc1 = vld1q_dup_f32(bias + oc); + float32x4_t vacc2 = vacc1; + for (uint32_t nz = 0; nz < cur_nnz; nz++) { + // load w + float32x4_t vw = vld1q_dup_f32(b++); + // load 8 inputs + const float *input = a + (*(dmap++) / sizeof(float)); + float32x4_t vi1 = vld1q_f32(input); + float32x4_t vi2 = vld1q_f32(input + 4); + vacc1 = vfmaq_f32(vacc1, vi1, vw); + vacc2 = vfmaq_f32(vacc2, vi2, vw); + } + // save output + *(c + oc) = vacc1[0]; + *(c + 1 * out_stride + oc) = vacc1[1]; + *(c + 2 * out_stride + oc) = vacc1[2]; + *(c + 3 * out_stride + oc) = vacc1[3]; + *(c + 4 * out_stride + oc) = vacc2[0]; + *(c + 5 * out_stride + oc) = vacc2[1]; + *(c + 6 * out_stride + oc) = vacc2[2]; + *(c + 7 * out_stride + oc) = vacc2[3]; + } +#endif +} diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_sparse/matmul_sparse_x1_fp32.h similarity index 48% rename from mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_base.h rename to mindspore-lite/ops/kernel/cpu/nnacl/fp32_sparse/matmul_sparse_x1_fp32.h index fc3ba4bbfd1be51bef522aeeb1a6ae6c3e16e9aa..b124475312ed9489ce3dbd387446696982e0ea0e 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_base.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/fp32_sparse/matmul_sparse_x1_fp32.h @@ -14,18 +14,28 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_BASE_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_BASE_H_ +#ifndef NNACL_FP32_MATMUL_SPARSE_X1_H_ +#define NNACL_FP32_MATMUL_SPARSE_X1_H_ -#include -#include "src/common/log_adapter.h" -#include "include/errorcode.h" +#include +#include +#include "nnacl/errorcode.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/op_base.h" -namespace mindspore::lite { -const char NCCL_WORLD_GROUP[] = "nccl_world_group"; +#ifdef __cplusplus +extern "C" { +#endif -int GetGPUGroupSize(); +#ifdef ENABLE_ARM64 +void SPMM8x8Fp32(const float *a, const float *b, const uint32_t *nnz, const size_t *dmap, float *c, const float *bias, + ActType act_type, size_t out_stride); +#endif -int GetRankID(); -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_BASE_H_ +void MatMulSparse8x8(const float *a, const float *b, const uint32_t *nnz, const size_t *dmap, float *c, + const float *bias, ActType act_type, int out_stride); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_FP32_MATMUL_SPARSE_X1_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/gather_nd_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/gather_nd_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..f73d0cd82b6a43f4e612fc2197641c08baaa5f70 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/gather_nd_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GATHER_ND_PARAMETER_H_ +#define NNACL_GATHER_ND_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct { + OpParameter op_parameter_; +} GatherNdParameter; + +#endif // NNACL_GATHER_ND_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/gather_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/gather_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..36df544b3a15413c181198e83759b43ebcf5e336 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/gather_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GATHER_PARAMETER_H_ +#define NNACL_GATHER_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct GatherParameter { + // Primitive parameter + OpParameter op_parameter_; + int axis_; +} GatherParameter; + +#endif // NNACL_GATHER_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/gelu_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/gelu_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..994acc06dbe742fca877829b182527c976205eec --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/gelu_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_GELU_PARAMETER_H_ +#define NNACL_GELU_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct GeLUParameter { + // Primitive parameter + OpParameter op_parameter_; + bool approximate_; +} GeLUParameter; + +#endif // NNACL_GELU_PARAMETER_H_ diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/ViewController.m b/mindspore-lite/ops/kernel/cpu/nnacl/glu_parameter.h similarity index 74% rename from mindspore-lite/examples/quick_start_ios/mindspore-lite/ViewController.m rename to mindspore-lite/ops/kernel/cpu/nnacl/glu_parameter.h index 6b9a5ed66eb3e2cb1349505d6c2008d3668699ac..2b147b4c720323dafd58147083b828d3d4ec223e 100644 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/ViewController.m +++ b/mindspore-lite/ops/kernel/cpu/nnacl/glu_parameter.h @@ -13,19 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef NNACL_GLU_PARAMETER_H_ +#define NNACL_GLU_PARAMETER_H_ -#import "ViewController.h" +#include "nnacl/op_base.h" -@interface ViewController () +typedef struct GluParameter { + OpParameter op_parameter_; + int axis_; +} GluParameter; -@end - -@implementation ViewController - -- (void)viewDidLoad { - [super viewDidLoad]; - // Do any additional setup after loading the view. -} - - -@end +#endif // NNACL_GLU_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/grid_sampler_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/grid_sampler_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..6c87a5400fe0ebcde62841abbce247f6ced20054 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/grid_sampler_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_GRID_SAMPLER_PARAMETER_H_ +#define NNACL_GRID_SAMPLER_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct GridSamplerParameter { + OpParameter op_parameter_; + int64_t interpolation_mode_; + int64_t padding_mode_; + bool align_corners_; +} GridSamplerParameter; + +#endif // NNACL_GRID_SAMPLER_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/group_norm_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/group_norm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..990cbc32feaeb55173ae6d1acbeb12a246786625 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/group_norm_parameter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_GROUP_NORM_PARAMETER_H_ +#define NNACL_GROUP_NORM_PARAMETER_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" +typedef struct GroupNormParameter { + // Primitive parameter + OpParameter op_parameter_; + float epsilon_; + int num_groups_; + int channel_; + int unit_; + int batch_; + bool affine_; + void *mean_; + void *variance_; +} GroupNormParameter; + +typedef struct GroupNormQuantArg { + int32_t in_zp_; + int32_t out_zp_; + double in_scale_; + double out_scale_; +} GroupNormQuantArg; + +#endif // NNACL_GROUP_NORM_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/gru_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/gru_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..abd1185b8c35842b8e187c139e59cb39c1d7230c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/gru_parameter.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_GRU_PARAMETER_H_ +#define NNACL_GRU_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct GruParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int input_size_; + int hidden_size_; // output_size + int seq_len_; + int batch_; + // other parameter + int output_step_; + bool bidirectional_; + int input_row_align_; + int input_col_align_; + int state_row_align_; + int state_col_align_; +} GruParameter; + +#endif // NNACL_GRU_PARAMETER_H_ diff --git a/mindspore-lite/src/extendrt/kernel/ascend/plugin/ascend_kernel_plugin.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/activation_grad_infer.c similarity index 41% rename from mindspore-lite/src/extendrt/kernel/ascend/plugin/ascend_kernel_plugin.h rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/activation_grad_infer.c index a85395b3c6c33cf6e13f8dab3203809ab3c89e5e..9b2b77e2e3224bfd51c1de9f73999b244a693f3c 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/plugin/ascend_kernel_plugin.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/activation_grad_infer.c @@ -14,36 +14,32 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_PLUGIN_H_ +#include "nnacl/infer/activation_grad_infer.h" +#include "nnacl/infer/infer_register.h" -#include -#include -#include -#include -#include -#include "common/kernel.h" -#include "include/api/status.h" +int ActivationGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + const TensorC *input = inputs[0]; + const TensorC *input_grad = inputs[1]; + if (input->shape_size_ != input_grad->shape_size_) { + return NNACL_ERR; + } + for (size_t i = 0; i < input->shape_size_; i++) { + if (input->shape_[i] != input_grad->shape_[i]) { + return NNACL_ERR; + } + } -namespace mindspore::kernel { -using KernelModFunc = std::function()>; + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} -class AscendKernelPlugin { - public: - static Status TryRegister(); - static bool Register(); - - private: - AscendKernelPlugin(); - ~AscendKernelPlugin(); - Status TryRegisterInner(); - void Unregister(); - - void *handle_ = nullptr; - std::map *create_kernel_map_ = nullptr; - std::vector register_kernels_; - bool is_registered_ = false; - static std::mutex mutex_; -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_KERNEL_PLUGIN_H_ +REG_INFER(ActivationGrad, PrimType_ActivationGrad, ActivationGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/activation_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/activation_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..3ce5ec55c45ec18c0e662e84c3cbaf3749c66e8b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/activation_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ACTIVATION_GRAD_INFER_H +#define MINDSPORE_NNACL_ACTIVATION_GRAD_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ActivationGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ACTIVATION_GRAD_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/adam_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/adam_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..f2f12e0c96a0170b83388d433a50e5d4c266a3c7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/adam_infer.c @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/infer/adam_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int AdamInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 10); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[1]) || + NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[2]) || + NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[9]) || NNACLGetElementNum(inputs[3]) != 1 || + NNACLGetElementNum(inputs[4]) != 1 || NNACLGetElementNum(inputs[5]) != 1 || NNACLGetElementNum(inputs[6]) != 1 || + NNACLGetElementNum(inputs[7]) != 1 || NNACLGetElementNum(inputs[8]) != 1) { + return NNACL_ERR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + + return NNACL_OK; +} + +REG_INFER(Adam, PrimType_Adam, AdamInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/adam_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/adam_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..ab6e4f0c95e7b1f3ce38e7cf3d2bf256594c918b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/adam_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ADAM_INFER_H +#define MINDSPORE_NNACL_ADAM_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AdamInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADAM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/adam_weight_decay_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/adam_weight_decay_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..1cde38a400f8dc660ee80b8138a46070613ab926 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/adam_weight_decay_infer.c @@ -0,0 +1,56 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/adam_weight_decay_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int AdamWeightDecayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + const size_t expected_inputs_size = 10; + const int var_idx = 0; + const int m_idx = 1; + const int v_idx = 2; + const int lr_idx = 3; + const int beta1_idx = 4; + const int beta2_idx = 5; + const int epsilon = 6; + const int decay_idx = 7; + const int grad_idx = 8; + int check_ret = + CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, expected_inputs_size); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[var_idx]) != NNACLGetElementNum(inputs[m_idx]) || + NNACLGetElementNum(inputs[var_idx]) != NNACLGetElementNum(inputs[v_idx]) || + NNACLGetElementNum(inputs[var_idx]) != NNACLGetElementNum(inputs[grad_idx]) || + NNACLGetElementNum(inputs[lr_idx]) != 1 || NNACLGetElementNum(inputs[beta1_idx]) != 1 || + NNACLGetElementNum(inputs[beta2_idx]) != 1 || NNACLGetElementNum(inputs[epsilon]) != 1 || + NNACLGetElementNum(inputs[decay_idx]) != 1) { + return NNACL_ERR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + return NNACL_OK; +} + +REG_INFER(AdamWeightDecay, PrimType_AdamWeightDecay, AdamWeightDecayInferShape) diff --git a/mindspore-lite/src/common/draw/adapter_graphs/drawer_mark_filter.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/adam_weight_decay_infer.h similarity index 58% rename from mindspore-lite/src/common/draw/adapter_graphs/drawer_mark_filter.h rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/adam_weight_decay_infer.h index bdf9b1fe82f7441e77f6113f24bee84a9a68dc6c..30803dc5b4168aded490c2dbf4097ebbe8ff2388 100644 --- a/mindspore-lite/src/common/draw/adapter_graphs/drawer_mark_filter.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/adam_weight_decay_infer.h @@ -14,16 +14,19 @@ * limitations under the License. */ -#ifdef ENABLE_DRAW -#ifndef MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_DRAWER_MARK_FILTER_H_ -#define MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_DRAWER_MARK_FILTER_H_ +#ifndef MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H +#define MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H -#include -#include "src/executor/kernel_exec.h" +#include "nnacl/infer/common_infer.h" -namespace mindspore::lite { -using MarkFilter = std::function; -} // namespace mindspore::lite +#ifdef __cplusplus +extern "C" { +#endif + +int AdamWeightDecayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); -#endif // MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_DRAWER_MARK_FILTER_H_ +#ifdef __cplusplus +} #endif +#endif // MINDSPORE_NNACL_ADAM_WEIGHT_DECAY_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/add_sub_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/add_sub_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2977a8fefc0b4ecb0cc5a184abb790d473a1c200 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/add_sub_grad_infer.c @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/add_sub_grad_infer.h" +#include "nnacl/arithmetic_parameter.h" +#include "nnacl/infer/infer_register.h" + +int AddSubGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *dy = inputs[0]; + const TensorC *x1 = inputs[1]; + const TensorC *x2 = inputs[2]; + TensorC *dx1 = outputs[0]; + TensorC *dx2 = outputs[1]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + + param->ndim_ = dy->shape_size_; + param->in_elements_num0_ = (int)param->ndim_; + param->in_elements_num1_ = (int)param->ndim_; + param->out_elements_num_ = (int)param->ndim_; + size_t fillDimNum0 = dy->shape_size_ - x1->shape_size_; + size_t fillDimNum1 = dy->shape_size_ - x2->shape_size_; + size_t j0 = 0; + size_t j1 = 0; + for (size_t i = 0; i < dy->shape_size_; i++) { + param->in_shape0_[i] = (i < fillDimNum0) ? 1 : x1->shape_[j0++]; + param->in_shape1_[i] = (i < fillDimNum1) ? 1 : x2->shape_[j1++]; + param->out_shape_[i] = dy->shape_[i]; + } + + SetShapeTensor(dx1, x1); + SetShapeTensor(dx2, x2); + SetDataTypeFormat(dx1, dy); + SetDataTypeFormat(dx2, dy); + return NNACL_OK; +} + +REG_INFER(AddGrad, PrimType_AddGrad, AddSubGradInferShape) +REG_INFER(SubGrad, PrimType_SubGrad, AddSubGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/add_sub_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/add_sub_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..870cb1db0fe3e65490a1dadc11ee46093eeb974d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/add_sub_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ADD_SUB_GRAD_INFER_H +#define MINDSPORE_NNACL_ADD_SUB_GRAD_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AddSubGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADD_SUB_GRAD_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/addn_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/addn_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..10030677d241b29c0eca381e946d498d2612322f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/addn_infer.c @@ -0,0 +1,86 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/addn_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int AddnInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + if (inputs_size < 2) { + return NNACL_ERR; + } + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + size_t max_dims = input->shape_size_; + size_t max_dims_idx = 0; + + // check zerp dimension + for (size_t i = 0; i < max_dims; i++) { + NNACL_CHECK_FALSE(input->shape_[i] == 0, NNACL_ERR); + } + + // determine max_dims + for (size_t i = 1; i < inputs_size; ++i) { + if (inputs[i]->shape_size_ > max_dims) { + max_dims = inputs[i]->shape_size_; + max_dims_idx = i; + } + } + ShapeSet(output->shape_, &output->shape_size_, inputs[max_dims_idx]->shape_, inputs[max_dims_idx]->shape_size_); + + // make sure all elements have the same size or 1 (broadcasting) in all dimensions + for (size_t i = 1; i < inputs_size; ++i) { + if ((inputs[i]->shape_size_ != max_dims) && + (NNACLGetElementNum(inputs[i]) != NNACLGetElementNum(inputs[max_dims_idx]))) { + return NNACL_ERR; + } + if (inputs[i]->shape_size_ == max_dims) { + for (size_t j = 0; j < max_dims; j++) { + if (inputs[i]->shape_[j] != inputs[max_dims_idx]->shape_[j] && inputs[i]->shape_[j] != 1 && + inputs[max_dims_idx]->shape_[j] != 1) { + return NNACL_ERR; + } + } + } + } + + for (size_t d = 0; d < inputs[max_dims_idx]->shape_size_; ++d) { + size_t max_dim = 0; + for (size_t i = 0; i < inputs_size; ++i) { + size_t shift = max_dims - (size_t)(inputs[i]->shape_size_); + size_t dim = (i < shift) ? 1 : (size_t)(inputs[i]->shape_[d]); + if (dim > max_dim) { + max_dim = dim; + } + } + output->shape_[d] = (int)(max_dim); // set the biggest dimension in the output tensor + } + + return NNACL_OK; +} + +REG_INFER(AddN, PrimType_AddN, AddnInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/addn_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/addn_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a0c889c600f39c6d11f8e65fec32cfd63d5da2a0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/addn_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ADDN_INFER_H +#define MINDSPORE_NNACL_ADDN_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AddnInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ADDN_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/affine_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/affine_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..bbcc3495512e5831a20053117f50d265c6481b63 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/affine_infer.c @@ -0,0 +1,122 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/affine_infer.h" +#include "nnacl/infer/infer_register.h" + +int MatmulInfer(const AffineParameter *param, int a_shape[MAX_SHAPE_SIZE], size_t a_shape_size, + int b_shape[MAX_SHAPE_SIZE], size_t b_shape_size) { + MatMulParameter *matmul_param = param->matmul_parameter_; + NNACL_CHECK_NULL_RETURN_ERR(matmul_param); + if (matmul_param->a_transpose_) { + if (a_shape_size < 2) { + return NNACL_ERR; + } + iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - 2]); + } + if (matmul_param->b_transpose_) { + if (b_shape_size < 2) { + return NNACL_ERR; + } + iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]); + } + return NNACL_OK; +} + +int AffineInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 3, 4, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + // splice + matmul + TensorC *input0 = (TensorC *)inputs[0]; + TensorC *input1 = (TensorC *)inputs[1]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input0); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + AffineParameter *param = (AffineParameter *)parameter; + if (param == NULL) { + return NNACL_NULL_PTR; + } + + int a_shape[MAX_SHAPE_SIZE] = {0}; + size_t a_shape_size = 0; + ShapeSet(a_shape, &a_shape_size, input0->shape_, input0->shape_size_); + if (a_shape_size == 4 && a_shape[2] == 1 && a_shape[3] == 1) { + a_shape_size = 2; + SetShapeArray(input0, a_shape, a_shape_size); + } + int context_min = param->context_[0]; + int context_max = param->context_[param->context_size_ - 1]; + + a_shape[1] = input0->shape_[1] - (context_max - context_min); + a_shape[2] = param->output_dim_; + + int b_shape[MAX_SHAPE_SIZE] = {0}; + size_t b_shape_size = 0; + ShapeSet(b_shape, &b_shape_size, input1->shape_, input1->shape_size_); + + bool del_start = false; + bool del_end = false; + if (a_shape_size == 1) { + int ret = ShapeInsert(a_shape, &a_shape_size, 0, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + SetShapeArray(input0, a_shape, a_shape_size); + del_start = true; + } + if (b_shape_size == 1) { + ShapePush(b_shape, &b_shape_size, 1); + SetShapeArray(input1, b_shape, b_shape_size); + del_end = true; + } + for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) { + if (a_shape[a_shape_size - 3 - i] != b_shape[b_shape_size - 3 - i]) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + + int ret = MatmulInfer(param, a_shape, a_shape_size, b_shape, b_shape_size); + if (ret != NNACL_OK) { + return ret; + } + + int c_shape[MAX_SHAPE_SIZE]; + size_t c_shape_size = 0; + ShapeSet(c_shape, &c_shape_size, a_shape, a_shape_size); + if (c_shape_size < 1 || b_shape_size < 1) { + return NNACL_ERR; + } + c_shape[c_shape_size - 1] = b_shape[b_shape_size - 1]; + if (del_start) { + int erase_ret = ShapeErase(c_shape, &c_shape_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + } + if (del_end) { + c_shape_size--; + } + SetShapeArray(output, c_shape, c_shape_size); + return NNACL_OK; +} + +REG_INFER(Affine, PrimType_Affine, AffineInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/affine_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/affine_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d347429eb7c140bb9aca397407e28c5ee64341b5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/affine_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_INFER_AFFINE_INFER_H_ +#define MINDSPORE_NNACL_INFER_AFFINE_INFER_H_ +#include "nnacl/infer/common_infer.h" +#include "nnacl/affine_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AffineInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_AFFINE_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/all_gather_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/all_gather_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..946cc9fda56036e4453bef64da4eb7257dcd671e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/all_gather_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/all_gather_infer.h" +#include "nnacl/infer/infer_register.h" + +int AllGatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (inputs_size != 1 || outputs_size != 1) { + return NNACL_NULL_PTR; + } + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + AllGatherParameter *param = (AllGatherParameter *)parameter; + if (param->rank_size_ <= 0) { + return NNACL_INFER_INVALID; + } + + const TensorC *input_tensor = inputs[0]; + const int *in_shape = input_tensor->shape_; + TensorC *out_tensor = outputs[0]; + + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + out_shape[0] = in_shape[0] * param->rank_size_; + out_shape_size++; + for (int i = 1; i < input_tensor->shape_size_; i++) { + out_shape[i] = in_shape[i]; + out_shape_size++; + } + SetShapeArray(out_tensor, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(AllGather, PrimType_AllGather, AllGatherInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/all_gather_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/all_gather_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..4f43657b88af4cbe4e9d0954901fc9ab9bbfa6df --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/all_gather_infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_INFER_ALL_GATHER_INFER_H_ +#define MINDSPORE_NNACL_INFER_ALL_GATHER_INFER_H_ + +#include "nnacl/infer/common_infer.h" +#include "nnacl/all_gather_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AllGatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_ALL_GATHER_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/apply_momentum_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/apply_momentum_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..c911ad3384bf8c19c392d1ba9ccacfc317da5fc9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/apply_momentum_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/apply_momentum_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int ApplyMomentumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 5); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[1]) || + NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[3]) || NNACLGetElementNum(inputs[2]) != 1 || + NNACLGetElementNum(inputs[4]) != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + if (out == NULL) { + return NNACL_NULL_PTR; + } + out->data_type_ = inputs[0]->data_type_; + out->format_ = inputs[0]->format_; + out->shape_size_ = 1; + out->shape_[0] = 1; + } + + return NNACL_OK; +} + +REG_INFER(ApplyMomentum, PrimType_ApplyMomentum, ApplyMomentumInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/apply_momentum_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/apply_momentum_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..f7460f0b3b7fa4aa32a8faaddbcf7eedd21dba14 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/apply_momentum_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_APPLY_MOMENTUM_INFER_H +#define MINDSPORE_NNACL_APPLY_MOMENTUM_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ApplyMomentumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_APPLY_MOMENTUM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/argmin_max_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/argmin_max_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..3608e762e1d9a77c0406f97f369e5ca77dcb5867 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/argmin_max_infer.c @@ -0,0 +1,83 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/argmin_max_infer.h" +#include "nnacl/infer/infer_register.h" + +int ArgMinMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size != 1 || outputs_size > 2) { + return NNACL_ERR; + } + + ArgMinMaxParameter *param = (ArgMinMaxParameter *)parameter; + const TensorC *input = inputs[0]; + TensorC *output_1 = NULL; + TensorC *output_2 = NULL; + if (outputs_size == 2) { + output_1 = outputs[0]; + output_2 = outputs[1]; + } else if (param->out_value_) { + output_2 = outputs[0]; + } else { + output_1 = outputs[0]; + } + + if (output_1 != NULL) { + output_1->format_ = input->format_; + output_1->data_type_ = kNumberTypeInt32; + } + if (output_2 != NULL) { + SetDataTypeFormat(output_2, input); + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, input->shape_, input->shape_size_); + int input_shape_size = (int)input->shape_size_; + int axis = param->axis_ < 0 ? param->axis_ + input_shape_size : param->axis_; + if (axis >= input_shape_size || axis < 0) { + return NNACL_PARAM_INVALID; + } + if (param->topk_ == 1 && !param->keep_dims_) { + int erase_ret = ShapeErase(output_shape, &output_shape_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + } else { + output_shape[axis] = param->topk_; + } + + if (output_1 != NULL) { + SetShapeArray(output_1, output_shape, output_shape_size); + } + if (output_2 != NULL) { + SetShapeArray(output_2, output_shape, output_shape_size); + } + return NNACL_OK; +} + +REG_INFER(ArgMin, PrimType_ArgMinFusion, ArgMinMaxInferShape) +REG_INFER(ArgMax, PrimType_ArgMaxFusion, ArgMinMaxInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/argmin_max_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/argmin_max_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..2febbc78501e6f530f07bea9563c4ca1d977fd25 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/argmin_max_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ARGMAX_INFER_H +#define MINDSPORE_NNACL_ARGMAX_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/arg_min_max_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ArgMinMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ARGMAX_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_compare_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_compare_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..655a158250d21692adf931f9c4766dbdd4f08a0d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_compare_infer.c @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/arithmetic_compare_infer.h" +#include "nnacl/infer/infer_register.h" + +int ArithmeticCompareInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int res = ArithmeticInferShape(inputs, inputs_size, outputs, outputs_size, parameter); + TensorC *output = outputs[0]; + if (output == NULL) { + return NNACL_NULL_PTR; + } + output->data_type_ = kNumberTypeBool; + return res; +} + +REG_INFER(Equal, PrimType_Equal, ArithmeticCompareInferShape) +REG_INFER(Greater, PrimType_Greater, ArithmeticCompareInferShape) +REG_INFER(GreaterEqual, PrimType_GreaterEqual, ArithmeticCompareInferShape) +REG_INFER(Less, PrimType_Less, ArithmeticCompareInferShape) +REG_INFER(LessEqual, PrimType_LessEqual, ArithmeticCompareInferShape) +REG_INFER(NotEqual, PrimType_NotEqual, ArithmeticCompareInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_compare_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_compare_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5513fb1219eecbf6a28790748785c4ff5aa22c2d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_compare_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ARITHMETIC_COMPARE_INFER_H +#define MINDSPORE_NNACL_ARITHMETIC_COMPARE_INFER_H + +#include "nnacl/infer/arithmetic_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ArithmeticCompareInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ARITHMETIC_COMPARE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..aa2c992ec0c1516f7f71518899da639ae76ed795 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_grad_infer.c @@ -0,0 +1,107 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/arithmetic_grad_infer.h" +#include "nnacl/arithmetic_parameter.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +/* + * the Arithmetic Grad op include AddGrad, SubGrad, MulGrad, DivGrad, MaximumGrad, MinimumGrad + * according to the arithmetic_fp32.h now + * the MaximumGrad, MinimumGrad run through MaximumGradInfershape + * the AddGrad, SubGrad run through AddSubGradInfershape + * the others run through this function + * */ +int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *dy = inputs[0]; + const TensorC *x1 = inputs[1]; + const TensorC *x2 = inputs[2]; + TensorC *dx1 = outputs[0]; + TensorC *dx2 = outputs[1]; + + if (dy->shape_size_ > MAX_SHAPE_SIZE || x1->shape_size_ > MAX_SHAPE_SIZE || x2->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int in_shape0[MAX_SHAPE_SIZE] = {0}; + size_t in_shape0_size = 0; + ShapeSet(in_shape0, &in_shape0_size, x1->shape_, x1->shape_size_); + int in_shape1[MAX_SHAPE_SIZE] = {0}; + size_t in_shape1_size = 0; + ShapeSet(in_shape1, &in_shape1_size, x2->shape_, x2->shape_size_); + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, dy->shape_, dy->shape_size_); + + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + + if (NNACLGetElementNum(dx1) < NNACLGetElementNum(dx2)) { + param->ndim_ = in_shape1_size; + param->in_elements_num0_ = (int)param->ndim_; + param->in_elements_num1_ = (int)param->ndim_; + param->out_elements_num_ = (int)param->ndim_; + size_t fill_dim_num = in_shape1_size - in_shape0_size; // This will not work for batch! + int j = 0; + for (unsigned int i = 0; i < in_shape1_size; i++) { + if (i < fill_dim_num) { + param->in_shape1_[i] = 1; + } else { + param->in_shape1_[i] = in_shape0[j++]; + } + param->in_shape0_[i] = in_shape1[i]; + param->out_shape_[i] = out_shape[i]; + } + } else if (NNACLGetElementNum(dx2) < NNACLGetElementNum(dx1)) { + param->ndim_ = in_shape0_size; + param->in_elements_num0_ = (int)param->ndim_; + param->in_elements_num1_ = (int)param->ndim_; + param->out_elements_num_ = (int)param->ndim_; + param->broadcasting_ = true; + int j = 0; + size_t fill_dim_num = in_shape0_size - in_shape1_size; + for (unsigned int i = 0; i < in_shape0_size; i++) { + if (i < fill_dim_num) { + param->in_shape1_[i] = 1; + } else { + param->in_shape1_[i] = in_shape1[j++]; + } + param->in_shape0_[i] = in_shape0[i]; + param->out_shape_[i] = out_shape[i]; + } + } else { + param->broadcasting_ = false; + for (unsigned int i = 0; i < in_shape0_size; i++) { + param->in_shape1_[i] = in_shape1[i]; + param->in_shape0_[i] = in_shape0[i]; + param->out_shape_[i] = out_shape[i]; + } + } + + SetShapeTensor(dx1, x1); + SetShapeTensor(dx2, x2); + dx1->data_type_ = dy->data_type_; + dx2->data_type_ = dy->data_type_; + return NNACL_OK; +} + +REG_INFER(DivGrad, PrimType_DivGrad, ArithmeticGradInferShape) +REG_INFER(MulGrad, PrimType_MulGrad, ArithmeticGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..bedf1dabb2396980fa9ca266f226218480331caa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ +#define MINDSPORE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..3e5ab70cb09799a95d0dc829c8f005ae22a7b82e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_infer.c @@ -0,0 +1,123 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/arithmetic_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/infer/broadcast_to_infer.h" + +void SetOutputDtypeFormat(const TensorC *input0, const TensorC *input1, TensorC *output) { + output->format_ = input0->format_; + output->data_type_ = input0->data_type_; + // e.g. input0's shape is 1 and input1's shape is 1 15 15 1 + // only regard larger shape size input as the right format input currently + // legacy problem: if input0 infer failed before, its shape is [-1], and input1's shape is [1,2] which need to + // be broadcasted. In this case our program will use input1's format, that's wrong and need to be solved later. + if (input0->data_ != NULL || input0->shape_size_ < input1->shape_size_) { + output->format_ = input1->format_; + } + // when input0 is const, it is quanted before insert quant trans op, so use input1 data type instead + if (((input0->data_ != NULL) && (input1->data_type_ != kTypeUnknown)) || + ((input0->data_type_ == kNumberTypeInt8) && (input1->data_type_ == kNumberTypeFloat32))) { + output->data_type_ = input1->data_type_; + } +} + +int BroadCastInferShape(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1, int *out_shape, + bool *has_broad_cast) { + if (input_shape0_size > MAX_SHAPE_SIZE || input_shape1_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + MakeUpInputShapes(input_shape0_size, input_shape1_size, input_shape0, input_shape1, ndim, in_shape0, in_shape1); + if (*ndim >= MAX_SHAPE_SIZE) { + return NNACL_INFER_INVALID; + } + + return BroadCastOutputShape(in_shape0, in_shape1, *ndim, out_shape, has_broad_cast); +} + +int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + param->broadcasting_ = false; + + const TensorC *input0 = inputs[0]; + const TensorC *input1 = inputs[1]; + TensorC *output = outputs[0]; + + const int *input_shape0 = input0->shape_; + size_t input_shape0_size = input0->shape_size_; + const int *input_shape1 = input1->shape_; + size_t input_shape1_size = input1->shape_size_; + SetOutputDtypeFormat(input0, input1, output); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int in_shape0[MAX_SHAPE_SIZE] = {0}; + int in_shape1[MAX_SHAPE_SIZE] = {0}; + int output_shape[MAX_SHAPE_SIZE] = {0}; + int ndim = (int)input_shape0_size; + bool has_broad_cast = false; + if (BroadCastInferShape(input_shape0_size, input_shape1_size, input_shape0, input_shape1, &ndim, in_shape0, in_shape1, + output_shape, &has_broad_cast) != NNACL_OK) { + return NNACL_ERR; + } + + SetShapeArray(output, output_shape, ndim); + + param->broadcasting_ = has_broad_cast; + param->ndim_ = (size_t)ndim; + if (ndim > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + memcpy(param->in_shape0_, in_shape0, ndim * sizeof(int)); + memcpy(param->in_shape1_, in_shape1, ndim * sizeof(int)); + memcpy(param->out_shape_, output_shape, ndim * sizeof(int)); + + param->in_elements_num0_ = 1; + param->in_elements_num1_ = 1; + param->out_elements_num_ = 1; + for (int i = 0; i < ndim; i++) { + param->in_elements_num0_ *= param->in_shape0_[i]; + param->in_elements_num1_ *= param->in_shape1_[i]; + param->out_elements_num_ *= param->out_shape_[i]; + } + return NNACL_OK; +} + +REG_INFER(Add, PrimType_AddFusion, ArithmeticInferShape) +REG_INFER(BiasAdd, PrimType_BiasAdd, ArithmeticInferShape) +REG_INFER(Div, PrimType_DivFusion, ArithmeticInferShape) +REG_INFER(Eltwise, PrimType_Eltwise, ArithmeticInferShape) +REG_INFER(FloorDiv, PrimType_FloorDiv, ArithmeticInferShape) +REG_INFER(FloorMod, PrimType_FloorMod, ArithmeticInferShape) +REG_INFER(LogicalAnd, PrimType_LogicalAnd, ArithmeticInferShape) +REG_INFER(LogicalOr, PrimType_LogicalOr, ArithmeticInferShape) +REG_INFER(Maximum, PrimType_Maximum, ArithmeticInferShape) +REG_INFER(Minimum, PrimType_Minimum, ArithmeticInferShape) +REG_INFER(Mod, PrimType_Mod, ArithmeticInferShape) +REG_INFER(Mul, PrimType_MulFusion, ArithmeticInferShape) +REG_INFER(RealDiv, PrimType_RealDiv, ArithmeticInferShape) +REG_INFER(Sub, PrimType_SubFusion, ArithmeticInferShape) +REG_INFER(SquaredDifference, PrimType_SquaredDifference, ArithmeticInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6c3a8005ff31a5f69d259f9740e39e2fd17822e1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/arithmetic_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ARITHMETIC_INFER_H +#define MINDSPORE_NNACL_ARITHMETIC_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/arithmetic_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outpus_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ARITHMETIC_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/assert_op_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/assert_op_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..35380ba319da840cf86339cc1888f255d947e88a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/assert_op_infer.c @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/assert_op_infer.h" +#include "nnacl/infer/infer_register.h" + +int AssertOpInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + return NNACL_OK; +} + +REG_INFER(Assert, PrimType_Assert, AssertOpInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/assert_op_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/assert_op_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5f0156e95d8ee9b6be6daab466f8ee8f28501582 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/assert_op_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ASSERT_OP_INFER_H +#define MINDSPORE_NNACL_ASSERT_OP_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AssertOpInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ASSERT_OP_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/assign_add_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/assign_add_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..3bde094cd6eb9f6e92f72c8fa35c4dfa144e2e61 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/assign_add_infer.c @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/assign_add_infer.h" +#include "nnacl/infer/infer_register.h" + +int AssignAddInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *x = inputs[0]; + const TensorC *y = inputs[1]; + TensorC *out = outputs[0]; + if (x->data_type_ != y->data_type_) { + return NNACL_ERR; + } + SetDataTypeFormat(out, x); + SetShapeTensor(out, x); + return NNACL_OK; +} + +REG_INFER(AssignAdd, PrimType_AssignAdd, AssignAddInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/assign_add_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/assign_add_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..1f7cea8abae89543759cf689ddb21e8d520a5f53 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/assign_add_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ASSIGN_ADD_INFER_H +#define MINDSPORE_NNACL_ASSIGN_ADD_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AssignAddInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ASSIGN_ADD_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/assign_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/assign_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..f7563c6923194001813febfd00437b51d9598dfd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/assign_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/assign_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int AssignInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[1])) { + return NNACL_ERR; + } + + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + return NNACL_OK; +} + +REG_INFER(Assign, PrimType_Assign, AssignInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/assign_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/assign_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..1de69e2185183301efaf7e71d9439dfdf8a727ab --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/assign_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ASSIGN_INFER_H +#define MINDSPORE_NNACL_ASSIGN_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AssignInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ASSIGN_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/attention_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/attention_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..56e30b3bbb6ecdb336fe3463580a2454c1a03b74 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/attention_infer.c @@ -0,0 +1,74 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/attention_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/attention_parameter.h" + +int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 7, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + AttentionParameter *param = (AttentionParameter *)parameter; + const TensorC *q_input = inputs[FIRST_INPUT]; + const TensorC *k_input = inputs[SECOND_INPUT]; + TensorC *output0 = outputs[FIRST_INPUT]; + SetDataTypeFormat(output0, q_input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *q_weight = inputs[FOURTH_INPUT]; + if (q_input->shape_size_ != C2NUM && q_input->shape_size_ != C3NUM) { + return NNACL_ERR; + } + if (q_weight->shape_size_ != C2NUM) { + return NNACL_ERR; + } + int batch = (q_input->shape_size_ == C2NUM) ? 1 : q_input->shape_[0]; + int f_seq = (q_input->shape_size_ == C2NUM) ? q_input->shape_[0] : q_input->shape_[1]; + int t_seq_len = k_input->shape_[1]; + if (q_input->shape_size_ == C2NUM) { + output0->shape_[FIRST_INPUT] = batch * f_seq; + output0->shape_[SECOND_INPUT] = param->head_num_ * param->head_size_; + output0->shape_size_ = C2NUM; + } else { + output0->shape_[FIRST_INPUT] = batch; + output0->shape_[SECOND_INPUT] = f_seq; + output0->shape_[THIRD_INPUT] = param->head_num_ * param->head_size_; + output0->shape_size_ = C3NUM; + } + if (outputs_size >= C3NUM) { + TensorC *output1 = outputs[SECOND_INPUT]; + SetDataTypeFormat(output1, q_input); + output1->shape_[FIRST_INPUT] = batch; + output1->shape_[SECOND_INPUT] = param->head_num_; + output1->shape_[THIRD_INPUT] = param->head_size_; + output1->shape_[FOURTH_INPUT] = t_seq_len; + output1->shape_size_ = C4NUM; + TensorC *output2 = outputs[THIRD_INPUT]; + SetDataTypeFormat(output2, q_input); + output2->shape_[FIRST_INPUT] = batch; + output2->shape_[SECOND_INPUT] = param->head_num_; + output2->shape_[THIRD_INPUT] = t_seq_len; + output2->shape_[FOURTH_INPUT] = param->head_size_; + output2->shape_size_ = C4NUM; + } + return NNACL_OK; +} + +REG_INFER(Attention, PrimType_Attention, AttentionInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/attention_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/attention_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..ebb602ade017dd77584a4aff60db97eb5a148fd4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/attention_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ATTENTION_INFER_H +#define MINDSPORE_NNACL_ATTENTION_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ATTENTION_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/audio_spectrogram_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/audio_spectrogram_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..959a4af64d6ad956d347f91e4676acb758ade3d7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/audio_spectrogram_infer.c @@ -0,0 +1,75 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/audio_spectrogram_infer.h" +#include "nnacl/infer/infer_register.h" + +unsigned Log2Ceil(unsigned length) { + if (length == 0) { + return 0; + } + int floor = 0; + for (int i = 4; i >= 0; --i) { + const unsigned shift = (1 << (unsigned)i); + unsigned tmp = length >> shift; + if (tmp != 0) { + length = tmp; + floor += shift; + } + } + return length == (length & ~(length - 1)) ? floor : floor + 1; +} + +unsigned GetFftLength(unsigned length) { + unsigned shift = Log2Ceil(length); + return 1 << shift; +} + +int AudioSpectrogramInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 2) { + return NNACL_ERR; + } + AudioSpectrogramParameter *param = (AudioSpectrogramParameter *)parameter; + if (param->window_size_ < 2) { + return NNACL_ERR; + } + if (param->stride_ < 1) { + return NNACL_ERR; + } + int output_shape[3]; + output_shape[0] = input->shape_[1]; + int sample_sub_window = input->shape_[0] - param->window_size_; + output_shape[1] = sample_sub_window < 0 ? 0 : 1 + sample_sub_window / param->stride_; + // compute fft length + int fft_length = (int)GetFftLength(param->window_size_); + output_shape[2] = fft_length / 2 + 1; + SetShapeArray(output, output_shape, 3); + return NNACL_OK; +} + +REG_INFER(AudioSpectrogram, PrimType_AudioSpectrogram, AudioSpectrogramInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/audio_spectrogram_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/audio_spectrogram_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..af5daa018bb6054b3199f0c06f39a5e4c70c7f6a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/audio_spectrogram_infer.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_AUDIO_SPECTROGRAM_INFER_H +#define MINDSPORE_NNACL_AUDIO_SPECTROGRAM_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct AudioSpectrogramParameter { + OpParameter op_parameter_; + int window_size_; + int stride_; +} AudioSpectrogramParameter; + +int AudioSpectrogramInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_AUDIO_SPECTROGRAM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/batch_to_space_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/batch_to_space_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..0f36e4137f0b5729a373f5d48dd185e77745323b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/batch_to_space_infer.c @@ -0,0 +1,144 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/batch_to_space_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int SetOutputShapeFromParam(const TensorC *const *inputs, TensorC **outputs, const OpParameter *parameter) { + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, inputs[0]->shape_, inputs[0]->shape_size_); + + if (input_shape_size != 4) { + return NNACL_PARAM_INVALID; + } + + const BatchToSpaceParameter *param = (const BatchToSpaceParameter *)parameter; + const int32_t *block_shape = param->block_shape_; + const int32_t *crops = param->crops_; + int mul_block_shape = 1; + + for (size_t i = 0; i < 2; ++i) { + if (block_shape[i] <= 0) { + return NNACL_PARAM_INVALID; + } + if (input_shape[kNHWC_N] % block_shape[i]) { + return NNACL_ERR; + } + mul_block_shape *= block_shape[i]; + } + + if (input_shape[kNHWC_N] < mul_block_shape) { + return NNACL_PARAM_INVALID; + } + for (size_t i = 0; i < 4; ++i) { + if (crops[i] < 0) { + return NNACL_PARAM_INVALID; + } + } + if (mul_block_shape == 0) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input_shape_size; + output_shape[kNHWC_N] = input_shape[kNHWC_N] / mul_block_shape; + output_shape[kNHWC_H] = input_shape[kNHWC_H] * block_shape[0] - crops[0] - crops[1]; + output_shape[kNHWC_W] = input_shape[kNHWC_W] * block_shape[1] - crops[2] - crops[3]; + output_shape[kNHWC_C] = input_shape[kNHWC_C]; + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +int SetOutputShapeFromInput(const TensorC *const *inputs, TensorC **outputs) { + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, inputs[0]->shape_, inputs[0]->shape_size_); + if (input_shape_size != 4) { + return NNACL_PARAM_INVALID; + } + int *block_shape = (int *)(inputs[1]->data_); + int *crops = (int *)(inputs[2]->data_); + if (NNACLGetElementNum(inputs[1]) != 2) { + return NNACL_PARAM_INVALID; + } + if (NNACLGetElementNum(inputs[2]) != 4) { + return NNACL_PARAM_INVALID; + } + int mul_block_shape_ = 1; + + for (size_t i = 0; i < 2; ++i) { + if (block_shape[i] <= 0) { + return NNACL_PARAM_INVALID; + } + if (input_shape[kNHWC_N] % block_shape[i]) { + return 1; + } + mul_block_shape_ *= block_shape[i]; + } + + if (input_shape[kNHWC_N] < mul_block_shape_) { + return NNACL_PARAM_INVALID; + } + for (size_t i = 0; i < 4; ++i) { + if (crops[i] < 0) { + return NNACL_PARAM_INVALID; + } + } + if (mul_block_shape_ == 0) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input_shape_size; + output_shape[kNHWC_N] = input_shape[kNHWC_N] / mul_block_shape_; + output_shape[kNHWC_H] = input_shape[kNHWC_H] * block_shape[0] - crops[0] - crops[1]; + output_shape[kNHWC_W] = input_shape[kNHWC_W] * block_shape[1] - crops[2] - crops[3]; + output_shape[kNHWC_C] = input_shape[kNHWC_C]; + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +int BatchToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (outputs_size != 1 || (inputs_size != 1 && inputs_size != 3)) { + return NNACL_PARAM_INVALID; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (inputs_size == 1) { + ret = SetOutputShapeFromParam(inputs, outputs, parameter); + return ret; + } + if (inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { + return NNACL_INFER_INVALID; + } + ret = SetOutputShapeFromInput(inputs, outputs); + return ret; +} + +REG_INFER(BatchToSpace, PrimType_BatchToSpace, BatchToSpaceInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/batch_to_space_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/batch_to_space_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..90d3cf0c929b31acd7b41dbe74491304ad9885f9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/batch_to_space_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BATCH_TO_SPACE_INFER_H +#define MINDSPORE_NNACL_BATCH_TO_SPACE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/batch_to_space_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BatchToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BATCH_TO_SPACE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/bias_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/bias_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..0f520c22cc0a1edd8bb638654d786d9dbd3f71de --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/bias_grad_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/bias_grad_infer.h" +#include "nnacl/infer/infer_register.h" + +int BiasGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + if (in0->shape_size_ > MAX_SHAPE_SIZE || in0->shape_size_ < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + int inshape[] = {in0->shape_[in0->shape_size_ - 1]}; + size_t inshape_size = 1; + SetDataTypeFormat(out, in0); + SetShapeArray(out, inshape, inshape_size); + + return NNACL_OK; +} + +REG_INFER(BiasAddGrad, PrimType_BiasAddGrad, BiasGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/bias_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/bias_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..3d87f516c4944a0ebe1bf5f05442cbd0e4f6b5b1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/bias_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BIAS_GRAD_INFER_H +#define MINDSPORE_NNACL_BIAS_GRAD_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BiasGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BIAS_GRAD_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/binary_cross_entropy_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/binary_cross_entropy_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..22e207ac4a75f21793d0c92992488fe9e7f648eb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/binary_cross_entropy_infer.c @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/binary_cross_entropy_infer.h" +#include "nnacl/infer/infer_register.h" + +int BinaryCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (ret != NNACL_OK) { + return ret; + } + const TensorC *x = inputs[0]; + TensorC *out = outputs[0]; + SetDataTypeFormat(out, x); + BinaryCrossEntropyParameter *param = (BinaryCrossEntropyParameter *)parameter; + ReductionType reduction = (ReductionType)(param->reduction); + if (reduction == Reduction_Mean || reduction == Reduction_Sum) { + out->shape_size_ = 1; + out->shape_[0] = 1; + } else { + SetShapeTensor(out, x); + } + return NNACL_OK; +} + +REG_INFER(BinaryCrossEntropy, PrimType_BinaryCrossEntropy, BinaryCrossEntropyInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/binary_cross_entropy_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/binary_cross_entropy_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6f4657c6432851dd8a9f51b2e80ba43772681577 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/binary_cross_entropy_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BINARY_CROSS_ENTROPY_INFER_H +#define MINDSPORE_NNACL_BINARY_CROSS_ENTROPY_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32_grad/binary_cross_entropy.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BinaryCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BINARY_CROSS_ENTROPY_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/bn_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/bn_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..3a70c32284b6e20a4aa4fbe96cf67c5439bf9545 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/bn_grad_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/bn_grad_infer.h" +#include "nnacl/infer/infer_register.h" + +int BnGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 6, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in = inputs[1]; + if ((inputs[0]->shape_size_ == 4 && inputs[0]->format_ != Format_NHWC) || + (in->shape_size_ == 4 && in->format_ != Format_NHWC)) { + return NNACL_FORMAT_ERROR; + } + const TensorC *scale = inputs[2]; + SetShapeTensor(outputs[0], in); + SetDataTypeFormat(outputs[0], in); + SetShapeTensor(outputs[1], scale); + SetDataTypeFormat(outputs[1], scale); + SetShapeTensor(outputs[2], scale); + SetDataTypeFormat(outputs[2], scale); + return NNACL_OK; +} + +REG_INFER(BatchNormGrad, PrimType_BatchNormGrad, BnGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/bn_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/bn_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5968586cb41a36a5bb1baa8820f7b9bd837e569b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/bn_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BN_GRAD_INFER_H +#define MINDSPORE_NNACL_BN_GRAD_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BnGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BN_GRAD_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/broadcast_to_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/broadcast_to_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..896dcf3897c88a7339af10ae884fc345b07ab015 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/broadcast_to_infer.c @@ -0,0 +1,200 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/broadcast_to_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int GetShapeByType(const TensorC *shape_tensor, int shape_size, int *dst_shape) { + if (shape_tensor == NULL || dst_shape == NULL) { + return NNACL_ERR; + } + if (shape_size == 0) { + return NNACL_INFER_INVALID; + } + NNACL_CHECK_NULL_RETURN_ERR(shape_tensor->data_); + switch (shape_tensor->data_type_) { + case kNumberTypeInt8: { + int8_t *data = (int8_t *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = data[i]; + } + } break; + case kNumberTypeInt32: { + int32_t *data = (int32_t *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = data[i]; + } + } break; + case kNumberTypeInt64: { + int64_t *data = (int64_t *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = (int)data[i]; + } + } break; + case kNumberTypeFloat: { + float *data = (float *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = data[i]; + } + } break; + case kNumberTypeUInt32: { + uint32_t *data = (uint32_t *)(shape_tensor->data_); + for (int i = 0; i < shape_size; i++) { + dst_shape[i] = (int)data[i]; + } + } break; + default: { + return NNACL_ERR; + } + } + return NNACL_OK; +} + +void MakeUpInputShapes(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1) { + if (input_shape0_size < input_shape1_size) { + *ndim = input_shape1_size; + int fill_dim_num = input_shape1_size - input_shape0_size; + int j = 0; + for (int i = 0; i < input_shape1_size; i++) { + if (i < fill_dim_num) { + in_shape0[i] = 1; + } else { + in_shape0[i] = input_shape0[j++]; + } + in_shape1[i] = input_shape1[i]; + } + } else if (input_shape0_size > input_shape1_size) { + *ndim = input_shape0_size; + int fill_dim_num = input_shape0_size - input_shape1_size; + int j = 0; + for (int i = 0; i < input_shape0_size; i++) { + if (i < fill_dim_num) { + in_shape1[i] = 1; + } else { + in_shape1[i] = input_shape1[j++]; + } + in_shape0[i] = input_shape0[i]; + } + } else { + for (int i = 0; i < input_shape0_size; i++) { + in_shape1[i] = input_shape1[i]; + in_shape0[i] = input_shape0[i]; + } + } +} + +int BroadCastOutputShape(const int *in_shape0, const int *in_shape1, const int ndim, int *out_shape, + bool *has_broad_cast) { + for (int i = 0; i < ndim; i++) { + if (in_shape0[i] != in_shape1[i]) { + if (in_shape0[i] == 1) { + out_shape[i] = in_shape1[i]; + } else if (in_shape1[i] == 1) { + out_shape[i] = in_shape0[i]; + } else { + return NNACL_ERR; + } + *has_broad_cast = true; + } else { + out_shape[i] = in_shape0[i]; + } + } + return NNACL_OK; +} + +int BroadCastToShape(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *out_shape, bool *has_broad_cast) { + if (input_shape0_size > MAX_SHAPE_SIZE || input_shape1_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + + int in_shape0[MAX_SHAPE_SIZE] = {0}; + int in_shape1[MAX_SHAPE_SIZE] = {0}; + + MakeUpInputShapes(input_shape0_size, input_shape1_size, input_shape0, input_shape1, ndim, in_shape0, in_shape1); + if (*ndim >= MAX_SHAPE_SIZE) { + return NNACL_INFER_INVALID; + } + + return BroadCastOutputShape(in_shape0, in_shape1, *ndim, out_shape, has_broad_cast); +} + +int BroadcastToInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (inputs_size != 1 && inputs_size != 2) { + return NNACL_ERR; + } + if (outputs_size != 1) { + return NNACL_ERR; + } + + const TensorC *input = inputs[0]; + SetDataTypeFormat(outputs[0], input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int dst_shape[MAX_SHAPE_SIZE] = {0}; + int dst_shape_size; + const int *input_shape = input->shape_; + int input_shape_size = input->shape_size_; + int output_shape[MAX_SHAPE_SIZE] = {0}; + int ndim = input_shape_size; + bool has_broad_cast = false; + if (inputs_size == 1) { + BroadcastToParameter *param = (BroadcastToParameter *)parameter; + dst_shape_size = (int)param->shape_size_; + if (dst_shape_size > MAX_SHAPE_SIZE) { + return NNACL_PARAM_INVALID; + } + for (int i = 0; i < dst_shape_size; i++) { + dst_shape[i] = param->shape_[i]; + } + } else { + const TensorC *shape_tensor = inputs[1]; + if (shape_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + dst_shape_size = NNACLGetElementNum(shape_tensor); + if (dst_shape_size > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + ret = GetShapeByType(shape_tensor, dst_shape_size, dst_shape); + if (ret != NNACL_OK) { + return ret; + } + for (int i = 0; i < dst_shape_size; ++i) { + if (dst_shape[i] == -1) { + dst_shape[i] = inputs[0]->shape_[i]; + } + } + } + + if (BroadCastToShape(input_shape_size, dst_shape_size, input_shape, dst_shape, &ndim, output_shape, + &has_broad_cast) != NNACL_OK) { + return NNACL_ERR; + } + + SetShapeArray(outputs[0], output_shape, (size_t)ndim); + return NNACL_OK; +} + +REG_INFER(BroadcastTo, PrimType_BroadcastTo, BroadcastToInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/broadcast_to_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/broadcast_to_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..4fe5645aa49b5b212f082a6d8c0e381bd59b8b8e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/broadcast_to_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_BROADCAST_TO_INFER_H +#define MINDSPORE_NNACL_BROADCAST_TO_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/base/broadcast_to.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int BroadcastToInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outpus_size, + OpParameter *parameter); +void MakeUpInputShapes(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1); +int BroadCastOutputShape(const int *in_shape0, const int *in_shape1, const int ndim, int *out_shape, + bool *has_broad_cast); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_BROADCAST_TO_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/cast_gather_reduce_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/cast_gather_reduce_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2ad0f1d5703ca996abd7730c1c1470e7f4460f53 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/cast_gather_reduce_infer.c @@ -0,0 +1,77 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/cast_gather_reduce_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/split_parameter.h" + +int CastGatherReduceFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + const size_t kMinimumGradInputsNum = 3; + if (inputs_size < kMinimumGradInputsNum || outputs_size != 1) { + return NNACL_ERR; + } + const TensorC *input = inputs[0]; + const TensorC *indices = inputs[1]; + TensorC *output = outputs[0]; + output->data_type_ = input->data_type_; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE || indices->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (inputs[C2NUM]->data_ == NULL) { + return NNACL_NULL_PTR; + } + int axis = *((int *)inputs[C2NUM]->data_); + if (axis < 0) { + axis += input->shape_size_; + } + int indices_shape[MAX_SHAPE_SIZE]; + size_t indices_shape_size = 0; + ShapeSet(indices_shape, &indices_shape_size, indices->shape_, indices->shape_size_); + size_t indices_rank = indices_shape_size; + int in_shape[MAX_SHAPE_SIZE] = {0}; + size_t in_shape_size = 0; + ShapeSet(in_shape, &in_shape_size, input->shape_, input->shape_size_); + if ((int)(in_shape_size) < axis + 1) { + return NNACL_ERR; + } + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, in_shape, in_shape_size); + int erase_ret = ShapeErase(out_shape, &out_shape_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + for (int i = (int)(indices_rank - 1); i >= 0; --i) { + ret = ShapeInsert(out_shape, &out_shape_size, axis, indices_shape[i]); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + } + out_shape[1] = 1; + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(CastGatherReduceFusion, PrimType_Inner_CastGatherReduceFusion, CastGatherReduceFusionInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/cast_gather_reduce_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/cast_gather_reduce_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..009eaf3c54909ff495be2bb634d0a5e18f7aa200 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/cast_gather_reduce_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CAST_GATHER_REDUCE_INFER_H +#define MINDSPORE_NNACL_CAST_GATHER_REDUCE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CastGatherReduceFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/cast_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/cast_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..d887eabfd5b61e8faf2b18bcdcdcda77e0f3dc5b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/cast_infer.c @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/cast_infer.h" +#include "nnacl/infer/infer_register.h" + +int CastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size != 2) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->format_ = input->format_; + const TensorC *dst_type = inputs[1]; + if (dst_type->data_ == NULL) { + return NNACL_NULL_PTR; + } + output->data_type_ = *((int *)dst_type->data_); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->data_type_ != kNumberTypeBool && input->data_type_ != kNumberTypeUInt8 && + input->data_type_ != kNumberTypeInt8 && input->data_type_ != kNumberTypeInt32 && + input->data_type_ != kNumberTypeInt64 && input->data_type_ != kNumberTypeFloat32 && + input->data_type_ != kNumberTypeFloat16) { + return NNACL_INPUT_TENSOR_ERROR; + } + + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(Cast, PrimType_Cast, CastInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/cast_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/cast_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a34e541accf696b7efa1fd47094259c495f3778b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/cast_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CAST_INFER_H +#define MINDSPORE_NNACL_CAST_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CAST_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/common_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/common_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..23e635334f0e821908a6d53e46a2b6a9d0bf240c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/common_infer.c @@ -0,0 +1,338 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use tensor file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/common_infer.h" +#include +#include +#include "nnacl/infer/infer_register.h" +#include "nnacl/op_base.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/tensorlist_c_utils.h" + +bool CheckShaleValid(TensorC **tensors, int tensors_size) { + for (int i = 0; i < tensors_size; i++) { + TensorC *t = tensors[i]; + for (size_t j = 0; j < t->shape_size_; j++) { + if (t->shape_[j] == -1) { + return false; + } + } + } + return true; +} + +bool CheckInferShapeDone(TensorC **in, int in_size, TensorC **out, int out_size) { + for (int i = 0; i < in_size; i++) { + TensorC *t = in[i]; + for (size_t j = 0; j < t->shape_size_; j++) { + if (t->shape_[j] == -1) { + return false; + } + } + } + for (int i = 0; i < out_size; i++) { + TensorC *t = out[i]; + for (size_t j = 0; j < t->shape_size_; j++) { + if (t->shape_[j] == -1) { + return false; + } + } + } + return true; +} + +void ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size) { + size_t i = 0; + for (; i < src_shape_size && i < MAX_SHAPE_SIZE; i++) { + dst_shape[i] = src_shape[i]; + } + *dst_shape_size = i; +} + +bool Int64ShapeSet(int *dst_shape, size_t *dst_shape_size, const int64_t *src_shape, size_t src_shape_size) { + if (dst_shape_size == NULL || dst_shape == NULL || src_shape == NULL) { + return false; + } + size_t i = 0; + for (; i < src_shape_size && i < MAX_SHAPE_SIZE; i++) { + if (MS_UNLIKELY(src_shape[i] > (int64_t)INT32_MAX || src_shape[i] < (int64_t)INT32_MIN)) { + return false; + } + dst_shape[i] = (int32_t)(src_shape[i]); + } + *dst_shape_size = i; + return true; +} + +void ShapePush(int *shape, size_t *shape_size, int value) { + if (*shape_size >= MAX_SHAPE_SIZE) { + return; + } + shape[*shape_size] = value; + *shape_size = *shape_size + 1; +} + +int GetInt32DataFromTensor(const TensorC *tensor, int *result, size_t *result_size) { + if (tensor->data_ == NULL || result == NULL || result_size == NULL) { + return NNACL_ERR; + } + if (tensor->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int ele_num = NNACLGetElementNum(tensor); + if (ele_num <= 0) { + return NNACL_ERR; + } + *result_size = (size_t)ele_num; + if (tensor->data_type_ == kNumberTypeInt || tensor->data_type_ == kNumberTypeInt32) { + int *data = (int *)(tensor->data_); + for (int i = 0; i < ele_num; i++) { + result[i] = data[i]; + } + } else if (tensor->data_type_ == kNumberTypeInt64) { + int64_t *data = (int64_t *)(tensor->data_); + for (int i = 0; i < ele_num; i++) { + if (data[i] >= INT32_MAX) { + return NNACL_ERR; + } + result[i] = (int32_t)data[i]; + } + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +int ShapeInsert(int *shape, size_t *shape_size, int index, int value) { + if (index < 0 || index > *shape_size) { + return NNACL_ERR; + } + if (*shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (int i = *shape_size; i > index; i--) { + shape[i] = shape[i - 1]; + } + shape[index] = value; + *shape_size = *shape_size + 1; + return NNACL_OK; +} + +int ShapeErase(int *shape, size_t *shape_size, int index) { + if (index < 0 || index >= *shape_size) { + return NNACL_ERR; + } + + for (int i = index; i < *shape_size - 1; i++) { + shape[i] = shape[i + 1]; + } + *shape_size = *shape_size - 1; + return NNACL_OK; +} + +bool ShapeEqual(const int *shape0, size_t shape0_size, const int *shape1, size_t shape1_size) { + if (shape0_size != shape1_size) { + return false; + } + for (size_t i = 0; i < shape0_size; i++) { + if (shape0[i] != shape1[i]) { + return false; + } + } + return true; +} + +void iswap(int *a, int *b) { + int tmp = *a; + *a = *b; + *b = tmp; +} + +int imin(int a, int b) { return a > b ? b : a; } + +int imax(int a, int b) { return a < b ? b : a; } + +// input == output completely refer to +// 1. zeros_like +int CommonInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int CommonGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (ret != NNACL_OK) { + return ret; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + NNACL_CHECK_TRUE_RET(inputs[0]->shape_size_ == inputs[1]->shape_size_, NNACL_ERR); + for (int i = 0; i < inputs[0]->shape_size_; i++) { + if (inputs[0]->shape_[i] != inputs[1]->shape_[i]) { + return NNACL_ERR; + } + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int CommonInferShapeWithOneInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (ret != NNACL_OK) { + return ret; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int CommonInferShapeWithTwoInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (ret != NNACL_OK) { + return ret; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int CommonInferShapeWithNHWC(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + if (inputs[0]->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[0]); + return NNACL_OK; +} + +int FftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (ret != NNACL_OK) { + return ret; + } + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = kNumberTypeFloat32; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); + if (input_shape_size == 0) { + return NNACL_ERR; + } + input_shape_size--; + SetShapeArray(output, input_shape, input_shape_size); + return NNACL_OK; +} + +bool InferFlag(const TensorC *const *inputs, size_t inputs_size) { + if (inputs == NULL) { + return false; + } + for (size_t i = 0; i < inputs_size; i++) { + if (inputs[i] == NULL) { + return false; + } + if (inputs[i]->data_type_ == kObjectTypeTensorType) { + if (InferFlagTensorList((TensorC *)inputs[i]) == false) { + return false; + } + } else { + for (size_t j = 0; j < inputs[i]->shape_size_; ++j) { + if (inputs[i]->shape_[j] < 0) { + return false; + } + } + } + } + return true; +} + +REG_INFER(Abs, PrimType_Abs, CommonInferShape) +REG_INFER(AbsGrad, PrimType_AbsGrad, CommonGradInferShape) +REG_INFER(Activation, PrimType_Activation, CommonInferShape) +REG_INFER(BatchNorm, PrimType_BatchNorm, CommonInferShape) +REG_INFER(BinaryCrossEntropyGrad, PrimType_BinaryCrossEntropyGrad, CommonInferShape) +REG_INFER(Ceil, PrimType_Ceil, CommonInferShape) +REG_INFER(Clip, PrimType_Clip, CommonInferShape) +REG_INFER(Cos, PrimType_Cos, CommonInferShape) +REG_INFER(Depend, PrimType_Depend, CommonInferShape) +REG_INFER(Elu, PrimType_Elu, CommonInferShape) +REG_INFER(Erf, PrimType_Erf, CommonInferShape) +REG_INFER(Exp, PrimType_ExpFusion, CommonInferShape) +REG_INFER(FakeQuantWithMinMaxVars, PrimType_FakeQuantWithMinMaxVars, CommonInferShape) +REG_INFER(Floor, PrimType_Floor, CommonInferShapeWithOneInput) +REG_INFER(LeakyRelu, PrimType_LeakyRelu, CommonInferShape) +REG_INFER(Log, PrimType_Log, CommonInferShape) +REG_INFER(Log1p, PrimType_Log1p, CommonInferShape) +REG_INFER(LogGrad, PrimType_LogGrad, CommonGradInferShape) +REG_INFER(LogicalNot, PrimType_LogicalNot, CommonInferShape) +REG_INFER(LRN, PrimType_LRN, CommonInferShapeWithNHWC) +REG_INFER(L2Normalize, PrimType_L2NormalizeFusion, CommonInferShape) +REG_INFER(Neg, PrimType_Neg, CommonInferShape) +REG_INFER(NegGrad, PrimType_NegGrad, CommonGradInferShape) +REG_INFER(OnesLike, PrimType_OnesLike, CommonInferShape) +REG_INFER(PowerGrad, PrimType_PowerGrad, CommonGradInferShape) +REG_INFER(PReLU, PrimType_PReLUFusion, CommonInferShape) +REG_INFER(Reciprocal, PrimType_Reciprocal, CommonInferShape) +REG_INFER(ReverseSequence, PrimType_ReverseSequence, CommonInferShape) +REG_INFER(Reverse, PrimType_ReverseV2, CommonInferShape) +REG_INFER(Round, PrimType_Round, CommonInferShape) +REG_INFER(Rsqrt, PrimType_Rsqrt, CommonInferShape) +REG_INFER(Scale, PrimType_ScaleFusion, CommonInferShape) +REG_INFER(SigmoidCrossEntropyWithLogits, PrimType_SigmoidCrossEntropyWithLogits, CommonInferShape) +REG_INFER(SigmoidCrossEntropyWithLogitsGrad, PrimType_SigmoidCrossEntropyWithLogitsGrad, CommonInferShape) +REG_INFER(Sin, PrimType_Sin, CommonInferShape) +REG_INFER(SmoothL1Loss, PrimType_SmoothL1Loss, CommonInferShape) +REG_INFER(SmoothL1LossGrad, PrimType_SmoothL1LossGrad, CommonInferShape) +REG_INFER(Sqrt, PrimType_Sqrt, CommonInferShape) +REG_INFER(SqrtGrad, PrimType_SqrtGrad, CommonInferShape) +REG_INFER(Square, PrimType_Square, CommonInferShape) +REG_INFER(ZerosLike, PrimType_ZerosLike, CommonInferShape) +REG_INFER(ScatterElements, PrimType_ScatterElements, CommonInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/common_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/common_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..ebae06e6df5d662488b341c1ff71562e2de579fb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/common_infer.h @@ -0,0 +1,94 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_COMMON_H_ +#define MINDSPORE_NNACL_COMMON_H_ + +#include +#include "nnacl/errorcode.h" +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" + +bool CheckShaleValid(TensorC **tensors, int tensors_size); +bool CheckInferShapeDone(TensorC **in, int in_size, TensorC **out, int out_size); + +#define EPSILON_VALUE 1e-6 + +enum NNACLLshProjectionType { + LshProjectionType_UNKNOWN = 0, + LshProjectionType_SPARSE = 1, + LshProjectionType_DENSE = 2, + LshProjectionType_MIN = LshProjectionType_UNKNOWN, + LshProjectionType_MAX = LshProjectionType_DENSE +}; + +typedef struct VectorC { + int *data_; + size_t size_; + size_t max_size_; + size_t per_malloc_size_; +} VectorC; + +#ifdef __cplusplus +extern "C" { +#endif + +int CheckAugmentNull(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter); +int CheckAugmentNullSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj); +int CheckAugmentNullSizeInputTwo(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, const OpParameter *parameter, size_t inputs_size_obj_0, + size_t inputs_size_obj_1, size_t outputs_size_obj); +int CheckAugmentNullInputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj); +int CheckAugmentNullOutputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t outputs_size_obj); +int CheckAugmentWithMinSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj); +void SetDataTypeFormat(TensorC *dst, const TensorC *src); + +void SetShapeTensor(TensorC *dst, const TensorC *src); +void SetShapeArray(TensorC *dst, const int *src, size_t src_size); +void ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size); +bool Int64ShapeSet(int *dst_shape, size_t *dst_shape_size, const int64_t *src_shape, size_t src_shape_size); +void ShapePush(int *shape, size_t *shape_size, int value); +int GetInt32DataFromTensor(const TensorC *tensor, int *result, size_t *result_size); +int ShapeInsert(int *shape, size_t *shape_size, int index, int value); +int ShapeErase(int *shape, size_t *shape_size, int index); +bool ShapeEqual(const int *shape0, size_t shape0_size, const int *shape1, size_t shape1_size); + +void iswap(int *a, int *b); + +int imin(int a, int b); +int imax(int a, int b); + +int CommonInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); +int CommonGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); +int CommonInferShapeWithOneInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); +int CommonInferShapeWithNHWC(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); +int FftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter); +bool InferFlag(const TensorC *const *inputs, size_t inputs_size); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_NNACL_COMMON__H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/concat_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/concat_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..6b828632b3c09bca5d7580f15400e95a738a84b1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/concat_infer.c @@ -0,0 +1,97 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/concat_infer.h" +#include "nnacl/infer/infer_register.h" + +int DataTypeJudge(const TensorC *input, const TensorC *output) { + if ((input->data_type_ != output->data_type_) && + !((input->data_type_ == kNumberTypeFloat16 && output->data_type_ == kNumberTypeFloat32) || + (input->data_type_ == kNumberTypeFloat32 && output->data_type_ == kNumberTypeFloat16))) { + return NNACL_PARAM_INVALID; + } + return NNACL_OK; +} + +int ConcatInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input0 = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input0); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const int *input0_shape = inputs[0]->shape_; + size_t input0_shape_size = inputs[0]->shape_size_; + + ConcatParameter *param = (ConcatParameter *)parameter; + int axis = param->axis_ < 0 ? param->axis_ + (int)input0_shape_size : param->axis_; + if (axis < 0 || axis >= (int)input0_shape_size) { + return NNACL_ERR; + } + if (input0_shape_size > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input0_shape_without_axis[MAX_SHAPE_SIZE] = {0}; + size_t input0_shape_without_axis_size = 0; + ShapeSet(input0_shape_without_axis, &input0_shape_without_axis_size, input0_shape, input0_shape_size); + int erase_ret = ShapeErase(input0_shape_without_axis, &input0_shape_without_axis_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + int output_axis_dim = input0_shape[axis]; + for (size_t i = 1; i < inputs_size; ++i) { + size_t input_i_shape_size = inputs[i]->shape_size_; + if (input_i_shape_size != input0_shape_size) { + return NNACL_PARAM_INVALID; + } + int shape_tmp[MAX_SHAPE_SIZE] = {0}; + size_t shape_tmp_size = 0; + ShapeSet(shape_tmp, &shape_tmp_size, inputs[i]->shape_, inputs[i]->shape_size_); + int data_type_judge = DataTypeJudge(inputs[i], output); + if (data_type_judge != NNACL_OK) { + return data_type_judge; + } + int axis_tmp = shape_tmp[axis]; + erase_ret = ShapeErase(shape_tmp, &shape_tmp_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + if (!ShapeEqual(input0_shape_without_axis, input0_shape_without_axis_size, shape_tmp, shape_tmp_size)) { + return NNACL_ERR; + } + output_axis_dim += axis_tmp; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input0_shape_size; + for (size_t i = 0; i < input0_shape_size; i++) { + output_shape[i] = input0_shape[i]; + } + output_shape[axis] = output_axis_dim; + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(Concat, PrimType_Concat, ConcatInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/concat_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/concat_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..315d60c8f7ee0171d78f850a7f3ce53e56e8a98d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/concat_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONCAT_INFER_H +#define MINDSPORE_NNACL_CONCAT_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/concat_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConcatInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONCAT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/constant_of_shape_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/constant_of_shape_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..56c767961c210089ecac67130b967b73892602e0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/constant_of_shape_infer.c @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/constant_of_shape_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int ConstantOfShapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_tensor = inputs[0]; + TensorC *out_tensor = outputs[0]; + ConstantOfShapeParameter *param = (ConstantOfShapeParameter *)parameter; + out_tensor->data_type_ = (TypeIdC)(param->data_type_); + out_tensor->format_ = in_tensor->format_; + if (!InferFlag(inputs, inputs_size) || in_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int size = NNACLGetElementNum(in_tensor); + if (size < 0 || size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int out_shape[MAX_SHAPE_SIZE]; + int out_shape_size = size; + switch (in_tensor->data_type_) { + case kNumberTypeInt32: { + int32_t *in_data = (int32_t *)(in_tensor->data_); + for (int i = 0; i < size; ++i) { + out_shape[i] = in_data[i]; + if (out_shape[i] < 0) { + return NNACL_ERR; + } + } + break; + } + case kNumberTypeInt64: { + int64_t *in_data = (int64_t *)(in_tensor->data_); + for (int i = 0; i < size; ++i) { + out_shape[i] = in_data[i]; + if (out_shape[i] < 0) { + return NNACL_ERR; + } + } + break; + } + default: + return NNACL_INFER_INVALID; + } + + SetShapeArray(out_tensor, out_shape, (size_t)out_shape_size); + return NNACL_OK; +} + +REG_INFER(ConstantOfShape, PrimType_ConstantOfShape, ConstantOfShapeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/constant_of_shape_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/constant_of_shape_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..baceb1201a4319b61aec78eb487026e9aab3222a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/constant_of_shape_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONSTANT_OF_SHAPE_INFER_H +#define MINDSPORE_NNACL_CONSTANT_OF_SHAPE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/constant_of_shape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ConstantOfShapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONSTANT_OF_SHAPE_INFER_H diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/equal.cu b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_infer.c similarity index 40% rename from mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/equal.cu rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_infer.c index d65c2129a2086156887f5e406817937b4e5be531..457e689e1b5a9da26862df1bd5a5159ce2bf60d3 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/equal.cu +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_infer.c @@ -14,22 +14,34 @@ * limitations under the License. */ -#include "src/extendrt/delegate/tensorrt/cuda_impl/equal.cuh" -#include -#include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h" - -template -__global__ void EqualKernel(const T *input1, const T *input2, T *output, int element_cnt) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < element_cnt; pos += blockDim.x * gridDim.x) { - output[pos] = (input1[pos] - input2[pos] < 1e-6 && input1[pos] - input2[pos] > -1e-6); +#include "nnacl/infer/control/tensor_array_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_array_parameter.h" + +int TensorArrayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { +#ifdef Debug + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; } -} +#endif + + TensorC *output = outputs[0]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + TensorArrayParameter *param = (TensorArrayParameter *)parameter; + if (param == NULL) { + return NNACL_NULL_PTR; + } + + output->data_type_ = param->data_type_; + SetShapeArray(output, param->element_shape_, (size_t)param->element_shape_size_); -template -void Equal(const T *input1, const T *input2, T *output, int element_cnt, cudaStream_t stream) { - EqualKernel<<>>(input1, input2, output, element_cnt); - return; + return NNACL_OK; } -template void Equal(const float *input1, const float *input2, float *output, int element_cnt, cudaStream_t stream); -template void Equal(const int *input1, const int *input2, int *output, int element_cnt, cudaStream_t stream); +REG_INFER(TensorArray, PrimType_TensorArray, TensorArrayInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..f65ea4b5b6b77471bedbbc12870ebebb8e6fd751 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_INFER_H_ +#define MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorArrayInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_read_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_read_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..8b19d76c52adaa0b2dd0eaa6d95415f68acec599 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_read_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/control/tensor_array_read_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_array_parameter.h" + +int TensorArrayReadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + // { prim, handle, index } -> node + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + NNACL_CHECK_TRUE_RET(inputs_size >= 1, NNACL_ERR); + NNACL_CHECK_TRUE_RET(outputs_size >= 1, NNACL_ERR); + TensorC *handle = (TensorC *)inputs[0]; + TensorC *output = outputs[0]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + output->data_type_ = handle->data_type_; + SetShapeArray(output, handle->shape_, handle->shape_size_); + + return NNACL_OK; +} + +REG_INFER(TensorArrayRead, PrimType_TensorArrayRead, TensorArrayReadInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_read_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_read_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9c080855e2c574f124818353e62567fe71793086 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_read_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_READ_INFER_H_ +#define MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_READ_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorArrayReadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_READ_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_write_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_write_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..4fff62a8c8cb0295f5e0e84528226f7cab49e4cc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_write_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/control/tensor_array_write_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_array_parameter.h" + +int TensorArrayWriteInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + // { handle, index, value, flow_in } -> empty + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 4, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + NNACL_CHECK_TRUE_RET(inputs_size >= 3, NNACL_ERR); + TensorC *handle = (TensorC *)inputs[0]; + TensorC *value = (TensorC *)inputs[2]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + TensorArrayParameter *param = (TensorArrayParameter *)parameter; + if (param == NULL) { + return NNACL_NULL_PTR; + } + + if (handle->shape_size_ != value->shape_size_) { + return NNACL_INFER_INVALID; + } + + for (size_t i = 0; i < handle->shape_size_; ++i) { + if (handle->shape_[i] != value->shape_[i]) { + return NNACL_INFER_INVALID; + } + } + + return NNACL_OK; +} + +REG_INFER(TensorArrayWrite, PrimType_TensorArrayWrite, TensorArrayWriteInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_write_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_write_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..2321729429b026c401d2d25f4ceef3d1fce88c35 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensor_array_write_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_WRITE_INFER_H_ +#define MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_WRITE_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorArrayWriteInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSOR_ARRAY_WRITE_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_fromtensor_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_fromtensor_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b53610fcac6c62b0cd6c82c512939339bcc3f992 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_fromtensor_infer.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/control/tensorlist_fromtensor_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensorlist_c_utils.h" +#include "nnacl/tensor_c_utils.h" + +int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorListC *output = (TensorListC *)(outputs[0]); + const TensorC *input0 = inputs[0]; + output->data_type_ = kObjectTypeTensorType; + output->format_ = Format_NHWC; + output->tensors_data_type_ = input0->data_type_; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input0->shape_size_ < 1) { + return NNACL_ERR; + } + int dim0 = input0->shape_[0]; + if (dim0 < 0) { + return NNACL_ERR; + } + const TensorC *input1 = inputs[1]; + if (input1->data_ == NULL) { + return NNACL_NULL_PTR; + } + int *ele_shape_ptr = (int *)(input1->data_); + NNACL_CHECK_NULL_RETURN_ERR(ele_shape_ptr); + vvector tensor_shape; + tensor_shape.size_ = (size_t)(dim0); + tensor_shape.shape_ = (int **)malloc(tensor_shape.size_ * sizeof(int *)); + if (tensor_shape.shape_ == NULL) { + return NNACL_NULL_PTR; + } + tensor_shape.shape_size_ = (int *)malloc(tensor_shape.size_ * sizeof(int)); + if (tensor_shape.shape_size_ == NULL) { + free(tensor_shape.shape_); + return NNACL_NULL_PTR; + } + for (int i = 0; i < dim0; i++) { + tensor_shape.shape_[i] = (int *)(input0->shape_ + 1); + tensor_shape.shape_size_[i] = (int)(input0->shape_size_) - 1; + } + + ShapeSet(output->element_shape_, &(output->element_shape_size_), ele_shape_ptr, (size_t)NNACLGetElementNum(input1)); + output->element_num_ = (size_t)(dim0); + int ret = MallocTensorListData(output, input0->data_type_, &tensor_shape); + if (ret != NNACL_OK) { + free(tensor_shape.shape_); + free(tensor_shape.shape_size_); + return NNACL_ERR; + } + free(tensor_shape.shape_); + free(tensor_shape.shape_size_); + return NNACL_OK; +} + +REG_INFER(TensorListFromTensor, PrimType_TensorListFromTensor, TensorListFromTensorInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_fromtensor_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_fromtensor_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..3c18266d057b390ac1e7d2c8f34504960bac32bb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_fromtensor_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_FROMTENSOR_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_FROMTENSOR_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListFromTensorInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_FROMTENSOR_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_getitem_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_getitem_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..ab61d279119cc4de7d35dbb2ff1fa255301ad551 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_getitem_infer.c @@ -0,0 +1,102 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/control/tensorlist_getitem_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensorlist_c_utils.h" +#include "nnacl/tensor_c_utils.h" + +int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (inputs[0]->data_type_ != kObjectTypeTensorType) { + return NNACL_ERR; + } + TensorListC *input0 = (TensorListC *)(inputs[0]); + const TensorC *get_index = inputs[1]; + if (get_index->data_ == NULL) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(get_index) != 1) { + return NNACL_ERR; + } + TensorC *output = outputs[0]; + if (!InferFlag(inputs, inputs_size) || input0->element_num_ == 0) { + return NNACL_INFER_INVALID; + } + int index = ((int *)(get_index->data_))[0]; + if (index < 0 || index > ((int)(input0->element_num_ - 1))) { + return NNACL_ERR; + } + TensorC *tensor_index = input0->tensors_[index]; + NNACL_CHECK_NULL_RETURN_ERR(tensor_index); + + if (tensor_index->data_type_ != kTypeUnknown) { + output->data_type_ = tensor_index->data_type_; + } else { + output->data_type_ = input0->tensors_data_type_; + } + output->format_ = input0->tensors_[index]->format_; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (tensor_index->data_type_ != kTypeUnknown) { + ShapeSet(output->shape_, &(output->shape_size_), tensor_index->shape_, tensor_index->shape_size_); + } else { + const TensorC *input2 = inputs[2]; + NNACL_CHECK_NULL_RETURN_ERR(input2); + NNACL_CHECK_NULL_RETURN_ERR(input2->data_); + int *ele_shape_data = (int *)(input2->data_); + NNACL_CHECK_NULL_RETURN_ERR(ele_shape_data); + int element_shape[MAX_SHAPE_SIZE] = {0}; + size_t element_shape_size = 0; + for (int i = 0; i < NNACLGetElementNum(input2); ++i) { + ShapePush(element_shape, &element_shape_size, ele_shape_data[i]); + } + int status = + TensorListMergeShape(element_shape, &element_shape_size, input0->element_shape_, input0->element_shape_size_); + if (status != NNACL_OK) { + return NNACL_ERR; + } + if (!TensorListIsFullyDefined(element_shape, element_shape_size)) { + for (size_t i = 0; i < input0->element_num_; ++i) { + TensorC *input = input0->tensors_[i]; + NNACL_CHECK_NULL_RETURN_ERR(input); + if (input->data_type_ != kTypeUnknown) { + status = TensorListMergeShape(element_shape, &element_shape_size, input->shape_, input->shape_size_); + if (status != NNACL_OK) { + return NNACL_ERR; + } + } + } + } + if (!TensorListIsFullyDefined(element_shape, element_shape_size)) { // the pre is the same judge condition + return NNACL_ERR; + } + + SetShapeArray(output, element_shape, element_shape_size); + } + + return NNACL_OK; +} + +REG_INFER(TensorListGetItem, PrimType_TensorListGetItem, TensorListGetItemInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_getitem_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_getitem_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9defe797173c537ec1b6dddd370e3933aa345f20 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_getitem_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_GETITEM_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_GETITEM_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/tensorlist_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_GETITEM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_reserve_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_reserve_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..c10b820cf99a46b4d3c1c7678282e90d67d6b2c3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_reserve_infer.c @@ -0,0 +1,84 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/control/tensorlist_reserve_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensorlist_parameter.h" +#include "nnacl/tensorlist_c_utils.h" +#include "nnacl/tensor_c_utils.h" + +int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorListParameter *reserve_param = (TensorListParameter *)parameter; + const TensorC *input0 = inputs[0]; + int ele_shape_type = input0->data_type_; + if (ele_shape_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) { + return NNACL_ERR; + } + + TensorListC *output = (TensorListC *)(outputs[0]); + output->data_type_ = kObjectTypeTensorType; + output->format_ = Format_NHWC; + output->tensors_data_type_ = reserve_param->element_dtype_; + + if (input0->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int *ele_shape_ptr = (int *)(input0->data_); + + const TensorC *input1 = inputs[1]; + int num_ele_type = input1->data_type_; + if (num_ele_type != kNumberTypeInt && ele_shape_type != kNumberTypeInt32) { + return NNACL_ERR; + } + if (input1->data_ == NULL) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(input1) != 1) { + return NNACL_ERR; + } + int num_elements = ((int *)(input1->data_))[0]; + ShapeSet(output->element_shape_, &(output->element_shape_size_), ele_shape_ptr, (size_t)NNACLGetElementNum(input0)); + output->element_num_ = (size_t)(num_elements); + + vvector tmp_shape; + tmp_shape.size_ = (size_t)(num_elements); + tmp_shape.shape_ = (int **)malloc(tmp_shape.size_ * sizeof(int *)); + if (tmp_shape.shape_ == NULL) { + return NNACL_NULL_PTR; + } + tmp_shape.shape_size_ = (int *)malloc(tmp_shape.size_ * sizeof(int)); + if (tmp_shape.shape_size_ == NULL) { + free(tmp_shape.shape_); + return NNACL_NULL_PTR; + } + + for (size_t i = 0; i < num_elements; ++i) { + tmp_shape.shape_size_[i] = output->element_shape_size_; + tmp_shape.shape_[i] = output->element_shape_; + } + int ret = MallocTensorListData(output, reserve_param->element_dtype_, &tmp_shape); + free(tmp_shape.shape_size_); + free(tmp_shape.shape_); + return ret; +} + +REG_INFER(TensorListReserve, PrimType_TensorListReserve, TensorListReserveInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_reserve_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_reserve_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..82b8cbefe6ca37defe061744b7e58d8200606318 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_reserve_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_RESERVE_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_RESERVE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListReserveInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_RESERVE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_setitem_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_setitem_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..e7c182e9ad2952c7fef7484ddcfb97f0a999423e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_setitem_infer.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/control/tensorlist_setitem_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensorlist_c_utils.h" +#include "nnacl/tensor_c_utils.h" + +int PreJudge(const TensorC *get_index, TensorListC *input0, const TensorC *value_tensor) { + if (get_index->data_ == NULL) { + return NNACL_INFER_INVALID; + } + + if (get_index->data_type_ != kNumberTypeInt && get_index->data_type_ != kNumberTypeInt32) { + return NNACL_ERR; + } + if (NNACLGetElementNum(get_index) != 1) { + return NNACL_ERR; + } + if (get_index->data_ == NULL) { + return NNACL_NULL_PTR; + } + return NNACL_OK; +} + +int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorListC *input0 = (TensorListC *)(inputs[0]); + const TensorC *get_index = inputs[1]; + const TensorC *value_tensor = inputs[2]; + TensorListC *output0 = (TensorListC *)(outputs[0]); + output0->data_type_ = input0->data_type_; + output0->format_ = input0->format_; + output0->tensors_data_type_ = value_tensor->data_type_; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int judge_ret = PreJudge(get_index, input0, value_tensor); + if (judge_ret != NNACL_OK) { + return judge_ret; + } + + int index = ((int *)(get_index->data_))[0]; + output0->max_elements_num_ = input0->max_elements_num_; + + if (input0->element_num_ == 0 && input0->element_shape_size_ == 0 && index == 0) { + ShapeSet(input0->element_shape_, &(input0->element_shape_size_), value_tensor->shape_, value_tensor->shape_size_); + ShapeSet(output0->element_shape_, &(output0->element_shape_size_), value_tensor->shape_, value_tensor->shape_size_); + } else { + ShapeSet(output0->element_shape_, &(output0->element_shape_size_), input0->element_shape_, + input0->element_shape_size_); + } + + vvector out_shape; + out_shape.size_ = 0; + out_shape.shape_ = (int **)malloc((input0->element_num_ + 1) * sizeof(int *)); + if (out_shape.shape_ == NULL) { + return NNACL_NULL_PTR; + } + out_shape.shape_size_ = (int *)malloc((input0->element_num_ + 1) * sizeof(int)); + if (out_shape.shape_size_ == NULL) { + free(out_shape.shape_); + return NNACL_NULL_PTR; + } + + if (index == 0 && input0->element_num_ == 0) { // uninitialized tensorlist + out_shape.shape_[out_shape.size_] = (int *)(value_tensor->shape_); + out_shape.shape_size_[out_shape.size_] = value_tensor->shape_size_; + out_shape.size_++; + output0->element_num_ = 1; + } else { + output0->element_num_ = input0->element_num_; + for (size_t i = 0; i < input0->element_num_; ++i) { + TensorC *src_ptr = input0->tensors_[i]; + if (src_ptr == NULL) { + free(out_shape.shape_); + free(out_shape.shape_size_); + return NNACL_NULL_PTR; + } + if (src_ptr->data_type_ != kTypeUnknown) { + out_shape.shape_[out_shape.size_] = src_ptr->shape_; + out_shape.shape_size_[out_shape.size_] = (int)(src_ptr->shape_size_); + out_shape.size_++; + } else { + out_shape.shape_[out_shape.size_] = NULL; + out_shape.shape_size_[out_shape.size_] = 0; + out_shape.size_++; + } + } + } + + if (input0->tensors_data_type_ == kTypeUnknown) { + input0->tensors_data_type_ = value_tensor->data_type_; + } + + out_shape.shape_[index] = (int *)(value_tensor->shape_); + out_shape.shape_size_[index] = (int)value_tensor->shape_size_; + int ret = MallocTensorListData(output0, input0->tensors_data_type_, &out_shape); + if (ret != NNACL_OK) { + free(out_shape.shape_); + free(out_shape.shape_size_); + return NNACL_ERR; + } + free(out_shape.shape_); + free(out_shape.shape_size_); + return NNACL_OK; +} + +REG_INFER(TensorListSetItem, PrimType_TensorListSetItem, TensorListSetItemInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_setitem_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_setitem_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..73f02c3faade265780ab6e350a3f39296ff35265 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_setitem_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_SETITEM_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_SETITEM_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListSetItemInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_SETITEM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_stack_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_stack_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..de3cdcd1f7bbee199024bd1d152cb907b02d90ab --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_stack_infer.c @@ -0,0 +1,96 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/control/tensorlist_stack_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensorlist_c_utils.h" +#include "nnacl/tensor_c_utils.h" + +int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorC *output = outputs[0]; + if (inputs[0]->data_type_ != kObjectTypeTensorType) { + return NNACL_INPUT_TENSOR_ERROR; + } + TensorListC *input0 = (TensorListC *)(inputs[0]); + output->data_type_ = input0->tensors_data_type_; + output->format_ = input0->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input0->element_num_ == 0) { + return NNACL_INFER_INVALID; + } + const TensorC *ele_shape = inputs[1]; // element shape + if (ele_shape->data_ == NULL) { + return NNACL_NULL_PTR; + } + int *ele_shape_ptr = (int *)(ele_shape->data_); + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + if (ele_shape_ptr[0] == -1) { + if (input0->element_shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (size_t i = 0; i < input0->element_shape_size_; i++) { + ShapePush(output_shape, &output_shape_size, input0->element_shape_[i]); + } + } else { + int ele_shape_num = NNACLGetElementNum(ele_shape); + if (ele_shape_num > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < ele_shape_num; ++i) { + ShapePush(output_shape, &output_shape_size, ele_shape_ptr[i]); + } + } + + int status = + TensorListMergeShape(output_shape, &output_shape_size, input0->element_shape_, input0->element_shape_size_); + if (status == NNACL_ERR) { + return NNACL_ERR; + } + if (!TensorListIsFullyDefined(output_shape, output_shape_size)) { + return NNACL_ERR; + } + if (!TensorListIsFullyDefined(input0->element_shape_, input0->element_shape_size_)) { + for (size_t i = 0; i < input0->element_num_; ++i) { + TensorC *tensor_ele = input0->tensors_[i]; + if (tensor_ele->data_type_ != kTypeUnknown) { + status = TensorListMergeShape(output_shape, &output_shape_size, tensor_ele->shape_, tensor_ele->shape_size_); + if (status == NNACL_ERR) { + return NNACL_ERR; + } + } + } + } + if (output_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int ret = ShapeInsert(output_shape, &output_shape_size, 0, input0->element_num_); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(TensorListStack, PrimType_TensorListStack, TensorListStackInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_stack_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_stack_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..972365be85b14c60a47f305698bba6b48b628acc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/control/tensorlist_stack_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_STACK_INFER_H +#define MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_STACK_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_CONTROL_TENSORLIST_STACK_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_grad_filter_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_grad_filter_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..5e880f39f68d7a7f97d920002e22d09526d28eed --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_grad_filter_infer.c @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl/infer/conv2d_grad_filter_infer.h" +#include "nnacl/infer/infer_register.h" + +int Conv2dGradFilterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (inputs_size < 3 || outputs_size != 1) { + return NNACL_ERR; + } + if (inputs[FIRST_INPUT]->format_ != Format_NHWC || inputs[SECOND_INPUT]->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[FIRST_INPUT], inputs[FIRST_INPUT]); + + if (inputs[THIRD_INPUT]->shape_size_ < DIMENSION_1D || inputs[THIRD_INPUT]->data_ == NULL) { + return NNACL_ERR; + } + if (inputs[THIRD_INPUT]->shape_[kNCHW_N] < 0) { + return NNACL_ERR; + } + size_t filter_shape_size = (size_t)(inputs[THIRD_INPUT]->shape_[kNCHW_N]); + if (filter_shape_size != DIMENSION_4D) { + return NNACL_ERR; + } + + int filter_shape[MAX_SHAPE_SIZE]; + if (inputs[THIRD_INPUT]->format_ == Format_NCHW || inputs[THIRD_INPUT]->format_ == Format_KCHW) { + const int nchw2nhwc[] = {kNCHW_N, kNCHW_H, kNCHW_W, kNCHW_C}; + for (size_t i = 0; i < filter_shape_size; i++) { + filter_shape[i] = *((int *)(inputs[THIRD_INPUT]->data_) + nchw2nhwc[i]); + } + } else if (inputs[THIRD_INPUT]->format_ == Format_NHWC || inputs[THIRD_INPUT]->format_ == Format_KHWC) { + memcpy(filter_shape, inputs[THIRD_INPUT]->data_, filter_shape_size * sizeof(int)); + } else { + return NNACL_ERR; + } + SetShapeArray(outputs[0], filter_shape, filter_shape_size); + return NNACL_OK; +} + +REG_INFER(Conv2DBackpropFilterFusion, PrimType_Conv2DBackpropFilterFusion, Conv2dGradFilterInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_grad_filter_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_grad_filter_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..02deb60f6407d86c73968642d5cc68d1b4b2b1d9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_grad_filter_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONV2D_GRAD_FILTER_INFER_H +#define MINDSPORE_NNACL_CONV2D_GRAD_FILTER_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Conv2dGradFilterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONV2D_GRAD_FILTER_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_grad_input_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_grad_input_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2651deb5970ed717e8fbe4b939647ec4d15e1bdb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_grad_input_infer.c @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/conv2d_grad_input_infer.h" +#include "nnacl/infer/infer_register.h" + +int Conv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (inputs_size < 3 || outputs_size != 1) { + return NNACL_ERR; + } + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + if (in0 == NULL || out == NULL) { + return NNACL_NULL_PTR; + } + if (in0->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(out, in0); + + if (inputs[THIRD_INPUT]->shape_size_ < 1 || inputs[THIRD_INPUT]->data_ == NULL) { + return NNACL_ERR; + } + size_t data_size = (size_t)inputs[2]->shape_[0]; + if (data_size != 4) { + return NNACL_ERR; + } + + int shape[MAX_SHAPE_SIZE]; + if (inputs[THIRD_INPUT]->format_ == Format_NCHW || inputs[THIRD_INPUT]->format_ == Format_KCHW) { + const int nchw2nhwc[4] = {kNCHW_N, kNCHW_H, kNCHW_W, kNCHW_C}; + for (size_t i = 0; i < data_size; i++) { + shape[i] = *((int *)(inputs[THIRD_INPUT]->data_) + nchw2nhwc[i]); + } + } else if (inputs[THIRD_INPUT]->format_ == Format_NHWC || inputs[THIRD_INPUT]->format_ == Format_KHWC) { + memcpy(shape, inputs[THIRD_INPUT]->data_, data_size * sizeof(int)); + } else { + return NNACL_ERR; + } + SetShapeArray(out, shape, data_size); + return NNACL_OK; +} + +REG_INFER(Conv2DBackpropInputFusion, PrimType_Conv2DBackpropInputFusion, Conv2dGradInputInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_grad_input_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_grad_input_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..17ae57a8e564c5e56e99ccb5557622fd95125635 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_grad_input_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONV2D_GRAD_INPUT_INFER_H +#define MINDSPORE_NNACL_CONV2D_GRAD_INPUT_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Conv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONV2D_GRAD_INPUT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..495b2cda93548b7e3161ddb5de3a428507f82261 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_infer.c @@ -0,0 +1,169 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/infer/conv2d_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int ConvInferShape(int input_h, int input_w, int *output_h, int *output_w, ConvParameter *param) { + int kernel_w = param->kernel_w_; + int kernel_h = param->kernel_h_; + int stride_w = param->stride_w_; + int stride_h = param->stride_h_; + int dilate_w = param->dilation_w_; + int dilate_h = param->dilation_h_; + + if (stride_w == 0 || stride_h == 0) { + return NNACL_PARAM_INVALID; + } + if (INT_MUL_OVERFLOW(kernel_h, dilate_h) || INT_MUL_OVERFLOW(kernel_w, dilate_w)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + if (param->pad_mode_ == Pad_same) { // maybe error + *output_w = ceil((float)(input_w) / (float)(stride_w)); + *output_h = ceil((float)(input_h) / (float)(stride_h)); + int pad_h_all = ((*output_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - input_h); + int pad_w_all = ((*output_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - input_w); + if (pad_h_all < 0) { + param->pad_u_ = param->pad_d_ = 0; + } else { + param->pad_u_ = pad_h_all / 2; + param->pad_d_ = pad_h_all - param->pad_u_; + } + if (pad_w_all < 0) { + param->pad_l_ = param->pad_r_ = 0; + } else { + param->pad_l_ = pad_w_all / 2; + param->pad_r_ = pad_w_all - param->pad_l_; + } + } else if (param->pad_mode_ == Pad_valid) { + *output_w = ceil(((float)(input_w) + param->pad_l_ + param->pad_r_ - ((float)(kernel_w)-1) * (float)(dilate_w)) / + (float)(stride_w)); + *output_h = ceil(((float)(input_h) + param->pad_u_ + param->pad_d_ - ((float)(kernel_h)-1) * (float)(dilate_h)) / + (float)(stride_h)); + } else { + int kernel_width = (kernel_w - 1) * dilate_w + 1; + int kernel_height = (kernel_h - 1) * dilate_h + 1; + *output_w = ((input_w) + param->pad_l_ + param->pad_r_ - kernel_width) / stride_w + 1; + *output_h = ((input_h) + param->pad_u_ + param->pad_d_ - kernel_height) / stride_h + 1; + } + + if (param->kernel_h_ > input_h + param->pad_u_ + param->pad_d_ || + param->kernel_w_ > input_w + param->pad_l_ + param->pad_r_) { + return NNACL_PARAM_INVALID; + } + return NNACL_OK; +} + +static const int MAX_CONV_KERNEL_DIM = 10000; // One big value that should not be adopted as the conv kernel dimension. + +int CheckConvAttr(const int input_c, const TensorC *weight_tensor, const ConvParameter *param) { + // common conv: input_c == weight_tensor->shape_[3] + // conv depthwise: input_c == 1 + // group conv: input_c / group == weight_tensor->shape_[3] + NNACL_CHECK_FALSE(param->group_ == 0, NNACL_PARAM_INVALID); + if (input_c != weight_tensor->shape_[3] && input_c != 1 && (input_c / param->group_) != weight_tensor->shape_[3]) { + return NNACL_PARAM_INVALID; + } + + // common conv: group == 1 + // conv depthwise: group == input_c == output_c + // group conv: group == input_c / weight_tensor->shape_[3] + NNACL_CHECK_FALSE(weight_tensor->shape_[3] == 0, NNACL_PARAM_INVALID); + if (param->group_ != 1 && param->group_ != input_c && param->group_ != (input_c / weight_tensor->shape_[3])) { + return NNACL_PARAM_INVALID; + } + if (param->stride_h_ <= 0 || param->stride_w_ <= 0) { + return NNACL_PARAM_INVALID; + } + + if ((param->kernel_h_ >= MAX_CONV_KERNEL_DIM) || (param->kernel_w_ >= MAX_CONV_KERNEL_DIM)) { + return NNACL_PARAM_INVALID; + } + + NNACL_CHECK_TRUE_RET(param->kernel_h_ == weight_tensor->shape_[1], NNACL_PARAM_INVALID); + NNACL_CHECK_TRUE_RET(param->kernel_w_ == weight_tensor->shape_[2], NNACL_PARAM_INVALID); + return NNACL_OK; +} + +int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input_tensor = inputs[0]; + if (input_tensor->format_ != Format_NHWC && input_tensor->format_ != Format_KHWC && + input_tensor->format_ != Format_NC4HW4 && input_tensor->format_ != Format_NC8HW8) { + return NNACL_FORMAT_ERROR; + } + const TensorC *weight_tensor = inputs[1]; + if (weight_tensor->format_ != Format_NHWC && weight_tensor->format_ != Format_KHWC) { + return NNACL_FORMAT_ERROR; + } + TensorC *out_tensor = outputs[0]; + if (out_tensor->format_ != Format_NC4HW4) { + out_tensor->format_ = input_tensor->format_; + } + out_tensor->data_type_ = input_tensor->data_type_; + ConvParameter *param = (ConvParameter *)parameter; + if (param->group_ == 0) { + param->group_ = weight_tensor->shape_[0]; + } + param->output_channel_ = weight_tensor->shape_[0]; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + param->kernel_h_ = param->kernel_h_ != -1 ? param->kernel_h_ : weight_tensor->shape_[1]; + param->kernel_w_ = param->kernel_w_ != -1 ? param->kernel_w_ : weight_tensor->shape_[2]; + + if (input_tensor->shape_size_ == 0) { + return NNACL_INFER_INVALID; + } + + int ret = CheckConvAttr(NNACLGetChannel(input_tensor), weight_tensor, param); + if (ret != NNACL_OK) { + return ret; + } + + int output_w = 0, output_h = 0; + ret = ConvInferShape(NNACLGetHeight(input_tensor), NNACLGetWidth(input_tensor), &output_h, &output_w, param); + if (ret != NNACL_OK) { + return ret; + } + + out_tensor->shape_size_ = input_tensor->shape_size_; + NNACLSetBatch(out_tensor, NNACLGetBatch(input_tensor)); + NNACLSetChannel(out_tensor, NNACLGetBatch(weight_tensor)); + output_h = output_h >= 0 ? output_h : 1; + NNACLSetHeight(out_tensor, output_h); + output_w = output_w >= 0 ? output_w : 1; + NNACLSetWidth(out_tensor, output_w); + + param->input_batch_ = NNACLGetBatch(input_tensor); + param->input_h_ = NNACLGetHeight(input_tensor); + param->input_w_ = NNACLGetWidth(input_tensor); + param->input_channel_ = NNACLGetChannel(input_tensor); + param->output_batch_ = NNACLGetBatch(out_tensor); + param->output_h_ = NNACLGetHeight(out_tensor); + param->output_w_ = NNACLGetWidth(out_tensor); + param->output_channel_ = NNACLGetChannel(out_tensor); + param->out_format_ = out_tensor->format_; + return NNACL_OK; +} + +REG_INFER(Adder, PrimType_AdderFusion, Conv2dInferShape) +REG_INFER(Conv2D, PrimType_Conv2DFusion, Conv2dInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..fc4b8a601669bdc8463e32d9980295be52e45234 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv2d_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONV2D_INFER_H +#define MINDSPORE_NNACL_CONV2D_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONV2D_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv3d_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv3d_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..3725521f26d144e2a5700b6ca7626e708df9b964 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv3d_infer.c @@ -0,0 +1,27 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/infer/conv3d_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int Conv3dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + // The InferShape of Conv3D is not implemented here, it just prevents the InferShape process from being interrupted + // and makes the nodes shape are {}. + return NNACL_OK; +} + +REG_INFER(Conv3D, PrimType_Inner_Conv3D, Conv3dInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv3d_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv3d_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..3f4139d34c12f352f864bf50e34bdc655b75d3ee --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/conv3d_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CONV3D_INFER_H +#define MINDSPORE_NNACL_CONV3D_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Conv3dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CONV3D_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/crop_and_resize_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/crop_and_resize_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..e0a7dfa42c49b9e159761730b996b45c5723cdb3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/crop_and_resize_infer.c @@ -0,0 +1,69 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/crop_and_resize_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int CropAndResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 4); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (outputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 0 && input->shape_size_ != 4) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + if (NNACLGetBatch(input) == 0) { + ShapePush(output_shape, &output_shape_size, 0); + } else if (inputs[1]->data_ != NULL) { + const TensorC *boxes_tensor = inputs[1]; + if (boxes_tensor->shape_size_ < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + ShapePush(output_shape, &output_shape_size, boxes_tensor->shape_[0]); + } else { + return NNACL_INFER_INVALID; + } + + const TensorC *shape_tensor = inputs[3]; + int32_t *data = (int32_t *)(shape_tensor->data_); + if (data == NULL) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(shape_tensor) < 2) { + return NNACL_INPUT_TENSOR_ERROR; + } + ShapePush(output_shape, &output_shape_size, data[0]); + ShapePush(output_shape, &output_shape_size, data[1]); + ShapePush(output_shape, &output_shape_size, NNACLGetChannel(input)); + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(CropAndResize, PrimType_CropAndResize, CropAndResizeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/crop_and_resize_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/crop_and_resize_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..3cb45b524e0b7923bc3213179e35d0c19ee632f6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/crop_and_resize_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CROP_AND_RESIZE_INFER_H +#define MINDSPORE_NNACL_CROP_AND_RESIZE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CropAndResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CROP_AND_RESIZE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/crop_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/crop_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..ef42cec1501f487318214b2df3a3525005ca7437 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/crop_infer.c @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/crop_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/crop_parameter.h" +#include "nnacl/tensor_c_utils.h" + +int CropInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + size_t input_shape_size = inputs[0]->shape_size_; + CropParameter *param = (CropParameter *)parameter; + int64_t axis = param->axis_ < 0 ? param->axis_ + (int64_t)input_shape_size : param->axis_; + if (axis < 0 || axis >= (int64_t)input_shape_size) { + return NNACL_ERR; + } + + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(outputs[0], inputs[1]); + return NNACL_OK; +} + +REG_INFER(Crop, PrimType_Crop, CropInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/crop_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/crop_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..0691e0e1bed8e19ca5e4e4175e84c7ff7ad42b34 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/crop_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CROP_INFER_H +#define MINDSPORE_NNACL_CROP_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/crop_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CropInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CROP_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/cumsum_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/cumsum_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a9a399830291d0b79eb4c4a6379eef7567991e37 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/cumsum_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/cumsum_infer.h" +#include "nnacl/infer/infer_register.h" + +int CumsumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(Cumsum, PrimType_CumSum, CumsumInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/cumsum_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/cumsum_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..7680f3e438ba574f4f98e4f0fdb893a3496afd22 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/cumsum_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUMSUM_INFER_H +#define MINDSPORE_NNACL_CUMSUM_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CumsumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUMSUM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_gru_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_gru_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..060d04cfb82a53ae781448484e5bfb608afcdb83 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_gru_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/custom_gru_infer.h" +#include "nnacl/infer/infer_register.h" + +int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C6NUM, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != C3NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + SetShapeTensor(output, input); + const TensorC *weight_in = inputs[1]; + if (weight_in->shape_size_ != C2NUM || weight_in->shape_[0] % C3NUM != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + output->shape_[C2NUM] = weight_in[0].shape_[0] / C3NUM; + return NNACL_OK; +} + +REG_INFER(CustomGru, PrimType_Inner_CustomGru, CustomGruInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_gru_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_gru_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..830150d5c0d1699408e3f871adee8de6172ae681 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_gru_infer.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_GRU_INFER_H +#define MINDSPORE_NNACL_CUSTOM_GRU_INFER_H +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomGruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_GRU_INFER_H diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ops/ascend_native_composite.cc b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_is_inf_infer.c similarity index 41% rename from mindspore-lite/src/extendrt/delegate/ascend_native/ops/ascend_native_composite.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_is_inf_infer.c index 84fc36d44f998c37fb40be3d49e789a668823eee..0bb35d0ca2df2b4179c63d666ff1e6ed1bab2253 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ops/ascend_native_composite.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_is_inf_infer.c @@ -14,27 +14,27 @@ * limitations under the License. */ -#include "extendrt/delegate/ascend_native/ops/ascend_native_composite.h" -#include "mindapi/base/shared_ptr.h" -#include "mindapi/ir/common.h" -#include "mindapi/ir/value.h" -#include "mindspore/ops/op_def/op_name.h" -#include "ops/primitive_c.h" -#include "src/common/log_adapter.h" -#include "mindapi/helper.h" +#include "nnacl/infer/custom_is_inf_infer.h" +#include "nnacl/infer/infer_register.h" -namespace mindspore { -namespace ops { -MIND_API_OPERATOR_IMPL(AscendNativeComposite, BaseOperator); -void AscendNativeComposite::Init(int64_t group) { this->set_group(group); } +int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C1NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } -void AscendNativeComposite::set_group(int64_t group) { (void)this->AddAttr(kGroup, api::MakeValue(group)); } - -int64_t AscendNativeComposite::get_group() const { - auto value_ptr = this->GetAttr(kGroup); - return GetValue(value_ptr); + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output = outputs[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + output->data_type_ = kNumberTypeBool; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; } -REGISTER_PRIMITIVE_C(kNameAscendNativeComposite, AscendNativeComposite); -} // namespace ops -} // namespace mindspore +REG_INFER(CustomIsInf, PrimType_Inner_CustomIsInf, CustomIsInfInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_is_inf_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_is_inf_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d1b4b33dc89243c91d6e800c058423cce67bc746 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_is_inf_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H +#define MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H diff --git a/mindspore-lite/src/extendrt/graph_compiler/factory.cc b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_masked_fill_infer.c similarity index 44% rename from mindspore-lite/src/extendrt/graph_compiler/factory.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_masked_fill_infer.c index dd640c1db0aca9c59e23b62de4f92b2448afc17c..cf7907d918a3fc47f6950e89cf4fff91339b6885 100644 --- a/mindspore-lite/src/extendrt/graph_compiler/factory.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_masked_fill_infer.c @@ -13,29 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "extendrt/graph_compiler/factory.h" -#include -#include -namespace mindspore { -GraphCompilerRegistry &GraphCompilerRegistry::GetInstance() { - static GraphCompilerRegistry instance; - return instance; -} +#include "nnacl/infer/custom_masked_fill_infer.h" +#include "nnacl/infer/infer_register.h" -void GraphCompilerRegistry::RegCompiler(const mindspore::GraphCompilerType &type, const GraphCompilerRegFunc &creator) { - if (creator == nullptr) { - return; +int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; } - graph_compiler_map_[type] = creator; -} -std::shared_ptr GraphCompilerRegistry::GetCompiler( - const mindspore::GraphCompilerType &type, const std::shared_ptr &context) { - auto it = graph_compiler_map_.find(type); - if (it == graph_compiler_map_.end()) { - return nullptr; + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output = outputs[FIRST_INPUT]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; } - return it->second(context); + SetShapeTensor(output, input); + return NNACL_OK; } -} // namespace mindspore + +REG_INFER(CustomMaskedFill, PrimType_Inner_CustomMaskedFill, CustomMaskedFillInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_masked_fill_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_masked_fill_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a8adbae2023f939587d1c2eb5b770f3a5c0db3ee --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_masked_fill_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H +#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H diff --git a/mindspore-lite/src/extendrt/graph_partitioner/factory.cc b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_tensor_scatter_max_infer.c similarity index 43% rename from mindspore-lite/src/extendrt/graph_partitioner/factory.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_tensor_scatter_max_infer.c index 23e4b4b7ef2675e23467a2c9899e662b16e19974..1e2feba26261897ccb146705f92db872d0185d3f 100644 --- a/mindspore-lite/src/extendrt/graph_partitioner/factory.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_tensor_scatter_max_infer.c @@ -13,27 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "extendrt/graph_partitioner/factory.h" -#include -#include -namespace mindspore { -GraphPartitionerRegistry &GraphPartitionerRegistry::GetInstance() { - static GraphPartitionerRegistry instance; - return instance; -} +#include "nnacl/infer/custom_tensor_scatter_max_infer.h" +#include "nnacl/infer/infer_register.h" -void GraphPartitionerRegistry::RegPartitioner(const mindspore::GraphPartitionerType &type, - const GraphPartitionerRegFunc &creator) { - graph_partitioner_map_[type] = creator; -} +int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } -std::shared_ptr GraphPartitionerRegistry::GetPartitioner( - const mindspore::GraphPartitionerType &type) { - auto it = graph_partitioner_map_.find(type); - if (it == graph_partitioner_map_.end()) { - return nullptr; + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output = outputs[FIRST_INPUT]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; } - return it->second(); + SetShapeTensor(output, input); + return NNACL_OK; } -} // namespace mindspore + +REG_INFER(CustomTensorScatterMax, PrimType_Inner_CustomTensorScatterMax, CustomTensorScatterMaxInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_tensor_scatter_max_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_tensor_scatter_max_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..641aa483655450a90fbbfdaec34e814f70f4d1b6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/custom_tensor_scatter_max_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H +#define MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/constant_of_shape.cc b/mindspore-lite/ops/kernel/cpu/nnacl/infer/decoder_layer_infer.c similarity index 44% rename from mindspore-lite/tools/graph_kernel/converter/expanders/constant_of_shape.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/decoder_layer_infer.c index 467920f98c9d78d8e3997bddb56a6c95c394393e..401ec032f60ef84b535b53bfd2d649d62a3594be 100644 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/constant_of_shape.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/decoder_layer_infer.c @@ -13,26 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include +#include "nnacl/infer/decoder_layer_infer.h" +#include "nnacl/infer/infer_register.h" -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" - -namespace mindspore::graphkernel::expanders { -class ConstantOfShape : public OpDesc { - public: - ConstantOfShape() { - std::initializer_list attrs{"value", "data_type", "shape"}; - (void)validators_.emplace_back(std::make_unique(attrs)); +int DecoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C16NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; } - ~ConstantOfShape() = default; - - protected: - NodePtrList Expand(const NodePtrList &) override { - auto result = gb.Emit("ConstantOfShape", {}, - {{"value", attrs_["value"]}, {"data_type", attrs_["data_type"]}, {"shape", attrs_["shape"]}}); - return {result}; + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output0 = outputs[FIRST_INPUT]; + SetDataTypeFormat(output0, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; } -}; -EXPANDER_OP_DESC_REGISTER("ConstantOfShape", ConstantOfShape); -} // namespace mindspore::graphkernel::expanders + SetShapeTensor(output0, input); + return NNACL_OK; +} + +REG_INFER(DecoderLayer, PrimType_Inner_DecoderLayer, DecoderLayerInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/decoder_layer_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/decoder_layer_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c72513624e8cbf5bfef17856256edb088695c87b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/decoder_layer_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_DECODER_LAYER_INFER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_DECODER_LAYER_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DecoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_DECODER_LAYER_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/deconv2d_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/deconv2d_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2102fd3ddd39ed72458f40f043540b51814fd9b6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/deconv2d_infer.c @@ -0,0 +1,119 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/deconv2d_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int Deconv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + const TensorC *weight = inputs[1]; + TensorC *output = outputs[0]; + output->format_ = input->format_; + output->data_type_ = input->data_type_; + + ConvParameter *param = (ConvParameter *)parameter; + if (param->group_ == 0) { + param->group_ = weight->shape_[0]; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int32_t input_h = NNACLGetHeight(input); + int32_t input_w = NNACLGetWidth(input); + + int32_t output_n = NNACLGetBatch(input); + int32_t output_h = 0; + int32_t output_w = 0; + int32_t output_c = NNACLGetChannel(weight); + NNACL_CHECK_TRUE_RET(NNACLGetChannel(input) == NNACLGetBatch(weight), NNACL_ERR); + if (param->group_ == NNACLGetChannel(input) && 1 == NNACLGetChannel(weight)) { + output_c = NNACLGetBatch(weight); /* depthwise */ + } + + int kernel_w = param->kernel_w_ != -1 ? param->kernel_w_ : NNACLGetWidth(weight); + int kernel_h = param->kernel_h_ != -1 ? param->kernel_h_ : NNACLGetHeight(weight); + NNACL_CHECK_FALSE(kernel_w <= 0, NNACL_ERR); + NNACL_CHECK_FALSE(kernel_h <= 0, NNACL_ERR); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(kernel_h, kernel_w), NNACL_ERR); + + int stride_w = param->stride_w_; + int stride_h = param->stride_h_; + NNACL_CHECK_FALSE(stride_w <= 0, NNACL_ERR); + NNACL_CHECK_FALSE(stride_h <= 0, NNACL_ERR); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(input_h, stride_h), NNACL_ERR); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(input_w, stride_w), NNACL_ERR); + + int dilate_w = param->dilation_w_; + int dilate_h = param->dilation_h_; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(kernel_h, dilate_h), NNACL_ERR); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(kernel_w, dilate_w), NNACL_ERR); + + int pad_mode = param->pad_mode_; + if (pad_mode == Pad_pad) { + output_h = (input_h - 1) * stride_h + ((kernel_h - 1) * dilate_h + 1) - param->pad_u_ - param->pad_d_; + output_w = (input_w - 1) * stride_w + ((kernel_w - 1) * dilate_w + 1) - param->pad_l_ - param->pad_r_; + } else if (pad_mode == Pad_same) { + output_h = input_h * stride_h; + output_w = input_w * stride_w; + } else if (pad_mode == Pad_valid) { + output_h = (input_h - 1) * stride_h + kernel_h; + output_w = (input_w - 1) * stride_w + kernel_w; + } else { + return NNACL_ERR; + } + + output_h += param->output_padding_h_; + output_w += param->output_padding_w_; + + output->shape_size_ = 4; + output->shape_[0] = output_n; + output->shape_[1] = output_h; + output->shape_[2] = output_w; + output->shape_[3] = output_c; + + if (pad_mode == Pad_same) { + param->pad_u_ = ((input_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - output_h) / 2; + param->pad_l_ = ((input_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - output_w) / 2; + } else if (pad_mode == Pad_valid) { + param->pad_u_ = 0; + param->pad_l_ = 0; + } + + const int *in_shape = input->shape_; + param->input_batch_ = in_shape[0]; + param->input_h_ = in_shape[1]; + param->input_w_ = in_shape[2]; + param->input_channel_ = in_shape[3]; + param->output_batch_ = output_n; + param->output_h_ = output_h; + param->output_w_ = output_w; + param->output_channel_ = output_c; + param->kernel_h_ = kernel_h; + param->kernel_w_ = kernel_w; + return NNACL_OK; +} + +REG_INFER(Conv2dTranspose, PrimType_Conv2dTransposeFusion, Deconv2dInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/deconv2d_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/deconv2d_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5bc0689dcdc427505c87448ccdd46b710933088a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/deconv2d_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DECONV2D_INFER_H +#define MINDSPORE_NNACL_DECONV2D_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Deconv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DECONV2D_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/depth_to_space_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/depth_to_space_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..54371c56327e5be0d7f244f5b5d215b68cb777bb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/depth_to_space_infer.c @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/depth_to_space_infer.h" +#include "nnacl/infer/infer_register.h" + +int DepthToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], input); + DepthToSpaceParameter *param = (DepthToSpaceParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 4) { + return NNACL_PARAM_INVALID; + } + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); + + int32_t block_size = param->block_size_; + if (INT_MUL_OVERFLOW(block_size, block_size)) { + return NNACL_PARAM_INVALID; + } + if (block_size == 0 || input_shape[kNHWC_C] % (block_size * block_size) != 0 || input_shape[kNHWC_C] == 0) { + return NNACL_PARAM_INVALID; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input_shape_size; + output_shape[kNHWC_N] = input_shape[kNHWC_N]; + output_shape[kNHWC_H] = input_shape[kNHWC_H] * block_size; + output_shape[kNHWC_W] = input_shape[kNHWC_W] * block_size; + output_shape[kNHWC_C] = input_shape[kNHWC_C] / (block_size * block_size); + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(DepthToSpace, PrimType_DepthToSpace, DepthToSpaceInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/depth_to_space_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/depth_to_space_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d80414d3c36e6e1726ad416fc50f19a361712e23 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/depth_to_space_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DEPTHTOSPACE_INFER_H +#define MINDSPORE_NNACL_DEPTHTOSPACE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/depth_to_space_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DepthToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DEPTHTOSPACE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/depthwise_conv2d_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/depthwise_conv2d_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..c1aff72bf0c73522ff8f010d6a99dafcb407ebc4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/depthwise_conv2d_infer.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/depthwise_conv2d_infer.h" +#include "nnacl/tensor_c_utils.h" + +int DepthwiseConv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + ConvParameter *param = (ConvParameter *)parameter; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 4) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input_h = input->shape_[1]; + int input_w = input->shape_[2]; + int input_channel = input->shape_[3]; + int output_w = 0, output_h = 0; + param->input_channel_ = input_channel; + + if (param->stride_h_ == 0 || param->stride_w_ == 0) { + return NNACL_PARAM_INVALID; + } + param->kernel_h_ = param->kernel_h_ != -1 ? param->kernel_h_ : NNACLGetHeight(inputs[kWeightIndex]); + param->kernel_w_ = param->kernel_w_ != -1 ? param->kernel_w_ : NNACLGetWidth(inputs[kWeightIndex]); + if (param->pad_mode_ == Pad_same) { + output_h = ceil((float)(input_h) / (float)(param->stride_h_)); + output_w = ceil((float)(input_w) / (float)(param->stride_w_)); + int pad_h_all = ((output_h - 1) * param->stride_h_ + (param->kernel_h_ - 1) * param->dilation_h_ + 1 - input_h); + int pad_w_all = ((output_w - 1) * param->stride_w_ + (param->kernel_w_ - 1) * param->dilation_w_ + 1 - input_w); + if (pad_h_all > 0) { + param->pad_u_ = pad_h_all / 2; + param->pad_d_ = pad_h_all - param->pad_u_; + } + if (pad_w_all > 0) { + param->pad_l_ = pad_w_all / 2; + param->pad_r_ = pad_w_all - param->pad_l_; + } + } else { + output_h = ceil(((float)(input_h) + param->pad_u_ + param->pad_d_ - + ((float)(param->kernel_h_) - 1) * (float)(param->dilation_h_)) / + (float)(param->stride_h_)); + output_w = ceil(((float)(input_w) + param->pad_l_ + param->pad_r_ - + ((float)(param->kernel_w_) - 1) * (float)(param->dilation_w_)) / + (float)(param->stride_w_)); + } + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_); + out_shape[1] = output_h; + out_shape[2] = output_w; + if (param->channel_multiplie_ != 1) { + return NNACL_ERR; + } + out_shape[3] = input_channel; // in_channel * out_channel + SetShapeArray(output, out_shape, out_shape_size); + return 0; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/depthwise_conv2d_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/depthwise_conv2d_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d3b840a3a2718c564aac701e468df884d199d37b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/depthwise_conv2d_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DEPTHWISE_CONV2D_INFER_H +#define MINDSPORE_NNACL_DEPTHWISE_CONV2D_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DepthwiseConv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DEPTHWISE_CONV2D_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/detection_post_process_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/detection_post_process_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..7ccc3f87de87fe1d91751ac68b3acdd41dec6192 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/detection_post_process_infer.c @@ -0,0 +1,83 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/detection_post_process_infer.h" +#include "nnacl/infer/infer_register.h" + +int DetectionPostProcessInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 4); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *boxes = inputs[0]; + const TensorC *scores = inputs[1]; + const TensorC *anchors = inputs[2]; + if (boxes->shape_size_ < 2 || scores->shape_size_ < 3 || anchors->shape_size_ < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + DetectionPostProcessParameter *param = (DetectionPostProcessParameter *)parameter; + if (scores->shape_[2] < param->num_classes_) { + return NNACL_ERR; + } + if (scores->shape_[2] - param->num_classes_ > 1) { + return NNACL_ERR; + } + if (boxes->shape_[1] != scores->shape_[1]) { + return NNACL_ERR; + } + if (boxes->shape_[1] != anchors->shape_[0]) { + return NNACL_ERR; + } + + TensorC *detected_boxes = outputs[0]; + TensorC *detected_classes = outputs[1]; + TensorC *detected_scores = outputs[2]; + TensorC *num_det = outputs[3]; + + detected_boxes->format_ = boxes->format_; + detected_boxes->data_type_ = kNumberTypeFloat32; + detected_classes->format_ = boxes->format_; + detected_classes->data_type_ = kNumberTypeFloat32; + detected_scores->format_ = boxes->format_; + detected_scores->data_type_ = kNumberTypeFloat32; + num_det->format_ = boxes->format_; + num_det->data_type_ = kNumberTypeFloat32; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const int max_detections = param->max_detections_; + const int max_classes_per_detection = param->max_classes_per_detection_; + const int num_detected_boxes = (int)(max_detections * max_classes_per_detection); + detected_boxes->shape_size_ = 3; + detected_boxes->shape_[0] = 1; + detected_boxes->shape_[1] = num_detected_boxes; + detected_boxes->shape_[2] = 4; + detected_classes->shape_size_ = 2; + detected_classes->shape_[0] = 1; + detected_classes->shape_[1] = num_detected_boxes; + detected_scores->shape_size_ = 2; + detected_scores->shape_[0] = 1; + detected_scores->shape_[1] = num_detected_boxes; + num_det->shape_size_ = 1; + num_det->shape_[0] = 1; + + return NNACL_OK; +} + +REG_INFER(DetectionPostProcess, PrimType_DetectionPostProcess, DetectionPostProcessInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/detection_post_process_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/detection_post_process_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6e74e56f7ea02fbc77a57cf88b1fcfdabe79e129 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/detection_post_process_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DETECTION_POST_PROCESS_INFER_H +#define MINDSPORE_NNACL_DETECTION_POST_PROCESS_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/detection_post_process_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DetectionPostProcessInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DETECTION_POST_PROCESS_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/dropout_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/dropout_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..d103017c35b940d93a396057ac5740552342a1f1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/dropout_grad_infer.c @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/dropout_grad_infer.h" +#include "nnacl/infer/infer_register.h" + +int DropoutGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (outputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(DropoutGrad, PrimType_DropoutGrad, DropoutGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/dropout_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/dropout_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..df8d9c5a277011ebf7bdae4bf68d28931c1f249f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/dropout_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DROPOUT_GRAD_INFER_H +#define MINDSPORE_NNACL_DROPOUT_GRAD_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DropoutGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DROPOUT_GRAD_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/dropout_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/dropout_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..0e14ea90ecaf03c04e934e8833d94c2f67440fd2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/dropout_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/dropout_infer.h" +#include "nnacl/infer/infer_register.h" + +int DropoutInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + SetDataTypeFormat(output0, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input); + if (outputs_size > 1) { + TensorC *output1 = outputs[1]; + SetDataTypeFormat(output1, input); + SetShapeTensor(output1, input); + } + return NNACL_OK; +} + +REG_INFER(Dropout, PrimType_Dropout, DropoutInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/dropout_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/dropout_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c592a9bfa60e843bb7eef958a8997a72db72f986 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/dropout_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DROPOUT_INFER_H +#define MINDSPORE_NNACL_DROPOUT_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DropoutInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DROPOUT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/dynamic_quant_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/dynamic_quant_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..47daa6b2fb283451b9e1182d795dfaa61663a1af --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/dynamic_quant_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/dynamic_quant_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/dynamic_quant_parameter.h" + +int DynamicQuantInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + DynamicQuantParameter *param = (DynamicQuantParameter *)parameter; + output->data_type_ = param->dst_type_; + NNACL_CHECK_TRUE_RET(output->data_type_ > kNumberTypeBegin && output->data_type_ < kNumberTypeEnd, NNACL_ERR); + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(DynamicQuant, PrimType_DynamicQuant, DynamicQuantInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/dynamic_quant_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/dynamic_quant_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6d9303104d05a5de27090f685f4413a062bc1306 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/dynamic_quant_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H +#define MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DynamicQuantInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/embedding_lookup_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/embedding_lookup_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..110612f0fba833f0ea447e83ceece21509deba56 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/embedding_lookup_infer.c @@ -0,0 +1,77 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/embedding_lookup_infer.h" +#include "nnacl/infer/infer_register.h" + +int EmbeddingLookupInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size < 2 || outputs_size != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *params_ = inputs[0]; + const TensorC *ids = inputs[inputs_size - 1]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, params_); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (params_->shape_size_ > MAX_SHAPE_SIZE || ids->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int embedding_shape[MAX_SHAPE_SIZE] = {0}; + size_t embedding_shape_size = 0; + ShapeSet(embedding_shape, &embedding_shape_size, params_->shape_, params_->shape_size_); + int erase_ret = ShapeErase(embedding_shape, &embedding_shape_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, ids->shape_, ids->shape_size_); + for (size_t i = 0; i < embedding_shape_size; ++i) { + if (output_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapePush(output_shape, &output_shape_size, embedding_shape[i]); + } + for (size_t i = 1; i < inputs_size - 1; ++i) { + if (inputs[i]->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int embedding_shape_t[MAX_SHAPE_SIZE] = {0}; + size_t embedding_shape_t_size = 0; + ShapeSet(embedding_shape_t, &embedding_shape_t_size, inputs[i]->shape_, inputs[i]->shape_size_); + erase_ret = ShapeErase(embedding_shape_t, &embedding_shape_t_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + bool t_equal = ShapeEqual(embedding_shape_t, embedding_shape_t_size, embedding_shape, embedding_shape_size); + if (!t_equal) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(EmbeddingLookup, PrimType_EmbeddingLookupFusion, EmbeddingLookupInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/embedding_lookup_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/embedding_lookup_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..581b1cd88661c18b247bed4d69d56925e894f37d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/embedding_lookup_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_EMBEDDING_LOOKUP_INFER_H +#define MINDSPORE_NNACL_EMBEDDING_LOOKUP_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int EmbeddingLookupInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_EMBEDDING_LOOKUP_INFER_H diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/strided_slice.cc b/mindspore-lite/ops/kernel/cpu/nnacl/infer/encoder_layer_infer.c similarity index 44% rename from mindspore-lite/tools/graph_kernel/converter/expanders/strided_slice.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/encoder_layer_infer.c index 461c0cb3b5061a2a83c5859dba41e6f3b2976393..e82aff681d860e8ff9ddd8c4f982fd151546289d 100644 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/strided_slice.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/encoder_layer_infer.c @@ -13,29 +13,24 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include +#include "nnacl/infer/encoder_layer_infer.h" +#include "nnacl/infer/infer_register.h" -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" - -namespace mindspore::graphkernel::expanders { -class StridedSlice : public OpDesc { - public: - StridedSlice() {} - ~StridedSlice() = default; - - protected: - bool CheckInputs() override { - const size_t onnx_strided_slice_input_num = 5; - return inputs_info_.size() == onnx_strided_slice_input_num; +int EncoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C9NUM, C1NUM); + if (check_ret != NNACL_OK) { + return check_ret; } - - NodePtrList Expand(const NodePtrList &inputs) override { - std::vector shp; - (void)shp.emplace_back(outputs_info_[0].shape); - auto result = gb.Emit("StridedSliceOnnx", inputs, {{"output_shape", MakeValue(shp)}}); - return {result}; + const TensorC *input = inputs[FIRST_INPUT]; + TensorC *output0 = outputs[FIRST_INPUT]; + SetDataTypeFormat(output0, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; } -}; -EXPANDER_OP_DESC_REGISTER("StridedSlice", StridedSlice); -} // namespace mindspore::graphkernel::expanders + SetShapeTensor(output0, input); + return NNACL_OK; +} + +REG_INFER(EncoderLayer, PrimType_Inner_EncoderLayer, EncoderLayerInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/encoder_layer_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/encoder_layer_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..8a46f92af39e381c2746d64688550addc6871167 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/encoder_layer_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int EncoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/expand_dims_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/expand_dims_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b844698523a44e02993390295d89cab879d11e7c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/expand_dims_infer.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/expand_dims_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int ExpandDimsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (parameter->quant_type_ == Quant_QuantWeight) { + output->data_type_ = kNumberTypeFloat32; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (inputs_size < C2NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + + if (inputs[1]->data_ == NULL) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (NNACLGetElementNum(inputs[1]) < 1) { + return NNACL_ERR; + } + int dim = ((int32_t *)(inputs[1]->data_))[0]; + if (dim < 0) { + dim += (int)(input->shape_size_) + 1; + } + if (dim > (int)(input->shape_size_)) { + return NNACL_INPUT_TENSOR_ERROR; + } + + ShapeSet(output->shape_, &(output->shape_size_), input->shape_, input->shape_size_); + int ret = ShapeInsert(output->shape_, &(output->shape_size_), dim, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + return NNACL_OK; +} + +REG_INFER(ExpandDims, PrimType_ExpandDims, ExpandDimsInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/expand_dims_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/expand_dims_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..170b76cd0768db56b378acd7be4334d2340ee15b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/expand_dims_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_EXPAND_DIMS_INFER_H +#define MINDSPORE_NNACL_EXPAND_DIMS_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ExpandDimsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_EXPAND_DIMS_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/fft_imag_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fft_imag_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..d794d7086e61710fbccdccb7237c3b65edcf273e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fft_imag_infer.c @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/fft_imag_infer.h" +#include "nnacl/infer/infer_register.h" + +int FftImagInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + return FftInferShape(inputs, inputs_size, outputs, outputs_size, parameter); +} + +REG_INFER(FftImag, PrimType_FftImag, FftImagInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/fft_imag_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fft_imag_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..8dcb2bb7265639b93f6e04f1bbdf364725dea3e6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fft_imag_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FFT_IMAG_INFER_H +#define MINDSPORE_NNACL_FFT_IMAG_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FftImagInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FFT_IMAG_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/fft_real_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fft_real_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..cd95ebb596c86d9b272e7e03390d144adb93eaf9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fft_real_infer.c @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/fft_real_infer.h" +#include "nnacl/infer/infer_register.h" + +int FftRealInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + return FftInferShape(inputs, inputs_size, outputs, outputs_size, parameter); +} + +REG_INFER(FftReal, PrimType_FftReal, FftRealInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/fft_real_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fft_real_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d3e89303259c7844b176aecad36ebf1ed7693aac --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fft_real_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FFT_REAL_INFER_H +#define MINDSPORE_NNACL_FFT_REAL_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FftRealInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FFT_REAL_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/fill_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fill_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b3a7caf1873748f0d2379d733a333976a35d6001 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fill_infer.c @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/fill_infer.h" +#include "nnacl/infer/infer_register.h" + +int FillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const TensorC *dst_shape_tensor = inputs[1]; + if (dst_shape_tensor->data_type_ != kNumberTypeInt && dst_shape_tensor->data_type_ != kNumberTypeInt32) { + return NNACL_ERR; + } + const int32_t *dst_shape = (int32_t *)(dst_shape_tensor->data_); + int num_dims = 1; + if (dst_shape_tensor->shape_size_ != DIMENSION_1D) { + return NNACL_ERR; + } + for (size_t i = 0; i < dst_shape_tensor->shape_size_; ++i) { + if (INT_MUL_OVERFLOW(num_dims, dst_shape_tensor->shape_[i])) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + NNACL_CHECK_FALSE(dst_shape_tensor->shape_[i] < 0, NNACL_ERR); + num_dims *= dst_shape_tensor->shape_[i]; + } + if (num_dims != 0 && dst_shape == NULL) { + return NNACL_INFER_INVALID; + } + if (num_dims > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (int i = 0; i < num_dims; i++) { + ShapePush(output_shape, &output_shape_size, dst_shape[i]); + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(Fill, PrimType_Fill, FillInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/fill_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fill_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9bc39a5e8aec4e11f0a52a99069b0add047ee384 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fill_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FILL_INFER_H +#define MINDSPORE_NNACL_FILL_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FILL_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/fillv2_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fillv2_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a8991aa3b090670966d4d81808af50993aa6707d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fillv2_infer.c @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/fillv2_infer.h" +#include "nnacl/infer/infer_register.h" + +int FillV2InferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[1]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const TensorC *dst_shape_tensor = inputs[0]; + const int32_t *dst_shape = (int32_t *)(dst_shape_tensor->data_); + int num_dims = 1; + if (dst_shape_tensor->shape_size_ != DIMENSION_1D) { + return NNACL_ERR; + } + for (size_t i = 0; i < dst_shape_tensor->shape_size_; ++i) { + if (INT_MUL_OVERFLOW(num_dims, dst_shape_tensor->shape_[i])) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + NNACL_CHECK_FALSE(dst_shape_tensor->shape_[i] < 0, NNACL_ERR); + num_dims *= dst_shape_tensor->shape_[i]; + } + if (num_dims != 0 && dst_shape == NULL) { + return NNACL_INFER_INVALID; + } + if (num_dims > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (int i = 0; i < num_dims; i++) { + ShapePush(output_shape, &output_shape_size, dst_shape[i]); + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(FillV2, PrimType_FillV2, FillV2InferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/fillv2_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fillv2_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..ccd6be9159324eea302ff153e9de29a4fcc39678 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fillv2_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FILLV2_INFER_H +#define MINDSPORE_NNACL_FILLV2_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FillV2InferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FILLV2_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/flatten_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/flatten_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..d5de8bd29d8b07860f8cfa21adbb29298271abc4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/flatten_grad_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/flatten_grad_infer.h" +#include "nnacl/infer/infer_register.h" + +int FlattenGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int output_shape_size = inputs[1]->shape_[0]; + if (inputs[1]->data_ == NULL || output_shape_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + SetShapeArray(output, (int *)(inputs[1]->data_), (size_t)output_shape_size); + return NNACL_OK; +} + +REG_INFER(FlattenGrad, PrimType_FlattenGrad, FlattenGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/flatten_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/flatten_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..503eecbccf038bc691b190311a121e4d2524d236 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/flatten_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FLATTEN_GRAD_INFER_INFER_H +#define MINDSPORE_NNACL_FLATTEN_GRAD_INFER_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FlattenGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FLATTEN_GRAD_INFER_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/flatten_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/flatten_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..dfaa2be3a4af8b4d8f11572c7bf77ebc42e15b37 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/flatten_infer.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/flatten_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/flatten_parameter.h" + +int FlattenInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ <= 0 || input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); + FlattenParameter *param = (FlattenParameter *)parameter; + int axis = param->axis_; + // The value for axis must be in the range[-r, r], where r is + // the rank of the input tensor.Negative value means counting + // dimensions from the back. + axis = axis < 0 ? (int)input_shape_size - axis : axis; + if (axis >= (int)input_shape_size) { + return NNACL_ERR; + } + int output_shape[2]; + output_shape[0] = axis == 0 ? 1 : input_shape[0]; + for (size_t i = 1; i < (size_t)axis; i++) { + output_shape[0] *= input_shape[i]; + } + output_shape[1] = input_shape[axis]; + for (size_t i = (size_t)axis + 1; i < input_shape_size; i++) { + output_shape[1] *= input_shape[i]; + } + SetShapeArray(output, output_shape, 2); + return NNACL_OK; +} + +REG_INFER(Flatten, PrimType_Flatten, FlattenInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/flatten_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/flatten_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..755c7b366475bf7eec9d08a2461fe9ae1f9c841b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/flatten_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FLATTEN_INFER_H +#define MINDSPORE_NNACL_FLATTEN_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FlattenInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FLATTEN_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/format_transpose_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/format_transpose_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..0af43f31197c9115fba5167415298736b70ef5d6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/format_transpose_infer.c @@ -0,0 +1,67 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/format_transpose_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/format_transpose_parameter.h" +#include "nnacl/tensor_c_utils.h" + +int FormatTransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + FormatTransposeParameter *param = (FormatTransposeParameter *)parameter; + output->format_ = (int)(param->dst_format_); + output->data_type_ = input->data_type_; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != DIMENSION_4D) { + SetShapeArray(output, input->shape_, input->shape_size_); + return NNACL_OK; + } + + int input_b = NNACLGetBatch(input); + int input_h = NNACLGetHeight(input); + int input_w = NNACLGetWidth(input); + int input_c = NNACLGetChannel(input); + + // set output shape + int out_shape[MAX_SHAPE_SIZE] = {0}; + out_shape[DIMENSION_0D] = input_b; + if (param->dst_format_ == Format_NCHW || param->dst_format_ == Format_NC4HW4 || param->dst_format_ == Format_NC8HW8) { + out_shape[DIMENSION_1D] = input_c; + out_shape[DIMENSION_2D] = input_h; + out_shape[DIMENSION_3D] = input_w; + } else if (param->dst_format_ == Format_NHWC) { + out_shape[DIMENSION_1D] = input_h; + out_shape[DIMENSION_2D] = input_w; + out_shape[DIMENSION_3D] = input_c; + } else { + return NNACL_ERR; + } + + SetShapeArray(output, out_shape, input->shape_size_); + return NNACL_OK; +} + +REG_INFER(FormatTranspose, PrimType_FormatTranspose, FormatTransposeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/format_transpose_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/format_transpose_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..391b86e07426e233d400166b37ea243cd958dfe0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/format_transpose_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FORMAT_TRANSPOSE_INFER_H +#define MINDSPORE_NNACL_FORMAT_TRANSPOSE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FormatTransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FORMAT_TRANSPOSE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/fse_decoder_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fse_decoder_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..7dab65efb08ba6c3bd5dca0bac86439b95db60a4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fse_decoder_infer.c @@ -0,0 +1,35 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/fse_decoder_infer.h" +#include "nnacl/infer/infer_register.h" + +size_t kInputSize = 7; + +int FseDecoderInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, kInputSize, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + const TensorC *cen_input = inputs[4]; + TensorC *output0 = outputs[FIRST_INPUT]; + SetDataTypeFormat(output0, cen_input); + + return NNACL_OK; +} + +REG_INFER(FseDecode, PrimType_Inner_FseDecode, FseDecoderInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/fse_decoder_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fse_decoder_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9754869e732a52164c7fc367a7dc673ba45a3bb1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fse_decoder_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_FSE_DECODER_INFER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_FSE_DECODER_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FseDecoderInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_FSE_DECODER_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/full_connection_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/full_connection_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..14ea335cc166fa3974d50df4115dc5a8fbcccc0f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/full_connection_infer.c @@ -0,0 +1,92 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/full_connection_infer.h" +#include "nnacl/infer/infer_register.h" + +int FullConnectionInferPreJudge(const MatMulParameter *param, size_t inputs_size, const TensorC *input0) { + if ((param->has_bias_ && inputs_size != 3) || (!param->has_bias_ && inputs_size != 2)) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (param->use_axis_ && (param->axis_ < 1 || param->axis_ > (int)(input0->shape_size_))) { + return NNACL_ERR; + } + return NNACL_OK; +} + +int FullConnectionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input0 = inputs[0]; + const TensorC *input1 = inputs[1]; + TensorC *output = outputs[0]; + MatMulParameter *param = (MatMulParameter *)parameter; + SetDataTypeFormat(output, input0); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int pre_judge = FullConnectionInferPreJudge(param, inputs_size, input0); + if (pre_judge != NNACL_OK) { + return pre_judge; + } + int new_k = 1; + if (param->use_axis_) { + for (size_t i = (size_t)(param->axis_); i < input0->shape_size_; ++i) { + new_k *= input0->shape_[i]; + } + if (new_k != input1->shape_[1]) { + return NNACL_INPUT_TENSOR_ERROR; + } + } else { + new_k = input1->shape_[1]; + } + if (param->has_bias_) { + if (inputs[2]->shape_[0] != input1->shape_[0]) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + if (inputs[0]->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, inputs[0]->shape_, inputs[0]->shape_size_); + if (param->use_axis_) { + out_shape_size = (size_t)(param->axis_) + 1; + out_shape[param->axis_] = input1->shape_[0]; + } else { + int total = 1; + for (size_t i = 0; i < input0->shape_size_; ++i) { + total *= input0->shape_[i]; + } + out_shape_size = 2; + if (new_k == 0) { + return NNACL_ERR; + } + int batch_size = total / new_k; + out_shape[0] = batch_size; + out_shape[1] = input1->shape_[0]; + } + SetShapeArray(output, out_shape, out_shape_size); + + return NNACL_OK; +} + +REG_INFER(FullConnection, PrimType_FullConnection, FullConnectionInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/full_connection_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/full_connection_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..68cc05a5a37425338ac54179ca6095ebd2875a19 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/full_connection_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FULL_CONNECTION_INFER_H +#define MINDSPORE_NNACL_FULL_CONNECTION_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FullConnectionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FULL_CONNECTION_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/fused_batchnorm_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fused_batchnorm_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..7a86f2398579c227f28daaf1ac665970165758d2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fused_batchnorm_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/fused_batchnorm_infer.h" +#include "nnacl/infer/infer_register.h" + +int FusedBatchNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + + for (size_t i = 0; i < inputs_size; i++) { + if (outputs_size <= i) { + break; + } + SetShapeTensor(outputs[i], inputs[i]); + SetDataTypeFormat(outputs[i], inputs[i]); + } + if (outputs_size > 5) { + SetDataTypeFormat(outputs[5], inputs[0]); + outputs[5]->shape_size_ = 1; + outputs[5]->shape_[0] = 1; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + return NNACL_OK; +} + +REG_INFER(FusedBatchNorm, PrimType_FusedBatchNorm, FusedBatchNormInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/fused_batchnorm_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fused_batchnorm_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5d6ccd2f6e736105331223e4905bbfba62de3d81 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/fused_batchnorm_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_FUSED_BATCHNORM_INFER_H +#define MINDSPORE_NNACL_FUSED_BATCHNORM_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int FusedBatchNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_FUSED_BATCHNORM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_d_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_d_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..cdab600c87a6aaee9be1e1127496ae9f52274023 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_d_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/gather_d_infer.h" +#include "nnacl/infer/infer_register.h" + +int GatherDInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + const int input_size_limit = 3; + const int output_size_limit = 1; + if (inputs_size != input_size_limit || outputs_size != output_size_limit) { + return NNACL_ERR; + } + const TensorC *input = inputs[0]; + const TensorC *index = inputs[2]; + TensorC *output = outputs[0]; + output->data_type_ = input->data_type_; + if (parameter->quant_type_ == Quant_QuantWeight) { + output->data_type_ = kNumberTypeFloat32; + } + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + SetShapeTensor(output, index); + return NNACL_OK; +} + +REG_INFER(GatherD, PrimType_GatherD, GatherDInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_d_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_d_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5e2c92be2540d12d42fb3a465985765cf96912db --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_d_infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_GATHER_D_INFER_H +#define MINDSPORE_NNACL_GATHER_D_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/gather_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GatherDInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GATHER_D_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..eb52e2354591f4bcc4eb5a7368be7dc3742ab661 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_infer.c @@ -0,0 +1,83 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/gather_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + const size_t kMinimumGradInputsNum = 3; + if (inputs_size < kMinimumGradInputsNum || outputs_size != 1) { + return NNACL_ERR; + } + const TensorC *input = inputs[0]; + const TensorC *indices = inputs[1]; + TensorC *output = outputs[0]; + output->data_type_ = input->data_type_; + if ((input->data_type_ == kNumberTypeInt8 || input->data_type_ == kNumberTypeInt16) && + (parameter->quant_type_ == Quant_QuantWeight || parameter->quant_type_ == Quant_QuantDynamic)) { + output->data_type_ = kNumberTypeFloat32; + } + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE || indices->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (inputs[2]->data_ == NULL) { + return NNACL_NULL_PTR; + } + if (NNACLGetElementNum(inputs[2]) < 1) { + return NNACL_ERR; + } + int axis = *((int *)inputs[2]->data_); + if (axis < 0) { + axis += input->shape_size_; + } + int indices_shape[MAX_SHAPE_SIZE]; + size_t indices_shape_size = 0; + ShapeSet(indices_shape, &indices_shape_size, indices->shape_, indices->shape_size_); + size_t indices_rank = indices_shape_size; + int in_shape[MAX_SHAPE_SIZE] = {0}; + size_t in_shape_size = 0; + ShapeSet(in_shape, &in_shape_size, input->shape_, input->shape_size_); + if ((int)(in_shape_size) < axis + 1) { + return NNACL_ERR; + } + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, in_shape, in_shape_size); + int erase_ret = ShapeErase(out_shape, &out_shape_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + for (int i = (int)(indices_rank - 1); i >= 0; --i) { + ret = ShapeInsert(out_shape, &out_shape_size, axis, indices_shape[i]); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Gather, PrimType_Gather, GatherInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..0c4d763c6b7c8a2ca736156e5e16b3a113df31df --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GATHER_INFER_H +#define MINDSPORE_NNACL_GATHER_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/gather_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GATHER_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_nd_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_nd_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..5c9d979d8667ec2193543c61b3dcb93a30032daf --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_nd_infer.c @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/gather_nd_infer.h" +#include "nnacl/infer/infer_register.h" + +int GatherNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + const TensorC *indices = inputs[1]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE || indices->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int in_rank = (int)(input->shape_size_); + int indices_rank = (int)(indices->shape_size_); + for (int i = 0; i < indices_rank; i++) { + NNACL_CHECK_FALSE(indices->shape_[i] == 0, NNACL_ERR); + } + if (indices->shape_[indices_rank - 1] > in_rank) { + return NNACL_OK; + } + int i = 0; + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + for (i = 0; i < indices_rank - 1; ++i) { + ShapePush(out_shape, &out_shape_size, indices->shape_[i]); + } + for (i = indices->shape_[indices_rank - 1]; i < in_rank; ++i) { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(GatherNd, PrimType_GatherNd, GatherNdInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_nd_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_nd_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..cd90d84bedc6a4f21e31c2aad7cb15514893593a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gather_nd_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GATHER_ND_INFER_H +#define MINDSPORE_NNACL_GATHER_ND_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/gatherNd_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GatherNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GATHER_ND_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/glu_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/glu_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..9c265cf370696ab15291ae62fd074330f668ca63 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/glu_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/infer/glu_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/glu_parameter.h" + +int GluInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + GluParameter *param = (GluParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (param->axis_ >= (int)input->shape_size_ || (param->axis_ < 0 && ((int)input->shape_size_ + param->axis_) < 0)) { + return NNACL_ERR; + } + int axis = param->axis_ > 0 ? param->axis_ : (int)input->shape_size_ + param->axis_; + if (axis < 0 || axis >= MAX_SHAPE_SIZE) { + return NNACL_BUFFER_OVERFLOW; + } + output->shape_[axis] /= 2; + return NNACL_OK; +} + +REG_INFER(GLU, PrimType_GLU, GluInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/glu_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/glu_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..95bbc8b24a949af694ec17e7b2aea1788cfcfee2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/glu_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GLU_INFER_H +#define MINDSPORE_NNACL_GLU_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GluInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GLU_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/grid_sampler_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/grid_sampler_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..70738e3c4f5db744d56f84828ae6054b6b4fa073 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/grid_sampler_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/infer/grid_sampler_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/grid_sampler_parameter.h" + +int GridSamplerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != inputs[1]->shape_size_) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (input->shape_size_ < DIMENSION_4D) { + return NNACL_INPUT_TENSOR_ERROR; + } + SetShapeTensor(output, input); + for (size_t i = DIMENSION_2D; i < input->shape_size_; ++i) { + output->shape_[i] = inputs[1]->shape_[i - 1]; + } + return NNACL_OK; +} + +REG_INFER(GridSampler, PrimType_Inner_GridSampler, GridSamplerInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/grid_sampler_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/grid_sampler_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d312d76f3a1db5f7f4fd08c7fc1eedb5cd47cb2c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/grid_sampler_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GRID_SAMPLER_INFER_H +#define MINDSPORE_NNACL_GRID_SAMPLER_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GridSamplerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GRID_SAMPLER_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/group_conv2d_grad_input_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/group_conv2d_grad_input_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..de5bf3faaf6551139a271897507e4d6a869a0c0d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/group_conv2d_grad_input_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/group_conv2d_grad_input_infer.h" + +int GroupConv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (inputs_size < 2 || outputs_size != 1) { + return NNACL_ERR; + } + + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + SetDataTypeFormat(out, in0); + + size_t shape_size = in0->shape_size_; + if (shape_size > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int shape_[MAX_SHAPE_SIZE]; + for (size_t i = 0; i < shape_size; i++) { + shape_[i] = in0->shape_[i]; + } + SetShapeArray(out, shape_, shape_size); + + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/group_conv2d_grad_input_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/group_conv2d_grad_input_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..e9cf44487533ca7ba32994e183f21d07870226d6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/group_conv2d_grad_input_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GROUP_CONV2D_GRAD_INPUT_INFER_H +#define MINDSPORE_NNACL_GROUP_CONV2D_GRAD_INPUT_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GroupConv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GROUP_CONV2D_GRAD_INPUT_INFER_H diff --git a/mindspore-lite/src/extendrt/kernel/cuda/unary.cc b/mindspore-lite/ops/kernel/cpu/nnacl/infer/group_norm_infer.c similarity index 46% rename from mindspore-lite/src/extendrt/kernel/cuda/unary.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/group_norm_infer.c index b4e9779b96dad5cfc62fb58a4c306c9c3ed7228b..f4ebe963d4480b146246bf370aa4ffd45d4478c7 100644 --- a/mindspore-lite/src/extendrt/kernel/cuda/unary.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/group_norm_infer.c @@ -14,24 +14,24 @@ * limitations under the License. */ -#include "src/extendrt/kernel/cuda/unary.h" -#include +#include "nnacl/infer/group_norm_infer.h" +#include "nnacl/infer/infer_register.h" -namespace mindspore::kernel { -int UnaryCudaKernel::Prepare() { - CudaKernel::Prepare(); - if (unary_helper_ == nullptr) { - unary_helper_ = std::make_shared>(type_name_); - helper_ = unary_helper_; +int GroupNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; } - int ret = ReSize(); - CHECK_NOT_EQUAL_RETURN(ret, RET_OK); - return RET_OK; -} -int UnaryCudaKernel::Run() { - int ret = unary_helper_->Process(input_device_ptrs_, output_device_ptrs_, work_device_ptrs_, stream_); - CHECK_NOT_EQUAL_RETURN(ret, RET_OK); - return RET_OK; + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + SetDataTypeFormat(output0, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input); + return NNACL_OK; } -// REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_Log, CudaKernelCreator) -} // namespace mindspore::kernel + +REG_INFER(GroupNorm, PrimType_GroupNormFusion, GroupNormInferShape) diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/shape.cc b/mindspore-lite/ops/kernel/cpu/nnacl/infer/group_norm_infer.h similarity index 54% rename from mindspore-lite/tools/graph_kernel/converter/expanders/shape.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/group_norm_infer.h index e5ebbcbae44bafe9b88279362073edfd254b3d95..dd758d044cb825f716d26050221b43b56fa674f7 100644 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/shape.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/group_norm_infer.h @@ -13,23 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_GROUP_NORM_INFER_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_GROUP_NORM_INFER_H_ -#include +#include "nnacl/infer/common_infer.h" -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" +#ifdef __cplusplus +extern "C" { +#endif -namespace mindspore::graphkernel::expanders { -class Shape : public OpDesc { - public: - Shape() {} - ~Shape() = default; +int GroupNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &input_x = inputs[0]; - auto result = gb.Shape(input_x); - return {result}; - } -}; -EXPANDER_OP_DESC_REGISTER("Shape", Shape); -} // namespace mindspore::graphkernel::expanders +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_GROUP_NORM_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/gru_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gru_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..cd0887089d51605c9b8ef0afb833c483488742e2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gru_infer.c @@ -0,0 +1,92 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/gru_infer.h" +#include "nnacl/infer/infer_register.h" + +int GruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 5, 6, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + const TensorC *weight_gate = inputs[1]; + const TensorC *weight_recurrence = inputs[2]; + const TensorC *bias = inputs[3]; + TensorC *output = outputs[0]; + for (int i = 0; i < 2; i++) { + SetDataTypeFormat(outputs[i], input); + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const int *in_shape = input->shape_; // seq_len, batch, input_size + const int *w_gate_shape = weight_gate->shape_; // num_direction, hidden_size * 3, input_size + const int *w_recu_shape = weight_recurrence->shape_; // num_direction, hidden_size * 3, hidden_size + const int *bias_shape = bias->shape_; // num_direction, hidden_size * 6 + if (input->shape_size_ != 3 || weight_gate->shape_size_ != 3 || weight_recurrence->shape_size_ != 3) { + return NNACL_ERR; + } + if (w_gate_shape[1] != w_recu_shape[1] || w_recu_shape[1] * 2 != bias_shape[1]) { + return NNACL_ERR; + } + if (inputs_size == 6) { + const int *seq_len_shape = inputs[5]->shape_; + if (seq_len_shape[0] > 1) { + return NNACL_ERR; + } + if (inputs[5]->shape_size_ != 1 && seq_len_shape[0] != in_shape[1]) { + return NNACL_ERR; + } + } + + int hidden_size = w_gate_shape[1] / 3; + // set output + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, in_shape, input->shape_size_); + out_shape[2] = hidden_size; + + GruParameter *param = (GruParameter *)parameter; + if (param->bidirectional_) { + int ret = ShapeInsert(out_shape, &out_shape_size, 1, 2); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + } else { + int ret = ShapeInsert(out_shape, &out_shape_size, 1, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + } + SetShapeArray(output, out_shape, out_shape_size); + // set hidden state + int state_shape[MAX_SHAPE_SIZE]; + size_t state_shape_size = 0; + ShapeSet(state_shape, &state_shape_size, in_shape, input->shape_size_); + state_shape[0] = param->bidirectional_ ? 2 : 1; + state_shape[2] = hidden_size; + SetShapeArray(outputs[1], state_shape, state_shape_size); + return NNACL_OK; +} + +REG_INFER(GRU, PrimType_GRU, GruInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/gru_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gru_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..14fc084aaacc883285c1232be343300b1067fc80 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/gru_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_GRU_INFER_H +#define MINDSPORE_NNACL_GRU_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/gru_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int GruInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_GRU_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/infer.h new file mode 100644 index 0000000000000000000000000000000000000000..cac96219734385a4ffc7f797179937ac1bd97328 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_INFER_H_ +#define MINDSPORE_NNACL_INFER_INFER_H_ + +#include "nnacl/tensor_c.h" +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +typedef int (*InferShape)(const TensorC *const *inputs, size_t input_size, TensorC **outputs, size_t output_size, + OpParameter *parameter); + +InferShape GetInferFunc(int prim_type); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/infer_register.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/infer_register.c new file mode 100644 index 0000000000000000000000000000000000000000..9b0c44c3f713e51f1de0682dd870951e688f8522 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/infer_register.c @@ -0,0 +1,450 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/infer/infer_register.h" + +#ifdef _MSC_VER +#include "nnacl/infer/activation_grad_infer.h" +#include "nnacl/infer/adam_infer.h" +#include "nnacl/infer/adam_weight_decay_infer.h" +#include "nnacl/infer/add_sub_grad_infer.h" +#include "nnacl/infer/addn_infer.h" +#include "nnacl/infer/affine_infer.h" +#include "nnacl/infer/all_gather_infer.h" +#include "nnacl/infer/apply_momentum_infer.h" +#include "nnacl/infer/argmin_max_infer.h" +#include "nnacl/infer/arithmetic_compare_infer.h" +#include "nnacl/infer/arithmetic_grad_infer.h" +#include "nnacl/infer/arithmetic_infer.h" +#include "nnacl/infer/assert_op_infer.h" +#include "nnacl/infer/assign_add_infer.h" +#include "nnacl/infer/assign_infer.h" +#include "nnacl/infer/attention_infer.h" +#include "nnacl/infer/encoder_layer_infer.h" +#include "nnacl/infer/audio_spectrogram_infer.h" +#include "nnacl/infer/batch_to_space_infer.h" +#include "nnacl/infer/bias_grad_infer.h" +#include "nnacl/infer/binary_cross_entropy_infer.h" +#include "nnacl/infer/bn_grad_infer.h" +#include "nnacl/infer/broadcast_to_infer.h" +#include "nnacl/infer/cast_infer.h" +#include "nnacl/infer/common_infer.h" +#include "nnacl/infer/concat_infer.h" +#include "nnacl/infer/constant_of_shape_infer.h" +#include "nnacl/infer/decoder_layer_infer.h" + +#ifdef MSLITE_ENABLE_CONTROLFLOW +#include "nnacl/infer/control/tensor_array_infer.h" +#include "nnacl/infer/control/tensor_array_read_infer.h" +#include "nnacl/infer/control/tensor_array_write_infer.h" +#include "nnacl/infer/control/tensorlist_fromtensor_infer.h" +#include "nnacl/infer/control/tensorlist_getitem_infer.h" +#include "nnacl/infer/control/tensorlist_reserve_infer.h" +#include "nnacl/infer/control/tensorlist_setitem_infer.h" +#include "nnacl/infer/control/tensorlist_stack_infer.h" +#endif +#include "nnacl/infer/conv2d_grad_filter_infer.h" +#include "nnacl/infer/conv2d_grad_input_infer.h" +#include "nnacl/infer/conv2d_infer.h" +#include "nnacl/infer/crop_and_resize_infer.h" +#include "nnacl/infer/crop_infer.h" +#include "nnacl/infer/cumsum_infer.h" +#include "nnacl/infer/deconv2d_infer.h" +#include "nnacl/infer/depth_to_space_infer.h" +#include "nnacl/infer/depthwise_conv2d_infer.h" +#include "nnacl/infer/detection_post_process_infer.h" +#include "nnacl/infer/dropout_grad_infer.h" +#include "nnacl/infer/dropout_infer.h" +#include "nnacl/infer/dynamic_quant_infer.h" +#include "nnacl/infer/embedding_lookup_infer.h" +#include "nnacl/infer/expand_dims_infer.h" +#include "nnacl/infer/fft_imag_infer.h" +#include "nnacl/infer/fft_real_infer.h" +#include "nnacl/infer/fill_infer.h" +#include "nnacl/infer/fillv2_infer.h" +#include "nnacl/infer/flatten_grad_infer.h" +#include "nnacl/infer/flatten_infer.h" +#include "nnacl/infer/full_connection_infer.h" +#include "nnacl/infer/fused_batchnorm_infer.h" +#include "nnacl/infer/gather_infer.h" +#include "nnacl/infer/gather_nd_infer.h" +#include "nnacl/infer/glu_infer.h" +#include "nnacl/infer/group_conv2d_grad_input_infer.h" +#include "nnacl/infer/gru_infer.h" +#include "nnacl/infer/instance_norm_infer.h" +#include "nnacl/infer/invert_permutation_infer.h" +#include "nnacl/infer/layer_norm_grad_infer.h" +#include "nnacl/infer/layer_norm_infer.h" +#include "nnacl/infer/lin_space_infer.h" +#include "nnacl/infer/log_softmax_infer.h" +#include "nnacl/infer/lstm_grad_data_infer.h" +#include "nnacl/infer/lstm_grad_infer.h" +#include "nnacl/infer/lstm_grad_weight_infer.h" +#include "nnacl/infer/lstm_infer.h" +#include "nnacl/infer/matmul_infer.h" +#include "nnacl/infer/max_min_grad_infer.h" +#include "nnacl/infer/mfcc_infer.h" +#include "nnacl/infer/nllloss_grad_infer.h" +#include "nnacl/infer/nllloss_infer.h" +#include "nnacl/infer/non_max_suppression_infer.h" +#include "nnacl/infer/one_hot_infer.h" +#include "nnacl/infer/pad_infer.h" +#include "nnacl/infer/pooling_grad_infer.h" +#include "nnacl/infer/pooling_infer.h" +#include "nnacl/infer/power_infer.h" +#include "nnacl/infer/prior_box_infer.h" +#include "nnacl/infer/quant_dtype_cast_infer.h" +#include "nnacl/infer/ragged_range_infer.h" +#include "nnacl/infer/random_normal_infer.h" +#include "nnacl/infer/random_standard_normal_infer.h" +#include "nnacl/infer/range_infer.h" +#include "nnacl/infer/rank_infer.h" +#include "nnacl/infer/reduce_infer.h" +#include "nnacl/infer/reduce_scatter_infer.h" +#include "nnacl/infer/reshape_infer.h" +#include "nnacl/infer/resize_grad_infer.h" +#include "nnacl/infer/resize_infer.h" +#include "nnacl/infer/rfft_infer.h" +#include "nnacl/infer/roi_pooling_infer.h" +#include "nnacl/infer/scatter_nd_infer.h" +#include "nnacl/infer/scatter_nd_update_infer.h" +#include "nnacl/infer/select_infer.h" +#include "nnacl/infer/sgd_infer.h" +#include "nnacl/infer/invalid_infer.h" +#ifndef RUNTIME_PASS_CLIP +#include "nnacl/infer/shape_fusion_infer.h" +#endif +#include "nnacl/infer/shape_infer.h" +#include "nnacl/infer/size_infer.h" +#include "nnacl/infer/slice_infer.h" +#include "nnacl/infer/softmax_cross_entropy_infer.h" +#include "nnacl/infer/softmax_infer.h" +#include "nnacl/infer/space_to_batch_infer.h" +#include "nnacl/infer/space_to_batch_nd_infer.h" +#include "nnacl/infer/space_to_depth_infer.h" +#include "nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h" +#include "nnacl/infer/sparse_to_dense_infer.h" +#include "nnacl/infer/splice_infer.h" +#include "nnacl/infer/split_infer.h" +#include "nnacl/infer/split_with_over_lap_infer.h" +#include "nnacl/infer/squeeze_infer.h" +#include "nnacl/infer/stack_infer.h" +#include "nnacl/infer/strided_slice_grad_infer.h" +#include "nnacl/infer/strided_slice_infer.h" +#ifdef MSLITE_ENABLE_STRING_KERNEL +#include "nnacl/infer/string/custom_extract_features_infer.h" +#include "nnacl/infer/string/custom_normalize_infer.h" +#include "nnacl/infer/string/custom_predict_infer.h" +#include "nnacl/infer/string/hashtable_lookup_infer.h" +#include "nnacl/infer/string/lsh_projection_infer.h" +#include "nnacl/infer/string/skip_gram_infer.h" +#endif +#include "nnacl/infer/tile_infer.h" +#include "nnacl/infer/topk_infer.h" +#include "nnacl/infer/transpose_infer.h" +#include "nnacl/infer/uniform_real_infer.h" +#include "nnacl/infer/unique_infer.h" +#include "nnacl/infer/unsorted_segment_sum_infer.h" +#include "nnacl/infer/unsqueeze_infer.h" +#include "nnacl/infer/unstack_infer.h" +#include "nnacl/infer/where_infer.h" +#include "nnacl/infer/isfinite_infer.h" +#include "nnacl/infer/fse_decoder_infer.h" +#include "nnacl/infer/custom_gru_infer.h" + +InferShape g_infer_func[PrimType_MAX] = {0}; +InferShape g_inner_op_infer_func[PrimType_InnerOpMax - PrimType_InnerOpMin] = {0}; +void RegAllInferFunc1() { + g_infer_func[PrimType_NONE] = NULL; + g_infer_func[PrimType_Abs] = CommonInferShape; + g_infer_func[PrimType_AbsGrad] = CommonGradInferShape; + g_infer_func[PrimType_Activation] = CommonInferShape; + g_infer_func[PrimType_ActivationGrad] = ActivationGradInferShape; + g_infer_func[PrimType_Adam] = AdamInferShape; + g_infer_func[PrimType_AdamWeightDecay] = AdamWeightDecayInferShape; + g_infer_func[PrimType_AdderFusion] = Conv2dInferShape; + g_infer_func[PrimType_AddFusion] = ArithmeticInferShape; + g_infer_func[PrimType_AddGrad] = AddSubGradInferShape; + g_infer_func[PrimType_AddN] = AddnInferShape; + g_infer_func[PrimType_Affine] = AffineInferShape; + g_infer_func[PrimType_All] = NULL; + g_infer_func[PrimType_AllGather] = AllGatherInferShape; + g_infer_func[PrimType_ApplyMomentum] = ApplyMomentumInferShape; + g_infer_func[PrimType_ArgMaxFusion] = ArgMinMaxInferShape; + g_infer_func[PrimType_ArgMinFusion] = ArgMinMaxInferShape; + g_infer_func[PrimType_Assert] = AssertOpInferShape; + g_infer_func[PrimType_Assign] = AssignInferShape; + g_infer_func[PrimType_AssignAdd] = AssignAddInferShape; + g_infer_func[PrimType_Attention] = AttentionInferShape; + g_infer_func[PrimType_AudioSpectrogram] = AudioSpectrogramInferShape; + g_infer_func[PrimType_AvgPoolFusion] = PoolingInferShape; + g_infer_func[PrimType_AvgPoolGrad] = PoolingGradInferShape; + g_infer_func[PrimType_BatchNorm] = CommonInferShape; + g_infer_func[PrimType_BatchNormGrad] = BnGradInferShape; + g_infer_func[PrimType_BatchToSpace] = BatchToSpaceInferShape; + g_infer_func[PrimType_BatchToSpaceND] = NULL; + g_infer_func[PrimType_BiasAdd] = ArithmeticInferShape; + g_infer_func[PrimType_BiasAddGrad] = BiasGradInferShape; + g_infer_func[PrimType_BinaryCrossEntropy] = BinaryCrossEntropyInferShape; + g_infer_func[PrimType_BinaryCrossEntropyGrad] = CommonInferShape; + g_infer_func[PrimType_BroadcastTo] = BroadcastToInferShape; + g_infer_func[PrimType_Call] = InvalidInferShape; + g_infer_func[PrimType_Cast] = CastInferShape; + g_infer_func[PrimType_Ceil] = CommonInferShape; + g_infer_func[PrimType_Clip] = CommonInferShape; + g_infer_func[PrimType_Concat] = ConcatInferShape; + g_infer_func[PrimType_ConstantOfShape] = ConstantOfShapeInferShape; + g_infer_func[PrimType_Conv2DBackpropFilterFusion] = Conv2dGradFilterInferShape; + g_infer_func[PrimType_Conv2DBackpropInputFusion] = Conv2dGradInputInferShape; + g_infer_func[PrimType_Conv2DFusion] = Conv2dInferShape; + g_infer_func[PrimType_Conv2dTransposeFusion] = Deconv2dInferShape; + g_infer_func[PrimType_Cos] = CommonInferShape; + g_infer_func[PrimType_Crop] = CropInferShape; + g_infer_func[PrimType_CropAndResize] = CropAndResizeInferShape; + g_infer_func[PrimType_CumSum] = CumsumInferShape; + g_infer_func[PrimType_Custom] = NULL; +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_CustomExtractFeatures] = CustomExtractFeaturesInferShape; +#endif +} + +void RegAllInferFunc2() { +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_CustomNormalize] = CustomNormalizeInferShape; + g_infer_func[PrimType_CustomPredict] = CustomPredictInferShape; +#endif + g_infer_func[PrimType_DeConv2DGradFilter] = NULL; + g_infer_func[PrimType_Depend] = CommonInferShape; + g_infer_func[PrimType_DepthToSpace] = DepthToSpaceInferShape; + g_infer_func[PrimType_DetectionPostProcess] = DetectionPostProcessInferShape; + g_infer_func[PrimType_DivFusion] = ArithmeticInferShape; + g_infer_func[PrimType_DivGrad] = ArithmeticGradInferShape; + g_infer_func[PrimType_Dropout] = DropoutInferShape; + g_infer_func[PrimType_DropoutGrad] = DropoutGradInferShape; + g_infer_func[PrimType_DynamicQuant] = DynamicQuantInferShape; + g_infer_func[PrimType_Eltwise] = ArithmeticInferShape; + g_infer_func[PrimType_Elu] = CommonInferShape; + g_infer_func[PrimType_EmbeddingLookupFusion] = EmbeddingLookupInferShape; + g_infer_func[PrimType_Equal] = ArithmeticCompareInferShape; + g_infer_func[PrimType_Erf] = CommonInferShape; + g_infer_func[PrimType_ExpandDims] = ExpandDimsInferShape; + g_infer_func[PrimType_ExpFusion] = CommonInferShape; + g_infer_func[PrimType_FakeQuantWithMinMaxVars] = CommonInferShape; + g_infer_func[PrimType_FakeQuantWithMinMaxVarsPerChannel] = NULL; + g_infer_func[PrimType_FftImag] = FftImagInferShape; + g_infer_func[PrimType_FftReal] = FftRealInferShape; + g_infer_func[PrimType_Fill] = FillInferShape; + g_infer_func[PrimType_FillV2] = FillInferShape; + g_infer_func[PrimType_Flatten] = FlattenInferShape; + g_infer_func[PrimType_FlattenGrad] = FlattenGradInferShape; + g_infer_func[PrimType_Floor] = CommonInferShapeWithOneInput; + g_infer_func[PrimType_FloorDiv] = ArithmeticInferShape; + g_infer_func[PrimType_FloorMod] = ArithmeticInferShape; + g_infer_func[PrimType_FullConnection] = FullConnectionInferShape; + g_infer_func[PrimType_FusedBatchNorm] = FusedBatchNormInferShape; + g_infer_func[PrimType_Gather] = GatherInferShape; + g_infer_func[PrimType_GatherNd] = GatherNdInferShape; + g_infer_func[PrimType_GenOP] = NULL; + g_infer_func[PrimType_GLU] = GluInferShape; + g_infer_func[PrimType_Greater] = ArithmeticCompareInferShape; + g_infer_func[PrimType_GreaterEqual] = ArithmeticCompareInferShape; + g_infer_func[PrimType_GRU] = GruInferShape; +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_HashtableLookup] = HashtableLoopupInferShape; +#endif + g_infer_func[PrimType_InstanceNorm] = InstanceNormInferShape; + g_infer_func[PrimType_InvertPermutation] = InvertPermutationInferShape; + g_infer_func[PrimType_IsFinite] = IsFiniteInferShape; + g_infer_func[PrimType_L2NormalizeFusion] = CommonInferShape; + g_infer_func[PrimType_LayerNormFusion] = LayerNormInferShape; + g_infer_func[PrimType_LayerNormGrad] = LayerNormGradInferShape; + g_infer_func[PrimType_LeakyRelu] = CommonInferShape; + g_infer_func[PrimType_Less] = ArithmeticCompareInferShape; + g_infer_func[PrimType_LessEqual] = ArithmeticCompareInferShape; + g_infer_func[PrimType_LinSpace] = LinSpaceInferShape; +} + +void RegAllInferFunc3() { + g_infer_func[PrimType_Log] = CommonInferShape; + g_infer_func[PrimType_LogGrad] = CommonGradInferShape; + g_infer_func[PrimType_LogicalAnd] = ArithmeticInferShape; + g_infer_func[PrimType_LogicalNot] = CommonInferShape; + g_infer_func[PrimType_LogicalOr] = ArithmeticInferShape; + g_infer_func[PrimType_LogSoftmax] = LogSoftmaxInferShape; + g_infer_func[PrimType_LpNormalization] = NULL; + g_infer_func[PrimType_LRN] = CommonInferShapeWithNHWC; +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_LshProjection] = LshProjectionInferShape; +#endif + g_infer_func[PrimType_LSTM] = LstmInferShape; + g_infer_func[PrimType_LSTMGrad] = LstmGradInferShape; + g_infer_func[PrimType_LSTMGradData] = LstmGradDataInferShape; + g_infer_func[PrimType_LSTMGradWeight] = LstmGradWeightInferShape; + g_infer_func[PrimType_MatMulFusion] = MatmulInferShape; + g_infer_func[PrimType_Maximum] = ArithmeticInferShape; + g_infer_func[PrimType_MaximumGrad] = MaxMinGradInferShape; + g_infer_func[PrimType_MaxPoolFusion] = PoolingInferShape; + g_infer_func[PrimType_MaxPoolGrad] = PoolingGradInferShape; + g_infer_func[PrimType_SwitchLayer] = InvalidInferShape; + g_infer_func[PrimType_Mfcc] = MfccInferShape; + g_infer_func[PrimType_MIN] = NULL; + g_infer_func[PrimType_Minimum] = ArithmeticInferShape; + g_infer_func[PrimType_MinimumGrad] = MaxMinGradInferShape; + g_infer_func[PrimType_Mod] = ArithmeticInferShape; + g_infer_func[PrimType_MulFusion] = ArithmeticInferShape; + g_infer_func[PrimType_MulGrad] = ArithmeticGradInferShape; + g_infer_func[PrimType_Neg] = CommonInferShape; + g_infer_func[PrimType_NegGrad] = CommonGradInferShape; + g_infer_func[PrimType_NLLLoss] = NLLLossInferShape; + g_infer_func[PrimType_NLLLossGrad] = NLLLossGradInferShape; + g_infer_func[PrimType_NonMaxSuppression] = NonMaxSuppressionInferShape; + g_infer_func[PrimType_NonZero] = NULL; + g_infer_func[PrimType_NotEqual] = ArithmeticCompareInferShape; + g_infer_func[PrimType_OneHot] = OneHotInferShape; + g_infer_func[PrimType_OnesLike] = NULL; + g_infer_func[PrimType_PadFusion] = PadInferShape; + g_infer_func[PrimType_PartialFusion] = InvalidInferShape; + g_infer_func[PrimType_PowerGrad] = CommonGradInferShape; + g_infer_func[PrimType_PowFusion] = PowerInferShape; + g_infer_func[PrimType_PReLUFusion] = CommonInferShape; + g_infer_func[PrimType_PriorBox] = PriorBoxInferShape; + g_infer_func[PrimType_QuantDTypeCast] = QuantDtypeCastInferShape; + g_infer_func[PrimType_RaggedRange] = RaggedRangeInferShape; + g_infer_func[PrimType_RandomNormal] = RandomNormalInferShape; + g_infer_func[PrimType_RandomStandardNormal] = RandomStandardNormalInferShape; + g_infer_func[PrimType_Range] = RangeInferShape; + g_infer_func[PrimType_Rank] = RankInferShape; +} + +void RegAllInferFunc4() { + g_infer_func[PrimType_RealDiv] = ArithmeticInferShape; + g_infer_func[PrimType_Reciprocal] = CommonInferShape; + g_infer_func[PrimType_ReduceFusion] = ReduceInferShape; + g_infer_func[PrimType_ReduceScatter] = ReduceScatterInferShape; + g_infer_func[PrimType_Reshape] = ReshapeInferShape; + g_infer_func[PrimType_Resize] = ResizeInferShape; + g_infer_func[PrimType_ResizeGrad] = ResizeGradInferShape; + g_infer_func[PrimType_ReverseSequence] = CommonInferShape; + g_infer_func[PrimType_ReverseV2] = CommonInferShape; + g_infer_func[PrimType_Rfft] = RfftInferShape; + g_infer_func[PrimType_ROIPooling] = ROIPoolingInferShape; + g_infer_func[PrimType_Round] = CommonInferShape; + g_infer_func[PrimType_Rsqrt] = CommonInferShape; + g_infer_func[PrimType_RsqrtGrad] = NULL; + g_infer_func[PrimType_ScaleFusion] = CommonInferShape; + g_infer_func[PrimType_ScatterNd] = ScatterNdInferShape; + g_infer_func[PrimType_ScatterNdUpdate] = ScatterNdUpdateInferShape; + g_infer_func[PrimType_TensorScatterAdd] = ScatterNdUpdateInferShape; + g_infer_func[PrimType_Select] = SelectInferShape; + g_infer_func[PrimType_SGD] = SgdInferShape; + g_infer_func[PrimType_Shape] = ShapeInferShape; + g_infer_func[PrimType_SigmoidCrossEntropyWithLogits] = CommonInferShape; + g_infer_func[PrimType_SigmoidCrossEntropyWithLogitsGrad] = CommonInferShape; + g_infer_func[PrimType_Sin] = CommonInferShape; + g_infer_func[PrimType_Size] = SizeInferShape; +#ifdef MSLITE_ENABLE_STRING_KERNEL + g_infer_func[PrimType_SkipGram] = SkipGramInferShape; +#endif + g_infer_func[PrimType_SliceFusion] = SliceInferShape; + g_infer_func[PrimType_SmoothL1Loss] = CommonInferShape; + g_infer_func[PrimType_SmoothL1LossGrad] = CommonInferShape; + g_infer_func[PrimType_Softmax] = SoftMaxInferShape; + g_infer_func[PrimType_SoftmaxCrossEntropyWithLogits] = SoftmaxCrossEntropyInferShape; + g_infer_func[PrimType_SpaceToBatch] = SpaceToBatchInferShape; + g_infer_func[PrimType_SpaceToBatchND] = SpaceToBatchNdInferShape; + g_infer_func[PrimType_SpaceToDepth] = SpaceToDepthInferShape; + g_infer_func[PrimType_SparseSoftmaxCrossEntropyWithLogits] = SparseSoftmaxCrossEntropyWithLogitsInferShape; + g_infer_func[PrimType_SparseToDense] = SparseToDenseInferShape; + g_infer_func[PrimType_Splice] = SpliceInferShape; + g_infer_func[PrimType_Split] = SplitInferShape; + g_infer_func[PrimType_SplitWithOverlap] = SplitWithOverlapInferShape; + g_infer_func[PrimType_Sqrt] = CommonInferShape; + g_infer_func[PrimType_SqrtGrad] = NULL; + g_infer_func[PrimType_Square] = CommonInferShape; + g_infer_func[PrimType_SquaredDifference] = ArithmeticInferShape; + g_infer_func[PrimType_Squeeze] = SqueezeInferShape; + g_infer_func[PrimType_Stack] = StackInferShape; + g_infer_func[PrimType_StridedSlice] = StridedSliceInferShape; + g_infer_func[PrimType_StridedSliceGrad] = StridedSliceGradInferShape; + g_infer_func[PrimType_SubFusion] = ArithmeticInferShape; + g_infer_func[PrimType_SubGrad] = AddSubGradInferShape; +} + +void RegAllInferFunc5() { + g_infer_func[PrimType_Switch] = InvalidInferShape; +#ifdef MSLITE_ENABLE_CONTROLFLOW + g_infer_func[PrimType_TensorArray] = TensorArrayInferShape; + g_infer_func[PrimType_TensorArrayRead] = TensorArrayReadInferShape; + g_infer_func[PrimType_TensorArrayWrite] = TensorArrayWriteInferShape; + g_infer_func[PrimType_TensorListFromTensor] = TensorListFromTensorInferShape; + g_infer_func[PrimType_TensorListGetItem] = TensorListGetItemInferShape; + g_infer_func[PrimType_TensorListReserve] = TensorListReserveInferShape; + g_infer_func[PrimType_TensorListSetItem] = TensorListSetItemInferShape; + g_infer_func[PrimType_TensorListStack] = TensorListStackInferShape; +#endif + g_infer_func[PrimType_TileFusion] = TileInferShape; + g_infer_func[PrimType_TopKFusion] = TopKInferShape; + g_infer_func[PrimType_Transpose] = TransposeInferShape; + g_infer_func[PrimType_UniformReal] = UniformRealInferShape; + g_infer_func[PrimType_Unique] = UniqueInferShape; + g_infer_func[PrimType_UnsortedSegmentSum] = UnsortedSegmentSumInferShape; + g_infer_func[PrimType_Unsqueeze] = UnsqueezeInferShape; + g_infer_func[PrimType_Unstack] = UnstackInferShape; + g_infer_func[PrimType_Where] = WhereInferShape; + g_infer_func[PrimType_ZerosLike] = CommonInferShape; + + // fused operators. + g_inner_op_infer_func[PrimType_Inner_GltextureToOpencl - PrimType_InnerOpMin] = NULL; + g_inner_op_infer_func[PrimType_Inner_Identity - PrimType_InnerOpMin] = NULL; +#ifndef RUNTIME_PASS_CLIP + g_inner_op_infer_func[PrimType_Inner_ShapeFusion - PrimType_InnerOpMin] = ShapeFusionInferShape; + g_inner_op_infer_func[PrimType_Inner_EncoderLayer - PrimType_InnerOpMin] = EncoderLayerInferShape; + g_inner_op_infer_func[PrimType_Inner_DecoderLayer - PrimType_InnerOpMin] = DecoderLayerInferShape; + g_inner_op_infer_func[PrimType_Inner_FseDecode - PrimType_InnerOpMin] = FseDecoderInferShape; +#endif + g_inner_op_infer_func[PrimType_Inner_CustomGru - PrimType_InnerOpMin] = CustomGruInferShape; + g_inner_op_infer_func[PrimType_Inner_ToFormat - PrimType_InnerOpMin] = NULL; +} + +#else +__attribute__((init_priority(101))) InferShape g_infer_func[PrimType_MAX] = {0}; +__attribute__((init_priority(101))) InferShape g_inner_op_infer_func[PrimType_InnerOpMax - PrimType_InnerOpMin] = {0}; +#endif // _MSC_VER + +InferShape GetInferFunc(int prim_type) { +#ifdef _MSC_VER + if (g_infer_func[PrimType_Abs] == NULL) { + RegAllInferFunc1(); + RegAllInferFunc2(); + RegAllInferFunc3(); + RegAllInferFunc4(); + RegAllInferFunc5(); + } +#endif + if (prim_type > PrimType_MIN && prim_type < PrimType_MAX) { + return g_infer_func[prim_type]; + } else if (prim_type >= PrimType_InnerOpMin && prim_type < PrimType_InnerOpMax) { + return g_inner_op_infer_func[prim_type - PrimType_InnerOpMin]; + } + return NULL; +} + +void RegInfer(int prim_type, InferShape func) { + if (prim_type > PrimType_MIN && prim_type < PrimType_MAX) { + g_infer_func[prim_type] = func; + } else if (prim_type >= PrimType_InnerOpMin && prim_type < PrimType_InnerOpMax) { + g_inner_op_infer_func[prim_type - PrimType_InnerOpMin] = func; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/infer_register.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/infer_register.h new file mode 100644 index 0000000000000000000000000000000000000000..c33ec1bc0bba98052ebddedb1e270cac1216300a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/infer_register.h @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_INFER_REGISTER_H_ +#define MINDSPORE_NNACL_INFER_INFER_REGISTER_H_ + +#include "nnacl/tensor_c.h" +#include "nnacl/op_base.h" +#include "nnacl/infer/infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void RegInfer(int prim_type, InferShape func); + +#ifdef _MSC_VER +#define REG_INFER(op, type, func) +#else +#define REG_INFER(op, type, func) \ + __attribute__((constructor(102))) void Reg##op##Infer() { RegInfer(type, func); } +#endif + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_INFER_REGISTER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/instance_norm_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/instance_norm_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..68d77030ccb5139b7b62143d40c1af1357946a80 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/instance_norm_infer.c @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/instance_norm_infer.h" +#include "nnacl/infer/crop_infer.h" +#include "nnacl/infer/infer_register.h" + +int InstanceNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + TensorC *output = outputs[0]; + SetDataTypeFormat(output, inputs[0]); + if (output->format_ == Format_NC4HW4) { + output->format_ = Format_NHWC; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, inputs[0]); + if (inputs[0]->format_ != Format_NC4HW4) { + return NNACL_OK; + } + if (output->shape_size_ <= DIMENSION_2D) { + return NNACL_OK; + } + int channel = output->shape_[1]; + ShapeErase(output->shape_, &output->shape_size_, 1); + ShapePush(output->shape_, &output->shape_size_, channel); + return NNACL_OK; +} +REG_INFER(InstanceNorm, PrimType_InstanceNorm, InstanceNormInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/instance_norm_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/instance_norm_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..bdcd1c6f19bd757f3ffc6103705702346a36cbc5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/instance_norm_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INSTANCE_NORM_INFER_H +#define MINDSPORE_NNACL_INSTANCE_NORM_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int InstanceNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INSTANCE_NORM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/invalid_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/invalid_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2951a05c16c8cd7f29770d1e3dcf2812daa03db3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/invalid_infer.c @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/invalid_infer.h" +#include "nnacl/infer/infer_register.h" + +int InvalidInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + return NNACL_INFER_INVALID; +} + +REG_INFER(PartialFusion, PrimType_PartialFusion, InvalidInferShape) +REG_INFER(Switch, PrimType_Switch, InvalidInferShape) +REG_INFER(Call, PrimType_Call, InvalidInferShape) +REG_INFER(SwitchLayer, PrimType_SwitchLayer, InvalidInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/invalid_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/invalid_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a77a2d5c32a3442e413228f72b728ed2deac0fa0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/invalid_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INVALID_INFER_H +#define MINDSPORE_NNACL_INVALID_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int InvalidInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INVALID_INFER_H diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_runtime.cc b/mindspore-lite/ops/kernel/cpu/nnacl/infer/invert_permutation_infer.c similarity index 42% rename from mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_runtime.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/invert_permutation_infer.c index 0e3069ffdd98281688ddcb9c801f2da35bf9f2ac..4e4d2d1557c5ad96f98b11bcfa38483fcd045969 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_runtime.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/invert_permutation_infer.c @@ -14,39 +14,30 @@ * limitations under the License. */ -#include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h" -#include -#include +#include "nnacl/infer/invert_permutation_infer.h" +#include "nnacl/infer/infer_register.h" -namespace mindspore::lite { -int TensorRTRuntime::Init() { - if (is_init_) { - return RET_OK; +int InvertPermutationInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; } - builder_ = nvinfer1::createInferBuilder(this->logger_); - if (builder_ == nullptr) { - MS_LOG(ERROR) << "create infer builder failed."; - return RET_ERROR; - } - builder_->setMaxBatchSize(MAX_BATCH_SIZE); - allocator_ = new (std::nothrow) TensorRTAllocator(); - if (allocator_ == nullptr) { - MS_LOG(ERROR) << "Create allocator failed."; - return RET_ERROR; - } - is_init_ = true; - return RET_OK; -} -TensorRTRuntime::~TensorRTRuntime() { - if (builder_ != nullptr) { - builder_->destroy(); - builder_ = nullptr; + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; } - if (allocator_ != nullptr) { - allocator_->ClearDeviceMem(); - delete allocator_; - allocator_ = nullptr; + if (input->data_type_ != kNumberTypeInt32) { + return NNACL_ERR; } + if (input->shape_size_ != 1) { + return NNACL_ERR; + } + SetShapeTensor(output, input); + return NNACL_OK; } -} // namespace mindspore::lite + +REG_INFER(InvertPermutation, PrimType_InvertPermutation, InvertPermutationInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/invert_permutation_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/invert_permutation_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..580ba06fbc7831a45b5e6df285ed28aa3bb4b1cb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/invert_permutation_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INVERT_PERMUTATION_INFER_H +#define MINDSPORE_NNACL_INVERT_PERMUTATION_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int InvertPermutationInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INVERT_PERMUTATION_INFER_H diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/transpose.cc b/mindspore-lite/ops/kernel/cpu/nnacl/infer/isfinite_infer.c similarity index 42% rename from mindspore-lite/tools/graph_kernel/converter/expanders/transpose.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/isfinite_infer.c index 3de6fb9e116e88d34dc315fa9d22ab837642c854..b55853fc6df1c44a3a219fe1ab0d3123452ea7a0 100644 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/transpose.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/isfinite_infer.c @@ -14,29 +14,29 @@ * limitations under the License. */ -#include +#include "nnacl/infer/isfinite_infer.h" +#include "nnacl/infer/infer_register.h" -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "tools/graph_kernel/converter/expanders/activation.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" - -namespace mindspore::graphkernel::expanders { -class Transpose : public OpDesc { - public: - Transpose() { - std::initializer_list attrs{"perm"}; - (void)validators_.emplace_back(std::make_unique(attrs)); +int IsFiniteInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; } - ~Transpose() = default; - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &input_x = inputs[0]; - auto perm = GetValue(attrs_["perm"]); - auto result = gb.Transpose(input_x, perm); - return {result}; + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + output->data_type_ = kNumberTypeBool; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; } -}; -EXPANDER_OP_DESC_REGISTER("Transpose", Transpose); -} // namespace mindspore::graphkernel::expanders + for (size_t i = 0; i < input->shape_size_; i++) { + output->shape_[i] = input->shape_[i]; + } + output->shape_size_ = input->shape_size_; + return NNACL_OK; +} + +REG_INFER(IsFinite, PrimType_IsFinite, IsFiniteInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/isfinite_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/isfinite_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..57f912a3dd9ad24b9129bb8364dfc10b3e12a0d6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/isfinite_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ISFINITE_INFER_H_ +#define MINDSPORE_NNACL_ISFINITE_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int IsFiniteInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ISFINITE_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/layer_norm_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/layer_norm_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..82b8854da33ba7f7a52fd88ded39db80f9ed797a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/layer_norm_grad_infer.c @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/infer/layer_norm_grad_infer.h" +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32_grad/layernormgrad_parameter.h" +#include "nnacl/infer/infer_register.h" + +int LayerNormGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + LayerNormGradParameter *param = (LayerNormGradParameter *)parameter; + const TensorC *input_x = inputs[0]; + TensorC *output_dx = outputs[0]; + TensorC *output_dg = outputs[1]; + TensorC *output_db = outputs[2]; + SetDataTypeFormat(output_dx, input_x); + SetDataTypeFormat(output_dg, input_x); + SetDataTypeFormat(output_db, input_x); + SetShapeTensor(output_dx, input_x); + int begin_params_axis = param->begin_params_axis_; + if (param->begin_params_axis_ < 0) { + begin_params_axis += (int)(input_x->shape_size_); + } + size_t size = 0; + if (input_x->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + for (int i = begin_params_axis; i < input_x->shape_size_; i++) { + if (size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + output_dg->shape_[size] = input_x->shape_[i]; + output_db->shape_[size] = input_x->shape_[i]; + size++; + } + output_db->shape_size_ = size; + output_dg->shape_size_ = size; + return NNACL_OK; +} + +REG_INFER(LayerNormGrad, PrimType_LayerNormGrad, LayerNormGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/layer_norm_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/layer_norm_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d554b3cd4863b5c6870fbcf20e45f19aad9036bc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/layer_norm_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ +#define MINDSPORE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/layer_norm_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/layer_norm_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..5f460b16563ed8395fa263fa065f02eb22788943 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/layer_norm_infer.c @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/layer_norm_infer.h" +#include "nnacl/infer/infer_register.h" + +int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if ((inputs_size != 1 && inputs_size != 3) || (outputs_size != 1 && outputs_size != 3)) { + return NNACL_INPUT_TENSOR_ERROR; + } + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + LayerNormParameter *param = (LayerNormParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (input->shape_size_ > COMM_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (param->begin_params_axis_ < (-1 * (int)(input->shape_size_)) || + param->begin_params_axis_ >= (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + param->begin_norm_axis_ = + param->begin_norm_axis_ < 0 ? param->begin_norm_axis_ + ((int)(input->shape_size_)) : param->begin_norm_axis_; + SetShapeTensor(output, input); + // take care of other outputs + if (outputs_size == 3) { + TensorC *output_mean = outputs[1]; + TensorC *output_var = outputs[2]; + SetDataTypeFormat(output_mean, input); + SetDataTypeFormat(output_var, input); + int size = 0; + NNACL_CHECK_TRUE_RET(param->begin_norm_axis_ <= MAX_SHAPE_SIZE, NNACL_ERR); + for (; size < param->begin_norm_axis_; size++) { + output_mean->shape_[size] = input->shape_[size]; + output_var->shape_[size] = input->shape_[size]; + } + output_mean->shape_size_ = (size_t)size; + output_var->shape_size_ = (size_t)size; + } + + return NNACL_OK; +} + +REG_INFER(LayerNormFusion, PrimType_LayerNormFusion, LayerNormInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/layer_norm_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/layer_norm_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..03fdc4293601ea60ee492bdf4f765843d10d1fc2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/layer_norm_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LAYER_NORM_INFER_H +#define MINDSPORE_NNACL_LAYER_NORM_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/layer_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LAYER_NORM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/lin_space_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lin_space_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..5b702e79865e679c8b6632fee803c07d24342110 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lin_space_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/lin_space_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int LinSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + output->data_type_ = input->data_type_; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(inputs[2]) < 1) { + return NNACL_ERR; + } + int *num = (int *)(inputs[2]->data_); + if (num == NULL) { + return NNACL_INFER_INVALID; + } + output->shape_size_ = 1; + output->shape_[0] = num[0]; + return NNACL_OK; +} + +REG_INFER(LinSpace, PrimType_LinSpace, LinSpaceInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/lin_space_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lin_space_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..18a0d015d5978ac00f826b82a3cef29bdbed9e49 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lin_space_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LIN_SPACE_INFER_H +#define MINDSPORE_NNACL_LIN_SPACE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LinSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LIN_SPACE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/log_softmax_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/log_softmax_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b43a6ffd6aef01be93edf1b9e281cb227bdc1f5b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/log_softmax_infer.c @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/log_softmax_infer.h" +#include "nnacl/infer/infer_register.h" + +int LogSoftmaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + const int input_size_limit = 1; + const int output_size_limit = 1; + if (inputs_size != input_size_limit || outputs_size != output_size_limit) { + return NNACL_ERR; + } + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > 5) { + return NNACL_ERR; + } + SetShapeTensor(output, input); + SoftmaxParameter *param = (SoftmaxParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (param->axis_ < (-1 * (int)(input->shape_size_)) || param->axis_ >= (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + return NNACL_OK; +} + +REG_INFER(LogSoftmax, PrimType_LogSoftmax, LogSoftmaxInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/log_softmax_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/log_softmax_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..4603bb148407f6ecee2a7e752875e6e3043e1733 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/log_softmax_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_LOG_SOFTMAX_INFER_H +#define MINDSPORE_LITE_NNACL_LOG_SOFTMAX_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/softmax_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LogSoftmaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_NNACL_LOG_SOFTMAX_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_data_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_data_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..131086b6efbb518819af6c87aed7fe4bca3cb703 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_data_infer.c @@ -0,0 +1,60 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/lstm_grad_data_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/lstm_fp32.h" +#include "nnacl/fp32_grad/lstm_grad_fp32.h" + +int LstmGradDataInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 9, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + LstmGradParameter *p = (LstmGradParameter *)parameter; + const TensorC *Y = inputs[SECOND_INPUT]; + const TensorC *H = inputs[THIRD_INPUT]; + const TensorC *C = inputs[FOURTH_INPUT]; + const TensorC *weight = inputs[FIFTH_INPUT]; + + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + + for (int i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], Y); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (Y->shape_size_ != C3NUM || weight->shape_size_ != C3NUM) { + return NNACL_ERR; + } + ShapePush(out_shape, &out_shape_size, Y->shape_[out_shape_size]); + ShapePush(out_shape, &out_shape_size, Y->shape_[out_shape_size]); + ShapePush(out_shape, &out_shape_size, p->input_size_); + + SetShapeArray(outputs[FIRST_INPUT], out_shape, C3NUM); + SetShapeArray(outputs[SECOND_INPUT], H->shape_, H->shape_size_); + SetShapeArray(outputs[THIRD_INPUT], C->shape_, C->shape_size_); + + return NNACL_OK; +} + +REG_INFER(LSTMGradData, PrimType_LSTMGradData, LstmGradDataInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_data_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_data_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c8988e41957d2e0798fd857ac4a815ce254cf0b1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_data_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LSTM_GRAD_DATA_INFER_H +#define MINDSPORE_NNACL_LSTM_GRAD_DATA_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/lstm_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LstmGradDataInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LSTM_GRAD_DATA_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..ddad7f2b852972d43857cb62c1c334f6d57641fa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/lstm_grad_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32_grad/lstm_grad_fp32.h" + +int LstmGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 11, 4); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + const TensorC *H = inputs[1]; + const TensorC *C = inputs[2]; + const TensorC *weight = inputs[3]; + TensorC *output = outputs[0]; + for (size_t i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ != 3 || weight->shape_size_ != 3) { + return NNACL_ERR; + } + + SetShapeArray(output, input->shape_, input->shape_size_); + SetShapeArray(outputs[SECOND_INPUT], H->shape_, H->shape_size_); + SetShapeArray(outputs[THIRD_INPUT], C->shape_, C->shape_size_); + SetShapeArray(outputs[FOURTH_INPUT], weight->shape_, weight->shape_size_); + + return NNACL_OK; +} + +REG_INFER(LSTMGrad, PrimType_LSTMGrad, LstmGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..ea764171339bfed0e8090367850cf1f9b93cee9a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LSTM_GRAD_INFER_H +#define MINDSPORE_NNACL_LSTM_GRAD_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/lstm_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LstmGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LSTM_GRAD_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_weight_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_weight_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..7c480bfa0ee99ec766a99957fbbabe1d27a271ff --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_weight_infer.c @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/lstm_grad_weight_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32_grad/lstm_grad_fp32.h" + +int LstmGradWeightInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[FIRST_INPUT]; + const TensorC *H = inputs[SECOND_INPUT]; + const TensorC *Y = inputs[THIRD_INPUT]; + + TensorC *output = outputs[FIRST_INPUT]; + for (int i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ != C3NUM || H->shape_size_ != C3NUM || Y->shape_size_ != C3NUM) { + return NNACL_ERR; + } + LstmGradParameter *param = (LstmGradParameter *)parameter; + int has_bias = param->has_bias_; + int output_shape[3] = {0, 1, 1}; + int gate_size = 4 * param->hidden_size_; + output_shape[0] += gate_size * param->input_size_; + output_shape[0] += gate_size * param->hidden_size_; + if (has_bias) { + output_shape[0] += C2NUM * gate_size; + } + int dir_mul = (param->bidirectional_) ? C2NUM : C1NUM; + output_shape[0] *= dir_mul; + SetShapeArray(output, output_shape, C3NUM); + + return NNACL_OK; +} + +REG_INFER(LSTMGradWeight, PrimType_LSTMGradWeight, LstmGradWeightInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_weight_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_weight_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..3cf6d44da07759f243a8ced22eda0d47df9e2ca0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_grad_weight_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LSTM_GRAD_WEIGHT_INFER_H +#define MINDSPORE_NNACL_LSTM_GRAD_WEIGHT_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/lstm_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LstmGradWeightInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LSTM_GRAD_WEIGHT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..7e3966b8f720edaf74daf378c96ceec8f2d65b19 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_infer.c @@ -0,0 +1,161 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/lstm_infer.h" +#include "nnacl/infer/infer_register.h" + +static const int no_of_recorde_values = 5; + +int CheckInputShapeValid(const TensorC *const *inputs, size_t inputs_size, const LstmParameter *parameter) { + if (inputs_size < C6NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + const TensorC *input = inputs[FIRST_INPUT]; + const TensorC *weight_i = inputs[SECOND_INPUT]; + const TensorC *weight_g = inputs[THIRD_INPUT]; + const TensorC *bias = inputs[FOURTH_INPUT]; + const TensorC *hidden_init = inputs[FIFTH_INPUT]; + const TensorC *cell_init = inputs[SIXTH_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(weight_i); + NNACL_CHECK_NULL_RETURN_ERR(weight_g); + NNACL_CHECK_NULL_RETURN_ERR(bias); + NNACL_CHECK_NULL_RETURN_ERR(hidden_init); + NNACL_CHECK_NULL_RETURN_ERR(cell_init); + NNACL_CHECK_TRUE_RET(input->shape_size_ == DIMENSION_3D && weight_i->shape_size_ == DIMENSION_3D && + weight_g->shape_size_ == DIMENSION_3D && bias->shape_size_ == DIMENSION_2D, + NNACL_ERR); + int batch = input->shape_[kNHWC_H]; + int input_size = input->shape_[kNHWC_W]; + int hidden_size = weight_i->shape_[kNHWC_H] / C4NUM; + int out_size = hidden_size; + if (inputs_size == C7NUM) { + NNACL_CHECK_TRUE_RET(inputs[SEVENTH_INPUT]->shape_size_ == DIMENSION_3D, NNACL_INPUT_TENSOR_ERROR); + out_size = inputs[SEVENTH_INPUT]->shape_[kNHWC_H]; + } + bool bidirectional = parameter->bidirectional_; + int bidirection = bidirectional ? C2NUM : C1NUM; + NNACL_CHECK_TRUE_RET(weight_i->shape_[kNHWC_N] == bidirection && weight_i->shape_[kNHWC_H] == hidden_size * C4NUM && + weight_i->shape_[kNHWC_W] == input_size, + NNACL_ERR); + NNACL_CHECK_TRUE_RET(weight_g->shape_[kNHWC_N] == bidirection && weight_g->shape_[kNHWC_H] == hidden_size * C4NUM && + weight_g->shape_[kNHWC_W] == out_size, + NNACL_ERR); + NNACL_CHECK_TRUE_RET(bias->shape_[kNHWC_N] == bidirection && bias->shape_[kNHWC_H] == hidden_size * C8NUM, NNACL_ERR); + if (!bidirectional && hidden_init->shape_size_ == DIMENSION_2D) { + NNACL_CHECK_TRUE_RET(hidden_init->shape_[kNHWC_N] == batch && hidden_init->shape_[kNHWC_H] == out_size, NNACL_ERR); + } else { + NNACL_CHECK_TRUE_RET(hidden_init->shape_size_ == DIMENSION_3D && hidden_init->shape_[kNHWC_N] == bidirection && + hidden_init->shape_[kNHWC_H] == batch && hidden_init->shape_[kNHWC_W] == out_size, + NNACL_ERR); + } + if (!bidirectional && cell_init->shape_size_ == DIMENSION_2D) { + NNACL_CHECK_TRUE_RET(cell_init->shape_[kNHWC_N] == batch && cell_init->shape_[kNHWC_H] == hidden_size, NNACL_ERR); + } else { + NNACL_CHECK_TRUE_RET(cell_init->shape_size_ == DIMENSION_3D && cell_init->shape_[kNHWC_N] == bidirection && + cell_init->shape_[kNHWC_H] == batch && cell_init->shape_[kNHWC_W] == hidden_size, + NNACL_ERR); + } + return NNACL_OK; +} + +int InferFirstOutputMindir(const TensorC *const *inputs, size_t inputs_size, TensorC *output, LstmParameter *param) { + for (size_t i = 0; i < inputs_size; ++i) { + if (inputs[i]->shape_size_ != C3NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + ShapeSet(output->shape_, &output->shape_size_, inputs[0]->shape_, inputs[0]->shape_size_); + int out_size = inputs[SECOND_INPUT]->shape_[THIRD_INPUT]; + output->shape_[THIRD_INPUT] = (param->bidirectional_ ? C2NUM : 1) * out_size; + return NNACL_OK; +} + +int InferFirstOutputNonMindir(const TensorC *const *inputs, size_t inputs_size, TensorC *output, LstmParameter *param) { + if (CheckInputShapeValid(inputs, inputs_size, param) != NNACL_OK) { + return NNACL_ERR; + } + ShapeSet(output->shape_, &output->shape_size_, inputs[0]->shape_, inputs[0]->shape_size_); + const TensorC *hidden_init = inputs[FIFTH_INPUT]; + int out_size = hidden_init->shape_[hidden_init->shape_size_ - 1]; + output->shape_[THIRD_INPUT] = out_size; + int direction = param->bidirectional_ ? C2NUM : C1NUM; + int ret = ShapeInsert(output->shape_, &output->shape_size_, 1, direction); + return ret; +} + +int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 4, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + for (int i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + LstmParameter *param = (LstmParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int hidden_size = 0; + int out_size = 0; + if (inputs_size == C4NUM) { + int ret = InferFirstOutputMindir(inputs, inputs_size, output, param); + if (ret != NNACL_OK) { + return ret; + } + hidden_size = inputs[THIRD_INPUT]->shape_[THIRD_INPUT]; + out_size = inputs[SECOND_INPUT]->shape_[THIRD_INPUT]; + } else { + int ret = InferFirstOutputNonMindir(inputs, inputs_size, output, param); + if (ret != NNACL_OK) { + return ret; + } + hidden_size = inputs[SIXTH_INPUT]->shape_[inputs[SIXTH_INPUT]->shape_size_ - 1]; + out_size = inputs[FIFTH_INPUT]->shape_[inputs[FIFTH_INPUT]->shape_size_ - 1]; + } + + int dir_multiplier = param->bidirectional_ ? C2NUM : C1NUM; + int state_shape[MAX_SHAPE_SIZE]; + size_t state_shape_size = 0; + + ShapeSet(state_shape, &state_shape_size, input->shape_, input->shape_size_); + state_shape[FIRST_INPUT] = dir_multiplier; + state_shape[THIRD_INPUT] = out_size; + SetShapeArray(outputs[SECOND_INPUT], state_shape, state_shape_size); + state_shape[THIRD_INPUT] = hidden_size; + SetShapeArray(outputs[THIRD_INPUT], state_shape, state_shape_size); + + if (outputs_size > DIMENSION_4D) { + int intermediate_states_shape[MAX_SHAPE_SIZE]; + const size_t intermediate_states_shape_size = 1; + int batch_size = input->shape_[SECOND_INPUT]; + int seq_len = input->shape_[FIRST_INPUT]; + intermediate_states_shape[FIRST_INPUT] = + batch_size * seq_len * dir_multiplier * (out_size + no_of_recorde_values * hidden_size); + SetShapeArray(outputs[FOURTH_INPUT], intermediate_states_shape, intermediate_states_shape_size); + SetShapeArray(outputs[FIFTH_INPUT], state_shape, state_shape_size); + } + + return NNACL_OK; +} + +REG_INFER(LSTM, PrimType_LSTM, LstmInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c361619cd1bfb5394ca9d09201d9c9a8a4b20f69 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/lstm_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_LSTM_INFER_H +#define MINDSPORE_NNACL_LSTM_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/lstm_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_LSTM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/matmul_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/matmul_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..f5720cd9280424f8c274b7a59198791c3fba7178 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/matmul_infer.c @@ -0,0 +1,148 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/matmul_infer.h" +#include +#include "nnacl/infer/infer_register.h" + +#define MIN_SHAPE_SIZE 2 + +int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_t b_shape_size, const int *bias_shape, + size_t bias_shape_size, const MatMulParameter *param) { + if (a_shape_size < MIN_SHAPE_SIZE || b_shape_size < MIN_SHAPE_SIZE) { + return NNACL_PARAM_INVALID; + } + for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) { + int min_value = MSMIN(a_shape[i], b_shape[i]); + int max_value = MSMAX(a_shape[i], b_shape[i]); + if (min_value != 0 && max_value % min_value != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + if (param->a_transpose_) { + iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - DIMENSION_2D]); + } + if (param->b_transpose_) { + iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]); + } + if (bias_shape_size == DIMENSION_1D && bias_shape[0] != b_shape[b_shape_size - 1]) { + return NNACL_ERR; + } + if (a_shape[a_shape_size - 1] != b_shape[b_shape_size - 2]) { + return NNACL_ERR; + } + return NNACL_OK; +} + +int CheckMatMulBias(int *shape, size_t dim_size) { + if (dim_size > 1) { + for (size_t i = 0; i < dim_size - 1; i++) { + if (shape[i] != DIMENSION_1D) { + return NNACL_ERR; + } + } + } + return NNACL_OK; +} + +int SetShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + TensorC *input0 = (TensorC *)inputs[0]; + TensorC *input1 = (TensorC *)inputs[1]; + TensorC *output = outputs[0]; + MatMulParameter *param = (MatMulParameter *)parameter; + int a_shape[MAX_SHAPE_SIZE] = {0}; + size_t a_shape_size = 0; + ShapeSet(a_shape, &a_shape_size, input0->shape_, input0->shape_size_); + int b_shape[MAX_SHAPE_SIZE] = {0}; + size_t b_shape_size = 0; + ShapeSet(b_shape, &b_shape_size, input1->shape_, input1->shape_size_); + int *shape_align = a_shape_size > b_shape_size ? b_shape : a_shape; + size_t *shape_size_align = a_shape_size > b_shape_size ? &b_shape_size : &a_shape_size; + int diff = abs((int)a_shape_size - (int)b_shape_size); + for (int i = 0; i < diff; ++i) { + ShapeInsert(shape_align, shape_size_align, 0, 1); + } + int bias_shape[MAX_AXIS_SIZE] = {0}; + size_t bias_shape_size = 0; + if (inputs_size == kInputSize2) { + TensorC *bias = (TensorC *)inputs[2]; + ShapeSet(bias_shape, &bias_shape_size, bias->shape_, bias->shape_size_); + NNACL_CHECK_TRUE_RET(CheckMatMulBias(bias_shape, bias_shape_size) == NNACL_OK, NNACL_ERR); + } + + bool del_start = false; + bool del_end = false; + if (a_shape_size == 1) { + int insert_ret = ShapeInsert(a_shape, &a_shape_size, 0, 1); + if (insert_ret != NNACL_OK) { + return NNACL_ERR; + } + del_start = true; + } + if (b_shape_size == 1) { + ShapePush(b_shape, &b_shape_size, 1); + del_end = true; + } + int ret = CheckMatmulInputShape(a_shape, a_shape_size, b_shape, b_shape_size, bias_shape, bias_shape_size, param); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + int c_shape[MAX_SHAPE_SIZE]; + size_t c_shape_size = 0; + ShapeSet(c_shape, &c_shape_size, a_shape, a_shape_size); + c_shape[c_shape_size - 1] = b_shape[b_shape_size - 1]; + if (del_start) { + int erase_ret = ShapeErase(c_shape, &c_shape_size, 0); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + } + if (del_end) { + c_shape_size--; + } + + for (size_t i = 0; i < (a_shape_size - 2) && i < (b_shape_size - 2); ++i) { + c_shape[i] = MSMAX(a_shape[i], b_shape[i]); + } + + SetShapeArray(output, c_shape, c_shape_size); + return NNACL_OK; +} + +int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorC *input0 = (TensorC *)inputs[0]; + TensorC *input1 = (TensorC *)inputs[1]; + TensorC *output = outputs[0]; + + TensorC *input = input1->data_ == NULL ? input1 : input0; // transfer the input which comes from the other node. + SetDataTypeFormat(output, input); + if (input->data_type_ == kNumberTypeInt8 && parameter->quant_type_ == Quant_QuantDynamic) { + output->data_type_ = kNumberTypeFloat32; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + return SetShape(inputs, inputs_size, outputs, outputs_size, parameter); +} + +REG_INFER(MatMul, PrimType_MatMulFusion, MatmulInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/matmul_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/matmul_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..e0bb5cfe4085305400856b137ca628267c275cb6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/matmul_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_MATMUL_INFER_H +#define MINDSPORE_NNACL_MATMUL_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_MATMUL_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/max_min_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/max_min_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..63564994e426a8a2dd2a5f6a85239dd61e92910c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/max_min_grad_infer.c @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/max_min_grad_infer.h" +#include "nnacl/arithmetic_parameter.h" +#include "nnacl/infer/infer_register.h" + +int MaxMinGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *x1 = inputs[0]; + const TensorC *x2 = inputs[1]; + const TensorC *dy = inputs[2]; + TensorC *dx1 = outputs[0]; + TensorC *dx2 = outputs[1]; + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (x1->shape_size_ > MAX_SHAPE_SIZE || x2->shape_size_ > MAX_SHAPE_SIZE || dy->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + + param->ndim_ = dy->shape_size_; + param->in_elements_num0_ = (int)(param->ndim_); + param->in_elements_num1_ = (int)(param->ndim_); + param->out_elements_num_ = (int)(param->ndim_); + int fillDimNum0 = (int)(dy->shape_size_ - x1->shape_size_); + int fillDimNum1 = (int)(dy->shape_size_ - x2->shape_size_); + int j0 = 0; + int j1 = 0; + for (unsigned int i = 0; i < dy->shape_size_; i++) { + param->in_shape0_[i] = ((int)i < fillDimNum0) ? 1 : x1->shape_[j0++]; + param->in_shape1_[i] = ((int)i < fillDimNum1) ? 1 : x2->shape_[j1++]; + param->out_shape_[i] = dy->shape_[i]; + } + + SetShapeTensor(dx1, x1); + SetShapeTensor(dx2, x2); + SetDataTypeFormat(dx1, dy); + SetDataTypeFormat(dx2, dy); + return NNACL_OK; +} + +REG_INFER(MaximumGrad, PrimType_MaximumGrad, MaxMinGradInferShape) +REG_INFER(MinimumGrad, PrimType_MinimumGrad, MaxMinGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/max_min_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/max_min_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5a12f49dc42be4ea8d3530d73b9db41ee197c9de --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/max_min_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_MAX_MIN_GRAD_INFER_H_ +#define MINDSPORE_NNACL_INFER_MAX_MIN_GRAD_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int MaxMinGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_MAX_MIN_GRAD_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/mfcc_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/mfcc_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..f3f3a1f22083f2df7b4ec4add1e07b89f537bf11 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/mfcc_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/mfcc_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int MfccInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 3) { + return NNACL_ERR; + } + if (NNACLGetElementNum(inputs[1]) != 1) { + return NNACL_ERR; + } + output->shape_size_ = 3; + output->shape_[0] = input->shape_[0]; + output->shape_[1] = input->shape_[1]; + MfccParameter *param = (MfccParameter *)parameter; + output->shape_[2] = param->dct_coeff_num_; + return NNACL_OK; +} + +REG_INFER(Mfcc, PrimType_Mfcc, MfccInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/mfcc_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/mfcc_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5f7b4b0bc5f2ab9d704f5cdeecb43d4965d7502a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/mfcc_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_MFCC_INFER_H +#define MINDSPORE_NNACL_MFCC_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct MfccParameter { + OpParameter op_parameter_; + int dct_coeff_num_; +} MfccParameter; + +int MfccInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_MFCC_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/nllloss_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/nllloss_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..1516991afaa7e07e21683aeb876ba59863397f65 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/nllloss_grad_infer.c @@ -0,0 +1,54 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/nllloss_grad_infer.h" + +#include "nnacl/infer/infer_register.h" + +int NLLLossGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C5NUM, C1NUM); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *logits = inputs[0]; + const TensorC *loss_grad = inputs[1]; + const TensorC *labels = inputs[2]; + const TensorC *weight = inputs[3]; + const TensorC *total_weight = inputs[4]; + if (logits->shape_size_ != C2NUM || labels->shape_size_ != C1NUM || weight->shape_size_ != C1NUM || + total_weight->shape_size_ != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (labels->shape_[0] != logits->shape_[0] || weight->shape_[0] != logits->shape_[1]) { + return NNACL_INPUT_TENSOR_ERROR; + } + + NLLLossParameter *param = (NLLLossParameter *)parameter; + if (param->reduction_type_ == Reduction_None && loss_grad->shape_size_ != C1NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (param->reduction_type_ != Reduction_None && loss_grad->shape_size_ != 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + TensorC *logits_grad = outputs[0]; + SetDataTypeFormat(logits_grad, logits); + SetShapeTensor(logits_grad, logits); + return NNACL_OK; +} + +REG_INFER(NLLLossGrad, PrimType_NLLLossGrad, NLLLossGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/nllloss_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/nllloss_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9e9a340a15f003d93225d4b4e39c8104ff918c79 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/nllloss_grad_infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_NLLLOSS_GRAD_INFER_H +#define MINDSPORE_NNACL_NLLLOSS_GRAD_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/nllloss_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int NLLLossGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_NLLLOSS_GRAD_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/nllloss_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/nllloss_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..fb4c20a84553f9b7cf6da15d2ba1e83badfac891 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/nllloss_infer.c @@ -0,0 +1,52 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/nllloss_infer.h" + +#include "nnacl/infer/infer_register.h" + +int NLLLossInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C2NUM); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *logits = inputs[0]; + const TensorC *labels = inputs[1]; + const TensorC *weight = inputs[2]; + if (logits->shape_size_ != C2NUM || labels->shape_size_ != C1NUM || weight->shape_size_ != C1NUM) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (labels->shape_[0] != logits->shape_[0] || weight->shape_[0] != logits->shape_[1]) { + return NNACL_INPUT_TENSOR_ERROR; + } + TensorC *loss = outputs[0]; + TensorC *total_weight = outputs[1]; + + NLLLossParameter *param = (NLLLossParameter *)parameter; + if (param->reduction_type_ == Reduction_None) { + SetShapeTensor(loss, labels); + } else { + loss->shape_size_ = 0; + } + total_weight->shape_size_ = 0; + SetDataTypeFormat(loss, logits); + SetDataTypeFormat(total_weight, logits); + return NNACL_OK; +} + +REG_INFER(NLLLoss, PrimType_NLLLoss, NLLLossInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/nllloss_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/nllloss_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6942dac519d17577e2e61206fc345fa360913fe9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/nllloss_infer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_NLLLOSS_INFER_H +#define MINDSPORE_NNACL_NLLLOSS_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/nllloss_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int NLLLossInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_NLLLOSS_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/non_max_suppression_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/non_max_suppression_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..ac982315b3fa47ae52e3eaeae622bea5fcce1395 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/non_max_suppression_infer.c @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/non_max_suppression_infer.h" +#include "nnacl/infer/infer_register.h" + +int NonMaxSuppressionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = kNumberTypeInt32; + output->format_ = input->format_; + return NNACL_INFER_INVALID; +} + +REG_INFER(NonMaxSuppression, PrimType_NonMaxSuppression, NonMaxSuppressionInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/non_max_suppression_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/non_max_suppression_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..19575888dd78937933465570ae338db3e7bd08fa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/non_max_suppression_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_NON_MAX_SUPPRESSION_INFER_H +#define MINDSPORE_NNACL_NON_MAX_SUPPRESSION_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int NonMaxSuppressionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_NON_MAX_SUPPRESSION_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/one_hot_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/one_hot_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..0cfbd1eed402427256d5d07ec9ac190726ea1ca2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/one_hot_infer.c @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/one_hot_infer.h" +#include "nnacl/infer/infer_register.h" + +int OneHotInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size != 4 && inputs_size != 3) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + const TensorC *depth_tensor = inputs[1]; + const TensorC *on_value = inputs[2]; + TensorC *output = outputs[0]; + const int *depth = (int *)(depth_tensor->data_); + if (depth == NULL) { + return NNACL_NULL_PTR; + } + SetDataTypeFormat(output, on_value); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + OneHotParameter *param = (OneHotParameter *)parameter; + int axis = param->axis_; + int input_rank = (int)(input->shape_size_); + if (axis < 0) { + axis += input_rank + 1; + } + if (input->shape_size_ >= MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + ShapeSet(output->shape_, &(output->shape_size_), input->shape_, input->shape_size_); + int res_insert = ShapeInsert(output->shape_, &output->shape_size_, axis, *depth); + if (res_insert == NNACL_ERR) { + return NNACL_ERR; + } + + return NNACL_OK; +} + +REG_INFER(OneHot, PrimType_OneHot, OneHotInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/one_hot_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/one_hot_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..b5ed83ec0a421be14187ea30d3a8c97e6e009719 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/one_hot_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ONE_HOT_INFER_H +#define MINDSPORE_NNACL_ONE_HOT_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/one_hot_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int OneHotInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ONE_HOT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/pad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/pad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..59ee580b981f62749098dfdb59b476fc27224877 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/pad_infer.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/pad_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int PadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + PadParameter *param = (PadParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ > DEFAULT_PAD_NDIMS) { + return NNACL_INPUT_TENSOR_ERROR; + } + const TensorC *paddings = inputs[1]; + int size = NNACLGetElementNum(paddings); + if (size > MAX_PAD_SIZE) { + return NNACL_PARAM_INVALID; + } + if (paddings->data_ == NULL) { + return NNACL_INFER_INVALID; + } + param->padding_length = size; + for (int i = 0; i < size; ++i) { + NNACL_CHECK_TRUE_RET(((int *)paddings->data_)[i] >= 0, NNACL_INFER_INVALID); + param->paddings_[i] = ((int *)paddings->data_)[i]; + } + + int output_shape[DEFAULT_PAD_NDIMS] = {0}; + size_t output_shape_size = 0; + for (size_t i = 0; i < input->shape_size_; i++) { + int shape = input->shape_[i] + param->paddings_[2 * i] + param->paddings_[2 * i + 1]; + ShapePush(output_shape, &output_shape_size, shape); + } + + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(Pad, PrimType_PadFusion, PadInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/pad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/pad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..e277b9865ef565598867c412313c4f990b37f535 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/pad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_PAD_INFER_H +#define MINDSPORE_NNACL_PAD_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/pad_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_PAD_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/pooling_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/pooling_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..199d0444f740c7a09c1c4eca41a408cb077a1341 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/pooling_grad_infer.c @@ -0,0 +1,74 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/pooling_grad_infer.h" +#include +#include "nnacl/infer/infer_register.h" + +int PoolingGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + int input_h = input->shape_[1]; + int input_w = input->shape_[2]; + if (input->shape_size_ != 4) { + return NNACL_INPUT_TENSOR_ERROR; + } + PoolingParameter *param = (PoolingParameter *)parameter; + int window_h = param->window_h_; + int window_w = param->window_w_; + if (param->global_) { + window_h = input_h; + window_w = input_w; + } + + if (param->stride_h_ == 0 || param->stride_w_ == 0) { + return NNACL_PARAM_INVALID; + } + if (param->pad_mode_ == Pad_same) { + NNACL_CHECK_ZERO_RETURN_ERR(param->stride_w_); + NNACL_CHECK_ZERO_RETURN_ERR(param->stride_h_); + int output_w = ceil((float)(input_w) / (float)(param->stride_w_)); + int output_h = ceil((float)(input_h) / (float)(param->stride_h_)); + int pad_h_all = ((output_h - 1) * param->stride_h_ + (window_h - 1) + 1 - input_h); + int pad_w_all = ((output_w - 1) * param->stride_w_ + (window_w - 1) + 1 - input_w); + if (pad_h_all < 0) { + param->pad_u_ = param->pad_d_ = 0; + } else { + param->pad_u_ = pad_h_all / 2; + param->pad_d_ = pad_h_all - param->pad_u_; + } + if (pad_w_all < 0) { + param->pad_l_ = param->pad_r_ = 0; + } else { + param->pad_l_ = pad_w_all / 2; + param->pad_r_ = pad_w_all - param->pad_l_; + } + } + SetDataTypeFormat(outputs[0], input); + SetShapeTensor(outputs[0], input); + return NNACL_OK; +} + +REG_INFER(AvgPoolGrad, PrimType_AvgPoolGrad, PoolingGradInferShape) +REG_INFER(MaxPoolGrad, PrimType_MaxPoolGrad, PoolingGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/pooling_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/pooling_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..016ece352b184d27519bfdb0da82c2fb70d76e97 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/pooling_grad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_POOLING_GRAD_INFER_H +#define MINDSPORE_NNACL_POOLING_GRAD_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/pooling_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PoolingGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_POOLING_GRAD_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/pooling_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/pooling_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..f0ec0ae9b064abb6865e572075af5d4df5f239f7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/pooling_infer.c @@ -0,0 +1,107 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/pooling_infer.h" +#include +#include "nnacl/infer/infer_register.h" + +int ComputePadList(PoolingParameter *param, int input_h, int input_w, int output_h, int output_w) { + if (param == NULL) { + return NNACL_NULL_PTR; + } + int pad_h_all = ((output_h - 1) * param->stride_h_ + (param->window_h_ - 1) + 1 - input_h); + int pad_w_all = ((output_w - 1) * param->stride_w_ + (param->window_w_ - 1) + 1 - input_w); + if (pad_h_all < 0) { + param->pad_u_ = param->pad_d_ = 0; + } else { + param->pad_u_ = pad_h_all / 2; + param->pad_d_ = pad_h_all - param->pad_u_; + } + if (pad_w_all < 0) { + param->pad_l_ = param->pad_r_ = 0; + } else { + param->pad_l_ = pad_w_all / 2; + param->pad_r_ = pad_w_all - param->pad_l_; + } + return NNACL_OK; +} + +int PoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + NNACL_CHECK_TRUE_RET(input->format_ == Format_NHWC, NNACL_FORMAT_ERROR); + for (size_t i = 0; i < outputs_size; i++) { + TensorC *output = outputs[i]; + SetDataTypeFormat(output, input); + } + PoolingParameter *param = (PoolingParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ < 3 || input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input_h = input->shape_[1]; + int input_w = input->shape_[2]; + + int window_h = param->window_h_; + int window_w = param->window_w_; + if (param->global_) { + param->window_h_ = window_h = input_h; + param->window_w_ = window_w = input_w; + } + int output_h = 0; + int output_w = 0; + if ((param->stride_h_ == 0 || param->stride_w_ == 0) && !param->global_) { + return NNACL_PARAM_INVALID; + } + if (param->pad_mode_ == Pad_same) { + output_w = ceil((float)(input_w) / (float)(param->stride_w_)); + output_h = ceil((float)(input_h) / (float)(param->stride_h_)); + if (ComputePadList(param, input_h, input_w, output_h, output_w) != NNACL_OK) { + return NNACL_NULL_PTR; + } + } else { + int round_mode = (RoundType)param->round_type_; + if (round_mode == RoundType_Floor) { + output_h = floor((float)(input_h + param->pad_u_ + param->pad_d_ - window_h) / param->stride_h_) + 1; + output_w = floor((float)(input_w + param->pad_l_ + param->pad_r_ - window_w) / param->stride_w_) + 1; + } else if (round_mode == RoundType_Ceil) { + output_h = ceil((float)(input_h + param->pad_u_ + param->pad_d_ - window_h) / param->stride_h_) + 1; + output_w = ceil((float)(input_w + param->pad_l_ + param->pad_r_ - window_w) / param->stride_w_) + 1; + } else { + return NNACL_ERR; + } + } + int input_shape[MAX_SHAPE_SIZE]; + size_t input_shape_size = 0; + ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); + input_shape[1] = output_h > 0 ? output_h : 1; + input_shape[2] = output_w > 0 ? output_w : 1; + for (size_t i = 0; i < outputs_size; i++) { + TensorC *output = outputs[i]; + SetShapeArray(output, input_shape, input_shape_size); + } + return NNACL_OK; +} + +REG_INFER(MaxPool, PrimType_MaxPoolFusion, PoolingInferShape) +REG_INFER(AvgPool, PrimType_AvgPoolFusion, PoolingInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/pooling_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/pooling_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c4a9369b5897628ebfde9a76eb690ab17fb80c30 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/pooling_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_POOLING_INFER_H +#define MINDSPORE_NNACL_POOLING_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/pooling_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_POOLING_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/power_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/power_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..d10eafb766a5473653f1f5b7b3e6cc275cf53846 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/power_infer.c @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/power_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int PowerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *x_tensor = inputs[0]; + TensorC *exp_tensor = NULL; + if (inputs_size == 2) { + exp_tensor = (TensorC *)inputs[1]; + PowParameter *param = (PowParameter *)parameter; + float *exp_data = (float *)(exp_tensor->data_); + if (exp_data == NULL) { + return NNACL_INFER_INVALID; + } + param->power_ = *exp_data; + } + TensorC *output_tensor = outputs[0]; + + SetDataTypeFormat(output_tensor, x_tensor); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (exp_tensor != NULL) { + bool exp_x_equal = ShapeEqual(exp_tensor->shape_, exp_tensor->shape_size_, x_tensor->shape_, x_tensor->shape_size_); + if (!exp_x_equal && NNACLGetElementNum(exp_tensor) != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + } + + SetShapeTensor(output_tensor, x_tensor); + return NNACL_OK; +} + +REG_INFER(Pow, PrimType_PowFusion, PowerInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/power_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/power_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a079002f9f1ecfc332d38a941540c33eb6f9874d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/power_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_POWER_INFER_H +#define MINDSPORE_NNACL_POWER_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/pow_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PowerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_POWER_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/prior_box_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/prior_box_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..db79bd49b9dc93236423b7d6e96307f88b1ad43b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/prior_box_infer.c @@ -0,0 +1,87 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/prior_box_infer.h" +#include +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int PriorBoxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = kNumberTypeFloat32; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + float different_aspect_ratios[MAX_SHAPE_SIZE * 2 + 1]; // NOTE: flip double the number + different_aspect_ratios[0] = 1.0; + int32_t different_aspect_ratios_size = 1; + + PriorBoxParameter *param = (PriorBoxParameter *)parameter; + float *aspect_ratios = param->aspect_ratios; + if (aspect_ratios == NULL) { + return NNACL_NULL_PTR; + } + int32_t aspect_ratios_size = param->aspect_ratios_size; + NNACL_CHECK_TRUE_RET(aspect_ratios_size <= MAX_SHAPE_SIZE, NNACL_ERR); + for (int32_t i = 0; i < aspect_ratios_size; i++) { + float ratio = aspect_ratios[i]; + if (fabsf(ratio) < EPSILON_VALUE) { + return NNACL_ERR; + } + + bool exist = false; + for (int32_t j = 0; j < different_aspect_ratios_size; j++) { + if (fabsf(ratio - different_aspect_ratios[j]) < EPSILON_VALUE) { + exist = true; + break; + } + } + if (!exist) { + different_aspect_ratios[different_aspect_ratios_size] = ratio; + different_aspect_ratios_size++; + if (param->flip) { + different_aspect_ratios[different_aspect_ratios_size] = 1.0f / ratio; + different_aspect_ratios_size++; + } + } + } + + int32_t min_sizes_size = param->min_sizes_size; + int32_t max_sizes_size = param->max_sizes_size; + int32_t num_priors_box = min_sizes_size * different_aspect_ratios_size + max_sizes_size; + const int kPriorBoxPoints = 4; + const int kPriorBoxN = 1; + const int kPriorBoxW = 1; + const int kPriorBoxC = 2; + + int32_t h = NNACLGetHeight(input) * NNACLGetWidth(input) * num_priors_box * kPriorBoxPoints; + output->shape_size_ = 4; + output->shape_[0] = kPriorBoxN; + output->shape_[1] = h; + output->shape_[2] = kPriorBoxW; + output->shape_[3] = kPriorBoxC; + return NNACL_OK; +} + +REG_INFER(PriorBox, PrimType_PriorBox, PriorBoxInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/prior_box_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/prior_box_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..80b5af6db8100dd97832eaa0efdeb719cc4e00db --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/prior_box_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_PRIOR_BOX_INFER_H +#define MINDSPORE_NNACL_PRIOR_BOX_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/prior_box_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int PriorBoxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_PRIOR_BOX_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/quant_dtype_cast_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/quant_dtype_cast_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..bb48b8d79c563e1f7c09976c0724b059d16de617 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/quant_dtype_cast_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/quant_dtype_cast_infer.h" +#include "nnacl/infer/infer_register.h" + +int QuantDtypeCastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + QuantDtypeCastParameter *param = (QuantDtypeCastParameter *)parameter; + output->data_type_ = param->dstT_; + NNACL_CHECK_TRUE_RET(output->data_type_ > kNumberTypeBegin && output->data_type_ < kNumberTypeEnd, NNACL_ERR); + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(QuantDTypeCast, PrimType_QuantDTypeCast, QuantDtypeCastInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/quant_dtype_cast_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/quant_dtype_cast_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..bf854ba717a3af11aac8047d5d9d8b34686299c8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/quant_dtype_cast_infer.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_QUANT_DTYPE_CAST_INFER_H +#define MINDSPORE_NNACL_QUANT_DTYPE_CAST_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct QuantDtypeCastParameter { + OpParameter op_parameter_; + int srcT_; // deprecated + int dstT_; +} QuantDtypeCastParameter; + +int QuantDtypeCastInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_QUANT_DTYPE_CAST_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/ragged_range_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/ragged_range_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..5ec1234d79afeaaaf4fed2875b792b30b9301433 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/ragged_range_infer.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/infer/ragged_range_infer.h" +#include +#include "nnacl/infer/infer_register.h" + +int CheckInputTensor(const TensorC *const *inputs) { + if (inputs[0]->data_ == NULL || inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { + return NNACL_INFER_INVALID; + } + if (inputs[0]->shape_size_ != 0 && inputs[0]->shape_size_ != 1) { + return NNACL_ERR; + } + return NNACL_OK; +} + +int GetRows(const TensorC *const *inputs, bool starts_is_scalar, bool limits_is_scalar, bool deltas_is_scalar, + int *rows) { + NNACL_CHECK_NULL_RETURN_ERR(rows); + int sizes[3]; + int not_scalar_count = 0; + if (!starts_is_scalar) { + sizes[not_scalar_count++] = inputs[0]->shape_[0]; + } + if (!limits_is_scalar) { + sizes[not_scalar_count++] = inputs[1]->shape_[0]; + } + if (!deltas_is_scalar) { + sizes[not_scalar_count++] = inputs[2]->shape_[0]; + } + for (int i = 1; i < not_scalar_count; i++) { + if (sizes[i] != sizes[i - 1]) { + return NNACL_ERR; + } + } + *rows = not_scalar_count == 0 ? 1 : sizes[0]; + return NNACL_OK; +} + +int GetOutputValueElementNum(const TensorC *const *inputs, bool starts_is_scalar, bool limits_is_scalar, + bool deltas_is_scalar, int rows, int *output_value_element_num) { + int count = 0; + switch (inputs[0]->data_type_) { + case kNumberTypeInt32: { + int *starts = (int *)(inputs[0]->data_); + int *limits = (int *)(inputs[1]->data_); + int *deltas = (int *)(inputs[2]->data_); + for (int i = 0; i < rows; i++) { + int start = starts_is_scalar ? starts[0] : starts[i]; + int limit = limits_is_scalar ? limits[0] : limits[i]; + int delta = deltas_is_scalar ? deltas[0] : deltas[i]; + NNACL_CHECK_ZERO_RETURN_ERR(delta); + count += MSMAX((int)(ceil((float)(limit - start) / delta)), 0); + } + } break; + case kNumberTypeFloat32: { + float *starts = (float *)(inputs[0]->data_); + float *limits = (float *)(inputs[1]->data_); + float *deltas = (float *)(inputs[2]->data_); + for (int i = 0; i < rows; i++) { + float start = starts_is_scalar ? starts[0] : starts[i]; + float limit = limits_is_scalar ? limits[0] : limits[i]; + float delta = deltas_is_scalar ? deltas[0] : deltas[i]; + NNACL_CHECK_ZERO_RETURN_ERR(delta); + count += MSMAX((ceil((limit - start) / delta)), 0); + } + } break; + default: { + return NNACL_ERR; + } + } + *output_value_element_num = count; + return NNACL_OK; +} + +int RaggedRangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + outputs[0]->data_type_ = kNumberTypeInt32; + outputs[0]->format_ = inputs[0]->format_; + SetDataTypeFormat(outputs[1], inputs[0]); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int ret = CheckInputTensor(inputs); + if (ret != NNACL_OK) { + return ret; + } + + bool starts_is_scalar = inputs[0]->shape_size_ == 0; + bool limits_is_scalar = inputs[1]->shape_size_ == 0; + bool deltas_is_scalar = inputs[2]->shape_size_ == 0; + int rows; + ret = GetRows(inputs, starts_is_scalar, limits_is_scalar, deltas_is_scalar, &rows); + if (ret != NNACL_OK) { + return ret; + } + int output_value_element_num; + ret = GetOutputValueElementNum(inputs, starts_is_scalar, limits_is_scalar, deltas_is_scalar, rows, + &output_value_element_num); + if (ret != NNACL_OK) { + return ret; + } + outputs[0]->shape_size_ = 1; + outputs[0]->shape_[0] = rows + 1; + outputs[1]->shape_size_ = 1; + outputs[1]->shape_[0] = output_value_element_num; + return NNACL_OK; +} + +REG_INFER(RaggedRange, PrimType_RaggedRange, RaggedRangeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/ragged_range_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/ragged_range_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..11e0365d395be513a3efb0d749f3847a1a013065 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/ragged_range_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RAGGED_RANGE_INFER_H +#define MINDSPORE_NNACL_RAGGED_RANGE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/ragged_range_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RaggedRangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RAGGED_RANGE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/random_normal_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/random_normal_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..23b805b31be5544df7704d39c204a35dcff9592f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/random_normal_infer.c @@ -0,0 +1,38 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/random_normal_infer.h" +#include "nnacl/infer/infer_register.h" + +int RandomNormalInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + outputs[0]->data_type_ = inputs[0]->data_type_; + outputs[0]->format_ = inputs[0]->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + SetShapeTensor(outputs[0], inputs[0]); + + return NNACL_OK; +} + +REG_INFER(RandomNormal, PrimType_RandomNormal, RandomNormalInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/random_normal_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/random_normal_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..f07353a2493fb2bc62ef04cc7599ca2e2eb3e164 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/random_normal_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANDOM_NORMAL_INFER_H +#define MINDSPORE_NNACL_RANDOM_NORMAL_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RandomNormalInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RANDOM_STANDARD_NORMAL_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/random_standard_normal_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/random_standard_normal_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..330979523db8353516ecc5343772dd56d04e1588 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/random_standard_normal_infer.c @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/random_standard_normal_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int RandomStandardNormalInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + outputs[0]->data_type_ = kNumberTypeFloat32; + outputs[0]->format_ = inputs[0]->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int32_t *input_data = (int32_t *)(inputs[0]->data_); + if (input_data == NULL) { + return NNACL_INFER_INVALID; + } + int input_num = NNACLGetElementNum(inputs[0]); + if (input_num > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (int i = 0; i < input_num; i++) { + ShapePush(output_shape, &output_shape_size, input_data[i]); + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + + return NNACL_OK; +} + +REG_INFER(RandomStandardNormal, PrimType_RandomStandardNormal, RandomStandardNormalInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/random_standard_normal_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/random_standard_normal_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..f769e4b9c2e6b3ba9b3218f9c6721d7aac6b55d6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/random_standard_normal_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANDOM_STANDARD_NORMAL_INFER_H +#define MINDSPORE_NNACL_RANDOM_STANDARD_NORMAL_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RandomStandardNormalInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RANDOM_STANDARD_NORMAL_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/range_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/range_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..e49089dee0a71ab7b857ba944f8306b1b20afed6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/range_infer.c @@ -0,0 +1,91 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/range_infer.h" +#include +#include "nnacl/infer/infer_register.h" +#include "nnacl/range_parameter.h" +#include "nnacl/tensor_c_utils.h" + +int RangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, C3NUM, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = inputs_size == C3NUM ? input->data_type_ : kNumberTypeInt32; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(inputs[FIRST_INPUT]) < 1) { + return NNACL_ERR; + } + int shape_size = 0; + if (inputs_size == C3NUM) { + NNACL_CHECK_FALSE(inputs[FIRST_INPUT]->data_ == NULL, NNACL_INFER_INVALID); + NNACL_CHECK_FALSE(inputs[SECOND_INPUT]->data_ == NULL, NNACL_INFER_INVALID); + NNACL_CHECK_FALSE(inputs[THIRD_INPUT]->data_ == NULL, NNACL_INFER_INVALID); + if ((inputs[FIRST_INPUT]->data_type_ != inputs[SECOND_INPUT]->data_type_) || + (inputs[FIRST_INPUT]->data_type_ != inputs[THIRD_INPUT]->data_type_)) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(inputs[SECOND_INPUT]) < 1 || NNACLGetElementNum(inputs[THIRD_INPUT]) < 1) { + return NNACL_ERR; + } + switch (inputs[0]->data_type_) { + case kNumberTypeInt: + case kNumberTypeInt32: { + int start = *(int *)(inputs[0]->data_); + int limit = *(int *)(inputs[1]->data_); + int delta = *(int *)(inputs[2]->data_); + if (delta == 0) { + return NNACL_ERR; + } + shape_size = imax((int)(ceil((float)(limit - start) / delta)), 0); + } break; + case kNumberTypeFloat32: + case kNumberTypeFloat: { + float start = *(float *)(inputs[0]->data_); + float limit = *(float *)(inputs[1]->data_); + float delta = *(float *)(inputs[2]->data_); + if (fabsf(delta) < EPSILON_VALUE) { + return NNACL_ERR; + } + shape_size = imax((int)(ceil((float)(limit - start) / delta)), 0); + } break; + default: { + return NNACL_ERR; + } + } + } else { + RangeParameter *param = (RangeParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (param->delta_ == 0) { + return NNACL_PARAM_INVALID; + } + shape_size = ceil((float)(param->limit_ - param->start_) / param->delta_); + } + + output->shape_size_ = 1; + output->shape_[0] = shape_size; + return NNACL_OK; +} + +REG_INFER(Range, PrimType_Range, RangeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/range_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/range_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..efaceafef4ea82e6ce8274a7dc9d7757756aa3f7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/range_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANGE_INFER_H +#define MINDSPORE_NNACL_RANGE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RangeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RANGE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/rank_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/rank_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..cb9cfec6816426f7213071ba10bc966f7735a199 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/rank_infer.c @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/rank_infer.h" +#include "nnacl/infer/infer_register.h" + +int RankInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + output->shape_size_ = 1; + output->shape_[0] = 1; + return NNACL_OK; +} + +REG_INFER(Rank, PrimType_Rank, RankInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/rank_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/rank_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..ab2d8af3c7ab8daaabfed8298472e02ce7508703 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/rank_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RANK_INFER_H +#define MINDSPORE_NNACL_RANK_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int RankInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RANK_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_concat_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_concat_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..e0d12232f172336d8bdfa8c169a6b939d9216e57 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_concat_infer.c @@ -0,0 +1,95 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/reduce_concat_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/split_parameter.h" + +int DataTypeJudge2(const TensorC *input, const TensorC *output) { + if ((input->data_type_ != output->data_type_) && + !((input->data_type_ == kNumberTypeFloat16 && output->data_type_ == kNumberTypeFloat32) || + (input->data_type_ == kNumberTypeFloat32 && output->data_type_ == kNumberTypeFloat16))) { + return NNACL_PARAM_INVALID; + } + return NNACL_OK; +} + +int ReduceConcatFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullOutputSize(inputs, inputs_size, outputs, outputs_size, parameter, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (inputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input0 = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input0); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + const int *input0_shape = inputs[0]->shape_; + size_t input0_shape_size = inputs[0]->shape_size_; + + int axis = C2NUM; + if (axis < 0 || axis >= (int)input0_shape_size) { + return NNACL_ERR; + } + if (input0_shape_size > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int input0_shape_without_axis[MAX_SHAPE_SIZE] = {0}; + size_t input0_shape_without_axis_size = 0; + ShapeSet(input0_shape_without_axis, &input0_shape_without_axis_size, input0_shape, input0_shape_size); + int erase_ret = ShapeErase(input0_shape_without_axis, &input0_shape_without_axis_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + int output_axis_dim = input0_shape[axis]; + for (size_t i = 1; i < inputs_size; ++i) { + size_t input_i_shape_size = inputs[i]->shape_size_; + if (input_i_shape_size != input0_shape_size) { + return NNACL_PARAM_INVALID; + } + int shape_tmp[MAX_SHAPE_SIZE] = {0}; + size_t shape_tmp_size = 0; + ShapeSet(shape_tmp, &shape_tmp_size, inputs[i]->shape_, inputs[i]->shape_size_); + int data_type_judge = DataTypeJudge2(inputs[i], output); + if (data_type_judge != NNACL_OK) { + return data_type_judge; + } + int axis_tmp = shape_tmp[axis]; + erase_ret = ShapeErase(shape_tmp, &shape_tmp_size, axis); + if (erase_ret != NNACL_OK) { + return NNACL_ERR; + } + + output_axis_dim += axis_tmp; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input0_shape_size; + for (size_t i = 0; i < input0_shape_size; i++) { + output_shape[i] = input0_shape[i]; + } + output_shape[axis] = output_axis_dim; + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(ReduceConcatFusion, PrimType_Inner_ReduceConcatFusion, ReduceConcatFusionInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_concat_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_concat_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6db11c930259ad4453ce6a21c3c414aba6a7745f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_concat_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_REDUCE_CONCAT_ONLINE_FUSION_INFER_H +#define MINDSPORE_NNACL_REDUCE_CONCAT_ONLINE_FUSION_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReduceConcatFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..0b764f9a8a33e59f4efaf909b96fad2d54cc1885 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_infer.c @@ -0,0 +1,140 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/reduce_infer.h" +#include "nnacl/infer/infer_register.h" + +int ReduceOnAllAxes(const TensorC *input, TensorC *output, int *out_shape, size_t out_shape_size, bool keep_dims) { + if (keep_dims) { + for (size_t i = 0; i < input->shape_size_; i++) { + ShapePush(out_shape, &out_shape_size, 1); + } + } + SetShapeArray(output, out_shape, out_shape_size); + output->data_type_ = input->data_type_; + return NNACL_OK; +} + +int ReduceOnSelectedAxes(const TensorC *input, size_t num_axes, const int *actual_axes, TensorC *output, int *out_shape, + size_t out_shape_size, bool keep_dims) { + for (size_t i = 0; i < input->shape_size_; i++) { + bool reduce_axis = false; + for (size_t idx = 0; idx < num_axes; ++idx) { + if ((size_t)(actual_axes[idx]) == i || (size_t)(actual_axes[idx]) + input->shape_size_ == i) { + reduce_axis = true; + break; + } + } + if (reduce_axis) { + if (keep_dims) { + ShapePush(out_shape, &out_shape_size, 1); + } + } else { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +bool IsReduceAllAxes(const TensorC *const *inputs, size_t inputs_size) { + if (inputs_size == 1) { + return true; + } + // When axes not given, reduce op will have two input tensor by the old version converter_lite tool. + if (inputs_size == 2 && inputs[1]->shape_size_ == 1 && inputs[1]->shape_[0] == 0) { + return true; + } + return false; +} + +int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + ReduceParameter *param = (ReduceParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + bool keep_dims = param->keep_dims_; + int out_shape[MAX_SHAPE_SIZE] = {0}; + const size_t out_shape_size = 0; + if (IsReduceAllAxes(inputs, inputs_size)) { + return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims); + } + + // get axes from input tensor + const TensorC *axes_input = inputs[1]; + NNACL_CHECK_NULL_RETURN_ERR(axes_input->data_); + + int num_axes; + if (axes_input->shape_size_ == 1) { + num_axes = axes_input->shape_[0]; + } else if (axes_input->shape_size_ == 0) { + num_axes = 1; + } else { + return NNACL_ERR; + } + if (num_axes > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int rank = (int)(input->shape_size_); + if (rank > MAX_SHAPE_SIZE || rank < 0) { + return NNACL_ERR; + } + int actual_axes[MAX_SHAPE_SIZE] = {0}; + size_t actual_axes_size = 0; + int ret = GetInt32DataFromTensor(axes_input, actual_axes, &actual_axes_size); + if (ret != NNACL_OK) { + return ret; + } + + if (param->reduce_to_end_) { + if (num_axes != 1) { + return NNACL_ERR; + } + + if (actual_axes[0] < -1 * rank || actual_axes[0] >= rank) { + return NNACL_PARAM_INVALID; + } + int begin_axis; + begin_axis = actual_axes[0] < 0 ? actual_axes[0] + rank : actual_axes[0]; + for (int i = begin_axis + 1; i < rank; ++i) { + ShapePush(actual_axes, &actual_axes_size, i); + } + num_axes = rank - begin_axis; + keep_dims = false; + } + // reduce on all axes + if (num_axes == 0) { + return ReduceOnAllAxes(input, output, out_shape, out_shape_size, keep_dims); + } + // reduce on selected axes + return ReduceOnSelectedAxes(input, (size_t)num_axes, actual_axes, output, out_shape, out_shape_size, keep_dims); +} + +REG_INFER(Reduce, PrimType_ReduceFusion, ReduceInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..abe278b2533c5b4b6cd0cd81d6971f7814d3fa6a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_REDUCE_INFER_H +#define MINDSPORE_NNACL_REDUCE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/reduce_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_REDUCE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_scatter_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_scatter_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..654c44076957ea1edc60a0046b42ec8b432b8911 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_scatter_infer.c @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/reduce_scatter_infer.h" +#include "nnacl/infer/infer_register.h" + +int ReduceScatterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) { + return NNACL_NULL_PTR; + } + SetDataTypeFormat(outputs[0], inputs[0]); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + ReduceScatterParameter *param = (ReduceScatterParameter *)parameter; + if (param->rank_size_ <= 0) { + return NNACL_INFER_INVALID; + } + + const TensorC *input_tensor = inputs[0]; + const int *in_shape = input_tensor->shape_; + TensorC *out_tensor = outputs[0]; + + if (in_shape[0] % param->rank_size_ != 0) { + return NNACL_INFER_INVALID; + } + + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + out_shape[0] = in_shape[0] / param->rank_size_; + out_shape_size++; + for (int i = 1; i < input_tensor->shape_size_; i++) { + out_shape[i] = in_shape[i]; + out_shape_size++; + } + SetShapeArray(out_tensor, out_shape, out_shape_size); + + return NNACL_OK; +} + +REG_INFER(ReduceScatter, PrimType_ReduceScatter, ReduceScatterInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_scatter_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_scatter_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..1fe74eb0ed5adde61d481403dcba2769fa01ae9f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reduce_scatter_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_REDUCE_SCATTER_INFER_H +#define MINDSPORE_NNACL_REDUCE_SCATTER_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/reduce_scatter_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReduceScatterInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_REDUCE_SCATTER_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/reshape_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reshape_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..1fa4192da655ec98ab4e5db7830c19027b82cce4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reshape_infer.c @@ -0,0 +1,221 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/reshape_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/op_base.h" +#include "nnacl/tensor_c_utils.h" + +int CalShape(const int *data, const TensorC *const *inputs, int *out_shape, size_t *out_shape_size, int shape_size) { + int input_count = NNACLGetElementNum(inputs[0]); + int index = 0; + int size = 1; + for (int i = 0; i < shape_size; i++) { + if ((int)(data[i]) == -1) { + index = i; + } else if ((int)(data[i]) == 0) { + size *= inputs[0]->shape_[i]; + } else { + size *= data[i]; + } + ShapePush(out_shape, out_shape_size, data[i]); + } + if (size == 0) { + return NNACL_ERR; + } + if ((int)(data[index]) == -1) { + if (index >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + out_shape[index] = input_count / size; + } + return NNACL_OK; +} + +int CalNewShape(const TensorC *in_tensor, int *out_shape, size_t out_shape_size) { + int in_shape_size = 1; + for (size_t i = 0; i < in_tensor->shape_size_; i++) { + in_shape_size *= in_tensor->shape_[i]; + } + int64_t infer_index = -1; + int out_shape_size_new = 1; + for (size_t i = 0; i < out_shape_size; i++) { + if (out_shape[i] == -1) { + if (infer_index == -1) { + infer_index = (int64_t)(i); + } else { + return NNACL_ERR; + } + } else if (out_shape[i] < 0) { + return NNACL_ERR; + } else if (out_shape[i] == 0) { + if (NNACLGetElementNum(in_tensor) != 0) { + out_shape[i] = in_tensor->shape_[i]; + out_shape_size_new *= out_shape[i]; + } else { + out_shape_size_new = 0; + break; + } + } else { + out_shape_size_new *= out_shape[i]; + } + } + if (infer_index == -1 && out_shape_size_new != in_shape_size) { + return NNACL_ERR; + } + if (infer_index != -1) { + if (out_shape_size_new == 0) { + return NNACL_ERR; + } + if (infer_index >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + out_shape[infer_index] = in_shape_size / out_shape_size_new; + } + return NNACL_OK; +} + +int CalShapeByType(const TensorC *const *inputs, size_t shape_size, int *out_shape, size_t *out_shape_size) { + const TensorC *shape_tensor = inputs[1]; + if (shape_size == 0) { + return NNACL_ERR; + } + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW((sizeof(int)), shape_size), NNACL_ERR); + int *data_int = (int *)malloc(sizeof(int) * shape_size); + if (data_int == NULL) { + return NNACL_ERR; + } + switch (shape_tensor->data_type_) { + case kNumberTypeInt8: { + int8_t *data = (int8_t *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + case kNumberTypeInt32: { + int32_t *data = (int32_t *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + case kNumberTypeInt64: { + int64_t *data = (int64_t *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + case kNumberTypeFloat: { + float *data = (float *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + case kNumberTypeUInt32: { + uint32_t *data = (uint32_t *)(shape_tensor->data_); + for (size_t i = 0; i < shape_size; i++) { + data_int[i] = (int)data[i]; + } + int cal_ret = CalShape(data_int, inputs, out_shape, out_shape_size, shape_size); + if (cal_ret != NNACL_OK) { + free(data_int); + return NNACL_ERR; + } + } break; + default: { + free(data_int); + return NNACL_ERR; + } + } + free(data_int); + return NNACL_OK; +} + +int ReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + ReshapeParameter *param = (ReshapeParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + if (inputs_size == 2) { + const TensorC *shape_tensor = inputs[1]; + if (NNACLGetElementNum(input) == 1) { + if (shape_tensor->data_ == NULL || (shape_tensor->shape_size_ == 1 && shape_tensor->shape_[0] == 0)) { + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; + } + } + + if (shape_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int shape_size = NNACLGetElementNum(shape_tensor); + if (shape_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int calRet = CalShapeByType(inputs, shape_size, out_shape, &out_shape_size); + if (calRet != NNACL_OK) { + return calRet; + } + } else if (inputs_size == 1) { + if (param->shape_dim_ > MAX_SHAPE_SIZE) { + return NNACL_PARAM_INVALID; + } + for (int i = 0; i < param->shape_dim_; ++i) { + ShapePush(out_shape, &out_shape_size, param->shape_[i]); + } + } else { + return NNACL_ERR; + } + int ret = CalNewShape(inputs[0], out_shape, out_shape_size); + if (ret != NNACL_OK) { + return ret; + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Reshape, PrimType_Reshape, ReshapeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/reshape_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reshape_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..484eecdcc216431aead1607399b4b00780468824 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/reshape_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RESHAPE_INFER_H +#define MINDSPORE_NNACL_RESHAPE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/reshape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RESHAPE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/resize_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/resize_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..a90b499b2b216d3cc466e5b3a904618cfd12f1ac --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/resize_grad_infer.c @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/resize_grad_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int ResizeGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *input_1 = inputs[1]; + if (input_1->shape_size_ == 4) { + ShapeSet(output->shape_, &output->shape_size_, input_1->shape_, input_1->shape_size_); + } else if (input_1->shape_size_ == 1 && input_1->shape_[0] == 2 && input_1->data_type_ == kNumberTypeInt32) { + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + int32_t *data = (int32_t *)(input_1->data_); + + ShapePush(output_shape, &output_shape_size, NNACLGetBatch(input)); + ShapePush(output_shape, &output_shape_size, data[0]); + ShapePush(output_shape, &output_shape_size, data[1]); + ShapePush(output_shape, &output_shape_size, NNACLGetChannel(input)); + SetShapeArray(output, output_shape, output_shape_size); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} + +REG_INFER(ResizeGrad, PrimType_ResizeGrad, ResizeGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/resize_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/resize_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a1ed88b48ffaee4c42d1315deec221a3973b7338 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/resize_grad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RESIZE_GRAD_INFER_H_ +#define MINDSPORE_NNACL_RESIZE_GRAD_INFER_H_ + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32_grad/resize_grad.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ResizeGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RESIZE_GRAD_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/resize_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/resize_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b689b05fd7b67124c40c6f8a7ea51719ee9a8b6d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/resize_infer.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/resize_infer.h" +#include +#include +#include "nnacl/infer/infer_register.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/tensor_c_utils.h" + +int HandleTwoInputs(const TensorC *const *inputs, ResizeParameter *param) { + const TensorC *input = inputs[0]; + const TensorC *shape_tensor = inputs[1]; + if (shape_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int shape_size = NNACLGetElementNum(shape_tensor); + void *origin_data = shape_tensor->data_; + if (origin_data == NULL) { + return NNACL_INFER_INVALID; + } + switch (shape_size) { + case 2: + case 4: { + int height_index = 0; + int width_index = 1; + if (shape_size == 4) { + height_index = kNHWC_H; + width_index = kNHWC_W; + } + if (shape_tensor->data_type_ == kNumberTypeInt32) { + int32_t *data = (int32_t *)(origin_data); + param->new_height_ = data[height_index]; + param->new_width_ = data[width_index]; + } else if (shape_tensor->data_type_ == kNumberTypeFloat32) { + float *data = (float *)(origin_data); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[height_index]), NNACLGetHeight(input), NNACL_ERRCODE_MUL_OVERFLOW); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW((int)(data[width_index]), NNACLGetWidth(input), NNACL_ERRCODE_MUL_OVERFLOW); + param->new_height_ = round(data[height_index] * NNACLGetHeight(input)); + param->new_width_ = round(data[width_index] * NNACLGetWidth(input)); + } else if (shape_tensor->data_type_ == kNumberTypeFloat16) { + uint16_t *data = (uint16_t *)(shape_tensor->data_); + float scale_height = ShortToFloat32(data[height_index]); + float scale_width = ShortToFloat32(data[width_index]); + param->new_height_ = round(scale_height * NNACLGetHeight(input)); + param->new_width_ = round(scale_width * NNACLGetWidth(input)); + } + break; + } + case 1: { + // caffe zoom_factor + int scale; + if (shape_tensor->data_type_ == kNumberTypeInt32) { + int *data = (int *)(origin_data); + scale = data[0]; + } else { + return NNACL_ERR; + } + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(NNACLGetHeight(input) - 1, scale - 1, NNACL_ERRCODE_MUL_OVERFLOW); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(NNACLGetWidth(input) - 1, scale - 1, NNACL_ERRCODE_MUL_OVERFLOW); + param->new_height_ = NNACLGetHeight(input) + (NNACLGetHeight(input) - 1) * (scale - 1); + param->new_width_ = NNACLGetWidth(input) + (NNACLGetWidth(input) - 1) * (scale - 1); + break; + } + default: { + return NNACL_ERR; + } + } + return NNACL_OK; +} + +int CalculateNewHeightAndWidth(const TensorC *const *inputs, size_t inputs_size, ResizeParameter *param) { + if (inputs_size == 2) { + return HandleTwoInputs(inputs, param); + } else if (inputs_size == 1) { + } else { + return NNACL_ERR; + } + return NNACL_OK; +} + +int ResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 0 && input->shape_size_ != 4) { + return NNACL_ERR; + } + ResizeParameter *param = (ResizeParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapePush(output_shape, &output_shape_size, NNACLGetBatch(input)); + int ret = CalculateNewHeightAndWidth(inputs, inputs_size, param); + if (ret == NNACL_OK) { + ShapePush(output_shape, &output_shape_size, param->new_height_); + ShapePush(output_shape, &output_shape_size, param->new_width_); + ShapePush(output_shape, &output_shape_size, NNACLGetChannel(input)); + SetShapeArray(output, output_shape, output_shape_size); + } + return ret; +} + +REG_INFER(Resize, PrimType_Resize, ResizeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/resize_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/resize_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..aeef0a69e183f992cce60cdd0882fc6a5ac4d97f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/resize_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RESIZE_INFER_H +#define MINDSPORE_NNACL_RESIZE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/resize_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RESIZE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/rfft_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/rfft_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..6759b0b997c11a64fea8a445c038d650dcd751f8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/rfft_infer.c @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/rfft_infer.h" +#include "nnacl/infer/infer_register.h" + +int RfftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + output->data_type_ = kNumberTypeComplex64; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ >= MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + ShapeSet(output->shape_, &(output->shape_size_), input->shape_, input->shape_size_); + RfftParameter *param = (RfftParameter *)parameter; + if (input->shape_size_ < 1) { + return NNACL_ERR; + } + output->shape_[input->shape_size_ - 1] = param->fft_length_ / 2 + 1; + ShapePush(output->shape_, &(output->shape_size_), 2); + return NNACL_OK; +} + +REG_INFER(Rfft, PrimType_Rfft, RfftInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/rfft_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/rfft_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..19ef61ffee793cd721cb6e7d7b8557ccf2339a37 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/rfft_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_RFFT_INFER_H +#define MINDSPORE_NNACL_RFFT_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct RfftParameter { + OpParameter op_parameter_; + int fft_length_; +} RfftParameter; + +int RfftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_RFFT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/roi_pooling_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/roi_pooling_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..f492a8034db9da7a34582a1d74b8210532e5dac6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/roi_pooling_infer.c @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/roi_pooling_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int ROIPoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + if (outputs_size < 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + const TensorC *roi = inputs[1]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + ROIPoolingParameter *param = (ROIPoolingParameter *)parameter; + output->shape_size_ = 4; + output->shape_[0] = roi->shape_[0]; + output->shape_[1] = param->pooledH_; + output->shape_[2] = param->pooledW_; + output->shape_[3] = NNACLGetChannel(input); + return NNACL_OK; +} + +REG_INFER(ROIPooling, PrimType_ROIPooling, ROIPoolingInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/roi_pooling_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/roi_pooling_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..4f730eb64bda0af461bdd0d84bb8fdd42dd27c2b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/roi_pooling_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ROI_POOLING_INFER_H +#define MINDSPORE_NNACL_ROI_POOLING_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/roi_pooling_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ROIPoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ROI_POOLING_INFER_H diff --git a/mindspore-lite/src/extendrt/kernel/cpu/less_test_kernel_mod.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/scatter_nd_infer.c similarity index 37% rename from mindspore-lite/src/extendrt/kernel/cpu/less_test_kernel_mod.h rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/scatter_nd_infer.c index 4dfd7a6312f55fe6e6ab0a29cda74947639fd611..744127f2c8b1ad92dbc07aac631d144e2c7b29d8 100644 --- a/mindspore-lite/src/extendrt/kernel/cpu/less_test_kernel_mod.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/scatter_nd_infer.c @@ -14,29 +14,32 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MINDIR_MODEL_LESS_TEST_KERNEL_MOD_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MINDIR_MODEL_LESS_TEST_KERNEL_MOD_H_ - -#include -#include - -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "common/common_utils.h" - -namespace mindspore::kernel { -class LessTestKernelMod : public NativeCpuKernelMod { - public: - LessTestKernelMod() = default; - ~LessTestKernelMod() override = default; - - explicit LessTestKernelMod(const std::string name) { kernel_name_ = name; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - std::vector GetOpSupport() override { return {}; } -}; -} // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_MINDIR_LOADER_MINDIR_MODEL_LESS_TEST_KERNEL_MOD_H_ +#include "nnacl/infer/scatter_nd_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int ScatterNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *shape = inputs[THIRD_INPUT]; + if (shape->data_ == NULL) { + return NNACL_INFER_INVALID; + } + const TensorC *update = inputs[SECOND_INPUT]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, update); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int *shape_data = (int *)(shape->data_); + NNACL_CHECK_TRUE_RET(NNACLGetElementNum(shape) <= MAX_SHAPE_SIZE, NNACL_ERR); + SetShapeArray(output, shape_data, (size_t)NNACLGetElementNum(shape)); + return NNACL_OK; +} + +REG_INFER(ScatterNd, PrimType_ScatterNd, ScatterNdInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/scatter_nd_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/scatter_nd_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..7b035b15a0ea9b266ad4b7570e5241add8870652 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/scatter_nd_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SCATTER_ND_INFER_H +#define MINDSPORE_NNACL_SCATTER_ND_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ScatterNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SCATTER_ND_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/scatter_nd_update_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/scatter_nd_update_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..f53c0a7d92a21f5bbdcc70437dda36f3a3081064 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/scatter_nd_update_infer.c @@ -0,0 +1,59 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/scatter_nd_update_infer.h" +#include "nnacl/infer/infer_register.h" + +int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input_x = inputs[0]; + const TensorC *indices = inputs[1]; + const TensorC *updates = inputs[2]; + TensorC *output = outputs[0]; + if (updates->data_type_ != input_x->data_type_ || + (indices->data_type_ != kNumberTypeInt32 && indices->data_type_ != kNumberTypeInt64)) { + return NNACL_ERR; + } + SetDataTypeFormat(output, input_x); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (indices->shape_size_ < 2 || indices->shape_[indices->shape_size_ - 1] > input_x->shape_size_) { + return NNACL_ERR; + } + if (updates->shape_size_ != + (indices->shape_size_ - 1) + input_x->shape_size_ - indices->shape_[indices->shape_size_ - 1]) { + return NNACL_ERR; + } + for (int i = 0; i < updates->shape_size_; i++) { + if ((i < indices->shape_size_ - 1 && updates->shape_[i] != indices->shape_[i]) || + (i >= indices->shape_size_ - 1 && + updates->shape_[i] != + input_x->shape_[indices->shape_[indices->shape_size_ - 1] + i - indices->shape_size_ + 1])) { + return NNACL_ERR; + } + } + SetShapeArray(output, input_x->shape_, input_x->shape_size_); + return NNACL_OK; +} + +REG_INFER(ScatterNdUpdate, PrimType_ScatterNdUpdate, ScatterNdUpdateInferShape) +REG_INFER(TensorScatterAdd, PrimType_TensorScatterAdd, ScatterNdUpdateInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/scatter_nd_update_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/scatter_nd_update_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..fb4c531a50c74ae2790eba8cd0a554be6668fa91 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/scatter_nd_update_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SCATTER_ND_UPDATE_INFER_H +#define MINDSPORE_NNACL_SCATTER_ND_UPDATE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ScatterNdUpdateInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SCATTER_ND_UPDATE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/select_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/select_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..36c158e8e4ed0757f1b4934d29877f90bac53bc1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/select_infer.c @@ -0,0 +1,62 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/select_infer.h" +#include +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensorlist_c_utils.h" + +int SelectInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = + CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 2 * outputs_size + 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + for (size_t i = 0; i < outputs_size; i++) { + const TensorC *input = inputs[i + 1]; + TensorC *output = outputs[i]; + SetDataTypeFormat(output, input); + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + for (size_t i = 0; i < outputs_size; i++) { + const TensorC *input = inputs[i + 1]; + TensorC *output = outputs[i]; + if (input->data_type_ == kObjectTypeTensorType) { + TensorListC *input_tensorlist = (TensorListC *)(input); + TensorListC *output_tensorlist = (TensorListC *)(output); + output_tensorlist->element_shape_size_ = input_tensorlist->element_shape_size_; + for (size_t j = 0; j < input_tensorlist->element_shape_size_; j++) { + output_tensorlist->element_shape_[j] = input_tensorlist->element_shape_[j]; + } + output_tensorlist->max_elements_num_ = input_tensorlist->max_elements_num_; + output_tensorlist->tensors_data_type_ = input_tensorlist->tensors_data_type_; + output_tensorlist->element_num_ = input_tensorlist->element_num_; + + for (size_t j = 0; j < output_tensorlist->element_num_; j++) { + memcpy(&output_tensorlist->tensors_[j], &input_tensorlist->tensors_[j], sizeof(TensorC)); + } + } else { + SetShapeTensor(output, input); + } + } + return NNACL_OK; +} + +REG_INFER(Select, PrimType_Select, SelectInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/select_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/select_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..0f94c22e8d418cfdfce829e2b1c409553cbedbf4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/select_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SELECT_INFER_H +#define MINDSPORE_NNACL_SELECT_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SelectInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SELECT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/sgd_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sgd_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..ce832a29faaffd7a2531d584381f859ac935b948 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sgd_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/sgd_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int SgdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, 6); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[1]) || + NNACLGetElementNum(inputs[0]) != NNACLGetElementNum(inputs[3]) || NNACLGetElementNum(inputs[2]) != 1 || + NNACLGetElementNum(inputs[4]) != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + if (outputs_size != 0) { + TensorC *out = outputs[0]; + SetDataTypeFormat(out, inputs[0]); + out->shape_size_ = 1; + out->shape_[0] = 1; + } + + return NNACL_OK; +} + +REG_INFER(SGD, PrimType_SGD, SgdInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/sgd_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sgd_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..cb32363bce008ddfb227859f704c3fbf70251e23 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sgd_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SGD_INFER_H +#define MINDSPORE_NNACL_SGD_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SgdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SGD_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/shape_fusion_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/shape_fusion_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..f57d39ae6157410dbb545c8d28d77b1fafcc3382 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/shape_fusion_infer.c @@ -0,0 +1,97 @@ +/** + * Copyright 2021-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/shape_fusion_infer.h" +#include "nnacl/infer/infer_register.h" + +int CalculateOutput(const TensorC *in_tensor, const TensorC *matrix_tensor, TensorC *out_tensor, size_t input_len, + size_t origin_out_size) { + size_t out_size = out_tensor->shape_size_ == 0 ? 1 : (size_t)(out_tensor->shape_[0]); + if (out_size != origin_out_size && out_tensor->data_ != NULL) { + free(out_tensor->data_); + out_tensor->data_ = NULL; + } + size_t matrix_data_size = input_len * out_size * sizeof(float); + float *matrix_data = (float *)(malloc(matrix_data_size)); + NNACL_CHECK_NULL_RETURN_ERR(matrix_data); + if (matrix_tensor->data_type_ == kNumberTypeFloat32 || matrix_tensor->data_type_ == kNumberTypeFloat) { + memcpy(matrix_data, matrix_tensor->data_, matrix_data_size); +#ifdef ENABLE_FP16 + } else if (matrix_tensor->data_type_ == kNumberTypeFloat16) { + for (size_t i = 0; i < input_len * out_size; i++) { + matrix_data[i] = (float)(((float16_t *)(matrix_tensor->data_))[i]); + } +#endif + } else { + free(matrix_data); + return NNACL_ERR; + } + if (out_tensor->data_ == NULL) { + out_tensor->data_ = malloc(out_size * sizeof(int)); + } + int *data = (int *)out_tensor->data_; + if (data == NULL) { + free(matrix_data); + return NNACL_ERR; + } + memset(data, 0, out_size * sizeof(int)); + for (size_t i = 0; i < out_size; i++) { + for (size_t j = 0; j < input_len - 1; j++) { + data[i] += (int)(in_tensor->shape_[j] * matrix_data[i * input_len + j]); + } + data[i] += (int)(matrix_data[i * input_len + input_len - 1]); + } + free(matrix_data); + return NNACL_OK; +} + +int ShapeFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + NNACL_CHECK_TRUE_RET(inputs_size == outputs_size + 1, NNACL_INPUT_TENSOR_ERROR); + const TensorC *in_tensor = inputs[0]; + size_t input_len = in_tensor->shape_size_ + 1; + for (size_t out_idx = 0; out_idx < outputs_size; out_idx++) { + TensorC *out_tensor = outputs[out_idx]; + size_t origin_out_size = + out_tensor->data_ == NULL ? 0 : (out_tensor->shape_size_ == 0 ? 1 : (size_t)out_tensor->shape_[0]); + out_tensor->data_type_ = kNumberTypeInt32; + out_tensor->format_ = in_tensor->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + // calculate output tensor shape. + const TensorC *matrix_tensor = inputs[out_idx + 1]; + if (matrix_tensor->shape_size_ == 1) { + out_tensor->shape_size_ = 0; + out_tensor->shape_[0] = 0; + } else { + out_tensor->shape_size_ = 1; + out_tensor->shape_[0] = (int)(matrix_tensor->shape_[0]); + } + int ret = CalculateOutput(in_tensor, matrix_tensor, out_tensor, input_len, origin_out_size); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +REG_INFER(ShapeFusion, PrimType_Inner_ShapeFusion, ShapeFusionInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/shape_fusion_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/shape_fusion_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a324618dfae3201b4c1552c0ae6f0696b707e2ef --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/shape_fusion_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SHAPE_FUSION_INFER_H_ +#define MINDSPORE_NNACL_SHAPE_FUSION_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ShapeFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SHAPE_FUSION_INFER_H_ diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/utils.cc b/mindspore-lite/ops/kernel/cpu/nnacl/infer/shape_infer.c similarity index 43% rename from mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/utils.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/shape_infer.c index a848cb5bfc44b878564a7971a8a53cb1be445113..05ade6dede0b0e831ae52a28ac652c2bc3c71ca1 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/utils.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/shape_infer.c @@ -13,25 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "cxx_api/utils.h" -#include -#include "mindspore/ccsrc/include/common/utils/comm_manager.h" -#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" -namespace mindspore { -bool CreateGroupsByCkptFile(const std::string &file) { - parallel::GroupInfoMap group_info_map; - if (parallel::StrategyCheckpoint::GetInstance().LoadGroupInfo(file, &group_info_map) != parallel::SUCCESS) { - return false; - } - for (const auto &[group_name, rank_ids] : group_info_map) { - if (!CommManager::GetInstance().CreateGroupSync(group_name, rank_ids)) { - MS_LOG(ERROR) << "Create group " << group_name << " rank ids " << rank_ids << " failed."; - return false; - } +#include "nnacl/infer/shape_infer.h" +#include "nnacl/infer/infer_register.h" + +int ShapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; } - MS_LOG(INFO) << "Create groups by checkpoint file success"; - return true; + const TensorC *in_tensor = inputs[0]; + TensorC *out_tensor = outputs[0]; + + out_tensor->data_type_ = kNumberTypeInt32; + out_tensor->format_ = in_tensor->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + out_tensor->shape_size_ = 1; + out_tensor->shape_[0] = (int)(in_tensor->shape_size_); + return NNACL_OK; } -} // namespace mindspore + +REG_INFER(Shape, PrimType_Shape, ShapeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/shape_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/shape_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..99dd285b9a8c53ea0071943739a8635fb058ea60 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/shape_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SHAPE_INFER_H +#define MINDSPORE_NNACL_SHAPE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ShapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SHAPE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/size_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/size_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..3a7c22f902cf43dc54b56be0e63ed22b1e109abc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/size_infer.c @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/size_infer.h" +#include "nnacl/infer/infer_register.h" + +int SizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_tensor = inputs[0]; + TensorC *out_tensor = outputs[0]; + out_tensor->data_type_ = kNumberTypeInt32; + out_tensor->format_ = in_tensor->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + out_tensor->shape_size_ = 0; + out_tensor->shape_[0] = 1; + + return NNACL_OK; +} + +REG_INFER(SizeOp, PrimType_Size, SizeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/size_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/size_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..23481b69203549ee190318ac0989c4d5d076e164 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/size_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SIZE_INFER_H +#define MINDSPORE_NNACL_SIZE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SIZE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/slice_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/slice_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..1b59023b6593abfb975255ba4a3c735ad1e9aaa0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/slice_infer.c @@ -0,0 +1,126 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/slice_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +static bool CheckInputsDataType(const TensorC *const *inputs, size_t inputs_size) { + // not support data_type of slice's begin and size is not int32 + if (inputs_size >= 2) { + if (inputs[1]->data_type_ != kNumberTypeInt32) { + return false; + } + } + if (inputs_size == 3) { + if (inputs[2]->data_type_ != kNumberTypeInt32) { + return false; + } + } + return true; +} + +int InitBeginAndSizeParam(const TensorC *const *inputs, int *begin, int *size, int param_length) { + /* init begin parameter */ + int slice_begin_size = NNACLGetElementNum(inputs[1]); + int *begin_ptr = (int *)(inputs[1]->data_); + if (slice_begin_size != param_length || begin_ptr == NULL) { + return NNACL_INFER_INVALID; + } + if (slice_begin_size > MAX_AXIS_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < slice_begin_size; i++) { + begin[i] = begin_ptr[i]; + } + + /* init size parameter */ + int slice_size_size = NNACLGetElementNum(inputs[2]); + int *size_ptr = (int *)(inputs[2]->data_); + if (slice_size_size != param_length || size_ptr == NULL) { + return NNACL_INFER_INVALID; + } + if (slice_size_size > MAX_AXIS_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < slice_size_size; i++) { + size[i] = size_ptr[i]; + } + return NNACL_OK; +} + +int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + + if (!CheckInputsDataType(inputs, inputs_size)) { + return NNACL_ERR; + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + SliceParameter *param = (SliceParameter *)parameter; + int param_length = (int)(input->shape_size_); + output->shape_size_ = input->shape_size_; + int begin[MAX_SHAPE_SIZE]; + int size[MAX_SHAPE_SIZE]; + + ret = InitBeginAndSizeParam(inputs, begin, size, param_length); + if (ret != NNACL_OK) { + return ret; + } + + for (int32_t i = 0; i < param_length; ++i) { + if (param->axis_[i] < 0) { + NNACL_CHECK_INT_ADD_NOT_OVERFLOW(param->axis_[i], (int)input->shape_size_, NNACL_PARAM_INVALID); + param->axis_[i] += (int)input->shape_size_; + } + NNACL_CHECK_TRUE_RET(param->axis_[i] >= 0 && param->axis_[i] < param_length, NNACL_PARAM_INVALID); + begin[param->axis_[i]] = begin[i]; + size[param->axis_[i]] = size[i]; + } + + for (int32_t i = 0; i < param_length; ++i) { + if (size[i] < 0 && size[i] != -1) { + return NNACL_PARAM_INVALID; + } + if (begin[i] < 0) { + return NNACL_PARAM_INVALID; + } + if (input->shape_[i] < begin[i]) { + return NNACL_PARAM_INVALID; + } + if (size[i] > (input->shape_[i] - begin[i])) { + return NNACL_PARAM_INVALID; + } + + output->shape_[i] = size[i] < 0 ? input->shape_[i] - begin[i] : size[i]; + } + return NNACL_OK; +} + +REG_INFER(Slice, PrimType_SliceFusion, SliceInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/slice_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/slice_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..50e06cd48722cc941800a871e9793630262bac4e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/slice_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SLICE_INFER_H +#define MINDSPORE_NNACL_SLICE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SLICE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/softmax_cross_entropy_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/softmax_cross_entropy_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..896e66c60d781c09d4d514ca89d47ddc476b2dfe --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/softmax_cross_entropy_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/softmax_cross_entropy_infer.h" +#include "nnacl/infer/infer_register.h" + +int SoftmaxCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + out->shape_size_ = 2; + out->shape_[0] = in0->shape_[0]; + out->shape_[1] = 1; + SetDataTypeFormat(out, in0); + + if (1 < outputs_size) { + TensorC *grads = outputs[1]; + SetShapeTensor(grads, in0); + SetDataTypeFormat(grads, in0); + } + return NNACL_OK; +} + +REG_INFER(SoftmaxCrossEntropyWithLogits, PrimType_SoftmaxCrossEntropyWithLogits, SoftmaxCrossEntropyInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/softmax_cross_entropy_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/softmax_cross_entropy_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c5d89519ef87e5820604d2c58b25a515ef0ece94 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/softmax_cross_entropy_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SOFTMAX_CROSS_ENTROPY_INFER_H +#define MINDSPORE_NNACL_SOFTMAX_CROSS_ENTROPY_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SoftmaxCrossEntropyInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SOFTMAX_ENTROPY_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/softmax_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/softmax_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..300362fd7b562199c1abcdb6ae5f37a8fccabaae --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/softmax_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/softmax_infer.h" +#include "nnacl/infer/infer_register.h" + +int SoftMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + output->data_type_ = input->data_type_; + output->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + // there is a model with an 8-dim input, which runs on ascend910. + if (input->shape_size_ > DIMENSION_8D) { + return NNACL_ERR; + } + + SoftmaxParameter *param = (SoftmaxParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (param->axis_ < (-1 * (int)(input->shape_size_)) || param->axis_ > (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + SetShapeTensor(output, input); + return NNACL_OK; +} + +REG_INFER(Softmax, PrimType_Softmax, SoftMaxInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/softmax_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/softmax_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9ce561666f90f454387858bcf62e7975d7da6661 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/softmax_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SOFTMAX_INFER_H +#define MINDSPORE_NNACL_SOFTMAX_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/softmax_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SoftMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SOFTMAX_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_batch_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_batch_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..98d765418b82da562759c18eb1f85a559c6ad9e5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_batch_infer.c @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/space_to_batch_infer.h" +#include "nnacl/infer/infer_register.h" + +int SpaceToBatchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], input); + SpaceToBatchParameter *param = (SpaceToBatchParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + + int *block_shape = param->block_sizes_; + int block_shape_size = param->m_; + int *paddings = param->paddings_; + int padding_left = 0; + int padding_right = 0; + int block_w = 1; + if (block_shape_size == 2) { + padding_left = paddings[2]; + padding_right = paddings[3]; + block_w = block_shape[1]; + } + + NNACL_CHECK_ZERO_RETURN_ERR(block_shape[0]); + NNACL_CHECK_ZERO_RETURN_ERR(block_w); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(block_shape[0], block_w, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input->shape_[kNHWC_N], block_shape[0] * block_w, NNACL_ERR); + outputs[0]->shape_[kNHWC_N] = input->shape_[kNHWC_N] * (block_shape[0] * block_w); + outputs[0]->shape_[kNHWC_H] = (input->shape_[kNHWC_H] + paddings[0] + paddings[1]) / block_shape[0]; + outputs[0]->shape_[kNHWC_W] = (input->shape_[kNHWC_W] + padding_left + padding_right) / block_w; + outputs[0]->shape_[kNHWC_C] = input->shape_[kNHWC_C]; + outputs[0]->shape_size_ = input->shape_size_; + return NNACL_OK; +} + +REG_INFER(SpaceToBatch, PrimType_SpaceToBatch, SpaceToBatchInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_batch_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_batch_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..393a958c683dfd1a607693b2ad1ebc4ae859c014 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_batch_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPACE_TO_BATCH_INFER_H +#define MINDSPORE_NNACL_SPACE_TO_BATCH_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/space_to_batch_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SpaceToBatchInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPACE_TO_BATCH_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_batch_nd_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_batch_nd_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2d91a4c38972a0a3e05a565cb03ecfa5997faf37 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_batch_nd_infer.c @@ -0,0 +1,143 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/space_to_batch_nd_infer.h" +#include +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int SpaceSetOutputShapeFromParam(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, const OpParameter *parameter) { + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + const SpaceToBatchParameter *param = (const SpaceToBatchParameter *)parameter; + const int *block_shape = param->block_sizes_; + int block_shape_size = param->m_; + const int *padding = param->paddings_; + int padding_left = 0; + int padding_right = 0; + int block_w = 1; + if (block_shape_size == 2) { + padding_left = padding[2]; + padding_right = padding[3]; + block_w = block_shape[1]; + } + if (input->shape_[kNHWC_N] == 0 || block_shape[0] * block_w > INT_MAX / input->shape_[kNHWC_N]) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_N] = input->shape_[kNHWC_N] * block_shape[0] * block_w; + if (padding[0] + padding[1] > INT_MAX - input->shape_[kNHWC_H]) { + return NNACL_ERR; + } + if (block_shape[0] == 0 || block_w == 0) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_H] = (input->shape_[kNHWC_H] + padding[0] + padding[1]) / block_shape[0]; + if (padding_left + padding_right > INT_MAX - input->shape_[kNHWC_W]) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_W] = (input->shape_[kNHWC_W] + padding_left + padding_right) / block_w; + if (input->shape_size_ > 3) { + outputs[0]->shape_[kNHWC_C] = input->shape_[kNHWC_C]; + } + outputs[0]->shape_size_ = input->shape_size_; + return NNACL_OK; +} + +int SpaceSetOutputShapeFromInput(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + const TensorC *input = inputs[0]; + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + if (NNACLGetElementNum(inputs[2]) != 4) { + return NNACL_ERR; + } + int *block_shape = (int *)(inputs[1]->data_); + int *padding = (int *)(inputs[2]->data_); + int padding_left = 0; + int padding_right = 0; + int block_w = 1; + if (NNACLGetElementNum(inputs[1]) == 2) { + padding_left = padding[2]; + padding_right = padding[3]; + block_w = block_shape[1]; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = input->shape_size_; + if (input->shape_[kNHWC_N] == 0 || block_shape[0] * block_w > INT_MAX / input->shape_[kNHWC_N]) { + return NNACL_ERR; + } + output_shape[kNHWC_N] = input->shape_[kNHWC_N] * block_shape[0] * block_w; + if (padding[0] + padding[1] > INT_MAX - input->shape_[kNHWC_H]) { + return NNACL_ERR; + } + if (block_shape[0] == 0 || block_w == 0) { + return NNACL_ERR; + } + output_shape[kNHWC_H] = (input->shape_[kNHWC_H] + padding[0] + padding[1]) / block_shape[0]; + if (padding_left + padding_right > INT_MAX - input->shape_[kNHWC_W]) { + return NNACL_ERR; + } + output_shape[kNHWC_W] = (input->shape_[kNHWC_W] + padding_left + padding_right) / block_w; + if (input->shape_size_ > 3) { + output_shape[kNHWC_C] = input->shape_[kNHWC_C]; + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +int SpaceToBatchNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_ERR; + } + outputs[0]->data_type_ = input->data_type_; + outputs[0]->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (inputs_size == 1) { + int ret = SpaceSetOutputShapeFromParam(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + } + if (inputs_size == 3) { + if (inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int ret = SpaceSetOutputShapeFromInput(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +REG_INFER(SpaceToBatchND, PrimType_SpaceToBatchND, SpaceToBatchNdInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_batch_nd_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_batch_nd_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d63b46d48056e68574704aeffc778b8ea88a61dc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_batch_nd_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPACE_TO_BATCH_ND_INFER_H +#define MINDSPORE_NNACL_SPACE_TO_BATCH_ND_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/space_to_batch_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SpaceToBatchNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPACE_TO_BATCH_ND_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_depth_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_depth_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..1b28c03f01d17529d0babdafc89421dfca21c6bf --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_depth_infer.c @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/space_to_depth_infer.h" +#include +#include "nnacl/infer/infer_register.h" + +int SpaceToDepthInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + if (input->format_ != Format_NHWC) { + return NNACL_FORMAT_ERROR; + } + SetDataTypeFormat(outputs[0], input); + SpaceToDepthParameter *param = (SpaceToDepthParameter *)parameter; + NNACL_CHECK_NULL_RETURN_ERR(param); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ != 4) { + return NNACL_ERR; + } + + int32_t block_size = param->block_size_; + if (block_size == 0) { + return NNACL_ERR; + } + if (input->shape_[kNHWC_H] % block_size != 0 || input->shape_[kNHWC_H] == 0 || + input->shape_[kNHWC_W] % block_size != 0 || input->shape_[kNHWC_W] == 0) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_N] = input->shape_[kNHWC_N]; + outputs[0]->shape_[kNHWC_H] = input->shape_[kNHWC_H] / block_size; + outputs[0]->shape_[kNHWC_W] = input->shape_[kNHWC_W] / block_size; + if (input->shape_[kNHWC_C] == 0 || block_size * block_size > INT_MAX / input->shape_[kNHWC_C]) { + return NNACL_ERR; + } + outputs[0]->shape_[kNHWC_C] = input->shape_[kNHWC_C] * (block_size * block_size); + outputs[0]->shape_size_ = input->shape_size_; + return NNACL_OK; +} + +REG_INFER(SpaceToDepth, PrimType_SpaceToDepth, SpaceToDepthInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_depth_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_depth_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..809a637ccc29a9b2fe87b868a251abb3e6fe766b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/space_to_depth_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPACE_TO_DEPTH_INFER_H +#define MINDSPORE_NNACL_SPACE_TO_DEPTH_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/space_to_depth_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SpaceToDepthInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPACE_TO_DEPTH_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_fill_empty_rows_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_fill_empty_rows_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..3f0a966b9eaf7918fdb70eae61913042444b727a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_fill_empty_rows_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/sparse_fill_empty_rows_infer.h" +#include "nnacl/infer/infer_register.h" + +int SparseFillEmptyRowsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullInputSize(inputs, inputs_size, outputs, outputs_size, parameter, C4NUM); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input0 = inputs[0]; + TensorC *output0 = outputs[0]; + SetDataTypeFormat(output0, input0); + + const TensorC *input1 = inputs[1]; + TensorC *output1 = outputs[1]; + SetDataTypeFormat(output1, input1); + + TensorC *output2 = outputs[C2NUM]; + SetDataTypeFormat(output2, input0); + output2->data_type_ = kNumberTypeBool; + + if (outputs_size == C4NUM) { + TensorC *output3 = outputs[C3NUM]; + SetDataTypeFormat(output3, input0); + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + return NNACL_INFER_INVALID; +} + +REG_INFER(SparseFillEmptyRows, PrimType_SparseFillEmptyRows, SparseFillEmptyRowsInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_fill_empty_rows_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_fill_empty_rows_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..bba313008d899565a4bcb74d3827e9e59947e079 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_fill_empty_rows_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPARSE_FILL_EMPTY_ROWS_H +#define MINDSPORE_NNACL_SPARSE_FILL_EMPTY_ROWS_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseFillEmptyRowsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPARSE_FILL_EMPTY_ROWS_H diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/scale_fusion.cc b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_reshape_infer.c similarity index 31% rename from mindspore-lite/tools/graph_kernel/converter/expanders/scale_fusion.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_reshape_infer.c index 7cfb0f56972a181c4df1b803080e4fcb9773c3c9..7419e45f3711353dadabb4e724e527d87fde51e4 100644 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/scale_fusion.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_reshape_infer.c @@ -14,41 +14,40 @@ * limitations under the License. */ -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "tools/graph_kernel/converter/expanders/activation.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" - -namespace mindspore::graphkernel::expanders { -constexpr size_t kInputIdx = 0; -constexpr size_t kScaleIdx = 1; -constexpr size_t kOffsetIdx = 2; -class ScaleFusion : public OpDesc { - public: - ScaleFusion() { - (void)validators_.emplace_back(std::make_unique(ActivationType::NO_ACTIVATION)); +#include "nnacl/infer/sparse_reshape_infer.h" +#include "nnacl/infer/infer_register.h" + +int SparseReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C2NUM); + if (check_ret != NNACL_OK) { + return check_ret; } - ~ScaleFusion() = default; - - protected: - bool CheckInputs() override { - auto axis = GetAxisList(attrs_["axis"])[0]; - size_t input_shape_size = inputs_info_[kInputIdx].shape.size(); - size_t scale_shape_size = inputs_info_[kScaleIdx].shape.size(); - axis = axis < 0 ? axis + SizeToLong(input_shape_size) : axis; - return (LongToSize(axis) + scale_shape_size) == input_shape_size; + + const TensorC *in_indices_tensor = inputs[0]; + TensorC *out_indices_tensor = outputs[0]; + SetDataTypeFormat(out_indices_tensor, in_indices_tensor); + + const TensorC *in_out_shape_tensor = inputs[C2NUM]; + TensorC *out_shape_tensor = outputs[C1NUM]; + SetDataTypeFormat(out_shape_tensor, in_out_shape_tensor); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; } - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &input_x = inputs[kInputIdx]; - const auto &input_scale = inputs[kScaleIdx]; - const auto &input_offset = inputs[kOffsetIdx]; - auto mul = gb.Mul(input_x, input_scale); - auto mul_add = gb.Add(mul, input_offset); - return {mul_add}; + SetShapeArray(out_shape_tensor, in_out_shape_tensor->shape_, in_out_shape_tensor->shape_size_); + + int out_indices_shape[MAX_SHAPE_SIZE] = {0}; + out_indices_shape[0] = in_indices_tensor->shape_[0]; + size_t out_indices_shape_size = 1; + + for (int i = 0; i < in_out_shape_tensor->shape_size_; ++i) { + out_indices_shape[i + 1] = in_out_shape_tensor->shape_[i]; + out_indices_shape_size++; } -}; -EXPANDER_OP_DESC_REGISTER("ScaleFusion", ScaleFusion); -} // namespace mindspore::graphkernel::expanders + SetShapeArray(out_indices_tensor, out_indices_shape, out_indices_shape_size); + return NNACL_OK; +} + +REG_INFER(SparseReshape, PrimType_SparseReshape, SparseReshapeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_reshape_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_reshape_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..1f4059b2ba8bbbddfb12e4897d6b1a1a00de91b8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_reshape_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPARSE_RESHAPE_INFER_H +#define MINDSPORE_NNACL_SPARSE_RESHAPE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPARSE_RESHAPE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_segment_sum_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_segment_sum_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..37103ff0e6c54f7ef29869b250af51bf312d376e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_segment_sum_infer.c @@ -0,0 +1,37 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/sparse_segment_sum_infer.h" +#include "nnacl/infer/infer_register.h" + +int SparseSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + return NNACL_OK; +} + +REG_INFER(SparseSegmentSum, PrimType_SparseSegmentSum, SparseSegmentSumInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_segment_sum_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_segment_sum_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..f973df75ecebec36736c8b4c4f4d7bab054d80b7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_segment_sum_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPARSE_SEGMENT_SUM_INFER_H +#define MINDSPORE_NNACL_SPARSE_SEGMENT_SUM_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPARSE_SEGMENT_SUM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..4d2065ced3e35834cd7dc6c236b5c361c30a3e55 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h" +#include "nnacl/fp32_grad/softmax_grad.h" +#include "nnacl/infer/infer_register.h" + +int SparseSoftmaxCrossEntropyWithLogitsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in0 = inputs[0]; + TensorC *out = outputs[0]; + + SoftmaxCrossEntropyParameter *param = (SoftmaxCrossEntropyParameter *)parameter; + if (param->is_grad_ != 0) { + SetShapeTensor(out, in0); + SetDataTypeFormat(out, in0); + } else { + out->shape_size_ = 1; + out->shape_[0] = 1; + SetDataTypeFormat(out, in0); + } + + return NNACL_OK; +} + +REG_INFER(SparseSoftmaxCrossEntropyWithLogits, PrimType_SparseSoftmaxCrossEntropyWithLogits, + SparseSoftmaxCrossEntropyWithLogitsInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..614d1436afbb526ec688989b7f99548031829294 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_softmax_cross_entropy_with_logits_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H_ +#define MINDSPORE_NNACL_INFER_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseSoftmaxCrossEntropyWithLogitsInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_to_dense_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_to_dense_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2fa5c3aeb0feb1b94bc6a6f048b18f92da9e5e10 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_to_dense_infer.c @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/sparse_to_dense_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int SparseToDenseInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorC *output = outputs[0]; + if (inputs_size < 3) { + return NNACL_INPUT_TENSOR_ERROR; + } + const TensorC *input1 = inputs[1]; + SetDataTypeFormat(output, input1); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int *input1_data = (int *)(input1->data_); + int data_num = NNACLGetElementNum(input1); + if (input1_data == 0 || data_num > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (int i = 0; i < data_num; i++) { + ShapePush(output_shape, &output_shape_size, input1_data[i]); + } + SetShapeArray(output, output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(SparseToDense, PrimType_SparseToDense, SparseToDenseInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_to_dense_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_to_dense_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a520734e5649f6fc7e1197ed19bf0bf89fa0ddc6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/sparse_to_dense_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPACE_TO_DENSE_INFER_H +#define MINDSPORE_NNACL_SPACE_TO_DENSE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SparseToDenseInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPACE_TO_DENSE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/splice_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/splice_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..1b865e84cd2a7b2f14ce8fd8611852bfb77ee5fc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/splice_infer.c @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/splice_infer.h" +#include "nnacl/infer/infer_register.h" + +int SpliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ != DIMENSION_3D) { + return NNACL_INPUT_TENSOR_ERROR; + } + SpliceParameter *param = (SpliceParameter *)parameter; + if (param == NULL) { + return NNACL_NULL_PTR; + } + int out_dim = param->output_dim_; + ShapeSet(output->shape_, &output->shape_size_, input->shape_, input->shape_size_); + + if (param->context_dim_ == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + if (param->forward_indexes_dim_ % param->context_dim_ != 0) { + return NNACL_PARAM_INVALID; + } + int out_size = param->forward_indexes_dim_ / param->context_dim_; + output->shape_[DIMENSION_1D] = out_size; + output->shape_[DIMENSION_2D] = out_dim; + return NNACL_OK; +} + +REG_INFER(Splice, PrimType_Splice, SpliceInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/splice_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/splice_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..54b6d3a1e91a554d043fbac37da728e4a8df470f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/splice_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_INFER_SPLICE_INFER_H_ +#define MINDSPORE_NNACL_INFER_SPLICE_INFER_H_ +#include "nnacl/infer/common_infer.h" +#include "nnacl/splice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SpliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_SPLICE_INFER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..1d3aabe17df14cd2efdf0cd7dafc23daa0730ddc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_infer.c @@ -0,0 +1,120 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/split_infer.h" +#include "nnacl/infer/infer_register.h" + +int UpdateSplitSize(const TensorC *const *inputs, size_t inputs_size, SplitParameter *param) { + // get split size from the second input. + if (inputs_size == DIMENSION_2D && inputs[SECOND_INPUT]->data_ != NULL) { + if (inputs[SECOND_INPUT]->data_type_ != kNumberTypeInt32) { + return NNACL_ERR; + } + int split_count = 1; + for (size_t i = 0; i < inputs[SECOND_INPUT]->shape_size_; i++) { + split_count *= inputs[SECOND_INPUT]->shape_[i]; + } + param->split_count_ = split_count; + for (int i = 0; i < split_count; i++) { + param->split_sizes_[i] = ((int *)(inputs[SECOND_INPUT]->data_))[i]; + } + } + if (param->split_count_ == 0) { + const TensorC *input = inputs[0]; + int32_t split_chunk_size = UP_DIV(input->shape_[param->split_dim_], param->num_split_); + for (int i = 0; i < param->num_split_; ++i) { + if (i != param->num_split_ - 1) { + param->split_sizes_[i] = split_chunk_size; + } else { + param->split_sizes_[i] = input->shape_[param->split_dim_] - split_chunk_size * i; + } + } + } + return NNACL_OK; +} + +int SetSplitOutputShape(const TensorC *input, TensorC **outputs, SplitParameter *param) { + for (int i = 0; i < param->num_split_; ++i) { + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, input->shape_, input->shape_size_); + int split_dim_i = input->shape_[param->split_dim_]; + if (i == param->num_split_ - 1 && param->split_sizes_[i] == -1) { + if (param->num_split_ - 1 < 0) { + return NNACL_ERR; + } + for (int j = 0; j < param->num_split_ - 1; ++j) { + split_dim_i -= param->split_sizes_[j]; + } + param->split_sizes_[i] = split_dim_i; + } else { + split_dim_i = param->split_sizes_[i]; + } + NNACL_CHECK_TRUE_RET(split_dim_i >= 0 && split_dim_i <= input->shape_[param->split_dim_], NNACL_ERR); + output_shape[param->split_dim_] = split_dim_i; + SetShapeArray(outputs[i], output_shape, output_shape_size); + SetDataTypeFormat(outputs[i], input); + } + return NNACL_OK; +} + +int SplitInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + for (size_t i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } + + SplitParameter *param = (SplitParameter *)parameter; + + int num_split = param->num_split_ == 0 ? (int)(outputs_size) : param->num_split_; + if (num_split == 0) { + return NNACL_ERR; + } + param->num_split_ = num_split; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int split_dim = param->split_dim_ < 0 ? ((int)(input->shape_size_)) + param->split_dim_ : param->split_dim_; + if (split_dim >= (int)(input->shape_size_) || split_dim < 0) { + return NNACL_ERR; + } + param->split_dim_ = split_dim; + if ((int)(outputs_size) != num_split) { + return NNACL_ERR; + } + + int ret = UpdateSplitSize(inputs, inputs_size, param); + if (ret != NNACL_OK) { + return ret; + } + ret = SetSplitOutputShape(input, outputs, param); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +REG_INFER(Split, PrimType_Split, SplitInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..67733fa841e9335751a10652921596ada9abc692 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPLIT_INFER_H +#define MINDSPORE_NNACL_SPLIT_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SplitInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_reduce_concat_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_reduce_concat_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..0f6d2ecca22304003e2cbece95ce47624bb25660 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_reduce_concat_infer.c @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/split_reduce_concat_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/split_parameter.h" + +int SplitReduceConcatFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + NNACL_CHECK_TRUE_RET(inputs_size == outputs_size, NNACL_INPUT_TENSOR_ERROR); + const TensorC *in_tensor = inputs[0]; + TensorC *out_tensor = outputs[0]; + out_tensor->format_ = in_tensor->format_; + for (size_t i = 0; i < in_tensor->shape_size_; i++) { + out_tensor->shape_[i] = in_tensor->shape_[i]; + } + SplitParameter *param = (SplitParameter *)parameter; + out_tensor->shape_[param->split_dim_] = param->num_split_; + out_tensor->shape_size_ = in_tensor->shape_size_; + return NNACL_OK; +} + +REG_INFER(SplitReduceConcatFusion, PrimType_Inner_SplitReduceConcatFusion, SplitReduceConcatFusionInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_reduce_concat_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_reduce_concat_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..b47818397f73cfb051e944c3b4c52774a56cb772 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_reduce_concat_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H +#define MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SplitReduceConcatFusionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_REDUCE_CONCAT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_with_over_lap_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_with_over_lap_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..5017e03e90a39e96eae452dd1283f7a985c4944e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_with_over_lap_infer.c @@ -0,0 +1,84 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/split_with_over_lap_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/op_base.h" + +int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *input = inputs[0]; + SplitWithOverlapParameter *param = (SplitWithOverlapParameter *)parameter; + + int split_dim = param->split_dim_; + int number_split = param->num_split_; + if (outputs_size != (size_t)number_split) { + return NNACL_ERR; + } + + int ratio[SPLIT_MAX_SLICE_NUM]; + int extend_top[SPLIT_MAX_SLICE_NUM]; + int extend_bottom[SPLIT_MAX_SLICE_NUM]; + for (int i = 0; i < number_split; ++i) { + ratio[i] = param->ratio_[i]; + extend_top[i] = param->extend_top_[i]; + extend_bottom[i] = param->extend_bottom_[i]; + } + + const int *input_shape = input->shape_; + int split_dim_size = input_shape[split_dim]; + int total_block_count = 0; + for (int i = 0; i < number_split; i++) { + total_block_count += ratio[i]; + } + + int borders[MAX_SHAPE_SIZE]; + borders[0] = 0; + int visited_block = 0; + for (int i = 0; i < number_split - 1; i++) { + visited_block += ratio[i]; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(split_dim_size, visited_block) || total_block_count == 0, NNACL_ERR); + int cur_border = UP_DIV(split_dim_size * visited_block, total_block_count); + borders[i + 1] = cur_border; + } + borders[number_split] = split_dim_size; + + for (int i = 0; i < number_split; ++i) { + int output_shape[MAX_SHAPE_SIZE]; + for (int dim = 0; dim < input->shape_size_; dim++) { + if (dim == split_dim) { + int splited_size = borders[i + 1] - borders[i]; + splited_size += (extend_top[i] + extend_bottom[i]); + output_shape[dim] = splited_size; + } else { + output_shape[dim] = input_shape[dim]; + } + } + SetShapeArray(outputs[i], output_shape, input->shape_size_); + SetDataTypeFormat(outputs[i], input); + } + return NNACL_OK; +} + +REG_INFER(SplitWithOverlap, PrimType_SplitWithOverlap, SplitWithOverlapInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_with_over_lap_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_with_over_lap_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9e0793d4580f1cbb492852cfdae31a01a93ce08d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/split_with_over_lap_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H +#define MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SplitWithOverlapInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SPLIT_WITH_OVER_LAP_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/squeeze_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/squeeze_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..7af14836e68fd74e883cd2b3d8136a2ad9c351a5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/squeeze_infer.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/squeeze_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int SqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = + CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, kInputSize1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + SqueezeParameter *param = (SqueezeParameter *)parameter; + SetDataTypeFormat(outputs[0], input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + + if (inputs_size == kInputSize1) { + NNACL_CHECK_TRUE_RET(inputs[1]->data_type_ == kNumberTypeInt32 || inputs[1]->data_type_ == kNumberTypeInt, + NNACL_PARAM_INVALID); + int *axis_data = (int *)(inputs[1]->data_); + NNACL_CHECK_TRUE_RET(axis_data != NULL, NNACL_PARAM_INVALID); + param->axis_size_ = NNACLGetElementNum(inputs[1]); + for (size_t i = 0; i < param->axis_size_; i++) { + param->axis_[i] = *(axis_data + i); + } + } + if (param->axis_size_ > MAX_SHAPE_SIZE) { + return NNACL_PARAM_INVALID; + } + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + + for (size_t i = 0; i < param->axis_size_; i++) { + param->axis_[i] = param->axis_[i] >= 0 ? param->axis_[i] : param->axis_[i] + (int)input->shape_size_; + } + + if (param->axis_size_ == 0) { + for (size_t i = 0; i < input->shape_size_; i++) { + if (input->shape_[i] != 1) { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + } else { + size_t axisIdx = 0; + for (size_t i = 0; i < input->shape_size_; i++) { + if (axisIdx < param->axis_size_ && param->axis_[axisIdx] == (int)(i)) { + if (input->shape_[i] != 1) return NNACL_PARAM_INVALID; + axisIdx++; + continue; + } else { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + } + SetShapeArray(outputs[0], out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Squeeze, PrimType_Squeeze, SqueezeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/squeeze_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/squeeze_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..6505b573c917c0f6d72f394e5e4ac76112ac7a6d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/squeeze_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_SQUEEZE_INFER_H +#define MINDSPORE_NNACL_SQUEEZE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/squeeze_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_SQUEEZE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/stack_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/stack_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..fe216dbf51578d2a5696fbcb4d798e5e48abef87 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/stack_infer.c @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/stack_infer.h" +#include "nnacl/infer/infer_register.h" + +int StackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (ret != NNACL_OK) { + return ret; + } + if (outputs_size != 1) { + return NNACL_PARAM_INVALID; + } + if (inputs_size < 1) { + return NNACL_PARAM_INVALID; + } + const TensorC *input = inputs[0]; + SetDataTypeFormat(outputs[0], input); + StackParameter *param = (StackParameter *)parameter; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, input->shape_, input->shape_size_); + int axis = param->axis_ < 0 ? (int)(param->axis_) + (int)(input->shape_size_) + 1 : param->axis_; + if (axis < 0 || axis > (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + + for (size_t i = 1; i < inputs_size; ++i) { + if (inputs[i]->shape_size_ != input->shape_size_) { + return NNACL_PARAM_INVALID; + } + for (size_t j = 0; j < input->shape_size_; ++j) { + if (inputs[i]->shape_[j] != input->shape_[j]) { + return NNACL_PARAM_INVALID; + } + } + if (inputs[i]->data_type_ != input->data_type_) { + return NNACL_PARAM_INVALID; + } + } + int insert_ret = ShapeInsert(output_shape, &output_shape_size, axis, inputs_size); + if (insert_ret != NNACL_OK) { + return NNACL_ERR; + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(Stack, PrimType_Stack, StackInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/stack_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/stack_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..d50a0ccb68c49cb1c377534d6045da170d598cc0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/stack_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_STACK_INFER_H +#define MINDSPORE_NNACL_STACK_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/stack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int StackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_STACK_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_grad_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_grad_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..770c419975c46082af4bc67c3db402c51d826a70 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_grad_infer.c @@ -0,0 +1,160 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/strided_slice_grad_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +bool StridedSliceCheckInputs(const TensorC *const *inputs, size_t inputs_size) { + for (size_t i = 1; i < inputs_size; ++i) { + if (inputs[i]->data_ == NULL) { + return false; + } + } + if (NNACLGetElementNum(inputs[2]) > MAX_SHAPE_SIZE) { + return false; + } + if (NNACLGetElementNum(inputs[2]) != NNACLGetElementNum(inputs[3]) && + NNACLGetElementNum(inputs[2]) != NNACLGetElementNum(inputs[4])) { + return false; + } + return true; // note: the original code is ndim_ <= in_shape_size +} + +void ApplyBeginEndEllipsisMask(size_t ndim, int *begins, const uint32_t *const begins_mask, int *ends, + const uint32_t *const ends_mask, const uint32_t *const ellipsis_mask, + const int *const in_shape) { + for (size_t i = 0; i < ndim; i++) { + if (begins_mask[i]) { + begins[i] = 0; + } + if (ends_mask[i]) { + ends[i] = in_shape[i]; + } + } + for (size_t i = 0; i < ndim; i++) { + if (ellipsis_mask[i]) { + begins[i] = 0; + ends[i] = in_shape[i]; + break; + } + } +} + +int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + SetDataTypeFormat(outputs[0], input); + bool inferflag = InferFlag(inputs, inputs_size); + + int in_shape_[MAX_SHAPE_SIZE] = {0}; + size_t in_shape_size = 0; + if (inferflag) { + ShapeSet(in_shape_, &in_shape_size, input->shape_, input->shape_size_); + } + int begins_[MAX_SHAPE_SIZE] = {0}; + size_t begins_size = 0; + int ends_[MAX_SHAPE_SIZE] = {0}; + size_t ends_size = 0; + int strides_[MAX_SHAPE_SIZE] = {0}; + size_t strides_size = 0; + + if (!StridedSliceCheckInputs(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + // input order: dy, shapex, begins, ends, strides. + const TensorC *begin_tensor = inputs[2]; + int *begin_data = (int *)(begin_tensor->data_); + int *end_data = (int *)(inputs[3]->data_); + int *stride_data = (int *)(inputs[4]->data_); + + size_t ndim_ = (size_t)NNACLGetElementNum(begin_tensor); + for (size_t i = 0; i < ndim_; ++i) { + ShapePush(begins_, &begins_size, begin_data[i]); + ShapePush(ends_, &ends_size, end_data[i]); + ShapePush(strides_, &strides_size, stride_data[i]); + } + + // set all mask to original input shape + uint32_t begins_mask_[MAX_SHAPE_SIZE] = {0}; + uint32_t ends_mask_[MAX_SHAPE_SIZE] = {0}; + uint32_t ellipsis_mask_[MAX_SHAPE_SIZE] = {0}; + uint32_t new_axis_mask_[MAX_SHAPE_SIZE] = {0}; + + StridedSliceParameter *param = (StridedSliceParameter *)parameter; + for (size_t i = 0; i < ndim_; i++) { + begins_mask_[i] = (unsigned)(param->begins_mask_) & (1 << i); + ends_mask_[i] = (unsigned)(param->ends_mask_) & (1 << i); + ellipsis_mask_[i] = (unsigned)(param->ellipsisMask_) & (1 << i); + new_axis_mask_[i] = (unsigned)(param->newAxisMask_) & (1 << i); + } + param->num_axes_ = (int)(in_shape_size); + param->in_shape_length_ = (int)(in_shape_size); + for (size_t i = 0; i < ndim_; ++i) { + param->begins_[i] = begins_[i]; + param->ends_[i] = ends_[i]; + param->strides_[i] = strides_[i]; + } + ShapeSet(param->in_shape_, &in_shape_size, input->shape_, input->shape_size_); + // ApplyNewAxisMask; + for (size_t i = 0; i < ndim_; i++) { + if (new_axis_mask_[i]) { + ndim_ += 1; + int ret = ShapeInsert(in_shape_, &in_shape_size, i, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + begins_[i] = 0; + ends_[i] = 1; + strides_[i] = 1; + + ShapePush(begins_, &begins_size, 0); + ShapePush(ends_, &ends_size, in_shape_[ndim_ - 1]); + ShapePush(strides_, &strides_size, 1); + + begins_mask_[i] = false; + ends_mask_[i] = false; + ellipsis_mask_[i] = false; + } + } + ApplyBeginEndEllipsisMask(ndim_, begins_, begins_mask_, ends_, ends_mask_, ellipsis_mask_, in_shape_); + if (!inferflag) { + return NNACL_OK; + } + int output_size = inputs[1]->shape_[0]; + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + if (inputs[1]->data_ == NULL) { + return NNACL_ERR; + } + + if (output_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < output_size; i++) { + ShapePush(output_shape, &output_shape_size, ((int *)(inputs[1]->data_))[i]); + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(StridedSliceGrad, PrimType_StridedSliceGrad, StridedSliceGradInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_grad_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_grad_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..db2525312422aaa4c5a55267a2771a56bebddfca --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_grad_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_STRIDED_SLICE_GRAD_INFER_H +#define MINDSPORE_NNACL_STRIDED_SLICE_GRAD_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_STRIDED_SLICE_GRAD_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..bf266159dec3ad5acfa7a6265e6689449394f757 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_infer.c @@ -0,0 +1,483 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/strided_slice_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/op_base.h" +#include "nnacl/tensor_c_utils.h" + +const size_t kStridedSliceOutputNum = 1; +const size_t kStridedSliceInputNum = 1; +const size_t kStridedSliceMultiInputNumMin = 3; +const size_t kStridedSliceMultiInputNumMax = 5; + +typedef struct StridedSliceTransferBuffer { + int ndim_; + + int begins_[MAX_SHAPE_SIZE]; + int ends_[MAX_SHAPE_SIZE]; + int strides_[MAX_SHAPE_SIZE]; + int begins_mask_[MAX_SHAPE_SIZE]; + int ends_mask_[MAX_SHAPE_SIZE]; + int ellipsis_mask_[MAX_SHAPE_SIZE]; + int new_axis_mask_[MAX_SHAPE_SIZE]; + int shrink_axis_mask_[MAX_SHAPE_SIZE]; + + size_t begins_size_; + size_t ends_size_; + size_t strides_size_; + size_t ellipsis_mask_size_; + size_t new_axis_mask_size_; + size_t shrink_axis_mask_size_; +} StridedSliceTransferBuffer; + +bool CheckInputs(const TensorC *const *inputs, size_t inputs_size) { + for (size_t i = 1; i < inputs_size; ++i) { + if (inputs[i]->data_ == NULL) { + return false; + } + } + return true; +} + +int HandleAxesCheckNull(const TensorC *input_tensor, const TensorC *begin_tensor, int *begin_data, + const TensorC *end_tensor, int *end_data) { + if (input_tensor == NULL || begin_tensor == NULL || end_tensor == NULL || begin_data == NULL || end_data == NULL) { + return NNACL_NULL_PTR; + } + return NNACL_OK; +} + +int HandleAxesInputNotExist(const TensorC *const *inputs, struct StridedSliceTransferBuffer *transfer_buffer) { + const TensorC *begin_tensor = inputs[1]; + const TensorC *end_tensor = inputs[2]; + const TensorC *stride_tensor = inputs[3]; + int ret = GetInt32DataFromTensor(begin_tensor, transfer_buffer->begins_, &transfer_buffer->begins_size_); + if (ret != NNACL_OK) { + return ret; + } + transfer_buffer->ndim_ = NNACLGetElementNum(begin_tensor); + ret = GetInt32DataFromTensor(end_tensor, transfer_buffer->ends_, &transfer_buffer->ends_size_); + if (ret != NNACL_OK) { + return ret; + } + ret = GetInt32DataFromTensor(stride_tensor, transfer_buffer->strides_, &transfer_buffer->strides_size_); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +int GenerateAxes(const TensorC *axes_tensor, int *axes, int num, int ndim) { + int *axes_data = NULL; + if (NNACLGetElementNum(axes_tensor) != 0) { + if (NNACLGetElementNum(axes_tensor) != num) { + return NNACL_ERR; + } + axes_data = (int *)(axes_tensor->data_); + if (axes_data == NULL) { + return NNACL_NULL_PTR; + } + } + if (axes_data == NULL) { + for (int i = 0; i < num; ++i) { + axes[i] = i; + } + } else { + for (int i = 0; i < num; i++) { + axes[i] = axes_data[i]; + } + for (int i = 0; i < num; ++i) { + if (axes[i] < 0) { + axes[i] += ndim; + } + } + } + return NNACL_OK; +} + +int HandleAxesInputExist(const TensorC *const *inputs, int *ndim, int *in_shape, int *begins, int *strides, int *ends) { + const TensorC *input_tensor = inputs[0]; + const TensorC *begin_tensor = inputs[1]; + int begin_data[MAX_SHAPE_SIZE]; + size_t begin_data_size; + int ret = GetInt32DataFromTensor(begin_tensor, begin_data, &begin_data_size); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *end_tensor = inputs[2]; + int end_data[MAX_SHAPE_SIZE]; + size_t end_data_size; + ret = GetInt32DataFromTensor(end_tensor, end_data, &end_data_size); + if (ret != NNACL_OK) { + return ret; + } + + int handle_check_ret = HandleAxesCheckNull(input_tensor, begin_tensor, begin_data, end_tensor, end_data); + if (handle_check_ret != NNACL_OK) { + return handle_check_ret; + } + + // when input contains axes, begins, ends, strides will be expand to the same length as input rank + *ndim = (int)(input_tensor->shape_size_); + int begin_ndim = NNACLGetElementNum(begin_tensor); + + int *stride_data = NULL; + const TensorC *stride_tensor = inputs[4]; + int stride_data_num = NNACLGetElementNum(stride_tensor); + if (stride_data_num != 0) { + NNACL_CHECK_TRUE_RET(stride_data_num == begin_ndim, NNACL_ERR); + stride_data = (int *)(stride_tensor->data_); + } + + const TensorC *axes_tensor = inputs[3]; + int axes[MAX_SHAPE_SIZE] = {0}; + ret = GenerateAxes(axes_tensor, axes, begin_ndim, *ndim); + if (ret != NNACL_OK) { + return ret; + } + + if (*ndim > MAX_SHAPE_SIZE || *ndim < 0) { + return NNACL_ERR; + } + for (int i = 0; i < *ndim; i++) { + in_shape[i] = 0; + begins[i] = 0; + strides[i] = 0; + } + for (int i = 0; i < *ndim; ++i) { + in_shape[i] = input_tensor->shape_[i]; + } + for (int i = 0; i < *ndim; ++i) { + int axes_it = 0; + if (begin_ndim > MAX_SHAPE_SIZE || begin_ndim < 0) { + return NNACL_ERR; + } + for (int j = 0; j < begin_ndim; j++) { + if (axes[j] == i) { + axes_it = j; + break; + } else { + axes_it++; + } + } + if (axes_it != begin_ndim) { + int axis = axes_it; + if (begin_data[axis] > input_tensor->shape_[i] - 1) { + begins[i] = begin_data[axis]; + } else { + begins[i] = imax(imin(begin_data[axis], input_tensor->shape_[i] - 1), -input_tensor->shape_[i]); + } + // ends exceed limit will be set to limit + ends[i] = imax(imin(end_data[axis], input_tensor->shape_[i]), -input_tensor->shape_[i] - 1); + if (stride_data == NULL) { + return NNACL_ERR; + } + strides[i] = stride_data[axis]; + } else { + begins[i] = 0; + ends[i] = input_tensor->shape_[i]; + strides[i] = 1; + } + } + return NNACL_OK; +} + +int StrideSlicePreCheck(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + if (outputs_size != kStridedSliceOutputNum) { + return NNACL_PARAM_INVALID; + } + if (inputs_size != kStridedSliceInputNum && + !(inputs_size <= kStridedSliceMultiInputNumMax && inputs_size >= kStridedSliceMultiInputNumMin)) { + return NNACL_PARAM_INVALID; + } + if (parameter == NULL || outputs[0] == NULL || inputs[0] == NULL) { + return NNACL_NULL_PTR; + } + if (inputs_size >= kStridedSliceMultiInputNumMin) { + bool begins_type_ok = + (inputs[C1NUM]->data_type_ == kNumberTypeInt32) || (inputs[C1NUM]->data_type_ == kNumberTypeInt64); + bool ends_type_ok = + (inputs[C2NUM]->data_type_ == kNumberTypeInt32) || (inputs[C2NUM]->data_type_ == kNumberTypeInt64); + if (!(begins_type_ok && ends_type_ok)) { + return NNACL_PARAM_INVALID; + } + } + return NNACL_OK; +} + +void Bit2Vector(StridedSliceTransferBuffer *transfer_buffer, const StridedSliceParameter *param) { + for (unsigned i = 0; i < (unsigned)(size_t)(transfer_buffer->ndim_); i++) { + transfer_buffer->begins_mask_[i] = (unsigned)(param->begins_mask_) & (1 << i); + transfer_buffer->ends_mask_[i] = (unsigned)(param->ends_mask_) & (1 << i); + transfer_buffer->ellipsis_mask_[i] = (unsigned)(param->ellipsisMask_) & (1 << i); + transfer_buffer->new_axis_mask_[i] = (unsigned)(param->newAxisMask_) & (1 << i); + transfer_buffer->shrink_axis_mask_[i] = (unsigned)(param->shrinkAxisMask_) & (1 << i); + } +} + +int ApplyNewAxisMask(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param, int *in_shape, + size_t *out_shape_size) { + for (size_t i = 0; i < transfer_buffer->new_axis_mask_size_; i++) { + if (transfer_buffer->new_axis_mask_[i]) { + if (*out_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + int ret = ShapeInsert(in_shape, out_shape_size, i, 1); + if (ret != NNACL_OK) { + return NNACL_ERR; + } + transfer_buffer->begins_[i] = 0; + transfer_buffer->ends_[i] = 1; + transfer_buffer->strides_[i] = 1; + + ShapePush(transfer_buffer->begins_, &transfer_buffer->begins_size_, 0); + ShapePush(transfer_buffer->ends_, &transfer_buffer->ends_size_, in_shape[(size_t)(transfer_buffer->ndim_) - 1]); + ShapePush(transfer_buffer->strides_, &transfer_buffer->strides_size_, 1); + + transfer_buffer->begins_mask_[i] = false; + transfer_buffer->ends_mask_[i] = false; + transfer_buffer->ellipsis_mask_[i] = false; + transfer_buffer->shrink_axis_mask_[i] = false; + } + } + return NNACL_OK; +} + +void ApplyBeginMask(StridedSliceTransferBuffer *transfer_buffer) { + for (int i = 0; i < transfer_buffer->ndim_; i++) { + if (transfer_buffer->begins_mask_[i]) { + transfer_buffer->begins_[i] = transfer_buffer->strides_[i] > 0 ? 0 : -1; + } + } +} + +int ApplyEndMask(StridedSliceTransferBuffer *transfer_buffer, const int *in_shape, size_t in_shape_size) { + for (int i = 0; i < transfer_buffer->ndim_; i++) { + if (transfer_buffer->ends_mask_[i]) { + if ((size_t)i >= in_shape_size) { + return NNACL_ERR; + } + transfer_buffer->ends_[i] = transfer_buffer->strides_[i] > 0 ? in_shape[i] : -1 - in_shape[i]; + } + } + return NNACL_OK; +} + +int ApplyEllipsisMask(StridedSliceTransferBuffer *transfer_buffer, const int *in_shape, size_t in_shape_size) { + for (size_t i = 0; i < transfer_buffer->ellipsis_mask_size_; i++) { + if (transfer_buffer->ellipsis_mask_[i]) { + if (i >= in_shape_size) { + return NNACL_ERR; + } + transfer_buffer->begins_[i] = 0; + transfer_buffer->ends_[i] = in_shape[i]; + break; + } + } + return NNACL_OK; +} + +int TransIndexToPositive(StridedSliceTransferBuffer *transfer_buffer, const int *in_shape, size_t max_shape_size, + size_t in_shape_size) { + for (size_t i = 0; i < transfer_buffer->begins_size_; i++) { + if (i >= max_shape_size) { + return NNACL_ERR; + } + if (transfer_buffer->begins_[i] < 0) { + transfer_buffer->begins_[i] += in_shape[i]; + } + if (transfer_buffer->ends_[i] < 0) { + transfer_buffer->ends_[i] += in_shape[i]; + } + if (i < in_shape_size) { + if (transfer_buffer->begins_[i] < 0 || transfer_buffer->begins_[i] > in_shape[i]) { + return NNACL_ERR; + } + if ((transfer_buffer->ends_[i] < 0 && transfer_buffer->ends_[i] != -1) || + transfer_buffer->ends_[i] > in_shape[i]) { + return NNACL_ERR; + } + } + } + return NNACL_OK; +} + +void ApplyShrinkMask(StridedSliceTransferBuffer *transfer_buffer, int *output_shape, size_t *output_shape_size) { + int old_out_shape[MAX_SHAPE_SIZE] = {0}; + size_t old_out_shape_size = 0; + ShapeSet(old_out_shape, &old_out_shape_size, output_shape, *output_shape_size); + *output_shape_size = 0; + for (size_t i = 0; i < transfer_buffer->shrink_axis_mask_size_; i++) { + if (transfer_buffer->shrink_axis_mask_[i]) { + transfer_buffer->ends_[i] = transfer_buffer->begins_[i] + 1; + transfer_buffer->strides_[i] = 1; + } else { + ShapePush(output_shape, output_shape_size, old_out_shape[i]); + } + } + for (size_t i = transfer_buffer->shrink_axis_mask_size_; i < old_out_shape_size; i++) { + ShapePush(output_shape, output_shape_size, old_out_shape[i]); + } +} + +int TransferBuffer2Param(const StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param, + const int *in_shape, size_t in_shape_size) { + if (transfer_buffer->ndim_ >= (int)(in_shape_size) || param->in_shape_length_ >= (int)(in_shape_size)) { + return NNACL_ERR; + } + for (int i = 0; i < transfer_buffer->ndim_; i++) { + param->begins_[i] = transfer_buffer->begins_[i]; + param->ends_[i] = transfer_buffer->ends_[i]; + param->in_shape_[i] = in_shape[i]; + param->strides_[i] = transfer_buffer->strides_[i]; + } + + for (int i = transfer_buffer->ndim_; i < param->in_shape_length_; i++) { + param->begins_[i] = 0; + param->ends_[i] = in_shape[i]; + param->in_shape_[i] = in_shape[i]; + param->strides_[i] = 1; + } + return NNACL_OK; +} + +void InitStridedSliceTransferBuffer(StridedSliceTransferBuffer *transfer_buffer) { + transfer_buffer->begins_size_ = 0; + transfer_buffer->ends_size_ = 0; + transfer_buffer->strides_size_ = 0; + transfer_buffer->ellipsis_mask_size_ = 0; + transfer_buffer->new_axis_mask_size_ = 0; + transfer_buffer->shrink_axis_mask_size_ = 0; +} + +void SetMaskSize(StridedSliceTransferBuffer *transfer_buffer) { + transfer_buffer->ellipsis_mask_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->new_axis_mask_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->shrink_axis_mask_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->begins_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->ends_size_ = (size_t)(transfer_buffer->ndim_); + transfer_buffer->strides_size_ = (size_t)(transfer_buffer->ndim_); +} + +// note: begin, end, stride length are equal, but may less than rank of input +int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = StrideSlicePreCheck(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + SetDataTypeFormat(outputs[0], inputs[0]); + + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int in_shape[MAX_SHAPE_SIZE] = {0}; + size_t in_shape_size = 0; + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapeSet(in_shape, &in_shape_size, input->shape_, input->shape_size_); + + StridedSliceTransferBuffer transfer_buffer; + InitStridedSliceTransferBuffer(&transfer_buffer); + + StridedSliceParameter *param = (StridedSliceParameter *)parameter; + + transfer_buffer.ndim_ = 0; + if (inputs_size == kStridedSliceInputNum) { + transfer_buffer.ndim_ = (int)(in_shape_size); + if (transfer_buffer.ndim_ > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + for (int i = 0; i < transfer_buffer.ndim_; i++) { + ShapePush(transfer_buffer.begins_, &transfer_buffer.begins_size_, param->begins_[i]); + ShapePush(transfer_buffer.ends_, &transfer_buffer.ends_size_, param->ends_[i]); + ShapePush(transfer_buffer.strides_, &transfer_buffer.strides_size_, param->strides_[i]); + } + } + if (!CheckInputs(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + if (inputs_size == 4) { + int ret = HandleAxesInputNotExist(inputs, &transfer_buffer); + if (ret != NNACL_OK) { + return ret; + } + } + + if (inputs_size == 5) { + int ret = HandleAxesInputExist(inputs, &transfer_buffer.ndim_, in_shape, transfer_buffer.begins_, + transfer_buffer.strides_, transfer_buffer.ends_); + if (ret != NNACL_OK) { + return ret; + } + } + + // set all mask to original input shape + SetMaskSize(&transfer_buffer); + Bit2Vector(&transfer_buffer, param); + int ret = ApplyNewAxisMask(&transfer_buffer, param, in_shape, &in_shape_size); + if (ret != NNACL_OK) { + return ret; + } + + // update parameter with new input shape + param->num_axes_ = (int)(in_shape_size); + param->in_shape_length_ = (int)(in_shape_size); + + ApplyBeginMask(&transfer_buffer); + ret = ApplyEndMask(&transfer_buffer, in_shape, MAX_SHAPE_SIZE); + if (ret != NNACL_OK) { + return ret; + } + ret = ApplyEllipsisMask(&transfer_buffer, in_shape, MAX_SHAPE_SIZE); + if (ret != NNACL_OK) { + return ret; + } + + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = 0; + ShapeSet(output_shape, &output_shape_size, in_shape, in_shape_size); + ret = TransIndexToPositive(&transfer_buffer, in_shape, MAX_SHAPE_SIZE, input->shape_size_); + if (ret != NNACL_OK) { + return ret; + } + for (int i = 0; i < transfer_buffer.ndim_; i++) { + if (transfer_buffer.strides_[i] == 0 || in_shape[i] < transfer_buffer.ends_[i]) { + return NNACL_ERR; + } + output_shape[i] = (transfer_buffer.ends_[i] - transfer_buffer.begins_[i] + transfer_buffer.strides_[i] + + (transfer_buffer.strides_[i] < 0 ? 1 : -1)) / + transfer_buffer.strides_[i]; + output_shape[i] = output_shape[i] > 0 ? output_shape[i] : 0; + } + ApplyShrinkMask(&transfer_buffer, output_shape, &output_shape_size); + SetShapeArray(outputs[0], output_shape, output_shape_size); + ret = TransferBuffer2Param(&transfer_buffer, param, in_shape, MAX_SHAPE_SIZE); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +REG_INFER(StridedSlice, PrimType_StridedSlice, StridedSliceInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..de3321b4441ecfa25c3e7d137b1c0276b97feea3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/strided_slice_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_STRIDED_SLICE_INFER_H +#define MINDSPORE_NNACL_STRIDED_SLICE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/strided_slice_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_STRIDED_SLICE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_extract_features_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_extract_features_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..341ff0a579b213ae0b04648b8b516dde55e3d85c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_extract_features_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/string/custom_extract_features_infer.h" +#include "nnacl/infer/infer_register.h" + +int CustomExtractFeaturesInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + TensorC *output1 = outputs[1]; + + output0->data_type_ = kNumberTypeInt32; + output0->format_ = input->format_; + output1->data_type_ = kNumberTypeFloat32; + output1->format_ = input->format_; + + if (input->data_ == NULL) { + return NNACL_INFER_INVALID; + } + int string_num = *((const int32_t *)(input->data_)); + + int res = (string_num == 0 ? 1 : string_num); + output0->shape_size_ = 1; + output0->shape_[0] = res; + output1->shape_size_ = 1; + output1->shape_[0] = res; + return NNACL_OK; +} + +REG_INFER(CustomExtractFeatures, PrimType_CustomExtractFeatures, CustomExtractFeaturesInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_extract_features_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_extract_features_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..b8fd536bed1a95c6c1b37ae281f93866ba068f30 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_extract_features_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_CUSTOM_EXTRACT_FEATURES_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_CUSTOM_EXTRACT_FEATURES_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomExtractFeaturesInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_CUSTOM_EXTRACT_FEATURES_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_normalize_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_normalize_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..5b7ab988c53ef4f188adbae709814d866f8807b4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_normalize_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/string/custom_normalize_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int CustomNormalizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + + if (input->data_ == NULL) { + return NNACL_INFER_INVALID; + } + if (NNACLGetElementNum(input) < 1) { + return NNACL_ERR; + } + if (input->data_type_ != kNumberTypeUInt32 && input->data_type_ != kObjectTypeString) { + return NNACL_ERR; + } + int string_num = *((const int32_t *)(input->data_)); // also look custom_extract_features + output->shape_size_ = 1; + output->shape_[0] = (string_num == 0 ? 1 : string_num); + return NNACL_OK; +} + +REG_INFER(CustomNormalize, PrimType_CustomNormalize, CustomNormalizeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_normalize_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_normalize_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..9d19a71bd1782fb1e99583c7b2c2abb2610243fe --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_normalize_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_CUSTOM_NORMALIZE_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_CUSTOM_NORMALIZE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/softmax_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int CustomNormalizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_CUSTOM_NORMALIZE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_predict_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_predict_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..bf50da8998b32e5c5a003d0309aebfae08bba86f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_predict_infer.c @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/string/custom_predict_infer.h" +#include "nnacl/infer/infer_register.h" + +int CustomPredictInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + TensorC *output1 = outputs[1]; + + CustomPredictParameter *param = (CustomPredictParameter *)parameter; + output0->shape_size_ = 1; + output0->shape_[0] = param->output_num; + output0->data_type_ = kNumberTypeInt32; + output0->format_ = input->format_; + output1->shape_size_ = 1; + output1->shape_[0] = param->output_num; + output1->data_type_ = kNumberTypeFloat32; + output1->format_ = input->format_; + return NNACL_OK; +} + +REG_INFER(CustomPredict, PrimType_CustomPredict, CustomPredictInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_predict_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_predict_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..37d8e98e5e4b6272161e8cd23230fed3724d11b5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/custom_predict_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_CUSTOM_PREDICT_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_CUSTOM_PREDICT_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct CustomPredictParameter { + OpParameter op_parameter_; + int output_num; +} CustomPredictParameter; + +int CustomPredictInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_CUSTOM_PREDICT_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/hashtable_lookup_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/hashtable_lookup_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..f7137d9804347bc942df0edeff1f475e99a9d284 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/hashtable_lookup_infer.c @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/string/hashtable_lookup_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int HashtableLoopupInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + const TensorC *values = inputs[2]; + if (input == NULL || values == NULL) { + return NNACL_NULL_PTR; + } + + TensorC *output = outputs[0]; + TensorC *hits = outputs[1]; + + output->data_type_ = values->data_type_; + output->format_ = input->format_; + hits->shape_size_ = 1; + hits->shape_[0] = NNACLGetDimensionSize(input, 0); + hits->data_type_ = kNumberTypeUInt8; + hits->format_ = input->format_; + + if (input->data_ == NULL) { + return NNACL_INFER_INVALID; + } + return NNACL_OK; +} + +REG_INFER(HashtableLookup, PrimType_HashtableLookup, HashtableLoopupInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/hashtable_lookup_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/hashtable_lookup_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..8e360c93b88a7b1f1b0c06f9c6e089a72d11e8c0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/hashtable_lookup_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_HASHTABLE_LOOKUP_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_HASHTABLE_LOOKUP_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int HashtableLoopupInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_HASHTABLE_LOOKUP_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/lsh_projection_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/lsh_projection_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..b5244b71b2d05a09f03d4bba0c18e38af0ce933f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/lsh_projection_infer.c @@ -0,0 +1,53 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/string/lsh_projection_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int LshProjectionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *in_hash = inputs[0]; + if (in_hash->shape_size_ != 2 || NNACLGetDimensionSize(in_hash, 1) > 32) { + return NNACL_ERR; + } + TensorC *out_tensor = outputs[0]; + out_tensor->data_type_ = kNumberTypeInt32; + out_tensor->format_ = Format_NHWC; + + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + LshProjectionParameter *param = (LshProjectionParameter *)parameter; + switch (param->lsh_type_) { + case LshProjectionType_SPARSE: + ShapePush(out_shape, &out_shape_size, NNACLGetDimensionSize(in_hash, 0)); + break; + case LshProjectionType_DENSE: + ShapePush(out_shape, &out_shape_size, NNACLGetDimensionSize(in_hash, 0) * NNACLGetDimensionSize(in_hash, 1)); + break; + default: + return NNACL_ERR; + } + SetShapeArray(out_tensor, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(LshProjection, PrimType_LshProjection, LshProjectionInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/lsh_projection_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/lsh_projection_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..bacf0a0e3e530a1f28d9bec7cc43910b1eefa0f4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/lsh_projection_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_LSH_PROJECTION_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_LSH_PROJECTION_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/lsh_projection_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LshProjectionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_LSH_PROJECTION_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/skip_gram_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/skip_gram_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..aa2908af04277ae5602426bd87f3b10154a4407d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/skip_gram_infer.c @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/string/skip_gram_infer.h" +#include "nnacl/infer/infer_register.h" + +int SkipGramInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (input->data_ == NULL) { + return NNACL_INFER_INVALID; + } + return NNACL_OK; +} + +REG_INFER(SkipGram, PrimType_SkipGram, SkipGramInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/skip_gram_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/skip_gram_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..25fd28cf6cba001af0cf1913972202499edf2cc4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/string/skip_gram_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INFER_STRING_SKIP_GRAM_INFER_H +#define MINDSPORE_NNACL_INFER_STRING_SKIP_GRAM_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int SkipGramInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_INFER_STRING_SKIP_GRAM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/tile_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/tile_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..4754dcaace32bd092881fc5ece732025064af8de --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/tile_infer.c @@ -0,0 +1,111 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/tile_infer.h" +#include +#include "nnacl/infer/infer_register.h" +#include "nnacl/tile_parameter.h" +#include "nnacl/tensor_c_utils.h" + +void TileParamCaffe2Tflite(TileParameter *param, size_t out_shape_size) { + if (param->dims_size_ != 0) { + int multiples_size_tmp[5] = {0}; + NNACL_CHECK_TRUE_RET_VOID(out_shape_size <= 5); + for (size_t i = 0; i < out_shape_size; i++) { + multiples_size_tmp[i] = 1; + } + for (size_t i = 0; i < param->dims_size_; i++) { + if (i >= MAX_SHAPE_SIZE) { + return; + } + multiples_size_tmp[param->dims_[i]] = param->multiples_[i]; + } + for (size_t i = 0; i < 5; i++) { + param->multiples_[i] = multiples_size_tmp[i]; + } + } +} + +int TileInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + TileParameter *param = (TileParameter *)parameter; + + size_t multiples_size = 0; + int input1_shape_size = inputs[1]->shape_size_; + if (input1_shape_size > (int)(input->shape_size_) || input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + NNACL_CHECK_TRUE_RET(input1_shape_size <= MAX_SHAPE_SIZE, NNACL_ERR); + int data_num = NNACLGetElementNum(inputs[1]); + multiples_size = (size_t)(data_num); + if (inputs[1]->data_type_ != kNumberTypeInt && inputs[1]->data_type_ != kNumberTypeInt32) { + return NNACL_INPUT_TENSOR_ERROR; + } + int *input1_data = inputs[1]->data_; + if (input1_data == NULL) { + return NNACL_INFER_INVALID; + } + NNACL_CHECK_TRUE_RET(data_num <= MAX_SHAPE_SIZE, NNACL_ERR); + for (int i = 0; i < data_num; i++) { + param->multiples_[i] = input1_data[i]; + } + + int *dims = param->dims_; + size_t dims_size = param->dims_size_; + if (dims_size == 0) { + int dim_num = NNACLGetElementNum(inputs[1]); + NNACL_CHECK_TRUE_RET(dim_num <= MAX_SHAPE_SIZE, NNACL_ERR); + for (int dim = 0; dim < dim_num; ++dim) { + ShapePush(dims, &dims_size, dim); + } + param->dims_size_ = dims_size; + } + NNACL_CHECK_TRUE_RET(multiples_size == dims_size, NNACL_ERR); + for (size_t i = 0; i < input->shape_size_; ++i) { + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + for (size_t i = 0; i < dims_size; ++i) { + if (dims[i] >= MAX_SHAPE_SIZE || input->shape_[dims[i]] == 0) { + return NNACL_ERR; + } + if (input->shape_[dims[i]] != 0 && param->multiples_[i] > INT_MAX / input->shape_[dims[i]]) { + return NNACL_ERR; + } + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(input->shape_[dims[i]], (param->multiples_[i])), NNACL_ERR); + out_shape[dims[i]] = input->shape_[dims[i]] * (param->multiples_[i]); + } + // change caffe param format to tflite + TileParamCaffe2Tflite(param, out_shape_size); + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Tile, PrimType_TileFusion, TileInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/tile_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/tile_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..26d67610d1bcb926d800923fa64fe7b8e3587776 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/tile_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_TILE_INFER_H +#define MINDSPORE_NNACL_TILE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/base/tile_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TileInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_TILE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/topk_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/topk_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..19fde09c2541f2a49ce65fb3d157bf253ca97286 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/topk_infer.c @@ -0,0 +1,66 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/topk_infer.h" +#include "nnacl/infer/infer_register.h" + +int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 2); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output0 = outputs[0]; + TensorC *output1 = outputs[1]; + SetDataTypeFormat(output0, input); + output1->data_type_ = kNumberTypeInt32; + output1->format_ = input->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *input_k_tensor = inputs[1]; + if (input_k_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + + TopkParameter *param = (TopkParameter *)parameter; + param->k_ = ((int32_t *)input_k_tensor->data_)[0]; + + if (input->shape_size_ > MAX_SHAPE_SIZE) { + return NNACL_INPUT_TENSOR_ERROR; + } + int out_shape[MAX_SHAPE_SIZE]; + size_t out_shape_size = 0; + ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_); + if (out_shape_size < 1) { + return NNACL_ERR; + } + if (param->axis_ < 0) { + param->axis_ += (int)out_shape_size; + } + if (param->axis_ < 0 || (size_t)param->axis_ >= out_shape_size) { + return NNACL_ERR; + } + out_shape[(size_t)param->axis_] = param->k_; + + SetShapeArray(output0, out_shape, out_shape_size); + SetShapeArray(output1, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(TopK, PrimType_TopKFusion, TopKInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/topk_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/topk_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c5f111ee25c7cdd95a6fb5503bd5cc5532412b72 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/topk_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_TOPK_INFER_H +#define MINDSPORE_NNACL_TOPK_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/topk_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_TOPK_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/transpose_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/transpose_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..41ac72541c1a4e7436d0a1bad8dc7251570248d2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/transpose_infer.c @@ -0,0 +1,137 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/transpose_infer.h" +#include "nnacl/infer/infer_register.h" + +bool CheckPermTransFormat(const int *perm, const int *perm_transformat, const int size) { + for (int i = 0; i < size; ++i) { + if (perm[i] != perm_transformat[i]) { + return false; + } + } + return true; +} + +int SetOutputShape(int perms_num, const TensorC *input, TensorC *output, const int *perm, size_t perm_size, + int *out_shape) { + // set output shape + size_t in_shape_size = input->shape_size_; + output->shape_size_ = in_shape_size; + if (perm_size == 0) { + for (size_t i = 0; i < in_shape_size; ++i) { + out_shape[in_shape_size - i - 1] = input->shape_[i]; + } + } else if (perm_size != in_shape_size) { + for (size_t i = 0; i < in_shape_size; ++i) { + out_shape[i] = input->shape_[i]; + } + } else { + output->shape_size_ = perm_size; + for (size_t i = 0; i < perm_size; ++i) { + if (perm[i] >= input->shape_size_) { + return NNACL_ERR; + } else { + out_shape[i] = input->shape_[perm[i]]; + } + } + } + return NNACL_OK; +} + +int GetAndCheckPerm(const TensorC *perm_tensor, const int perms_num, int *perm, size_t *perm_size) { + if (perms_num >= MAX_TRANSPOSE_DIM_SIZE) { + return NNACL_TRANSPOSE_PERM_DIMS_INVALID; + } + + int ret = GetInt32DataFromTensor(perm_tensor, perm, perm_size); + if (ret != NNACL_OK) { + return ret; + } + for (size_t i = 0; i < *perm_size; i++) { + NNACL_CHECK_TRUE_RET(perm[i] < perms_num, NNACL_ERR); + } + return NNACL_OK; +} + +void Handle4DPerm(const TensorC *input, TensorC *output, int *perm, size_t *perm_size) { + const int nchw2nhwc[4] = {Index0, Index2, Index3, Index1}; + const int nhwc2nchw[4] = {Index0, Index3, Index1, Index2}; + const int trans3d[3] = {Index0, Index2, Index1}; + if (input->format_ == Format_NCHW && CheckPermTransFormat(perm, nchw2nhwc, PERM_NUM_FOUR)) { + output->format_ = Format_NHWC; + } else if ((input->format_ == Format_NHWC || input->format_ == Format_KHWC) && + CheckPermTransFormat(perm, nhwc2nchw, PERM_NUM_FOUR)) { + output->format_ = Format_NCHW; + } + // though the perm is 4d in default, the input can be a 3d tensor. The op implementation must be adapted to this. + if (input->shape_size_ == DIMENSION_3D) { + ShapeSet(perm, perm_size, trans3d, DIMENSION_3D); + } +} + +int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + const TensorC *perm_tensor = inputs[1]; + if (perm_tensor == NULL) { + return NNACL_INFER_INVALID; + } + NNACL_CHECK_TRUE_RET(perm_tensor->shape_size_ == 1, NNACL_INFER_INVALID); + const int perms_num = perm_tensor->shape_[0]; + if (perms_num != 0 && perm_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } + TransposeParameter *transpose_param = (TransposeParameter *)parameter; + transpose_param->perm_size_ = perms_num; + int perm[MAX_TRANSPOSE_DIM_SIZE] = {0}; + size_t perm_size = 0; + int ret = GetAndCheckPerm(perm_tensor, perms_num, perm, &perm_size); + if (ret != NNACL_OK) { + return ret; + } + + if (perms_num == PERM_NUM_FOUR) { + Handle4DPerm(input, output, perm, &perm_size); + } + int kPermIndex0 = 0; + int kPermIndex2 = 2; + if (perms_num == PERM_NUM_THREE && perm[0] == kPermIndex0 && perm[1] == kPermIndex2) { + output->format_ = input->format_ == Format_NCHW ? Format_NHWC : Format_NCHW; + } + if (parameter->quant_type_ == Quant_QuantWeight) { + output->data_type_ = kNumberTypeFloat32; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + // set output shape + int out_shape[MAX_TRANSPOSE_DIM_SIZE] = {0}; + SetOutputShape(perms_num, input, output, perm, perm_size, out_shape); + SetShapeArray(output, out_shape, output->shape_size_); + return NNACL_OK; +} + +REG_INFER(Transpose, PrimType_Transpose, TransposeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/transpose_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/transpose_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..c4fdaa65a1396664d374f7e293bc00d385c910c2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/transpose_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_TRANSPOSE_INFER_H +#define MINDSPORE_NNACL_TRANSPOSE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/transpose_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_TRANSPOSE_INFER_H diff --git a/mindspore-lite/src/extendrt/kernel/kernel_selector/kernel_selector.cc b/mindspore-lite/ops/kernel/cpu/nnacl/infer/triu_tril_infer.c similarity index 42% rename from mindspore-lite/src/extendrt/kernel/kernel_selector/kernel_selector.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/triu_tril_infer.c index 18635762ac185ae09186afd3bec93068e1be4c64..6b2d4742c622e96651f5a2af3d97ac10997b9967 100644 --- a/mindspore-lite/src/extendrt/kernel/kernel_selector/kernel_selector.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/triu_tril_infer.c @@ -14,23 +14,29 @@ * limitations under the License. */ -#include "src/extendrt/kernel/kernel_selector/kernel_selector.h" -#include "src/extendrt/kernel/kernel_selector/nnacl_first_kernel_selector.h" +#include "nnacl/infer/triu_tril_infer.h" +#include "nnacl/infer/infer_register.h" -namespace mindspore::kernel { -std::vector KernelSelector::Candidates(const PrimitiveType &op_type, const KernelAttr &require, - const std::string &backend, Format format) { - std::vector results; - for (const auto &iter : KernelLibRegister::Instance().GetAllLibs()) { - const auto *kernel_lib = iter.second; - if (kernel_lib->Support(op_type, require, backend, format)) { - results.emplace_back(kernel_lib); - } +int TriuTrilInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; } - return results; + const size_t triul_input_min_size = 1; + const size_t triul_output_size = 1; + if (inputs_size < triul_input_min_size || outputs_size != triul_output_size) { + return NNACL_ERR; + } + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, input); + SetShapeTensor(output, input); + return NNACL_OK; } -std::shared_ptr CreateKernelSelector(const std::shared_ptr &compile_option) { - return std::make_shared(compile_option); -} -} // namespace mindspore::kernel +REG_INFER(Triu, PrimType_Triu, TriuTrilInferShape) +REG_INFER(Tril, PrimType_Tril, TriuTrilInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/triu_tril_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/triu_tril_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a7630dbeea933382965ebc408baebd2f570384cc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/triu_tril_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_TRIU_TRIL_INFER_H +#define MINDSPORE_NNACL_TRIU_TRIL_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/arg_min_max_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int TriuTrilInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_TRIU_TRIL_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/uniform_real_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/uniform_real_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2e040880496ff06a5e029f29d9b5537d48e4324f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/uniform_real_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/uniform_real_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" + +int UniformRealInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (ret != NNACL_OK) { + return ret; + } + outputs[0]->data_type_ = kNumberTypeFloat32; + outputs[0]->format_ = inputs[0]->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int32_t *input_data = (int32_t *)(inputs[0]->data_); + if (input_data == NULL) { + return NNACL_INFER_INVALID; + } + int input_num = NNACLGetElementNum(inputs[0]); + if (input_num > MAX_SHAPE_SIZE || input_num < 0) { + return NNACL_INPUT_TENSOR_ERROR; + } + int output_shape[MAX_SHAPE_SIZE]; + size_t output_shape_size = (size_t)(input_num); + for (int i = 0; i < input_num; i++) { + output_shape[i] = input_data[i]; + } + SetShapeArray(outputs[0], output_shape, output_shape_size); + return NNACL_OK; +} + +REG_INFER(UniformReal, PrimType_UniformReal, UniformRealInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/uniform_real_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/uniform_real_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..7030b609d529acd863bd31c80870910218bdb9d7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/uniform_real_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNIFORM_REAL_INFER_H +#define MINDSPORE_NNACL_UNIFORM_REAL_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UniformRealInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNIFORM_REAL_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/unique_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unique_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..3c396ab18b6cf85fba850461b82c979e85de5ced --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unique_infer.c @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/unique_infer.h" +#include "nnacl/infer/infer_register.h" + +int UniqueInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 2); + if (ret != NNACL_OK) { + return ret; + } + + const TensorC *input0 = inputs[0]; + TensorC *output0 = outputs[0]; + TensorC *output1 = outputs[1]; + + SetDataTypeFormat(output0, input0); + output1->data_type_ = kNumberTypeInt32; + output1->format_ = input0->format_; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + SetShapeTensor(output0, input0); + SetShapeTensor(output1, input0); + return NNACL_OK; +} + +REG_INFER(Unique, PrimType_Unique, UniqueInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/unique_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unique_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..1c99b86fbf77de95491022f3426227b7a132b52c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unique_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNIQUE_INFER_H +#define MINDSPORE_NNACL_UNIQUE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UniqueInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNIQUE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/unsorted_segment_sum_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unsorted_segment_sum_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..cee24a5e5f40c49f1d6423df295bb73d45febc4c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unsorted_segment_sum_infer.c @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/unsorted_segment_sum_infer.h" +#include "nnacl/infer/infer_register.h" + +int UnsortedSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + TensorC *out = outputs[0]; + const TensorC *x = inputs[0]; + const TensorC *segment_id = inputs[1]; + if (inputs[2]->data_ == NULL || + (inputs[2]->data_type_ != kNumberTypeInt && inputs[2]->data_type_ != kNumberTypeInt32)) { + return NNACL_INPUT_TENSOR_ERROR; + } + int num_segments = *(int *)(inputs[2]->data_); + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + ShapePush(output_shape, &output_shape_size, num_segments); + for (int index = (int)(segment_id->shape_size_); index < (int)(x->shape_size_); index++) { + if (output_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapePush(output_shape, &output_shape_size, x->shape_[index]); + } + SetShapeArray(out, output_shape, output_shape_size); + SetDataTypeFormat(out, x); + return NNACL_OK; +} + +REG_INFER(UnsortedSegmentSum, PrimType_UnsortedSegmentSum, UnsortedSegmentSumInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/unsorted_segment_sum_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unsorted_segment_sum_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..1382d611e4167ac74e97d61f321348f9cbcebaf4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unsorted_segment_sum_infer.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_INFER_H +#define MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct UnsortedSegmentSumParameter { + OpParameter op_parameter_; + int segments_num_; +} UnsortedSegmentSumParameter; + +int UnsortedSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNSORTED_SEGMENT_SUM_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/unsqueeze_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unsqueeze_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..2b936b41c9cb20179c11c5991ac320d88370d0fc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unsqueeze_infer.c @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/unsqueeze_infer.h" +#include "nnacl/infer/infer_register.h" + +int UnsqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input = inputs[0]; + TensorC *output = outputs[0]; + + SetDataTypeFormat(output, input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + + UnSqueezeParameter *param = (UnSqueezeParameter *)parameter; + int in_rank = (int)(input->shape_size_); + int dim_rank = param->num_dim_; + int out_shape[MAX_SHAPE_SIZE] = {0}; + size_t out_shape_size = 0; + if (dim_rank == 0) { + for (size_t i = 0; i < input->shape_size_; i++) { + if (input->shape_[i] != 1) { + if (out_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapePush(out_shape, &out_shape_size, input->shape_[i]); + } + } + } else { + int sz = in_rank + dim_rank; + size_t in_itr = 0; + size_t ax_itr = 0; + if (sz < 0) { + return NNACL_ERR; + } + for (int i = 0; i < sz; i++) { + if (out_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + if (ax_itr < (size_t)(dim_rank) && param->dims_[ax_itr] == (int)(i)) { + ShapePush(out_shape, &out_shape_size, 1); + ax_itr++; + } else if (ax_itr < (size_t)(dim_rank) && param->dims_[ax_itr] + sz == i) { + ShapePush(out_shape, &out_shape_size, 1); + ax_itr++; + } else { + if (in_itr >= input->shape_size_) { + return NNACL_ERR; + } + ShapePush(out_shape, &out_shape_size, input->shape_[in_itr]); + in_itr++; + } + } + } + SetShapeArray(output, out_shape, out_shape_size); + return NNACL_OK; +} + +REG_INFER(Unsqueeze, PrimType_Unsqueeze, UnsqueezeInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/unsqueeze_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unsqueeze_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..73cf8162759484ddbd9db989be7f98e08bf4353b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unsqueeze_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNSQUEEZE_INFER_H +#define MINDSPORE_NNACL_UNSQUEEZE_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/unsqueeze_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UnsqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNSQUEEZE_INFER_H diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unstack_infer.c similarity index 31% rename from mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.h rename to mindspore-lite/ops/kernel/cpu/nnacl/infer/unstack_infer.c index 37a326155e2c4c9ddb39e666edce49f9a1f4d41d..33eb635e278364d49b5f832bf5c3c1a9c84c4152 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unstack_infer.c @@ -1,5 +1,5 @@ /** - * Copyright 2021-2022 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,33 +13,47 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DECONVOLUTION_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DECONVOLUTION_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "infer/cxx_api/conv2d_transpose_fusion.h" -namespace mindspore::lite { -class DeconvolutionTensorRT : public TensorRTOp { - public: - DeconvolutionTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} +#include "nnacl/infer/unstack_infer.h" +#include "nnacl/infer/infer_register.h" - ~DeconvolutionTensorRT() override; +int UnstackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } - int AddInnerOp(TensorRTContext *ctx) override; + const TensorC *input = inputs[0]; + UnstackParameter *param = (UnstackParameter *)parameter; + int axis = param->axis_ < 0 ? param->axis_ + (int)(input->shape_size_) : param->axis_; + if (axis < 0 || axis >= (int)(input->shape_size_)) { + return NNACL_PARAM_INVALID; + } + for (size_t i = 0; i < outputs_size; i++) { + SetDataTypeFormat(outputs[i], input); + } - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t output_shape_size = 0; + for (size_t i = 0; i < input->shape_size_; ++i) { + if (i != (size_t)(axis)) { + if (output_shape_size >= MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + ShapePush(output_shape, &output_shape_size, input->shape_[i]); + } + } + for (size_t i = 0; i < outputs_size; i++) { + if (outputs[i] == NULL) { + return NNACL_NULL_PTR; + } + SetShapeArray(outputs[i], output_shape, output_shape_size); + } + return NNACL_OK; +} - private: - void SetAttributes(const std::shared_ptr &conv_op, - nvinfer1::IDeconvolutionLayer *decon_layer); - - void *pack_weight_{nullptr}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DECONVOLUTION_TENSORRT_H_ +REG_INFER(Unstack, PrimType_Unstack, UnstackInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/unstack_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unstack_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..a97da806ff09d6f7200a08636758943eaea08eb4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/unstack_infer.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_UNSTACK_INFER_H +#define MINDSPORE_NNACL_UNSTACK_INFER_H + +#include "nnacl/infer/common_infer.h" +#include "nnacl/unstack_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int UnstackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_UNSTACK_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/where_infer.c b/mindspore-lite/ops/kernel/cpu/nnacl/infer/where_infer.c new file mode 100644 index 0000000000000000000000000000000000000000..9c95a5a7b5c64067e84d2613c072e14eb28768fc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/where_infer.c @@ -0,0 +1,91 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/where_infer.h" +#include "nnacl/infer/infer_register.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/infer/broadcast_to_infer.h" + +int WhereBroadCastInferShape(const int input_shape0_size, const int input_shape1_size, const int *input_shape0, + const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1, int *out_shape, + bool *has_broad_cast) { + if (input_shape0_size > MAX_SHAPE_SIZE || input_shape1_size > MAX_SHAPE_SIZE) { + return NNACL_ERR; + } + MakeUpInputShapes(input_shape0_size, input_shape1_size, input_shape0, input_shape1, ndim, in_shape0, in_shape1); + if (*ndim >= MAX_SHAPE_SIZE) { + return NNACL_INFER_INVALID; + } + return BroadCastOutputShape(in_shape0, in_shape1, *ndim, out_shape, has_broad_cast); +} + +int WhereInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + + const TensorC *input0 = inputs[0]; + TensorC *output = outputs[0]; + NNACL_CHECK_NULL_RETURN_ERR(input0); + NNACL_CHECK_NULL_RETURN_ERR(output); + + // Need to dynamically allocate at runtime. + if (inputs_size == 1) { + output->data_type_ = kNumberTypeInt32; + output->format_ = input0->format_; + return NNACL_INFER_INVALID; + } + + if (inputs_size < 3 || outputs_size != 1) { + return NNACL_INPUT_TENSOR_ERROR; + } + + const TensorC *input1 = inputs[1]; + const TensorC *input2 = inputs[2]; + NNACL_CHECK_NULL_RETURN_ERR(input1); + NNACL_CHECK_NULL_RETURN_ERR(input2); + SetDataTypeFormat(output, input1); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + int in_shape0[MAX_SHAPE_SIZE] = {0}; + int in_shape1[MAX_SHAPE_SIZE] = {0}; + int in_shape2[MAX_SHAPE_SIZE] = {0}; + int output_shape[MAX_SHAPE_SIZE] = {0}; + size_t input_shape0_size = input0->shape_size_; + size_t input_shape1_size = input1->shape_size_; + size_t input_shape2_size = input2->shape_size_; + const int *input_shape0 = input0->shape_; + const int *input_shape1 = input1->shape_; + const int *input_shape2 = input2->shape_; + int ndim = (int)input_shape0_size; + bool has_broad_cast_1 = false; + bool has_broad_cast_2 = false; + if (WhereBroadCastInferShape(input_shape0_size, input_shape1_size, input_shape0, input_shape1, &ndim, in_shape0, + in_shape1, output_shape, &has_broad_cast_1) != NNACL_OK) { + return NNACL_ERR; + } + if (WhereBroadCastInferShape(ndim, input_shape2_size, output_shape, input_shape2, &ndim, in_shape0, in_shape2, + output_shape, &has_broad_cast_2) != NNACL_OK) { + return NNACL_ERR; + } + ShapeSet(output->shape_, &output->shape_size_, output_shape, ndim); + return NNACL_OK; +} + +REG_INFER(Where, PrimType_Where, WhereInferShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/infer/where_infer.h b/mindspore-lite/ops/kernel/cpu/nnacl/infer/where_infer.h new file mode 100644 index 0000000000000000000000000000000000000000..5ce524675a326ce062dd74450002d0dddcdbe23b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/infer/where_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_WHERE_INFER_H +#define MINDSPORE_NNACL_WHERE_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int WhereInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_WHERE_INFER_H diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/instance_norm_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/instance_norm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..0b9400be857748e3bda4ca63477c21bae1607d19 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/instance_norm_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INSTANCE_NORM_PARAMETER_H_ +#define NNACL_INSTANCE_NORM_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct InstanceNormParameter { + // Primitive parameter + OpParameter op_parameter_; + float epsilon_; + // shape correlative + int batch_; + int channel_; + int inner_size_; +} InstanceNormParameter; + +#endif // NNACL_INSTANCE_NORM_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/add_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/add_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..0140a0ec612b96f5775371f60f880604a1a68fac --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/add_int8.c @@ -0,0 +1,531 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/add_int8.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" +#ifdef ENABLE_AVX +#include "nnacl/intrinsics/avx/common_utils.h" +#endif +#include "nnacl/int8/fixed_point.h" + +#ifdef ENABLE_ARM +void AddInt8InputRounding(int32x4_t *in1, int32x4_t *in2, int32x4_t *in3, int32x4_t *in4, const int32x4_t left_vec, + const int32x4_t right_vec, const int32_t multiplier) { + // Apply left shift + *in1 = vmulq_s32(*in1, left_vec); + *in2 = vmulq_s32(*in2, left_vec); + *in3 = vmulq_s32(*in3, left_vec); + *in4 = vmulq_s32(*in4, left_vec); + + // Apply the fixed-point part of the multiplier. + *in1 = vqrdmulhq_n_s32(*in1, multiplier); + *in2 = vqrdmulhq_n_s32(*in2, multiplier); + *in3 = vqrdmulhq_n_s32(*in3, multiplier); + *in4 = vqrdmulhq_n_s32(*in4, multiplier); + + // Apply right shift + *in1 = vqaddq_s32(*in1, vshrq_n_s32(vandq_s32(*in1, right_vec), 31)); + *in2 = vqaddq_s32(*in2, vshrq_n_s32(vandq_s32(*in2, right_vec), 31)); + *in3 = vqaddq_s32(*in3, vshrq_n_s32(vandq_s32(*in3, right_vec), 31)); + *in4 = vqaddq_s32(*in4, vshrq_n_s32(vandq_s32(*in4, right_vec), 31)); + + *in1 = vrshlq_s32(*in1, right_vec); + *in2 = vrshlq_s32(*in2, right_vec); + *in3 = vrshlq_s32(*in3, right_vec); + *in4 = vrshlq_s32(*in4, right_vec); +} + +void AddInt8OutputRounding(int32x4_t *out1, int32x4_t *out2, int32x4_t *out3, int32x4_t *out4, const int32x4_t left_vec, + const int32x4_t right_vec, const int32_t multiplier) { + // Apply left shift + *out1 = vshlq_s32(*out1, left_vec); + *out2 = vshlq_s32(*out2, left_vec); + *out3 = vshlq_s32(*out3, left_vec); + *out4 = vshlq_s32(*out4, left_vec); + + // Apply the fixed-point part of the multiplier. + *out1 = vqrdmulhq_n_s32(*out1, multiplier); + *out2 = vqrdmulhq_n_s32(*out2, multiplier); + *out3 = vqrdmulhq_n_s32(*out3, multiplier); + *out4 = vqrdmulhq_n_s32(*out4, multiplier); + + // Apply right shift + *out1 = vqaddq_s32(*out1, vshrq_n_s32(vandq_s32(*out1, right_vec), 31)); + *out2 = vqaddq_s32(*out2, vshrq_n_s32(vandq_s32(*out2, right_vec), 31)); + *out3 = vqaddq_s32(*out3, vshrq_n_s32(vandq_s32(*out3, right_vec), 31)); + *out4 = vqaddq_s32(*out4, vshrq_n_s32(vandq_s32(*out4, right_vec), 31)); + + *out1 = vrshlq_s32(*out1, right_vec); + *out2 = vrshlq_s32(*out2, right_vec); + *out3 = vrshlq_s32(*out3, right_vec); + *out4 = vrshlq_s32(*out4, right_vec); +} +#endif + +void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, const AddQuantParameter *params) { + int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); + int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); + int index = 0; +#ifdef ENABLE_ARM + const int8x16_t min_vec = vdupq_n_s8(params->min_); + const int8x16_t max_vec = vdupq_n_s8(params->max_); + + const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_args_.zp_); + const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_args_.zp_); + const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_); + + const int32x4_t in0_left_vec = vdupq_n_s32(in0_left_shift); + const int32x4_t in1_left_vec = vdupq_n_s32(in1_left_shift); + + const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_args_.right_shift_); + const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_args_.right_shift_); + + const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_); + const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_); + + for (; index <= size - 16; index += 16) { + const int8x16_t in0_src = vld1q_s8(input0 + index); + const int8x16_t in1_src = vld1q_s8(input1 + index); + + const int16x8_t in0_s16_low = vmovl_s8(vget_low_s8(in0_src)); + const int16x8_t in0_s16_high = vmovl_s8(vget_high_s8(in0_src)); + const int16x8_t in1_s16_low = vmovl_s8(vget_low_s8(in1_src)); + const int16x8_t in1_s16_high = vmovl_s8(vget_high_s8(in1_src)); + + const int16x8_t in0_zp_low = vaddq_s16(in0_s16_low, in0_zp_vec); + const int16x8_t in0_zp_high = vaddq_s16(in0_s16_high, in0_zp_vec); + const int16x8_t in1_zp_low = vaddq_s16(in1_s16_low, in1_zp_vec); + const int16x8_t in1_zp_high = vaddq_s16(in1_s16_high, in1_zp_vec); + + int32x4_t in0_1 = vmovl_s16(vget_low_s16(in0_zp_low)); + int32x4_t in0_2 = vmovl_s16(vget_high_s16(in0_zp_low)); + int32x4_t in0_3 = vmovl_s16(vget_low_s16(in0_zp_high)); + int32x4_t in0_4 = vmovl_s16(vget_high_s16(in0_zp_high)); + int32x4_t in1_1 = vmovl_s16(vget_low_s16(in1_zp_low)); + int32x4_t in1_2 = vmovl_s16(vget_high_s16(in1_zp_low)); + int32x4_t in1_3 = vmovl_s16(vget_low_s16(in1_zp_high)); + int32x4_t in1_4 = vmovl_s16(vget_high_s16(in1_zp_high)); + + AddInt8InputRounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, in0_right_vec, params->in0_args_.multiplier_); + AddInt8InputRounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, in1_right_vec, params->in1_args_.multiplier_); + + /* calculate output */ + int32x4_t out1 = vaddq_s32(in0_1, in1_1); + int32x4_t out2 = vaddq_s32(in0_2, in1_2); + int32x4_t out3 = vaddq_s32(in0_3, in1_3); + int32x4_t out4 = vaddq_s32(in0_4, in1_4); + + AddInt8OutputRounding(&out1, &out2, &out3, &out4, out_left_vec, out_right_vec, params->out_multiplier_); + + const int16x4_t out1_s16 = vmovn_s32(out1); + const int16x4_t out2_s16 = vmovn_s32(out2); + const int16x4_t out3_s16 = vmovn_s32(out3); + const int16x4_t out4_s16 = vmovn_s32(out4); + + const int16x8_t out_s16_1 = vaddq_s16(vcombine_s16(out1_s16, out2_s16), out_zp_vec); + const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec); + + const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2)); + const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vec, out)); + + vst1q_s8(output + index, int8_out); + } +#endif + for (; index < size; index++) { + const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift; + const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift; + const int32_t in0 = + MultiplyByMultiplierAndRightShift(in0_left, params->in0_args_.multiplier_, params->in0_args_.right_shift_); + const int32_t in1 = + MultiplyByMultiplierAndRightShift(in1_left, params->in1_args_.multiplier_, params->in1_args_.right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} + +void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, + const AddQuantParameter *params, const AddQuantQrgs *ptr_args, const AddQuantQrgs *ele_args) { + int ptr_left_shift = (1 << params->left_shift_) * (1 << ptr_args->left_shift_); + int ele_left_shift = (1 << params->left_shift_) * (1 << ele_args->left_shift_); + int index = 0; + +#ifdef ENABLE_ARM + /* const value init */ + const int8x16_t min_vec = vdupq_n_s8(params->min_); + const int8x16_t max_vec = vdupq_n_s8(params->max_); + + const int16x8_t ptr_zp_vec = vdupq_n_s16(ptr_args->zp_); + const int16x8_t ele_zp_vec = vdupq_n_s16(ele_args->zp_); + const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_); + + const int32x4_t ptr_left_vec = vdupq_n_s32(ptr_left_shift); + const int32x4_t ele_left_vec = vdupq_n_s32(ele_left_shift); + + const int32x4_t ptr_right_vec = vdupq_n_s32(-ptr_args->right_shift_); + const int32x4_t ele_right_vec = vdupq_n_s32(-ele_args->right_shift_); + + const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_); + const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_); + + /* deal with const node */ + const int8x16_t ele_src = vdupq_n_s8(element_in); + const int16x8_t ele_s16_low = vmovl_s8(vget_low_s8(ele_src)); + const int16x8_t ele_s16_high = vmovl_s8(vget_high_s8(ele_src)); + const int16x8_t ele_zp_low = vaddq_s16(ele_s16_low, ele_zp_vec); + const int16x8_t ele_zp_high = vaddq_s16(ele_s16_high, ele_zp_vec); + int32x4_t ele1 = vmovl_s16(vget_low_s16(ele_zp_low)); + int32x4_t ele2 = vmovl_s16(vget_high_s16(ele_zp_low)); + int32x4_t ele3 = vmovl_s16(vget_low_s16(ele_zp_high)); + int32x4_t ele4 = vmovl_s16(vget_high_s16(ele_zp_high)); + + AddInt8InputRounding(&ele1, &ele2, &ele3, &ele4, ele_left_vec, ele_right_vec, ele_args->multiplier_); + + for (; index <= size - 16; index += 16) { + const int8x16_t ptr_src = vld1q_s8(ptr_in + index); + + const int16x8_t ptr_s16_low = vmovl_s8(vget_low_s8(ptr_src)); + const int16x8_t ptr_s16_high = vmovl_s8(vget_high_s8(ptr_src)); + + const int16x8_t ptr_zp_low = vaddq_s16(ptr_s16_low, ptr_zp_vec); + const int16x8_t ptr_zp_high = vaddq_s16(ptr_s16_high, ptr_zp_vec); + + int32x4_t ptr1 = vmovl_s16(vget_low_s16(ptr_zp_low)); + int32x4_t ptr2 = vmovl_s16(vget_high_s16(ptr_zp_low)); + int32x4_t ptr3 = vmovl_s16(vget_low_s16(ptr_zp_high)); + int32x4_t ptr4 = vmovl_s16(vget_high_s16(ptr_zp_high)); + + AddInt8InputRounding(&ptr1, &ptr2, &ptr3, &ptr4, ptr_left_vec, ptr_right_vec, ptr_args->multiplier_); + + /* calculate output */ + int32x4_t out1 = vaddq_s32(ptr1, ele1); + int32x4_t out2 = vaddq_s32(ptr2, ele2); + int32x4_t out3 = vaddq_s32(ptr3, ele3); + int32x4_t out4 = vaddq_s32(ptr4, ele4); + + AddInt8OutputRounding(&out1, &out2, &out3, &out4, out_left_vec, out_right_vec, params->out_multiplier_); + + const int16x4_t out1_s16 = vmovn_s32(out1); + const int16x4_t out2_s16 = vmovn_s32(out2); + const int16x4_t out3_s16 = vmovn_s32(out3); + const int16x4_t out4_s16 = vmovn_s32(out4); + + const int16x8_t out_s16_1 = vaddq_s16(vcombine_s16(out1_s16, out2_s16), out_zp_vec); + const int16x8_t out_s16_2 = vaddq_s16(vcombine_s16(out3_s16, out4_s16), out_zp_vec); + + const int8x16_t out = vcombine_s8(vqmovn_s16(out_s16_1), vqmovn_s16(out_s16_2)); + const int8x16_t int8_out = vmaxq_s8(min_vec, vminq_s8(max_vec, out)); + + vst1q_s8(output + index, int8_out); + } +#endif + for (; index < size; index++) { + const int32_t ptr_left = (ptr_in[index] + ptr_args->zp_) * ptr_left_shift; + const int32_t ele_left = (element_in + ele_args->zp_) * ele_left_shift; + const int32_t ptr = MultiplyByMultiplierAndRightShift(ptr_left, ptr_args->multiplier_, ptr_args->right_shift_); + const int32_t ele = MultiplyByMultiplierAndRightShift(ele_left, ele_args->multiplier_, ele_args->right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(ptr + ele, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} + +int ElementAddInt8(const int8_t *in0, const int8_t *in1, int8_t *out, int size) { + for (int i = 0; i < size; i++) { + out[i] = in0[i] + in1[i]; + } + return NNACL_OK; +} + +int BroadcastAddInt8(const int8_t *in0, const int8_t *in1, int8_t *tile_in0, int8_t *tile_in1, int8_t *out, int size, + ArithmeticParameter *param) { + TileDimensionsInt8(in0, in1, tile_in0, tile_in1, param); + return ElementAddInt8(tile_in0, tile_in1, out, size); +} + +#ifdef ENABLE_AVX +void AddInt8Rounding(__m128i *in1, __m128i *in2, __m128i *in3, __m128i *in4, const __m128i left_vec, + const int32_t right_shift, const __m128i multiplier) { + // Apply left shift + *in1 = _mm_mullo_epi32(*in1, left_vec); + *in2 = _mm_mullo_epi32(*in2, left_vec); + *in3 = _mm_mullo_epi32(*in3, left_vec); + *in4 = _mm_mullo_epi32(*in4, left_vec); + + // Apply the fixed-point part of the multiplier. + *in1 = _mm_qrdmulh_epi32(*in1, multiplier); + *in2 = _mm_qrdmulh_epi32(*in2, multiplier); + *in3 = _mm_qrdmulh_epi32(*in3, multiplier); + *in4 = _mm_qrdmulh_epi32(*in4, multiplier); + + // Apply right shift + int32_t in1_remainder_mask = (1ll << (right_shift)) - 1; + int32_t in1_remainder_threshold = in1_remainder_mask >> 1; + const __m128i vin1_remainder_mask = _mm_set1_epi32(in1_remainder_mask); + const __m128i vin1_remainder_threshold = _mm_set1_epi32(in1_remainder_threshold); + + const __m128i in1_remainder = + _mm_add_epi32(_mm_and_si128(*in1, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in1)); + *in1 = _mm_sub_epi32(_mm_rshr_epi32(*in1, right_shift), _mm_cmpgt_epi32(in1_remainder, vin1_remainder_threshold)); + + const __m128i in2_remainder = + _mm_add_epi32(_mm_and_si128(*in2, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in2)); + *in2 = _mm_sub_epi32(_mm_rshr_epi32(*in2, right_shift), _mm_cmpgt_epi32(in2_remainder, vin1_remainder_threshold)); + + const __m128i in3_remainder = + _mm_add_epi32(_mm_and_si128(*in3, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in3)); + *in3 = _mm_sub_epi32(_mm_rshr_epi32(*in3, right_shift), _mm_cmpgt_epi32(in3_remainder, vin1_remainder_threshold)); + + const __m128i in4_remainder = + _mm_add_epi32(_mm_and_si128(*in4, vin1_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), *in4)); + *in4 = _mm_sub_epi32(_mm_rshr_epi32(*in4, right_shift), _mm_cmpgt_epi32(in4_remainder, vin1_remainder_threshold)); +} + +void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, int size, + const AddQuantParameter *params) { + const int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); + const int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); + const __m128i min_vec = _mm_set1_epi8(params->min_); + const __m128i max_vec = _mm_set1_epi8(params->max_); + const __m128i in0_zp_vec = _mm_set1_epi16(params->in0_args_.zp_); + const __m128i in1_zp_vec = _mm_set1_epi16(params->in1_args_.zp_); + const __m128i out_zp_vec = _mm_set1_epi16(params->out_zp_); + const __m128i in0_left_vec = _mm_set1_epi32(in0_left_shift); + const __m128i in1_left_vec = _mm_set1_epi32(in1_left_shift); + const __m128i in0_multiplier = _mm_set1_epi32(params->in0_args_.multiplier_); + const __m128i in1_multiplier = _mm_set1_epi32(params->in1_args_.multiplier_); + const __m128i out_multiplier = _mm_set1_epi32(params->out_multiplier_); + int index = 0; + for (; index <= size - 16; index += 16) { + const __m128i in0_src = _mm_loadu_si128((__m128i *)(input0 + index)); + const __m128i in1_src = _mm_loadu_si128((__m128i *)(input1 + index)); + + const __m256i in0_s16 = _mm256_cvtepi8_epi16(in0_src); + const __m128i in0_s16_low = _mm256_extractf128_si256(in0_s16, 0); + const __m128i in0_s16_high = _mm256_extractf128_si256(in0_s16, 1); + const __m256i in1_s16 = _mm256_cvtepi8_epi16(in1_src); + const __m128i in1_s16_low = _mm256_extractf128_si256(in1_s16, 0); + const __m128i in1_s16_high = _mm256_extractf128_si256(in1_s16, 1); + + const __m128i in0_zp_low = _mm_add_epi16(in0_s16_low, in0_zp_vec); + const __m128i in0_zp_high = _mm_add_epi16(in0_s16_high, in0_zp_vec); + const __m128i in1_zp_low = _mm_add_epi16(in1_s16_low, in1_zp_vec); + const __m128i in1_zp_high = _mm_add_epi16(in1_s16_high, in1_zp_vec); + + __m256i tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_low); + __m128i in0_1 = _mm256_extractf128_si256(tmp_in0, 0); + __m128i in0_2 = _mm256_extractf128_si256(tmp_in0, 1); + tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_high); + __m128i in0_3 = _mm256_extractf128_si256(tmp_in0, 0); + __m128i in0_4 = _mm256_extractf128_si256(tmp_in0, 1); + __m256i tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_low); + __m128i in1_1 = _mm256_extractf128_si256(tmp_in1, 0); + __m128i in1_2 = _mm256_extractf128_si256(tmp_in1, 1); + tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_high); + __m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0); + __m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1); + + AddInt8Rounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, params->in0_args_.right_shift_, in0_multiplier); + AddInt8Rounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, params->in1_args_.right_shift_, in1_multiplier); + + /* calculate output */ + __m128i out1 = _mm_add_epi32(in0_1, in1_1); + __m128i out2 = _mm_add_epi32(in0_2, in1_2); + __m128i out3 = _mm_add_epi32(in0_3, in1_3); + __m128i out4 = _mm_add_epi32(in0_4, in1_4); + + // Apply left shift + out1 = _mm_slli_epi32(out1, params->out_left_shift_); + out2 = _mm_slli_epi32(out2, params->out_left_shift_); + out3 = _mm_slli_epi32(out3, params->out_left_shift_); + out4 = _mm_slli_epi32(out4, params->out_left_shift_); + + // Apply the fixed-point part of the multiplier. + out1 = _mm_qrdmulh_epi32(out1, out_multiplier); + out2 = _mm_qrdmulh_epi32(out2, out_multiplier); + out3 = _mm_qrdmulh_epi32(out3, out_multiplier); + out4 = _mm_qrdmulh_epi32(out4, out_multiplier); + + // Apply right shift + int32_t out_remainder_mask = (1ll << (params->out_right_shift_)) - 1; + int32_t out_remainder_threshold = out_remainder_mask >> 1; + const __m128i vout_remainder_mask = _mm_set1_epi32(out_remainder_mask); + const __m128i vout_remainder_threshold = _mm_set1_epi32(out_remainder_threshold); + const __m128i out1_remainder = + _mm_add_epi32(_mm_and_si128(out1, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out1)); + out1 = _mm_sub_epi32(_mm_rshr_epi32(out1, params->out_right_shift_), + _mm_cmpgt_epi32(out1_remainder, vout_remainder_threshold)); + const __m128i out2_remainder = + _mm_add_epi32(_mm_and_si128(out2, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out2)); + out2 = _mm_sub_epi32(_mm_rshr_epi32(out2, params->out_right_shift_), + _mm_cmpgt_epi32(out2_remainder, vout_remainder_threshold)); + const __m128i out3_remainder = + _mm_add_epi32(_mm_and_si128(out3, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out3)); + out3 = _mm_sub_epi32(_mm_rshr_epi32(out3, params->out_right_shift_), + _mm_cmpgt_epi32(out3_remainder, vout_remainder_threshold)); + const __m128i out4_remainder = + _mm_add_epi32(_mm_and_si128(out4, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out4)); + out4 = _mm_sub_epi32(_mm_rshr_epi32(out4, params->out_right_shift_), + _mm_cmpgt_epi32(out4_remainder, vout_remainder_threshold)); + + __m128i out1_s16 = _mm_packs_epi32(out1, out2); + __m128i out2_s16 = _mm_packs_epi32(out3, out4); + + __m128i out_s16_1 = _mm_add_epi16(out1_s16, out_zp_vec); + __m128i out_s16_2 = _mm_add_epi16(out2_s16, out_zp_vec); + __m128i out = _mm_packs_epi16(out_s16_1, out_s16_2); + __m128i int8_out = _mm_max_epi8(min_vec, _mm_min_epi8(max_vec, out)); + + _mm_storeu_si128((__m128i *)(output + index), int8_out); + } + for (; index < size; index++) { + const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift; + const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift; + const int32_t in0 = + MultiplyByMultiplierAndRightShift(in0_left, params->in0_args_.multiplier_, params->in0_args_.right_shift_); + const int32_t in1 = + MultiplyByMultiplierAndRightShift(in1_left, params->in1_args_.multiplier_, params->in1_args_.right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} + +void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, + const AddQuantParameter *params, const AddQuantQrgs *ptr_args, const AddQuantQrgs *ele_args) { + // input0: ptr_in + // input1: element_in + // load quant parameters of input0 and input1 + const int in0_left_shift = (1 << params->left_shift_) * (1 << ptr_args->left_shift_); + const int in1_left_shift = (1 << params->left_shift_) * (1 << ele_args->left_shift_); + const __m128i min_vec = _mm_set1_epi8(params->min_); + const __m128i max_vec = _mm_set1_epi8(params->max_); + const __m128i in0_zp_vec = _mm_set1_epi16(ptr_args->zp_); + const __m128i in1_zp_vec = _mm_set1_epi16(ele_args->zp_); + const __m128i out_zp_vec = _mm_set1_epi16(params->out_zp_); + const __m128i in0_left_vec = _mm_set1_epi32(in0_left_shift); + const __m128i in1_left_vec = _mm_set1_epi32(in1_left_shift); + const __m128i in0_multiplier = _mm_set1_epi32(params->in0_args_.multiplier_); + const __m128i in1_multiplier = _mm_set1_epi32(params->in1_args_.multiplier_); + const __m128i out_multiplier = _mm_set1_epi32(params->out_multiplier_); + + // input1 can be processed once because it is const + const __m128i in1_src = _mm_set1_epi8(element_in); + const __m256i in1_s16 = _mm256_cvtepi8_epi16(in1_src); + const __m128i in1_s16_low = _mm256_extractf128_si256(in1_s16, 0); + const __m128i in1_s16_high = _mm256_extractf128_si256(in1_s16, 1); + const __m128i in1_zp_low = _mm_add_epi16(in1_s16_low, in1_zp_vec); + const __m128i in1_zp_high = _mm_add_epi16(in1_s16_high, in1_zp_vec); + __m256i tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_low); + __m128i in1_1 = _mm256_extractf128_si256(tmp_in1, 0); + __m128i in1_2 = _mm256_extractf128_si256(tmp_in1, 1); + tmp_in1 = _mm256_cvtepi16_epi32(in1_zp_high); + __m128i in1_3 = _mm256_extractf128_si256(tmp_in1, 0); + __m128i in1_4 = _mm256_extractf128_si256(tmp_in1, 1); + + AddInt8Rounding(&in1_1, &in1_2, &in1_3, &in1_4, in1_left_vec, params->in1_args_.right_shift_, in1_multiplier); + + int index = 0; + for (; index <= size - 16; index += 16) { + const __m128i in0_src = _mm_loadu_si128((__m128i *)(ptr_in + index)); + const __m256i in0_s16 = _mm256_cvtepi8_epi16(in0_src); + const __m128i in0_s16_low = _mm256_extractf128_si256(in0_s16, 0); + const __m128i in0_s16_high = _mm256_extractf128_si256(in0_s16, 1); + const __m128i in0_zp_low = _mm_add_epi16(in0_s16_low, in0_zp_vec); + const __m128i in0_zp_high = _mm_add_epi16(in0_s16_high, in0_zp_vec); + + __m256i tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_low); + __m128i in0_1 = _mm256_extractf128_si256(tmp_in0, 0); + __m128i in0_2 = _mm256_extractf128_si256(tmp_in0, 1); + tmp_in0 = _mm256_cvtepi16_epi32(in0_zp_high); + __m128i in0_3 = _mm256_extractf128_si256(tmp_in0, 0); + __m128i in0_4 = _mm256_extractf128_si256(tmp_in0, 1); + + AddInt8Rounding(&in0_1, &in0_2, &in0_3, &in0_4, in0_left_vec, params->in0_args_.right_shift_, in0_multiplier); + + /* calculate output */ + __m128i out1 = _mm_add_epi32(in0_1, in1_1); + __m128i out2 = _mm_add_epi32(in0_2, in1_2); + __m128i out3 = _mm_add_epi32(in0_3, in1_3); + __m128i out4 = _mm_add_epi32(in0_4, in1_4); + + // Apply left shift + out1 = _mm_slli_epi32(out1, params->out_left_shift_); + out2 = _mm_slli_epi32(out2, params->out_left_shift_); + out3 = _mm_slli_epi32(out3, params->out_left_shift_); + out4 = _mm_slli_epi32(out4, params->out_left_shift_); + + // Apply the fixed-point part of the multiplier. + out1 = _mm_qrdmulh_epi32(out1, out_multiplier); + out2 = _mm_qrdmulh_epi32(out2, out_multiplier); + out3 = _mm_qrdmulh_epi32(out3, out_multiplier); + out4 = _mm_qrdmulh_epi32(out4, out_multiplier); + + // Apply right shift + int32_t out_remainder_mask = (1ll << (params->out_right_shift_)) - 1; + int32_t out_remainder_threshold = out_remainder_mask >> 1; + const __m128i vout_remainder_mask = _mm_set1_epi32(out_remainder_mask); + const __m128i vout_remainder_threshold = _mm_set1_epi32(out_remainder_threshold); + const __m128i out1_remainder = + _mm_add_epi32(_mm_and_si128(out1, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out1)); + out1 = _mm_sub_epi32(_mm_rshr_epi32(out1, params->out_right_shift_), + _mm_cmpgt_epi32(out1_remainder, vout_remainder_threshold)); + const __m128i out2_remainder = + _mm_add_epi32(_mm_and_si128(out2, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out2)); + out2 = _mm_sub_epi32(_mm_rshr_epi32(out2, params->out_right_shift_), + _mm_cmpgt_epi32(out2_remainder, vout_remainder_threshold)); + const __m128i out3_remainder = + _mm_add_epi32(_mm_and_si128(out3, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out3)); + out3 = _mm_sub_epi32(_mm_rshr_epi32(out3, params->out_right_shift_), + _mm_cmpgt_epi32(out3_remainder, vout_remainder_threshold)); + const __m128i out4_remainder = + _mm_add_epi32(_mm_and_si128(out4, vout_remainder_mask), _mm_cmpgt_epi32(_mm_setzero_si128(), out4)); + out4 = _mm_sub_epi32(_mm_rshr_epi32(out4, params->out_right_shift_), + _mm_cmpgt_epi32(out4_remainder, vout_remainder_threshold)); + + __m128i out1_s16 = _mm_packs_epi32(out1, out2); + __m128i out2_s16 = _mm_packs_epi32(out3, out4); + + __m128i out_s16_1 = _mm_add_epi16(out1_s16, out_zp_vec); + __m128i out_s16_2 = _mm_add_epi16(out2_s16, out_zp_vec); + __m128i out = _mm_packs_epi16(out_s16_1, out_s16_2); + __m128i int8_out = _mm_max_epi8(min_vec, _mm_min_epi8(max_vec, out)); + + _mm_storeu_si128((__m128i *)(output + index), int8_out); + } + for (; index < size; index++) { + const int32_t in0_left = (ptr_in[index] + ptr_args->zp_) * in0_left_shift; + const int32_t in1_left = (element_in + ele_args->zp_) * in1_left_shift; + const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, ptr_args->multiplier_, ptr_args->right_shift_); + const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, ele_args->multiplier_, ele_args->right_shift_); + + int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, + -params->out_right_shift_); + out += params->out_zp_; + output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); + } + return; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/add_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/add_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..333da821a4a8e14b78cc0c21bbe622bf077b37ac --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/add_int8.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_ADD_INT8_H_ +#define NNACL_ADD_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/arithmetic_parameter.h" +#include "nnacl/int8/arithmetic_int8.h" + +typedef struct AddQuantQrgs { + int32_t zp_; + int32_t left_shift_; + int32_t right_shift_; + int32_t multiplier_; +} AddQuantQrgs; + +typedef struct AddQuantParameter { + int left_shift_; + int32_t min_; + int32_t max_; + + AddQuantQrgs in0_args_; + AddQuantQrgs in1_args_; + + int32_t out_zp_; + int32_t out_left_shift_; + int32_t out_right_shift_; + int32_t out_multiplier_; +} AddQuantParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, const AddQuantParameter *params); + +void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, + const AddQuantParameter *params, const AddQuantQrgs *ptr_args, const AddQuantQrgs *ele_args); + +int ElementAddInt8(const int8_t *in0, const int8_t *in1, int8_t *out, int size); + +int BroadcastAddInt8(const int8_t *in0, const int8_t *in1, int8_t *tile_in0, int8_t *tile_in1, int8_t *out, int size, + ArithmeticParameter *param); + +#ifdef ENABLE_AVX +void AddInt8_AVX2(const int8_t *input0, const int8_t *input1, int8_t *output, int size, + const AddQuantParameter *params); + +void AddOptInt8_AVX2(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, + const AddQuantParameter *params, const AddQuantQrgs *ptr_args, const AddQuantQrgs *ele_args); +#endif +#ifdef __cplusplus +} +#endif + +#endif // NNACL_ADD_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/arg_min_max_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/arg_min_max_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..22fb1ab98eb64242535d83e1b7ac271936decbbb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/arg_min_max_int8.c @@ -0,0 +1,237 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/int8/arg_min_max_int8.h" +#include + +void CalcParameter(const int32_t *shape, int dims_number, int axis, int32_t *pre_axis_count, int32_t *axis_count, + int32_t *after_axis_count) { + *pre_axis_count = 1; + for (int i = 0; i < axis; ++i) { + *pre_axis_count = (*pre_axis_count) * shape[i]; + } + + *axis_count = shape[axis]; + + *after_axis_count = 1; + for (int i = axis + 1; i < dims_number; ++i) { + *after_axis_count = (*after_axis_count) * shape[i]; + } +} + +void SetOutputValue(float value, int32_t index, int8_t *output1, int8_t *output2, int offset, + float output_inverse_scale, float output_zp, bool out_value) { + if (output2 != NULL) { + int32_t *output1_index = (int32_t *)output1; + output1_index[offset] = index; + output2[offset] = value * output_inverse_scale + output_zp; + } else { + if (out_value) { + output1[offset] = value * output_inverse_scale + output_zp; + } else { + int32_t *output1_index = (int32_t *)output1; + output1_index[offset] = index; + } + } +} + +void DoArgMinMaxQuant(const int8_t *input, int8_t *output1, int8_t *output2, const ArgMinMaxComputeParam *param, + int pre_axis_count, int axis_count, int after_axis_count, const QuantArg *in_quant_arg, + const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + for (int i = 0; i < pre_axis_count; ++i) { + int output_offset = i * after_axis_count; + int input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float value = -FLT_MAX; + if (!param->get_max_) { + value = FLT_MAX; + } + int32_t index = 0; + for (int k = 0; k < axis_count; ++k) { + float value_tmp = input[input_offset + k * after_axis_count + j] * in_quant_arg->scale_ + bias; + if (param->get_max_) { + if (value_tmp > value) { + value = value_tmp; + index = k; + } + } else { + if (value_tmp < value) { + value = value_tmp; + index = k; + } + } + } + SetOutputValue(value, index, output1, output2, output_offset + j, output_inverse_scale, output_zp, out_value); + } + } +} + +void Int8ArgMinMaxQuant(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + const ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, + const QuantArg *out_quant_arg) { + int pre_axis_count = 1; + int axis_count = 1; + int after_axis_count = 1; + CalcParameter(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); + DoArgMinMaxQuant(input, output1, output2, param, pre_axis_count, axis_count, after_axis_count, in_quant_arg, + out_quant_arg); + return; +} + +int ArgCompareAscInt8(const void *a, const void *b) { + return ((ArgElement *)a)->data_.f_data_ - ((ArgElement *)b)->data_.f_data_; +} + +int ArgCompareDescInt8(const void *a, const void *b) { + return ((ArgElement *)b)->data_.f_data_ - ((ArgElement *)a)->data_.f_data_; +} + +int8_t GetInt8Output(float real_out, float output_inverse_scale, int32_t output_zp) { + return real_out * output_inverse_scale + output_zp; +} + +void Int8ArgMinMaxDim0(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { + for (int j = 0; j < in_shape[0]; ++j) { + int offset = param->in_strides_[0] * j + i; + param->arg_elements_[j].index_ = (uint32_t)j; + param->arg_elements_[j].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), ArgCompareAscInt8); + } + + for (int j = 0; j < param->topk_; ++j) { + int out_offset = j * param->out_strides_[0] + i; + SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2, + out_offset, output_inverse_scale, output_zp, out_value); + } + } +} + +void Int8ArgMinMaxDim1(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + int in_shape1 = in_shape[1]; + for (int i = 0; i < in_shape[0]; ++i) { + int in_dim0_offset = i * param->in_strides_[0]; + int out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < param->in_strides_[1]; ++j) { + for (int k = 0; k < in_shape1; ++k) { + int offset = param->in_strides_[1] * k + in_dim0_offset + j; + param->arg_elements_[k].index_ = (size_t)k; + param->arg_elements_[k].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), ArgCompareAscInt8); + } + + for (int k = 0; k < param->topk_; ++k) { + int out_offset = out_dim0_offset + j + k * param->out_strides_[1]; + SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2, + out_offset, output_inverse_scale, output_zp, out_value); + } + } + } +} + +void Int8ArgMinMaxDim2(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + for (int i = 0; i < in_shape[0]; ++i) { + int in_dim0_offset = i * param->in_strides_[0]; + int out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + int in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + int out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < param->in_strides_[2]; ++k) { + for (int l = 0; l < in_shape2; ++l) { + int offset = param->in_strides_[2] * l + k + in_dim1_offset; + param->arg_elements_[l].index_ = (uint32_t)l; + param->arg_elements_[l].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), ArgCompareAscInt8); + } + for (int l = 0; l < param->topk_; ++l) { + int out_offset = out_dim1_offset + k + l * param->out_strides_[2]; + SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2, + out_offset, output_inverse_scale, output_zp, out_value); + } + } + } + } +} + +void Int8ArgMinMaxDim3(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + bool out_value = param->out_value_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float bias = -in_quant_arg->zp_ * in_quant_arg->scale_; + int32_t output_zp = out_quant_arg->zp_; + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + int in_shape3 = in_shape[3]; + for (int i = 0; i < in_shape[0]; ++i) { + int in_dim0_offset = i * param->in_strides_[0]; + int out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + int in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + int out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < in_shape2; ++k) { + int in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; + int out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; + for (int l = 0; l < in_shape3; ++l) { + int offset = l + in_dim2_offset; + param->arg_elements_[l].index_ = (uint32_t)l; + param->arg_elements_[l].data_.f_data_ = input[offset] * in_quant_arg->scale_ + bias; + } + if (param->get_max_) { + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareDescInt8); + } else { + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), ArgCompareAscInt8); + } + for (int l = 0; l < param->topk_; ++l) { + int out_offset = out_dim2_offset + l; + SetOutputValue(param->arg_elements_[j].data_.f_data_, param->arg_elements_[j].index_, output1, output2, + out_offset, output_inverse_scale, output_zp, out_value); + } + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/arg_min_max_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/arg_min_max_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..966bb549fe06b5f32b16f9609e7115ac034e764a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/arg_min_max_int8.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_ARG_MIN_MAX_INT8_H_ +#define NNACL_INT8_ARG_MIN_MAX_INT8_H_ + +#include "nnacl/arg_min_max_parameter.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/kernel/arg_min_max.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Int8ArgMinMaxQuant(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + const ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +void Int8ArgMinMaxDim0(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +void Int8ArgMinMaxDim1(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +void Int8ArgMinMaxDim2(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +void Int8ArgMinMaxDim3(const int8_t *input, int8_t *output1, int8_t *output2, const int32_t *in_shape, + ArgMinMaxComputeParam *param, const QuantArg *in_quant, const QuantArg *out_quant); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_ARG_MIN_MAX_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/arithmetic_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/arithmetic_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..5e24dbbedb76a7626f0d1b88a7e252a23b48f2cb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/arithmetic_int8.c @@ -0,0 +1,137 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/arithmetic_int8.h" +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/errorcode.h" + +void TileOneDimensionInt8(const int8_t *inData, int8_t *outData, int dim, size_t ndim, const int32_t *inShape, + const int32_t *inStrides, const int32_t *outStrides, const int32_t *multiple) { + int srcDimSize = inShape[dim]; + if (dim == ndim - 1) { + for (int i = 0; i < multiple[dim]; i++) { + memcpy(outData, inData, srcDimSize * sizeof(int8_t)); + outData += srcDimSize; + } + return; + } + for (size_t i = 0; i < srcDimSize; i++) { + for (size_t j = 0; j < multiple[dim]; j++) { + TileOneDimensionInt8(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, + inShape, inStrides, outStrides, multiple); + } + } +} + +void TileDimensionsInt8(const int8_t *data0, const int8_t *data1, int8_t *tile_data0, int8_t *tile_data1, + ArithmeticParameter *param) { + CalcMultiplesAndStrides(param); + TileOneDimensionInt8(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, + param->multiples0_); + TileOneDimensionInt8(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, + param->multiples1_); +} + +#define ACCURACY_DATA 0.00000001 + +int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + float minus_inputs = in0_real - in1_real; + bool out_real = true; + if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { + out_real = false; + } + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + float minus_inputs = in0_real - in1_real; + bool out_real = false; + if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { + out_real = true; + } + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + bool out_real = in0_real < in1_real; + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + bool out_real = in0_real <= in1_real; + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + bool out_real = in0_real > in1_real; + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + for (int index = 0; index < element_size; ++index) { + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + bool out_real = in0_real >= in1_real; + output[index] = (uint8_t)out_real; + } + return NNACL_OK; +} + +#undef ACCURACY_DATA diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/arithmetic_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/arithmetic_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..eb304a95052d11961e5173358ce75d145ca36fbd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/arithmetic_int8.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_ARITHMETIC_INT8_H_ +#define NNACL_INT8_ARITHMETIC_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/base/arithmetic_base.h" + +#ifdef __cplusplus +extern "C" { +#endif +void TileOneDimensionInt8(const int8_t *inData, int8_t *outData, int dim, size_t ndim, const int32_t *inShape, + const int32_t *inStrides, const int32_t *outStrides, const int32_t *multiple); +void TileDimensionsInt8(const int8_t *data0, const int8_t *data1, int8_t *tile_data0, int8_t *tile_data1, + ArithmeticParameter *param); + +int ElementNotEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); + +int ElementEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); + +int ElementLessInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, ArithmeticQuantArg *quant_arg); + +int ElementLessEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); + +int ElementGreaterInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); + +int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, uint8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_ARITHMETIC_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/arithmetic_self_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/arithmetic_self_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..054ffc6b167d247f093c984a8f8a8b2bb8f37d4e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/arithmetic_self_int8.c @@ -0,0 +1,305 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/arithmetic_self_int8.h" +#include +#include +#ifdef ENABLE_NEON +#include +#include "nnacl/int8/common_func_int8.h" +#endif +#include "nnacl/int8/fixed_point.h" + +int Int8ElementFloor(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(floorf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementRound(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(round(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementCeil(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(ceil(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementAbs(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(fabsf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementSin(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(sinf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementCos(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(cosf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementLog(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(logf(input[i] * in_scale + bias) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementSqrt(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + float input_f32 = input[i] * in_scale + bias; + if (input_f32 < 0) { + return NNACL_ERRCODE_SQRT_NEGATIVE; + } + int32_t output_tmp = round(sqrtf(input_f32) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementRsqrt(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + float input_f32 = input[i] * in_scale + bias; + if (input_f32 <= 0) { + return NNACL_ERRCODE_RSQRT_NEGATIVE_OR_ZERO; + } + int32_t output_tmp = round(1.f / (sqrtf(input_f32) * out_scale)) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +#ifdef ENABLE_NEON + +int16x4_t ClacSumHalfWord(int32x4_t scaled_input, int32x4_t left_shift_out_vec, int32x4_t output_multiplier_vec, + ArithSelfQuantArg para) { + int32x4_t input_scale = vmulq_s32(scaled_input, scaled_input); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + para.shift_right_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(para.out_args_.zp_)); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(para.output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(para.output_activation_max_)); + return vqmovn_s32(raw_sum); +} + +void SquareInt8NEON(const int8_t *input_data, int8_t *output_data, int64_t element_size, ArithSelfQuantArg para, + int32_t *index) { + int32x4_t output_multiplier_vec = vdupq_n_s32(para.output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << (size_t)para.shift_left_); + + for (; (*index) <= element_size - 8; (*index) += 8) { + int16x8_t input_val = LoadAndAddOffset(input_data, *index, para.in_args_.zp_); + int32x4_t input_low = vmovl_s16(vget_low_s16(input_val)); + int32x4_t input_high = vmovl_s16(vget_high_s16(input_val)); + + int16x4_t sum_low = ClacSumHalfWord(input_low, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_high = ClacSumHalfWord(input_high, left_shift_out_vec, output_multiplier_vec, para); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(output_data, res_u8_n0); + output_data += 8; + } +} +#endif + +int Int8ElementSquare(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + int32_t in_zp = para.in_args_.zp_; + int32_t out_zp = para.out_args_.zp_; + + int index = 0; +#ifdef ENABLE_NEON + SquareInt8NEON(input, output, element_size, para, &index); +#endif + for (; index < element_size; index++) { + const int32_t input_val = input[index] + in_zp; + int32_t output_tmp = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input_val * input_val * (1 << para.shift_left_), para.output_multiplier_), + para.shift_right_); + output_tmp += out_zp; + if (output_tmp > para.output_activation_max_) { + output[index] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[index] = para.output_activation_min_; + } else { + output[index] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} + +int Int8ElementLogicalNot(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + int32_t output_tmp = round(((float)(!(bool)(input[i] * in_scale + bias))) / out_scale) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (output_tmp); + } + } + return NNACL_OK; +} + +int Int8ElementReciprocal(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para) { + float in_scale = para.in_args_.scale_; + int32_t in_zp = para.in_args_.zp_; + float out_scale = para.out_args_.scale_; + int32_t out_zp = para.out_args_.zp_; + float bias = in_zp * in_scale; + for (int i = 0; i < element_size; i++) { + float input_f32 = input[i] * in_scale + bias; + if (fabs(input_f32) <= FLT_EPSILON) { + return NNACL_ERR; + } + int32_t output_tmp = round(1.f / (input_f32 * out_scale)) + out_zp; + if (output_tmp > para.output_activation_max_) { + output[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output[i] = para.output_activation_min_; + } else { + output[i] = (int8_t)output_tmp; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/arithmetic_self_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/arithmetic_self_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..e188097bd5d994d94c4c9305584ea3ee68ab8581 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/arithmetic_self_int8.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_ARITHMETIC_SELF_INT8_H_ +#define NNACL_INT8_ARITHMETIC_SELF_INT8_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int Int8ElementRound(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementFloor(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementCeil(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementAbs(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementSin(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementCos(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementLog(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementSqrt(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementRsqrt(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementSquare(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementLogicalNot(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +int Int8ElementReciprocal(const int8_t *input, int8_t *output, int element_size, ArithSelfQuantArg para); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_ARITHMETIC_SELF_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/batch_to_space_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/batch_to_space_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..65bb27797b6d8443c5490ae988fb05cc4e56e872 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/batch_to_space_int8.c @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/batch_to_space_int8.h" + +void BatchToSpaceNoCropForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, int out_n, + const int32_t *block, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg) { + int block_h = block[0]; + int block_w = block[1]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int64_t stride_h = block_w * out_n; + int64_t output_offset = 0; + int64_t in_stride_h = in_w * in_c; + int64_t in_stride_n = in_stride_h * in_h; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float scale = in_quant_arg->scale_ * output_inverse_scale; + float bias = -in_quant_arg->zp_ * scale; + int32_t output_zp = out_quant_arg->zp_; + + for (int n = 0; n < out_n; ++n) { + for (int h = 0; h < in_h; ++h) { + int64_t h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + for (int w = 0; w < in_w; ++w) { + int64_t w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + int64_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + for (int c = 0; c < in_c; ++c) { + int32_t output_tmp = round(input[in_offset + c] * scale + bias) + output_zp; + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output[output_offset++] = output_tmp; + } + } + } + } + } + } +} + +void BatchToSpaceForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, int out_n, + const int32_t *block, const int32_t *crops, const QuantArg *in_quant_arg, + const QuantArg *out_quant_arg) { + int block_h = block[0]; + int block_w = block[1]; + int in_h = in_shape[1]; + int in_w = in_shape[2]; + int in_c = in_shape[3]; + int h_start = crops[0] / block_h; + int h_valid_begin = crops[0]; + int h_end = MSMIN((in_h * block_h - crops[1]) / block_h + 1, in_h); + int h_valid_end = in_h * block_h - crops[1] - 1; + int w_start = crops[2] / block_w; + int w_valid_begin = crops[2]; + int w_end = MSMIN((in_w * block_w - crops[3]) / block_w + 1, in_w); + int w_valid_end = in_w * block_w - crops[3] - 1; + + int64_t stride_h = block_w * out_n; + int64_t output_offset = 0; + int64_t in_stride_h = in_w * in_c; + int64_t in_stride_n = in_stride_h * in_h; + + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float scale = in_quant_arg->scale_ * output_inverse_scale; + float bias = -in_quant_arg->zp_ * scale; + int32_t output_zp = out_quant_arg->zp_; + + for (int n = 0; n < out_n; ++n) { + for (int h = h_start; h < h_end; ++h) { + int64_t h_offset = h * in_stride_h; + for (int bh = 0; bh < block_h; ++bh) { + int64_t h_index = h * block_h + bh; + if (h_index < h_valid_begin || h_index > h_valid_end) { + continue; + } + for (int w = w_start; w < w_end; ++w) { + int64_t w_offset = w * in_c; + for (int bw = 0; bw < block_w; ++bw) { + int64_t w_index = w * block_w + bw; + if (w_index < w_valid_begin || w_index > w_valid_end) { + continue; + } + int64_t in_offset = in_stride_n * (bh * stride_h + bw * out_n + n) + w_offset + h_offset; + for (int c = 0; c < in_c; ++c) { + int32_t output_tmp = round(input[in_offset + c] * scale + bias) + output_zp; + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output[output_offset++] = output_tmp; + } + } + } + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/batch_to_space_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/batch_to_space_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..ebc1fdfd10293f8d44793384063ed5dd4a70bcf9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/batch_to_space_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_BATCH_TO_SPACE_INT8_H_ +#define NNACL_INT8_BATCH_TO_SPACE_INT8_H_ +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +void BatchToSpaceNoCropForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, int out_n, + const int32_t *block, const QuantArg *in_quant_arg, const QuantArg *out_quant_arg); +void BatchToSpaceForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, int out_n, + const int32_t *block, const int32_t *crops, const QuantArg *in_quant_arg, + const QuantArg *out_quant_arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_BATCH_TO_SPACE_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/batchnorm_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/batchnorm_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..cac1e0ef331e78df9eb34a797d817435d69080ef --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/batchnorm_int8.c @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/batchnorm_int8.h" +#include +#include "nnacl/batchnorm_parameter.h" + +void BatchNormInt8(int8_t *output_ptr, const int8_t *input_ptr, const float *alpha_ptr, const float *beta_ptr, + int task_id, int unit, int units, int channel) { + int unit_st = task_id * unit; + int unit_end = MSMIN((task_id + 1) * unit, units); + for (int u = unit_st; u < unit_end; u++) { + for (int c = 0; c < channel; c++) { + int32_t output_tmp = round(input_ptr[u * channel + c] * alpha_ptr[c] + beta_ptr[c]); + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output_ptr[u * channel + c] = (int8_t)output_tmp; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/batchnorm_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/batchnorm_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..7ab797738cce378dd9f0f15c0d8bbc38c75bf807 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/batchnorm_int8.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_BATCHNORM_H_ +#define NNACL_INT8_BATCHNORM_H_ + +#include "nnacl/op_base.h" +#include "nnacl/batchnorm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void BatchNormInt8(int8_t *output_ptr, const int8_t *input_ptr, const float *alpha_ptr, const float *beta_ptr, + int task_id, int unit, int units, int channel); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_BATCHNORM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/common_func_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/common_func_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..ce318f35f991a053e821935c1382669c6f83f8fe --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/common_func_int8.c @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/common_func_int8.h" +#include "nnacl/int8/fixed_point.h" + +void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, size_t oc, size_t plane, + size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int32_t mini, int32_t maxi, + int32_t left_shift, int32_t right_shift, int32_t zp, int size) { + if (size == 0) { + return; + } + for (int r = 0; r < plane; r++) { + for (int c = 0; c < oc; c++) { + int c4div = c / size, c4mod = c % size; + int src_index = c4div * in_plane_stride + r * size + c4mod; + int dst_index = r * out_oc_stride + c; + int32_t value = in[src_index]; + if (bias != NULL) { + value = in[src_index] + bias[c]; + } + value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + out[dst_index] = (int8_t)value; + } + } + return; +} + +void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, + int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, + int32_t maxi) { +/* ((int32_t)row4x4-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */ +#ifndef ENABLE_ARM64 + PostConvFuncCommInt8(in, out, bias, oc, plane, stride, UP_ROUND(plane, C4NUM) * C4NUM, multiplier, mini, maxi, + left_shift, right_shift, zp, C4NUM); +#else + size_t oc4div = oc / C4NUM * C4NUM; + size_t oc4res = oc % C4NUM; + PostFuncInt8C4Neon64(in, bias, out, oc4div, oc4res, plane, stride * sizeof(int8_t), multiplier, left_shift, + right_shift, zp, mini, maxi); +#endif + return; +} + +#ifdef ENABLE_ARM +int16x8_t LoadAndAddOffset(const int8_t *data, int index, int offset) { + int8x8_t input_s8 = vld1_s8(data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + return vaddq_s16(input_s16, vdupq_n_s16(offset)); +} + +int32x4_t ClacScaledInput(int32x4_t input, int32x4_t left_shift_result_vec, int32x4_t input_multiplier_vec, + int32x4_t right_shift_vec) { + int32x4_t shifted_input = vmulq_s32(input, left_shift_result_vec); + shifted_input = vqrdmulhq_s32(shifted_input, input_multiplier_vec); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(shifted_input, right_shift_vec), 31); + return vrshlq_s32(vqaddq_s32(shifted_input, fixup), right_shift_vec); +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/common_func_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/common_func_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..870cfea4464e9411f9d6919bac59c97cb3f657ae --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/common_func_int8.h @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_COMMON_FUNC_H_ +#define NNACL_INT8_COMMON_FUNC_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, + int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, + int32_t maxi); +#ifdef ENABLE_ARM +void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, + int output_channel, int input_step, int8_t input_zp); +void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, + const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, int32_t acc_min, int32_t acc_max); +void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, + int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8, + size_t oc4, size_t offset); +void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, + size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, + size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, const int8_t *in_zp, + const int32_t *out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *acc_min, const int32_t *acc_max); +void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); +void DeconvDwInt8Post(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, int pixel_nums, + int out_multiplier, int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, + int32_t acc_max); +int16x8_t LoadAndAddOffset(const int8_t *data, int index, int offset); +int32x4_t ClacScaledInput(int32x4_t input, int32x4_t left_shift_result_vec, int32x4_t input_multiplier_vec, + int32x4_t right_shift_vec); +#endif + +#ifdef ENABLE_ARM32 +void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + int32_t acc_min, int32_t acc_max, size_t per_channel); +#endif + +#ifdef ENABLE_ARM64 +void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, + size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, + int32_t zp, int32_t mini, int32_t maxi); +void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, + int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, + int32_t out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, int32_t acc_min, int32_t acc_max, size_t per_channel); +void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, + int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, + int32_t out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, int32_t acc_min, int32_t acc_max, size_t per_channel); +void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, + size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, size_t acc_min, size_t acc_max, + size_t per_channel); +void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, + size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + size_t acc_min, size_t acc_max, size_t per_channel); +void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, + size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + size_t acc_min, size_t acc_max, size_t per_channel); +#endif +#ifdef __cplusplus +} +#endif + +#endif /* NNACL_FP32_COMMON_FUNC_H_ */ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/concat_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/concat_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..2bd02f8d00c3399479f55d01ab9ee6fc48936b64 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/concat_int8.c @@ -0,0 +1,57 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/concat_int8.h" +#include +#include +#include +#include "nnacl/concat_parameter.h" + +void Int8Concat(int8_t **inputs, int8_t *output, const ConcatParameter *para, int axis, int64_t real_dst_count, + int task_id, int input_num, int64_t count_unit, int64_t after_axis_size, int **input_shapes, + const int32_t *output_shape) { + float output_scale = para->quant_arg_.out_args_.scale_; + const float output_inverse_scale = 1.f / output_scale; + int out_copy_size = output_shape[axis] * after_axis_size; + QuantArg *input_quant = para->quant_arg_.in_args_; + int output_zp = para->quant_arg_.out_args_.zp_; + int8_t max_int8 = para->quant_arg_.output_activation_max_; + int8_t min_int8 = para->quant_arg_.output_activation_min_; + int64_t start = task_id * count_unit; + int64_t end = start + real_dst_count; + output += start * out_copy_size; + + for (int k = start; k < end; k++) { + for (int i = 0; i < input_num; i++) { + const int32_t *input_shape = input_shapes[i]; + int64_t in_copy_size = input_shape[axis] * after_axis_size; + const int8_t *input_ptr = inputs[i] + k * in_copy_size; + if (fabs(input_quant[i].scale_ - output_scale) <= FLT_EPSILON && input_quant[i].zp_ == output_zp) { + memcpy(output, input_ptr, in_copy_size); + } else { + float scale = input_quant[i].scale_ * output_inverse_scale; + float bias = -input_quant[i].zp_ * scale; + for (int j = 0; j < in_copy_size; j++) { + int32_t output_tmp = round(input_ptr[j] * scale + bias) + output_zp; + output_tmp = output_tmp > min_int8 ? output_tmp : min_int8; + output_tmp = output_tmp < max_int8 ? output_tmp : max_int8; + output[j] = (int8_t)output_tmp; + } + } + output += in_copy_size; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/concat_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/concat_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..3d3d1e6ed9cd52c8125ab9efe5bec1420e363652 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/concat_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_CONCAT_INT8_H_ +#define NNACL_INT8_CONCAT_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/concat_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Int8Concat(int8_t **inputs, int8_t *output, const ConcatParameter *para, int axis, int64_t real_dst_count, + int task_id, int input_num, int64_t count_unit, int64_t after_axis_size, int **input_shapes, + const int32_t *output_shape); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONCAT_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv1x1_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv1x1_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..8822fb6221052f7ad8fe0b90672ecb7774fc58b4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv1x1_int8.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/conv1x1_int8.h" + +void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, + const int32_t *filter_zp) { + int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; + matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, + left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc, + filter_zp); + return; +} + +void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, const int32_t *filter_zp) { + int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; + MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, + conv_param->output_channel_, is_per_oc, filter_zp); + return; +} diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/ms/ms_model.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv1x1_int8.h similarity index 35% rename from mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/ms/ms_model.h rename to mindspore-lite/ops/kernel/cpu/nnacl/int8/conv1x1_int8.h index ab47e7e4e455ba0a8c17ee3d3620893064576937..5066c684bab7b2ef64632d51ce2a9d8780696bbc 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/ms/ms_model.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv1x1_int8.h @@ -13,40 +13,34 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H -#define MINDSPORE_CCSRC_SESSION_SESSION_H - -#include -#include -#include -#include -#include - -#include "backend/common/session/session_basic.h" -#include "ir/anf.h" -#include "include/api/status.h" -#include "cxx_api/model/model_impl.h" - -namespace mindspore { -class MsModel : public ModelImpl { - public: - MsModel() {} - ~MsModel() = default; - - Status Build() override; - Status Resize(const std::vector &inputs, const std::vector> &dims) override; - - std::vector GetInputs() override; - std::vector GetOutputs() override; - - bool CheckDeviceSupport(mindspore::DeviceType device_type) override; - bool CheckModelSupport(enum ModelType model_type) override; - - private: - std::shared_ptr GenerateGraphCell(const std::vector> &dims); - uint32_t GetDeviceID() const; - - std::map> dynamic_size_graph_map_; -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H +#ifndef NNACL_INT8_CONV1X1_INT8_H_ +#define NNACL_INT8_CONV1X1_INT8_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/int8/matmul_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, const int32_t *filter_zp); +void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, + int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, + const int32_t *filter_zp); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONV1X1_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv3x3_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv3x3_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..7004f1c6ca9035bb1a90479a899cfd23acf88dea --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv3x3_int8.c @@ -0,0 +1,902 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/conv3x3_int8.h" + +void Conv3x3Int8InputUnit(const int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp) { +#ifdef ENABLE_ARM + int16x8_t zp = vdupq_n_s16(input_zp); + + int16x8_t d00 = vsubq_s16(vld1q_s16(tmp_data), zp); + int16x8_t d01 = vsubq_s16(vld1q_s16(tmp_data + 8), zp); + int16x8_t d02 = vsubq_s16(vld1q_s16(tmp_data + 2 * 8), zp); + int16x8_t d03 = vsubq_s16(vld1q_s16(tmp_data + 3 * 8), zp); + + int16x8_t d10 = vsubq_s16(vld1q_s16(tmp_data + 4 * 8), zp); + int16x8_t d11 = vsubq_s16(vld1q_s16(tmp_data + 5 * 8), zp); + int16x8_t d12 = vsubq_s16(vld1q_s16(tmp_data + 6 * 8), zp); + int16x8_t d13 = vsubq_s16(vld1q_s16(tmp_data + 7 * 8), zp); + + int16x8_t d20 = vsubq_s16(vld1q_s16(tmp_data + 8 * 8), zp); + int16x8_t d21 = vsubq_s16(vld1q_s16(tmp_data + 9 * 8), zp); + int16x8_t d22 = vsubq_s16(vld1q_s16(tmp_data + 10 * 8), zp); + int16x8_t d23 = vsubq_s16(vld1q_s16(tmp_data + 11 * 8), zp); + + int16x8_t d30 = vsubq_s16(vld1q_s16(tmp_data + 12 * 8), zp); + int16x8_t d31 = vsubq_s16(vld1q_s16(tmp_data + 13 * 8), zp); + int16x8_t d32 = vsubq_s16(vld1q_s16(tmp_data + 14 * 8), zp); + int16x8_t d33 = vsubq_s16(vld1q_s16(tmp_data + 15 * 8), zp); + + int16x8_t t00 = vsubq_s16(d00, d20); + int16x8_t t01 = vsubq_s16(d01, d21); + int16x8_t t02 = vsubq_s16(d02, d22); + int16x8_t t03 = vsubq_s16(d03, d23); + + int16x8_t t10 = vaddq_s16(d10, d20); + int16x8_t t11 = vaddq_s16(d11, d21); + int16x8_t t12 = vaddq_s16(d12, d22); + int16x8_t t13 = vaddq_s16(d13, d23); + + int16x8_t t20 = vsubq_s16(d20, d10); + int16x8_t t21 = vsubq_s16(d21, d11); + int16x8_t t22 = vsubq_s16(d22, d12); + int16x8_t t23 = vsubq_s16(d23, d13); + + int16x8_t t30 = vsubq_s16(d10, d30); + int16x8_t t31 = vsubq_s16(d11, d31); + int16x8_t t32 = vsubq_s16(d12, d32); + int16x8_t t33 = vsubq_s16(d13, d33); + + int16x8_t m00 = vsubq_s16(t00, t02); + int16x8_t m01 = vaddq_s16(t01, t02); + int16x8_t m02 = vsubq_s16(t02, t01); + int16x8_t m03 = vsubq_s16(t01, t03); + + int16x8_t m10 = vsubq_s16(t10, t12); + int16x8_t m11 = vaddq_s16(t11, t12); + int16x8_t m12 = vsubq_s16(t12, t11); + int16x8_t m13 = vsubq_s16(t11, t13); + + int16x8_t m20 = vsubq_s16(t20, t22); + int16x8_t m21 = vaddq_s16(t21, t22); + int16x8_t m22 = vsubq_s16(t22, t21); + int16x8_t m23 = vsubq_s16(t21, t23); + + int16x8_t m30 = vsubq_s16(t30, t32); + int16x8_t m31 = vaddq_s16(t31, t32); + int16x8_t m32 = vsubq_s16(t32, t31); + int16x8_t m33 = vsubq_s16(t31, t33); + + vst1q_s16(trans_input_data, m00); + vst1q_s16(trans_input_data + step, m01); + vst1q_s16(trans_input_data + 2 * step, m02); + vst1q_s16(trans_input_data + 3 * step, m03); + + vst1q_s16(trans_input_data + 4 * step, m10); + vst1q_s16(trans_input_data + 5 * step, m11); + vst1q_s16(trans_input_data + 6 * step, m12); + vst1q_s16(trans_input_data + 7 * step, m13); + + vst1q_s16(trans_input_data + 8 * step, m20); + vst1q_s16(trans_input_data + 9 * step, m21); + vst1q_s16(trans_input_data + 10 * step, m22); + vst1q_s16(trans_input_data + 11 * step, m23); + + vst1q_s16(trans_input_data + 12 * step, m30); + vst1q_s16(trans_input_data + 13 * step, m31); + vst1q_s16(trans_input_data + 14 * step, m32); + vst1q_s16(trans_input_data + 15 * step, m33); +#else + for (int i = 0; i < C8NUM; i++) { + const int16_t *local_ptr = tmp_data + i; + int16_t d00 = local_ptr[0] - input_zp; + int16_t d01 = (local_ptr + C8NUM)[0] - input_zp; + int16_t d02 = (local_ptr + 2 * C8NUM)[0] - input_zp; + int16_t d03 = (local_ptr + 3 * C8NUM)[0] - input_zp; + + int16_t d10 = (local_ptr + 4 * C8NUM)[0] - input_zp; + int16_t d11 = (local_ptr + 5 * C8NUM)[0] - input_zp; + int16_t d12 = (local_ptr + 6 * C8NUM)[0] - input_zp; + int16_t d13 = (local_ptr + 7 * C8NUM)[0] - input_zp; + + int16_t d20 = (local_ptr + 8 * C8NUM)[0] - input_zp; + int16_t d21 = (local_ptr + 9 * C8NUM)[0] - input_zp; + int16_t d22 = (local_ptr + 10 * C8NUM)[0] - input_zp; + int16_t d23 = (local_ptr + 11 * C8NUM)[0] - input_zp; + + int16_t d30 = (local_ptr + 12 * C8NUM)[0] - input_zp; + int16_t d31 = (local_ptr + 13 * C8NUM)[0] - input_zp; + int16_t d32 = (local_ptr + 14 * C8NUM)[0] - input_zp; + int16_t d33 = (local_ptr + 15 * C8NUM)[0] - input_zp; + + int16_t t00 = d00 - d20; + int16_t t01 = d01 - d21; + int16_t t02 = d02 - d22; + int16_t t03 = d03 - d23; + + int16_t t10 = d10 + d20; + int16_t t11 = d11 + d21; + int16_t t12 = d12 + d22; + int16_t t13 = d13 + d23; + + int16_t t20 = d20 - d10; + int16_t t21 = d21 - d11; + int16_t t22 = d22 - d12; + int16_t t23 = d23 - d13; + + int16_t t30 = d10 - d30; + int16_t t31 = d11 - d31; + int16_t t32 = d12 - d32; + int16_t t33 = d13 - d33; + + int16_t m00 = t00 - t02; + int16_t m01 = t01 + t02; + int16_t m02 = t02 - t01; + int16_t m03 = t01 - t03; + + int16_t m10 = t10 - t12; + int16_t m11 = t11 + t12; + int16_t m12 = t12 - t11; + int16_t m13 = t11 - t13; + + int16_t m20 = t20 - t22; + int16_t m21 = t21 + t22; + int16_t m22 = t22 - t21; + int16_t m23 = t21 - t23; + + int16_t m30 = t30 - t32; + int16_t m31 = t31 + t32; + int16_t m32 = t32 - t31; + int16_t m33 = t31 - t33; + + (trans_input_data + i)[0] = m00; + (trans_input_data + i + step)[0] = m01; + (trans_input_data + i + 2 * step)[0] = m02; + (trans_input_data + i + 3 * step)[0] = m03; + + (trans_input_data + i + 4 * step)[0] = m10; + (trans_input_data + i + 5 * step)[0] = m11; + (trans_input_data + i + 6 * step)[0] = m12; + (trans_input_data + i + 7 * step)[0] = m13; + + (trans_input_data + i + 8 * step)[0] = m20; + (trans_input_data + i + 9 * step)[0] = m21; + (trans_input_data + i + 10 * step)[0] = m22; + (trans_input_data + i + 11 * step)[0] = m23; + + (trans_input_data + i + 12 * step)[0] = m30; + (trans_input_data + i + 13 * step)[0] = m31; + (trans_input_data + i + 14 * step)[0] = m32; + (trans_input_data + i + 15 * step)[0] = m33; + } +#endif +} + +void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, + int kernel_plane) { + const int input_unit = 4; + int dst_step = iC8 * C8NUM * C4NUM; + for (int o = 0; o < output_channel; o++) { + int oc4_block_num = o / C4NUM; + int oc4_block_rem = o % C4NUM; + int src_oc_offset = o * iC8 * C8NUM * kernel_plane; + int dst_oc_offset = oc4_block_num * C4NUM * iC8 * C8NUM * input_unit * input_unit + oc4_block_rem; + for (int i = 0; i < iC8; i++) { + const int16_t *src_ic8_ptr = weight_data + src_oc_offset + i * kernel_plane * C8NUM; + int16_t *dst_ic8_ptr = trans_weight + dst_oc_offset + i * C4NUM * C8NUM; +#ifdef ENABLE_ARM + int16x8_t g00 = vld1q_s16(src_ic8_ptr); + int16x8_t g01 = vld1q_s16(src_ic8_ptr + 8); + int16x8_t g02 = vld1q_s16(src_ic8_ptr + 2 * 8); + int16x8_t g10 = vld1q_s16(src_ic8_ptr + 3 * 8); + int16x8_t g11 = vld1q_s16(src_ic8_ptr + 4 * 8); + int16x8_t g12 = vld1q_s16(src_ic8_ptr + 5 * 8); + int16x8_t g20 = vld1q_s16(src_ic8_ptr + 6 * 8); + int16x8_t g21 = vld1q_s16(src_ic8_ptr + 7 * 8); + int16x8_t g22 = vld1q_s16(src_ic8_ptr + 8 * 8); + + int16x8_t dst00 = vmulq_n_s16(g00, 2); + int16x8_t dst01 = vmulq_n_s16(g01, 2); + int16x8_t dst02 = vmulq_n_s16(g02, 2); + + int16x8_t dst10 = vaddq_s16(vaddq_s16(g00, g10), g20); + int16x8_t dst11 = vaddq_s16(vaddq_s16(g01, g11), g21); + int16x8_t dst12 = vaddq_s16(vaddq_s16(g02, g12), g22); + + int16x8_t dst20 = vaddq_s16(vsubq_s16(g00, g10), g20); + int16x8_t dst21 = vaddq_s16(vsubq_s16(g01, g11), g21); + int16x8_t dst22 = vaddq_s16(vsubq_s16(g02, g12), g22); + + int16x8_t dst30 = vmulq_n_s16(g20, 2); + int16x8_t dst31 = vmulq_n_s16(g21, 2); + int16x8_t dst32 = vmulq_n_s16(g22, 2); + + int16x8_t m00 = vmulq_n_s16(dst00, 2); + int16x8_t m01 = vaddq_s16(vaddq_s16(dst00, dst01), dst02); + int16x8_t m02 = vaddq_s16(vsubq_s16(dst00, dst01), dst02); + int16x8_t m03 = vmulq_n_s16(dst02, 2); + + int16x8_t m10 = vmulq_n_s16(dst10, 2); + int16x8_t m11 = vaddq_s16(vaddq_s16(dst10, dst11), dst12); + int16x8_t m12 = vaddq_s16(vsubq_s16(dst10, dst11), dst12); + int16x8_t m13 = vmulq_n_s16(dst12, 2); + + int16x8_t m20 = vmulq_n_s16(dst20, 2); + int16x8_t m21 = vaddq_s16(vaddq_s16(dst20, dst21), dst22); + int16x8_t m22 = vaddq_s16(vsubq_s16(dst20, dst21), dst22); + int16x8_t m23 = vmulq_n_s16(dst22, 2); + + int16x8_t m30 = vmulq_n_s16(dst30, 2); + int16x8_t m31 = vaddq_s16(vaddq_s16(dst30, dst31), dst32); + int16x8_t m32 = vaddq_s16(vsubq_s16(dst30, dst31), dst32); + int16x8_t m33 = vmulq_n_s16(dst32, 2); + + dst_ic8_ptr[0] = m00[0]; + dst_ic8_ptr[4] = m00[1]; + dst_ic8_ptr[8] = m00[2]; + dst_ic8_ptr[12] = m00[3]; + dst_ic8_ptr[16] = m00[4]; + dst_ic8_ptr[20] = m00[5]; + dst_ic8_ptr[24] = m00[6]; + dst_ic8_ptr[28] = m00[7]; + + dst_ic8_ptr[0 + dst_step] = m01[0]; + dst_ic8_ptr[4 + dst_step] = m01[1]; + dst_ic8_ptr[8 + dst_step] = m01[2]; + dst_ic8_ptr[12 + dst_step] = m01[3]; + dst_ic8_ptr[16 + dst_step] = m01[4]; + dst_ic8_ptr[20 + dst_step] = m01[5]; + dst_ic8_ptr[24 + dst_step] = m01[6]; + dst_ic8_ptr[28 + dst_step] = m01[7]; + + dst_ic8_ptr[0 + 2 * dst_step] = m02[0]; + dst_ic8_ptr[4 + 2 * dst_step] = m02[1]; + dst_ic8_ptr[8 + 2 * dst_step] = m02[2]; + dst_ic8_ptr[12 + 2 * dst_step] = m02[3]; + dst_ic8_ptr[16 + 2 * dst_step] = m02[4]; + dst_ic8_ptr[20 + 2 * dst_step] = m02[5]; + dst_ic8_ptr[24 + 2 * dst_step] = m02[6]; + dst_ic8_ptr[28 + 2 * dst_step] = m02[7]; + + dst_ic8_ptr[0 + 3 * dst_step] = m03[0]; + dst_ic8_ptr[4 + 3 * dst_step] = m03[1]; + dst_ic8_ptr[8 + 3 * dst_step] = m03[2]; + dst_ic8_ptr[12 + 3 * dst_step] = m03[3]; + dst_ic8_ptr[16 + 3 * dst_step] = m03[4]; + dst_ic8_ptr[20 + 3 * dst_step] = m03[5]; + dst_ic8_ptr[24 + 3 * dst_step] = m03[6]; + dst_ic8_ptr[28 + 3 * dst_step] = m03[7]; + + dst_ic8_ptr[0 + 4 * dst_step] = m10[0]; + dst_ic8_ptr[4 + 4 * dst_step] = m10[1]; + dst_ic8_ptr[8 + 4 * dst_step] = m10[2]; + dst_ic8_ptr[12 + 4 * dst_step] = m10[3]; + dst_ic8_ptr[16 + 4 * dst_step] = m10[4]; + dst_ic8_ptr[20 + 4 * dst_step] = m10[5]; + dst_ic8_ptr[24 + 4 * dst_step] = m10[6]; + dst_ic8_ptr[28 + 4 * dst_step] = m10[7]; + + dst_ic8_ptr[0 + 5 * dst_step] = m11[0]; + dst_ic8_ptr[4 + 5 * dst_step] = m11[1]; + dst_ic8_ptr[8 + 5 * dst_step] = m11[2]; + dst_ic8_ptr[12 + 5 * dst_step] = m11[3]; + dst_ic8_ptr[16 + 5 * dst_step] = m11[4]; + dst_ic8_ptr[20 + 5 * dst_step] = m11[5]; + dst_ic8_ptr[24 + 5 * dst_step] = m11[6]; + dst_ic8_ptr[28 + 5 * dst_step] = m11[7]; + + dst_ic8_ptr[0 + 6 * dst_step] = m12[0]; + dst_ic8_ptr[4 + 6 * dst_step] = m12[1]; + dst_ic8_ptr[8 + 6 * dst_step] = m12[2]; + dst_ic8_ptr[12 + 6 * dst_step] = m12[3]; + dst_ic8_ptr[16 + 6 * dst_step] = m12[4]; + dst_ic8_ptr[20 + 6 * dst_step] = m12[5]; + dst_ic8_ptr[24 + 6 * dst_step] = m12[6]; + dst_ic8_ptr[28 + 6 * dst_step] = m12[7]; + + dst_ic8_ptr[0 + 7 * dst_step] = m13[0]; + dst_ic8_ptr[4 + 7 * dst_step] = m13[1]; + dst_ic8_ptr[8 + 7 * dst_step] = m13[2]; + dst_ic8_ptr[12 + 7 * dst_step] = m13[3]; + dst_ic8_ptr[16 + 7 * dst_step] = m13[4]; + dst_ic8_ptr[20 + 7 * dst_step] = m13[5]; + dst_ic8_ptr[24 + 7 * dst_step] = m13[6]; + dst_ic8_ptr[28 + 7 * dst_step] = m13[7]; + + dst_ic8_ptr[0 + 8 * dst_step] = m20[0]; + dst_ic8_ptr[4 + 8 * dst_step] = m20[1]; + dst_ic8_ptr[8 + 8 * dst_step] = m20[2]; + dst_ic8_ptr[12 + 8 * dst_step] = m20[3]; + dst_ic8_ptr[16 + 8 * dst_step] = m20[4]; + dst_ic8_ptr[20 + 8 * dst_step] = m20[5]; + dst_ic8_ptr[24 + 8 * dst_step] = m20[6]; + dst_ic8_ptr[28 + 8 * dst_step] = m20[7]; + + dst_ic8_ptr[0 + 9 * dst_step] = m21[0]; + dst_ic8_ptr[4 + 9 * dst_step] = m21[1]; + dst_ic8_ptr[8 + 9 * dst_step] = m21[2]; + dst_ic8_ptr[12 + 9 * dst_step] = m21[3]; + dst_ic8_ptr[16 + 9 * dst_step] = m21[4]; + dst_ic8_ptr[20 + 9 * dst_step] = m21[5]; + dst_ic8_ptr[24 + 9 * dst_step] = m21[6]; + dst_ic8_ptr[28 + 9 * dst_step] = m21[7]; + + dst_ic8_ptr[0 + 10 * dst_step] = m22[0]; + dst_ic8_ptr[4 + 10 * dst_step] = m22[1]; + dst_ic8_ptr[8 + 10 * dst_step] = m22[2]; + dst_ic8_ptr[12 + 10 * dst_step] = m22[3]; + dst_ic8_ptr[16 + 10 * dst_step] = m22[4]; + dst_ic8_ptr[20 + 10 * dst_step] = m22[5]; + dst_ic8_ptr[24 + 10 * dst_step] = m22[6]; + dst_ic8_ptr[28 + 10 * dst_step] = m22[7]; + + dst_ic8_ptr[0 + 11 * dst_step] = m23[0]; + dst_ic8_ptr[4 + 11 * dst_step] = m23[1]; + dst_ic8_ptr[8 + 11 * dst_step] = m23[2]; + dst_ic8_ptr[12 + 11 * dst_step] = m23[3]; + dst_ic8_ptr[16 + 11 * dst_step] = m23[4]; + dst_ic8_ptr[20 + 11 * dst_step] = m23[5]; + dst_ic8_ptr[24 + 11 * dst_step] = m23[6]; + dst_ic8_ptr[28 + 11 * dst_step] = m23[7]; + + dst_ic8_ptr[0 + 12 * dst_step] = m30[0]; + dst_ic8_ptr[4 + 12 * dst_step] = m30[1]; + dst_ic8_ptr[8 + 12 * dst_step] = m30[2]; + dst_ic8_ptr[12 + 12 * dst_step] = m30[3]; + dst_ic8_ptr[16 + 12 * dst_step] = m30[4]; + dst_ic8_ptr[20 + 12 * dst_step] = m30[5]; + dst_ic8_ptr[24 + 12 * dst_step] = m30[6]; + dst_ic8_ptr[28 + 12 * dst_step] = m30[7]; + + dst_ic8_ptr[0 + 13 * dst_step] = m31[0]; + dst_ic8_ptr[4 + 13 * dst_step] = m31[1]; + dst_ic8_ptr[8 + 13 * dst_step] = m31[2]; + dst_ic8_ptr[12 + 13 * dst_step] = m31[3]; + dst_ic8_ptr[16 + 13 * dst_step] = m31[4]; + dst_ic8_ptr[20 + 13 * dst_step] = m31[5]; + dst_ic8_ptr[24 + 13 * dst_step] = m31[6]; + dst_ic8_ptr[28 + 13 * dst_step] = m31[7]; + + dst_ic8_ptr[0 + 14 * dst_step] = m32[0]; + dst_ic8_ptr[4 + 14 * dst_step] = m32[1]; + dst_ic8_ptr[8 + 14 * dst_step] = m32[2]; + dst_ic8_ptr[12 + 14 * dst_step] = m32[3]; + dst_ic8_ptr[16 + 14 * dst_step] = m32[4]; + dst_ic8_ptr[20 + 14 * dst_step] = m32[5]; + dst_ic8_ptr[24 + 14 * dst_step] = m32[6]; + dst_ic8_ptr[28 + 14 * dst_step] = m32[7]; + + dst_ic8_ptr[0 + 15 * dst_step] = m33[0]; + dst_ic8_ptr[4 + 15 * dst_step] = m33[1]; + dst_ic8_ptr[8 + 15 * dst_step] = m33[2]; + dst_ic8_ptr[12 + 15 * dst_step] = m33[3]; + dst_ic8_ptr[16 + 15 * dst_step] = m33[4]; + dst_ic8_ptr[20 + 15 * dst_step] = m33[5]; + dst_ic8_ptr[24 + 15 * dst_step] = m33[6]; + dst_ic8_ptr[28 + 15 * dst_step] = m33[7]; +#else + for (int j = 0; j < C8NUM; j++) { + const int16_t *local_ptr = src_ic8_ptr + j; + int16_t dst00 = local_ptr[0] * 2; + int16_t dst01 = (local_ptr + 8)[0] * 2; + int16_t dst02 = (local_ptr + 16)[0] * 2; + + int16_t dst10 = local_ptr[0] + (local_ptr + 24)[0] + (local_ptr + 48)[0]; + int16_t dst11 = (local_ptr + 8)[0] + (local_ptr + 32)[0] + (local_ptr + 56)[0]; + int16_t dst12 = (local_ptr + 16)[0] + (local_ptr + 40)[0] + (local_ptr + 64)[0]; + + int16_t dst20 = local_ptr[0] - (local_ptr + 24)[0] + (local_ptr + 48)[0]; + int16_t dst21 = (local_ptr + 8)[0] - (local_ptr + 32)[0] + (local_ptr + 56)[0]; + int16_t dst22 = (local_ptr + 16)[0] - (local_ptr + 40)[0] + (local_ptr + 64)[0]; + + int16_t dst30 = (local_ptr + 48)[0] * 2; + int16_t dst31 = (local_ptr + 56)[0] * 2; + int16_t dst32 = (local_ptr + 64)[0] * 2; + + int16_t m00 = dst00 * 2; + int16_t m01 = dst00 + dst01 + dst02; + int16_t m02 = dst00 - dst01 + dst02; + int16_t m03 = dst02 * 2; + + int16_t m10 = dst10 * 2; + int16_t m11 = dst10 + dst11 + dst12; + int16_t m12 = dst10 - dst11 + dst12; + int16_t m13 = dst12 * 2; + + int16_t m20 = dst20 * 2; + int16_t m21 = dst20 + dst21 + dst22; + int16_t m22 = dst20 - dst21 + dst22; + int16_t m23 = dst22 * 2; + + int16_t m30 = dst30 * 2; + int16_t m31 = dst30 + dst31 + dst32; + int16_t m32 = dst30 - dst31 + dst32; + int16_t m33 = dst32 * 2; + + *(dst_ic8_ptr + j * 4) = m00; + *(dst_ic8_ptr + j * 4 + dst_step) = m01; + *(dst_ic8_ptr + j * 4 + 2 * dst_step) = m02; + *(dst_ic8_ptr + j * 4 + 3 * dst_step) = m03; + + *(dst_ic8_ptr + j * 4 + 4 * dst_step) = m10; + *(dst_ic8_ptr + j * 4 + 5 * dst_step) = m11; + *(dst_ic8_ptr + j * 4 + 6 * dst_step) = m12; + *(dst_ic8_ptr + j * 4 + 7 * dst_step) = m13; + + *(dst_ic8_ptr + j * 4 + 8 * dst_step) = m20; + *(dst_ic8_ptr + j * 4 + 9 * dst_step) = m21; + *(dst_ic8_ptr + j * 4 + 10 * dst_step) = m22; + *(dst_ic8_ptr + j * 4 + 11 * dst_step) = m23; + + *(dst_ic8_ptr + j * 4 + 12 * dst_step) = m30; + *(dst_ic8_ptr + j * 4 + 13 * dst_step) = m31; + *(dst_ic8_ptr + j * 4 + 14 * dst_step) = m32; + *(dst_ic8_ptr + j * 4 + 15 * dst_step) = m33; + } +#endif + } + } +} + +void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound, + bool w_not_bound, int output_w, int real_num, int oc_start, + const ConvParameter *conv_param) { + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + int32_t *quant_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int out_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int out_max = conv_param->conv_quant_arg_.out_act_max_[0]; + +#ifdef ENABLE_ARM + int32x4_t bias_ptr = vld1q_s32(bias_data); + + int32x4_t s00 = vld1q_s32(gemm_out); + int32x4_t s01 = vld1q_s32(gemm_out + 4); + int32x4_t s02 = vld1q_s32(gemm_out + 8); + int32x4_t s03 = vld1q_s32(gemm_out + 12); + + int32x4_t s10 = vld1q_s32(gemm_out + 16); + int32x4_t s11 = vld1q_s32(gemm_out + 20); + int32x4_t s12 = vld1q_s32(gemm_out + 24); + int32x4_t s13 = vld1q_s32(gemm_out + 28); + + int32x4_t s20 = vld1q_s32(gemm_out + 32); + int32x4_t s21 = vld1q_s32(gemm_out + 36); + int32x4_t s22 = vld1q_s32(gemm_out + 40); + int32x4_t s23 = vld1q_s32(gemm_out + 44); + + int32x4_t s30 = vld1q_s32(gemm_out + 48); + int32x4_t s31 = vld1q_s32(gemm_out + 52); + int32x4_t s32 = vld1q_s32(gemm_out + 56); + int32x4_t s33 = vld1q_s32(gemm_out + 60); + + int32x4_t t00 = vshrq_n_s32(vaddq_s32(vaddq_s32(s00, s10), s20), 1); + int32x4_t t01 = vshrq_n_s32(vaddq_s32(vaddq_s32(s01, s11), s21), 1); + int32x4_t t02 = vshrq_n_s32(vaddq_s32(vaddq_s32(s02, s12), s22), 1); + int32x4_t t03 = vshrq_n_s32(vaddq_s32(vaddq_s32(s03, s13), s23), 1); + + int32x4_t t10 = vshrq_n_s32(vsubq_s32(vsubq_s32(s10, s20), s30), 1); + int32x4_t t11 = vshrq_n_s32(vsubq_s32(vsubq_s32(s11, s21), s31), 1); + int32x4_t t12 = vshrq_n_s32(vsubq_s32(vsubq_s32(s12, s22), s32), 1); + int32x4_t t13 = vshrq_n_s32(vsubq_s32(vsubq_s32(s13, s23), s33), 1); + + int32x4_t d00 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t00, t01), t02), 1), bias_ptr); + int32x4_t d01 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t01, t02), t03), 1), bias_ptr); + + int32x4_t d10 = vaddq_s32(vshrq_n_s32(vaddq_s32(vaddq_s32(t10, t11), t12), 1), bias_ptr); + int32x4_t d11 = vaddq_s32(vshrq_n_s32(vsubq_s32(vsubq_s32(t11, t12), t13), 1), bias_ptr); + + int32x4_t out_multiplier; + int32x4_t ls; + int32x4_t rs; + if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + out_multiplier = vld1q_s32(quant_multiplier + oc_start); + ls = vld1q_s32(left_shift + oc_start); + rs = vld1q_s32(right_shift + oc_start); + } else { + out_multiplier = vdupq_n_s32(quant_multiplier[0]); + ls = vdupq_n_s32(left_shift[0]); + rs = vdupq_n_s32(right_shift[0]); + } + int32x4_t out_zp = vdupq_n_s32(output_zp); + int32x4_t output_min = vdupq_n_s32(out_min); + int32x4_t output_max = vdupq_n_s32(out_max); + + d00 = vqshlq_s32(d00, ls); + d00 = vqrdmulhq_s32(d00, out_multiplier); + int32x4_t carry = vandq_s32(d00, rs); + carry = vshrq_n_s32(carry, 31); + d00 = vqaddq_s32(d00, carry); + d00 = vqrshlq_s32(d00, rs); + d00 = vaddq_s32(d00, out_zp); + d00 = vmaxq_s32(d00, output_min); + d00 = vminq_s32(d00, output_max); + + d01 = vqshlq_s32(d01, ls); + d01 = vqrdmulhq_s32(d01, out_multiplier); + carry = vandq_s32(d01, rs); + carry = vshrq_n_s32(carry, 31); + d01 = vqaddq_s32(d01, carry); + d01 = vqrshlq_s32(d01, rs); + d01 = vaddq_s32(d01, out_zp); + d01 = vmaxq_s32(d01, output_min); + d01 = vminq_s32(d01, output_max); + + d10 = vqshlq_s32(d10, ls); + d10 = vqrdmulhq_s32(d10, out_multiplier); + carry = vandq_s32(d10, rs); + carry = vshrq_n_s32(carry, 31); + d10 = vqaddq_s32(d10, carry); + d10 = vqrshlq_s32(d10, rs); + d10 = vaddq_s32(d10, out_zp); + d10 = vmaxq_s32(d10, output_min); + d10 = vminq_s32(d10, output_max); + + d11 = vqshlq_s32(d11, ls); + d11 = vqrdmulhq_s32(d11, out_multiplier); + carry = vandq_s32(d11, rs); + carry = vshrq_n_s32(carry, 31); + d11 = vqaddq_s32(d11, carry); + d11 = vqrshlq_s32(d11, rs); + d11 = vaddq_s32(d11, out_zp); + d11 = vmaxq_s32(d11, output_min); + d11 = vminq_s32(d11, output_max); + + (output_data)[0] = (int8_t)d00[0]; + (output_data + 1)[0] = (int8_t)d00[1]; + (output_data + 2)[0] = (int8_t)d00[2]; + (output_data + 3)[0] = (int8_t)d00[3]; + + if (w_not_bound) { + *(output_data + 4) = (int8_t)d01[0]; + *(output_data + 5) = (int8_t)d01[1]; + *(output_data + 6) = (int8_t)d01[2]; + *(output_data + 7) = (int8_t)d01[3]; + } + if (h_not_bound) { + *(output_data + output_w * 4) = (int8_t)d10[0]; + *(output_data + output_w * 4 + 1) = (int8_t)d10[1]; + *(output_data + output_w * 4 + 2) = (int8_t)d10[2]; + *(output_data + output_w * 4 + 3) = (int8_t)d10[3]; + if (w_not_bound) { + *(output_data + output_w * 4 + 4) = (int8_t)d11[0]; + *(output_data + output_w * 4 + 5) = (int8_t)d11[1]; + *(output_data + output_w * 4 + 6) = (int8_t)d11[2]; + *(output_data + output_w * 4 + 7) = (int8_t)d11[3]; + } + } +#else + if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) { + for (int i = 0; i < C4NUM; i++) { + const int32_t *local_ptr = gemm_out + i; + const int32_t *bias_ptr = bias_data + i; + + int32_t s00 = local_ptr[0]; + int32_t s01 = (local_ptr + 4)[0]; + int32_t s02 = (local_ptr + 8)[0]; + int32_t s03 = (local_ptr + 12)[0]; + + int32_t s10 = (local_ptr + 16)[0]; + int32_t s11 = (local_ptr + 20)[0]; + int32_t s12 = (local_ptr + 24)[0]; + int32_t s13 = (local_ptr + 28)[0]; + + int32_t s20 = (local_ptr + 32)[0]; + int32_t s21 = (local_ptr + 36)[0]; + int32_t s22 = (local_ptr + 40)[0]; + int32_t s23 = (local_ptr + 44)[0]; + + int32_t s30 = (local_ptr + 48)[0]; + int32_t s31 = (local_ptr + 52)[0]; + int32_t s32 = (local_ptr + 56)[0]; + int32_t s33 = (local_ptr + 60)[0]; + + int32_t t00 = (s00 + s10 + s20) / 2; + int32_t t01 = (s01 + s11 + s21) / 2; + int32_t t02 = (s02 + s12 + s22) / 2; + int32_t t03 = (s03 + s13 + s23) / 2; + + int32_t t10 = (s10 - s20 - s30) / 2; + int32_t t11 = (s11 - s21 - s31) / 2; + int32_t t12 = (s12 - s22 - s32) / 2; + int32_t t13 = (s13 - s23 - s33) / 2; + + int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; + int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; + + int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; + int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; + + int oc_index = oc_start + i; + d00 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d00 += output_zp; + d00 = d00 > out_min ? d00 : out_min; + d00 = d00 < out_max ? d00 : out_max; + + d01 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d01 += output_zp; + d01 = d01 > out_min ? d01 : out_min; + d01 = d01 < out_max ? d01 : out_max; + + d10 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d10 += output_zp; + d10 = d10 > out_min ? d10 : out_min; + d10 = d10 < out_max ? d10 : out_max; + + d11 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[oc_index]), quant_multiplier[oc_index]), + -right_shift[oc_index]); + d11 += output_zp; + d11 = d11 > out_min ? d11 : out_min; + d11 = d11 < out_max ? d11 : out_max; + + (output_data + i)[0] = (int8_t)d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = (int8_t)d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + } + } + } + } else { + for (int i = 0; i < C4NUM; i++) { + const int32_t *local_ptr = gemm_out + i; + const int32_t *bias_ptr = bias_data + i; + + int32_t s00 = local_ptr[0]; + int32_t s01 = (local_ptr + 4)[0]; + int32_t s02 = (local_ptr + 8)[0]; + int32_t s03 = (local_ptr + 12)[0]; + + int32_t s10 = (local_ptr + 16)[0]; + int32_t s11 = (local_ptr + 20)[0]; + int32_t s12 = (local_ptr + 24)[0]; + int32_t s13 = (local_ptr + 28)[0]; + + int32_t s20 = (local_ptr + 32)[0]; + int32_t s21 = (local_ptr + 36)[0]; + int32_t s22 = (local_ptr + 40)[0]; + int32_t s23 = (local_ptr + 44)[0]; + + int32_t s30 = (local_ptr + 48)[0]; + int32_t s31 = (local_ptr + 52)[0]; + int32_t s32 = (local_ptr + 56)[0]; + int32_t s33 = (local_ptr + 60)[0]; + + int32_t t00 = (s00 + s10 + s20) / 2; + int32_t t01 = (s01 + s11 + s21) / 2; + int32_t t02 = (s02 + s12 + s22) / 2; + int32_t t03 = (s03 + s13 + s23) / 2; + + int32_t t10 = (s10 - s20 - s30) / 2; + int32_t t11 = (s11 - s21 - s31) / 2; + int32_t t12 = (s12 - s22 - s32) / 2; + int32_t t13 = (s13 - s23 - s33) / 2; + + int32_t d00 = (t00 + t01 + t02) / 2 + bias_ptr[0]; + int32_t d01 = (t01 - t02 - t03) / 2 + bias_ptr[0]; + + int32_t d10 = (t10 + t11 + t12) / 2 + bias_ptr[0]; + int32_t d11 = (t11 - t12 - t13) / 2 + bias_ptr[0]; + + d00 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d00 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d00 += output_zp; + d00 = d00 > out_min ? d00 : out_min; + d00 = d00 < out_max ? d00 : out_max; + + d01 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d01 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d01 += output_zp; + d01 = d01 > out_min ? d01 : out_min; + d01 = d01 < out_max ? d01 : out_max; + + d10 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d10 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d10 += output_zp; + d10 = d10 > out_min ? d10 : out_min; + d10 = d10 < out_max ? d10 : out_max; + + d11 = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(d11 * (1 << (unsigned int)left_shift[0]), quant_multiplier[0]), + -right_shift[0]); + d11 += output_zp; + d11 = d11 > out_min ? d11 : out_min; + d11 = d11 < out_max ? d11 : out_max; + + (output_data + i)[0] = (int8_t)d00; + if (w_not_bound) { + (output_data + i + C4NUM)[0] = (int8_t)d01; + } + if (h_not_bound) { + (output_data + i + output_w * C4NUM)[0] = (int8_t)d10; + if (w_not_bound) { + (output_data + i + output_w * C4NUM + C4NUM)[0] = (int8_t)d11; + } + } + } + } +#endif +} + +void Conv3x3Int8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index, + int real_cal_num, int out_w_block, const ConvParameter *conv_param) { + int output_channel = conv_param->output_channel_; + int output_w = conv_param->output_w_; + int output_h = conv_param->output_h_; + const int oc4 = UP_DIV(output_channel, C4NUM); + const int input_unit = 4; + if (out_w_block == 0) { + return; + } + for (int i = 0; i < real_cal_num; i++) { + int out_w_index = (start_index + i) % out_w_block; + int out_h_index = (start_index + i) / out_w_block; + int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit; + int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w); + + for (int j = 0; j < oc4; j++) { + int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM; + int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w; + const int32_t *src_ptr = gemm_out + src_oc4_offset; + const int32_t *bias_ptr = bias_data + j * C4NUM; + int8_t *dst_ptr = out_data + dst_oc4_offset; + + // output transform + int real_num = (output_channel - j * C4NUM) < C4NUM ? (output_channel - j * C4NUM) : C4NUM; + bool w_not_bound = out_w_index * OUPUT_UNIT + 1 < output_w; + bool h_not_bound = out_h_index * OUPUT_UNIT + 1 < output_h; + Conv3x3Int8OutputUnit(src_ptr, bias_ptr, dst_ptr, h_not_bound, w_not_bound, output_w, real_num, j * C4NUM, + conv_param); + } + } +} + +void Conv3x3Int8InputTransform(const int16_t *input_data, int16_t *trans_input, int16_t *tmp_data, int start_index, + int real_cal_num, int out_w_block, const ConvParameter *conv_param) { + // input data format : nhwc + int input_channel = conv_param->input_channel_; + int input_width = conv_param->input_w_; + int input_height = conv_param->input_h_; + int pad_w = conv_param->pad_l_; + int pad_h = conv_param->pad_u_; + ConvQuantArg quant_arg = conv_param->conv_quant_arg_; + int input_zp = quant_arg.input_quant_args_[0].zp_; + const int ic8 = UP_DIV(input_channel, C8NUM); + const int input_unit = 4; + if (out_w_block == 0) { + return; + } + for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { + int x_id = start_index + cal_id; + int origin_x = (x_id % out_w_block) * OUPUT_UNIT - pad_w; + int origin_y = (x_id / out_w_block) * OUPUT_UNIT - pad_h; + int real_x_start = origin_x > 0 ? 0 : -origin_x; + int real_x_end = (origin_x + input_unit) < input_width ? input_unit : (input_width - origin_x); + int real_y_start = origin_y > 0 ? 0 : -origin_y; + int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y); + + int src_plane_offset = C8NUM * (origin_y * input_width + origin_x); + int dst_plane_offset = cal_id * C8NUM; + for (int ic = 0; ic < ic8; ic++) { + // copy data from origin input to tmp buffer + for (int i = 0; i < input_unit * input_unit * TILE_NUM; i++) tmp_data[i] = input_zp; + + int src_c8_offset = src_plane_offset + ic * C8NUM * input_height * input_width; + for (int j = real_y_start; j < real_y_end; j++) { + const int16_t *src = input_data + src_c8_offset + C8NUM * (j * input_width + real_x_start); + int16_t *dst = tmp_data + C8NUM * (C4NUM * j + real_x_start); + memcpy(dst, src, (size_t)(real_x_end - real_x_start) * C8NUM * sizeof(int16_t)); + } + // input transform + int dst_ic8_offset = dst_plane_offset + ic * TILE_NUM * C8NUM; + size_t dst_step = (size_t)ic8 * C8NUM * TILE_NUM; + int16_t *trans_input_ptr = trans_input + dst_ic8_offset; + Conv3x3Int8InputUnit(tmp_data, trans_input_ptr, dst_step, input_zp); + } + } +} + +void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) { + int oc4 = UP_DIV(oc, C4NUM); +#ifdef ENABLE_ARM + IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, (size_t)oc4 * 4 * 16 * sizeof(int32_t)); +#else + const int input_unit_square = 16; + for (int c = 0; c < oc4; c++) { + int filter_oc_offset = c * input_unit_square * ic8 * C8NUM * C4NUM; + int dst_oc_offset = c * input_unit_square * C4NUM; + for (int n = 0; n < real_cal_num; n++) { + int src_tile_offset = n * C8NUM; + int dst_tile_offset = dst_oc_offset + n * oc4 * C4NUM * input_unit_square; + for (int i = 0; i < 4; i++) { + int filter_h_offset = filter_oc_offset + i * 4 * ic8 * C8NUM * C4NUM; + int src_h_offset = src_tile_offset + i * C8NUM * ic8 * C8NUM * C4NUM; + int dst_h_offset = dst_tile_offset + i * 4 * 4; + for (int m = 0; m < 4; m++) { + int filter_w_offset = filter_h_offset + m * 4 * C8NUM * ic8; + int src_w_offset = src_h_offset + m * 8 * ic8 * C8NUM; + int dst_w_offset = dst_h_offset + m * C4NUM; + + int32_t acc[4] = {0}; + for (int z = 0; z < 4; z++) { + int filter_offset = filter_w_offset + z; + for (int j = 0; j < ic8; j++) { + int filter_c8_offset = filter_offset + j * 4 * 8; + int src_c8_offset = src_w_offset + j * 8 * 8; + + for (int k = 0; k < 8; k++) { + const int16_t *w_ptr = weight + filter_c8_offset + k * 4; + const int16_t *input_ptr = src + src_c8_offset + k; + acc[z] += w_ptr[0] * input_ptr[0]; + } + } + (dst + dst_w_offset + z)[0] = acc[z]; + } + } + } + } + } +#endif +} + +// int8 convolution 3x3 +void Conv3x3Int8(const int16_t *input_data, const int16_t *transed_weight, const int32_t *bias_data, + int8_t *output_data, int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, + int8_t *tmp_out, int task_id, const ConvParameter *conv_param) { + int ic8 = UP_DIV(conv_param->input_channel_, C8NUM); + int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT); + int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); + int output_count = out_w_block * out_h_block; + NNACL_CHECK_ZERO_RETURN(TILE_NUM); + int output_tile_count = UP_DIV(output_count, TILE_NUM); + int oc4 = UP_DIV(conv_param->output_channel_, C4NUM); + int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM; + const int block_unit_buffer_offset = 16 * C8NUM; + int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM; + + for (int batch = 0; batch < conv_param->input_batch_; batch++) { + int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; + int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int start_index = thread_id * TILE_NUM; + int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; + + Conv3x3Int8InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset, + block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, + out_w_block, conv_param); + + Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, + transed_weight, conv_param->output_channel_, ic8, real_cal_num); + + Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset, + bias_data, start_index, real_cal_num, out_w_block, conv_param); + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv3x3_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv3x3_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..fb6ad7c5ed7600fa345c93dbc3cd8deeec624b17 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv3x3_int8.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_CONV_INT8_H_ +#define NNACL_INT8_CONV_INT8_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/int8/fixed_point.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/int8/matmul_int8.h" +#include "nnacl/int8/common_func_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weight, int iC8, int output_channel, + int kernel_plane); + +void Conv3x3Int8(const int16_t *input_data, const int16_t *transed_weight, const int32_t *bias_data, + int8_t *output_data, int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, + int8_t *tmp_out, int task_id, const ConvParameter *conv_param); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONV_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_depthwise_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_depthwise_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..f0b368844917fffbb18acfd581d05966c988fc21 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_depthwise_int8.c @@ -0,0 +1,825 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/conv_depthwise_int8.h" +#include +#include "nnacl/int8/fixed_point.h" +#include "nnacl/int8/common_func_int8.h" + +/*conv depthwise int8 begin*/ +#ifndef ENABLE_ARM +void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, + int output_channel, int input_step, int8_t input_zp) { + for (int i = 0; i < num_pixels; i++) { + for (int c = 0; c < output_channel; c++) { + const int16_t input = input_ptr[c] - input_zp; + *output_ptr++ += input * weight_ptr[c]; + } + input_ptr += input_step; + } +} +#endif + +void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int output_w, int channel, int32_t output_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + int32_t acc_min, int32_t acc_max, bool per_channel) { + if (per_channel) { + // support perchannel + for (int w = 0; w < output_w; w++) { + int channel4 = 0; +#ifdef ENABLE_ARM + channel4 = channel / 4 * 4; + ConvDwInt8PostAlign4PerChannel(dst, buffer, channel4, output_zp, out_multiplier, left_shift, right_shift, acc_min, + acc_max); +#endif + for (int c = channel4; c < channel; c++) { + buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + buffer[c] += output_zp; + buffer[c] = MSMAX(buffer[c], acc_min); + buffer[c] = MSMIN(buffer[c], acc_max); + dst[c] = (buffer[c]); + } + buffer += channel; + dst += channel; + } + } else { + int num_pixels = output_w * channel; + int align_num = 0; +#ifdef ENABLE_ARM + align_num = num_pixels / 4 * 4; + ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier[0], left_shift[0], right_shift[0], acc_min, + acc_max); +#endif + for (int i = align_num; i < num_pixels; i++) { + buffer[i] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer[i] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]), + -right_shift[0]); + buffer[i] += output_zp; + buffer[i] = MSMAX(buffer[i], acc_min); + buffer[i] = MSMIN(buffer[i], acc_max); + dst[i] = (buffer[i]); + } + } +} + +void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, int task_id) { + int step_h = UP_DIV(conv_param->output_h_, conv_param->thread_num_); + int start_h = step_h * task_id; + int end_h = MSMIN(start_h + step_h, conv_param->output_h_); + + bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + + int intput_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int output_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int acc_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int acc_max = conv_param->conv_quant_arg_.out_act_max_[0]; + + for (int b = 0; b < conv_param->output_batch_; b++) { + const int8_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + int8_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + for (int oh = start_h; oh < end_h; oh++) { + int8_t *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_; + + int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_)); + + // init acc + for (int ow = 0; ow < conv_param->output_w_; ow++) { + memcpy(row_buffer + ow * conv_param->output_channel_, bias_data, conv_param->output_channel_ * sizeof(int32_t)); + } + for (int kh = start_kh; kh < end_kh; kh++) { + int ih = ih_origin + conv_param->dilation_h_ * kh; + + const int8_t *src_kh = src + ih * conv_param->input_w_ * conv_param->input_channel_; + const int16_t *weight_kh = weight_data + kh * conv_param->kernel_w_ * conv_param->output_channel_; + + int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_; + for (int kw = 0; kw < conv_param->kernel_w_; kw++) { + int out_w_start = MSMAX( + 0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_); + int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ - + conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / + conv_param->stride_w_); + + int32_t *acc_w = row_buffer + out_w_start * conv_param->output_channel_; + int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw; + + const int8_t *src_kw = src_kh + iw_origin * conv_param->input_channel_; + int num_pixels = out_w_end - out_w_start; + + ConvDwInt8Row(acc_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step, intput_zp); + weight_kh += conv_param->output_channel_; + } + } + // post func, acc int32 -> dst int8 + ConvDwInt8Post(dst_data, row_buffer, conv_param->output_w_, conv_param->output_channel_, output_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + } + } +} +/*conv depthwise int8 end*/ + +/*conv depthwise 3x3 int8 begin*/ +void ConvDw3x3Int8InitBuffer(int8_t *buffer, const int8_t *input, const ConvParameter *conv_param, int block_input_h, + int block_input_w) { + for (int h = 0; h < block_input_h; h++) { + const int8_t *src = input; + for (int w = 0; w < block_input_w; w++) { + memcpy(buffer, src, 64); + src += conv_param->input_channel_; + buffer += 64; + } + input += conv_param->input_w_ * conv_param->input_channel_; + } +} + +void ConvDw3x3Int8Window(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int col_size, + int row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + int32_t acc_min, int32_t acc_max, int stride, bool per_channel) { + for (int w = 0; w < output_w; w++) { + int tmp_buffer[C8NUM]; + for (int i = 0; i < C8NUM; i++) { + tmp_buffer[i] = 0; + } + int8_t *output_tmp = output; + const int8_t *src_kh = buffer; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < 3; kh++) { + const int8_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < 3; kw++) { + for (int c = 0; c < 8; c++) { + tmp_buffer[c] += (src_kw[c] - in_zp) * weight_kw[c]; + } + src_kw += col_size; + weight_kw += channel; + } + src_kh += row_size; + weight_kh += 3 * channel; + } + if (per_channel) { + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + tmp_buffer[c] += out_zp; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); + *output_tmp++ = (tmp_buffer[c]); + } + } else { + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]), + -right_shift[0]); + tmp_buffer[c] += out_zp; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); + *output_tmp++ = (tmp_buffer[c]); + } + } + output += channel; + buffer += col_size * stride; + } +} + +void ConvDw3x3Int8Block(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int start_c, + int end_c, int col_size, int row_size, int channel, int output_h, int output_w, int8_t in_zp, + int32_t out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, int32_t acc_min, int32_t acc_max, int stride, bool per_channel) { + for (; start_c <= end_c - 8; start_c += 8) { +#ifdef ENABLE_ARM64 + if (stride == 1) { + ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, per_channel); + } else { + ConvDw3x3Int8Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, per_channel); + } + +#else + ConvDw3x3Int8Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max, stride, per_channel); +#endif + output += 8; + buffer += 8; + weight += 8; + bias += 8; + if (per_channel) { + out_multiplier += 8; + left_shift += 8; + right_shift += 8; + } + } +} + +void ConvDw3x3Int8Row(int8_t *output, int8_t *buffer, const int8_t *input, const int16_t *weight, const int32_t *bias, + const ConvParameter *conv_param, int start_w, int end_w, int block_output_h, int block_output_w, + int block_input_h, int block_input_w) { + bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + int in_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int acc_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int acc_max = conv_param->conv_quant_arg_.out_act_max_[0]; + + const int ih_offset = 64 * block_input_w; + int w = start_w; + if (conv_param->output_channel_ > 64 || (conv_param->output_channel_ < 64 && conv_param->input_w_ > 150)) { + for (; w <= end_w - block_output_w; w += block_output_w) { + int8_t *output_ptr = output; + const int8_t *input_ptr = input; + const int16_t *weight_ptr = weight; + const int32_t *bias_ptr = bias; + int32_t *out_multiplier_ptr = out_multiplier; + int32_t *left_shift_ptr = left_shift; + int32_t *right_shift_ptr = right_shift; + int c = 0; + for (; c <= conv_param->output_channel_ - 64; c += 64) { + ConvDw3x3Int8InitBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w); + ConvDw3x3Int8Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_, + block_output_h, block_output_w, in_zp, out_zp, out_multiplier_ptr, left_shift_ptr, + right_shift_ptr, acc_min, acc_max, conv_param->stride_h_, filter_per_channel); + output_ptr += 64; + input_ptr += 64; + weight_ptr += 64; + bias_ptr += 64; + if (filter_per_channel) { + out_multiplier_ptr += 64; + left_shift_ptr += 64; + right_shift_ptr += 64; + } + } + // left channel + ConvDw3x3Int8Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_, + conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_, + conv_param->input_channel_, block_output_h, block_output_w, in_zp, out_zp, out_multiplier_ptr, + left_shift_ptr, right_shift_ptr, acc_min, acc_max, conv_param->stride_h_, filter_per_channel); + output += block_output_w * conv_param->input_channel_; + input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_; + } + } + // left width + int left_width = end_w - w; + if (left_width > 0) { + ConvDw3x3Int8Block(output, input, weight, bias, 0, conv_param->input_channel_, conv_param->input_channel_, + conv_param->input_w_ * conv_param->input_channel_, conv_param->input_channel_, block_output_h, + left_width, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, + conv_param->stride_h_, filter_per_channel); + } +} + +void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + int output_h = sliding->bottom_ - sliding->top_; + int step_oh = UP_DIV(output_h, conv_param->thread_num_); + int start_oh = step_oh * task_id + sliding->top_; + int end_oh = MSMIN(start_oh + step_oh, sliding->bottom_); + int start_ow = sliding->left_; + int end_ow = sliding->right_; + + const int block_output_h = 1; + int block_output_w = conv_param->stride_w_ == 1 ? 30 : 14; + const int block_input_h = 3; + int block_input_w = conv_param->stride_w_ * (block_output_w - 1) + 3; + + for (int b = 0; b < conv_param->output_batch_; b++) { + int start_ih = start_oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_iw = start_ow * conv_param->stride_w_ - conv_param->pad_l_; + const int8_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_ + + start_ih * conv_param->input_w_ * conv_param->input_channel_ + + start_iw * conv_param->input_channel_; + int8_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_ + + start_oh * conv_param->output_w_ * conv_param->output_channel_ + + start_ow * conv_param->output_channel_; + + for (int oh = start_oh; oh < end_oh; oh++) { + ConvDw3x3Int8Row(dst, buffer, src, weight_data, bias_data, conv_param, start_ow, end_ow, block_output_h, + block_output_w, block_input_h, block_input_w); + src += conv_param->stride_h_ * conv_param->input_w_ * conv_param->input_channel_; + dst += conv_param->output_w_ * conv_param->output_channel_; + } + } +} + +#ifndef ENABLE_ARM32 +void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + const int32_t acc_min, const int32_t acc_max, bool per_channel) { + for (int c = 0; c < channel; c += 8) { + int tmp_buffer[8]; + for (int i = 0; i < 8; i++) { + tmp_buffer[i] = 0; + } + const int8_t *src_kh = src; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const int8_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int i = 0; i < 8; i++) { + tmp_buffer[i] += (src_kw[c + i] - in_zp) * weight_kw[c + i]; + } + src_kw += in_kw_step; + weight_kw += channel; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += 3 * channel; + } // kernel_h loop + if (per_channel) { + for (int i = 0; i < 8; i++) { + tmp_buffer[i] += bias[c + i]; + tmp_buffer[i] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[i] * (1 << (unsigned int)left_shift[i]), out_multiplier[i]), + -right_shift[i]); + tmp_buffer[i] += out_zp; + tmp_buffer[i] = MSMAX(tmp_buffer[i], acc_min); + tmp_buffer[i] = MSMIN(tmp_buffer[i], acc_max); + dst[i] = (tmp_buffer[i]); + } + left_shift += 8; + right_shift += 8; + out_multiplier += 8; + } else { + for (int i = 0; i < 8; i++) { + tmp_buffer[i] += bias[c + i]; + tmp_buffer[i] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[i] * (1 << (unsigned int)left_shift[0]), out_multiplier[0]), + -right_shift[0]); + tmp_buffer[i] += out_zp; + tmp_buffer[i] = MSMAX(tmp_buffer[i], acc_min); + tmp_buffer[i] = MSMIN(tmp_buffer[i], acc_max); + dst[i] = (tmp_buffer[i]); + } + } + dst += 8; + } +} +#endif + +#ifndef ENABLE_ARM64 +void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step, + int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, int32_t acc_min, int32_t acc_max, + bool per_channel) { + ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 2, 2, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier, + left_shift, right_shift, acc_min, acc_max, per_channel); +} + +void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step, + int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, int32_t acc_min, int32_t acc_max, + bool per_channel) { + ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 2, 3, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier, + left_shift, right_shift, acc_min, acc_max, per_channel); +} + +void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int in_kh_step, + int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, int32_t acc_min, int32_t acc_max, + bool per_channel) { + ConvDw3x3Int8BorderPixel(dst, src, weight, bias, 3, 2, in_kh_step, in_kw_step, channel, in_zp, out_zp, out_multiplier, + left_shift, right_shift, acc_min, acc_max, per_channel); +} +#endif + +void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + bool filter_per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_; + int in_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_; + int acc_min = conv_param->conv_quant_arg_.out_act_min_[0]; + int acc_max = conv_param->conv_quant_arg_.out_act_max_[0]; + int input_row_size = conv_param->input_w_ * conv_param->input_channel_; + int weight_row_size = conv_param->kernel_w_ * conv_param->input_channel_; + int output_row_size = conv_param->output_w_ * conv_param->output_channel_; + int in_kh_step = sliding->in_kh_step_; + int in_kw_step = sliding->in_kw_step_; + + // top + for (int b = 0; b < conv_param->output_batch_; b++) { + const int8_t *input_batch = + input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + int8_t *output_batch = + output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + + const int8_t *input = input_batch; + const int16_t *weight = weight_data + weight_row_size + conv_param->input_channel_; + int8_t *output = output_batch; + ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + input += (conv_param->stride_w_ - 1) * conv_param->input_channel_; + weight = weight_data + weight_row_size; + output += conv_param->output_channel_; + for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { + ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + input += conv_param->stride_w_ * conv_param->input_channel_; + output += conv_param->output_channel_; + } + ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + + // left + input = input_batch + (conv_param->stride_h_ - 1) * input_row_size; + weight = weight_data + conv_param->input_channel_; + output = output_batch + output_row_size; + for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { + ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, + in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, + filter_per_channel); + input += conv_param->stride_h_ * input_row_size; + output += output_row_size; + } + + // right + input = input_batch + (conv_param->input_w_ - 2) * conv_param->input_channel_ + + (conv_param->stride_h_ - 1) * input_row_size; + weight = weight_data; + output = output_batch + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_; + for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { + ConvDw3x3Int8Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, + in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, + filter_per_channel); + input += conv_param->stride_h_ * input_row_size; + output += output_row_size; + } + + // bottom + input = input_batch + (conv_param->input_h_ - 2) * input_row_size; + weight = weight_data + conv_param->input_channel_; + output = output_batch + (conv_param->output_h_ - 1) * output_row_size; + ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_; + weight = weight_data; + output += conv_param->output_channel_; + for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { + ConvDw3x3Int8Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + input += conv_param->stride_w_ * conv_param->input_channel_; + output += conv_param->output_channel_; + } + ConvDw3x3Int8Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, filter_per_channel); + } +} +/*conv depthwise 3x3 int8 end*/ + +/*conv depthwise sliding window perchannel int8 begin*/ +void ConvDwInt8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int kernel_w, const int8_t *input_zp, + const int32_t *out_zp, const int32_t *out_multiplier, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *acc_min, const int32_t *acc_max) { + int tmp_buffer[C8NUM]; + for (int i = 0; i < C8NUM; i++) { + tmp_buffer[i] = 0; + } + const int8_t *src_kh = src; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + const int8_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += (src_kw[c] - input_zp[c]) * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + tmp_buffer[c] += out_zp[c]; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]); + dst[c] = (tmp_buffer[c]); + } +} + +void ConvDwInt8Border(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top, int bottom, + int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + const int8_t *in_zp, const int32_t *out_zp, const int32_t *out_multiplier, + const int32_t *left_shift, const int32_t *right_shift, const int32_t *acc_min, + const int32_t *acc_max) { + int8_t *dst_h = dst + top * sliding->out_h_step_; + for (int oh = top; oh < bottom; oh++) { + int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); + const int8_t *src_h = src + ih * sliding->in_h_step_; + + int8_t *dst_kernel = dst_h + left * sliding->block_channel_; + for (int ow = left; ow < right; ow++) { + int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); + const int8_t *src_w = src_h + iw * sliding->block_channel_; + + const int8_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; + + ConvDwInt8BorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max); + + dst_kernel += sliding->block_channel_; + } // width loop + dst_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM +void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, int width, + int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, + int in_kh_step, int in_kw_step, const int8_t *in_zp, const int32_t *out_zp, + const int32_t *out_multiplier, const int32_t *left_shift, const int32_t *right_shift, + const int32_t *acc_min, const int32_t *acc_max) { + int tmp_buffer[C8NUM]; + int8_t *dst_h = dst; + const int8_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + int8_t *dst_w = dst_h; + const int8_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + const int8_t *src_kh = src_w; + const int16_t *weight_kh = weight; + + for (int i = 0; i < C8NUM; i++) { + tmp_buffer[i] = 0; + } + for (int kh = 0; kh < kernel_h; kh++) { + const int8_t *src_kw = src_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += (src_kw[c] - in_zp[c]) * weight_kw[c]; + } + src_kw += in_kw_step; + weight_kw += C8NUM; + } // kernel_w loop + src_kh += in_kh_step; + weight_kh += kernel_w * C8NUM; + } // kernel_h loop + // add bias relu + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += bias[c]; + tmp_buffer[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + tmp_buffer[c] += out_zp[c]; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]); + dst_w[c] = (tmp_buffer[c]); + } + dst_w += block_channel; + src_w += in_sw_step; + } // dst_width loop + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} +#endif + +void ConvDwInt8SW(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, + const int8_t *input_zp, const int32_t *output_zp, const ConvParameter *conv_param, + const SlidingWindowParam *sliding, int task_id) { + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_h_); + NNACL_CHECK_ZERO_RETURN(conv_param->dilation_w_); + const int8_t *src = input_data; + int8_t *dst = output_data; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + const int8_t *src_data = src + oc * C8NUM; + int8_t *dst_data = dst + oc * C8NUM; + const int16_t *weight = weight_data + oc * sliding->kernel_step_; + const int32_t *bias = bias_data + oc * C8NUM; + + int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C8NUM; + int32_t *left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C8NUM; + int32_t *right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C8NUM; + int32_t *acc_min = conv_param->conv_quant_arg_.out_act_min_ + oc * C8NUM; + int32_t *acc_max = conv_param->conv_quant_arg_.out_act_max_ + oc * C8NUM; + const int8_t *in_zp = input_zp + oc * C8NUM; + const int32_t *out_zp = output_zp + oc * C8NUM; + + ConvDwInt8Border(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, + sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); + ConvDwInt8Border(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, + conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, + right_shift, acc_min, acc_max); + ConvDwInt8Border(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, conv_param, + sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); + ConvDwInt8Border(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, + right_shift, acc_min, acc_max); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + const int8_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; + ConvDwInt8Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, in_zp, + out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); + } + } // output C8 loop + src += sliding->in_step_; + dst += sliding->out_step_; + } // batch loop + // output nhwc8 +} +/*conv depthwise sliding window perchannel int8 end*/ + +/*deconv depthwise int8 begin*/ +void DeconvDwInt8BorderPixel(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, + int in_kh_step, int in_kw_step, int kernel_w) { + int32_t *dst_kh = dst; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < height; kh++) { + int32_t *dst_kw = dst_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < width; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop +} + +void DeconvDwInt8Border(int32_t *dst, const int16_t *src, const int16_t *weight, int top, int bottom, int left, + int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { + const int16_t *src_h = src + top * sliding->out_h_step_; + for (int ih = top; ih < bottom; ih++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int32_t *dst_h = dst + oh * sliding->in_h_step_; + + const int16_t *src_kernel = src_h + left * sliding->block_channel_; + for (int iw = left; iw < right; iw++) { + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + int32_t *dst_w = dst_h + ow * C4NUM; + + const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + int32_t *dst_kernel = dst_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + + DeconvDwInt8BorderPixel(dst_kernel, src_kernel, weight_kernel, end_kh - start_kh, end_kw - start_kw, + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_); + src_kernel += sliding->block_channel_; + } // width loop + src_h += sliding->out_h_step_; + } // height loop +} + +#ifndef ENABLE_ARM +void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, int kernel_h, + int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int in_kh_step, + int in_kw_step) { + int32_t *dst_h = dst; + const int16_t *src_h = src; + for (int oh = 0; oh < height; oh++) { + int32_t *dst_w = dst_h; + const int16_t *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + int32_t *dst_kh = dst_w; + const int16_t *weight_kh = weight; + for (int kh = 0; kh < kernel_h; kh++) { + int32_t *dst_kw = dst_kh; + const int16_t *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++) { + for (int c = 0; c < C4NUM; c++) { + dst_kw[c] += src_w[c] * weight_kw[c]; + } + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} +#endif + +#ifndef ENABLE_ARM +void DeconvDwInt8Post(int8_t *dst, int32_t *output_buffer, const int32_t *bias, int block_channel, int pixel_nums, + int out_multiplier, int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, + int32_t acc_max) { + int8_t *dst_k = dst; + int32_t *buffer_k = output_buffer; + for (int k = 0; k < pixel_nums; k++) { + for (int c = 0; c < C4NUM; c++) { + buffer_k[c] += bias[c]; + buffer_k[c] = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(buffer_k[c] * (1 << (unsigned int)left_shift), out_multiplier), -right_shift); + buffer_k[c] += out_zp; + buffer_k[c] = MSMAX(buffer_k[c], acc_min); + buffer_k[c] = MSMIN(buffer_k[c], acc_max); + dst_k[c] = (buffer_k[c]); + } + dst_k += block_channel; + buffer_k += C4NUM; + } +} +#endif + +void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id) { + const int16_t *src = input_data; + int8_t *dst = output_data; + int buffer_size = conv_param->output_h_ * conv_param->output_w_ * C4NUM; + for (int b = 0; b < conv_param->output_batch_; b++) { + for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { + memset(output_buffer, 0, buffer_size * sizeof(int32_t)); + const int16_t *src_data = src + oc * C4NUM; + const int16_t *weight = weight_data + oc * sliding->kernel_step_; + const int32_t *bias = bias_data + oc * C4NUM; + int8_t *dst_data = dst + oc * C4NUM; + DeconvDwInt8Border(output_buffer, src_data, weight, 0, sliding->top_, 0, conv_param->input_w_, conv_param, + sliding); + DeconvDwInt8Border(output_buffer, src_data, weight, sliding->bottom_, conv_param->input_h_, 0, + conv_param->input_w_, conv_param, sliding); + DeconvDwInt8Border(output_buffer, src_data, weight, sliding->top_, sliding->bottom_, 0, sliding->left_, + conv_param, sliding); + DeconvDwInt8Border(output_buffer, src_data, weight, sliding->top_, sliding->bottom_, sliding->right_, + conv_param->input_w_, conv_param, sliding); + + if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { + int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; + int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; + int32_t *out_t = output_buffer + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; + const int16_t *in_t = + src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM + DeconvDwInt8Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int16_t), + sliding->block_channel_ * sizeof(int16_t), sliding->in_sh_step_ * sizeof(int32_t), + sliding->in_sw_step_ * sizeof(int32_t), sliding->in_kh_step_ * sizeof(int32_t), + sliding->in_kw_step_ * sizeof(int32_t)); +#else + DeconvDwInt8Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, + sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); +#endif + } + DeconvDwInt8Post(dst_data, output_buffer, bias, sliding->block_channel_, + conv_param->output_h_ * conv_param->output_w_, conv_param->conv_quant_arg_.quant_multiplier_[0], + conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + } // output C4 loop + src += sliding->out_step_; + dst += sliding->in_step_; + } // batch loop + // output nhwc4 +} +/*deconv depthwise int8 end*/ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_depthwise_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_depthwise_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..cab506b644c3ca8daa97a89c001469ef7e4f3744 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_depthwise_int8.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_CONV_DEPTHWISE_H_ +#define NNACL_INT8_CONV_DEPTHWISE_H_ + +#include "nnacl/conv_parameter.h" +#include "nnacl/fp32/conv_depthwise_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, int task_id); + +void ConvDw3x3Int8Pad(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding); + +void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +void ConvDwInt8SW(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, + const int8_t *input_zp, const int32_t *output_zp, const ConvParameter *conv_param, + const SlidingWindowParam *sliding, int task_id); + +void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, + const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, + int task_id); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONV_DEPTHWISE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..6510a9aa8d2b9e0c9cd1dd14c6f99c675459a0df --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_int8.c @@ -0,0 +1,913 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/conv_int8.h" + +#ifdef ENABLE_ARM32 +void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr, + size_t plane_size, size_t input_channel, size_t output_channel) { + size_t hw4 = UP_ROUND(plane_size, C4NUM); + size_t ic16 = UP_ROUND(input_channel, C16NUM); + +#ifdef ENABLE_ARM32 + size_t oc_div2 = output_channel / C2NUM * C2NUM; + size_t oc_res2 = output_channel - oc_div2; + size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4; + PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride); +#else + for (int ri = 0; ri < plane_size; ri++) { + int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; + for (int ci = 0; ci < output_channel; ci++) { + int32_t tmp_sum_value = 0; + int ci2div = ci / C2NUM, ci2mod = ci % C2NUM; + int32_t filter_zp = filter_zp_ptr[ci]; + for (int di = 0; di < input_channel; di++) { + size_t di16div = di / C16NUM, di16mod = di % C16NUM; + int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; + tmp_sum_value += input_value[src_index]; + } + int dst_index = ci2div * C2NUM * hw4 + ri * C2NUM + ci2mod; + input_sum[dst_index] = tmp_sum_value * filter_zp; + } + } +#endif + return; +} +#endif + +void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, const int32_t *filter_zp_ptr, + size_t plane_size, size_t input_channel, size_t output_channel) { + size_t hw4 = UP_ROUND(plane_size, C4NUM); + size_t ic16 = UP_ROUND(input_channel, C16NUM); +#ifdef ENABLE_ARM64 + size_t oc_div4 = output_channel / C4NUM * C4NUM; + size_t oc_res4 = output_channel - oc_div4; + size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4; + PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride); +#else + + for (int ri = 0; ri < plane_size; ri++) { + int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; + for (int ci = 0; ci < output_channel; ci++) { + int32_t tmp_sum_value = 0; + int ci4div = ci / C4NUM, ci4mod = ci % C4NUM; + int32_t filter_zp = filter_zp_ptr[ci]; + for (int di = 0; di < input_channel; di++) { + size_t di16div = di / C16NUM, di16mod = di % C16NUM; + int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod; + tmp_sum_value += input_value[src_index]; + } + int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod; + input_sum[dst_index] = tmp_sum_value * filter_zp; + } + } +#endif + return; +} + +void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t output_channel, size_t plane_size, const int32_t *filter_zp, size_t inputsum_stride) { + int ic4 = UP_ROUND(input_channel, C4NUM); + int oc8 = UP_ROUND(output_channel, C8NUM); + int hw8 = UP_ROUND(plane_size, C8NUM); + size_t hw_8div = plane_size / C8NUM * C8NUM; + size_t oc_8div = output_channel / C8NUM * C8NUM; + size_t oc_8res = output_channel - oc_8div; + size_t ic_4div = input_channel / C4NUM * C4NUM; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + int32_t *input_sum_r = input_sum; + + for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_oc = input_sum_r; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + size_t input_sum_stride = inputsum_stride * 4 - C8NUM * C8NUM * 4; + asm volatile( + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x0, #0 \n" + "1: \n" + "cmp x0, %[ic_4div] \n" + "add x0, x0, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + "ld1 {v1.s}[0], [x12], %[src_stride]\n" + "ld1 {v1.s}[1], [x12], %[src_stride]\n" + "ld1 {v1.s}[2], [x12], %[src_stride]\n" + "ld1 {v1.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 1b \n" + + "3: \n" /* col res 1 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + "ld1 {v1.b}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[8], [x12], %[src_stride]\n" + "ld1 {v1.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "4: \n" /* col res 2 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "5: \n" /* col res 3 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[2], [x13], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.b}[6], [x13], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[10], [x13], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "6: \n" + "dup v0.4s, v16.s[0] \n" + "dup v1.4s, v16.s[1] \n" + "dup v2.4s, v16.s[2] \n" + "dup v3.4s, v16.s[3] \n" + "dup v4.4s, v17.s[0] \n" + "dup v5.4s, v17.s[1] \n" + "dup v6.4s, v17.s[2] \n" + "dup v7.4s, v17.s[3] \n" + "mov x4, #0 \n" + "mov x10, %[filter_zp] \n" + "mov x11, %[input_sum_oc] \n" + + "7: \n" + "cmp x4, %[oc_8div] \n" + "beq 8f \n" + "add x4, x4, #8\n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.4s}, [x10], #16\n" + + "mul v18.4s, v16.4s, v0.4s \n" + "mul v19.4s, v17.4s, v0.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + + "mul v20.4s, v16.4s, v1.4s \n" + "mul v21.4s, v17.4s, v1.4s \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + + "mul v22.4s, v16.4s, v2.4s \n" + "mul v23.4s, v17.4s, v2.4s \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + + "mul v24.4s, v16.4s, v3.4s \n" + "mul v25.4s, v17.4s, v3.4s \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "mul v18.4s, v16.4s, v4.4s \n" + "mul v19.4s, v17.4s, v4.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + + "mul v20.4s, v16.4s, v5.4s \n" + "mul v21.4s, v17.4s, v5.4s \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + + "mul v22.4s, v16.4s, v6.4s \n" + "mul v23.4s, v17.4s, v6.4s \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + + "mul v24.4s, v16.4s, v7.4s \n" + "mul v25.4s, v17.4s, v7.4s \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "add x11, x11, %[input_sum_stride] \n" + "b 7b \n" + + "8: \n" + "cmp %[oc_8res], #0\n" + "beq 17f \n" + + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + "cmp %[oc_8res], #1\n" + "beq 9f \n" + "cmp %[oc_8res], #2\n" + "beq 10f \n" + "cmp %[oc_8res], #3\n" + "beq 11f \n" + "cmp %[oc_8res], #4\n" + "beq 12f \n" + "cmp %[oc_8res], #5\n" + "beq 13f \n" + "cmp %[oc_8res], #6\n" + "beq 14f \n" + "cmp %[oc_8res], #7\n" + "beq 15f \n" + + "9: \n" + "ld1 {v16.s}[0], [x10] \n" + "b 16f \n" + + "10: \n" + "ld1 {v16.d}[0], [x10] \n" + "b 16f \n" + + "11: \n" + "ld1 {v16.d}[0], [x10] \n" + "add x10, x10, #8 \n" + "ld1 {v16.s}[2], [x10] \n" + "b 16f \n" + + "12: \n" + "ld1 {v16.4s}, [x10] \n" + "b 16f \n" + + "13: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.s}[0], [x10] \n" + "b 16f \n" + + "14: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.d}[0], [x10] \n" + "b 16f \n" + + "15: \n" + "ld1 {v16.4s}, [x10], #16\n" + "ld1 {v17.d}[0], [x10] \n" + "add x10, x10, #8 \n" + "ld1 {v17.s}[2], [x10] \n" + "b 16f \n" + + "16: \n" + "mul v18.4s, v16.4s, v0.4s \n" + "mul v19.4s, v17.4s, v0.4s \n" + "mul v20.4s, v16.4s, v1.4s \n" + "mul v21.4s, v17.4s, v1.4s \n" + "mul v22.4s, v16.4s, v2.4s \n" + "mul v23.4s, v17.4s, v2.4s \n" + "mul v24.4s, v16.4s, v3.4s \n" + "mul v25.4s, v17.4s, v3.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "mul v18.4s, v16.4s, v4.4s \n" + "mul v19.4s, v17.4s, v4.4s \n" + "mul v20.4s, v16.4s, v5.4s \n" + "mul v21.4s, v17.4s, v5.4s \n" + "mul v22.4s, v16.4s, v6.4s \n" + "mul v23.4s, v17.4s, v6.4s \n" + "mul v24.4s, v16.4s, v7.4s \n" + "mul v25.4s, v17.4s, v7.4s \n" + "st1 {v18.4s}, [x11], #16 \n" + "st1 {v19.4s}, [x11], #16 \n" + "st1 {v20.4s}, [x11], #16 \n" + "st1 {v21.4s}, [x11], #16 \n" + "st1 {v22.4s}, [x11], #16 \n" + "st1 {v23.4s}, [x11], #16 \n" + "st1 {v24.4s}, [x11], #16 \n" + "st1 {v25.4s}, [x11], #16 \n" + + "17: \n" + + : + : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ filter_zp ] "r"(filter_zp), + [ input_sum_oc ] "r"(input_sum_oc), [ input_sum_stride ] "r"(input_sum_stride), [ src_stride ] "r"(src_stride), + [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ oc_8div ] "r"(oc_8div), [ oc_8res ] "r"(oc_8res) + : "x0", "x1", "x4", "x9", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25"); +#else + int32_t tmp_sum_value[8] = {0}; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int ici = input_channel; ici < ic4; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } + + for (int oci = 0; oci < oc_8div; oci += C8NUM) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + 0] = tmp_sum_value[ri] * filter_zp[oci + 0]; + input_sum_oc[ri * C8NUM + 1] = tmp_sum_value[ri] * filter_zp[oci + 1]; + input_sum_oc[ri * C8NUM + 2] = tmp_sum_value[ri] * filter_zp[oci + 2]; + input_sum_oc[ri * C8NUM + 3] = tmp_sum_value[ri] * filter_zp[oci + 3]; + input_sum_oc[ri * C8NUM + 4] = tmp_sum_value[ri] * filter_zp[oci + 4]; + input_sum_oc[ri * C8NUM + 5] = tmp_sum_value[ri] * filter_zp[oci + 5]; + input_sum_oc[ri * C8NUM + 6] = tmp_sum_value[ri] * filter_zp[oci + 6]; + input_sum_oc[ri * C8NUM + 7] = tmp_sum_value[ri] * filter_zp[oci + 7]; + } + input_sum_oc += inputsum_stride; + } + if (oc_8div != output_channel) { + for (int oci = 0; oci < oc_8res; oci += 1) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + oci] = tmp_sum_value[ri] * filter_zp[oc_8div + oci]; + } + } + for (int oci = oc_8res; oci < C8NUM; oci += 1) { + for (int ri = 0; ri < C8NUM; ri++) { + input_sum_oc[ri * C8NUM + oci] = 0; + } + } + } /* oc8 res done */ +#endif + src_r += input_channel * C8NUM; + pack_r += ic4 * C8NUM; + input_sum_r += C8NUM * C8NUM; + } + + if (hw_8div != plane_size) { + memset(pack_r, 0, C8NUM * ic4); + for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { + int32_t *input_sum_oc = input_sum_r; + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + + for (int oci = 0; oci < oc_8div; oci += C8NUM) { + for (int curoi = 0; curoi < C8NUM; curoi++) { + input_sum_oc[curoi] = tmp_sum_value * filter_zp[oci + curoi]; + } + input_sum_oc += inputsum_stride; + } + if (oc_8div != output_channel) { + for (int oci = 0; oci < oc_8res; oci += 1) { + input_sum_oc[oci] = tmp_sum_value * filter_zp[oc_8div + oci]; + } + for (int oci = oc_8res; oci < C8NUM; oci += 1) { + input_sum_oc[oci] = 0; + } + } /* oc8 res done */ + + src_r += input_channel; + pack_r += C4NUM; + input_sum_r += C8NUM; + } + + for (int hwi = plane_size; hwi < hw8; hwi++) { + for (int oc = 0; oc < oc8; oc++) { + int oc8div = oc / C8NUM, oc8res = oc % C8NUM; + input_sum[oc8div * inputsum_stride + hwi * C8NUM + oc8res] = 0; + } + } + } + return; +} + +void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t plane_size, const ConvParameter *conv_param) { + int ic4 = UP_ROUND(input_channel, C4NUM); + size_t hw_8div = plane_size / C8NUM * C8NUM; + size_t ic_4div = input_channel / C4NUM * C4NUM; + int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_r = input_sum + hwi; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + asm volatile( + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + "mov x14, %[input_sum_r] \n" + "dup v20.4s, %w[filter_zp] \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x0, #0 \n" + "1: \n" + "cmp x0, %[ic_4div] \n" + "add x0, x0, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + "ld1 {v1.s}[0], [x12], %[src_stride]\n" + "ld1 {v1.s}[1], [x12], %[src_stride]\n" + "ld1 {v1.s}[2], [x12], %[src_stride]\n" + "ld1 {v1.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 1b \n" + + "3: \n" /* col res 1 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + "ld1 {v1.b}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[8], [x12], %[src_stride]\n" + "ld1 {v1.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "4: \n" /* col res 2 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "5: \n" /* col res 3 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[2], [x13], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.b}[6], [x13], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[10], [x13], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v16.4s, v16.4s, v0.4s \n" + "add v17.4s, v17.4s, v1.4s \n" + "b 6f \n" + + "6: \n" + "mul v16.4s, v16.4s, v20.4s \n" + "mul v17.4s, v17.4s, v20.4s \n" + + "st1 {v16.4s}, [x14], #16 \n" + "st1 {v17.4s}, [x14], #16 \n" + + : + : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), + [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp) + : "x0", "x1", "x10", "x11", "x12", "x13", "x14", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", + "v20"); +#else + int32_t tmp_sum_value[8] = {0}; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int ici = input_channel; ici < ic4; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } + + for (int i = 0; i < C8NUM; i++) { + input_sum_r[i] = tmp_sum_value[i] * filter_zp; + } +#endif + src_r += input_channel * C8NUM; + pack_r += ic4 * C8NUM; + } + + if (hw_8div != plane_size) { + memset(pack_r, 0, C8NUM * ic4); + for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + input_sum[hwi] = tmp_sum_value * filter_zp; + src_r += input_channel; + pack_r += C4NUM; + } + for (int hwi = plane_size; hwi < UP_ROUND(plane_size, C8NUM); hwi++) { + input_sum[hwi] = 0; + } + } + return; +} + +void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, const int32_t *filter_zp, + const ConvParameter *conv_param) { + size_t hw = conv_param->output_h_ * conv_param->output_w_; + size_t hw4 = UP_ROUND(hw, C4NUM); + size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM); + if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { + PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); + } else { +#ifdef ENABLE_ARM32 + PackInputSum16x4PerChannelArm32(input, input_sum, filter_zp, hw, conv_param->input_channel_, + conv_param->output_channel_); +#else + PackInputSum16x4PerChannel(input, input_sum, filter_zp, hw, conv_param->input_channel_, + conv_param->output_channel_); +#endif + } + return; +} + +void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int real_cal_num, + int block_index, const int32_t *filter_zp, int32_t *input_sum, + const ConvParameter *conv_param, bool per_channel, bool is_optimize) { + // input format : nhwc + int kernel_h = conv_param->kernel_h_; + int kernel_w = conv_param->kernel_w_; + int stride_h = conv_param->stride_h_; + int stride_w = conv_param->stride_w_; + int pad_h = conv_param->pad_u_; + int pad_w = conv_param->pad_l_; + int dilation_h = conv_param->dilation_h_; + int dilation_w = conv_param->dilation_w_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int out_w = conv_param->output_w_; + int kernel_plane = kernel_h * kernel_w; + NNACL_CHECK_ZERO_RETURN(out_w); + NNACL_CHECK_ZERO_RETURN(dilation_h); + NNACL_CHECK_ZERO_RETURN(dilation_w); + for (int i = 0; i < real_cal_num; i++) { + int block_start = block_index + i; + int input_h = block_start / out_w * stride_h - pad_h; + int input_w = block_start % out_w * stride_w - pad_w; + int input_stride = input_h * in_w * in_channel + input_w * in_channel; + int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h)); + int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h)); + int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w)); + int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w)); + if (kw_e <= kw_s || kh_e <= kh_s) { + continue; + } + if (dilation_w == 1 && dilation_h == 1) { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * in_w * in_channel + input_stride; + int input_x_stride = input_y_stride + kw_s * in_channel; + int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane; + memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, (kw_e - kw_s) * in_channel); + } // kernel_h loop + } else { + for (int j = kh_s; j < kh_e; j++) { + int input_y_stride = j * dilation_h * in_w * in_channel + input_stride; + for (int k = kw_s; k < kw_e; ++k) { + int input_x_stride = input_y_stride + k * dilation_w * in_channel; + int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane; + memcpy(matmul_input + input_plane_offset, input_data + input_x_stride, in_channel); + } + } // kernel_h loop + } + } // tile num loop + int deep = kernel_plane * in_channel; + if (is_optimize) { + if (per_channel) { + Conv1x1PreOptPeroc(matmul_input, packed_input, input_sum, deep, conv_param->output_channel_, real_cal_num, + filter_zp, C8NUM * C8NUM); + } else { + Conv1x1PreOptPert(matmul_input, packed_input, input_sum, deep, real_cal_num, conv_param); + } + } else { + RowMajor2Row16x4MajorInt8(matmul_input, packed_input, real_cal_num, deep); + if (per_channel) { +#ifdef ENABLE_ARM32 + PackInputSum16x4PerChannelArm32(packed_input, input_sum, filter_zp, real_cal_num, deep, + conv_param->output_channel_); +#else + PackInputSum16x4PerChannel(packed_input, input_sum, filter_zp, real_cal_num, deep, conv_param->output_channel_); +#endif + } else { + size_t hw4 = UP_ROUND(real_cal_num, C4NUM); + size_t ic16 = UP_ROUND(deep, C16NUM); + PackInputSum16x4PerLayer(packed_input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, + ic16); + } + } +} + +void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight, + const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id, + ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize) { + int in_channel = conv_param->input_channel_; + int out_channel = conv_param->output_channel_; + int tile_n = conv_param->tile_num_; + int output_count = conv_param->output_h_ * conv_param->output_w_; + NNACL_CHECK_ZERO_RETURN(tile_n); + int output_tile_count = UP_DIV(output_count, tile_n); + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + int unit_size; + int input_sum_offset; + int up_round_oc; +#ifdef ENABLE_ARM32 + up_round_oc = UP_ROUND(out_channel, C2NUM); + unit_size = UP_ROUND(kernel_plane * in_channel, C16NUM); +#else + if (is_optimize) { + up_round_oc = UP_ROUND(out_channel, C8NUM); + unit_size = UP_ROUND(kernel_plane * in_channel, C4NUM); + } else { + up_round_oc = UP_ROUND(out_channel, C4NUM); + unit_size = UP_ROUND(kernel_plane * in_channel, C16NUM); + } +#endif + bool per_channel = false; + if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) { + input_sum_offset = tile_n * up_round_oc; + per_channel = true; + } else { + input_sum_offset = tile_n; + per_channel = false; + } + + for (int b = 0; b < conv_param->input_batch_; b++) { + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; + int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_; + for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { + int start_index = thread_id * tile_n; + int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n; + int32_t *tmp_input_sum = input_sum + task_id * input_sum_offset; + int8_t *gemm_input = packed_input + task_id * unit_size * tile_n; + int8_t *matmul = matmul_input + task_id * kernel_plane * in_channel * tile_n; + memset(matmul, conv_param->conv_quant_arg_.input_quant_args_[0].zp_, kernel_plane * in_channel * tile_n); + Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, matmul, real_cal_num, start_index, filter_zp, + tmp_input_sum, conv_param, per_channel, is_optimize); + + int out_offset = thread_id * tile_n * out_channel + out_batch_offset; + int8_t *gemm_output = output_data + out_offset; +#ifdef ENABLE_ARM32 + MatmulInt8Neon32( + gemm_input, packed_weight, gemm_output, real_cal_num, out_channel, unit_size, tmp_input_sum, bias_data, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.quant_multiplier_, + conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, out_channel, per_channel); +#elif ENABLE_ARM64 + if (is_optimize) { + matmul_func(gemm_input, packed_weight, gemm_output, real_cal_num, out_channel, unit_size, out_channel, + tmp_input_sum, bias_data, conv_param->conv_quant_arg_.left_shift_, + conv_param->conv_quant_arg_.right_shift_, conv_param->conv_quant_arg_.quant_multiplier_, + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], + conv_param->conv_quant_arg_.out_act_max_[0], per_channel); + } else { + MatmulInt8Neon64(gemm_input, packed_weight, gemm_output, UP_ROUND(real_cal_num, C4NUM), + UP_ROUND(out_channel, C4NUM), unit_size, tmp_input_sum, bias_data, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.left_shift_, + conv_param->conv_quant_arg_.right_shift_, real_cal_num, out_channel, out_channel, per_channel); + } +#else + MatMulInt8_8x8_r( + gemm_input, packed_weight, gemm_output, real_cal_num, out_channel, unit_size, out_channel, tmp_input_sum, + bias_data, conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, + conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], per_channel); +#endif + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..94fc4ded500b118fd0dfc348d082667b2b092940 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/conv_int8.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_CONV_INT8_H_ +#define NNACL_INT8_CONV_INT8_H_ + +#include +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/int8/matmul_int8.h" +#include "nnacl/int8/common_func_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif +// int8 conv common +void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight, + const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id, + ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CONV_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/crop_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/crop_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..b1a1469bc6a50deebb01467281627055d78f1706 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/crop_int8.c @@ -0,0 +1,236 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/crop_parameter.h" +#include +#include +#include +#include "nnacl/int8/crop_int8.h" + +void Int8Crop1D(const int8_t *input, int8_t *output, int *output_shape, int64_t *in_offset, int task_id, + int thread_count, const CropQuantArg *quant) { + const int out_batch = output_shape[0]; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_batch, thread_count) : out_batch; + if (task_id_stride <= 0) { + return; + } + + float in_scale = quant->in_args_.scale_; + int32_t in_zp = quant->in_args_.zp_; + float out_scale = quant->out_args_.scale_; + int32_t out_zp = quant->out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + int n = task_id * task_id_stride; + if (n >= out_batch) { + return; + } + const int8_t *in_ptr = input + n + in_offset[0]; + int8_t *out_ptr = output + n; + int64_t out_dist_stride = MSMIN(out_batch - task_id * task_id_stride, task_id_stride); + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_dist_stride); + } else { + for (int i = 0; i < out_dist_stride; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > quant->output_activation_max_) { + out_ptr[i] = quant->output_activation_max_; + } else if (output_tmp < quant->output_activation_min_) { + out_ptr[i] = quant->output_activation_min_; + } else { + out_ptr[i] = (int8_t)output_tmp; + } + } + } + return; +} + +void Int8Crop2D(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int task_id, int thread_count, const CropQuantArg *quant) { + const int in_height = input_shape[1]; + const int out_batch = output_shape[0]; + const int out_height = output_shape[1]; + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + float in_scale = quant->in_args_.scale_; + int32_t in_zp = quant->in_args_.zp_; + float out_scale = quant->out_args_.scale_; + int32_t out_zp = quant->out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + for (int n = 0; n < out_batch; n++) { + int h = task_id * task_id_stride; + if (h >= out_height) { + return; + } + const int8_t *in_ptr = input + (n + in_offset[0]) * in_height + h + in_offset[1]; + int8_t *out_ptr = output + n * out_height + h; + int64_t out_dist_stride = MSMIN(out_height - task_id * task_id_stride, task_id_stride); + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_dist_stride); + } else { + for (int i = 0; i < out_dist_stride; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > quant->output_activation_max_) { + out_ptr[i] = quant->output_activation_max_; + } else if (output_tmp < quant->output_activation_min_) { + out_ptr[i] = quant->output_activation_min_; + } else { + out_ptr[i] = (int8_t)output_tmp; + } + } + } + } + return; +} + +void Int8Crop3D(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int task_id, int thread_count, const CropQuantArg *quant) { + const int in_height = input_shape[1]; + const int in_width = input_shape[2]; + + const int out_batch = output_shape[0]; + const int out_height = output_shape[1]; + const int out_width = output_shape[2]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + const int in_stride_h = in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_h = out_width; + const int out_stride_n = out_stride_h * out_height; + + float in_scale = quant->in_args_.scale_; + int32_t in_zp = quant->in_args_.zp_; + float out_scale = quant->out_args_.scale_; + int32_t out_zp = quant->out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + int h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + const int8_t *in_ptr = input + (n + in_offset[0]) * in_stride_n + (h + in_offset[1]) * in_stride_h + in_offset[2]; + int8_t *out_ptr = output + n * out_stride_n + h * out_stride_h; + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_width); + } else { + for (int i = 0; i < out_width; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > quant->output_activation_max_) { + out_ptr[i] = quant->output_activation_max_; + } else if (output_tmp < quant->output_activation_min_) { + out_ptr[i] = quant->output_activation_min_; + } else { + out_ptr[i] = (int8_t)output_tmp; + } + } + } + } + } + return; +} + +void Int8Crop4D(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int task_id, int thread_count, const CropQuantArg *quant) { + const int in_height = input_shape[1]; + const int in_width = input_shape[2]; + const int in_channel = input_shape[3]; + + const int out_batch = output_shape[0]; + const int out_height = output_shape[1]; + const int out_width = output_shape[2]; + const int out_channel = output_shape[3]; + + int64_t task_id_stride = thread_count > 1 ? UP_DIV(out_height, thread_count) : out_height; + if (task_id_stride <= 0) { + return; + } + + const int in_stride_w = in_channel; + const int in_stride_h = in_channel * in_width; + const int in_stride_n = in_stride_h * in_height; + + const int out_stride_w = out_channel; + const int out_stride_h = out_channel * out_width; + const int out_stride_n = out_stride_h * out_height; + + float in_scale = quant->in_args_.scale_; + int32_t in_zp = quant->in_args_.zp_; + float out_scale = quant->out_args_.scale_; + int32_t out_zp = quant->out_args_.zp_; + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + + for (int n = 0; n < out_batch; n++) { + for (int t = 0; t < task_id_stride; t++) { + int h = t + task_id * task_id_stride; + if (h >= out_height) { + break; + } + for (int w = 0; w < out_width; w++) { + const int8_t *in_ptr = input + (n + in_offset[0]) * in_stride_n + (h + in_offset[1]) * in_stride_h + + (w + in_offset[2]) * in_stride_w + in_offset[3]; + int8_t *out_ptr = output + n * out_stride_n + h * out_stride_h + w * out_stride_w; + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + memcpy(out_ptr, in_ptr, sizeof(int8_t) * out_channel); + } else { + for (int i = 0; i < out_channel; i++) { + int32_t output_tmp = round(in_ptr[i] * scale + bias) + out_zp; + if (output_tmp > quant->output_activation_max_) { + out_ptr[i] = quant->output_activation_max_; + } else if (output_tmp < quant->output_activation_min_) { + out_ptr[i] = quant->output_activation_min_; + } else { + out_ptr[i] = (int8_t)output_tmp; + } + } + } + } + } + } + return; +} + +void Int8Crop(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int input_dim, int task_id, int thread_count, const CropQuantArg *quant) { + switch (input_dim) { + case 1: + Int8Crop1D(input, output, output_shape, in_offset, task_id, thread_count, quant); + break; + case 2: + Int8Crop2D(input, output, input_shape, output_shape, in_offset, task_id, thread_count, quant); + break; + case 3: + Int8Crop3D(input, output, input_shape, output_shape, in_offset, task_id, thread_count, quant); + break; + case 4: + Int8Crop4D(input, output, input_shape, output_shape, in_offset, task_id, thread_count, quant); + break; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/crop_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/crop_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..06128bb60263f899d79659ef70b80d6ae78cdf13 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/crop_int8.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_CROP_INT8_H_ +#define NNACL_INT8_CROP_INT8_H_ +#include "nnacl/op_base.h" +#include "nnacl/crop_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Int8Crop(const int8_t *input, int8_t *output, int *input_shape, int *output_shape, int64_t *in_offset, + int input_dim, int task_id, int thread_count, const CropQuantArg *quant); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_CROP_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/deconv_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/deconv_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..19882a7c5e2929dc5a80bf88668d64a14aa5f0cb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/deconv_int8.c @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/deconv_int8.h" +#include "nnacl/int8/matmul_int8.h" +#include "nnacl/int8/common_func_int8.h" +int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + const ConvParameter *conv_param) { + /* row4x4-major(ih*iw x oc*kh*kw) -> row4-major(oh*ow x oc) */ + int input_plane = conv_param->input_w_ * conv_param->input_h_; + int kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + int output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc4 = UP_DIV(output_channel, C4NUM); + int in_plane4 = UP_ROUND(input_plane, C4NUM); + + int src_iw_stride = C4NUM; + int src_ih_stride = conv_param->input_w_ * C4NUM; + int src_kw_stride = in_plane4 * C4NUM; + int src_kh_stride = in_plane4 * conv_param->kernel_w_ * C4NUM; + int dst_oh_stride = conv_param->output_w_ * C4NUM; + int dst_ow_stride = C4NUM; + int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C4NUM; + int dst_kw_stride = conv_param->dilation_w_ * C4NUM; + + for (int c = 0; c < oc4; c++) { + int32_t *dst_ptr = tmp + c * output_plane * C4NUM; + const int32_t *src_ptr = src + c * in_plane4 * kernel_plane * C4NUM; + memset(dst_ptr, 0, (size_t)output_plane * C4NUM * sizeof(int32_t)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_u_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_l_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * src_ih_stride + iw * src_iw_stride + kh * src_kh_stride + kw * src_kw_stride; + int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; + int32_t *tmp_dst = dst_ptr + dst_index; + const int32_t *tmp_src = src_ptr + src_index; +#ifndef ENABLE_ARM64 + for (int i = 0; i < C4NUM; i++) { + tmp_dst[i] += tmp_src[i]; + } +#else + asm volatile( + "mov x0, %[tmp_src] \n" + "mov x1, %[tmp_dst] \n" + + "ld1 {v0.4s}, [x0] \n" + "ld1 {v1.4s}, [x1] \n" + + "add v0.4s, v0.4s, v1.4s \n" + + "st1 {v0.4s}, [x1] \n" + + : + : [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst) + : "x0", "x1", "v0", "v1"); +#endif + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc*/ + + PostFuncInt8C4(tmp, bias, out, output_channel, (size_t)output_plane, conv_param->output_channel_, + conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + return NNACL_OK; +} + +void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane, + bool support_optimize_) { + /* optimize normal -> same layout */ + int ic16 = UP_ROUND(input_channel, C16NUM); + int oc4 = UP_ROUND(output_channel, C4NUM); + for (int ic = 0; ic < input_channel; ic++) { + int ic16div = ic / C16NUM, ic16mod = ic % C16NUM; + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / C4NUM, oc4mod = oc % C4NUM; + for (int hw = 0; hw < plane; hw++) { + int src_index = ic * output_channel * plane + hw * output_channel + oc; + int dst_index = hw * ic16 * oc4 + oc4div * ic16 * C4NUM + ic16div * C16NUM * C4NUM + oc4mod * C16NUM + ic16mod; + dst[dst_index] = src[src_index]; + } + } + } + return; +} + +void DeConvPackWeightSum(const int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep, + int col4, bool suppport_opt) { + int deep16 = UP_ROUND(deep, C16NUM); + int32_t zp_sum = filter_zp * input_zp * deep; + for (int c = 0; c < col4; c++) { + int c4div = c / C4NUM, c4mod = c % C4NUM; + int32_t value = 0; + for (int r = 0; r < deep; r++) { + int r16div = r / C16NUM, r16mod = r % C16NUM; + int src_index = c4div * deep16 * C4NUM + r16div * C4NUM * C16NUM + c4mod * C16NUM + r16mod; + value += weight[src_index]; + } + weight_sum[c] = zp_sum - value * input_zp; + } + return; +} + +void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16, + bool suppport_opt) { + /* optimize normal -> same layout */ + PackInputSum16x4PerLayer(src, dst, filter_zp, row4, col16); + return; +} + +int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, const int32_t *weight_sum, + const int32_t *input_sum, size_t act_row, size_t act_col, size_t act_deep, + const ConvParameter *conv_param, MATMUL_OPT_R4_FUNC matmul_func) { + if (matmul_func != NULL) { + matmul_func(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum); + } else { + MatMulInt8_16x4(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum); + } + return NNACL_OK; +} + +int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param, bool support_optimize) { + /* optimize normal -> same layout (C4) */ + int error_code = DeConvPostInt8C4(src, bias, tmp, out, output_channel, conv_param); + return error_code; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/deconv_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/deconv_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..203620d16ea3eedaf4dac2bfa50e71efcd7773c2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/deconv_int8.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_DECONV_H_ +#define NNACL_INT8_DECONV_H_ + +#include +#include "nnacl/pack.h" +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/common_func.h" +#include "nnacl/int8/matmul_int8.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DeConvPackWeightSum(const int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep, + int col4, bool suppport_opt); +void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16, + bool suppport_opt); +void DeConvWeightTransInt8(const int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane, + bool support_optimize_); + +int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, const int32_t *weight_sum, + const int32_t *input_sum, size_t act_row, size_t act_col, size_t act_deep, + const ConvParameter *conv_param, MATMUL_OPT_R4_FUNC matmul_func); +int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param, bool support_optimize); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DECONV_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/depth_to_space_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/depth_to_space_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..3b54c2930de6918fbe9748d22ecbb631d968953e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/depth_to_space_int8.c @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/int8/depth_to_space_int8.h" +#include + +void DepthToSpaceForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, DepthToSpaceArgs *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg) { + int32_t block_size = param->block_size_; + int32_t in_shape_dim2 = in_shape[2]; + int32_t in_shape_dim1 = in_shape[1]; + int64_t copy_size = block_size * param->out_stride_dim2_; + const float output_inverse_scale = 1.f / out_quant_arg->scale_; + float scale = in_quant_arg->scale_ * output_inverse_scale; + float bias = -in_quant_arg->zp_ * scale; + int32_t output_zp = out_quant_arg->zp_; + for (int i = 0; i < in_shape[0]; ++i) { + int64_t in_offset_n = i * param->in_stride_dim0_; + int64_t out_offset_n = i * param->out_stride_dim0_; + for (int j = 0; j < in_shape_dim1; ++j) { + int64_t in_offset_h = in_offset_n + j * param->in_stride_dim1_; + int64_t out_offset_h = out_offset_n + j * block_size * param->out_stride_dim1_; + for (int k = 0; k < in_shape_dim2; ++k) { + int64_t in_offset_w = in_offset_h + k * param->in_stride_dim2_; + int64_t out_offset_w = out_offset_h + k * block_size * param->out_stride_dim2_; + for (int l = 0; l < block_size; ++l) { + int64_t out_offset = out_offset_w + l * param->out_stride_dim1_; + int64_t in_offset = in_offset_w + l * block_size * param->out_stride_dim2_; + for (int m = 0; m < copy_size; ++m) { + int32_t output_tmp = round(input[in_offset + m] * scale + bias) + output_zp; + output_tmp = output_tmp > 127 ? 127 : output_tmp; + output_tmp = output_tmp < -128 ? -128 : output_tmp; + output[out_offset + m] = output_tmp; + } + } + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/depth_to_space_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/depth_to_space_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..3063634b1936b6743fde570a0e1a9a639d6bd1a9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/depth_to_space_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ +#define NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ + +#include "nnacl/depth_to_space_parameter.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/kernel/depth_to_space.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DepthToSpaceForNHWCInt8(const int8_t *input, int8_t *output, const int32_t *in_shape, DepthToSpaceArgs *param, + QuantArg *in_quant_arg, QuantArg *out_quant_arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DEPTH_TO_SPACE_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/div_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/div_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..144758b5474e681ec206e5adc3679fdc3d73ba11 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/div_int8.c @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/div_int8.h" + +int DivInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const DivQuantArg *para) { + int index = 0; + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->in0_args_.zp_ + input0_data[index]; + const int32_t input1_val = para->in1_args_.zp_ + input1_data[index]; + if (input1_val == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + + int recip_shift; + const int32_t input1_inv = (input1_val > 0) ? ComputerReciprocal(input1_val, 31, &recip_shift) + : -ComputerReciprocal(-input1_val, 31, &recip_shift); + const int leading_bits = CountLeadingSignBits(input0_val); + const int32_t raw_data = + SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv); + const int total_shift = para->output_shift_ - recip_shift - leading_bits; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data, para->output_multiplier_), -total_shift) + + para->out_args_.zp_; + output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); + } + return NNACL_OK; +} + +int DivScalarInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const DivQuantArg *para) { + int index = 0; + const int32_t input1_val = para->in1_args_.zp_ + *input1_data; + if (input1_val == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + int recip_shift; + const int32_t input1_inv = (input1_val > 0) ? ComputerReciprocal(input1_val, 31, &recip_shift) + : -ComputerReciprocal(-input1_val, 31, &recip_shift); + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->in0_args_.zp_ + input0_data[index]; + + const int leading_bits = CountLeadingSignBits(input0_val); + const int32_t raw_data = + SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv); + const int total_shift = para->output_shift_ - recip_shift - leading_bits; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data, para->output_multiplier_), -total_shift) + + para->out_args_.zp_; + output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/div_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/div_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..26507047f2ef12bf58f9de2023f6f7dd6d80473f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/div_int8.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DIV_INT8_H_ +#define NNACL_INT8_DIV_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/int8/fixed_point.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DivInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const DivQuantArg *para); + +int DivScalarInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const DivQuantArg *para); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DIV_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_gather_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_gather_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..c376c66a01ccafe6a164fb3a269323e0c5d1efea --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_gather_int8.c @@ -0,0 +1,76 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include "nnacl/int8/dynamic_gather_int8.h" +#include "nnacl/op_base.h" + +void DynamicGather(const int8_t *input, int outer_size, int inner_size, int limit, const int32_t *indices, + int indices_element_size, float *output, const float *scale_in, const int32_t *zp_in) { + for (int m = 0; m < outer_size; ++m) { + const int8_t *int8_in_m = input + inner_size * m * limit; + float *int8_out_m = output + inner_size * m * indices_element_size; + for (int i = 0; i < indices_element_size; ++i) { + int index = indices[i]; + index = index < 0 ? index + limit : index; + const float scale = scale_in[index]; + const int zp = zp_in[index]; + float *out = int8_out_m + i * inner_size; + const int8_t *src = int8_in_m + index * inner_size; +#ifndef ENABLE_ARM64 + for (int j = 0; j < inner_size; ++j) { + out[j] = (src[j] - zp) * scale; + } +#else + int count_16 = DOWN_ROUND(inner_size, C16NUM); + DynamicGatherArm64(src, out, count_16, zp, scale); + for (int j = count_16; j < inner_size; ++j) { + out[j] = (src[j] - zp) * scale; + } +#endif + } + } + return; +} + +#ifdef ENABLE_FP16 +void DynamicGatherForFp16(const int8_t *input, int outer_size, int inner_size, int limit, const int32_t *indices, + int indices_element_size, float16_t *output, const float *scale_in, const int32_t *zp_in) { + for (int m = 0; m < outer_size; ++m) { + const int8_t *int8_in_m = input + inner_size * m * limit; + float16_t *int8_out_m = output + inner_size * m * indices_element_size; + for (int i = 0; i < indices_element_size; ++i) { + int index = indices[i]; + index = index < 0 ? index + limit : index; + const float scale = scale_in[index]; + const int zp = zp_in[index]; + float16_t *out = int8_out_m + i * inner_size; + const int8_t *src = int8_in_m + index * inner_size; +#ifndef ENABLE_ARM64 + for (int j = 0; j < inner_size; ++j) { + out[j] = (float16_t)(src[j] - zp) * scale; + } +#else + int count_16 = DOWN_ROUND(inner_size, C16NUM); + DynamicGatherArm64ForFp16(src, out, count_16, zp, scale); + for (int j = count_16; j < inner_size; ++j) { + out[j] = (float16_t)((src[j] - zp) * scale); + } +#endif + } + } + return; +} +#endif diff --git a/mindspore-lite/src/extendrt/kernel/cuda/unique.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_gather_int8.h similarity index 40% rename from mindspore-lite/src/extendrt/kernel/cuda/unique.h rename to mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_gather_int8.h index 9c217ebdebec852776c6378d5286052cea0a2f7f..fc3853d074372bbb1485967f8d122153ff365903 100644 --- a/mindspore-lite/src/extendrt/kernel/cuda/unique.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_gather_int8.h @@ -14,27 +14,27 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CUDA_UNIQUE_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CUDA_UNIQUE_H_ +#ifndef NNACL_INT8_DYNAMIC_GATHER_INT8_H_ +#define NNACL_INT8_DYNAMIC_GATHER_INT8_H_ -#include -#include -#include "src/extendrt/kernel/cuda/cuda_kernel.h" -#include "cuda_impl/cuda_class/unique_helper.h" +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" -namespace mindspore::kernel { -class UniqueCudaKernel : public CudaKernel { - public: - UniqueCudaKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::InnerContext *ctx) - : CudaKernel(parameter, inputs, outputs, ctx) {} - ~UniqueCudaKernel() override = default; - int Prepare() override; - int PostProcess() override; - int Run() override; +#ifdef __cplusplus +extern "C" { +#endif +void DynamicGather(const int8_t *input, int outer_size, int inner_size, int limit, const int32_t *indices, + int indices_element_size, float *output, const float *scale_in, const int32_t *zp_in); +void DynamicGatherArm64(const int8_t *src, float *output, int count_16, int zp, float scale); + +#ifdef ENABLE_FP16 +void DynamicGatherForFp16(const int8_t *input, int outer_size, int inner_size, int limit, const int32_t *indices, + int indices_element_size, float16_t *output, const float *scale_in, const int32_t *zp_in); +void DynamicGatherArm64ForFp16(const int8_t *src, float16_t *output, int count_16, int zp, float scale); +#endif - private: - std::shared_ptr> unique_helper_{nullptr}; -}; -} // namespace mindspore::kernel +#ifdef __cplusplus +} #endif + +#endif // NNACL_INT8_DYNAMIC_GATHER_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_matmul_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_matmul_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..c30a2fe25153f28da8b0fc8eddaeb5c4adaa2ba2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_matmul_int8.c @@ -0,0 +1,420 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/dynamic_matmul_int8.h" +#include "nnacl/int8/fixed_point.h" + +void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales, + const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, + const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode) { + /* * + * row4x4-major * row4x16-major => (int8)row-major + * support activation per-layer symmetric && weight per-layer/per-channel symmetric + * */ + for (int r = 0; r < row; r++) { + int64_t s2 = a_sums[r]; + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c16div = c / C16NUM, c16mod = c % C16NUM; + int32_t s1 = 0; + for (int d = 0; d < deep4; d++) { + int d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r4div * deep4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod; + size_t bi = c16div * deep4 * C16NUM + d4div * C4NUM * C16NUM + c16mod * C4NUM + d4mod; + s1 += a[ai] * b[bi]; + } + int64_t s3 = b_sums[c] * a_zp; + int64_t s4 = a_zp * b_zp_sum; + size_t ci = r * stride / sizeof(float) + c; + int scale_offset = mode == 0 ? 0 : (mode == 1 ? c : (mode == C2NUM ? r : r * C16NUM + c)); + out[ci] = multi_scales[scale_offset] * (s1 - s2 - s3 + s4); + if (bias != NULL) { + out[ci] += bias[c]; + } + if (act_type == ActType_Relu) { + out[ci] = MSMAX(0, out[ci]); + } else if (act_type == ActType_Relu6) { + out[ci] = MSMAX(0, out[ci]); + out[ci] = MSMIN(C6NUM, out[ci]); + } + } + } + return; +} + +#ifdef ENABLE_FP16 +void DynamicMatmul4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, + const float *multi_scales, const float16_t *bias, size_t row, size_t col, + size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, + int64_t b_zp_sum, int64_t act_type, int64_t mode) { + /* * + * row4x4-major * row4x16-major => (int8)row-major + * support activation per-layer symmetric && weight per-layer/per-channel symmetric + * */ + for (int r = 0; r < row; r++) { + int64_t s2 = a_sums[r]; + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c16div = c / C16NUM, c16mod = c % C16NUM; + int32_t s1 = 0; + for (int d = 0; d < deep4; d++) { + int d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r4div * deep4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod; + size_t bi = c16div * deep4 * C16NUM + d4div * C4NUM * C16NUM + c16mod * C4NUM + d4mod; + s1 += a[ai] * b[bi]; + } + int64_t s3 = b_sums[c] * a_zp; + int64_t s4 = a_zp * b_zp_sum; + size_t ci = r * stride / sizeof(float16_t) + c; + int scale_offset = mode == 0 ? 0 : (mode == 1 ? c : (mode == C2NUM ? r : r * C16NUM + c)); + out[ci] = multi_scales[scale_offset] * (s1 - s2 - s3 + s4); + if (bias != NULL) { + out[ci] += bias[c]; + } + if (act_type == ActType_Relu) { + out[ci] = MSMAX(0, out[ci]); + } else if (act_type == ActType_Relu6) { + out[ci] = MSMAX(0, out[ci]); + out[ci] = MSMIN(C6NUM, out[ci]); + } + } + } + return; +} +#endif + +void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, + int deep, int deep16, size_t stride, int input_zp, const float *input_scale, + const float *filter_scale, int filter_zp, bool input_per_channel, bool filter_per_channel, + int64_t act_type) { + /* * + * row4x16-major * row16x4-major => (int8)row-major + * support activation per-layer symmetric && weight per-layer/per-channel symmetric + * */ + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c4div = c / C4NUM, c4mod = c % C4NUM; + int32_t value = 0; + int32_t s0 = 0; + int32_t s1 = 0; + int32_t s2 = 0; + int32_t s3 = 0; + for (int d = 0; d < deep; d++) { + int d16div = d / C16NUM, d16mod = d % C16NUM; + size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; + s0 += a[ai] * b[bi]; + s1 += filter_zp * a[ai]; + s2 += input_zp * b[bi]; + s3 += input_zp * filter_zp; + } + value = s0 - s1 - s2 + s3; + int input_quant_index = input_per_channel ? r : 0; + int filter_quant_index = filter_per_channel ? c : 0; + float multi_scale = input_scale[input_quant_index] * filter_scale[filter_quant_index]; + size_t ci = r * stride + c; + dst[ci] = multi_scale * value; + if (bias != NULL) { + dst[ci] += bias[c]; + } + if (act_type == ActType_Relu) { + dst[ci] = MSMAX(0, dst[ci]); + } else if (act_type == ActType_Relu6) { + dst[ci] = MSMAX(0, dst[ci]); + dst[ci] = MSMIN(C6NUM, dst[ci]); + } + } + } + return; +} + +#ifdef ENABLE_ARM64 +void PackInput4x4Asm(const int8_t *src_ic, int8_t *pack_ic, size_t ic_4div, size_t input_channel) { + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + asm volatile( + "dup v2.4s, wzr \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x15, #0 \n" + "1: \n" + "cmp x15, %[ic_4div] \n" + "add x15, x15, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "b 1b \n" + + "3: \n" /* ic res 1 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "b 6f \n" + + "4: \n" /* ic res 2 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "b 6f \n" + + "5: \n" /* ic res 3 */ + "dup v0.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "b 6f \n" + + "6: \n" + + : + : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), + [ ic_4res ] "r"(ic_4res) + : "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3"); +} +#endif + +void PackInput4x4(const int8_t *src_input, int8_t *packed_input, size_t input_channel, size_t plane_size) { + int ic4 = UP_ROUND(input_channel, C4NUM); + size_t hw_4div = plane_size / C4NUM * C4NUM; + size_t ic_4div = input_channel / C4NUM * C4NUM; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (int hwi = 0; hwi < hw_4div; hwi += C4NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; +#ifdef ENABLE_ARM64 + PackInput4x4Asm(src_ic, pack_ic, ic_4div, input_channel); +#else + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (size_t i = 0; i < C4NUM; i++) { + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C4NUM; i++) { + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int ici = input_channel; ici < ic4; ici += 1) { + for (int i = 0; i < C4NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } +#endif + src_r += input_channel * C4NUM; + pack_r += ic4 * C4NUM; + } + + if (hw_4div != plane_size) { + memset(pack_r, 0, C4NUM * ic4); + for (int hwi = hw_4div; hwi < plane_size; hwi += 1) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + src_r += input_channel; + pack_r += C4NUM; + } + } + return; +} + +// For matmul input a transpose case +void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, int col, int row_stride) { + const int row_tile = C4NUM; + int row_align = UP_ROUND(row, row_tile); + int row_div = row / row_tile * row_tile; + const int row_res = row - row_div; + + const int col_tile = C4NUM; + int col_div = col / col_tile * col_tile; + const int col_res = col - col_div; + + const int8_t *src_ic = NULL; + int8_t *packed_ic = NULL; + for (int c = 0; c < col_div; c += C4NUM) { + int r = 0; + src_ic = src_input + c; + packed_ic = packed_input + c * row_align; +#ifdef ENABLE_ARM64 + size_t row_stride_int64 = row_stride; + asm volatile( + "mov w10, %w[row]\n" + "mov x11, %[src_ic]\n" + "mov x12, %[packed_ic]\n" + "cmp w10, wzr\n" + "beq 1f\n" + "2:\n" + "subs w10, w10, #4\n" + "ld1 {v0.s}[0], [x11], %[row_stride]\n" + "ld1 {v1.s}[0], [x11], %[row_stride]\n" + "ld1 {v0.s}[1], [x11], %[row_stride]\n" + "ld1 {v1.s}[1], [x11], %[row_stride]\n" + "zip1 v2.8b, v0.8b, v1.8b\n" + "zip2 v3.8b, v0.8b, v1.8b\n" + "zip1 v4.4h, v2.4h, v3.4h\n" + "zip2 v5.4h, v2.4h, v3.4h\n" + "st1 {v4.4h, v5.4h}, [x12], #16\n" + + "bgt 2b\n" + "1:\n" + + : + : [ src_ic ] "r"(src_ic), [ packed_ic ] "r"(packed_ic), [ row ] "r"(row_div), [ row_stride ] "r"(row_stride_int64) + : "memory", "w10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12"); + packed_ic += C4NUM * row_div; + src_ic += row_div * row_stride; +#else + for (; r < row_div; r += C4NUM) { + for (int i = 0; i < row_tile; i++) { + packed_ic[0 * row_tile + i] = src_ic[i * row_stride + 0]; + packed_ic[1 * row_tile + i] = src_ic[i * row_stride + 1]; + packed_ic[2 * row_tile + i] = src_ic[i * row_stride + 2]; + packed_ic[3 * row_tile + i] = src_ic[i * row_stride + 3]; + } + packed_ic += C16NUM; + src_ic += row_tile * row_stride; + } +#endif + for (r = 0; r < row_res; ++r) { + for (int i = 0; i < C4NUM; ++i) { + packed_ic[i * row_tile + r] = src_ic[r * row_stride + i]; + } + } + } + if (col_res == 0) { + return; + } + src_ic = src_input + col_div; + packed_ic = packed_input + row_align * col_div; + for (int r = 0; r < row_div; r += row_tile) { + for (int i = 0; i < col_res; ++i) { + packed_ic[i * row_tile + 0] = src_ic[0 * row_stride + i]; + packed_ic[i * row_tile + 1] = src_ic[1 * row_stride + i]; + packed_ic[i * row_tile + 2] = src_ic[2 * row_stride + i]; + packed_ic[i * row_tile + 3] = src_ic[3 * row_stride + i]; + } + src_ic += row_tile * row_stride; + packed_ic += row_tile * col_tile; + } + + for (int r = 0; r < row_res; ++r) { + for (int c = 0; c < col_res; ++c) { + packed_ic[c * row_tile + r] = src_ic[r * row_stride + c]; + } + } +} + +void CalcWeightSums(const int8_t *weight, int row, int col, int32_t *dst, DataOrder order) { + if (order == RowMajor) { + for (int c = 0; c < col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + sum += weight[r * col + c]; + } + dst[c] = sum; + } + } else { + for (int c = 0; c < col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + sum += weight[c * row + r]; + } + dst[c] = sum; + } + } + return; +} + +void CalcPartWeightSums(const int8_t *weight, int row, int stride, int cur_col, int32_t *dst, DataOrder order) { + if (order == RowMajor) { + for (int c = 0; c < cur_col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + sum += weight[r * stride + c]; + } + dst[c] = sum; + } + } else { + for (int c = 0; c < cur_col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + sum += weight[c * row + r]; + } + dst[c] = sum; + } + } + return; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_matmul_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_matmul_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..6b46b7e73dea2c7e31ce61434fcc13ad472e7990 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_matmul_int8.h @@ -0,0 +1,74 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DYNAMIC_MATMUL_H_ +#define NNACL_INT8_DYNAMIC_MATMUL_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void PackInput2Col4x4(const int8_t *src_input, int8_t *packed_input, int row, int col, int row_stride); +void PackInput4x4(const int8_t *src_input, int8_t *packed_input, size_t input_channel, size_t plane_size); +void DynamicMatmul4x16x4AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, + int deep, int deep16, size_t stride, int input_zp, const float *input_scale, + const float *filter_scale, int filter_zp, bool input_per_channel, bool filter_per_channel, + int64_t act_type); +void CalcWeightSums(const int8_t *weight, int row, int col, int32_t *dst, DataOrder order); +void CalcPartWeightSums(const int8_t *weight, int row, int stride, int cur_col, int32_t *dst, DataOrder order); +#if defined(ENABLE_ARM64) && !defined(USE_AOS_GCC_TOOLCHAIN) +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmulSdot4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales, + const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, + const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode); +#endif +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmul4x4x16AIWI(const int8_t *a, const int8_t *b, float *out, size_t deep4, const float *multi_scales, + const float *bias, size_t row, size_t col, size_t stride, const int32_t *a_sums, + const int32_t *b_sums, int64_t a_zp, int64_t b_zp_sum, int64_t act_type, int64_t mode); +#ifdef ENABLE_FP16 +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmul4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, + const float *multi_scales, const float16_t *bias, size_t row, size_t col, + size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, + int64_t b_zp_sum, int64_t act_type, int64_t mode); +/* + * mode is used to distinguish different quantization scenarios, whose value is 0-3. + * 0: TensorByTensor, 1: TensorByChannel, 2: ChannelByTensor, 3: ChannelByChannel. + */ +void DynamicMatmulSdot4x4x16AIWIForFp16(const int8_t *a, const int8_t *b, float16_t *out, size_t deep4, + const float *multi_scales, const float16_t *bias, size_t row, size_t col, + size_t stride, const int32_t *a_sums, const int32_t *b_sums, int64_t a_zp, + int64_t b_zp_sum, int64_t act_type, int64_t mode); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DYNAMIC_MATMUL_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_quant_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_quant_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..c1041ddfcda62e87130dc7d7253729c3d3c54bb4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_quant_int8.c @@ -0,0 +1,91 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/dynamic_quant_int8.h" + +void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max) { + if (count == 0) { + return; + } +#ifndef ENABLE_ARM64 + for (int i = 0; i < count; ++i) { + *real_min = data[i] < *real_min ? data[i] : *real_min; + *real_max = data[i] > *real_max ? data[i] : *real_max; + } +#else + // avoid to compile optimize. + volatile int count_4 = DOWN_ROUND(count, C4NUM); + asm volatile( + "mov x4, %[data]\n" // reload data + "mov w5, %w[count_4]\n" // reload count + "ld1 {v31.4s}, [x4]\n" // min + "ld1 {v30.4s}, [x4], #16\n" // max + "subs w5, w5, #4\n" + "ble 1f\n" + + "0:\n" + "ld1 {v0.4s}, [x4], #16\n" + "fmin v31.4s, v31.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "subs w5, w5, #4\n" + "bgt 0b\n" + + "1:\n" + "fminv s6, v31.4s\n" + "fmaxv s7, v30.4s\n" + + "str s6, [%[real_min]]\n" + "str s7, [%[real_max]]\n" + + : + : [ data ] "r"(data), [ count_4 ] "r"(count_4), [ real_min ] "r"(real_min), [ real_max ] "r"(real_max) + : "x4", "w5", "s6", "s7", "v0", "v30", "v31"); + for (int i = count_4; i < count; ++i) { + *real_min = data[i] < *real_min ? data[i] : *real_min; + *real_max = data[i] > *real_max ? data[i] : *real_max; + } +#endif +} + +void CalculateChannelRowMinMax(const float *data, int count, float *real_min, float *real_max, int row_length) { + if (row_length == 0) { + return; + } + int channel_total = count / row_length; + for (int i = 0; i < channel_total; i++) { + CalculateMinMaxFp32(data + i * row_length, row_length, real_min + i, real_max + i); + } +} + +void CalculateChannelColMinMax(const float *data, int count, float *real_min, float *real_max, int row_length) { + if (row_length == 0) { + return; + } + int row_total = count / row_length; + for (int r = 0; r < row_total; r++) { + const float *data_current = data + r * row_length; + for (int c = 0; c < row_length; c++) { + float *real_min_channel = real_min + c; + float *real_max_channel = real_max + c; + if (data_current[c] < *real_min_channel) { + *real_min_channel = data_current[c]; + } + if (data_current[c] > *real_max_channel) { + *real_max_channel = data_current[c]; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_quant_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_quant_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..c7a93b419c1179aceb4542ab5ba30fd6604eba3f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/dynamic_quant_int8.h @@ -0,0 +1,34 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_DYNAMIC_QUANT_INT8_H_ +#define NNACL_INT8_DYNAMIC_QUANT_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/pow_parameter.h" +#include "nnacl/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +void CalculateMinMaxFp32(const float *data, int count, float *real_min, float *real_max); +void CalculateChannelRowMinMax(const float *data, int count, float *real_min, float *real_max, int row_length); +void CalculateChannelColMinMax(const float *data, int count, float *real_min, float *real_max, int row_length); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_DYNAMIC_QUANT_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/fixed_point.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/fixed_point.c new file mode 100644 index 0000000000000000000000000000000000000000..a100286eca6bc13ecddb827ee097dfa9a7ce83c2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/fixed_point.c @@ -0,0 +1,276 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/fixed_point.h" + +#define C31NUM 31 + +// returns the high-32 bits of a * b with rounding +// assume that a and b is divided by 2^31, who fall into [-1, 1] +// so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31 +// actually we compute 2 * a * b / 2^32 +// and take 32 bits of mantissa for rounding +int SaturatingRoundingDoublingHighMul(int a, int b) { + if (a == INT_MIN && b == INT_MIN) { + return INT_MAX; + } + int64_t ab = ((int64_t)a) * ((int64_t)b); + int64_t rounding = ab >= 0 ? (1ll << 30) : (1ll - (1ll << 30)); + // do not apply right shift to potential negetive values + int ab_mantissa = (int)((ab + rounding) / (1ll << 31)); + return ab_mantissa; +} + +int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b) { + if (a == SHRT_MIN && b == SHRT_MIN) { + return SHRT_MAX; + } + int32_t ab = ((int32_t)a) * ((int32_t)b); + int16_t rounding = ab >= 0 ? (1ll << 14) : (1ll - (1ll << 14)); + return (int16_t)((ab + rounding) / (1ll << 15)); +} + +// division by a 2^exponent with rounding +// or arithmetic right shift with rounding +int RoundingDivideByPOT(int x, int exponent) { + if (exponent > C31NUM) { + exponent = C31NUM; + } + const int mask = (1ll << exponent) - 1; + const int remainder = x & mask; + const int threshold = (mask >> 1) + (x < 0 ? 1 : 0); + return (x >> exponent) + (remainder > threshold ? 1 : 0); +} + +int UpwardRounding(int x, int exponent) { + const int32_t rounding_offset = (exponent > 0) ? (1 << (exponent - 1)) : 0; + if (x > INT32_MAX - rounding_offset) { + return 1 << (31 - exponent); + } + return (x + rounding_offset) >> exponent; +} + +int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) { + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); +} + +int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift, + int32_t right_shift) { + return UpwardRounding(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); +} + +int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift) { + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value, multiplier), right_shift); +} + +int FractionsBits(int integer_bits) { return 8 * (int)(sizeof(int32_t)) - 1 - integer_bits; } + +int FixedPoint_One(int integer_bits, int fractions_bits) { + return (integer_bits == 0 ? INT32_MAX : ((1) << (uint32_t)(integer_bits == 0 ? 0 : fractions_bits))); +} + +int RoundingHalfSum(int32_t a, int32_t b) { + int64_t sum = (int64_t)a + (int64_t)b; + return (int32_t)((sum + (sum > 0 ? 1 : -1)) / 2); +} + +int32_t BitAnd(int32_t a, int32_t b) { return (uint32_t)a & (uint32_t)b; } + +int32_t BitOr(int32_t a, int32_t b) { return (uint32_t)a | (uint32_t)b; } + +int32_t BitXor(int32_t a, int32_t b) { return (uint32_t)a ^ (uint32_t)b; } + +int32_t BitNot(int32_t a) { return ~(uint32_t)a; } + +int BitsSelect(int mask, int bound, int val) { return BitXor(BitAnd(mask, bound), BitAnd(BitNot(mask), val)); } + +int ConstantPOT(int fractional_bits, int exponent) { return (1 << (uint32_t)(fractional_bits + exponent)); } + +int32_t MaskIfNonZero(int32_t a) { return a ? BitNot(0) : 0; } + +int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); } + +int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero((a < b)); } + +uint32_t CountLeadingZeroBits(uint32_t x) { + if (x == 0) { + return 8 * sizeof(uint32_t) - 1; + } +#if defined(__GUNC__) + return __builtin_clz(x); +#else + const uint32_t leading_positive = (uint32_t)(1) << (8 * sizeof(uint32_t) - 1); + uint32_t leading_zeros = 0; + while (x < leading_positive) { + x <<= 1; + leading_zeros++; + } + return leading_zeros; +#endif +} + +uint32_t CountLeadingSignBits(int32_t x) { + if (x == 0) { + return 8 * sizeof(int32_t) - 1; + } +#if defined(__GUNC__) && !defined(__clang__) + return __builtin_clrsb(x); +#else + return x >= 0 ? CountLeadingZeroBits((uint32_t)x) - 1 : x != INT32_MIN ? CountLeadingZeroBits(2 * (uint32_t)(-x)) : 0; +#endif +} + +int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent) { + if (exponent > 0) { + const int min = INT32_MIN; + const int max = INT32_MAX; + const int scalar_int_bits = 8 * (int)(sizeof(int32_t)); + const int threshold = ((1 << (uint32_t)(scalar_int_bits - 1 - exponent)) - 1); + const int positive_mask = x > threshold ? BitNot(0) : 0; + const int negative_mask = x < -threshold ? BitNot(0) : 0; + int result = x * ((int32_t)(1) << (uint32_t)exponent); + result = BitsSelect(positive_mask, max, result); + result = BitsSelect(negative_mask, min, result); + return result; + } else if (exponent < 0) { + return RoundingDivideByPOT(x, -exponent); + } else { + return x; + } +} + +int32_t Rescale(int x, int integer_bits_src, int integer_bits_dst) { + int exponent = integer_bits_src - integer_bits_dst; + return SaturatingRoundingMultiplyByPOT(x, exponent); +} + +int32_t reciprocal_on_interval_between_0_1(int32_t a) { + int one = FixedPoint_One(0, FractionsBits(0)); + int half_sum = RoundingHalfSum(a, one); + const int constant_48_over_17 = 1515870810; + const int constant_neg_32_over_17 = -1010580540; + int x = constant_48_over_17 + SaturatingRoundingDoublingHighMul(half_sum, constant_neg_32_over_17); + for (int i = 0; i < 3; i++) { + int half_sum_times_x = SaturatingRoundingDoublingHighMul(half_sum, x); + int one_minus_half_sum_times_x = FixedPoint_One(2, FractionsBits(2)) - half_sum_times_x; + x = x + Rescale(SaturatingRoundingDoublingHighMul(x, one_minus_half_sum_times_x), 2 + 2, 2); + } + return Rescale(x, 2 - 1, 0); +} + +int32_t ComputerReciprocal(int32_t x, uint32_t x_digits, int32_t *recip_shift) { + uint32_t leading_zreos_plus_one = CountLeadingZeroBits((uint32_t)x); + *recip_shift = x_digits - leading_zreos_plus_one; + const int32_t shifted_minus_one = (int32_t)(((uint32_t)x << leading_zreos_plus_one) - ((uint32_t)(1) << 31)); + const int32_t shifted_scaled = reciprocal_on_interval_between_0_1(shifted_minus_one); + return shifted_scaled; +} + +int exp_on_interval_values(int a) { + const int constant_neg_1_over_8 = 1895147668; + const int constant_1_over_3 = 715827883; + int fractional_bits = FractionsBits(0); + int x = a + ConstantPOT(fractional_bits, -3); + int x2 = SaturatingRoundingDoublingHighMul(x, x); + int x3 = SaturatingRoundingDoublingHighMul(x2, x); + int x4 = SaturatingRoundingDoublingHighMul(x2, x2); + int x4_over_4 = SaturatingRoundingMultiplyByPOT(x4, -2); + int x4_over_24_plus_x3_over_6_plus_x2_over_2 = + SaturatingRoundingMultiplyByPOT((SaturatingRoundingDoublingHighMul((x4_over_4 + x3), constant_1_over_3) + x2), -1); + return constant_neg_1_over_8 + + SaturatingRoundingDoublingHighMul(constant_neg_1_over_8, (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); +} + +void exp_barrel_shifter(int exponent, int muliplier, int integer_bits, int fractional_bits, int remainder, + int32_t *result) { + if (integer_bits > exponent) { + int total_shift = integer_bits > exponent ? fractional_bits + exponent : 0; + *result = BitsSelect(MaskIfNonZero(BitAnd(remainder, (1 << (uint32_t)total_shift))), + SaturatingRoundingDoublingHighMul(*result, muliplier), *result); + } +} + +int exp_on_negative_values(int a, const int integer_bits) { + int fractional_bits = FractionsBits(integer_bits); + const int one_quarter = ConstantPOT(fractional_bits, -2); + int a_mod_quarter_minus_one_quarter = ((unsigned)(a) & (one_quarter - 1)) - one_quarter; + int result = exp_on_interval_values(Rescale(a_mod_quarter_minus_one_quarter, integer_bits, 0)); + int remainder = a_mod_quarter_minus_one_quarter - a; + + exp_barrel_shifter(-2, 1672461947, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(-1, 1302514674, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+0, 790015084, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+1, 290630308, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+2, 39332535, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+3, 720401, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+4, 242, integer_bits, fractional_bits, remainder, &result); + + int clamp_bits = integer_bits > 5 ? 36 - integer_bits : 0; + if (integer_bits > 5) { + const int clamp = -(1 << (uint32_t)clamp_bits); + result = BitsSelect(MaskIfLessThan(a, clamp), 0, result); + } + result = BitsSelect(MaskIfZero(a), FixedPoint_One(0, fractional_bits), result); + return result; +} + +void GetSqrtQuantMultiplierExp(int32_t input, int reverse_shift, int32_t *multiplier, int32_t *shift) { + if (input <= 1) { + *multiplier = INT_MAX; + *shift = 0; + } + *shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*shift; + } + uint32_t max_left_shift_bits = CountLeadingSignBits(input); + if (max_left_shift_bits < 2) { + return; + } + uint32_t left_shift_bit_pairs = max_left_shift_bits / 2 - 1; + *shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + int32_t fixedpoint_f3_input = input >> 1; // sign: 1 bit, integer: 3 bit, fractional: 28 bit + int32_t fp_f3_half_input = SaturatingRoundingMultiplyByPOT(fixedpoint_f3_input, -1); + int32_t fp_f3_half_three = (1 << 28) + (1 << 27); + int32_t tmp = (1 << 28); // one + for (int i = 0; i < 5; i++) { + int32_t tmp3 = Rescale(SaturatingRoundingDoublingHighMul(tmp, SaturatingRoundingDoublingHighMul(tmp, tmp)), 9, 3); + tmp = Rescale(SaturatingRoundingDoublingHighMul(fp_f3_half_three, tmp) - + SaturatingRoundingDoublingHighMul(fp_f3_half_input, tmp3), + 6, 3); + } + const int32_t fp_f0_half_sqrt_2 = 1518500250; // sqrt(2) / 2 + tmp = SaturatingRoundingDoublingHighMul(tmp, fp_f0_half_sqrt_2); + *multiplier = tmp; + if (*shift < 0) { + *multiplier <<= -*shift; + *shift = 0; + } + *shift *= reverse_shift; +} + +#ifdef ENABLE_NEON +int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) { + const int32x4_t shift_vec = vdupq_n_s32(-exponent); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); + return vrshlq_s32(fixed_up_x, shift_vec); +} + +int32x4_t SaturatingRoundingDoublingHighMulInt32x4(int32x4_t a, int32x4_t b) { return vqrdmulhq_s32(a, b); } +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/fixed_point.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/fixed_point.h new file mode 100644 index 0000000000000000000000000000000000000000..503a5e1d672a7e226d7cf1c71dbac8cd9724d863 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/fixed_point.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_QUANTIZATION_FIXED_POINT_H_ +#define NNACL_QUANTIZATION_FIXED_POINT_H_ + +#include +#include +#ifdef ENABLE_NEON +#include +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +// returns the high-32 bits of a * b with rounding +// assume that a and b is divided by 2^31, who fall into [-1, 1] +// so the mantissa of a * b is (a / 2^31) * (b / 2^31) * 2^31= (a * b) / 2^31 +// actually we compute 2 * a * b / 2^32 +// and take 32 bits of mantissa for rounding +int SaturatingRoundingDoublingHighMul(int a, int b); + +int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b); + +// division by a 2^exponent with rounding +// or arithmetic right shift with rounding +int RoundingDivideByPOT(int x, int exponent); + +int UpwardRounding(int x, int exponent); + +int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift); + +int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift, + int32_t right_shift); + +int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift); + +int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent); + +int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst); + +uint32_t CountLeadingSignBits(int32_t x); + +int32_t ComputerReciprocal(int32_t x, uint32_t x_digits, int32_t *recip_shift); + +int exp_on_negative_values(int a, const int tIntegerBits); + +void GetSqrtQuantMultiplierExp(int32_t input, int reverse_shift, int32_t *multiplier, int32_t *shift); + +#ifdef __cplusplus +} +#endif + +#ifdef ENABLE_NEON +int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent); + +int32x4_t SaturatingRoundingDoublingHighMulInt32x4(int32x4_t a, int32x4_t b); +#endif + +#endif // NNACL_QUANTIZATION_FIXED_POINT_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/gatherNd_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/gatherNd_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..3a5fefe5ca1b3fa7363964c50383cb002c247bd1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/gatherNd_int8.c @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/gatherNd_int8.h" +#include +#include "nnacl/errorcode.h" + +int GatherNdInt8(int8_t *input, int8_t *output, const int32_t *in_offset, int area, int count, GatherQuantArg param) { + double alpha = param.alpha_; + int z1 = param.zp_in_; + int z2 = param.zp_out_; + for (int i = 0; i < count; ++i) { + for (int j = 0; j < area; ++j) { + int32_t tmp = round(alpha * (input[in_offset[i] + j] - z1)) + z2; + tmp = tmp > 127 ? 127 : tmp; + tmp = tmp < -128 ? -128 : tmp; + output[area * i + j] = (int8_t)tmp; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/gatherNd_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/gatherNd_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..1ef8ee2dd603472f56cabd5469990ab5ef3c939e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/gatherNd_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_GATHERND_INT8_H_ +#define NNACL_INT8_GATHERND_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int GatherNdInt8(int8_t *in_data, int8_t *out_data, const int32_t *in_offset, int area, int count, + GatherQuantArg param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_GATHERND_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/gather_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/gather_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..2964893419724de3ae94651a9506cbef66307de9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/gather_int8.c @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +#include "nnacl/int8/gather_int8.h" +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/errorcode.h" + +int GatherInt8Int32Index(const int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, + const int32_t *indices, int indices_element_size, GatherQuantArg para) { + double alpha = para.alpha_; + int z1 = para.zp_in_; + int z2 = para.zp_out_; + int i, m, j; + for (m = 0; m < outer_size; ++m) { + const int8_t *inputm = in_data + inner_size * m * limit; + int8_t *outputm = out_data + inner_size * m * indices_element_size; + for (i = 0; i < indices_element_size; ++i) { + if (indices[i] < 0 || indices[i] > limit) { + return NNACL_ERR; + } + for (j = 0; j < inner_size; ++j) { + int32_t tmp = round(alpha * (inputm[indices[i] * inner_size + j] - z1)) + z2; + tmp = tmp > 127 ? 127 : tmp; + tmp = tmp < -128 ? -128 : tmp; + outputm[i * inner_size + j] = (int8_t)tmp; + } + } + } + return NNACL_OK; +} + +int GatherInt8Int64Index(const int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, + const int64_t *indices, int indices_element_size, GatherQuantArg para) { + double alpha = para.alpha_; + int z1 = para.zp_in_; + int z2 = para.zp_out_; + int i, m, j; + for (m = 0; m < outer_size; ++m) { + const int8_t *inputm = in_data + inner_size * m * limit; + int8_t *outputm = out_data + inner_size * m * indices_element_size; + for (i = 0; i < indices_element_size; ++i) { + if (indices[i] < 0 || indices[i] > limit) { + return NNACL_ERR; + } + for (j = 0; j < inner_size; ++j) { + int32_t tmp = round(alpha * (inputm[indices[i] * inner_size + j] - z1)) + z2; + tmp = tmp > 127 ? 127 : tmp; + tmp = tmp < -128 ? -128 : tmp; + outputm[i * inner_size + j] = (int8_t)tmp; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/gather_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/gather_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..e5b3d12abe30912459643a38160f9ff464a2a444 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/gather_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_GATHER_INT8_H_ +#define NNACL_INT8_GATHER_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int GatherInt8Int32Index(const int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, + const int32_t *indices, int indices_element_size, GatherQuantArg para); + +int GatherInt8Int64Index(const int8_t *in_data, int8_t *out_data, int outer_size, int inner_size, int limit, + const int64_t *indices, int indices_element_size, GatherQuantArg para); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_GATHER_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/hswish_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/hswish_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..b3814ac00675af03ec1823942ca51e9698b86f76 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/hswish_int8.c @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/hswish_int8.h" + +int16_t SaturatingLeftShift(int16_t value, int shift_num) { + int32_t result = (int32_t)value * (1 << shift_num); + return MSMAX(MSMIN(result, SHRT_MAX), SHRT_MIN); +} + +int HSwishInt8(const int8_t *src, int length, int8_t *dst, const HswishQuantArg *arg) { + for (int i = 0; i < length; i++) { + const int16_t input_value = src[i] - arg->input_zp; + const int16_t input_value_scale = input_value * (1 << 7); + const int16_t input_value_on_preshift_output_scale = + SaturatingRoundingDoublingHighMulInt16(input_value_scale, arg->output_multiplier_fixedpoint_int16); + int16_t relu6_value = input_value_scale; + if (arg->relu6_multiplier_exponent > 0) { + relu6_value = SaturatingLeftShift(relu6_value, arg->relu6_multiplier_exponent - 1); + } + relu6_value = SaturatingRoundingDoublingHighMulInt16(relu6_value, arg->relu6_multiplier_fixedpoint_int16); + + if (arg->relu6_multiplier_exponent > 0) { + relu6_value = SaturatingLeftShift(relu6_value, 1); + } + if (arg->relu6_multiplier_exponent < 0) { + relu6_value = RoundingDivideByPOT(relu6_value, -arg->relu6_multiplier_exponent); + } + relu6_value = (size_t)(relu6_value + (1 << 15)) >> 1; + const int16_t preshift_output_value = + SaturatingRoundingDoublingHighMulInt16(relu6_value, input_value_on_preshift_output_scale); + + int16_t output = RoundingDivideByPOT(preshift_output_value, -arg->output_multiplier_exponent); + output += arg->output_zp; + output = MSMIN(output, 127); + output = MSMAX(output, -128); + dst[i] = (int8_t)output; + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/hswish_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/hswish_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..a2cd5c2115c6e10b3eeb9bf2b0fa72f6edefb4b9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/hswish_int8.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_HSWISH_INT8_H_ +#define NNACL_INT8_HSWISH_INT8_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/int8/fixed_point.h" + +typedef struct HswishQuantArg { + double input_scale; + int32_t input_zp; + double output_scale; + int32_t output_zp; + int16_t relu6_multiplier_fixedpoint_int16; + int32_t relu6_multiplier_exponent; + int16_t output_multiplier_fixedpoint_int16; + int32_t output_multiplier_exponent; +} HswishQuantArg; + +#ifdef __cplusplus +extern "C" { +#endif +int HSwishInt8(const int8_t *src, int length, int8_t *dst, const HswishQuantArg *arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_HSWISH_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/l2_norm_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/l2_norm_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..3a2265409230318b59d03d6445f97b64e6304775 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/l2_norm_int8.c @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/int8/l2_norm_int8.h" +#include +#include "nnacl/int8/fixed_point.h" +#include "nnacl/errorcode.h" + +int L2NormalizationInt8(const int8_t *input_data, int8_t *output_data, const L2NormParameter *param, + const L2NormQuantArg *quant_param, const int begin, const int end) { + const int inner_size = param->shape_[param->shape_num_ - 1]; + + for (int i = begin; i < end; ++i) { + int32_t square_sum = 0; + for (int j = 0; j < inner_size; ++j) { + int32_t in = input_data[i * inner_size + j] - quant_param->in_.zp_; + square_sum += in * in; + } + int32_t multiplier; + int32_t shift; + GetSqrtQuantMultiplierExp(square_sum, -1, &multiplier, &shift); + for (int k = 0; k < inner_size; ++k) { + int32_t in = input_data[i * inner_size + k] - quant_param->in_.zp_; + int32_t out = RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(in * (1 << 7), multiplier), -shift); + output_data[i * inner_size + k] = MSMIN(127, MSMAX(-128, out)); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/l2_norm_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/l2_norm_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..9786f3d917ff65db326d213583ba23bdd0f38f20 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/l2_norm_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_L2_NORM_INT8_H_ +#define NNACL_INT8_L2_NORM_INT8_H_ + +#include "nnacl/l2_norm_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int L2NormalizationInt8(const int8_t *input_data, int8_t *output_data, const L2NormParameter *param, + const L2NormQuantArg *quant_param, const int begin, const int end); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_L2_NORM_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/layer_norm_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/layer_norm_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..3f3aaf667a204e22d3466e39f2620545d6feedee --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/layer_norm_int8.c @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/layer_norm_int8.h" + +void LayerNormGammaAndBetaInt8(int8_t *dst, const int8_t *src, const float *gamma_data, const float *beta_data, + const LayerNormQuantArg *quant, int num, const float mean, const float deno) { + for (int i = 0; i < num; i++) { + float fp32_src = (src[i] - quant->in_zp_) * quant->in_scale_; + float fp32_dst = (fp32_src - mean) * deno; + fp32_dst = fp32_dst * gamma_data[i] + beta_data[i]; + int32_t int32_dst = (int32_t)round(fp32_dst * 1.0 / quant->out_scale_ + quant->out_zp_); + dst[i] = (int8_t)MSMAX(MSMIN(int32_dst, 127), -128); + } +} + +/* + * origin : (x-mean) / sqrt(variance + epsilon) * gamma + beta + * quant : (x-mean) / sqrt(sum(x * x) - mean * mean) * gamma + beta + * + * */ +int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data, + const LayerNormComputeParam *param, const LayerNormQuantArg *quant, int task_id, int thread_num) { + if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { + return NNACL_NULL_PTR; + } + NNACL_CHECK_ZERO_RETURN_ERR(param->params_inner_size_); + NNACL_CHECK_ZERO_RETURN_ERR(param->params_outer_size_); + + int step = UP_DIV(param->norm_outer_size_, thread_num); + int thread_end = NNACL_MIN((task_id + 1) * step, param->norm_outer_size_); + for (int i = task_id * step; i < thread_end; i++) { + const int8_t *src_norm = src_data + i * param->norm_inner_size_; + int8_t *dst_norm = dst_data + i * param->norm_inner_size_; + float mean = 0.0f; + float square_mean = 0.0f; + for (int j = 0; j < param->norm_inner_size_; j++) { + float float_src = (src_norm[j] - quant->in_zp_) * quant->in_scale_; + mean += float_src; + square_mean += float_src * float_src; + } + mean /= (float)param->norm_inner_size_; + square_mean /= (float)param->norm_inner_size_; + const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_); + + if (param->norm_outer_size_ <= param->params_outer_size_) { + for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) { + const int8_t *src_param = src_norm + x * param->params_inner_size_; + int8_t *dst_param = dst_norm + x * param->params_inner_size_; + LayerNormGammaAndBetaInt8(dst_param, src_param, gamma_data, beta_data, quant, param->norm_inner_size_, mean, + deno); + } + } else { + int x = i / param->params_outer_size_; + const float *gamma = gamma_data + x * param->norm_inner_size_; + const float *beta = beta_data + x * param->norm_inner_size_; + LayerNormGammaAndBetaInt8(dst_norm, src_norm, gamma, beta, quant, param->norm_inner_size_, mean, deno); + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/layer_norm_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/layer_norm_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..48a502add3e032fcb5b4d495056f0559e0e3fd1b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/layer_norm_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_LAYER_NORM_H_ +#define NNACL_INT8_LAYER_NORM_H_ + +#include "nnacl/errorcode.h" +#include "nnacl/layer_norm_parameter.h" +#include "nnacl/int8/fixed_point.h" +#include "nnacl/kernel/layer_norm.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormInt8(const int8_t *src_data, const float *gamma_data, const float *beta_data, int8_t *dst_data, + const LayerNormComputeParam *param, const LayerNormQuantArg *quant, int task_id, int thread_num); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_LAYER_NORM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/leaky_relu_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/leaky_relu_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..44b367de1df7a5870caa709ed8cbc69586b34f72 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/leaky_relu_int8.c @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/leaky_relu_int8.h" +#include "nnacl/errorcode.h" + +int DoLeakReluInt8(const int8_t *inputs, int8_t *output_ptr, const LeakyReluQuantArg *quant_prelu_parm, int task_id) { + if (quant_prelu_parm == NULL) { + return NNACL_NULL_PTR; + } + float output_scale = quant_prelu_parm->out_args_.scale_; + int output_zp = quant_prelu_parm->out_args_.zp_; + const float output_inverse_scale = 1.f / output_scale; + + float scale = quant_prelu_parm->in_args_.scale_ * output_inverse_scale; + float bias = -quant_prelu_parm->in_args_.zp_ * scale; + for (int j = task_id; j < quant_prelu_parm->element_num; j += quant_prelu_parm->thread_num_) { + if (inputs[j] <= 0) { + int32_t output_tmp = round(inputs[j] * quant_prelu_parm->slope_ * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[j] = 127; + } else if (output_tmp < -128) { + output_ptr[j] = -128; + } else { + output_ptr[j] = (int8_t)output_tmp; + } + } else { + int32_t output_tmp = round(inputs[j] * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[j] = 127; + } else if (output_tmp < -128) { + output_ptr[j] = -128; + } else { + output_ptr[j] = (int8_t)output_tmp; + } + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/leaky_relu_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/leaky_relu_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..b917ccd6b4f280de283f5cad2a6edf5f6acb8f5f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/leaky_relu_int8.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_PRELU_INT8_H_ +#define NNACL_INT8_PRELU_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int DoLeakReluInt8(const int8_t *inputs, int8_t *output_ptr, const LeakyReluQuantArg *quant_Prelu_parm, int task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_PRELU_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/matmul_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/matmul_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..0647bb46d88839d87f6f9eb41625cf366523e074 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/matmul_int8.c @@ -0,0 +1,839 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/matmul_int8.h" +#include "nnacl/int8/fixed_point.h" + +void RowMajor2Row2x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int col16 = UP_ROUND(col, C16NUM); + for (int r = 0; r < row; r++) { + int rd2 = r / C2NUM; + int rm2 = r % C2NUM; + for (int c = 0; c < col; c++) { + int cd16 = c / C16NUM; + int cm16 = c % C16NUM; + int dst_index = rd2 * col16 * C2NUM + cd16 * C2NUM * C16NUM + rm2 * C16NUM + cm16; + int src_index = r * col + c; + dst_ptr[dst_index] = src_ptr[src_index]; + } + } +} + +void RowMajor2Row4x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col) { + int row_div = row / C4NUM * C4NUM; + int col_4 = UP_ROUND(col, C4NUM); + int col_div = col / C4NUM * C4NUM; + + const int8_t *src_r4 = src; + int8_t *packed_r4 = dst; + const int8_t *src_c4 = NULL; + int8_t *packed_c4 = NULL; + for (int r = 0; r < row_div; r += C4NUM) { + src_c4 = src_r4; + packed_c4 = packed_r4; + + for (int c = 0; c < col_div; c += C4NUM) { + for (int i = 0; i < C4NUM; i++) { + packed_c4[i * C4NUM + 0] = src_c4[i * col + 0]; + packed_c4[i * C4NUM + 1] = src_c4[i * col + 1]; + packed_c4[i * C4NUM + 2] = src_c4[i * col + 2]; + packed_c4[i * C4NUM + 3] = src_c4[i * col + 3]; + } + src_c4 += C4NUM; + packed_c4 += C16NUM; + } + + if (col == col_div) { + continue; + } + memset(packed_c4, 0, C16NUM * sizeof(int8_t)); + for (int i = 0; i < C4NUM; ++i) { + for (int c = 0; c < col - col_div; ++c) { + packed_c4[i * C4NUM + c] = src_c4[i * col + c]; + } + } + src_r4 += C4NUM * col; + packed_r4 += C4NUM * col_4; + } + + if (row == row_div) { + return; + } + memset(packed_r4, 0, C4NUM * col_4); + src_c4 = src_r4; + packed_c4 = packed_r4; + for (int c = 0; c < col_div; c += C4NUM) { + for (int i = 0; i < row - row_div; ++i) { + packed_c4[i * C4NUM + 0] = src_c4[i * col + 0]; + packed_c4[i * C4NUM + 1] = src_c4[i * col + 1]; + packed_c4[i * C4NUM + 2] = src_c4[i * col + 2]; + packed_c4[i * C4NUM + 3] = src_c4[i * col + 3]; + } + src_c4 += C4NUM; + packed_c4 += C16NUM; + } + if (col == col_div) { + return; + } + for (int i = 0; i < row - row_div; ++i) { + for (int c = 0; c < col - col_div; ++c) { + packed_c4[i * C4NUM + c] = src_c4[i * col + c]; + } + } +} + +void RowMajor2Col16x2MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int row16 = UP_ROUND(row, C16NUM); + int stride = C16NUM * C2NUM; + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / C2NUM * (row16 / C16NUM) + r / C16NUM; + int dst_idx = stride * stride_idx + c % C2NUM * C16NUM + r % C16NUM; + int src_idx = r * col + c; + dst_ptr[dst_idx] = src_ptr[src_idx]; + } + } +} + +void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int col4 = UP_ROUND(col, C4NUM); + for (int r = 0; r < row; r++) { + int rd8 = r / C8NUM; + int rm8 = r % C8NUM; + for (int c = 0; c < col; c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + int dst_index = rd8 * col4 * C8NUM + cd4 * C8NUM * C4NUM + rm8 * C4NUM + cm4; + int src_index = r * col + c; + dst_ptr[dst_index] = src_ptr[src_index]; + } + } +} + +void MatrixPack4x16UnitInt8(const int8_t *src, int8_t *dst, int row, int col, int stride) { + for (int r = 0; r < row; r++) { + const int8_t *src_r = src + r * stride; + int8_t *dst_r = dst + r * C16NUM; + memcpy(dst_r, src_r, col * sizeof(int8_t)); + } + return; +} + +void MatrixEmptyInt8(int8_t *dst, int row, int col) { + for (int r = 0; r < row; r++) { + int8_t *dst_r = dst + r * C16NUM; + memset(dst_r, 0, col * sizeof(int8_t)); + } + return; +} + +void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int col4 = UP_ROUND(col, C4NUM); + for (int r = 0; r < row; r++) { + int rd16 = r / C16NUM; + int rm16 = r % C16NUM; + for (int c = 0; c < col; c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + int dst_index = rd16 * col4 * C16NUM + cd4 * C16NUM * C4NUM + rm16 * C4NUM + cm4; + int src_index = r * col + c; + dst_ptr[dst_index] = src_ptr[src_index]; + } + } +} + +void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + /* Row-major to row16x4-major (block row-major) */ + int col16 = UP_ROUND(col, C16NUM); + int row_4div = row / C4NUM * C4NUM; + int row_4res = row - row_4div; + int col_16div = col / C16NUM * C16NUM; + int col_16res = col - col_16div; + int8_t *src_r = (int8_t *)src_ptr; + int8_t *dst_r = (int8_t *)dst_ptr; + + for (int ri = 0; ri < row_4div; ri += C4NUM) { + for (int ci = 0; ci < col_16div; ci += C16NUM) { + size_t col_offset = (size_t)col; + int8_t *src_c = src_r + ci; + int8_t *dst_c = dst_r + ci * C4NUM; +#ifdef ENABLE_ARM64 + asm volatile( + "mov x10, %[src_c] \n" + "mov x11, %[dst_c] \n" + + "ld1 {v0.16b}, [x10], %[col_offset]\n" + "ld1 {v1.16b}, [x10], %[col_offset]\n" + "ld1 {v2.16b}, [x10], %[col_offset]\n" + "ld1 {v3.16b}, [x10], %[col_offset]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "st1 {v2.16b}, [x11], #16\n" + "st1 {v3.16b}, [x11], #16\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col_offset ] "r"(col_offset) + : "x10", "x11", "v0", "v1", "v2", "v3"); +#elif ENABLE_ARM32 + asm volatile( + "mov r0, %[src_c] \n" + "mov r1, %[dst_c] \n" + "mov r2, %[col_offset] \n" + "mov r3, #16 \n" + + "vld1.8 {q0}, [r0], r2 \n" + "vld1.8 {q1}, [r0], r2 \n" + "vld1.8 {q2}, [r0], r2 \n" + "vld1.8 {q3}, [r0], r2 \n" + + "vst1.32 {d0, d1}, [r1], r3 \n" + "vst1.32 {d2, d3}, [r1], r3 \n" + "vst1.32 {d4, d5}, [r1], r3 \n" + "vst1.32 {d6, d7}, [r1], r3 \n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col_offset ] "r"(col_offset) + : "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3"); +#else + MatrixPack4x16UnitInt8(src_c, dst_c, C4NUM, C16NUM, col_offset); +#endif + } + + if (col != col_16div) { + MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, C4NUM, col_16res, col); + MatrixEmptyInt8(dst_r + col_16div * C4NUM + col_16res, C4NUM, C16NUM - col_16res); + } + src_r += C4NUM * col; + dst_r += C4NUM * col16; + } + + if (row != row_4div) { + memset(dst_r, 0, C4NUM * col16); + + for (int ci = 0; ci < col_16div; ci += C16NUM) { + MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, row_4res, C16NUM, col); + } + + if (col != col_16div) { + MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, row_4res, col_16res, col); + } + } + return; +} + +void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int32_t *dst, int row_4, int col_4, int deep_16, + const int32_t *input_sum, const int32_t *bias) { + /* row4x16-major * row16x4-major => row4x4-major */ + for (int r = 0; r < row_4; r++) { + for (int c = 0; c < col_4; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c4div = c / C4NUM, c4mod = c % C4NUM; + int64_t ci = c4div * row_4 * C4NUM + r * C4NUM + c4mod; + int32_t value = 0; + for (int d = 0; d < deep_16; d++) { + int d16div = d / C16NUM, d16mod = d % C16NUM; + int64_t ai = r4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + int64_t bi = c4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; + value = value + a[ai] * b[bi]; + } + value -= input_sum[r]; + value += bias[c]; + ((int32_t *)dst)[ci] = value; + } + } + return; +} + +void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, bool peroc) { + /* support per-layer && weight per-channel */ + /* row4x16-major * row16x2-major => (int8)row-major*/ + for (size_t r = 0; r < row; r++) { + for (size_t c = 0; c < col; c++) { + size_t r4div = r / C4NUM, r4mod = r % C4NUM; + size_t c2div = c / C2NUM, c2mod = c % C2NUM; + size_t ci = r * stride + c; + int32_t value = 0; + for (size_t d = 0; d < deep_16; d++) { + size_t d16div = d / C16NUM, d16mod = d % C16NUM; + size_t ai = r4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + size_t bi = c2div * deep_16 * C2NUM + d16div * C2NUM * C16NUM + c2mod * C16NUM + d16mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = + peroc ? input_sum[c2div * UP_ROUND(row, C4NUM) * C2NUM + r * C2NUM + c2mod] : input_sum[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = peroc ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = peroc ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = peroc ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} + +#ifndef ENABLE_ARM +void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int32_t *a_sums, + const int32_t *bias, int mini, int maxi, int out_zp, const int32_t *multiplier, + const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc, + const int32_t *filter_zp) { + /* + * row4x16-major * row16x4-major => (int8)row-major + * support per-layer && weight per-channel + * a_sums is perT : input_row_sum * filter_zp + * perOc : input_row_sum + * */ + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c4div = c / C4NUM, c4mod = c % C4NUM; + int64_t ci = r * stride + c; + int32_t value = 0; + for (int d = 0; d < deep16; d++) { + int d16div = d / C16NUM, d16mod = d % C16NUM; + int64_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + int64_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = filter_peroc ? a_sums[r] * filter_zp[c] : a_sums[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = filter_peroc ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = filter_peroc ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = filter_peroc ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + out_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} +#endif +void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel) { + /* row8x4-major * row4x8-major => (int8)row-major */ + for (size_t r = 0; r < row; r++) { + for (size_t c = 0; c < col; c++) { + size_t r8div = r / C8NUM, r8mod = r % C8NUM; + size_t c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = r * stride + c; + int32_t value = 0; + for (size_t d = 0; d < deep_4; d++) { + size_t d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + r8mod * C4NUM + d4mod; + size_t bi = c8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + c8mod * C4NUM + d4mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = + per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) * C8NUM + r * C8NUM + c8mod] : input_sum[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} + +void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel, const int32_t *filter_zp) { + /* row4x4-major * row4x16-major => (int8)row-major */ + for (size_t r = 0; r < row; r++) { + for (size_t c = 0; c < col; c++) { + size_t r4div = r / C4NUM, r4mod = r % C4NUM; + size_t c16div = c / C16NUM, c16mod = c % C16NUM; + size_t ci = r * stride + c; + int32_t value = 0; + for (size_t d = 0; d < deep_4; d++) { + size_t d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r4div * deep_4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod; + size_t bi = c16div * deep_4 * C16NUM + d4div * C16NUM * C4NUM + c16mod * C4NUM + d4mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = per_channel ? input_sum[r] * filter_zp[c] : input_sum[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} + +#ifdef ENABLE_ARM64 +void PackInput4x4AndInputSumPert_arm64(const int8_t *src_ic, int8_t *pack_ic, int32_t *input_sum_r, size_t src_stride, + size_t ic_4div, size_t ic_4res, int32_t filter_zp) { + asm volatile( + "dup v2.4s, wzr \n" + "mov x14, %[input_sum_r] \n" + "dup v3.4s, %w[filter_zp] \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x15, #0 \n" + "1: \n" + "cmp x15, %[ic_4div] \n" + "add x15, x15, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 1b \n" + + "3: \n" /* ic res 1 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 6f \n" + + "4: \n" /* ic res 2 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 6f \n" + + "5: \n" /* ic res 3 */ + "dup v0.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 6f \n" + + "6: \n" + "mul v2.4s, v2.4s, v3.4s \n" + + "st1 {v2.4s}, [x14], #16 \n" + + : + : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), + [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp) + : "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3"); + return; +} +#endif +void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, + size_t input_channel, size_t plane_size, int32_t filter_zp) { + size_t ic4 = UP_ROUND(input_channel, C4NUM); + size_t hw4 = UP_ROUND(plane_size, C4NUM); + size_t hw_4div = plane_size / C4NUM * C4NUM; + size_t ic_4div = input_channel / C4NUM * C4NUM; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (size_t hwi = 0; hwi < hw_4div; hwi += C4NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_r = input_sum + hwi; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + PackInput4x4AndInputSumPert_arm64(src_ic, pack_ic, input_sum_r, src_stride, ic_4div, ic_4res, filter_zp); +#else + int32_t tmp_sum_value[4] = {0}; + for (size_t ici = 0; ici < ic_4div; ici += C4NUM) { + for (size_t i = 0; i < C4NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + for (size_t ici = ic_4div; ici < input_channel; ici += 1) { + for (size_t i = 0; i < C4NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (size_t ici = input_channel; ici < ic4; ici += 1) { + for (size_t i = 0; i < C4NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } + + for (size_t i = 0; i < C4NUM; i++) { + input_sum_r[i] = tmp_sum_value[i] * filter_zp; + } +#endif + src_r += input_channel * C4NUM; + pack_r += ic4 * C4NUM; + } + + if (hw_4div != plane_size) { + (void)memset(pack_r, 0, C4NUM * ic4); + for (size_t hwi = hw_4div; hwi < plane_size; hwi += 1) { + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (size_t ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + for (size_t ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + input_sum[hwi] = tmp_sum_value * filter_zp; + src_r += input_channel; + pack_r += C4NUM; + } + for (size_t hwi = plane_size; hwi < hw4; hwi++) { + input_sum[hwi] = 0; + } + } + return; +} + +#ifdef ENABLE_ARM64 +void PackInput2Col4x4AndInputSumPert_arm64(const int8_t *src_ic, int8_t *packed_ic, int32_t *input_sum, int row, + size_t row_stride, int32_t filter_zp) { + asm volatile( + "ld1 {v12.s}[0], [%[input_sum]]\n" + "mov w10, %w[row]\n" + "mov x11, %[src_ic]\n" + "mov x12, %[packed_ic]\n" + "sxtl v6.8h, v12.8b\n" + "sxtl v12.4s, v6.4h\n" + "cmp w10, wzr\n" + "beq 1f\n" + "2:\n" + "subs w10, w10, #4\n" + "ld1 {v0.s}[0], [x11], %[row_stride]\n" + "ld1 {v1.s}[0], [x11], %[row_stride]\n" + "ld1 {v0.s}[1], [x11], %[row_stride]\n" + "ld1 {v1.s}[1], [x11], %[row_stride]\n" + "zip1 v2.8b, v0.8b, v1.8b\n" + "zip2 v3.8b, v0.8b, v1.8b\n" + "zip1 v4.4h, v2.4h, v3.4h\n" + "zip2 v5.4h, v2.4h, v3.4h\n" + "st1 {v4.4h, v5.4h}, [x12], #16\n" + + "sxtl v6.8h, v0.8b\n" + "sxtl v7.4s, v6.4h\n" + "sxtl2 v8.4s, v6.8h\n" + "sxtl v9.8h, v1.8b\n" + "sxtl v10.4s, v9.4h\n" + "sxtl2 v11.4s, v9.8h\n" + "add v10.4s, v10.4s, v7.4s\n" + "add v10.4s, v10.4s, v8.4s\n" + "add v10.4s, v10.4s, v10.4s\n" + "add v10.4s, v10.4s, v11.4s\n" + "bgt 2b\n" + "1:\n" + + : + : [ src_ic ] "r"(src_ic), [ packed_ic ] "r"(packed_ic), [ input_sum ] "r"(input_sum), [ row ] "r"(row), + [ row_stride ] "r"(row_stride), [ filter_zp ] "r"(filter_zp) + : "memory", "w10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12"); + + return; +} +#endif + +// For matmul input a transpose case +void PackInput2Col4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, int row, + int col, int row_stride, int32_t filter_zp) { + const int row_tile = C4NUM; + int row_align = UP_ROUND(row, row_tile); + int row_div = row / row_tile * row_tile; + const int row_res = row - row_div; + + const int col_tile = C4NUM; + int col_div = col / col_tile * col_tile; + const int col_res = col - col_div; + + const int8_t *src_ic = NULL; + int8_t *packed_ic = NULL; + int32_t *tmp_sum = NULL; + for (int c = 0; c < col_div; c += C4NUM) { + int r = 0; + src_ic = src_input + c; + packed_ic = packed_input + c * row_align; + tmp_sum = input_sum + c; +#ifdef ENABLE_ARM64 + PackInput2Col4x4AndInputSumPert_arm64(src_ic, packed_ic, tmp_sum, row_div, row_stride, filter_zp); + packed_ic += C4NUM * row_div; + src_ic += row_div * row_stride; +#else + for (; r < row_div; r += C4NUM) { + for (int i = 0; i < row_tile; i++) { + packed_ic[0 * row_tile + i] = src_ic[i * row_stride + 0]; + packed_ic[1 * row_tile + i] = src_ic[i * row_stride + 1]; + packed_ic[2 * row_tile + i] = src_ic[i * row_stride + 2]; + packed_ic[3 * row_tile + i] = src_ic[i * row_stride + 3]; + + tmp_sum[0] += src_ic[i * row_stride + 0]; + tmp_sum[1] += src_ic[i * row_stride + 1]; + tmp_sum[2] += src_ic[i * row_stride + 2]; + tmp_sum[3] += src_ic[i * row_stride + 3]; + } + packed_ic += C16NUM; + src_ic += row_tile * row_stride; + } +#endif + + for (r = 0; r < row_res; ++r) { + for (int i = 0; i < C4NUM; ++i) { + packed_ic[i * row_tile + r] = src_ic[r * row_stride + i]; + tmp_sum[i] += src_ic[r * row_stride + i]; + } + } + } + if (col_res == 0) { + for (int i = 0; i < col; ++i) { + input_sum[i] *= filter_zp; + } + return; + } + src_ic = src_input + col_div; + packed_ic = packed_input + row_align * col_div; + tmp_sum = input_sum + col_div; + for (int r = 0; r < row_div; r += row_tile) { + for (int i = 0; i < col_res; ++i) { + packed_ic[i * row_tile + 0] = src_ic[0 * row_stride + i]; + packed_ic[i * row_tile + 1] = src_ic[1 * row_stride + i]; + packed_ic[i * row_tile + 2] = src_ic[2 * row_stride + i]; + packed_ic[i * row_tile + 3] = src_ic[3 * row_stride + i]; + + tmp_sum[i] += src_ic[0 * row_stride + i]; + tmp_sum[i] += src_ic[1 * row_stride + i]; + tmp_sum[i] += src_ic[2 * row_stride + i]; + tmp_sum[i] += src_ic[3 * row_stride + i]; + } + src_ic += row_tile * row_stride; + packed_ic += row_tile * col_tile; + } + + for (int r = 0; r < row_res; ++r) { + for (int c = 0; c < col_res; ++c) { + packed_ic[c * row_tile + r] = src_ic[r * row_stride + c]; + tmp_sum[c] += src_ic[r * row_stride + c]; + } + } + + for (int i = 0; i < col; ++i) { + input_sum[i] *= filter_zp; + } +} + +void RowMajor2Col16x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col) { + int row_16 = UP_ROUND(row, C16NUM); + int stride = sizeof(int8_t) * 16 * 4; + for (int r = 0; r < row_16; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / 4 * (row_16 / 16) + r / 16; + if (r >= row) { + dst[stride * stride_idx + c % 4 * 16 + r % 16] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % 4 * 16 + r % 16] = src[src_idx]; + } + } + } +} + +void RowMajor2Col4x4MajorInt8(const int8_t *src, int row, int col, int8_t *dst) { + int row_4 = UP_ROUND(row, C4NUM); + int stride = C4NUM * C4NUM; + for (int r = 0; r < row_4; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / C4NUM * (row_4 / C4NUM) + r / C4NUM; + if (r >= row) { + dst[stride * stride_idx + c % C4NUM * C4NUM + r % C4NUM] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % C4NUM * C4NUM + r % C4NUM] = src[src_idx]; + } + } + } +} + +void RowMajor2Col4x16MajorPartInt8(const int8_t *src, int8_t *dst, int row, int col, int cur_oc) { + int row_4 = UP_ROUND(row, C4NUM); + int stride = C16NUM * C4NUM; + for (int r = 0; r < row_4; ++r) { + for (int c = 0; c < cur_oc; ++c) { + int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM; + if (r >= row) { + dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = src[src_idx]; + } + } + } +} + +void RowMajor2Col4x16MajorInt8(const int8_t *src, int8_t *dst, int row, int col) { + int row_4 = UP_ROUND(row, C4NUM); + int stride = C16NUM * C4NUM; + for (int r = 0; r < row_4; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / C16NUM * (row_4 / C4NUM) + r / C4NUM; + if (r >= row) { + dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = 0; + } else { + int src_idx = r * col + c; + dst[stride * stride_idx + c % C16NUM * C4NUM + r % C4NUM] = src[src_idx]; + } + } + } +} + +void CalcInputSums(const int8_t *input, int row, int col, int weight_zp, int32_t *dst, DataOrder order) { + for (int r = 0; r < row; ++r) { + int sum = 0; + for (int c = 0; c < col; ++c) { + if (order == RowMajor) { + sum += input[r * col + c]; + } else { + sum += input[c * row + r]; + } + } + sum *= weight_zp; + dst[r] = sum; + } +} + +// dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums +void CalcWeightBiasSums(const int8_t *weight, int row, int col, int input_zp, const int32_t *weight_zp_ptr, + const int32_t *bias, int32_t *dst, DataOrder order, bool filter_per_channel) { + for (int c = 0; c < col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + if (order == RowMajor) { + sum += weight[r * col + c]; + } else { + sum += weight[c * row + r]; + } + } + int weight_zp = filter_per_channel ? weight_zp_ptr[c] : weight_zp_ptr[0]; + dst[c] = row * input_zp * weight_zp - input_zp * sum; + if (bias != NULL) { + dst[c] += bias[c]; + } + } +} + +void CalcPartWeightBiasSums(const int8_t *weight, int row, int stride, int cur_col, int input_zp, + const int32_t *weight_zp_ptr, const int32_t *bias, int32_t *dst, DataOrder order, + bool filter_per_channel) { + for (int c = 0; c < cur_col; ++c) { + int sum = 0; + for (int r = 0; r < row; ++r) { + if (order == RowMajor) { + sum += weight[r * stride + c]; + } else { + sum += weight[c * row + r]; + } + } + int weight_zp = filter_per_channel ? weight_zp_ptr[c] : weight_zp_ptr[0]; + dst[c] = row * input_zp * weight_zp - input_zp * sum; + if (bias != NULL) { + dst[c] += bias[c]; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/matmul_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/matmul_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..844109012770417fcae92e84f00a285d8d7c18d2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/matmul_int8.h @@ -0,0 +1,93 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_MATMUL_H_ +#define NNACL_INT8_MATMUL_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/matmul_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +/* 4x16 16x4 -> 4x4 */ +/* sdot 4x4 4x16 -> 4x16 */ +/* matmul */ +void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int32_t *dst, int row_4, int col_4, int deep_16, + const int32_t *input_sum, const int32_t *bias); +void RowMajor2Row16x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void RowMajor2Col16x4MajorInt8(const int8_t *src, int8_t *dst, int row, int col); +void RowMajor2Col4x16MajorInt8(const int8_t *src, int8_t *dst, int row, int col); +void RowMajor2Col4x16MajorPartInt8(const int8_t *src, int8_t *dst, int row, int col, int cur_oc); +void PackInput2Col4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, int row, + int col, int row_stride, int32_t filter_zp); +void CalcInputSums(const int8_t *input, int row, int col, int weight_zp, int32_t *dst, DataOrder order); +void CalcWeightBiasSums(const int8_t *weight, int row, int col, int input_zp, const int32_t *weight_zp_ptr, + const int32_t *bias, int32_t *dst, DataOrder order, bool filter_per_channel); +void CalcPartWeightBiasSums(const int8_t *weight, int row, int stride, int cur_col, int input_zp, + const int32_t *weight_zp_ptr, const int32_t *bias, int32_t *dst, DataOrder order, + bool filter_per_channel); +void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int32_t *a_sums, + const int32_t *bias, int act_min, int act_max, int out_zp, const int32_t *multiplier, + const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc, + const int32_t *filter_zp); +/* 8x4 4x8 -> 8x8 */ +/* optimize conv */ +void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel); + +/* 4x16 16x2 -> 4x2 */ +/* arm32 conv1x1 */ +void RowMajor2Row2x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void RowMajor2Col16x2MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, bool peroc); + +/* 4x4 4x16 -> 4x16 */ +/* optimize conv1x1 */ +void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, + size_t input_channel, size_t plane_size, int32_t filter_zp); +void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, const int32_t *left_shift, + const int32_t *right_shift, const int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel, const int32_t *filter_zp); + +#ifdef ENABLE_ARM64 +void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, + const int32_t *a_sums, const int32_t *bias, int act_min, int act_max, int out_zp, + int32_t *multiplier, int32_t *left_shift, int32_t *right_shift, int row, int col, int stride, + int filter_peroc); + +void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, + const int32_t *input_sum, const int32_t *bias); +#endif +#ifdef ENABLE_ARM32 +void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, + const int32_t *input_sums, const int32_t *weight_bias, int act_min, int act_max, int out_zp, + int32_t *multiplier, int32_t *left_shift, int32_t *right_shift, int stride, int per_channel); +#endif +#ifdef __cplusplus +} +#endif + +#endif // LITE_SRC_BACKEND_ARM_NNACL_INT8_MATMUL_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/mul_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/mul_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..8b0dd4a15077342ee4b4fcafb3ecc6ffeb0a2c2e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/mul_int8.c @@ -0,0 +1,238 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/mul_int8.h" + +#ifdef ENABLE_NEON +int16x4_t ClacSumHalfWordMul(int16x4_t scaled_input0, int16x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t right_shift_out_vec, int32x4_t output_multiplier_vec) { + int32x4_t input_scale = vmull_s16(scaled_input0, scaled_input1); + int32x4_t raw_sum = vqrdmulhq_s32(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(raw_sum, right_shift_out_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(raw_sum, fixup); + raw_sum = vrshlq_s32(fixed_up_x, right_shift_out_vec); + return vqmovn_s32(raw_sum); +} + +void MulInt8NEON(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const MulQuantArg *quant_arg, int32_t *index) { + int32x4_t output_multiplier_vec = vdupq_n_s32(quant_arg->output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << (size_t)quant_arg->shift_left_); + int32x4_t right_shift_out_vec = vdupq_n_s32(-quant_arg->shift_right_); + int16x8_t out_zp_vec = vdupq_n_s16(quant_arg->out_quant_arg_.zp_); + int8x16_t out_min_vec = vdupq_n_s8(quant_arg->output_activation_min_); + int8x16_t out_max_vec = vdupq_n_s8(quant_arg->output_activation_max_); + int8x8_t out_min_vec_s8 = vdup_n_s8(quant_arg->output_activation_min_); + int8x8_t out_max_vec_s8 = vdup_n_s8(quant_arg->output_activation_max_); + + for (; (*index) <= real_dst_count - 16; (*index) += 16) { + int16x8_t zp1_vec = vdupq_n_s16(quant_arg->in_quant_args_[0].zp_); + int16x8_t zp2_vec = vdupq_n_s16(quant_arg->in_quant_args_[1].zp_); + int8x16_t input0_vec = vld1q_s8(input0_data + *index); + int8x16_t input1_vec = vld1q_s8(input1_data + *index); + int16x8_t input0_low = vmovl_s8(vget_low_s8(input0_vec)); + int16x8_t input0_high = vmovl_s8(vget_high_s8(input0_vec)); + int16x8_t input1_low = vmovl_s8(vget_low_s8(input1_vec)); + int16x8_t input1_high = vmovl_s8(vget_high_s8(input1_vec)); + input0_low = vaddq_s16(input0_low, zp1_vec); + input0_high = vaddq_s16(input0_high, zp1_vec); + input1_low = vaddq_s16(input1_low, zp2_vec); + input1_high = vaddq_s16(input1_high, zp2_vec); + + int16x4_t input0_low_low = vget_low_s16(input0_low); + int16x4_t input0_low_high = vget_high_s16(input0_low); + int16x4_t input0_high_low = vget_low_s16(input0_high); + int16x4_t input0_high_high = vget_high_s16(input0_high); + int16x4_t input1_low_low = vget_low_s16(input1_low); + int16x4_t input1_low_high = vget_high_s16(input1_low); + int16x4_t input1_high_low = vget_low_s16(input1_high); + int16x4_t input1_high_high = vget_high_s16(input1_high); + + int16x4_t sum_low_low = ClacSumHalfWordMul(input0_low_low, input1_low_low, left_shift_out_vec, right_shift_out_vec, + output_multiplier_vec); + int16x4_t sum_low_high = ClacSumHalfWordMul(input0_low_high, input1_low_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_low = ClacSumHalfWordMul(input0_high_low, input1_high_low, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_high = ClacSumHalfWordMul(input0_high_high, input1_high_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low_low, sum_low_high), out_zp_vec); + int16x8_t res_s162 = vaddq_s16(vcombine_s16(sum_high_low, sum_high_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + int8x8_t res_u8_n1 = vqmovn_s16(res_s162); + int8x16_t res_s8 = vcombine_s8(res_u8_n0, res_u8_n1); + res_s8 = vminq_s8(res_s8, out_max_vec); + res_s8 = vmaxq_s8(res_s8, out_min_vec); + vst1q_s8(output_data, res_s8); + output_data += 16; + } + for (; (*index) <= real_dst_count - 8; (*index) += 8) { + int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, quant_arg->in_quant_args_[0].zp_); + int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, quant_arg->in_quant_args_[1].zp_); + + int16x4_t input0_low = vget_low_s16(input0_val); + int16x4_t input0_high = vget_high_s16(input0_val); + int16x4_t input1_low = vget_low_s16(input1_val); + int16x4_t input1_high = vget_high_s16(input1_val); + + int16x4_t sum_low = + ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high = + ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low, sum_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + res_u8_n0 = vmin_s8(res_u8_n0, out_max_vec_s8); + res_u8_n0 = vmax_s8(res_u8_n0, out_min_vec_s8); + vst1_s8(output_data, res_u8_n0); + output_data += 8; + } +} +#endif + +void FastMul(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int depth, + int64_t real_dst_count, bool input1_broad, const MulQuantArg *quant_arg) { + // input0 need broadcast + int32_t zp1 = quant_arg->in_quant_args_[0].zp_; + int32_t zp2 = quant_arg->in_quant_args_[1].zp_; + if (input1_broad) { + zp1 = quant_arg->in_quant_args_[1].zp_; + zp2 = quant_arg->in_quant_args_[0].zp_; + } +#ifdef ENABLE_NENO + int32x4_t output_multiplier_vec = vdupq_n_s32(quant_arg->output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << (size_t)quant_arg->shift_left_); + int32x4_t right_shift_out_vec = vdupq_n_s32(-quant_arg->shift_right_); + int16x8_t out_zp_vec = vdupq_n_s16(quant_arg->out_quant_arg_.zp_); + int8x16_t out_min_vec = vdupq_n_s8(quant_arg->output_activation_min_); + int8x16_t out_max_vec = vdupq_n_s8(quant_arg->output_activation_max_); + int8x8_t out_min_vec_s8 = vdup_n_s8(quant_arg->output_activation_min_); + int8x8_t out_max_vec_s8 = vdup_n_s8(quant_arg->output_activation_max_); + int16x8_t zp1_vec = vdupq_n_s16(zp1); + int16x8_t zp2_vec = vdupq_n_s16(zp2); +#endif + for (int index = 0; index < real_dst_count; ++index) { + int j = 0; +#ifdef ENABLE_NENO + for (; j <= depth - 16; j += 16) { + int8x16_t input0_vec = vld1q_s8(input0_data + j); + int8x16_t input1_vec = vld1q_s8(input1_data); + int16x8_t input0_low = vmovl_s8(vget_low_s8(input0_vec)); + int16x8_t input0_high = vmovl_s8(vget_high_s8(input0_vec)); + int16x8_t input1_low = vmovl_s8(vget_low_s8(input1_vec)); + int16x8_t input1_high = vmovl_s8(vget_high_s8(input1_vec)); + input0_low = vaddq_s16(input0_low, zp1_vec); + input0_high = vaddq_s16(input0_high, zp1_vec); + input1_low = vaddq_s16(input1_low, zp2_vec); + input1_high = vaddq_s16(input1_high, zp2_vec); + + int16x4_t input0_low_low = vget_low_s16(input0_low); + int16x4_t input0_low_high = vget_high_s16(input0_low); + int16x4_t input0_high_low = vget_low_s16(input0_high); + int16x4_t input0_high_high = vget_high_s16(input0_high); + int16x4_t input1_low_low = vget_low_s16(input1_low); + int16x4_t input1_low_high = vget_high_s16(input1_low); + int16x4_t input1_high_low = vget_low_s16(input1_high); + int16x4_t input1_high_high = vget_high_s16(input1_high); + + int16x4_t sum_low_low = ClacSumHalfWordMul(input0_low_low, input1_low_low, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_low_high = ClacSumHalfWordMul(input0_low_high, input1_low_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_low = ClacSumHalfWordMul(input0_high_low, input1_high_low, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high_high = ClacSumHalfWordMul(input0_high_high, input1_high_high, left_shift_out_vec, + right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low_low, sum_low_high), out_zp_vec); + int16x8_t res_s162 = vaddq_s16(vcombine_s16(sum_high_low, sum_high_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + int8x8_t res_u8_n1 = vqmovn_s16(res_s162); + int8x16_t res_s8 = vcombine_s8(res_u8_n0, res_u8_n1); + res_s8 = vminq_s8(res_s8, out_max_vec); + res_s8 = vmaxq_s8(res_s8, out_min_vec); + vst1q_s8(output_data, res_s8); + input1_data += 16; + output_data += 16; + } + for (; j <= depth - 8; j += 8) { + int8x8_t input0_vec = vld1_s8(input0_data + j); + int8x8_t input1_vec = vld1_s8(input1_data); + int16x8_t input0_val = vmovl_s8(input0_vec); + int16x8_t input1_val = vmovl_s8(input1_vec); + input0_val = vaddq_s16(input0_val, zp1_vec); + input1_val = vaddq_s16(input1_val, zp2_vec); + + int16x4_t input0_low = vget_low_s16(input0_val); + int16x4_t input0_high = vget_high_s16(input0_val); + int16x4_t input1_low = vget_low_s16(input1_val); + int16x4_t input1_high = vget_high_s16(input1_val); + + int16x4_t sum_low = + ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + int16x4_t sum_high = + ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec); + + int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low, sum_high), out_zp_vec); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + res_u8_n0 = vmin_s8(res_u8_n0, out_max_vec_s8); + res_u8_n0 = vmax_s8(res_u8_n0, out_min_vec_s8); + vst1_s8(output_data, res_u8_n0); + input1_data += 8; + output_data += 8; + } +#endif + for (; j < depth; ++j) { + const int32_t input0_val = zp1 + input0_data[j]; + const int32_t input1_val = zp2 + input1_data[0]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << (size_t)quant_arg->shift_left_), + quant_arg->output_multiplier_), + quant_arg->shift_right_); + + mul_result += quant_arg->out_quant_arg_.zp_; + mul_result = mul_result < quant_arg->output_activation_max_ ? mul_result : quant_arg->output_activation_max_; + mul_result = mul_result > quant_arg->output_activation_min_ ? mul_result : quant_arg->output_activation_min_; + output_data[0] = (int8_t)mul_result; + input1_data++; + output_data++; + } + } + return; +} + +void Mul(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const MulQuantArg *quant_arg) { + int index = 0; +#ifdef ENABLE_NEON + MulInt8NEON(input0_data, input1_data, output_data, real_dst_count, quant_arg, &index); +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = quant_arg->in_quant_args_[0].zp_ + input0_data[index]; + const int32_t input1_val = quant_arg->in_quant_args_[1].zp_ + input1_data[index]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << (size_t)quant_arg->shift_left_), + quant_arg->output_multiplier_), + quant_arg->shift_right_); + + mul_result += quant_arg->out_quant_arg_.zp_; + mul_result = mul_result < quant_arg->output_activation_max_ ? mul_result : quant_arg->output_activation_max_; + mul_result = mul_result > quant_arg->output_activation_min_ ? mul_result : quant_arg->output_activation_min_; + output_data[index] = (int8_t)mul_result; + } + return; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/mul_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/mul_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..d5d0b5b596c9b03ac3d054e9e8adea5c99451391 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/mul_int8.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_MUL_INT8_H_ +#define NNACL_INT8_MUL_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/mul_parameter.h" +#include "nnacl/int8/common_func_int8.h" +#include "nnacl/int8/fixed_point.h" +#ifdef ENABLE_NEON +#include +#endif + +#ifdef __cplusplus +extern "C" { +#endif +void Mul(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const MulQuantArg *quant_arg); +void FastMul(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int depth, + int64_t real_dst_count, bool input1_broad, const MulQuantArg *quant_arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_MUL_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/pack_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/pack_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..2032cff6ccbfe501eb9468e495cc77a0359c3173 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/pack_int8.c @@ -0,0 +1,454 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/pack_int8.h" + +void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, const ConvParameter *conv_param) { + int in_batch = conv_param->input_batch_; + int in_channel = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + int ic8_round = UP_ROUND(in_channel, C8NUM); + int ic8 = in_channel / C8NUM * C8NUM; + int in_plane = in_h * in_w; + + for (int b = 0; b < in_batch; b++) { + int src_batch_offset = b * in_channel * in_plane; + int dst_batch_offset = b * ic8_round * in_plane; + for (int k = 0; k < in_plane; k++) { + int src_plane_offset = src_batch_offset + k * in_channel; + int dst_plane_offset = dst_batch_offset + k * C8NUM; + for (int i = 0; i < ic8; i += 8) { + int src_c_offset = src_plane_offset + i; + int dst_c_offset = dst_plane_offset + i * in_plane; +#ifdef ENABLE_ARM + vst1q_s16(packed_input + dst_c_offset, vmovl_s8(vld1_s8(input_data + src_c_offset))); +#else + for (int j = 0; j < C8NUM; ++j) { + (packed_input + dst_c_offset)[j] = (int16_t)(input_data + src_c_offset)[j]; + } +#endif + } // ic8_minus loop + int res_c = in_channel - ic8; + int tmp_ic_offset = ic8 * in_plane; + for (int l = 0; l < res_c; ++l) { + int src_c_offset = src_plane_offset + ic8 + l; + int dst_c_offset = dst_plane_offset + tmp_ic_offset + l; + (packed_input + dst_c_offset)[0] = (int16_t)(input_data + src_c_offset)[0]; + } // res ic loop + int res2 = ic8_round - in_channel; + for (int l = 0; l < res2; ++l) { + int dst_c_offset = dst_plane_offset + tmp_ic_offset + res_c + l; + (packed_input + dst_c_offset)[0] = 0; + } // res ic loop + } // kh * kw loop + } +} + +void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, + const ConvParameter *conv_param) { + // origin weight format : ohwi + int input_channel = conv_param->input_channel_; + int ic8 = input_channel / C8NUM * C8NUM; + int ic8_round = UP_ROUND(input_channel, C8NUM); + int output_channel = conv_param->output_channel_; + QuantArg *filter_zp = conv_param->conv_quant_arg_.filter_quant_args_; + int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_; + + for (int k = 0; k < kernel_plane; k++) { + int src_kernel_offset = k * input_channel; + int dst_kernel_offset = k * C8NUM; + for (int o = 0; o < output_channel; o++) { + int32_t zp; + if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { + zp = filter_zp[0].zp_; + } else { + zp = filter_zp[o].zp_; + } + int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; + int dst_oc_offset = dst_kernel_offset + o * ic8_round * kernel_plane; + int i = 0; + for (; i < ic8; i += C8NUM) { + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + i * kernel_plane; +#ifdef ENABLE_ARM64 + int8x8_t src_s8 = vld1_s8(origin_weight_data + src_ic_offset); + int16x8_t src_s16 = vmovl_s8(src_s8); + int16x4_t src1_s16 = vget_low_s16(src_s16); + int16x4_t src2_s16 = vget_high_s16(src_s16); + int32x4_t src1_s32 = vmovl_s16(src1_s16); + int32x4_t src2_s32 = vmovl_s16(src2_s16); + int32x4_t zp_s32 = vdupq_n_s32(zp); + int32x4_t dst1_s32 = vsubq_s32(src1_s32, zp_s32); + int32x4_t dst2_s32 = vsubq_s32(src2_s32, zp_s32); + int16x4_t dst1_s16 = vqmovn_s32(dst1_s32); + int16x4_t dst2_s16 = vqmovn_s32(dst2_s32); + vst1_s16(packed_weight_data + dst_ic_offset, dst1_s16); + vst1_s16(packed_weight_data + dst_ic_offset + 4, dst2_s16); +#else + for (int ci = 0; ci < C8NUM; ++ci) { + (packed_weight_data + dst_ic_offset + ci)[0] = (int16_t)((origin_weight_data + src_ic_offset + ci)[0] - zp); + } +#endif + } + dst_oc_offset += ic8 * kernel_plane; + for (; i < input_channel; i++) { + int c8_block_rem = i % C8NUM; + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + c8_block_rem; + (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - zp); + } + } + } +} + +void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) { + /* normal matmul : 4x16 * 16x4 -> 4x4 */ +#ifdef ENABLE_ARM + PreSum4x16Int8Pert(src, dst, row4, col16, filter_zp); +#else + for (size_t r = 0; r < row4; r++) { + int32_t tmp_value = 0; + for (size_t c = 0; c < col16; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM; + int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod; + tmp_value += src[src_index]; + } + dst[r] = tmp_value * filter_zp; + } +#endif + return; +} +void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) { + int input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; + int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); + int unit = conv_param->input_h_ * conv_param->input_w_; + + for (int b = 0; b < conv_param->input_batch_; b++) { + const int8_t *src_b = src + b * unit * conv_param->input_channel_; + int16_t *dst_b = dst + b * unit * ic4 * C4NUM; + for (int k = 0; k < unit; k++) { + const int8_t *src_k = src_b + k * conv_param->input_channel_; + int16_t *dst_k = dst_b + k * ic4 * C4NUM; + for (int c = 0; c < conv_param->input_channel_; c++) { + dst_k[c] = (int16_t)(src_k[c] - input_zp); + } + } + } +} + +void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + const ConvQuantArg *quant_qrg) { + int weight_zp = quant_qrg->filter_quant_args_[0].zp_; + for (int c = 0; c < channel; c++) { + if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { + weight_zp = quant_qrg->filter_quant_args_[c].zp_; + } + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + const int8_t *src_c = origin_weight + c * plane; + int16_t *dst_c = packed_weight_ + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + const int8_t *src_kernel = src_c + k; + int16_t *dst_kernel = dst_c + C8NUM * k + c8_block_rem; + *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); + } + } +} + +void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + const ConvQuantArg *quant_qrg) { + int weight_zp = quant_qrg->filter_quant_args_[0].zp_; + for (int c = 0; c < channel; c++) { + if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { + weight_zp = quant_qrg->filter_quant_args_[c].zp_; + } + int c4_block_num = c / C4NUM; + int c4_block_rem = c % C4NUM; + const int8_t *src_c = origin_weight + c * plane; + int16_t *dst_c = packed_weight_ + c4_block_num * plane * C4NUM; + for (int k = 0; k < plane; k++) { + const int8_t *src_kernel = src_c + k; + int16_t *dst_kernel = dst_c + C4NUM * k + c4_block_rem; + *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); + } + } +} +void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int c4_channel = c4 * C4NUM; + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + int nhwc4_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + int8_t *dst_per_plane = (int8_t *)dst + nhwc4_batch_offset + i * c4_channel; + memcpy(dst_per_plane, (int8_t *)src + batch_offset + i * channel, channel); + for (int j = channel; j < c4_channel; ++j) { + dst_per_plane[j] = 0; + } + } + nhwc4_batch_offset += nhwc4_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + int nhwc4_batch_unit_offset = c4 * C4NUM * plane; + int ic_remainder_ = channel % C4NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + int nhwc4_batch_offset = b * nhwc4_batch_unit_offset; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc4_batch_offset + i * c4 * C4NUM, + channel); + } + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + nhwc8_batch_offset + i * c8 * C8NUM, (int8_t *)src + batch_offset + i * channel, + channel); + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + int nhwc8_batch_offset = b * nhwc8_batch_unit_offset; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc8_batch_offset + i * c8 * C8NUM, + channel); + } + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c4 = UP_DIV(channel, C4NUM); + for (int b = 0; b < batch; b++) { + int src_offset = b * plane * c4 * C4NUM; + int dst_offset = b * plane * channel; + for (int k = 0; k < plane; k++) { + int src_kernel_offset = src_offset + k * C4NUM; + int dst_kernel_offset = dst_offset + k * channel; + for (int c = 0; c < c4 - 1; c++) { + int src_c_offset = src_kernel_offset + c * plane * C4NUM; + int dst_c_offset = dst_kernel_offset + c * C4NUM; + ((int8_t *)dst + dst_c_offset)[0] = ((int8_t *)src + src_c_offset)[0]; + ((int8_t *)dst + dst_c_offset)[1] = ((int8_t *)src + src_c_offset)[1]; + ((int8_t *)dst + dst_c_offset)[2] = ((int8_t *)src + src_c_offset)[2]; + ((int8_t *)dst + dst_c_offset)[3] = ((int8_t *)src + src_c_offset)[3]; + } + // res part + int res_c = channel - (c4 - 1) * C4NUM; + for (int i = 0; i < res_c; i++) { + int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; + int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; + ((int8_t *)dst + dst_res_c_offset)[0] = ((int8_t *)src + src_res_c_offset)[0]; + } + } + } +} + +void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + for (int n = 0; n < batch; n++) { + for (int c = 0; c < channel; c++) { + for (int hw = 0; hw < plane; hw++) { + int nhwc_index = n * channel * plane + hw * channel + c; + int nchw_index = n * channel * plane + c * plane + hw; + ((int8_t *)(dst))[nhwc_index] = ((const int8_t *)(src))[nchw_index]; + } + } + } + return; +} + +void PackNHWCToNCHWInt8(const void *src, void *dst, int batches, int plane, int channel) { + int hw8 = plane / C8NUM * C8NUM; + int c8 = channel / C8NUM * C8NUM; + int batch = plane * channel; + for (int n = 0; n < batches; n++) { + const int8_t *src_batch = (const int8_t *)src + n * batch; + int8_t *dst_batch = (int8_t *)dst + n * batch; + int hw = 0; + for (; hw < hw8; hw += C8NUM) { + int c = 0; + for (; c < c8; c += C8NUM) { + const int8_t *src_ptr = src_batch + hw * channel + c; + int8_t *dst_ptr = dst_batch + c * plane + hw; +#ifdef ENABLE_ARM64 + size_t srcStride = channel * sizeof(int8_t); + size_t dstStride = plane * sizeof(int8_t); + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8b}, [x10], %[srcStride]\n" + "ld1 {v1.8b}, [x10], %[srcStride]\n" + "ld1 {v2.8b}, [x10], %[srcStride]\n" + "ld1 {v3.8b}, [x10], %[srcStride]\n" + + "trn1 v4.8b, v0.8b, v1.8b\n" + "trn2 v5.8b, v0.8b, v1.8b\n" + "trn1 v6.8b, v2.8b, v3.8b\n" + "trn2 v7.8b, v2.8b, v3.8b\n" + + "ld1 {v0.8b}, [x10], %[srcStride]\n" + "ld1 {v1.8b}, [x10], %[srcStride]\n" + "ld1 {v2.8b}, [x10], %[srcStride]\n" + "ld1 {v3.8b}, [x10], %[srcStride]\n" + + "trn1 v8.4h, v4.4h, v6.4h\n" + "trn2 v9.4h, v4.4h, v6.4h\n" + "trn1 v10.4h, v5.4h, v7.4h\n" + "trn2 v11.4h, v5.4h, v7.4h\n" + + "trn1 v4.8b, v0.8b, v1.8b\n" + "trn2 v5.8b, v0.8b, v1.8b\n" + "trn1 v6.8b, v2.8b, v3.8b\n" + "trn2 v7.8b, v2.8b, v3.8b\n" + + "trn1 v12.4h, v4.4h, v6.4h\n" + "trn2 v13.4h, v4.4h, v6.4h\n" + "trn1 v14.4h, v5.4h, v7.4h\n" + "trn2 v15.4h, v5.4h, v7.4h\n" + + "trn1 v0.2s, v8.2s, v12.2s\n" + "trn2 v4.2s, v8.2s, v12.2s\n" + "trn1 v1.2s, v10.2s, v14.2s\n" + "trn2 v5.2s, v10.2s, v14.2s\n" + "trn1 v2.2s, v9.2s, v13.2s\n" + "trn2 v6.2s, v9.2s, v13.2s\n" + "trn1 v3.2s, v11.2s, v15.2s\n" + "trn2 v7.2s, v11.2s, v15.2s\n" + + "st1 {v0.8b}, [x11], %[dstStride]\n" + "st1 {v1.8b}, [x11], %[dstStride]\n" + "st1 {v2.8b}, [x11], %[dstStride]\n" + "st1 {v3.8b}, [x11], %[dstStride]\n" + "st1 {v4.8b}, [x11], %[dstStride]\n" + "st1 {v5.8b}, [x11], %[dstStride]\n" + "st1 {v6.8b}, [x11], %[dstStride]\n" + "st1 {v7.8b}, [x11], %[dstStride]\n" + : + : + [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31"); +#elif ENABLE_ARM32 + size_t srcStride = channel * sizeof(int8_t); + size_t dstStride = plane * sizeof(int8_t); + asm volatile( + "mov r10, %[src_ptr]\n" + "mov r12, %[dst_ptr]\n" + + "vld1.8 {d0}, [r10], %[srcStride]\n" + "vld1.8 {d1}, [r10], %[srcStride]\n" + "vld1.8 {d2}, [r10], %[srcStride]\n" + "vld1.8 {d3}, [r10], %[srcStride]\n" + "vld1.8 {d4}, [r10], %[srcStride]\n" + "vld1.8 {d5}, [r10], %[srcStride]\n" + "vld1.8 {d6}, [r10], %[srcStride]\n" + "vld1.8 {d7}, [r10], %[srcStride]\n" + + "vtrn.8 d0, d1\n" + "vtrn.8 d2, d3\n" + "vtrn.8 d4, d5\n" + "vtrn.8 d6, d7\n" + + "vtrn.16 d0, d2\n" + "vtrn.16 d1, d3\n" + "vtrn.16 d4, d6\n" + "vtrn.16 d5, d7\n" + + "vtrn.32 d0, d4\n" + "vtrn.32 d1, d5\n" + "vtrn.32 d2, d6\n" + "vtrn.32 d3, d7\n" + + "vst1.8 {d0}, [r12], %[dstStride]\n" + "vst1.8 {d1}, [r12], %[dstStride]\n" + "vst1.8 {d2}, [r12], %[dstStride]\n" + "vst1.8 {d3}, [r12], %[dstStride]\n" + "vst1.8 {d4}, [r12], %[dstStride]\n" + "vst1.8 {d5}, [r12], %[dstStride]\n" + "vst1.8 {d6}, [r12], %[dstStride]\n" + "vst1.8 {d7}, [r12], %[dstStride]\n" + : + : + [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +#else + for (int tr = 0; tr < C8NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; + } + } +#endif + } + for (; c < channel; c++) { + const int8_t *src_ptr = src_batch + hw * channel + c; + int8_t *dst_ptr = dst_batch + c * plane + hw; + for (size_t i = 0; i < C8NUM; i++) { + dst_ptr[i] = src_ptr[i * channel]; + } + } + } + for (; hw < plane; hw++) { + const int8_t *src_ptr = src_batch + hw * channel; + int8_t *dst_ptr = dst_batch + hw; + for (size_t i = 0; i < channel; i++) { + dst_ptr[i * plane] = src_ptr[i]; + } + } + } + return; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/pack_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/pack_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..1542c5797fca33a28c1008886cd9295380abd1c9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/pack_int8.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_PACK_INT8_H_ +#define NNACL_INT8_PACK_INT8_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/int8/matmul_int8.h" +#include "nnacl/conv_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel); + +void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); +void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, const ConvParameter *conv_param); +void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight_data, const ConvParameter *conv_param); +#ifdef ENABLE_ARM +void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp); +void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, const int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div, + size_t oc_res, size_t stride); +#endif + +void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); +void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + const ConvQuantArg *quant_qrg); +void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + const ConvQuantArg *quant_qrg); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_PAD_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/pad_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/pad_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..d33f36cf77ee6fae14b2725a45e0b315bb2d014a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/pad_int8.c @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/pad_int8.h" +#include "nnacl/common_func.h" +#include "nnacl/errorcode.h" + +int PadConstant4D(const int8_t *in_data, int8_t *out_data, const int32_t *in_dims, const int32_t *out_dims, + const int32_t *paddings, const int tid, const int thread_num) { + if (thread_num == 0) { + return NNACL_ERR; + } + int32_t copy_size = in_dims[3]; + for (int n = 0; n < in_dims[0]; n++) { + for (int h = tid; h < in_dims[1]; h += thread_num) { + for (int w = 0; w < in_dims[2]; w++) { + const int8_t *in = in_data + Offset(in_dims, n, h, w, 0); + int8_t *out = out_data + Offset(out_dims, n + paddings[0], h + paddings[2], w + paddings[4], paddings[6]); + memcpy(out, in, (size_t)copy_size * sizeof(int8_t)); + } + } + } + return NNACL_OK; +} + +int TransOut2InputDimIndexInt8(int out_dim_index, int left_pad, int in_dim, int offset) { + if (out_dim_index < left_pad) { + // left pad + const int index_sum = left_pad + offset - 1; + return MSMAX(index_sum - out_dim_index, offset); + } + out_dim_index -= left_pad; + if (out_dim_index < in_dim) { + return out_dim_index; + } + // right pad + out_dim_index -= in_dim; + const int index_sum = in_dim - 1 - offset; + return MSMAX(index_sum - out_dim_index, 0); +} + +int GetInputFlattenIndexInt8(int out_flatten_index, const int32_t *input_shape, int mirror_offset, + const int *in_strides, const int *out_strides, const int *paddings) { + int in_flatten_index = 0; + int i; + for (i = 0; i < COMM_SHAPE_SIZE; ++i) { + int left_pad = paddings[i * 2]; + NNACL_CHECK_ZERO_RETURN_ERR(out_strides[i]); + int out_dim_index = out_flatten_index / out_strides[i]; + out_flatten_index %= out_strides[i]; + int in_dim_index = TransOut2InputDimIndexInt8(out_dim_index, left_pad, input_shape[i], mirror_offset); + in_flatten_index += in_dim_index * in_strides[i]; + } + return in_flatten_index; +} + +void MirrorPadInt8(const int8_t *in, int8_t *out, const int32_t *input_shape, int mirror_offset, const int *in_strides, + const int *out_strides, const int *paddings, int begin, int end) { + for (int i = begin; i < end; ++i) { + out[i] = in[GetInputFlattenIndexInt8(i, input_shape, mirror_offset, in_strides, out_strides, paddings)]; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/pad_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/pad_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..9d19e6176768194ba41ee590cd33d5135aa4a2c0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/pad_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_PAD_INT8_H_ +#define NNACL_INT8_PAD_INT8_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/pad_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int PadConstant4D(const int8_t *in_data, int8_t *out_data, const int32_t *in_dims, const int32_t *out_dims, + const int32_t *paddings, const int tid, const int thread_num); +void MirrorPadInt8(const int8_t *in, int8_t *out, const int32_t *input_shape, int mirror_offset, const int *in_strides, + const int *out_strides, const int *paddings, int begin, int end); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_PAD_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/pooling_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/pooling_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..915708d4cafaab3ac7ffba969486d9df92a75a8d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/pooling_int8.c @@ -0,0 +1,516 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/pooling_int8.h" +#include "nnacl/common_func.h" +#include "nnacl/errorcode.h" + +int AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = compute_args->input_channel_; + int in_w = compute_args->input_w_; + int in_h = compute_args->input_h_; + int output_w = compute_args->output_w_; + int output_h = compute_args->output_h_; + int output_batch = compute_args->output_batch_; + int out_plane = output_w * output_h; + float input_scale = quant_args[0][0].scale_; + int input_zp = quant_args[0][0].zp_; + float output_scale = quant_args[1][0].scale_; + int output_zp = quant_args[1][0].zp_; + double real_multiplier = input_scale / output_scale; + const int8_t out_min = INT8_MIN; + const int8_t out_max = INT8_MAX; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int i = 0; i < out_plane; i++) { + int out_w_index = i % output_w; + int out_h_index = i / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + i * channel; + for (int j = 0; j < channel; j++) { + int in_channel_offset = in_batch_offset + j; + int out_channel_offset = out_plane_offset + j; + int16_t tmp_avg = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += *(input_ptr + in_offset); + ++real_count; + } + } // win_w loop + } // win_h loop + if (real_count == 0) { + return NNACL_ERR; + } + int16_t tmp_out = round((float)tmp_avg / (float)real_count); + tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp); + int8_t real_out = tmp_out < out_min ? out_min : tmp_out; + real_out = real_out > out_max ? out_max : real_out; + *(output_ptr + out_channel_offset) = real_out; + } // in_channel loop + } // out_plane loop + } // out_batch loop + return NNACL_OK; +} + +int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args, int task_id, int thread_num) { + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = compute_args->input_channel_; + int c16 = channel / C16NUM; + int in_w = compute_args->input_w_; + int output_w = compute_args->output_w_; + int out_plane = output_w * compute_args->output_h_; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + thread_num = out_tile_count < thread_num ? out_tile_count : thread_num; + int input_zp = quant_args[0][0].zp_; + int output_zp = quant_args[1][0].zp_; + double real_multiplier = quant_args[0][0].scale_ / quant_args[1][0].scale_; + const int8_t out_min = INT8_MIN; + const int8_t out_max = INT8_MAX; + NNACL_CHECK_ZERO_RETURN_ERR(output_w); + for (int batch = 0; batch < compute_args->output_batch_; batch++) { + int in_batch_offset = batch * compute_args->input_h_ * in_w * channel; + int out_batch_offset = batch * compute_args->output_h_ * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + int out_plane_offset = out_batch_offset + index * channel; + int input_stride = (in_h_index * in_w + in_w_index) * channel; + int kw_s = MSMAX(0, -in_w_index); + int kw_e = MSMIN(win_w, in_w - in_w_index); + int kh_s = MSMAX(0, -in_h_index); + int kh_e = MSMIN(win_h, compute_args->input_h_ - in_h_index); + int real_count = (kw_e - kw_s) * (kh_e - kh_s); + if (real_count == 0) { + return NNACL_ERR; + } + + // 16 channels + for (int j = 0; j < c16; j++) { +#ifdef ENABLE_NEON + int16x8_t tmp_avg[2]; + tmp_avg[0] = vmovq_n_s16(0); + tmp_avg[1] = vmovq_n_s16(0); +#else + int16_t tmp_avg[16]; + int16_t real_out[16]; + for (int m = 0; m < C16NUM; ++m) { + tmp_avg[m] = 0; + } +#endif + int in_channel_offset = in_batch_offset + j * C16NUM; + int out_channel_offset = out_plane_offset + j * C16NUM; + + for (int h = kh_s; h < kh_e; h++) { + for (int w = kw_s; w < kw_e; w++) { + int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel; +#ifdef ENABLE_NEON + int8x16_t in_ptr = vld1q_s8(input_ptr + in_offset); + int8x8_t in_data1 = vget_low_s8(in_ptr); + int8x8_t in_data2 = vget_high_s8(in_ptr); + int16x8_t data1 = vmovl_s8(in_data1); + int16x8_t data2 = vmovl_s8(in_data2); + tmp_avg[0] = vaddq_s16(tmp_avg[0], data1); + tmp_avg[1] = vaddq_s16(tmp_avg[1], data2); +#else + for (int k = 0; k < C16NUM; ++k) { + tmp_avg[k] += input_ptr[in_offset + k]; + } +#endif + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + int16_t tmp_data[8]; + int16_t tmp_out[8]; + int16_t tmp_data1[8]; + int16_t tmp_out1[8]; + for (int l = 0; l < C8NUM; l++) { + tmp_data[l] = tmp_avg[0][l] + 128 * real_count; + tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count; + tmp_out[l] -= 128; + tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp; + } + for (int l = 0; l < C8NUM; l++) { + tmp_data1[l] = tmp_avg[1][l] + 128 * real_count; + tmp_out1[l] = (tmp_data1[l] + real_count / 2) / real_count; + tmp_out1[l] -= 128; + tmp_out1[l] = round((tmp_out1[l] - input_zp) * real_multiplier) + output_zp; + } + int8x8_t real_out[2]; + int8x8_t output_min = vdup_n_s8(out_min); + int8x8_t output_max = vdup_n_s8(out_max); + real_out[0] = vqmovn_s16(vld1q_s16(tmp_out)); + real_out[0] = vmin_s8(real_out[0], output_max); + real_out[0] = vmax_s8(real_out[0], output_min); + vst1_s8(output_ptr + out_channel_offset, real_out[0]); + real_out[1] = vqmovn_s16(vld1q_s16(tmp_out1)); + real_out[1] = vmin_s8(real_out[1], output_max); + real_out[1] = vmax_s8(real_out[1], output_min); + vst1_s8(output_ptr + out_channel_offset + 8, real_out[1]); +#else + for (int l = 0; l < C16NUM; ++l) { + int16_t tmp_data = tmp_avg[l] + 128 * real_count; + real_out[l] = (tmp_data + real_count / 2) / real_count - 128; + real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp); + real_out[l] = real_out[l] < out_min ? out_min : real_out[l]; + real_out[l] = real_out[l] > out_max ? out_max : real_out[l]; + *(output_ptr + out_channel_offset + l) = (int8_t)real_out[l]; + } +#endif + } + + // 8 channels + int channel_16_res = channel - c16 * C16NUM; + int c8 = channel_16_res / C8NUM; + int in_c16_offset = in_batch_offset + c16 * C16NUM; + int out_c16_offset = out_plane_offset + c16 * C16NUM; + for (int j = 0; j < c8; j++) { +#ifdef ENABLE_NEON + int16x8_t tmp_avg = vmovq_n_s16(0); +#else + int16_t tmp_avg[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + int16_t real_out[8]; +#endif + int in_channel_offset = in_c16_offset + j * C8NUM; + int out_channel_offset = out_c16_offset + j * C8NUM; + for (int h = kh_s; h < kh_e; h++) { + for (int w = kw_s; w < kw_e; w++) { + int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel; +#ifdef ENABLE_NEON + int8x8_t in_ptr = vld1_s8(input_ptr + in_offset); + int16x8_t data = vmovl_s8(in_ptr); + tmp_avg = vaddq_s16(tmp_avg, data); +#else + for (int k = 0; k < C8NUM; ++k) { + tmp_avg[k] += input_ptr[in_offset + k]; + } +#endif + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + int16_t tmp_data[8]; + int16_t tmp_out[8]; + for (int l = 0; l < C8NUM; l++) { + tmp_data[l] = tmp_avg[l] + 128 * real_count; + tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count; + tmp_out[l] -= 128; + tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp; + } + int8x8_t real_out; + int8x8_t output_min = vdup_n_s8(out_min); + int8x8_t output_max = vdup_n_s8(out_max); + real_out = vqmovn_s16(vld1q_s16(tmp_out)); + real_out = vmin_s8(real_out, output_max); + real_out = vmax_s8(real_out, output_min); + vst1_s8(output_ptr + out_channel_offset, real_out); +#else + for (int l = 0; l < C8NUM; ++l) { + int16_t tmp_data = tmp_avg[l] + 128 * real_count; + real_out[l] = (tmp_data + real_count / 2) / real_count - 128; + real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp); + real_out[l] = real_out[l] < out_min ? out_min : real_out[l]; + real_out[l] = real_out[l] > out_max ? out_max : real_out[l]; + *(output_ptr + out_channel_offset + l) = (int8_t)real_out[l]; + } +#endif + } + + // less than 8 channel + int channel_8_res = channel_16_res - c8 * C8NUM; + int in_c8_offset = in_c16_offset + c8 * C8NUM; + int out_c8_offset = out_c16_offset + c8 * C8NUM; + for (int k = 0; k < channel_8_res; k++) { + int in_channel_offset = in_c8_offset + k; + int out_channel_offset = out_c8_offset + k; + int16_t tmp_avg = 0; + for (int h = kh_s; h < kh_e; h++) { + for (int w = kw_s; w < kw_e; w++) { + int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel; + tmp_avg += input_ptr[in_offset]; + } // win_w loop + } // win_h loop + int16_t tmp_out = round((float)tmp_avg / (float)real_count + 128) - 128; + tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp); + int16_t real_out = tmp_out < out_min ? out_min : tmp_out; + real_out = real_out > out_max ? out_max : real_out; + *(output_ptr + out_channel_offset) = (int8_t)real_out; + } // channel_res loop + } // out_plane loop + } // out_batch loop + } + return NNACL_OK; +} + +void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = compute_args->input_channel_; + int in_w = compute_args->input_w_; + int in_h = compute_args->input_h_; + int output_w = compute_args->output_w_; + int output_h = compute_args->output_h_; + int output_batch = compute_args->output_batch_; + int out_plane = output_w * output_h; + // input channel is equal to output channel + float input_scale = quant_args[0][0].scale_; + int input_zp = quant_args[0][0].zp_; + float output_scale = quant_args[1][0].scale_; + int output_zp = quant_args[1][0].zp_; + double real_multiplier = input_scale / output_scale; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int i = 0; i < out_plane; i++) { + int out_w_index = i % output_w; + int out_h_index = i / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + i * channel; + for (int j = 0; j < channel; j++) { + int in_channel_offset = in_batch_offset + j; + int out_channel_offset = out_plane_offset + j; + int8_t tmp_max = INT8_MIN; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp); + } // in_channel loop + } // out_plane loop + } // out_batch loop +} + +void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args, int task_id, int thread_num) { + int channel = compute_args->input_channel_; + int in_w = compute_args->input_w_; + int in_h = compute_args->input_h_; + int output_w = compute_args->output_w_; + int out_plane = output_w * compute_args->output_h_; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + thread_num = out_tile_count < thread_num ? out_tile_count : thread_num; + int c16 = UP_DIV(channel, 16); + // input channel is equal to output channel + float input_scale = quant_args[0][0].scale_; + int input_zp = quant_args[0][0].zp_; + float output_scale = quant_args[1][0].scale_; + int output_zp = quant_args[1][0].zp_; + double real_multiplier = input_scale / output_scale; + + NNACL_CHECK_ZERO_RETURN(output_w); + for (int batch = 0; batch < compute_args->output_batch_; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * compute_args->output_h_ * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c16 - 1; j++) { + int in_channel_offset = in_batch_offset + j * 16; + int out_channel_offset = out_plane_offset + j * 16; +#ifdef ENABLE_NEON + int8x16_t tmp_max = vdupq_n_s8(INT8_MIN); +#else + int8_t tmp_max[16]; + for (int m = 0; m < C16NUM; ++m) { + tmp_max[m] = INT8_MIN; + } +#endif + for (int h = 0; h < pooling_param->window_h_; h++) { + for (int w = 0; w < pooling_param->window_w_; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset)); +#else + for (int k = 0; k < C16NUM; ++k) { + tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k)); + } +#endif + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + for (int l = 0; l < C16NUM; ++l) { + tmp_max[l] = (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp); + } + vst1q_s8(output_ptr + out_channel_offset, tmp_max); +#else + for (int l = 0; l < C16NUM; ++l) { + *(output_ptr + out_channel_offset + l) = + (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp); + } +#endif + } // in_channel loop + + // res channel + int channel_s = (c16 - 1) * 16; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + int8_t tmp_max = INT8_MIN; + for (int h = 0; h < pooling_param->window_h_; h++) { + for (int w = 0; w < pooling_param->window_w_; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp); + } // channel_res loop + } // out_plane loop + } // out_batch loop + } +} + +void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, int task_id, int thread_num) { + int channel = compute_args->input_channel_; + int in_w = compute_args->input_w_; + int output_w = compute_args->output_w_; + int out_plane = output_w * compute_args->output_h_; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + thread_num = MSMIN(out_tile_count, thread_num); + int8_t out_array[MAX_MAXPOOL_SIZE]; + + NNACL_CHECK_ZERO_RETURN(output_w); + for (int batch = 0; batch < compute_args->output_batch_; batch++) { + int in_batch_offset = batch * compute_args->input_h_ * in_w * channel; + int out_batch_offset = batch * compute_args->output_h_ * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = out_plane - cal_start_index; + real_cal_num = MSMIN(real_cal_num, TILE_NUM); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_; + int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_; + const int ky_s = 0 > (-in_h_index) ? 0 : (-in_h_index); + int ky_e = MSMIN(compute_args->window_h_, compute_args->input_h_ - in_h_index); + const int kx_s = 0 > (-in_w_index) ? 0 : (-in_w_index); + int kx_e = MSMIN(compute_args->window_w_, in_w - in_w_index); + int input_stride = (in_h_index * in_w + in_w_index) * channel + in_batch_offset; + int out_plane_offset = out_batch_offset + index * channel; + + int c = 0; + for (; c < channel; c += MAX_MAXPOOL_SIZE) { + int real_channel = channel - c; + real_channel = MSMIN(real_channel, MAX_MAXPOOL_SIZE); + memset(out_array, INT8_MIN, real_channel); + int8_t *out_data = output_ptr + out_plane_offset + c; + for (int h = ky_s; h < ky_e; ++h) { + int in_h_offset = input_stride + h * in_w * channel + c; + for (int w = kx_s; w < kx_e; ++w) { + const int8_t *in_data = input_ptr + in_h_offset + w * channel; + int j = 0; +#ifdef ENABLE_NEON + const int8_t *tmp_in_data = in_data; + int c16 = real_channel / 16 * 16; + int c8 = real_channel / 8 * 8; + for (; j < c16; j += 16) { + int8x16_t ori_in = vld1q_s8(tmp_in_data); + int8x16_t out_array16 = vld1q_s8(out_array + j); + tmp_in_data += 16; + out_array16 = vmaxq_s8(ori_in, out_array16); + vst1q_s8(out_array + j, out_array16); + } // 16 channel loop + + for (; j < c8; j += 8) { + int8x8_t ori_in = vld1_s8(tmp_in_data); + int8x8_t out_array8 = vld1_s8(out_array + j); + tmp_in_data += 8; + out_array8 = vmax_s8(ori_in, out_array8); + vst1_s8(out_array + j, out_array8); + } // 8 channel loop +#endif + for (; j < real_channel; ++j) { + out_array[j] = out_array[j] > in_data[j] ? out_array[j] : in_data[j]; + } + } // kw loop + } // kh loop + + int j = 0; +#ifdef ENABLE_NEON + int c16 = real_channel / 16 * 16; + int c8 = real_channel / 8 * 8; + int8_t *tmp_out_data = out_data; + for (; j < c16; j += 16) { + vst1q_s8(tmp_out_data, vld1q_s8(out_array + j)); + tmp_out_data += 16; + } // 16 channel loop + + for (; j < c8; j += 8) { + vst1_s8(tmp_out_data, vld1_s8(out_array + j)); + tmp_out_data += 8; + } // 8 channel loop +#endif + for (; j < real_channel; ++j) { + out_data[j] = out_array[j]; + } + } // 256 channel loop + } // out_plane loop + } // out_batch loop + } +} diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv3d_tensorrt.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/pooling_int8.h similarity index 32% rename from mindspore-lite/src/extendrt/delegate/tensorrt/op/conv3d_tensorrt.h rename to mindspore-lite/ops/kernel/cpu/nnacl/int8/pooling_int8.h index f4b1e96dcc2a1ea80ce2944f452a37e33a156cb4..a2d318de6e2e0c4a33b299d2e0a8a1c7b9c68d7f 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv3d_tensorrt.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/pooling_int8.h @@ -1,5 +1,5 @@ /** - * Copyright 2021-2022 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,32 +13,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONV3D_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONV3D_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "infer/conv3d.h" - -namespace mindspore::lite { -class Conv3DTensorRT : public TensorRTOp { - public: - Conv3DTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~Conv3DTensorRT() override; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return true; } - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - void SetAttributes(const std::shared_ptr &conv_op, nvinfer1::IConvolutionLayer *current_layer_); -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONV3D_TENSORRT_H_ + +#ifndef NNACL_INT8_POOLING_H_ +#define NNACL_INT8_POOLING_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include "nnacl/op_base.h" +#include "nnacl/fp32/pooling_fp32.h" +#include "nnacl/kernel/pooling.h" + +#ifdef __cplusplus +extern "C" { +#endif +#define MAX_MAXPOOL_SIZE 256 + +int AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args); + +int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args, int task_id, int thread_num); + +void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, const PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args); + +void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, QuantArg **quant_args, int task_id, int thread_num); + +void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + PoolingComputeParam *compute_args, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_POOLING_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/power_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/power_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..1fa16faaee6944e9140b36167590721c41d5e2aa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/power_int8.c @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/power_int8.h" + +int PowerInt8(const int8_t *input, const int8_t *exp_ptr, int8_t *output, int count, const PowQuantArg *args, + bool broadcast, int scale, int shift) { + double input_scale = args->in_args_.scale_; + int input_zp = args->in_args_.zp_; + double output_scale = args->out_args_.scale_; + int output_zp = args->out_args_.zp_; + int act_min = args->output_activation_min_; + int act_max = args->output_activation_max_; + double exp_scale = args->exp_args_.scale_; + int exp_zp = args->exp_args_.zp_; + + if (broadcast) { + float exp_val = exp_scale * (exp_ptr[0] - exp_zp); + for (int i = 0; i < count; ++i) { + float input_val = input_scale * (input[i] - input_zp); + float output_val = pow(scale * input_val + shift, exp_val); + int32_t output_scaled = round(output_val / output_scale) + output_zp; + output[i] = (int8_t)MSMAX(act_min, MSMIN(output_scaled, act_max)); + } + } else { + for (int i = 0; i < count; ++i) { + float input_val = input_scale * (input[i] - input_zp); + float exp_val = exp_scale * (exp_ptr[i] - exp_zp); + float output_val = pow(scale * input_val + shift, exp_val); + int32_t output_scaled = round(output_val / output_scale) + output_zp; + output[i] = (int8_t)MSMAX(act_min, MSMIN(output_scaled, act_max)); + } + } + return 0; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/power_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/power_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..a569f4c49f065a0f0ddb65c9b98e6efc006dadb0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/power_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_POWER_INT8_H_ +#define NNACL_INT8_POWER_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/pow_parameter.h" +#include "nnacl/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int PowerInt8(const int8_t *input, const int8_t *exp_ptr, int8_t *output, int count, const PowQuantArg *args, + bool broadcast, int scale, int shift); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_POWER_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/quant_dtype_cast_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/quant_dtype_cast_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..4b1546b16eacf2fea8f7c75166f15a79119ffd0d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/quant_dtype_cast_int8.c @@ -0,0 +1,437 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl/errorcode.h" +#ifdef ENABLE_ARM +#include +#endif + +#ifdef ENABLE_ARM64 +inline void Int8ToFp32_arm64(const int8_t *quant_values, float *dst, float scale, int32_t zp, int size) { + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "dup v20.4s, %w[zp32]\n" + "dup v21.4s, %w[scale]\n" + + "cmp w8, #16\n" + "blt 1f\n" + + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v7.16b}, [%[quant_values]], #16\n" + + "sxtl v8.8h, v7.8b\n" + "sxtl2 v9.8h, v7.16b\n" + + "sxtl v0.4s, v8.4h\n" + "sxtl2 v1.4s, v8.8h\n" + "sxtl v2.4s, v9.4h\n" + "sxtl2 v3.4s, v9.8h\n" + "sub v0.4s, v0.4s, v20.4s\n" + "sub v1.4s, v1.4s, v20.4s\n" + "sub v2.4s, v2.4s, v20.4s\n" + "sub v3.4s, v3.4s, v20.4s\n" + "scvtf v4.4s, v0.4s\n" + "scvtf v5.4s, v1.4s\n" + "scvtf v6.4s, v2.4s\n" + "scvtf v7.4s, v3.4s\n" + + "fmul v0.4s, v4.4s, v21.4s\n" + "fmul v1.4s, v5.4s, v21.4s\n" + "fmul v2.4s, v6.4s, v21.4s\n" + "fmul v3.4s, v7.4s, v21.4s\n" + + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[dst]], #64\n" + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "ldrsb w9, [%[quant_values]], #1\n" + + "subs w8, w8, #1\n" + "sub w9, w9, %w[zp32]\n" + "scvtf s9, w9\n" + + "fmul s9, s9, s21\n" + "str s9, [%[dst]], #4\n" + "bne 1b\n" + + "2:\n" + + : + : [ quant_values ] "r"(quant_values), [ dst ] "r"(dst), [ scale ] "r"(scale), [ zp32 ] "r"(zp), [ size ] "r"(size) + : "w8", "w9", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v20", "v21"); +} +#endif + +int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + +#ifdef ENABLE_ARM64 + Int8ToFp32_arm64(quant_values, real_values, scale, zp, size); +#else + for (int i = 0; i < size; i++) { + real_values[i] = (quant_values[i] - zp) * scale; + } +#endif + return NNACL_OK; +} + +#ifdef ENABLE_ARM64 +inline void Fp32ToInt8_arm64(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + int32_t min_value, int32_t max_value) { + float ivs = 1.0f / scale; + + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "dup v12.4s, %w[ivs]\n" + "dup v13.4s, %w[min_value]\n" + "dup v14.4s, %w[max_value]\n" + "cmp w8, #16\n" + "blt 1f\n" + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[real_values]], #64\n" + "dup v8.4s, %w[zp]\n" + "dup v9.4s, %w[zp]\n" + "dup v10.4s, %w[zp]\n" + "dup v11.4s, %w[zp]\n" + "scvtf v4.4s, v8.4s\n" + "scvtf v5.4s, v9.4s\n" + "scvtf v6.4s, v10.4s\n" + "scvtf v7.4s, v11.4s\n" + "fmla v4.4s, v0.4s, v12.4s\n" + "fmla v5.4s, v1.4s, v12.4s\n" + "fmla v6.4s, v2.4s, v12.4s\n" + "fmla v7.4s, v3.4s, v12.4s\n" + + "fcvtas v0.4s, v4.4s\n" + "fcvtas v1.4s, v5.4s\n" + "fcvtas v2.4s, v6.4s\n" + "fcvtas v3.4s, v7.4s\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smax v1.4s, v1.4s, v13.4s\n" + "smax v2.4s, v2.4s, v13.4s\n" + "smax v3.4s, v3.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "smin v1.4s, v1.4s, v14.4s\n" + "smin v2.4s, v2.4s, v14.4s\n" + "smin v3.4s, v3.4s, v14.4s\n" + + "sqxtn v4.4h, v0.4s\n" + "sqxtn2 v4.8h, v1.4s\n" + "sqxtn v5.4h, v2.4s\n" + "sqxtn2 v5.8h, v3.4s\n" + "sqxtn v6.8b, v4.8h\n" + "sqxtn2 v6.16b, v5.8h\n" + "st1 {v6.16b}, [%[quant_values]], #16\n" + + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "scvtf s0, %w[zp]\n" + "subs w8, w8, #1\n" + "ldr s4, [%[real_values]], #4\n" + "fmul s4, s4, s12\n" + "fadd s0, s0, s4\n" + "fcvtas s0, s0\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "sqxtn v1.4h, v0.4s\n" + "sqxtn v0.8b, v1.8h\n" + "st1 {v0.b}[0], [%[quant_values]], #1\n" + + "bne 1b\n" + + "2:\n" + : + : [ quant_values ] "r"(quant_values), [ real_values ] "r"(real_values), [ scale ] "r"(scale), [ zp ] "r"(zp), + [ size ] "r"(size), [ ivs ] "r"(ivs), [ min_value ] "r"(min_value), [ max_value ] "r"(max_value) + : "w8", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14"); +} +#endif + +int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Fp32ToInt8_arm64(real_values, quant_values, scale, zp, size, min_value, max_value); +#else + const float inverse_scale = 1.0f / scale; + for (int i = 0; i < size; ++i) { + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + } else if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + } else { + int temp = round(real_values[i] * inverse_scale + zp); + temp = temp < max_value ? temp : max_value; + temp = temp > min_value ? temp : min_value; + quant_values[i] = (int8_t)temp; + } + } +#endif + return NNACL_OK; +} + +#ifdef ENABLE_ARM64 +inline void Fp32ToInt8Perchannel_arm64(const float *real_values, int8_t *quant_values, float *scales, int32_t *zps, + int size, int row_length, int32_t min_value, int32_t max_value) { + volatile float ivs[size]; + for (int i = 0; i < size; i++) { + volatile int channel_index = i / row_length; + ivs[i] = 1.0f / scales[channel_index]; + } + volatile int32_t zp = zps[0]; + + asm volatile( + "mov w8, %w[size]\n" + "cmp w8, #0\n" + "beq 2f\n" + + "mov x4, %[ivs]\n" // reload ivs + "dup v13.4s, %w[min_value]\n" + "dup v14.4s, %w[max_value]\n" + "cmp w8, #16\n" + "blt 1f\n" + "0:\n" + "subs w8, w8, #16\n" + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[real_values]], #64\n" + "dup v8.4s, %w[zp]\n" + "dup v9.4s, %w[zp]\n" + "dup v10.4s, %w[zp]\n" + "dup v11.4s, %w[zp]\n" + "scvtf v4.4s, v8.4s\n" + "scvtf v5.4s, v9.4s\n" + "scvtf v6.4s, v10.4s\n" + "scvtf v7.4s, v11.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v4.4s, v0.4s, v12.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v5.4s, v1.4s, v12.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v6.4s, v2.4s, v12.4s\n" + "ld1 {v12.4s}, [x4], #16\n" + "fmla v7.4s, v3.4s, v12.4s\n" + + "fcvtas v0.4s, v4.4s\n" + "fcvtas v1.4s, v5.4s\n" + "fcvtas v2.4s, v6.4s\n" + "fcvtas v3.4s, v7.4s\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smax v1.4s, v1.4s, v13.4s\n" + "smax v2.4s, v2.4s, v13.4s\n" + "smax v3.4s, v3.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "smin v1.4s, v1.4s, v14.4s\n" + "smin v2.4s, v2.4s, v14.4s\n" + "smin v3.4s, v3.4s, v14.4s\n" + + "sqxtn v4.4h, v0.4s\n" + "sqxtn2 v4.8h, v1.4s\n" + "sqxtn v5.4h, v2.4s\n" + "sqxtn2 v5.8h, v3.4s\n" + "sqxtn v6.8b, v4.8h\n" + "sqxtn2 v6.16b, v5.8h\n" + "st1 {v6.16b}, [%[quant_values]], #16\n" + + "beq 2f\n" + "cmp w8, #16\n" + "bge 0b\n" + + "1:\n" + "scvtf s0, %w[zp]\n" + "subs w8, w8, #1\n" + "ldr s4, [%[real_values]], #4\n" + "fmul s4, s4, s12\n" + "fadd s0, s0, s4\n" + "fcvtas s0, s0\n" + "smax v0.4s, v0.4s, v13.4s\n" + "smin v0.4s, v0.4s, v14.4s\n" + "sqxtn v1.4h, v0.4s\n" + "sqxtn v0.8b, v1.8h\n" + "st1 {v0.b}[0], [%[quant_values]], #1\n" + + "bne 1b\n" + + "2:\n" + : + : [ quant_values ] "r"(quant_values), [ real_values ] "r"(real_values), [ scales ] "r"(scales), [ zp ] "r"(zp), + [ size ] "r"(size), [ row_length ] "r"(row_length), [ ivs ] "r"(ivs), [ min_value ] "r"(min_value), + [ max_value ] "r"(max_value) + : "w8", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "x4"); +} +#endif + +int DoChannelRowFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int row_length, int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL || scale == NULL || zp == NULL || row_length == 0) { + return NNACL_PARAM_INVALID; + } +#ifdef ENABLE_ARM64 + Fp32ToInt8Perchannel_arm64(real_values, quant_values, scale, zp, size, row_length, min_value, max_value); +#else + for (int i = 0; i < size; ++i) { + int channel_index = i / row_length; + const float inverse_scale = 1.0f / scale[channel_index]; + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + } else if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + } else { + int temp = round(real_values[i] * inverse_scale + zp[channel_index]); + temp = temp < max_value ? temp : max_value; + temp = temp > min_value ? temp : min_value; + quant_values[i] = (int8_t)temp; + } + } +#endif + return NNACL_OK; +} + +int DoChannelColFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int row_length, int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL || scale == NULL || zp == NULL || row_length == 0) { + return NNACL_PARAM_INVALID; + } + int row_total = size / row_length; + for (int r = 0; r < row_total; r++) { + const float *real_current = real_values + r * row_length; + int8_t *quant_current = quant_values + r * row_length; + for (int c = 0; c < row_length; c++) { + const float inverse_scale = 1.0f / scale[c]; + if (real_current[c] == INFINITY) { + quant_current[c] = max_value; + } else if (real_current[c] == -INFINITY) { + quant_current[c] = min_value; + } else { + int temp = round(real_current[c] * inverse_scale + zp[c]); + temp = temp < max_value ? temp : max_value; + temp = temp > min_value ? temp : min_value; + quant_current[c] = (int8_t)temp; + } + } + } + return NNACL_OK; +} + +int DoQuantizeFp32ToInt8FromUint8Source(const float *real_values, int8_t *quant_values, float scale, int32_t zp, + int size, int32_t min_value, int32_t max_value) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + zp += 128; + const float inverse_scale = 1.0f / scale; + for (int i = 0; i < size; ++i) { + if (real_values[i] == INFINITY) { + quant_values[i] = max_value; + } else if (real_values[i] == -INFINITY) { + quant_values[i] = min_value; + } else { + int temp = round(real_values[i] * inverse_scale + zp); + temp -= 128; + temp = temp < 127 ? temp : 127; + temp = temp > -128 ? temp : -128; + quant_values[i] = (int8_t)temp; + } + } + return NNACL_OK; +} + +int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + real_values[i] = (float)((int)quant_values[i] - zp) * scale; + } + return NNACL_OK; +} + +int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + if (isinf(real_values[i])) { + quant_values[i] = 255; + } else { + float temp = (float)round(real_values[i] * 1.0 / scale + zp); + if (temp > 255) { + quant_values[i] = 255; + } else if (temp < 0) { + quant_values[i] = 0; + } else { + quant_values[i] = (uint8_t)temp; + } + } + } + return NNACL_OK; +} + +int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + int temp = quant_values[i] + 128; + if (temp > 255) { + real_values[i] = (uint8_t)255; + } else if (temp < 0) { + real_values[i] = 0; + } else { + real_values[i] = (uint8_t)temp; + } + } + return NNACL_OK; +} + +int UInt8ToInt8(const uint8_t *real_values, int8_t *quant_values, int size) { + if (quant_values == NULL || real_values == NULL) { + return NNACL_PARAM_INVALID; + } + + for (int i = 0; i < size; ++i) { + int temp = (int)real_values[i] - 128; + if (temp > 127) { + quant_values[i] = 127; + } else if (temp < -128) { + quant_values[i] = -128; + } else { + quant_values[i] = (int8_t)temp; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/quant_dtype_cast_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/quant_dtype_cast_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..2065afebd43f0b6c0d3ee906aa9dee071caf6954 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/quant_dtype_cast_int8.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_QUANTDTYPECAST_H_ +#define NNACL_INT8_QUANTDTYPECAST_H_ + +#include "nnacl/op_base.h" + +typedef struct QuantDTypeCastParameter { + OpParameter op_parameter_; + int32_t srcT; + int32_t dstT; + int32_t axis; +} QuantDTypeCastParameter; + +#ifdef __cplusplus +extern "C" { +#endif +int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + int32_t min_value, int32_t max_value); +int DoChannelRowFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int row_length, int32_t min_value, int32_t max_value); +int DoChannelColFp32ToInt8(const float *real_values, int8_t *quant_values, float *scale, int32_t *zp, int size, + int row_length, int32_t min_value, int32_t max_value); +int DoQuantizeFp32ToInt8FromUint8Source(const float *real_values, int8_t *quant_values, float scale, int32_t zp, + int size, int32_t min_value, int32_t max_value); +#ifdef ENABLE_ARM64 +void Fp32ToInt8_arm64(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + int32_t min_value, int32_t max_value); +void Int8ToFp32_arm64(const int8_t *quant_values, float *dst, float scale, int32_t zp, int size); +void Fp32ToInt8Perchannel_arm64(const float *real_values, int8_t *quant_values, float *scales, int32_t *zps, int size, + int row_length, int32_t min_value, int32_t max_value); +#endif +int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size); +int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); +int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size); +int UInt8ToInt8(const uint8_t *real_values, int8_t *quant_values, int size); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_QUANTDTYPECAST_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/quantize.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/quantize.c new file mode 100644 index 0000000000000000000000000000000000000000..dfbfe72aba27e390f8c7abac8b14027704954f9e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/quantize.c @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/quantize.h" +#include + +const uint64_t dSignMask = 1ull << 63; +const uint64_t dExponentMask = 0x7ffull << 52; +const uint64_t dFractionMask = (1ull << 52) - 1; +const int dExponentBias = 1022; +const int dMantissaBits = 52; +const int dInfiniteExponent = 0x7ff; +const double dNormalizer = 0x1p54; +const int dNormalizerBias = 54; +const int iMantissaBits = 31; + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int32_t *right_shift) { + if (quantized_multiplier == NULL || right_shift == NULL) { + return; + } + int shift = 0; + QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift); + *right_shift = -shift; +} + +void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, + int32_t *left_shift, int32_t *right_shift) { + int shift = 0; + QuantizeMultiplierSmallerThanOne(double_multiplier, quantized_multiplier, &shift); + shift = -shift; + if (shift < 0) { + *left_shift = 0; + *right_shift = shift; + } else { + *left_shift = shift; + *right_shift = 0; + } +} + +void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, + int32_t *left_shift, int32_t *right_shift) { + int shift = 0; + const uint32_t scale_bits = (uint32_t)(double_multiplier); + /* multiplier is in[0x40000000, 0x7FFFFF80] range */ + *quantized_multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + if (quantized_multiplier[0] < INT32_C(0x40000000) || quantized_multiplier[0] > INT32_C(0x7FFFFF80)) { + return; + } + /* shift is in [0, 31] range */ + shift = 127 + 31 - 32 - ((uint32_t)(double_multiplier) >> 23); + shift = -shift; + if (shift < 0) { + *left_shift = 0; + *right_shift = shift; + } else { + *left_shift = shift; + *right_shift = 0; + } +} + +uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } + +int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } + +void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int32_t *mini, + int32_t *maxi) { + int32_t min = INT8_MIN; + int32_t max = INT8_MAX; + int32_t quantized_zero = QuantizeToInt8(0, scale, zp); + int32_t quantized_six = QuantizeToInt8(6, scale, zp); + if (is_relu) { + min = min > quantized_zero ? min : quantized_zero; + } else if (is_relu6) { + min = min > quantized_zero ? min : quantized_zero; + max = max < quantized_six ? max : quantized_six; + } else { + // do nothing + } + *mini = min; + *maxi = max; +} + +// quantize from float to int8 +void Quantize(const float *input_data, int length, float scale, int zero_point, int8_t *output_data) { + for (int i = 0; i < length; ++i) { + int q = (int)round(input_data[i] / scale + zero_point); + q = q > SCHAR_MAX ? SCHAR_MAX : q; + q = q < SCHAR_MIN ? SCHAR_MIN : q; + output_data[i] = (int8_t)q; + } +} + +// dequantize from int8 to float +void Dequantize(const int8_t *input_data, int length, float scale, int zero_point, float *output_data) { + for (int i = 0; i < length; ++i) { + output_data[i] = scale * (input_data[i] - zero_point); + } +} + +void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int32_t *shift) { + if (quantized_multiplier == NULL || shift == NULL) { + return; + } + // we split a floating number into two parts: exponent and fraction + // since fraction is stored as int32, only 31 bits of mantissa is remained + union { + double d; + uint64_t ul; + } dul; + dul.d = double_multiplier; + if (!(dul.ul & (~dSignMask))) { + // multiplier is 0 + *quantized_multiplier = 0; + *shift = 0; + return; + } + int exponent = (int)((dul.ul & dExponentMask) >> dMantissaBits); + if (exponent == dInfiniteExponent) { + // multiplier is inf or NaN + *shift = 0; + if (!(dul.ul & dFractionMask)) { + // inf + *quantized_multiplier = (dul.ul & dSignMask) ? INT_MIN : INT_MAX; + } else { + // NaN + *quantized_multiplier = 0; + } + return; + } + if (exponent == 0) { + // multiplier is a subnormal number + dul.d *= dNormalizer; + exponent = (int)((dul.ul & dExponentMask) >> dMantissaBits); + *shift = exponent - dExponentBias - dNormalizerBias; + } else { + *shift = exponent - dExponentBias; + } + uint64_t fraction = dul.ul & dFractionMask; + fraction += (1ull << dMantissaBits); + uint64_t rounded = ((fraction >> (dMantissaBits - iMantissaBits)) + 1ull) >> 1; + // we get 31 rounded bits now + if (rounded == (1ull << iMantissaBits)) { + // rounding may cause a carry + rounded >>= 1; + ++*shift; + } + *quantized_multiplier = (dul.ul & dSignMask) ? (-(int32_t)(rounded)) : (int32_t)(rounded); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/quantize.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/quantize.h new file mode 100644 index 0000000000000000000000000000000000000000..dc65186ed9eda093c46535e0204882e43786e081 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/quantize.h @@ -0,0 +1,222 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_QUANTIZATION_QUANTIZE_H_ +#define NNACL_QUANTIZATION_QUANTIZE_H_ + +#include +#include "nnacl/op_base.h" + +#define INPUT_PER_CHANNEL 0b001 +#define FILTER_PER_CHANNEL 0b010 +#define OUTPUT_PER_CHANNEL 0b100 + +typedef struct ConvQuantArg { + RoundingMode round_mode_; + CalFixedMultiplierMode quant_multiplier_mode_; + QuantArg *input_quant_args_; + QuantArg *filter_quant_args_; + QuantArg *output_quant_args_; + double *real_multiplier_; + int32_t *left_shift_; + int32_t *right_shift_; + int32_t *quant_multiplier_; + int32_t *out_act_min_; + int32_t *out_act_max_; + size_t input_arg_num_; + size_t filter_arg_num_; + size_t output_arg_num_; + uint8_t per_channel_; +} ConvQuantArg; + +typedef struct ConcatQuantArg { + QuantArg *in_args_; + QuantArg out_args_; + int8_t output_activation_min_; + int8_t output_activation_max_; +} ConcatQuantArg; + +typedef struct PreluQuantArg { + int32_t *input_sizes_; + int output_size_; + int32_t **input_shapes_; + int32_t *output_shape_; + size_t input_num_; + size_t output_dim_; + float alpha_; + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + QuantArg *in_quant_args_; + QuantArg out_quant_args_; +} PreluQuantArg; + +typedef struct CropQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +} CropQuantArg; + +typedef struct ArithSelfQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + int output_multiplier_; + int shift_left_; + int shift_right_; +} ArithSelfQuantArg; + +typedef struct GatherQuantArg { + double alpha_; + int zp_in_; + int zp_out_; +} GatherQuantArg; + +typedef struct DynamicGatherQuantArg { + float *scale_in_; + int32_t *zp_in_; +} DynamicGatherQuantArg; + +typedef struct SoftmaxQuantArg { + QuantArg in_quant_args_; + QuantArg out_quant_arg_; + int output_activation_min_; + int output_activation_max_; + int output_multiplier_; + int shift_left_; + int shift_right_; +} SoftmaxQuantArg; + +typedef struct SubQuantArg { + QuantArg in0_args_; + QuantArg in1_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + int input0_multiplier_; + int input1_multiplier_; + int output_multiplier_; + int input0_shift_; + int input1_shift_; + int output_shift_; + int left_shift_result0_; + int left_shift_result1_; + int right_shift0_; + int right_shift1_; + int left_shift_out_; + int right_shift_out_; +} SubQuantArg; + +typedef struct ArithmeticQuantArg { + QuantArg in0_args_; + QuantArg in1_args_; + QuantArg out_args_; +} ArithmeticQuantArg; + +typedef struct DivQuantArg { + QuantArg in0_args_; + QuantArg in1_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + int output_multiplier_; + int output_shift_; +} DivQuantArg; + +typedef struct ReduceQuantArg { + double in_scale_; + int32_t in_zp_; + double out_scale_; + int32_t out_zp_; + int32_t in_out_multiplier_; + int in_out_left_shift_; + int in_out_right_shift_; + int32_t mean_multiplier_; + int mean_left_shift_; + int mean_right_shift_; + int32_t prod_multiplier_; + int prod_left_shift_; + int prod_right_shift_; + int32_t sum_square_multiplier_; + int sum_square_left_shift_; + int sum_square_right_shift_; +} ReduceQuantArg; + +typedef struct LeakyReluQuantArg { + QuantArg in_args_; + QuantArg out_args_; + float slope_; + int input_dim_; + int element_num; + int thread_num_; +} LeakyReluQuantArg; + +typedef struct ResizeQuantArg { + int32_t ratio_x_; + int32_t ratio_y_; + int32_t *x_axis_index_; + int32_t *x_axis_lower_; + int32_t *x_axis_upper_; + int32_t *y_axis_index_; + int32_t *y_axis_lower_; + int32_t *y_axis_upper_; +} ResizeQuantArg; + +typedef struct ResizeFloatScaleQuantArg { + float ratio_x_; + float ratio_y_; + float *x_axis_index_; + int32_t *x_axis_lower_; + int32_t *x_axis_upper_; + float *y_axis_index_; + int32_t *y_axis_lower_; + int32_t *y_axis_upper_; +} ResizeFloatScaleQuantArg; + +#ifdef __cplusplus +extern "C" { +#endif + +void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int32_t *shift); + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int32_t *right_shift); + +void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, + int32_t *left_shift, int32_t *right_shift); + +void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, + int32_t *left_shift, int32_t *right_shift); + +uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp); + +int32_t QuantizeToInt8(float real_value, float scale, int32_t zp); + +void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int32_t *mini, + int32_t *maxi); +// quantize from float to int8 +void Quantize(const float *input_data, int length, float scale, int zero_point, int8_t *output_data); + +// dequantize from int8 to float +void Dequantize(const int8_t *input_data, int length, float scale, int zero_point, float *output_data); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_QUANTIZATION_QUANTIZE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/reduce_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/reduce_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..e40e45be404abf496894ae2ac09de59b38a2007a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/reduce_int8.c @@ -0,0 +1,597 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "nnacl/int8/reduce_int8.h" +#include "nnacl/errorcode.h" +#include "nnacl/int8/fixed_point.h" +#include "nnacl/common_func.h" + +int ReduceMeanN(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanNH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanNW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanNC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} +int ReduceMeanHW(int n, int plane, int count, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg, + int32_t bias) { + int stride = plane * UP_ROUND(c, C4NUM); + for (int batch = 0; batch < n; ++batch) { + int8_t *in_ptr = in_data + batch * stride; + int8_t *out_ptr = out_data + batch * c; + for (int i = 0; i < count; ++i) { + int32_t sum_array = 0; + int j = 0; +#ifdef ENABLE_ARM64 + for (; j < plane; j += 16) { + int8x16_t in_data_vec = vld1q_s8(in_ptr); + sum_array += vaddlvq_s8(in_data_vec); + in_ptr += 16; + } + for (; j < plane; j += 8) { + int8x8_t in_data_vec = vld1_s8(in_ptr); + sum_array += vaddlv_s8(in_data_vec); + in_ptr += 8; + } + for (; j < plane; j += 4) { + int32x4_t in_data_vec; + in_data_vec[0] = in_ptr[0]; + in_data_vec[1] = in_ptr[1]; + in_data_vec[2] = in_ptr[2]; + in_data_vec[3] = in_ptr[3]; + sum_array += vaddvq_s32(in_data_vec); + in_ptr += 4; + } +#elif ENABLE_ARM32 + int32x4_t accum = vmovq_n_s32(0); + for (; j < plane; j += 16) { + int32x4_t in_data_vec1; + int32x4_t in_data_vec2; + int32x4_t in_data_vec3; + int32x4_t in_data_vec4; + in_data_vec1[0] = in_ptr[0]; + in_data_vec1[1] = in_ptr[1]; + in_data_vec1[2] = in_ptr[2]; + in_data_vec1[3] = in_ptr[3]; + in_data_vec2[0] = in_ptr[4]; + in_data_vec2[1] = in_ptr[5]; + in_data_vec2[2] = in_ptr[6]; + in_data_vec2[3] = in_ptr[7]; + in_data_vec3[0] = in_ptr[8]; + in_data_vec3[1] = in_ptr[9]; + in_data_vec3[2] = in_ptr[10]; + in_data_vec3[3] = in_ptr[11]; + in_data_vec4[0] = in_ptr[12]; + in_data_vec4[1] = in_ptr[13]; + in_data_vec4[2] = in_ptr[14]; + in_data_vec4[3] = in_ptr[15]; + accum = vaddq_s32(accum, in_data_vec1); + accum = vaddq_s32(accum, in_data_vec2); + accum = vaddq_s32(accum, in_data_vec3); + accum = vaddq_s32(accum, in_data_vec4); + in_ptr += 16; + } + for (; j < plane; j += 8) { + int32x4_t in_data_vec1; + int32x4_t in_data_vec2; + in_data_vec1[0] = in_ptr[0]; + in_data_vec1[1] = in_ptr[1]; + in_data_vec1[2] = in_ptr[2]; + in_data_vec1[3] = in_ptr[3]; + in_data_vec2[0] = in_ptr[4]; + in_data_vec2[1] = in_ptr[5]; + in_data_vec2[2] = in_ptr[6]; + in_data_vec2[3] = in_ptr[7]; + accum = vaddq_s32(accum, in_data_vec1); + accum = vaddq_s32(accum, in_data_vec2); + in_ptr += 8; + } + for (; j < plane; j += 4) { + int32x4_t in_data_vec; + in_data_vec[0] = in_ptr[0]; + in_data_vec[1] = in_ptr[1]; + in_data_vec[2] = in_ptr[2]; + in_data_vec[3] = in_ptr[3]; + accum = vaddq_s32(accum, in_data_vec); + in_ptr += 4; + } + sum_array += accum[0]; + sum_array += accum[1]; + sum_array += accum[2]; + sum_array += accum[3]; +#endif + for (; j < plane; j++) { + sum_array += in_ptr[0]; + in_ptr++; + } + int32_t mean = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum_array * (1 << (unsigned int)quant_arg.left_shift_), + quant_arg.multiplier_), + quant_arg.right_shift_); + mean += bias; + *out_ptr++ = MSMAX(MSMIN(mean, INT8_MAX), INT8_MIN); + } + } + return NNACL_OK; +} + +int ReduceMeanHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanNHW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanNHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanNWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +int ReduceMeanNHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg) { + return NNACL_OK; +} + +// Get x such that (x-zp_in) * scale_in = mean +// Assuming reduce n axes, this works for first n-1 reduce. One call for one reduce. +int ReduceMeanInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isAddOverflow(sum, tmp)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + int32_t mean = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->mean_left_shift_), quant->mean_multiplier_), + quant->mean_right_shift_); + if (isAddOverflow(mean, quant->in_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + *inner_dst = mean + quant->in_zp_; + } + } + return NNACL_OK; +} + +// suppose reduce n axes, this works for last reduce axis. +// get y such that (y-zp_out) * scale_out = mean(x-zp_in)*scale_in +int ReduceMeanLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isAddOverflow(tmp, sum)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + int32_t mean = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->mean_left_shift_), quant->mean_multiplier_), + quant->mean_right_shift_); + // trans to output scale + int32_t mean_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(mean * (1 << (unsigned int)quant->in_out_left_shift_), + quant->in_out_multiplier_), + quant->in_out_right_shift_); + if (isAddOverflow(mean_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + mean = mean_scaled + quant->out_zp_; + + *inner_dst = MSMAX(MSMIN(mean, INT8_MAX), INT8_MIN); + } + } + return NNACL_OK; +} + +// Get x such that (x-zp_in) * scale_in = sum(item-zp_in)*scale_in +// Assuming reduce n axes, this works for first n-1 reduce. One call for one reduce. +int ReduceSumInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isAddOverflow(tmp, sum)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + + if (isAddOverflow(quant->in_zp_, sum)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + *inner_dst = sum + quant->in_zp_; + } + } + return NNACL_OK; +} + +// suppose reduce n axes, this works for last reduce axis. +// get y such that (y-zp_out) * scale_out = sum(item-zp_in)*scale_in +int ReduceSumLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isAddOverflow(tmp, sum)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + int32_t sum_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->in_out_left_shift_), + quant->in_out_multiplier_), + quant->in_out_right_shift_); + if (isAddOverflow(sum_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum = sum_scaled + quant->out_zp_; + if (sum > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (sum < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)sum; + } + } + } + return NNACL_OK; +} + +int ReduceMaxLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t tmp = INT8_MIN; + for (i = 0; i < axis_size; i++) { + tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + int32_t tmp_scaled = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul((tmp - quant->in_zp_) * (1 << (unsigned int)quant->in_out_left_shift_), + quant->in_out_multiplier_), + quant->in_out_right_shift_); + if (isAddOverflow(tmp_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + tmp = tmp_scaled + quant->out_zp_; + if (tmp > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (tmp < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)tmp; + } + } + } + return NNACL_OK; +} + +int ReduceMaxInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t tmp = INT8_MIN; + for (i = 0; i < axis_size; i++) { + tmp = tmp > inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceMinLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + const int base_offset = 20; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t tmp = INT8_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + int32_t tmp_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + (tmp - quant->in_zp_) * (1 << ((unsigned int)quant->in_out_left_shift_ + base_offset)), + quant->in_out_multiplier_), + quant->in_out_right_shift_ + base_offset); + if (isAddOverflow(tmp_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + tmp = tmp_scaled + quant->out_zp_; + if (tmp > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (tmp < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)tmp; + } + } + } + return NNACL_OK; +} + +int ReduceMinInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t tmp = INT8_MAX; + for (i = 0; i < axis_size; i++) { + tmp = tmp < inner_src[i * inner_size] ? tmp : inner_src[i * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} + +int ReduceProdLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t prod = 1; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isMulOverflow(prod, tmp)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + prod *= tmp; + } + prod = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(prod * (1 << (unsigned int)quant->prod_left_shift_), quant->prod_multiplier_), + quant->prod_right_shift_); + int32_t prod_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(prod * (1 << (unsigned int)quant->in_out_left_shift_), + quant->in_out_multiplier_), + quant->in_out_right_shift_); + if (isAddOverflow(prod_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + prod = prod_scaled + quant->out_zp_; + if (prod > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (prod < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)prod; + } + } + } + return NNACL_OK; +} + +int ReduceProdInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t prod = 1; + for (i = 0; i < axis_size; i++) { + int32_t tmp = inner_src[i * inner_size] - quant->in_zp_; + if (isMulOverflow(prod, tmp)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + prod *= tmp; + } + prod = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(prod * (1 << (unsigned int)quant->prod_left_shift_), quant->prod_multiplier_), + quant->prod_right_shift_); + if (isAddOverflow(prod, quant->in_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + *inner_dst = prod + quant->in_zp_; + } + } + return NNACL_OK; +} + +int ReduceSumSquareLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int8_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int8_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp; + if (isMulOverflow(inner_src[i * inner_size] - quant->in_zp_, inner_src[i * inner_size] - quant->in_zp_)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + tmp = (inner_src[i * inner_size] - quant->in_zp_) * (inner_src[i * inner_size] - quant->in_zp_); + if (isAddOverflow(sum, tmp)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + int32_t sum_scaled = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->sum_square_left_shift_), + quant->sum_square_multiplier_), + quant->sum_square_right_shift_); + if (isAddOverflow(sum_scaled, quant->out_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum = sum_scaled + quant->out_zp_; + + if (sum > INT8_MAX) { + *inner_dst = INT8_MAX; + } else if (sum < INT8_MIN) { + *inner_dst = INT8_MIN; + } else { + *inner_dst = (int8_t)sum; + } + } + } + return NNACL_OK; +} + +int ReduceSumSquareInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j, k; + for (j = tid; j < outer_size; j += thread_num) { + const int32_t *outer_src = src_data + j * axis_size * inner_size; + int32_t *outer_dst = dst_data + j * inner_size; + for (k = 0; k < inner_size; k++) { + const int32_t *inner_src = outer_src + k; + int32_t *inner_dst = outer_dst + k; + int32_t sum = 0; + for (i = 0; i < axis_size; i++) { + int32_t tmp; + if (isMulOverflow(inner_src[i * inner_size] - quant->in_zp_, inner_src[i * inner_size] - quant->in_zp_)) { + return NNACL_ERRCODE_MUL_OVERFLOW; + } + tmp = (inner_src[i * inner_size] - quant->in_zp_) * (inner_src[i * inner_size] - quant->in_zp_); + if (isAddOverflow(sum, tmp)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + sum += tmp; + } + sum = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(sum * (1 << (unsigned int)quant->sum_square_left_shift_), + quant->sum_square_multiplier_), + quant->sum_square_right_shift_); + if (isAddOverflow(sum, quant->in_zp_)) { + return NNACL_ERRCODE_ADD_OVERFLOW; + } + *inner_dst = sum + quant->in_zp_; + } + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/reduce_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/reduce_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..ec1c7d38a8024753a509100dbbe856cc9f0105ef --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/reduce_int8.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_REDUCE_INT8_H_ +#define NNACL_INT8_REDUCE_INT8_H_ + +#include "nnacl/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int ReduceMeanN(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNH(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanHW(int n, int plane, int count, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg, + int32_t bias); +int ReduceMeanHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNHW(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNHC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); +int ReduceMeanNHWC(int n, int h, int w, int c, int8_t *in_data, int8_t *out_data, QuantMulArg quant_arg); + +int ReduceMeanInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMeanLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceSumInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceSumLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMaxInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMaxLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMinInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceMinLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceProdLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceProdInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceSumSquareLastAxis(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int8_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +int ReduceSumSquareInt8(const int outer_size, const int inner_size, const int axis_size, const int32_t *src_data, + int32_t *dst_data, const ReduceQuantArg *quant, const int tid, const int thread_num); +#ifdef __cplusplus +} +#endif +#endif // NNACL_INT8_REDUCE_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/relux_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/relux_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..230e2bdd4060affcecb9db0ac5daa25f4e2763b7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/relux_int8.c @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/int8/relux_int8.h" + +void ReluXInt8(const int8_t *src, int length, int8_t *dst, const ReluXQuantArg *arg) { + for (int i = 0; i < length; ++i) { + if (src[i] <= arg->input_arg.zp_) { + dst[i] = arg->output_arg.zp_; + continue; + } + const int32_t input_val = src[i] - arg->input_arg.zp_; + const int32_t scaled_input = SaturatingRoundingDoublingHighMul(input_val, arg->input_multiplier_); + const int32_t shifted_input = RoundingDivideByPOT(scaled_input * (1U << arg->left_shift_), -arg->right_shift_); + const int32_t output = shifted_input + arg->output_arg.zp_; + dst[i] = (int8_t)MSMIN(output, arg->quantized_output_max); + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/relux_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/relux_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..7a3fa793bb605c5487e3cd2d38fc3598bd0dca7a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/relux_int8.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_RELU_INT8_H_ +#define NNACL_INT8_RELU_INT8_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/int8/fixed_point.h" +#include "nnacl/int8/quantize.h" + +typedef struct ReluXQuantArg { + QuantArg input_arg; + QuantArg output_arg; + int input_multiplier_; + int left_shift_; + int right_shift_; + int quantized_output_min; + int quantized_output_max; +} ReluXQuantArg; + +#ifdef __cplusplus +extern "C" { +#endif +void ReluXInt8(const int8_t *src, int length, int8_t *dst, const ReluXQuantArg *arg); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_RELU_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/reshape_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/reshape_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..aeef34034f09afac4483ec748a9aa4eeda9c8051 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/reshape_int8.c @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/reshape_int8.h" +#include "nnacl/reshape_parameter.h" +#include + +void Int8Reshape(const int8_t *input_ptr, int8_t *output_ptr, int64_t real_dst_count, ReshapeQuantArg para) { + if (para.in_args_.scale_ == para.out_args_.scale_ && para.in_args_.zp_ == para.out_args_.zp_) { + memcpy(output_ptr, input_ptr, real_dst_count); + } else { + const float output_inverse_scale = 1.f / para.out_args_.scale_; + float scale = para.in_args_.scale_ * output_inverse_scale; + float bias = -para.in_args_.zp_ * scale; + int32_t output_zp = para.out_args_.zp_; + for (int i = 0; i < real_dst_count; i++) { + int32_t output_tmp = round(input_ptr[i] * scale + bias) + output_zp; + if (output_tmp > para.output_activation_max_) { + output_ptr[i] = para.output_activation_max_; + } else if (output_tmp < para.output_activation_min_) { + output_ptr[i] = para.output_activation_min_; + } else { + output_ptr[i] = (int8_t)output_tmp; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/reshape_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/reshape_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..e6ba4ad47b60bc89ac4a148c647d7f4217829ab8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/reshape_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_RESHAHPE_INT8_H_ +#define NNACL_INT8_RESHAHPE_INT8_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/reshape_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +void Int8Reshape(const int8_t *input_ptr, int8_t *output_ptr, int64_t real_dst_count, ReshapeQuantArg para); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_RESHAHPE_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/resize_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/resize_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..29927bb0501dcbaf2b133727f304bce958dc977a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/resize_int8.c @@ -0,0 +1,233 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nnacl/int8/resize_int8.h" +#include "nnacl/common_func.h" +#include "nnacl/int8/fixed_point.h" +#include "nnacl/errorcode.h" + +int ResizeBilinearInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, int out_h, int out_w, + int channel, int index, int count, ResizeQuantArg quant_arg) { + if (out_w == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + int in_plane = in_h * in_w; + int out_plane = out_h * out_w; + for (int n = 0; n < batch; n++) { + const int8_t *in_b_ptr = input_ptr + n * in_plane * channel; + int8_t *out_b_ptr = output_ptr + n * out_plane * channel; + for (int t = 0; t < count; t++) { + int ori_out_h = (index + t) / out_w; + int ori_out_w = (index + t) % out_w; + int32_t x_lower_value = quant_arg.x_axis_lower_[ori_out_w]; + int32_t x_upper_value = quant_arg.x_axis_upper_[ori_out_w]; + int32_t y_lower_value = quant_arg.y_axis_lower_[ori_out_h]; + int32_t y_upper_value = quant_arg.y_axis_upper_[ori_out_h]; + int32_t weight_x = quant_arg.x_axis_index_[ori_out_w] - (1 << 10) * x_lower_value; + int32_t one_minus_weight_x = (1 << 10) - weight_x; + int32_t weight_y = quant_arg.y_axis_index_[ori_out_h] - (1 << 10) * y_lower_value; + int32_t one_minus_weight_y = (1 << 10) - weight_y; + int64_t left_bottom_coef = (int64_t)(one_minus_weight_x * one_minus_weight_y); + int64_t left_top_coef = (int64_t)(weight_y * one_minus_weight_x); + int64_t right_bottom_coef = (int64_t)(weight_x * one_minus_weight_y); + int64_t right_top_coef = (int64_t)(weight_x * weight_y); + int input_lb_index = (y_lower_value * in_w + x_lower_value) * channel; + int input_lt_index = (y_upper_value * in_w + x_lower_value) * channel; + int input_rb_index = (y_lower_value * in_w + x_upper_value) * channel; + int input_rt_index = (y_upper_value * in_w + x_upper_value) * channel; + int c = 0; + for (; c < channel; c++) { + int64_t out_left_bottom = left_bottom_coef * in_b_ptr[input_lb_index]; + int64_t out_left_top = left_top_coef * in_b_ptr[input_lt_index]; + int64_t out_right_bottom = right_bottom_coef * in_b_ptr[input_rb_index]; + int64_t out_right_top = right_top_coef * in_b_ptr[input_rt_index]; + int64_t out_value = out_left_bottom + out_left_top + out_right_bottom + out_right_top; + out_b_ptr[0] = (int8_t)((out_value + (1 << 19)) / (1 << 20)); + input_lb_index++; + input_lt_index++; + input_rb_index++; + input_rt_index++; + out_b_ptr++; + } + } + } + return NNACL_OK; +} + +int ResizeBilinearWithFloatScaleInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, + int out_h, int out_w, int channel, int index, int count, + ResizeFloatScaleQuantArg quant_arg) { + if (out_w == 0) { + return NNACL_ERRCODE_DIVISOR_ZERO; + } + int in_plane = in_h * in_w; + int out_plane = out_h * out_w; + for (int n = 0; n < batch; n++) { + const int8_t *in_b_ptr = input_ptr + n * in_plane * channel; + int8_t *out_b_ptr = output_ptr + n * out_plane * channel; + for (int t = 0; t < count; t++) { + int ori_out_h = (index + t) / out_w; + int ori_out_w = (index + t) % out_w; + int32_t x_lower_value = quant_arg.x_axis_lower_[ori_out_w]; + int32_t x_upper_value = quant_arg.x_axis_upper_[ori_out_w]; + int32_t y_lower_value = quant_arg.y_axis_lower_[ori_out_h]; + int32_t y_upper_value = quant_arg.y_axis_upper_[ori_out_h]; + float weight_x = quant_arg.x_axis_index_[ori_out_w] - x_lower_value; + const float one_minus_weight_x = 1 - weight_x; + float weight_y = quant_arg.y_axis_index_[ori_out_h] - y_lower_value; + const float one_minus_weight_y = 1 - weight_y; + float left_bottom_coef = one_minus_weight_x * one_minus_weight_y; + float left_top_coef = weight_y * one_minus_weight_x; + float right_bottom_coef = weight_x * one_minus_weight_y; + float right_top_coef = weight_x * weight_y; + int input_lb_index = (y_lower_value * in_w + x_lower_value) * channel; + int input_lt_index = (y_upper_value * in_w + x_lower_value) * channel; + int input_rb_index = (y_lower_value * in_w + x_upper_value) * channel; + int input_rt_index = (y_upper_value * in_w + x_upper_value) * channel; + int c = 0; +#ifdef ENABLE_ARM + for (; c <= channel - 4; c += 4) { + float32x4_t in_lb; + in_lb[0] = (float)in_b_ptr[input_lb_index]; + in_lb[1] = (float)in_b_ptr[input_lb_index + 1]; + in_lb[2] = (float)in_b_ptr[input_lb_index + 2]; + in_lb[3] = (float)in_b_ptr[input_lb_index + 3]; + float32x4_t out_left_bottom = vmulq_n_f32(in_lb, left_bottom_coef); + float32x4_t in_lt; + in_lt[0] = (float)in_b_ptr[input_lt_index]; + in_lt[1] = (float)in_b_ptr[input_lt_index + 1]; + in_lt[2] = (float)in_b_ptr[input_lt_index + 2]; + in_lt[3] = (float)in_b_ptr[input_lt_index + 3]; + float32x4_t out_left_top = vmulq_n_f32(in_lt, left_top_coef); + float32x4_t in_rb; + in_rb[0] = (float)in_b_ptr[input_rb_index]; + in_rb[1] = (float)in_b_ptr[input_rb_index + 1]; + in_rb[2] = (float)in_b_ptr[input_rb_index + 2]; + in_rb[3] = (float)in_b_ptr[input_rb_index + 3]; + float32x4_t out_right_bottom = vmulq_n_f32(in_rb, right_bottom_coef); + float32x4_t in_rt; + in_rt[0] = (float)in_b_ptr[input_rt_index]; + in_rt[1] = (float)in_b_ptr[input_rt_index + 1]; + in_rt[2] = (float)in_b_ptr[input_rt_index + 2]; + in_rt[3] = (float)in_b_ptr[input_rt_index + 3]; + float32x4_t out_right_top = vmulq_n_f32(in_rt, right_top_coef); + float32x4_t out_value1 = vaddq_f32(out_left_bottom, out_left_top); + float32x4_t out_value2 = vaddq_f32(out_right_top, out_right_bottom); + float32x4_t out_value = vaddq_f32(out_value1, out_value2); + out_b_ptr[0] = (int8_t)(out_value[0]); + out_b_ptr[1] = (int8_t)(out_value[1]); + out_b_ptr[2] = (int8_t)(out_value[2]); + out_b_ptr[3] = (int8_t)(out_value[3]); + input_lb_index += 4; + input_lt_index += 4; + input_rb_index += 4; + input_rt_index += 4; + out_b_ptr += 4; + } +#endif + for (; c < channel; c++) { + float out_left_bottom = left_bottom_coef * in_b_ptr[input_lb_index]; + float out_left_top = left_top_coef * in_b_ptr[input_lt_index]; + float out_right_bottom = right_bottom_coef * in_b_ptr[input_rb_index]; + float out_right_top = right_top_coef * in_b_ptr[input_rt_index]; + float out_value = out_left_bottom + out_left_top + out_right_bottom + out_right_top; + out_b_ptr[0] = (int8_t)(out_value); + input_lb_index++; + input_lt_index++; + input_rb_index++; + input_rt_index++; + out_b_ptr++; + } + } + } + return NNACL_OK; +} + +int ResizeNearestNeighborInt8Simple(const int8_t *input_data, int8_t *output_data, const int32_t *input_shape, + const int32_t *output_shape, const bool align_corners, int tid, int thread_num) { + int batch, y, x, c; + c = output_shape[3]; + int in_h, in_w, new_height, new_width; + in_h = input_shape[1]; + in_w = input_shape[2]; + new_height = output_shape[1]; + new_width = output_shape[2]; + + for (batch = 0; batch < output_shape[0]; batch++) { + for (y = tid; y < output_shape[1]; y += thread_num) { + int input_y = 0; + ComputeNearestNeighborInt(y, in_h, new_height, align_corners, &input_y); + for (x = 0; x < output_shape[2]; x++) { + int input_x = 0; + ComputeNearestNeighborInt(x, in_w, new_width, align_corners, &input_x); + int in_offset = Offset(input_shape, batch, input_y, input_x, 0); + int out_offset = Offset(output_shape, batch, y, x, 0); + memcpy(output_data + out_offset, input_data + in_offset, c * sizeof(int8_t)); + } + } + } + + return NNACL_OK; +} + +void ComputeNearestNeighborInt(const int32_t pos, const int in_size, const int32_t new_size, const bool align_corners, + int32_t *nearest) { + if (new_size == 0) { + return; + } + *nearest = (in_size * pos) / new_size; + if (align_corners && new_size != 1) { + *nearest = ((in_size - 1) * pos + (new_size - 1) / 2) / (new_size - 1); + } + *nearest = *nearest < in_size ? *nearest : in_size - 1; +} + +int ResizeNearestNeighborInt8(const int8_t *input_data, int8_t *output_data, const int32_t *input_shape, + const int32_t *output_shape, const bool align_corners, const QuantMulArg *multiplier, + const QuantArg *quant_in, const QuantArg *quant_out, int tid, int thread_num) { + const int base_offset = 20; + int32_t batch, y, x, c; + int32_t in_h, in_w, new_height, new_width; + in_h = input_shape[1]; + in_w = input_shape[2]; + new_height = output_shape[1]; + new_width = output_shape[2]; + + for (batch = 0; batch < output_shape[0]; batch++) { + for (y = tid; y < output_shape[1]; y += thread_num) { + int input_y = 0; + ComputeNearestNeighborInt(y, in_h, new_height, align_corners, &input_y); + for (x = 0; x < output_shape[2]; x++) { + int input_x = 0; + ComputeNearestNeighborInt(x, in_w, new_width, align_corners, &input_x); + for (c = 0; c < output_shape[3]; c++) { + int in_offset = Offset(input_shape, batch, input_y, input_x, c); + int out_offset = Offset(output_shape, batch, y, x, c); + + int32_t out_value = MultiplyByQuantizedMultiplier( + input_data[in_offset] - quant_in->zp_, multiplier->multiplier_, + multiplier->left_shift_ + base_offset, multiplier->right_shift_ - base_offset) + + quant_out->zp_; + out_value = out_value > INT8_MAX ? INT8_MAX : out_value; + out_value = out_value < INT8_MIN ? INT8_MIN : out_value; + output_data[out_offset] = (int8_t)out_value; + } + } + } + } + + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/resize_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/resize_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..31b088efb470ece430b4a8938f25a293e134e920 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/resize_int8.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_RESIZE_H_ +#define NNACL_INT8_RESIZE_H_ + +#ifdef ENABLE_NEON +#include +#endif +#include +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/resize_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int ResizeBilinearInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, int out_h, int out_w, + int channel, int index, int count, ResizeQuantArg quant_arg); + +int ResizeBilinearWithFloatScaleInt8(const int8_t *input_ptr, int8_t *output_ptr, int batch, int in_h, int in_w, + int out_h, int out_w, int channel, int index, int count, + ResizeFloatScaleQuantArg quant_arg); + +int ResizeNearestNeighborInt8Simple(const int8_t *input_data, int8_t *output_data, const int32_t *input_shape, + const int32_t *output_shape, const bool align_corners, int tid, int thread_num); + +int ResizeNearestNeighborInt8(const int8_t *input_data, int8_t *output_data, const int32_t *input_shape, + const int32_t *output_shape, const bool align_corners, const QuantMulArg *multiplier, + const QuantArg *quant_in, const QuantArg *quant_out, int tid, int thread_num); + +void ComputeNearestNeighborInt(const int32_t pos, const int in_size, const int32_t new_size, const bool align_corners, + int32_t *nearest); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_RESIZE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/scale_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/scale_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..438d88fe363976ed2de37e381049de4851fa5008 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/scale_int8.c @@ -0,0 +1,164 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/scale_int8.h" +#include "nnacl/int8/fixed_point.h" + +#ifdef ENABLE_NEON +int16x4_t ClacSumHalfWordMul2(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t output_multiplier_vec, const ScaleQuantParameter *scale_param) { + int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + scale_param->scale_mul_arg_.right_shift_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_)); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_)); + return vqmovn_s32(raw_sum); +} + +int16x4_t ClacSumHalfWordMul3(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t scaled_input2, + const ScaleQuantParameter *scale_param) { + int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_); + int32x4_t output_multiplier_vec2 = vdupq_n_s32(scale_param->offset_mul_arg_.multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << (size_t)(scale_param->scale_mul_arg_.left_shift_)); + int32x4_t left_shift_out_vec2 = vdupq_n_s32(1 << (size_t)(scale_param->offset_mul_arg_.left_shift_)); + int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + scale_param->scale_mul_arg_.right_shift_); + int32x4_t raw_sum2 = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(scaled_input2, left_shift_out_vec2), output_multiplier_vec2), + scale_param->offset_mul_arg_.right_shift_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_)); + raw_sum = vaddq_s32(raw_sum, raw_sum2); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_)); + return vqmovn_s32(raw_sum); +} +#endif + +void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleQuantParameter *scale_param, + int real_dst_count) { + int index = 0; +#ifdef ENABLE_NEON + int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << scale_param->scale_mul_arg_.left_shift_); + + for (; index <= real_dst_count - 8; index += 8) { + int8x8_t input_s8 = vld1_s8(in_data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_)); + + int8x8_t input1_s8 = vld1_s8(scale + index); + int16x8_t input1_s16 = vmovl_s8(input1_s8); + int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_)); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + + int16x4_t sum_low = + ClacSumHalfWordMul2(input0_low, input1_low, left_shift_out_vec, output_multiplier_vec, scale_param); + int16x4_t sum_high = + ClacSumHalfWordMul2(input0_high, input1_high, left_shift_out_vec, output_multiplier_vec, scale_param); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(out_data, res_u8_n0); + out_data += 8; + } +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = scale_param->input_zp_ + in_data[index]; + const int32_t input1_val = scale_param->scale_zp_ + scale[index]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_), + scale_param->scale_mul_arg_.multiplier_), + scale_param->scale_mul_arg_.right_shift_); + + mul_result += scale_param->output_zp_; + + if (mul_result > scale_param->output_activation_max_) { + out_data[index] = scale_param->output_activation_max_; + } else if (mul_result < scale_param->output_activation_min_) { + out_data[index] = scale_param->output_activation_min_; + } else { + out_data[index] = (int8_t)mul_result; + } + } + return; +} + +void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset, + const ScaleQuantParameter *scale_param, int real_dst_count) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= real_dst_count - 8; index += 8) { + int8x8_t input_s8 = vld1_s8(in_data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_)); + + int8x8_t input1_s8 = vld1_s8(scale + index); + int16x8_t input1_s16 = vmovl_s8(input1_s8); + int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_)); + + int8x8_t input2_s8 = vld1_s8(offset + index); + int16x8_t input2_s16 = vmovl_s8(input2_s8); + int16x8_t input2_val = vaddq_s16(input2_s16, vdupq_n_s16(scale_param->offset_zp_)); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + int32x4_t input2_low = vmovl_s16(vget_low_s16(input2_val)); + int32x4_t input2_high = vmovl_s16(vget_high_s16(input2_val)); + + int16x4_t sum_low = ClacSumHalfWordMul3(input0_low, input1_low, input2_low, scale_param); + int16x4_t sum_high = ClacSumHalfWordMul3(input0_high, input1_high, input2_high, scale_param); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(out_data, res_u8_n0); + out_data += 8; + } +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = in_data[index] - scale_param->input_zp_; + const int32_t input1_val = scale[index] - scale_param->scale_zp_; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_), + scale_param->scale_mul_arg_.multiplier_), + scale_param->scale_mul_arg_.right_shift_); + int tmp_bias = offset[index] - scale_param->offset_zp_; + int bias = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_bias * (1 << (unsigned int)scale_param->offset_mul_arg_.left_shift_), + scale_param->offset_mul_arg_.multiplier_), + scale_param->offset_mul_arg_.right_shift_); + + mul_result += bias + scale_param->output_zp_; + + if (mul_result > scale_param->output_activation_max_) { + out_data[index] = scale_param->output_activation_max_; + } else if (mul_result < scale_param->output_activation_min_) { + out_data[index] = scale_param->output_activation_min_; + } else { + out_data[index] = (int8_t)mul_result; + } + } + return; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/scale_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/scale_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..14bedf43bd8ed02f1151c64d656709f243a3a20d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/scale_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SCALE_INT8_H_ +#define NNACL_SCALE_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/scale_parameter.h" +#include "nnacl/nnacl_common.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleQuantParameter *scale_param, + int real_dst_count); +void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset, + const ScaleQuantParameter *scale_param, int real_dst_count); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_SCALE_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/sigmoid_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/sigmoid_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..cf157567b4f0f86c34ffcb5010c79fd2ca394c87 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/sigmoid_int8.c @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/sigmoid_int8.h" + +int SigmoidInt8(const int8_t *src, int length, int8_t *dst, int8_t *table) { + for (int i = 0; i < length; i++) { + const int8_t input_value = src[i]; + uint8_t index = (uint8_t)input_value; + dst[i] = table[index]; + } + return 0; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/sigmoid_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/sigmoid_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..33b73cfff79d5898d919206b7375f41343abd0a4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/sigmoid_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_SIGMOID_INT8_H_ +#define NNACL_INT8_SIGMOID_INT8_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/int8/fixed_point.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SigmoidInt8(const int8_t *src, int length, int8_t *dst, int8_t *table); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SIGMOID_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/slice_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/slice_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..50223bfc3ee001789169ddfdf2e9b796b0b9ad9e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/slice_int8.c @@ -0,0 +1,97 @@ +/** + * Copyright 2020-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/slice_int8.h" +#include +#include +#include "nnacl/errorcode.h" + +int SliceInt8(const int8_t *input, int8_t *output, const SliceStruct *param, const SliceQuantArg *quant_arg, + int thread_id, int thread_num) { + double input_scale = quant_arg->in_args_.scale_; + int input_zp = quant_arg->in_args_.zp_; + double output_scale = quant_arg->out_args_.scale_; + int output_zp = quant_arg->out_args_.zp_; + const int base_offset = 20; + int act_min = quant_arg->output_activation_min_; + int act_max = quant_arg->output_activation_max_; + + size_t out_stride[8]; + out_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + out_stride[i] = out_stride[i + 1] * param->size_[i + 1]; + } + + int count_per_thread = UP_DIV(param->size_[5], thread_num); + size_t thread_begin = thread_id * count_per_thread; + size_t thread_end = MSMIN(param->size_[5], thread_begin + count_per_thread); + int unit_size = param->size_[7] * sizeof(int8_t); + size_t in_stride[8]; + in_stride[7] = 1; + for (int i = 6; i >= 0; --i) { + in_stride[i] = param->shape_[i + 1] * in_stride[i + 1]; + } + int i, j, k, l, n, h, w, c; + + int equal_quant = 0; + if (fabs(input_scale - output_scale) <= FLT_EPSILON && input_zp == output_zp) { + equal_quant = 1; + } + + for (i = 0; i < param->size_[0]; ++i) { + size_t out_offset0 = i * out_stride[0]; + size_t in_offset0 = (i + param->begin_[0]) * in_stride[0] + param->begin_[7]; + for (j = 0; j < param->size_[1]; ++j) { + size_t out_offset1 = j * out_stride[1] + out_offset0; + size_t in_offset1 = (j + param->begin_[1]) * in_stride[1] + in_offset0; + for (k = 0; k < param->size_[2]; ++k) { + size_t out_offset2 = k * out_stride[2] + out_offset1; + size_t in_offset2 = (k + param->begin_[2]) * in_stride[2] + in_offset1; + for (l = 0; l < param->size_[3]; ++l) { + size_t out_offset3 = l * out_stride[3] + out_offset2; + size_t in_offset3 = (l + param->begin_[3]) * in_stride[3] + in_offset2; + for (n = 0; n < param->size_[4]; ++n) { + size_t out_offset4 = n * out_stride[4] + out_offset3; + size_t in_offset4 = (n + param->begin_[4]) * in_stride[4] + in_offset3; + for (h = thread_begin; h < thread_end; ++h) { + size_t out_offset5 = h * out_stride[5] + out_offset4; + size_t in_offset5 = (h + param->begin_[5]) * in_stride[5] + in_offset4; + for (w = 0; w < param->size_[6]; ++w) { + size_t out_offset = w * out_stride[6] + out_offset5; + size_t in_offset = (w + param->begin_[6]) * in_stride[6] + in_offset5; + if (equal_quant == 1) { + memcpy(output + out_offset, input + in_offset, unit_size); + } else { + for (c = 0; c < param->size_[7]; ++c) { + int32_t output_val = + MultiplyByQuantizedMultiplier(input[in_offset + c] - input_zp, quant_arg->multiplier_.multiplier_, + quant_arg->multiplier_.left_shift_ + base_offset, + quant_arg->multiplier_.right_shift_ - base_offset) + + output_zp; + output_val = MSMAX(INT8_MIN, MSMIN(output_val, INT8_MAX)); + output[c + out_offset] = (int8_t)MSMAX(act_min, MSMIN(output_val, act_max)); + } + } + } + } + } + } + } + } + } + + return 0; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/slice_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/slice_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..e19ce3627d97523eafd212ce2244ba731c725db3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/slice_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_SLICE_INT8_H_ +#define NNACL_INT8_SLICE_INT8_H_ + +#include +#include +#include "nnacl/op_base.h" +#include "nnacl/slice_parameter.h" +#include "nnacl/int8/fixed_point.h" +#include "nnacl/kernel/slice.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SliceInt8(const int8_t *input, int8_t *output, const SliceStruct *param, const SliceQuantArg *quant_arg, + int thread_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SLICE_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/softmax_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/softmax_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..eaa8bc49c7a3d4bd3b2ac1b0011b64dac89726bc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/softmax_int8.c @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/softmax_int8.h" +#include "nnacl/errorcode.h" + +int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int32_t *exp_data, int32_t *sum_data, + const int32_t *input_shape, int n_dim, int32_t axis, const SoftmaxQuantArg *quant_param) { + int axis_shape_size = input_shape[axis]; + int inner_size = 1; + if (n_dim > DIMENSION_5D) { + return NNACL_ERR; + } + for (int i = axis + 1; i < n_dim; i++) { + inner_size *= input_shape[i]; + } + + for (int o = 0; o < count; o++) { + int outter_offset = o * axis_shape_size * inner_size; + + for (int c = 0; c < inner_size; c++) { + int8_t max_row = quant_param->output_activation_min_; + for (int i = 0; i < axis_shape_size; ++i) { + int axis_offset = outter_offset + c + i * inner_size; + max_row = MSMAX(max_row, input_ptr[axis_offset]); + } + + int32_t exp_sum = 0; + for (int i = 0; i < axis_shape_size; ++i) { + int axis_offset = outter_offset + c + i * inner_size; + const int32_t input_val = input_ptr[axis_offset] - max_row; + const int32_t input_scaled = SaturatingRoundingDoublingHighMul( + input_val * (1 << (unsigned int)quant_param->shift_left_), quant_param->output_multiplier_); + int exp_val = exp_on_negative_values(input_scaled, 5); + exp_data[axis_offset] = exp_val; + exp_sum = exp_sum + Rescale(exp_val, 0, 12); + } + sum_data[c] = exp_sum; + } + for (int i = 0; i < axis_shape_size; ++i) { + int axis_offset = outter_offset + i * inner_size; + for (int c = 0; c < inner_size; ++c) { + int num_bits_over_unit; + int shifted_scale = ComputerReciprocal(sum_data[c], 12, &num_bits_over_unit); + int unsat_output = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_scale, exp_data[axis_offset + c]), num_bits_over_unit + 31 - 8); + + int raw_output = unsat_output + quant_param->output_activation_min_; + output_ptr[axis_offset + c] = + (int8_t)MSMAX(quant_param->output_activation_min_, MSMIN(raw_output, quant_param->output_activation_max_)); + } + } + } + return 0; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/softmax_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/softmax_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..ae90ef5ba16d7d817f01330cd4af74747980d7c3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/softmax_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_SOFTMAX_INT8_H_ +#define NNACL_INT8_SOFTMAX_INT8_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/softmax_parameter.h" +#include "nnacl/int8/fixed_point.h" +#include "nnacl/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int32_t *exp_data, int32_t *sum_data, + const int32_t *input_shape, int n_dim, int32_t axis, const SoftmaxQuantArg *quant_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SOFTMAX_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/space_to_batch_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/space_to_batch_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..76eb668341dc11bbc0d4bf1959d1512a045d3153 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/space_to_batch_int8.c @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/int8/space_to_batch_int8.h" +#include "nnacl/common_func.h" + +void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, const int32_t *block_sizes, const int32_t *in_shape, + const int32_t *out_shape) { + int out_dim0 = out_shape[0]; + int out_dim1 = out_shape[1]; + int out_dim2 = out_shape[2]; + int copy_num = out_shape[3]; + int block_w = block_sizes[1]; + int block_h = block_sizes[0]; + int in_strides[4] = {0}; + ComputeStrides(in_shape, in_strides, 4); + int out_strides[4] = {0}; + ComputeStrides(out_shape, out_strides, 4); + size_t copy_size = copy_num * sizeof(int8_t); + size_t out_offset = 0; + + NNACL_CHECK_ZERO_RETURN(in_shape[0]); + NNACL_CHECK_ZERO_RETURN(block_w); + for (int n = 0; n < out_dim0; ++n) { + int in_n = n % in_shape[0]; + int32_t stride_w = (n / in_shape[0]) % block_w; + int32_t stride_h = (n / in_shape[0]) / block_w; + size_t in_offset0 = in_n * in_strides[0]; + for (int h = 0; h < out_dim1; ++h) { + size_t in_offset1 = in_offset0 + (h * block_h + stride_h) * in_strides[1]; + for (int w = 0; w < out_dim2; ++w) { + size_t in_offset2 = in_offset1 + (w * block_w + stride_w) * in_strides[2]; + memcpy(output + out_offset, input + in_offset2, copy_size); + out_offset += copy_num; + } + } + } +} + +void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, SpaceToBatchParameter *param, int32_t zp) { + int block_shape_h = param->block_sizes_[0]; + int block_shape_w = param->m_ == 2 ? param->block_sizes_[1] : 1; + int in_b = param->input_shape_[0]; + int in_h = param->input_shape_[1]; + int in_w = param->input_shape_[2]; + int channel = param->input_shape_[3]; + int out_h = param->output_shape_[1]; + int out_w = param->output_shape_[2]; + int pad_t = param->paddings_[0]; + int pad_l = param->m_ == 2 ? param->paddings_[2] : 0; + + NNACL_CHECK_ZERO_RETURN(in_b); + NNACL_CHECK_ZERO_RETURN(block_shape_w); + for (int i = 0; i < param->output_shape_[0]; ++i) { + int in_batch = i % in_b; + int offset_w = (i / in_b) % block_shape_w; + int offset_h = (i / in_b) / block_shape_w; + int in_b_offset = in_batch * in_h * in_w * channel; + int out_b_offset = i * out_h * out_w * channel; + for (int j = 0; j < out_h; ++j) { + int out_h_offset = out_b_offset + j * out_w * channel; + for (int k = 0; k < out_w; ++k) { + int8_t *out_ptr = output + out_h_offset + k * channel; + int index_h = j * block_shape_h + offset_h; + int index_w = k * block_shape_w + offset_w; + if (index_h < pad_t || index_h >= (pad_t + in_h) || index_w < pad_l || index_w >= (pad_l + in_w)) { + memset(out_ptr, zp, channel * sizeof(int8_t)); + } else { + int in_plane_offset = in_b_offset + ((index_h - pad_t) * in_w + (index_w - pad_l)) * channel; + const int8_t *in_ptr = input + in_plane_offset; + memcpy(out_ptr, in_ptr, channel * sizeof(int8_t)); + } + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/space_to_batch_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/space_to_batch_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..f2185017cb76f7e84ae46965cecd8f83c368ca7a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/space_to_batch_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INT8_SPACE_TO_BATCH_INT8_H_ +#define NNACL_INT8_SPACE_TO_BATCH_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/fp32/space_to_batch_fp32.h" + +#ifdef __cplusplus +extern "C" { +#endif +void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, const int32_t *block_sizes, const int32_t *in_shape, + const int32_t *out_shape); +void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, SpaceToBatchParameter *param, int32_t zp); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SPACE_TO_BATCH_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/split_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/split_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..d1e3a1ed0dcaed56a30434df7c884a39f6f3397a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/split_int8.c @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/split_int8.h" +#include +#include +#include +#include "nnacl/split_parameter.h" +#include "nnacl/errorcode.h" + +int Int8DoSplit(const int8_t *in_data, int8_t **out_data, const int32_t *input_shape, int offset, int num_unit, + const SplitParameter *param) { + if (in_data == NULL || out_data == NULL) { + return NNACL_ERR; + } + const int num_split = param->num_split_; + const int32_t *split_sizes = param->split_sizes_; + const int32_t *strides = param->strides_; + const int split_dim = param->split_dim_; + int in_stride = strides[split_dim]; + + int stride_per_split = in_stride * input_shape[split_dim]; + int split_which = offset % num_split; + int split_times = offset / num_split; + const int8_t *src = in_data + split_times * stride_per_split; + for (int i = 0; i < split_which; i++) { + src += split_sizes[i] * in_stride; + } + + const QuantArg in_quant_arg = param->quant_arg_.in_args_; + float in_scale = in_quant_arg.scale_; + int32_t in_zp = in_quant_arg.zp_; + const QuantArg *out_quant_arg = param->quant_arg_.out_args_; + + for (int i = offset; i < offset + num_unit; i++) { + split_which = i % num_split; + split_times = i / num_split; + int copy_size = split_sizes[split_which] * in_stride; + int8_t *dst = out_data[split_which] + split_times * copy_size; + float out_scale = out_quant_arg[split_which].scale_; + int32_t out_zp = out_quant_arg[split_which].zp_; + if (fabs(in_scale - out_scale) <= FLT_EPSILON && in_zp == out_zp) { + (void)memcpy(dst, src, copy_size * sizeof(int8_t)); + } else { + float scale = in_scale / out_scale; + float bias = -in_zp * scale; + for (int j = 0; j < copy_size; j++) { + int32_t output_tmp = round(src[j] * scale + bias) + out_zp; + if (output_tmp > param->quant_arg_.output_activation_max_) { + dst[j] = param->quant_arg_.output_activation_max_; + } else if (output_tmp < param->quant_arg_.output_activation_min_) { + dst[j] = param->quant_arg_.output_activation_min_; + } else { + dst[j] = (int8_t)output_tmp; + } + } + } + src += copy_size; + } + + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/split_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/split_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..06203b55cf48239d4ccbb4696a2df4a0d5dc8409 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/split_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_SPLIT_INT8_H_ +#define NNACL_INT8_SPLIT_INT8_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/split_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Int8DoSplit(const int8_t *in_data, int8_t **out_data, const int32_t *input_shape, int offset, int num_unit, + const SplitParameter *split_param); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SPLIT_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/squeeze_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/squeeze_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..6af4c843fe99791d94bb0ff3121fd35d068d70e5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/squeeze_int8.c @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/squeeze_int8.h" + +void SqueezeInt8(const int8_t *input_ptr, int8_t *output_ptr, const SqueezeQuantArg *quant_Squeeze_parm, int num, + int task_id, int thread_count) { + float output_scale = quant_Squeeze_parm->out_quant_args_->scale_; + const float output_inverse_scale = 1.f / output_scale; + QuantArg *input_quant = quant_Squeeze_parm->in_quant_args_; + int output_zp = quant_Squeeze_parm->out_quant_args_->zp_; + + const int i = 0; + for (int j = task_id; j < num; j += thread_count) { + float scale = input_quant[i].scale_ * output_inverse_scale; + float bias = -input_quant[i].zp_ * scale; + int32_t output_tmp = round(input_ptr[j] * scale + bias) + output_zp; + if (output_tmp > 127) { + output_ptr[j] = 127; + } else if (output_tmp < -128) { + output_ptr[j] = -128; + } else { + output_ptr[j] = (int8_t)output_tmp; + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/squeeze_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/squeeze_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..60abe7164dcf5539119c5b469c101633fd445a91 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/squeeze_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_SQUEEZE_INT8_H_ +#define NNACL_INT8_SQUEEZE_INT8_H_ + +#include "nnacl/squeeze_parameter.h" +#include "nnacl/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +void SqueezeInt8(const int8_t *input_ptr, int8_t *output_ptr, const SqueezeQuantArg *quant_Squeeze_parm, int num, + int task_id, int thread_count); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SQUEEZE_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/sub_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/sub_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..bb05bde0f7a668b328b085b5de0a8120547140aa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/sub_int8.c @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/sub_int8.h" +#ifdef ENABLE_NEON +#include +#include "nnacl/int8/common_func_int8.h" +#endif +#include "nnacl/int8/fixed_point.h" + +#ifdef ENABLE_NEON + +int16x4_t DoClacSumHalfWord(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t output_multiplier_vec, const SubQuantArg *para) { + int32x4_t raw_data = vsubq_s32(scaled_input0, scaled_input1); + + raw_data = RoundingDivideByPOTInt32x4(vqrdmulhq_s32(vmulq_s32(raw_data, left_shift_out_vec), output_multiplier_vec), + para->right_shift_out_); + raw_data = vaddq_s32(raw_data, vdupq_n_s32(para->out_args_.zp_)); + raw_data = vmaxq_s32(raw_data, vdupq_n_s32(para->output_activation_min_)); + raw_data = vminq_s32(raw_data, vdupq_n_s32(para->output_activation_max_)); + return vqmovn_s32(raw_data); +} + +void SubInt8NEON(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const SubQuantArg *para, int32_t *index) { + int32x4_t left_shift_result0_vec = vdupq_n_s32(para->left_shift_result0_); + int32x4_t left_shift_result1_vec = vdupq_n_s32(para->left_shift_result1_); + int32x4_t input0_multiplier_vec = vdupq_n_s32(para->input0_multiplier_); + int32x4_t input1_multiplier_vec = vdupq_n_s32(para->input1_multiplier_); + int32x4_t output_multiplier_vec = vdupq_n_s32(para->output_multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32((1 << (size_t)para->left_shift_out_)); + int32x4_t right_shift0_vec = vdupq_n_s32(-para->right_shift0_); + int32x4_t right_shift1_vec = vdupq_n_s32(-para->right_shift1_); + + for (; (*index) <= real_dst_count - 8; (*index) += 8) { + int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, para->in0_args_.zp_); + int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, para->in1_args_.zp_); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + + int32x4_t scaled_input0_low = + ClacScaledInput(input0_low, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec); + int32x4_t scaled_input0_high = + ClacScaledInput(input0_high, left_shift_result0_vec, input0_multiplier_vec, right_shift0_vec); + int32x4_t scaled_input1_low = + ClacScaledInput(input1_low, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec); + int32x4_t scaled_input1_high = + ClacScaledInput(input1_high, left_shift_result1_vec, input1_multiplier_vec, right_shift1_vec); + + int16x4_t sum_low = + DoClacSumHalfWord(scaled_input0_low, scaled_input1_low, left_shift_out_vec, output_multiplier_vec, para); + int16x4_t sum_high = + DoClacSumHalfWord(scaled_input0_high, scaled_input1_high, left_shift_out_vec, output_multiplier_vec, para); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(output_data + *index, res_u8_n0); + } +} +#endif + +int SubInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const SubQuantArg *para) { + int index = 0; +#ifdef ENABLE_NEON + SubInt8NEON(input0_data, input1_data, output_data, real_dst_count, para, &index); +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = para->in0_args_.zp_ + input0_data[index]; + const int32_t input1_val = para->in1_args_.zp_ + input1_data[index]; + const int32_t shifted_input0_val = input0_val * para->left_shift_result0_; + const int32_t shifted_input1_val = input1_val * para->left_shift_result1_; + const int32_t scaled_input0_val = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_input0_val, para->input0_multiplier_), para->right_shift0_); + const int32_t scaled_input1_val = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(shifted_input1_val, para->input1_multiplier_), para->right_shift1_); + + const int32_t raw_data = scaled_input0_val - scaled_input1_val; + const int32_t raw_output = + RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(raw_data * (1 << (unsigned int)para->left_shift_out_), + para->output_multiplier_), + para->right_shift_out_) + + para->out_args_.zp_; + + output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_)); + } + return 0; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/sub_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/sub_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..637dce2fc4bafaa43acbfa1078842468c38ca980 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/sub_int8.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_SUB_INT8_H_ +#define NNACL_INT8_SUB_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +#ifdef __cplusplus +extern "C" { +#endif +int SubInt8(const int8_t *input0_data, const int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, + const SubQuantArg *para); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_SUB_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/tanh_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/tanh_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..585d0a4db08ab3ac7f56f5fb8776a764cc447807 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/tanh_int8.c @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/tanh_int8.h" +#ifdef ENABLE_NEON +#include +#endif + +void TanhInt8(const int8_t *input_ptr, int8_t *output_ptr, int size, const TanhQuantParameter *quant) { + for (int i = 0; i < size; ++i) { + float fp32_src = (input_ptr[i] - quant->in_zp_) * quant->in_scale_; + float fp32_dst = TanhOpt(fp32_src); + int32_t int32_dst = (int32_t)round(fp32_dst * 1.0 / quant->out_scale_ + quant->out_zp_); + output_ptr[i] = (int8_t)MSMAX(MSMIN(int32_dst, 127), -128); + } + return; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/tanh_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/tanh_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..0af3145f3e20d2c0a6b90602951f05a0bf522520 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/tanh_int8.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_TANH_INT8_H_ +#define NNACL_INT8_TANH_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" +#include "nnacl/int8/fixed_point.h" +#include "nnacl/int8/quant_dtype_cast_int8.h" +#include "nnacl/fp32/activation_fp32.h" + +typedef struct TanhQuantParameter { + int32_t in_zp_; + int32_t out_zp_; + double in_scale_; + double out_scale_; +} TanhQuantParameter; + +#ifdef __cplusplus +extern "C" { +#endif + +void TanhInt8(const int8_t *input_ptr, int8_t *output_ptr, int size, const TanhQuantParameter *quant); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_TANH_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/topk_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/topk_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..9ebcc121ca502cf0e6d1c5d949ad4e1830914190 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/topk_int8.c @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/topk_int8.h" + +int DescendCmpInt8(const void *a, const void *b) { + return ((const TopkNodeInt8 *)b)->element - ((const TopkNodeInt8 *)a)->element; +} + +int AscendCmpInt8(const void *a, const void *b) { + return ((const TopkNodeInt8 *)a)->element - ((const TopkNodeInt8 *)b)->element; +} + +void TopkInt8(int8_t *input_data, int8_t *output_data, int32_t *output_index, TopkParameter *parameter) { + int dim_size = parameter->dim_size_; + int outer_loop_num = parameter->outer_loop_num_; + int inner_loop_num = parameter->inner_loop_num_; + int k = parameter->k_; + TopkNode *top_map = (TopkNode *)parameter->topk_node_list_; + + int8_t *cur_input_data = (int8_t *)input_data; + int8_t *cur_output_data = (int8_t *)output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < outer_loop_num; i++) { + int in_offset = i * dim_size * inner_loop_num; + int out_offset = i * k * inner_loop_num; + for (int j = 0; j < inner_loop_num; j++) { + for (int m = 0; m < dim_size; m++) { + int offset = in_offset + m * inner_loop_num + j; + top_map[m].element = *(cur_input_data + offset); + top_map[m].index = m; + } + qsort(top_map, dim_size, sizeof(top_map[0]), DescendCmpInt8); + if (!parameter->sorted_) { + qsort(top_map, k, sizeof(top_map[0]), AscendCmpInt8); + } + for (int m = 0; m < k; m++) { + int offset = out_offset + m * inner_loop_num + j; + cur_output_data[offset] = top_map[m].element; + cur_output_index[offset] = top_map[m].index; + } + } + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/topk_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/topk_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..e6a47530176e4df892272a286441b38a4763d174 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/topk_int8.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_TOPK_INT8_H_ +#define NNACL_INT8_TOPK_INT8_H_ + +#include "nnacl/op_base.h" +#include "nnacl/fp32/topk_fp32.h" + +typedef struct TopkNodeInt8 { + int8_t element; + int32_t index; +} TopkNodeInt8; + +#ifdef __cplusplus +extern "C" { +#endif +void TopkInt8(int8_t *input_data, int8_t *output_data, int32_t *output_index, TopkParameter *parameter); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_TOPK_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/transpose_int8.c b/mindspore-lite/ops/kernel/cpu/nnacl/int8/transpose_int8.c new file mode 100644 index 0000000000000000000000000000000000000000..a6bf6e3a95acb85701ae7443f4f51a956dfd10cb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/transpose_int8.c @@ -0,0 +1,257 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/transpose_int8.h" +void TransposeDim2Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * output1; + int stride0_i = i * 1 * stride0; + for (int j = 0; j < output1; ++j) { + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; + } + } + return; +} + +void TransposeDim3Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; + } + } + } +} + +void TransposeDim4Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; + } + } + } + } +} + +void TransposeDim5Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; + } + } + } + } + } +} + +void TransposeDim6Int8(const int8_t *in_data, int8_t *out_data, const int32_t *strides, const int32_t *out_strides, + const int32_t *perm, const int32_t *output_shape) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int stride5 = strides[perm[5]]; + + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int out_stride4 = out_strides[4]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + const int output5 = output_shape[5]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + int out_stride4_n = n * out_stride4; + int stride4_n = n * stride4; + for (int p = 0; p < output5; ++p) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + p] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + p * stride5]; + } + } + } + } + } + } +} + +int DoTransposeInt8(const int8_t *in_data, int8_t *out_data, const int32_t *output_shape, + const TransposeParameter *transpose_param) { + NNACL_CHECK_NULL_RETURN_ERR(in_data); + NNACL_CHECK_NULL_RETURN_ERR(out_data); + NNACL_CHECK_NULL_RETURN_ERR(output_shape); + NNACL_CHECK_NULL_RETURN_ERR(transpose_param); + + const int32_t *perm = transpose_param->perm_; + const int32_t *strides = transpose_param->strides_; + const int32_t *out_strides = transpose_param->out_strides_; + const int num_axes = transpose_param->num_axes_; + + // check if transpose is needed + bool needTranspose = false; + for (int i = 1; i < num_axes; i++) { + if (perm[i] - perm[i - 1] != 1) { + needTranspose = true; + break; + } + } + + if (!needTranspose) { + (void)memcpy(out_data, in_data, transpose_param->data_num_ * sizeof(int8_t)); + return NNACL_OK; + } + + switch (num_axes) { + case 2: + TransposeDim2Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + case 3: + TransposeDim3Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + case 4: + TransposeDim4Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + case 5: + TransposeDim5Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + case 6: + TransposeDim6Int8(in_data, out_data, strides, out_strides, perm, output_shape); + break; + default: + return NNACL_ERR; + } + + return NNACL_OK; +} + +void TransposeDimsInt8(const int8_t *in_data, int8_t *out_data, const int32_t *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num) { + NNACL_CHECK_NULL_RETURN_VOID(in_data); + NNACL_CHECK_NULL_RETURN_VOID(out_data); + NNACL_CHECK_NULL_RETURN_VOID(output_shape); + NNACL_CHECK_NULL_RETURN_VOID(transpose_param); + NNACL_CHECK_ZERO_RETURN(thread_num); + const int32_t *perm = transpose_param->perm_; + const int32_t *strides = transpose_param->strides_; + const int32_t *out_strides = transpose_param->out_strides_; + int num_axes = transpose_param->num_axes_; + size_t data_size = (size_t)((*out_strides) * output_shape[0]); + size_t offset_size = UP_DIV(data_size, thread_num); + size_t task_offset = offset_size * task_id; + size_t count = data_size - task_offset; + if (data_size < task_offset) { + return; + } + count = MSMIN(offset_size, count); + for (size_t idx = task_offset; idx < task_offset + count; ++idx) { + int pos = (int)idx; + int output_idx = 0; + int input_idx = 0; + for (int i = 0; i < num_axes; ++i) { + NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); + int position = pos / *(out_strides + i); + int out_stride = i < num_axes - 1 ? out_strides[i] : 1; + output_idx += (position * out_stride); + input_idx += (position * strides[perm[i]]); + pos -= position * (*(out_strides + i)); + } + out_data[output_idx] = in_data[input_idx]; + } +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/transpose_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/transpose_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..f4f0ffc589850fb2e21a9b1f4ae3f9fede37916d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/transpose_int8.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_TRANSPOSE_INT8_H_ +#define NNACL_INT8_TRANSPOSE_INT8_H_ + +#include +#include "nnacl/transpose_parameter.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DoTransposeInt8(const int8_t *in_data, int8_t *out_data, const int32_t *output_shape, + const TransposeParameter *transpose_param); +void TransposeDimsInt8(const int8_t *in_data, int8_t *out_data, const int32_t *output_shape, + const TransposeParameter *transpose_param, int task_id, int thread_num); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_TRANSPOSE_INT8_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_plugin_impl.cc b/mindspore-lite/ops/kernel/cpu/nnacl/int8/unsqueeze_int8.c similarity index 43% rename from mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_plugin_impl.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/int8/unsqueeze_int8.c index 9d907b5b1a3d3e200c4b7c37888e48d1a27e2d72..3ef52e2db04b8a33649d5f18c79032553cb269d6 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_plugin_impl.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/unsqueeze_int8.c @@ -1,5 +1,5 @@ /** - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,30 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "extendrt/delegate/tensorrt/tensorrt_plugin_impl.h" -#ifdef LITE_CUDA_DISTRIBUTION -#include "extendrt/delegate/tensorrt/distribution/distribution_base.h" -#include "plugin/device/gpu/hal/device/distribution/collective_wrapper.h" -#endif -namespace mindspore::lite { -int TensorRTPluginImpl::GetGPUGroupSize() const { -#ifdef LITE_CUDA_DISTRIBUTION - return GetGroupSize(NCCL_WORLD_GROUP); -#else - return 1; -#endif -} +#include "nnacl/int8/unsqueeze_int8.h" +#include "nnacl/unsqueeze_parameter.h" +#include "nnacl/errorcode.h" -int TensorRTPluginImpl::GetRankID() const { -#ifdef LITE_CUDA_DISTRIBUTION - return GetRankIDByGroup(NCCL_WORLD_GROUP); -#else - return 0; -#endif -} -} // namespace mindspore::lite +int Int8Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, const UnSqueezeParameter *para_, size_t data_size, + int task_id) { + float output_scale = para_->quant_arg.out_quant_args_.scale_; + NNACL_CHECK_ZERO_RETURN_ERR(output_scale); + int8_t output_zp = para_->quant_arg.out_quant_args_.zp_; + float input_scale = para_->quant_arg.in_quant_args_.scale_; + int8_t input_zp = para_->quant_arg.in_quant_args_.zp_; -mindspore::lite::TensorRTExecutorPluginImplBase *CreateTensorRTPluginImpl() { - return new mindspore::lite::TensorRTPluginImpl(); + for (int i = task_id; i < (int)data_size; i += para_->thread_count_) { + output_ptr[i] = output_zp + round(1 / output_scale * input_scale * (input_ptr[i] - input_zp)); + } + return 0; } diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/int8/unsqueeze_int8.h b/mindspore-lite/ops/kernel/cpu/nnacl/int8/unsqueeze_int8.h new file mode 100644 index 0000000000000000000000000000000000000000..11657de0dd32716272307a9be5f903751a2bfdc3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/int8/unsqueeze_int8.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INT8_UNSQUEEZE_INT8_H_ +#define NNACL_INT8_UNSQUEEZE_INT8_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/unsqueeze_parameter.h" + +#ifdef __cplusplus +extern "C" { +#endif +int Int8Unsqueeze(const int8_t *input_ptr, int8_t *output_ptr, const UnSqueezeParameter *para_, size_t data_size, + int task_id); +#ifdef __cplusplus +} +#endif + +#endif // NNACL_INT8_UNSQUEEZE_INT8_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/DeconvMatMulAvx.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/DeconvMatMulAvx.c new file mode 100644 index 0000000000000000000000000000000000000000..a0ca009206851f5b737048a887312704dd5b8c1b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/DeconvMatMulAvx.c @@ -0,0 +1,188 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/op_base.h" + +void Deconv4X8AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) { + __m256 res1 = _mm256_setzero_ps(); + __m256 res4 = _mm256_setzero_ps(); + __m256 res7 = _mm256_setzero_ps(); + __m256 res10 = _mm256_setzero_ps(); + + for (int d = 0; d < depth; ++d) { + __m256 w0 = _mm256_loadu_ps(weight); + __m256 tmp = _mm256_set1_ps(*src); + __m256 tmp1 = _mm256_set1_ps(*(src + C1NUM)); + weight += C8NUM; + __m256 tmp2 = _mm256_set1_ps(*(src + C2NUM)); + __m256 tmp3 = _mm256_set1_ps(*(src + C3NUM)); + res1 = _mm256_fmadd_ps(tmp, w0, res1); + res4 = _mm256_fmadd_ps(tmp1, w0, res4); + src += C4NUM; + res7 = _mm256_fmadd_ps(tmp2, w0, res7); + res10 = _mm256_fmadd_ps(tmp3, w0, res10); + } + // write + _mm256_storeu_ps(dst, res1); + _mm256_storeu_ps(dst + C8NUM, res4); + _mm256_storeu_ps(dst + C16NUM, res7); + _mm256_storeu_ps(dst + C24NUM, res10); +} + +void Deconv4X16AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) { + __m256 res1 = _mm256_setzero_ps(); + __m256 res2 = _mm256_setzero_ps(); + __m256 res4 = _mm256_setzero_ps(); + __m256 res5 = _mm256_setzero_ps(); + __m256 res7 = _mm256_setzero_ps(); + __m256 res8 = _mm256_setzero_ps(); + __m256 res10 = _mm256_setzero_ps(); + __m256 res11 = _mm256_setzero_ps(); + + for (int d = 0; d < depth; ++d) { + __m256 w0 = _mm256_loadu_ps(weight); + __m256 w1 = _mm256_loadu_ps(weight + C8NUM); + weight += C16NUM; + __m256 tmp = _mm256_set1_ps(*src); + __m256 tmp1 = _mm256_set1_ps(*(src + C1NUM)); + __m256 tmp2 = _mm256_set1_ps(*(src + C2NUM)); + __m256 tmp3 = _mm256_set1_ps(*(src + C3NUM)); + res1 = _mm256_fmadd_ps(tmp, w0, res1); + res2 = _mm256_fmadd_ps(tmp, w1, res2); + src += C4NUM; + res4 = _mm256_fmadd_ps(tmp1, w0, res4); + res5 = _mm256_fmadd_ps(tmp1, w1, res5); + res7 = _mm256_fmadd_ps(tmp2, w0, res7); + res8 = _mm256_fmadd_ps(tmp2, w1, res8); + res10 = _mm256_fmadd_ps(tmp3, w0, res10); + res11 = _mm256_fmadd_ps(tmp3, w1, res11); + } + // write + _mm256_storeu_ps(dst, res1); + _mm256_storeu_ps(dst + C8NUM, res4); + _mm256_storeu_ps(dst + C16NUM, res7); + _mm256_storeu_ps(dst + C24NUM, res10); + + _mm256_storeu_ps(dst + stride, res2); + _mm256_storeu_ps(dst + stride + C8NUM, res5); + _mm256_storeu_ps(dst + stride + C16NUM, res8); + _mm256_storeu_ps(dst + stride + C24NUM, res11); +} + +void Deconv4X24AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) { + __m256 res1 = _mm256_setzero_ps(); + __m256 res2 = _mm256_setzero_ps(); + __m256 res3 = _mm256_setzero_ps(); + __m256 res4 = _mm256_setzero_ps(); + __m256 res5 = _mm256_setzero_ps(); + __m256 res6 = _mm256_setzero_ps(); + __m256 res7 = _mm256_setzero_ps(); + __m256 res8 = _mm256_setzero_ps(); + __m256 res9 = _mm256_setzero_ps(); + __m256 res10 = _mm256_setzero_ps(); + __m256 res11 = _mm256_setzero_ps(); + __m256 res12 = _mm256_setzero_ps(); + + for (int d = 0; d < depth; ++d) { + __m256 w0 = _mm256_loadu_ps(weight); + __m256 w1 = _mm256_loadu_ps(weight + C8NUM); + __m256 w2 = _mm256_loadu_ps(weight + C16NUM); + __m256 tmp = _mm256_set1_ps(*src); + res1 = _mm256_fmadd_ps(tmp, w0, res1); + res2 = _mm256_fmadd_ps(tmp, w1, res2); + res3 = _mm256_fmadd_ps(tmp, w2, res3); + tmp = _mm256_set1_ps(*(src + C1NUM)); + res4 = _mm256_fmadd_ps(tmp, w0, res4); + res5 = _mm256_fmadd_ps(tmp, w1, res5); + res6 = _mm256_fmadd_ps(tmp, w2, res6); + tmp = _mm256_set1_ps(*(src + C2NUM)); + res7 = _mm256_fmadd_ps(tmp, w0, res7); + res8 = _mm256_fmadd_ps(tmp, w1, res8); + res9 = _mm256_fmadd_ps(tmp, w2, res9); + tmp = _mm256_set1_ps(*(src + C3NUM)); + res10 = _mm256_fmadd_ps(tmp, w0, res10); + res11 = _mm256_fmadd_ps(tmp, w1, res11); + res12 = _mm256_fmadd_ps(tmp, w2, res12); + weight += C24NUM; + src += C4NUM; + } + // write + _mm256_storeu_ps(dst, res1); + _mm256_storeu_ps(dst + C8NUM, res4); + _mm256_storeu_ps(dst + C16NUM, res7); + _mm256_storeu_ps(dst + C24NUM, res10); + + _mm256_storeu_ps(dst + stride, res2); + _mm256_storeu_ps(dst + stride + C8NUM, res5); + _mm256_storeu_ps(dst + stride + C16NUM, res8); + _mm256_storeu_ps(dst + stride + C24NUM, res11); + + _mm256_storeu_ps(dst + C2NUM * stride, res3); + _mm256_storeu_ps(dst + C2NUM * stride + C8NUM, res6); + _mm256_storeu_ps(dst + C2NUM * stride + C16NUM, res9); + _mm256_storeu_ps(dst + C2NUM * stride + C24NUM, res12); +} + +void DeconvMatmulAvx(const float *a, const float *b, float *c, int depth, int row, int col, const int plane) { + NNACL_CHECK_ZERO_RETURN(plane); + int col_num = 0; + int col_block = UP_DIV(col / plane, C8NUM); + DeconvAvxKernel kernel[3] = {Deconv4X8AvxKernel, Deconv4X16AvxKernel, Deconv4X24AvxKernel}; + for (int col_tmp = 0; col_tmp < col_block; col_tmp += col_num) { + col_num = MSMIN(C3NUM, col_block - col_tmp); + for (int p = 0; p < plane; ++p) { + for (int r = 0; r < row; r += C4NUM) { + kernel[col_num - 1](a + r * depth, b + (col_tmp * plane + p * col_num) * C8NUM * depth, + c + (col_tmp * plane + p) * C8NUM * row + r * C8NUM, col_num, C4NUM, depth, + row * C8NUM * plane); + } + } + } +} + +#ifdef ENABLE_DEBUG +void DeconvColXRowAvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, + int stride) { + __m256 res[C12NUM]; + __m256 w[C3NUM]; + for (int i = 0; i < C12NUM; ++i) { + res[i] = _mm256_setzero_ps(); + } + for (int d = 0; d < depth; ++d) { + for (int c = 0; c < col; ++c) { + w[c] = _mm256_loadu_ps(weight + c * C8NUM); + } + weight += col * C8NUM; + for (int r = 0; r < row; ++r) { // C4NUm + __m256 tmp = _mm256_set1_ps(*src); + for (int c = 0; c < col; ++c) { // 3 * C8NUM + res[r * col + c] = _mm256_fmadd_ps(tmp, w[c], res[r * col + c]); + } + src += 1; + } + } + // write + for (int i = 0; i < col; ++i) { + for (int j = 0; j < row; ++j) { + _mm256_storeu_ps(dst + j * C8NUM, res[j * col + i]); + } + dst += stride; + } +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/PostFuncBiasReluC8.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/PostFuncBiasReluC8.c new file mode 100644 index 0000000000000000000000000000000000000000..6ce0eb6f05979e4d11d1049b6dd20cd67b1ef98e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/PostFuncBiasReluC8.c @@ -0,0 +1,352 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl/intrinsics/avx/common_utils.h" + +void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type) { + stride /= sizeof(float); + int loop_c8 = 0; + size_t src_stride = plane_size * C8NUM; + for (; loop_c8 <= (int)(oc8div)-C32NUM; loop_c8 += C32NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + __m256 bias3 = _mm256_setzero_ps(); + __m256 bias4 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias3 = _mm256_loadu_ps(bias + C16NUM); + bias4 = _mm256_loadu_ps(bias + C24NUM); + bias += C32NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src9 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src10 = _mm256_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m256 src13 = _mm256_loadu_ps(src + src_stride * C3NUM); + __m256 src14 = _mm256_loadu_ps(src + src_stride * C3NUM + C8NUM); + + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src9 = _mm256_add_ps(src9, bias3); + src10 = _mm256_add_ps(src10, bias3); + src13 = _mm256_add_ps(src13, bias4); + src14 = _mm256_add_ps(src14, bias4); + + ActBlock8Avx(&src1, &src2, &src5, &src6, &src9, &src10, &src13, &src14, relu_type); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src5); + _mm256_storeu_ps(dst_c8 + C16NUM, src9); + _mm256_storeu_ps(dst_c8 + C24NUM, src13); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src2); + _mm256_storeu_ps(dst_c8 + C8NUM, src6); + _mm256_storeu_ps(dst_c8 + C16NUM, src10); + _mm256_storeu_ps(dst_c8 + C24NUM, src14); + dst_c8 += stride; + + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + __m256 src11 = _mm256_loadu_ps(src + src_stride * C2NUM + C16NUM); + __m256 src12 = _mm256_loadu_ps(src + src_stride * C2NUM + C24NUM); + __m256 src15 = _mm256_loadu_ps(src + src_stride * C3NUM + C16NUM); + __m256 src16 = _mm256_loadu_ps(src + src_stride * C3NUM + C24NUM); + src += C32NUM; + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + src11 = _mm256_add_ps(src11, bias3); + src12 = _mm256_add_ps(src12, bias3); + src15 = _mm256_add_ps(src15, bias4); + src16 = _mm256_add_ps(src16, bias4); + + ActBlock8Avx(&src3, &src4, &src7, &src8, &src11, &src12, &src15, &src16, relu_type); + + _mm256_storeu_ps(dst_c8, src3); + _mm256_storeu_ps(dst_c8 + C8NUM, src7); + _mm256_storeu_ps(dst_c8 + C16NUM, src11); + _mm256_storeu_ps(dst_c8 + C24NUM, src15); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src4); + _mm256_storeu_ps(dst_c8 + C8NUM, src8); + _mm256_storeu_ps(dst_c8 + C16NUM, src12); + _mm256_storeu_ps(dst_c8 + C24NUM, src16); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + __m256 src3 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src4 = _mm256_loadu_ps(src + src_stride * C3NUM); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + src3 = _mm256_add_ps(src3, bias3); + src4 = _mm256_add_ps(src4, bias4); + + ActBlock4Avx(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src2); + _mm256_storeu_ps(dst_c8 + C16NUM, src3); + _mm256_storeu_ps(dst_c8 + C24NUM, src4); + dst_c8 += stride; + src += C8NUM; + } + src += C3NUM * src_stride; + } + for (; loop_c8 <= (int)(oc8div)-C24NUM; loop_c8 += C24NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + __m256 bias3 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias3 = _mm256_loadu_ps(bias + C16NUM); + bias += C24NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + __m256 src9 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src10 = _mm256_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m256 src11 = _mm256_loadu_ps(src + src_stride * C2NUM + C16NUM); + __m256 src12 = _mm256_loadu_ps(src + src_stride * C2NUM + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + src9 = _mm256_add_ps(src9, bias3); + src10 = _mm256_add_ps(src10, bias3); + src11 = _mm256_add_ps(src11, bias3); + src12 = _mm256_add_ps(src12, bias3); + + ActBlock12Avx(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, &src9, &src10, &src11, &src12, + relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src5); + _mm256_storeu_ps(dst_c8 + C16NUM, src9); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src2); + _mm256_storeu_ps(dst_c8 + C8NUM, src6); + _mm256_storeu_ps(dst_c8 + C16NUM, src10); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src3); + _mm256_storeu_ps(dst_c8 + C8NUM, src7); + _mm256_storeu_ps(dst_c8 + C16NUM, src11); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src4); + _mm256_storeu_ps(dst_c8 + C8NUM, src8); + _mm256_storeu_ps(dst_c8 + C16NUM, src12); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + __m256 src3 = _mm256_loadu_ps(src + src_stride * C2NUM); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + src3 = _mm256_add_ps(src3, bias3); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + ActBlock1Avx(&src2, relu_type == 1, relu_type == C3NUM); + ActBlock1Avx(&src3, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src2); + _mm256_storeu_ps(dst_c8 + C16NUM, src3); + dst_c8 += stride; + src += C8NUM; + } + src += C2NUM * src_stride; + } + for (; loop_c8 <= (int)(oc8div)-C16NUM; loop_c8 += C16NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias += C16NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + + ActBlock8Avx(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src5); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src2); + _mm256_storeu_ps(dst_c8 + C8NUM, src6); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src3); + _mm256_storeu_ps(dst_c8 + C8NUM, src7); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src4); + _mm256_storeu_ps(dst_c8 + C8NUM, src8); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + + ActBlock2Avx(&src1, &src2, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + _mm256_storeu_ps(dst_c8 + C8NUM, src2); + dst_c8 += stride; + src += C8NUM; + } + src += src_stride; + } + for (; loop_c8 < (int)(oc8div); loop_c8 += C8NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias += C8NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + + ActBlock4Avx(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src2); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src3); + dst_c8 += stride; + _mm256_storeu_ps(dst_c8, src4); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + src1 = _mm256_add_ps(src1, bias1); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + + _mm256_storeu_ps(dst_c8, src1); + dst_c8 += stride; + src += C8NUM; + } + } + if (oc8mod == 0) { + return; + } + __m256 bias1 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias += C8NUM; + } + float *dst_c1 = dst + oc8div; + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1, src += C8NUM, dst_c1 += stride) { + __m256 src1 = _mm256_loadu_ps(src); + src1 = _mm256_add_ps(src1, bias1); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + __m128 src_high = _mm256_extractf128_ps(src1, 1); + + switch (oc8mod) { + case 1: + dst_c1[0] = _mm256_cvtss_f32(src1); + break; + case C2NUM: + _mm_storel_pi((__m64 *)(dst_c1), _mm256_castps256_ps128(src1)); + break; + case C3NUM: + _mm_storel_pi((__m64 *)(dst_c1), _mm256_castps256_ps128(src1)); + dst_c1[C2NUM] = MS_F32X8_GETI(src1, C2NUM); + break; + case C4NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + break; + case C5NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_store_ss(dst_c1 + C4NUM, src_high); + break; + case C6NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_storel_pi((__m64 *)(dst_c1 + C4NUM), src_high); + break; + case C7NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_storel_pi((__m64 *)(dst_c1 + C4NUM), src_high); + dst_c1[C6NUM] = MS_F32X4_GETI(src_high, C2NUM); + break; + default: + break; + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/TiledC8MatMulFp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/TiledC8MatMulFp32.c new file mode 100644 index 0000000000000000000000000000000000000000..31f2d7160462a69082b1a33f20076508e8c9319b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/TiledC8MatMulFp32.c @@ -0,0 +1,244 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#ifdef _MSC_VER +#include +#else +#include +#endif +#include "nnacl/fp32/common_func_fp32.h" + +void TiledC8MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic8, size_t oc8) { + const float *src_tmp = src; + for (int i = 0; i < oc8; ++i) { + src = src_tmp; + register __m256 dst1 asm("ymm0") = _mm256_setzero_ps(); + register __m256 dst2 asm("ymm1") = _mm256_setzero_ps(); + register __m256 dst3 asm("ymm2") = _mm256_setzero_ps(); + register __m256 dst4 asm("ymm3") = _mm256_setzero_ps(); + register __m256 dst5 asm("ymm4") = _mm256_setzero_ps(); + register __m256 dst6 asm("ymm5") = _mm256_setzero_ps(); + register __m256 dst7 asm("ymm6") = _mm256_setzero_ps(); + register __m256 dst8 asm("ymm7") = _mm256_setzero_ps(); + for (size_t ic8_tmp = 0; ic8_tmp < ic8; ++ic8_tmp) { +#ifndef ENABLE_DEBUG + asm volatile( + // 1 + "vmovups (%1), %%ymm8\n" + + "vbroadcastss (%0), %%ymm9\n" + "vbroadcastss 32(%0), %%ymm10\n" + "vbroadcastss 64(%0), %%ymm11\n" + "vbroadcastss 96(%0), %%ymm12\n" + "vbroadcastss 128(%0), %%ymm13\n" + "vbroadcastss 160(%0), %%ymm14\n" + + "vfmadd231ps %%ymm9, %%ymm8, %%ymm0\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm1\n" + "vfmadd231ps %%ymm11, %%ymm8, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm3\n" + "vfmadd231ps %%ymm13, %%ymm8, %%ymm4\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm5\n" + + "vbroadcastss 192(%0), %%ymm9\n" + "vbroadcastss 224(%0), %%ymm10\n" + "vfmadd231ps %%ymm9, %%ymm8, %%ymm6\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm7\n" + + // 2 + "vmovups 32(%1), %%ymm15\n" + + "vbroadcastss 4(%0), %%ymm11\n" + "vbroadcastss 36(%0), %%ymm12\n" + "vbroadcastss 68(%0), %%ymm13\n" + "vbroadcastss 100(%0), %%ymm14\n" + "vbroadcastss 132(%0), %%ymm9\n" + "vbroadcastss 164(%0), %%ymm10\n" + + "vfmadd231ps %%ymm11, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm3\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm5\n" + + "vbroadcastss 196(%0), %%ymm11\n" + "vbroadcastss 228(%0), %%ymm12\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" + + // 3 + "vmovups 64(%1), %%ymm8\n" + + "vbroadcastss 8(%0), %%ymm13\n" + "vbroadcastss 40(%0), %%ymm14\n" + "vbroadcastss 72(%0), %%ymm9\n" + "vbroadcastss 104(%0), %%ymm10\n" + "vbroadcastss 136(%0), %%ymm11\n" + "vbroadcastss 168(%0), %%ymm12\n" + + "vfmadd231ps %%ymm13, %%ymm8, %%ymm0\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm1\n" + "vfmadd231ps %%ymm9, %%ymm8, %%ymm2\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm3\n" + "vfmadd231ps %%ymm11, %%ymm8, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm5\n" + + "vbroadcastss 200(%0), %%ymm13\n" + "vbroadcastss 232(%0), %%ymm14\n" + "vfmadd231ps %%ymm13, %%ymm8, %%ymm6\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm7\n" + + // 4 + "vmovups 96(%1), %%ymm15\n" + + "vbroadcastss 12(%0), %%ymm9\n" + "vbroadcastss 44(%0), %%ymm10\n" + "vbroadcastss 76(%0), %%ymm11\n" + "vbroadcastss 108(%0), %%ymm12\n" + "vbroadcastss 140(%0), %%ymm13\n" + "vbroadcastss 172(%0), %%ymm14\n" + + "vfmadd231ps %%ymm9, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm3\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm5\n" + + "vbroadcastss 204(%0), %%ymm9\n" + "vbroadcastss 236(%0), %%ymm10\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm7\n" + + // 5 + "vmovups 128(%1), %%ymm8\n" + + "vbroadcastss 16(%0), %%ymm11\n" + "vbroadcastss 48(%0), %%ymm12\n" + "vbroadcastss 80(%0), %%ymm13\n" + "vbroadcastss 112(%0), %%ymm14\n" + "vbroadcastss 144(%0), %%ymm9\n" + "vbroadcastss 176(%0), %%ymm10\n" + + "vfmadd231ps %%ymm11, %%ymm8, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm1\n" + "vfmadd231ps %%ymm13, %%ymm8, %%ymm2\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm3\n" + "vfmadd231ps %%ymm9, %%ymm8, %%ymm4\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm5\n" + + "vbroadcastss 208(%0), %%ymm11\n" + "vbroadcastss 240(%0), %%ymm12\n" + "vfmadd231ps %%ymm11, %%ymm8, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm7\n" + + // 6 + "vmovups 160(%1), %%ymm15\n" + + "vbroadcastss 20(%0), %%ymm13\n" + "vbroadcastss 52(%0), %%ymm14\n" + "vbroadcastss 84(%0), %%ymm9\n" + "vbroadcastss 116(%0), %%ymm10\n" + "vbroadcastss 148(%0), %%ymm11\n" + "vbroadcastss 180(%0), %%ymm12\n" + + "vfmadd231ps %%ymm13, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm3\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm5\n" + + "vbroadcastss 212(%0), %%ymm13\n" + "vbroadcastss 244(%0), %%ymm14\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm7\n" + + // 7 + "vmovups 192(%1), %%ymm8\n" + + "vbroadcastss 24(%0), %%ymm9\n" + "vbroadcastss 56(%0), %%ymm10\n" + "vbroadcastss 88(%0), %%ymm11\n" + "vbroadcastss 120(%0), %%ymm12\n" + "vbroadcastss 152(%0), %%ymm13\n" + "vbroadcastss 184(%0), %%ymm14\n" + + "vfmadd231ps %%ymm9, %%ymm8, %%ymm0\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm1\n" + "vfmadd231ps %%ymm11, %%ymm8, %%ymm2\n" + "vfmadd231ps %%ymm12, %%ymm8, %%ymm3\n" + "vfmadd231ps %%ymm13, %%ymm8, %%ymm4\n" + "vfmadd231ps %%ymm14, %%ymm8, %%ymm5\n" + + "vbroadcastss 216(%0), %%ymm9\n" + "vbroadcastss 248(%0), %%ymm10\n" + "vfmadd231ps %%ymm9, %%ymm8, %%ymm6\n" + "vfmadd231ps %%ymm10, %%ymm8, %%ymm7\n" + + // 8 + "vmovups 224(%1), %%ymm15\n" + + "vbroadcastss 28(%0), %%ymm11\n" + "vbroadcastss 60(%0), %%ymm12\n" + "vbroadcastss 92(%0), %%ymm13\n" + "vbroadcastss 124(%0), %%ymm14\n" + "vbroadcastss 156(%0), %%ymm9\n" + "vbroadcastss 188(%0), %%ymm10\n" + + "vfmadd231ps %%ymm11, %%ymm15, %%ymm0\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm1\n" + "vfmadd231ps %%ymm13, %%ymm15, %%ymm2\n" + "vfmadd231ps %%ymm14, %%ymm15, %%ymm3\n" + "vfmadd231ps %%ymm9, %%ymm15, %%ymm4\n" + "vfmadd231ps %%ymm10, %%ymm15, %%ymm5\n" + + "vbroadcastss 220(%0), %%ymm11\n" + "vbroadcastss 252(%0), %%ymm12\n" + "vfmadd231ps %%ymm11, %%ymm15, %%ymm6\n" + "vfmadd231ps %%ymm12, %%ymm15, %%ymm7\n" + : + : "r"(src), "r"(weight) + : "memory"); +#else + for (int j = 0; j < C8NUM; ++j) { + __m256 weight_data = _mm256_loadu_ps(weight + j * C8NUM); + dst1 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j)), dst1); + dst2 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C8NUM)), dst2); + dst3 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C16NUM)), dst3); + dst4 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C24NUM)), dst4); + dst5 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C32NUM)), dst5); + dst6 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C40NUM)), dst6); + dst7 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C48NUM)), dst7); + dst8 = _mm256_fmadd_ps(weight_data, _mm256_set1_ps(*(src + j + C56NUM)), dst8); + } +#endif + src += C64NUM; + weight += C64NUM; + } + _mm256_storeu_ps(dst, dst1); + _mm256_storeu_ps(dst + C8NUM, dst2); + _mm256_storeu_ps(dst + C16NUM, dst3); + _mm256_storeu_ps(dst + C24NUM, dst4); + _mm256_storeu_ps(dst + C32NUM, dst5); + _mm256_storeu_ps(dst + C40NUM, dst6); + _mm256_storeu_ps(dst + C48NUM, dst7); + _mm256_storeu_ps(dst + C56NUM, dst8); + dst += cal_num; + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/WinogradPostFuncBiasReluC8.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/WinogradPostFuncBiasReluC8.c new file mode 100644 index 0000000000000000000000000000000000000000..c476a5d4af543e8222e7224f2c2df2640c9cf21d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/WinogradPostFuncBiasReluC8.c @@ -0,0 +1,357 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl/intrinsics/avx/common_utils.h" + +void WinogradPostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t plane_stride, size_t relu_type) { + size_t stride = oc8div + oc8mod; + plane_stride /= sizeof(float); + int loop_c8 = 0; + size_t src_stride = plane_size * C8NUM + plane_stride; + for (; loop_c8 <= (int)(oc8div)-C32NUM; loop_c8 += C32NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + __m256 bias3 = _mm256_setzero_ps(); + __m256 bias4 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias3 = _mm256_loadu_ps(bias + C16NUM); + bias4 = _mm256_loadu_ps(bias + C24NUM); + bias += C32NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src9 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src10 = _mm256_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m256 src13 = _mm256_loadu_ps(src + src_stride * C3NUM); + __m256 src14 = _mm256_loadu_ps(src + src_stride * C3NUM + C8NUM); + + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src9 = _mm256_add_ps(src9, bias3); + src10 = _mm256_add_ps(src10, bias3); + src13 = _mm256_add_ps(src13, bias4); + src14 = _mm256_add_ps(src14, bias4); + + ActBlock8Avx(&src1, &src2, &src5, &src6, &src9, &src10, &src13, &src14, relu_type); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src5); + _mm256_stream_ps(dst_c8 + C16NUM, src9); + _mm256_stream_ps(dst_c8 + C24NUM, src13); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src2); + _mm256_stream_ps(dst_c8 + C8NUM, src6); + _mm256_stream_ps(dst_c8 + C16NUM, src10); + _mm256_stream_ps(dst_c8 + C24NUM, src14); + dst_c8 += stride; + + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + __m256 src11 = _mm256_loadu_ps(src + src_stride * C2NUM + C16NUM); + __m256 src12 = _mm256_loadu_ps(src + src_stride * C2NUM + C24NUM); + __m256 src15 = _mm256_loadu_ps(src + src_stride * C3NUM + C16NUM); + __m256 src16 = _mm256_loadu_ps(src + src_stride * C3NUM + C24NUM); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + src11 = _mm256_add_ps(src11, bias3); + src12 = _mm256_add_ps(src12, bias3); + src15 = _mm256_add_ps(src15, bias4); + src16 = _mm256_add_ps(src16, bias4); + + ActBlock8Avx(&src3, &src4, &src7, &src8, &src11, &src12, &src15, &src16, relu_type); + + _mm256_stream_ps(dst_c8, src3); + _mm256_stream_ps(dst_c8 + C8NUM, src7); + _mm256_stream_ps(dst_c8 + C16NUM, src11); + _mm256_stream_ps(dst_c8 + C24NUM, src15); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src4); + _mm256_stream_ps(dst_c8 + C8NUM, src8); + _mm256_stream_ps(dst_c8 + C16NUM, src12); + _mm256_stream_ps(dst_c8 + C24NUM, src16); + dst_c8 += stride; + src += C32NUM; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + __m256 src3 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src4 = _mm256_loadu_ps(src + src_stride * C3NUM); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + src3 = _mm256_add_ps(src3, bias3); + src4 = _mm256_add_ps(src4, bias4); + + ActBlock4Avx(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src2); + _mm256_stream_ps(dst_c8 + C16NUM, src3); + _mm256_stream_ps(dst_c8 + C24NUM, src4); + dst_c8 += stride; + src += C8NUM; + } + src += plane_stride; + src += C3NUM * src_stride; + } + for (; loop_c8 <= (int)(oc8div)-C24NUM; loop_c8 += C24NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + __m256 bias3 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias3 = _mm256_loadu_ps(bias + C16NUM); + bias += C24NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + __m256 src9 = _mm256_loadu_ps(src + src_stride * C2NUM); + __m256 src10 = _mm256_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m256 src11 = _mm256_loadu_ps(src + src_stride * C2NUM + C16NUM); + __m256 src12 = _mm256_loadu_ps(src + src_stride * C2NUM + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + src9 = _mm256_add_ps(src9, bias3); + src10 = _mm256_add_ps(src10, bias3); + src11 = _mm256_add_ps(src11, bias3); + src12 = _mm256_add_ps(src12, bias3); + + ActBlock12Avx(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, &src9, &src10, &src11, &src12, + relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src5); + _mm256_stream_ps(dst_c8 + C16NUM, src9); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src2); + _mm256_stream_ps(dst_c8 + C8NUM, src6); + _mm256_stream_ps(dst_c8 + C16NUM, src10); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src3); + _mm256_stream_ps(dst_c8 + C8NUM, src7); + _mm256_stream_ps(dst_c8 + C16NUM, src11); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src4); + _mm256_stream_ps(dst_c8 + C8NUM, src8); + _mm256_stream_ps(dst_c8 + C16NUM, src12); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + __m256 src3 = _mm256_loadu_ps(src + src_stride * C2NUM); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + src3 = _mm256_add_ps(src3, bias3); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + ActBlock1Avx(&src2, relu_type == 1, relu_type == C3NUM); + ActBlock1Avx(&src3, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src2); + _mm256_stream_ps(dst_c8 + C16NUM, src3); + dst_c8 += stride; + src += C8NUM; + } + src += plane_stride; + src += C2NUM * src_stride; + } + for (; loop_c8 <= (int)(oc8div)-C16NUM; loop_c8 += C16NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + __m256 bias2 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias2 = _mm256_loadu_ps(bias + C8NUM); + bias += C16NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + __m256 src5 = _mm256_loadu_ps(src + src_stride); + __m256 src6 = _mm256_loadu_ps(src + src_stride + C8NUM); + __m256 src7 = _mm256_loadu_ps(src + src_stride + C16NUM); + __m256 src8 = _mm256_loadu_ps(src + src_stride + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + src5 = _mm256_add_ps(src5, bias2); + src6 = _mm256_add_ps(src6, bias2); + src7 = _mm256_add_ps(src7, bias2); + src8 = _mm256_add_ps(src8, bias2); + + ActBlock8Avx(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src5); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src2); + _mm256_stream_ps(dst_c8 + C8NUM, src6); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src3); + _mm256_stream_ps(dst_c8 + C8NUM, src7); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src4); + _mm256_stream_ps(dst_c8 + C8NUM, src8); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + src_stride); + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias2); + + ActBlock2Avx(&src1, &src2, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + _mm256_stream_ps(dst_c8 + C8NUM, src2); + dst_c8 += stride; + src += C8NUM; + } + src += plane_stride; + src += src_stride; + } + for (; loop_c8 < (int)(oc8div); loop_c8 += C8NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m256 bias1 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias += C8NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m256 src1 = _mm256_loadu_ps(src); + __m256 src2 = _mm256_loadu_ps(src + C8NUM); + __m256 src3 = _mm256_loadu_ps(src + C16NUM); + __m256 src4 = _mm256_loadu_ps(src + C24NUM); + src += C32NUM; + src1 = _mm256_add_ps(src1, bias1); + src2 = _mm256_add_ps(src2, bias1); + src3 = _mm256_add_ps(src3, bias1); + src4 = _mm256_add_ps(src4, bias1); + + ActBlock4Avx(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src2); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src3); + dst_c8 += stride; + _mm256_stream_ps(dst_c8, src4); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m256 src1 = _mm256_loadu_ps(src); + src1 = _mm256_add_ps(src1, bias1); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + + _mm256_stream_ps(dst_c8, src1); + dst_c8 += stride; + src += C8NUM; + } + src += plane_stride; + } + if (oc8mod == 0) { + return; + } + __m256 bias1 = _mm256_setzero_ps(); + if (bias != NULL) { + bias1 = _mm256_loadu_ps(bias); + bias += C8NUM; + } + float *dst_c1 = dst + oc8div; + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1, src += C8NUM, dst_c1 += stride) { + __m256 src1 = _mm256_loadu_ps(src); + src1 = _mm256_add_ps(src1, bias1); + + ActBlock1Avx(&src1, relu_type == 1, relu_type == C3NUM); + __m128 src_high = _mm256_extractf128_ps(src1, 1); + + switch (oc8mod) { + case 1: + dst_c1[0] = _mm256_cvtss_f32(src1); + break; + case C2NUM: + _mm_storel_pi((__m64 *)(dst_c1), _mm256_castps256_ps128(src1)); + break; + case C3NUM: + _mm_storel_pi((__m64 *)(dst_c1), _mm256_castps256_ps128(src1)); + dst_c1[C2NUM] = MS_F32X8_GETI(src1, C2NUM); + break; + case C4NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + break; + case C5NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_store_ss(dst_c1 + C4NUM, src_high); + break; + case C6NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_storel_pi((__m64 *)(dst_c1 + C4NUM), src_high); + break; + case C7NUM: + _mm_storeu_ps(dst_c1, _mm256_castps256_ps128(src1)); + _mm_storel_pi((__m64 *)(dst_c1 + C4NUM), src_high); + dst_c1[C6NUM] = MS_F32X4_GETI(src_high, C2NUM); + break; + default: + break; + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/WinogradTransAvx.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/WinogradTransAvx.c new file mode 100644 index 0000000000000000000000000000000000000000..820f415c4c2f31d5c6853584711de8bce93b9af5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/WinogradTransAvx.c @@ -0,0 +1,355 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/fp32/common_func_fp32.h" + +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + size_t len_c8 = length * C8NUM; + size_t S_step = length * w * C8NUM; + for (int h1 = 0; h1 < h; ++h1) { + const float *SW = S; + memset(M, 0, len_c8 * w * sizeof(float)); + for (int w_tmp = w; w_tmp > 0; --w_tmp) { + const float *SK = SW; + const float *BK = B; + int k_tmp = k; + for (; k_tmp >= C8NUM; k_tmp -= C8NUM) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + __m256 k4 = _mm256_set1_ps(*(BK + C3NUM * h)); + __m256 k5 = _mm256_set1_ps(*(BK + C4NUM * h)); + __m256 k6 = _mm256_set1_ps(*(BK + C5NUM * h)); + __m256 k7 = _mm256_set1_ps(*(BK + C6NUM * h)); + __m256 k8 = _mm256_set1_ps(*(BK + C7NUM * h)); + BK += C8NUM * h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += C8NUM, SK += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + M1 = _mm256_fmadd_ps(s3, k3, M1); + __m256 s4 = _mm256_loadu_ps(SK + C3NUM * S_step); + M2 = _mm256_fmadd_ps(s4, k4, M2); + __m256 s5 = _mm256_loadu_ps(SK + C4NUM * S_step); + M1 = _mm256_fmadd_ps(s5, k5, M1); + __m256 s6 = _mm256_loadu_ps(SK + C5NUM * S_step); + M2 = _mm256_fmadd_ps(s6, k6, M2); + __m256 s7 = _mm256_loadu_ps(SK + C6NUM * S_step); + M1 = _mm256_fmadd_ps(s7, k7, M1); + __m256 s8 = _mm256_loadu_ps(SK + C7NUM * S_step); + M2 = _mm256_fmadd_ps(s8, k8, M2); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + M -= len_c8; + SK += C8NUM * S_step - len_c8; + } + for (; k_tmp >= C4NUM; k_tmp -= C4NUM) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + __m256 k4 = _mm256_set1_ps(*(BK + C3NUM * h)); + BK += C4NUM * h; + int len_tmp = length; + for (; len_tmp >= C2NUM; len_tmp -= C2NUM, SK += C16NUM, M += C16NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 M3 = _mm256_loadu_ps(M + C8NUM); + __m256 M4 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + __m256 s11 = _mm256_loadu_ps(SK + C8NUM); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + __m256 s22 = _mm256_loadu_ps(SK + S_step + C8NUM); + M1 = _mm256_fmadd_ps(s1, k1, M1); + M2 = _mm256_fmadd_ps(s2, k2, M2); + M3 = _mm256_fmadd_ps(s11, k1, M3); + M4 = _mm256_fmadd_ps(s22, k2, M4); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + __m256 s33 = _mm256_loadu_ps(SK + C2NUM * S_step + C8NUM); + __m256 s4 = _mm256_loadu_ps(SK + C3NUM * S_step); + __m256 s44 = _mm256_loadu_ps(SK + C3NUM * S_step + C8NUM); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M2 = _mm256_fmadd_ps(s4, k4, M2); + M3 = _mm256_fmadd_ps(s33, k3, M3); + M4 = _mm256_fmadd_ps(s44, k4, M4); + M1 = _mm256_add_ps(M1, M2); + M4 = _mm256_add_ps(M3, M4); + _mm256_storeu_ps(M, M1); + _mm256_storeu_ps(M + C8NUM, M4); + } + for (; len_tmp > 0; len_tmp--, SK += C8NUM, M += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + M1 = _mm256_fmadd_ps(s3, k3, M1); + __m256 s4 = _mm256_loadu_ps(SK + C3NUM * S_step); + M2 = _mm256_fmadd_ps(s4, k4, M2); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + M -= len_c8; + SK += C4NUM * S_step - len_c8; + } + for (; k_tmp >= C3NUM; k_tmp -= C3NUM) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + BK += C3NUM * h; + int len_tmp = length; + for (; len_tmp >= C3NUM; len_tmp -= C3NUM, SK += C24NUM, M += C24NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 M3 = _mm256_loadu_ps(M + C8NUM); + __m256 M4 = _mm256_set1_ps(0.0f); + __m256 M5 = _mm256_loadu_ps(M + C16NUM); + __m256 M6 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + __m256 s11 = _mm256_loadu_ps(SK + C8NUM); + __m256 s22 = _mm256_loadu_ps(SK + S_step + C8NUM); + __m256 s111 = _mm256_loadu_ps(SK + C16NUM); + __m256 s222 = _mm256_loadu_ps(SK + S_step + C16NUM); + M1 = _mm256_fmadd_ps(s1, k1, M1); + M2 = _mm256_fmadd_ps(s2, k2, M2); + M3 = _mm256_fmadd_ps(s11, k1, M3); + M4 = _mm256_fmadd_ps(s22, k2, M4); + M5 = _mm256_fmadd_ps(s111, k1, M5); + M6 = _mm256_fmadd_ps(s222, k2, M6); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + __m256 s33 = _mm256_loadu_ps(SK + C2NUM * S_step + C8NUM); + __m256 s333 = _mm256_loadu_ps(SK + C2NUM * S_step + C16NUM); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M3 = _mm256_fmadd_ps(s33, k3, M3); + M5 = _mm256_fmadd_ps(s333, k3, M5); + M1 = _mm256_add_ps(M1, M2); + M3 = _mm256_add_ps(M3, M4); + M5 = _mm256_add_ps(M5, M6); + _mm256_storeu_ps(M, M1); + _mm256_storeu_ps(M + C8NUM, M3); + _mm256_storeu_ps(M + C16NUM, M5); + } + for (; len_tmp > 0; len_tmp--, SK += C8NUM, M += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(SK + S_step); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(SK + C2NUM * S_step); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + M -= len_c8; + SK += C3NUM * S_step - len_c8; + } + for (; k_tmp > 0; k_tmp -= 1) { + __m256 k1 = _mm256_set1_ps(*BK); + BK += h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += C8NUM, M += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 s0 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s0, k1, M1); + _mm256_storeu_ps(M, M1); + } + M -= len_c8; + SK += S_step - len_c8; + } + SW += len_c8; + M += len_c8; + } + B += 1; + } +} + +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + size_t len_c8 = length * C8NUM, k_step = len_c8 * k; + for (int h1 = 0; h1 < h; ++h1, S += k_step) { + const float *BW = B; + memset(M, 0, len_c8 * w * sizeof(float)); + for (int ww = 0; ww < w; ++ww, BW += 1, M += len_c8) { + const float *SK = S, *BK = BW; + int k_tmp = k; + for (; k_tmp >= C8NUM; k_tmp -= C8NUM, M -= len_c8) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + __m256 k4 = _mm256_set1_ps(*(BK + C3NUM * h)); + __m256 k5 = _mm256_set1_ps(*(BK + C4NUM * h)); + __m256 k6 = _mm256_set1_ps(*(BK + C5NUM * h)); + __m256 k7 = _mm256_set1_ps(*(BK + C6NUM * h)); + __m256 k8 = _mm256_set1_ps(*(BK + C7NUM * h)); + BK += C8NUM * h; + const float *S2 = SK + len_c8, *S3 = S2 + len_c8; + const float *S4 = S3 + len_c8, *S5 = S4 + len_c8; + const float *S6 = S5 + len_c8, *S7 = S6 + len_c8, *S8 = S7 + len_c8; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM, + S4 += C8NUM, S5 += C8NUM, S6 += C8NUM, S7 += C8NUM, S8 += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(S2); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(S3); + M1 = _mm256_fmadd_ps(s3, k3, M1); + __m256 s4 = _mm256_loadu_ps(S4); + M2 = _mm256_fmadd_ps(s4, k4, M2); + __m256 s5 = _mm256_loadu_ps(S5); + M1 = _mm256_fmadd_ps(s5, k5, M1); + __m256 s6 = _mm256_loadu_ps(S6); + M2 = _mm256_fmadd_ps(s6, k6, M2); + __m256 s7 = _mm256_loadu_ps(S7); + M1 = _mm256_fmadd_ps(s7, k7, M1); + __m256 s8 = _mm256_loadu_ps(S8); + M2 = _mm256_fmadd_ps(s8, k8, M2); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + SK = S7; + } + for (; k_tmp >= C4NUM; k_tmp -= C4NUM, M -= len_c8) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + __m256 k4 = _mm256_set1_ps(*(BK + C3NUM * h)); + BK += C4NUM * h; + const float *S2 = SK + len_c8; + const float *S3 = S2 + len_c8; + const float *S4 = S3 + len_c8; + int len_tmp = length; + for (; len_tmp >= C2NUM; + len_tmp -= C2NUM, M += C16NUM, SK += C16NUM, S2 += C16NUM, S3 += C16NUM, S4 += C16NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 M3 = _mm256_loadu_ps(M + C8NUM); + __m256 M4 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + __m256 s2 = _mm256_loadu_ps(S2); + __m256 s11 = _mm256_loadu_ps(SK + C8NUM); + __m256 s22 = _mm256_loadu_ps(S2 + C8NUM); + M1 = _mm256_fmadd_ps(s1, k1, M1); + M2 = _mm256_fmadd_ps(s2, k2, M2); + M3 = _mm256_fmadd_ps(s11, k1, M3); + M4 = _mm256_fmadd_ps(s22, k2, M4); + + __m256 s3 = _mm256_loadu_ps(S3); + __m256 s4 = _mm256_loadu_ps(S4); + __m256 s33 = _mm256_loadu_ps(S3 + C8NUM); + __m256 s44 = _mm256_loadu_ps(S4 + C8NUM); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M2 = _mm256_fmadd_ps(s4, k4, M2); + M3 = _mm256_fmadd_ps(s33, k3, M3); + M4 = _mm256_fmadd_ps(s44, k4, M4); + + M1 = _mm256_add_ps(M1, M2); + M3 = _mm256_add_ps(M3, M4); + _mm256_storeu_ps(M, M1); + _mm256_storeu_ps(M + C8NUM, M3); + } + for (; len_tmp > 0; len_tmp--, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM, S4 += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(S2); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(S3); + M1 = _mm256_fmadd_ps(s3, k3, M1); + __m256 s4 = _mm256_loadu_ps(S4); + M2 = _mm256_fmadd_ps(s4, k4, M2); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + SK = S4; + } + for (; k_tmp >= C3NUM; k_tmp -= C3NUM, M -= len_c8) { + __m256 k1 = _mm256_set1_ps(*BK); + __m256 k2 = _mm256_set1_ps(*(BK + h)); + __m256 k3 = _mm256_set1_ps(*(BK + C2NUM * h)); + BK += C3NUM * h; + const float *S2 = SK + len_c8; + const float *S3 = S2 + len_c8; + int len_tmp = length; + for (; len_tmp >= C3NUM; len_tmp -= C3NUM, M += C24NUM, SK += C24NUM, S2 += C24NUM, S3 += C24NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 M3 = _mm256_loadu_ps(M + C8NUM); + __m256 M4 = _mm256_set1_ps(0.0f); + __m256 M5 = _mm256_loadu_ps(M + C16NUM); + __m256 M6 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + __m256 s2 = _mm256_loadu_ps(S2); + __m256 s11 = _mm256_loadu_ps(SK + C8NUM); + __m256 s22 = _mm256_loadu_ps(S2 + C8NUM); + __m256 s111 = _mm256_loadu_ps(SK + C16NUM); + __m256 s222 = _mm256_loadu_ps(S2 + C16NUM); + M1 = _mm256_fmadd_ps(s1, k1, M1); + M2 = _mm256_fmadd_ps(s2, k2, M2); + M3 = _mm256_fmadd_ps(s11, k1, M3); + M4 = _mm256_fmadd_ps(s22, k2, M4); + M5 = _mm256_fmadd_ps(s111, k1, M5); + M6 = _mm256_fmadd_ps(s222, k2, M6); + __m256 s3 = _mm256_loadu_ps(S3); + __m256 s33 = _mm256_loadu_ps(S3 + C8NUM); + __m256 s333 = _mm256_loadu_ps(S3 + C16NUM); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M3 = _mm256_fmadd_ps(s33, k3, M3); + M5 = _mm256_fmadd_ps(s333, k3, M5); + M1 = _mm256_add_ps(M1, M2); + M3 = _mm256_add_ps(M3, M4); + M5 = _mm256_add_ps(M6, M5); + _mm256_storeu_ps(M, M1); + _mm256_storeu_ps(M + C8NUM, M3); + _mm256_storeu_ps(M + C16NUM, M5); + } + for (; len_tmp > 0; len_tmp--, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 M2 = _mm256_set1_ps(0.0f); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + __m256 s2 = _mm256_loadu_ps(S2); + M2 = _mm256_fmadd_ps(s2, k2, M2); + __m256 s3 = _mm256_loadu_ps(S3); + M1 = _mm256_fmadd_ps(s3, k3, M1); + M1 = _mm256_add_ps(M1, M2); + _mm256_storeu_ps(M, M1); + } + SK = S3; + } + for (; k_tmp >= 1; k_tmp -= 1, M -= len_c8) { + __m256 k1 = _mm256_set1_ps(*BK); + BK += h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += C8NUM, SK += C8NUM) { + __m256 M1 = _mm256_loadu_ps(M); + __m256 s1 = _mm256_loadu_ps(SK); + M1 = _mm256_fmadd_ps(s1, k1, M1); + _mm256_storeu_ps(M, M1); + } + } + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/common_utils.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/common_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..7b16f1239bbade662592e03d5be8c0c7489a5155 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/common_utils.c @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/intrinsics/avx/common_utils.h" +#include + +__m128i _mm_adds_epi32(__m128i a, __m128i b) { + __m128i int_min = _mm_set1_epi32(0x80000000); + __m128i int_max = _mm_set1_epi32(0x7FFFFFFF); + + const __m128i res = _mm_add_epi32(a, b); + const __m128i sign_and = _mm_and_si128(a, b); + const __m128i sign_or = _mm_or_si128(a, b); + + const __m128i min_sat_mask = _mm_andnot_si128(res, sign_and); + const __m128i max_sat_mask = _mm_andnot_si128(sign_or, res); + const __m128 res_temp = + _mm_blendv_ps(_mm_castsi128_ps(res), _mm_castsi128_ps(int_min), _mm_castsi128_ps(min_sat_mask)); + return _mm_castps_si128(_mm_blendv_ps(res_temp, _mm_castsi128_ps(int_max), _mm_castsi128_ps(max_sat_mask))); +} + +__m128i _mm_rshr_epi32(__m128i a, int shift) { + const __m128i vmask = _mm_cmpgt_epi32(_mm_setzero_si128(), a); + const __m128i vabs_a = _mm_sub_epi32(_mm_xor_si128(a, vmask), vmask); + const __m128i tmp_res = _mm_srli_epi32(vabs_a, shift); + return _mm_xor_si128(tmp_res, vmask); +} + +__m128i _mm_qrdmulh_epi32(__m128i a, __m128i b) { + const __m128i tmp_a_lo = _mm_unpacklo_epi32(a, _mm_setzero_si128()); + const __m128i tmp_a_hi = _mm_unpackhi_epi32(a, _mm_setzero_si128()); + const __m256i tmp_a_256 = _mm256_set_m128i(tmp_a_hi, tmp_a_lo); + const __m128i tmp_b_lo = _mm_unpacklo_epi32(b, _mm_setzero_si128()); + const __m128i tmp_b_hi = _mm_unpackhi_epi32(b, _mm_setzero_si128()); + const __m256i tmp_b_256 = _mm256_set_m128i(tmp_b_hi, tmp_b_lo); + __m256i tmp_out = _mm256_mul_epi32(tmp_a_256, tmp_b_256); + tmp_out = _mm256_add_epi64(tmp_out, _mm256_set1_epi64x(1ll << 30)); + const __m256i vmask = _mm256_cmpgt_epi64(_mm256_setzero_si256(), tmp_out); + const __m256i vabs_tmp_out = _mm256_sub_epi64(_mm256_xor_si256(tmp_out, vmask), vmask); + tmp_out = _mm256_srli_epi64(vabs_tmp_out, 31); + const __m256i vtmp_out = _mm256_sub_epi64(_mm256_xor_si256(tmp_out, vmask), vmask); + const int32_t max_32bit = (1ll << 31) - 1; + const int32_t min_32bit = -(1ll << 31); + int64_t *tmp_out_ptr = (int64_t *)(&vtmp_out); + int32_t r1 = tmp_out_ptr[0] > max_32bit ? max_32bit : tmp_out_ptr[0]; + r1 = r1 < min_32bit ? min_32bit : r1; + int32_t r2 = tmp_out_ptr[1] > max_32bit ? max_32bit : tmp_out_ptr[1]; + r2 = r2 < min_32bit ? min_32bit : r2; + int32_t r3 = tmp_out_ptr[2] > max_32bit ? max_32bit : tmp_out_ptr[2]; + r3 = r3 < min_32bit ? min_32bit : r3; + int32_t r4 = tmp_out_ptr[3] > max_32bit ? max_32bit : tmp_out_ptr[3]; + r4 = r4 < min_32bit ? min_32bit : r4; + return _mm_set_epi32(r4, r3, r2, r1); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/common_utils.h b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/common_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..0b80a83a1440fb06a6cd1ea9ab9a97a669a138d6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/avx/common_utils.h @@ -0,0 +1,157 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_ +#define MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_ + +#ifdef _MSC_VER +#include +#else +#include +#endif + +#ifdef __cplusplus +extern "C" { +#endif +#ifdef __GNUC__ +#if __GNUC__ < 8 +#define _mm256_set_m128i(xmm1, xmm2) \ + _mm256_permute2f128_si256(_mm256_castsi128_si256(xmm1), _mm256_castsi128_si256(xmm2), 2) +#define _mm256_set_m128f(xmm1, xmm2) \ + _mm256_permute2f128_ps(_mm256_castps128_ps256(xmm1), _mm256_castps128_ps256(xmm2), 2) +#endif +#endif + +#define AVX_ACT_RELU 1 +#define AVX_ACT_RELU6 3 + +// Signed saturating Add +__m128i _mm_adds_epi32(__m128i a, __m128i b); + +// Signed rounding shift right +__m128i _mm_rshr_epi32(__m128i a, int shift); + +// Signed saturating Rounding Doubling Multiply return High half +__m128i _mm_qrdmulh_epi32(__m128i a, __m128i b); + +static inline void ActBlock1Avx(__m256 *v1, size_t relu, size_t relu6) { + __m256 zero_ma = _mm256_setzero_ps(); + __m256 relu6_ma = _mm256_set1_ps(6.0f); + if (relu || relu6) { + *v1 = _mm256_max_ps(zero_ma, *v1); + } + if (relu6) { + *v1 = _mm256_min_ps(relu6_ma, *v1); + } +} + +static inline void ActBlock2Avx(__m256 *v1, __m256 *v2, size_t relu, size_t relu6) { + __m256 zero_ma = _mm256_setzero_ps(); + __m256 relu6_ma = _mm256_set1_ps(6.0f); + if (relu || relu6) { + *v1 = _mm256_max_ps(zero_ma, *v1); + *v2 = _mm256_max_ps(zero_ma, *v2); + } + if (relu6) { + *v1 = _mm256_min_ps(relu6_ma, *v1); + *v2 = _mm256_min_ps(relu6_ma, *v2); + } +} + +static inline void ActBlock4Avx(__m256 *v1, __m256 *v2, __m256 *v3, __m256 *v4, size_t relu, size_t relu6) { + __m256 zero_ma = _mm256_setzero_ps(); + __m256 relu6_ma = _mm256_set1_ps(6.0f); + if (relu || relu6) { + *v1 = _mm256_max_ps(zero_ma, *v1); + *v2 = _mm256_max_ps(zero_ma, *v2); + *v3 = _mm256_max_ps(zero_ma, *v3); + *v4 = _mm256_max_ps(zero_ma, *v4); + } + if (relu6) { + *v1 = _mm256_min_ps(relu6_ma, *v1); + *v2 = _mm256_min_ps(relu6_ma, *v2); + *v3 = _mm256_min_ps(relu6_ma, *v3); + *v4 = _mm256_min_ps(relu6_ma, *v4); + } +} + +static inline void ActBlock8Avx(__m256 *v1, __m256 *v2, __m256 *v3, __m256 *v4, __m256 *v5, __m256 *v6, __m256 *v7, + __m256 *v8, size_t relu_type) { + __m256 relu6 = _mm256_set1_ps(6.0); + __m256 zero = _mm256_setzero_ps(); + switch (relu_type) { + case AVX_ACT_RELU6: + *v1 = _mm256_min_ps(*v1, relu6); + *v2 = _mm256_min_ps(*v2, relu6); + *v3 = _mm256_min_ps(*v3, relu6); + *v4 = _mm256_min_ps(*v4, relu6); + *v5 = _mm256_min_ps(*v5, relu6); + *v6 = _mm256_min_ps(*v6, relu6); + *v7 = _mm256_min_ps(*v7, relu6); + *v8 = _mm256_min_ps(*v8, relu6); + case AVX_ACT_RELU: + *v1 = _mm256_max_ps(*v1, zero); + *v2 = _mm256_max_ps(*v2, zero); + *v3 = _mm256_max_ps(*v3, zero); + *v4 = _mm256_max_ps(*v4, zero); + *v5 = _mm256_max_ps(*v5, zero); + *v6 = _mm256_max_ps(*v6, zero); + *v7 = _mm256_max_ps(*v7, zero); + *v8 = _mm256_max_ps(*v8, zero); + break; + default: + break; + } +} + +static inline void ActBlock12Avx(__m256 *v1, __m256 *v2, __m256 *v3, __m256 *v4, __m256 *v5, __m256 *v6, __m256 *v7, + __m256 *v8, __m256 *v9, __m256 *v10, __m256 *v11, __m256 *v12, size_t relu, + size_t relu6) { + if (relu || relu6) { + __m256 zero_ma = _mm256_setzero_ps(); + *v1 = _mm256_max_ps(zero_ma, *v1); + *v2 = _mm256_max_ps(zero_ma, *v2); + *v3 = _mm256_max_ps(zero_ma, *v3); + *v4 = _mm256_max_ps(zero_ma, *v4); + *v5 = _mm256_max_ps(zero_ma, *v5); + *v6 = _mm256_max_ps(zero_ma, *v6); + *v7 = _mm256_max_ps(zero_ma, *v7); + *v8 = _mm256_max_ps(zero_ma, *v8); + *v9 = _mm256_max_ps(zero_ma, *v9); + *v10 = _mm256_max_ps(zero_ma, *v10); + *v11 = _mm256_max_ps(zero_ma, *v11); + *v12 = _mm256_max_ps(zero_ma, *v12); + } + if (relu6) { + __m256 relu6_ma = _mm256_set1_ps(6.0f); + *v1 = _mm256_min_ps(relu6_ma, *v1); + *v2 = _mm256_min_ps(relu6_ma, *v2); + *v3 = _mm256_min_ps(relu6_ma, *v3); + *v4 = _mm256_min_ps(relu6_ma, *v4); + *v5 = _mm256_min_ps(relu6_ma, *v5); + *v6 = _mm256_min_ps(relu6_ma, *v6); + *v7 = _mm256_min_ps(relu6_ma, *v7); + *v8 = _mm256_min_ps(relu6_ma, *v8); + *v9 = _mm256_min_ps(relu6_ma, *v9); + *v10 = _mm256_min_ps(relu6_ma, *v10); + *v11 = _mm256_min_ps(relu6_ma, *v11); + *v12 = _mm256_min_ps(relu6_ma, *v12); + } +} + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_X86_64_AVX_COMMON_UTILS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_avx512_instructions.h b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_avx512_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..5918725b66bc561cf11a3bbc29347c4bf0ee678e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_avx512_instructions.h @@ -0,0 +1,446 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include +#include + +#ifdef _MSC_VER +#include +#define MS_F32X16_GETI(src, i) src.m512_f32[i] +#define MS512_F32_GETI(src, i) src.m512_f32[i] +#else +#include +#define MS_F32X16_GETI(src, i) src[i] +#define MS512_F32_GETI(src, i) src[i] +#endif + +#pragma GCC push_options +#pragma GCC target("avx512f") + +#define PI 3.1415926f +#define LN2 0.693147f + +#define MS_FLOAT32X16 __m512 +#define MS_FLOAT512_F32 __m512 +#define MS_INT32X16 __m512i +#define MS_INT512_EPI32 __m512i +#define MS_MASK512_TYPE __mmask16 +#define MS_LD512_F32 _mm512_loadu_ps +#define MS_LD512_EPI32(src) _mm512_loadu_si512((__m512i const *)(src)) +#define MS_LD512_HALF_EPI32(src) _mm256_loadu_si256((__m256i const *)(src)) +#define MS_ADD512_F32 _mm512_add_ps +#define MS_ADD512_EPI32 _mm512_add_epi32 +#define MS_MOV512_F32 _mm512_set1_ps +#define MS_MOV512_EPI32 _mm512_set1_epi32 +#define MS_MOV512_VAL0_F32 _mm512_setzero_ps() +#define MS_MLA512_F32(src1, src2, src3) _mm512_fmadd_ps(src2, src3, src1) +#define MS_ST512_F32 _mm512_storeu_ps +#define MS_ST512_EPI32(src1, src2) _mm512_storeu_si512((__m512i *)(src1), src2) +#define MS_ST512_HALF_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2) +#define MS_SUB512_F32 _mm512_sub_ps +#define MS_SUB512_EPI32 _mm512_sub_epi32 +#define MS_MAX512_F32 _mm512_max_ps +#define MS_MAX512_EPI32 _mm512_max_epi32 +#define MS_MIN512_F32 _mm512_min_ps +#define MS_MIN512_EPI32 _mm512_min_epi32 +#define MS_SQRT512_F32 _mm512_sqrt_ps +#define MS_RSQRT512_F32 _mm512_rsqrt14_ps +#define MS_SIN512_F32 _mm512_sin_ps +#define MS_ERF512_F32 _mm512_erf_ps +#define MS_ABS512_F32 _mm512_abs_ps +#define MS_ABS512_EPI32 _mm512_abs_epi32 + +#define MS_ROUND512_F32(src) \ + _mm512_add_round_ps(src, _mm512_set1_ps(0.0f), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define MS_FLOOR512_F32 _mm512_floor_ps +#define MS_CEIL512_F32 _mm512_ceil_ps +#define MS_MUL512_F32(src1, src2) _mm512_mul_ps(src1, src2) +#define MS_MUL512_EPI32(src1, src2) _mm512_mullo_epi32(src1, src2) +#define MS_FMADD512_F32(src1, src2, src3) _mm512_fmadd_ps(src1, src2, src3) +#define MS_FMSUB512_F32(src1, src2, src3) _mm512_fmsub_ps(src1, src2, src3) +#define MS_FSMUL512_F32(src1, src2, src3) _mm512_fnmadd_ps(src3, src2, src1) // src1 - src2 * src3 +#define MS_DIV512_F32(src1, src2) _mm512_div_ps(src1, src2) +#define MS_MUL512_N_F32(src1, src2) _mm512_mul_ps(src1, _mm512_set1_ps(src2)) +#define MS_MUL512_N_EPI32(src1, src2) _mm512_mullo_epi32(src1, _mm512_set1_epi32(src2)) +#define MS_DIV512_N_F32(src1, src2) _mm512_div_ps(src1, _mm512_set1_ps(src2)) +#define MS_SLLI512_EPI32(src1, src2) _mm512_slli_epi32(src1, src2) +#define MS_CVT512PS_EPI32(src) _mm512_cvttps_epi32(src) +#define MS_CVT512EPI32_PS(src) _mm512_cvtepi32_ps(src) // truncate float to int +#define MS_CMP512_F32(src1, src2, src3) _mm512_cmp_ps_mask(src1, src2, src3) +#define MS_CMPGT512_F32(src1, src2) _mm512_cmp_ps_mask(src1, src2, 30) +#define MS_CMPLE512_F32(src1, src2) _mm512_cmp_ps_mask(src1, src2, 18) +#define MS_CMPLT512_F32(src1, src2) _mm512_cmp_ps_mask(src1, src2, 17) +#define MS_CMPGT512_EPI32(src1, src2) _mm512_cmpgt_epi32(src1, src2) +#define MS_BLEND512_F32(src1, src2, mask) _mm512_mask_blend_ps(mask, src1, src2) +#define MS_BLEND512_EPI32(src1, src2, mask) _mm512_mask_blend_epi32(mask, src1, src2) +#define MS_CAST512_F32_S32(src) _mm512_castsi512_ps(src) +#define MS_REDUCE_ADD512_F32(src) _mm512_reduce_add_ps(src) +#define MS_GET_MAX512_F32(src) _mm512_reduce_max_ps(src) +#define MS_GET_MIN512_F32(src) _mm512_reduce_min_ps(src) +#define MS_GET_SUM512_F32(src) _mm512_reduce_add_ps(src) +#define MS_AND512_MASK(src1, src2) _mm512_kand(src1, src2) + +#define MS512_SRLI_EPI32(src1, src2) _mm512_srli_epi32(src1, src2) +#define MS512_AND_EPI32(src1, src2) _mm512_and_si512(src1, src2) +#define MS512_CASTPS_EPI32(src) _mm512_castps_si512(src) +#define MS_OR512_EPI32(src1, src2) _mm512_or_epi32(src1, src2) +#define MS_AND512_EPI32(src1, src2) _mm512_and_epi32(src1, src2) +#define MS_AND512_F32(src1, src2) \ + _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(src1), _mm512_castps_si512(src2))) + +static inline MS_FLOAT512_F32 SIMD_SIGN512_F32(MS_FLOAT512_F32 src) { + MS_FLOAT512_F32 abs_src = MS_ABS512_F32(src); + MS_FLOAT512_F32 sign = MS_DIV512_F32(abs_src, src); + return sign; +} + +#define SIMD_SIGNABS512_F32(src, abs_src) MS_DIV512_F32(abs_src, src) + +static inline MS_FLOAT512_F32 MS_OR512_F32(MS_FLOAT512_F32 src1, MS_FLOAT512_F32 src2) { + MS_FLOAT512_F32 result = MS_CAST512_F32_S32(MS_OR512_EPI32(MS512_CASTPS_EPI32(src1), MS512_CASTPS_EPI32(src2))); + return result; +} + +static inline MS_FLOAT512_F32 MS512_ANDNOT_F32(MS_FLOAT512_F32 src1, MS_FLOAT512_F32 src2) { + MS_FLOAT512_F32 result = MS_CAST512_F32_S32(MS_AND512_EPI32(~MS512_CASTPS_EPI32(src1), MS512_CASTPS_EPI32(src2))); + return result; +} + +static inline MS_FLOAT512_F32 MS_AND512_MASK_F32(MS_MASK512_TYPE mask, MS_FLOAT512_F32 value) { + /* mask = T ? value ; 0 */ + MS_FLOAT512_F32 zeros = _mm512_set1_ps(0.0f); + return _mm512_mask_blend_ps(mask, zeros, value); +} + +static inline MS_FLOAT32X16 MS_POW512_F32(MS_FLOAT32X16 src1, MS_FLOAT32X16 src2) { + MS_FLOAT32X16 dst; + MS512_F32_GETI(dst, 0) = powf(MS512_F32_GETI(src1, 0), MS512_F32_GETI(src2, 0)); + MS512_F32_GETI(dst, 1) = powf(MS512_F32_GETI(src1, 1), MS512_F32_GETI(src2, 1)); + MS512_F32_GETI(dst, 2) = powf(MS512_F32_GETI(src1, 2), MS512_F32_GETI(src2, 2)); + MS512_F32_GETI(dst, 3) = powf(MS512_F32_GETI(src1, 3), MS512_F32_GETI(src2, 3)); + MS512_F32_GETI(dst, 4) = powf(MS512_F32_GETI(src1, 4), MS512_F32_GETI(src2, 4)); + MS512_F32_GETI(dst, 5) = powf(MS512_F32_GETI(src1, 5), MS512_F32_GETI(src2, 5)); + MS512_F32_GETI(dst, 6) = powf(MS512_F32_GETI(src1, 6), MS512_F32_GETI(src2, 6)); + MS512_F32_GETI(dst, 7) = powf(MS512_F32_GETI(src1, 7), MS512_F32_GETI(src2, 7)); + MS512_F32_GETI(dst, 8) = powf(MS512_F32_GETI(src1, 8), MS512_F32_GETI(src2, 8)); + MS512_F32_GETI(dst, 9) = powf(MS512_F32_GETI(src1, 9), MS512_F32_GETI(src2, 9)); + MS512_F32_GETI(dst, 10) = powf(MS512_F32_GETI(src1, 10), MS512_F32_GETI(src2, 10)); + MS512_F32_GETI(dst, 11) = powf(MS512_F32_GETI(src1, 11), MS512_F32_GETI(src2, 11)); + MS512_F32_GETI(dst, 12) = powf(MS512_F32_GETI(src1, 12), MS512_F32_GETI(src2, 12)); + MS512_F32_GETI(dst, 13) = powf(MS512_F32_GETI(src1, 13), MS512_F32_GETI(src2, 13)); + MS512_F32_GETI(dst, 14) = powf(MS512_F32_GETI(src1, 14), MS512_F32_GETI(src2, 14)); + MS512_F32_GETI(dst, 15) = powf(MS512_F32_GETI(src1, 15), MS512_F32_GETI(src2, 15)); + return dst; +} + +static inline MS_FLOAT32X16 MS_COS512_F32(MS_FLOAT32X16 src) { + static const MS_FLOAT32X16 pi = {PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI, PI}; + static const MS_FLOAT32X16 pi2_neg = {-2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, + -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI}; + static const MS_FLOAT32X16 div_pi2 = {1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), + 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), + 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), + 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI)}; + MS_FLOAT512_F32 src_abs = MS_ABS512_F32(src); + MS_FLOAT512_F32 src_cycle = + MS_ADD512_F32(MS_MUL512_F32(MS_FLOOR512_F32(MS_MUL512_F32(MS_ADD512_F32(src_abs, pi), div_pi2)), pi2_neg), src_abs); + static const MS_FLOAT512_F32 data0 = {1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, + 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, + 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90}; + static const MS_FLOAT512_F32 data1 = {1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, + 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, + 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56}; + static const MS_FLOAT512_F32 data2 = {1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, + 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, + 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30}; + static const MS_FLOAT512_F32 data3 = {1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, + 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, + 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12}; + static const MS_FLOAT512_F32 data4 = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; + static const MS_FLOAT32X16 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X16 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + + MS_FLOAT32X16 square = MS_MUL512_F32(src_cycle, src_cycle); + + MS_FLOAT32X16 tmp = + MS_MUL512_F32(MS_MUL512_F32(MS_ADD512_F32(MS_MUL512_F32(MS_MUL512_F32(neg, square), data0), pos), square), data1); + MS_FLOAT32X16 tmp1 = MS_MUL512_F32(MS_MUL512_F32(MS_ADD512_F32(tmp, neg), square), data2); + MS_FLOAT512_F32 res = MS_ADD512_F32( + MS_MUL512_F32( + MS_MUL512_F32(MS_ADD512_F32(MS_MUL512_F32(MS_MUL512_F32(MS_ADD512_F32(tmp1, pos), square), data3), neg), square), + data4), + pos); + return res; +} + +static inline MS_FLOAT32X16 MS512_LOG_F32(MS_FLOAT32X16 src) { + const MS_INT512_EPI32 gFloatExpMask = MS_MOV512_EPI32(0xffULL << 23); + const MS_INT512_EPI32 gFloatExp0 = MS_MOV512_EPI32(127ULL << 23); + const MS_INT512_EPI32 gExpNormalizer = MS_MOV512_EPI32(127); + static const MS_FLOAT512_F32 data0 = {1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, + 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, + 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11}; + static const MS_FLOAT512_F32 data1 = {1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, + 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9}; + static const MS_FLOAT512_F32 data2 = {1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, + 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7}; + static const MS_FLOAT512_F32 data3 = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, + 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f}; + static const MS_FLOAT512_F32 data4 = {1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, + 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3}; + static const MS_FLOAT512_F32 data5 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT512_F32 data6 = {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, + 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}; + static const MS_FLOAT32X16 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X16 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT32X16 ln2 = {LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2}; + + const MS_INT512_EPI32 exps32 = MS512_SRLI_EPI32(MS512_AND_EPI32(gFloatExpMask, MS512_CASTPS_EPI32(src)), 23); + const MS_INT512_EPI32 normExps = MS_SUB512_EPI32(exps32, gExpNormalizer); + const MS_FLOAT32X16 expsPD = MS_CVT512EPI32_PS(normExps); + const MS_FLOAT32X16 y = + MS_OR512_F32(MS_CAST512_F32_S32(gFloatExp0), MS512_ANDNOT_F32(MS_CAST512_F32_S32(gFloatExpMask), src)); + MS_FLOAT32X16 div = MS_DIV512_F32(MS_ADD512_F32(y, neg), MS_ADD512_F32(y, pos)); + MS_FLOAT32X16 square = MS_MUL512_F32(div, div); + + MS_FLOAT32X16 tmp = MS_ADD512_F32( + MS_MUL512_F32(MS_ADD512_F32(MS_MUL512_F32(square, MS_ADD512_F32(MS_MUL512_F32(square, data0), data1)), data2), + square), + data3); + MS_FLOAT32X16 tmp1 = MS_MUL512_F32(square, MS_ADD512_F32(MS_MUL512_F32(square, tmp), data4)); + MS_FLOAT32X16 res = + MS_ADD512_F32(MS_MUL512_F32(ln2, expsPD), MS_MUL512_F32(MS_MUL512_F32(div, MS_ADD512_F32(tmp1, data5)), data6)); + MS_MASK512_TYPE mask = MS_CMP512_F32(src, MS_MOV512_F32(0.0f), _CMP_EQ_OQ); + res = MS_BLEND512_F32(res, MS_MOV512_F32(-INFINITY), mask); + mask = MS_CMP512_F32(src, MS_MOV512_F32(INFINITY), _CMP_EQ_OQ); + res = MS_BLEND512_F32(res, MS_MOV512_F32(INFINITY), mask); + mask = MS_CMPLT512_F32(src, MS_MOV512_F32(0.0f)); + res = MS_BLEND512_F32(res, MS_MOV512_F32(NAN), mask); + mask = MS_CMP512_F32(src, MS_MOV512_F32(0.0f), _CMP_UNORD_Q); + res = MS_BLEND512_F32(res, MS_MOV512_F32(NAN), mask); + return res; +} + +#define MS_DIV512_EPI32(src1, src2) \ + _mm512_cvttps_epi32(MS_DIV512_F32(_mm512_cvtepi32_ps(src1), _mm512_cvtepi32_ps(src2))) + +#define MS512_INT16_TO_FLOAT16(src) _mm512_cvtepi16_ph(src) +#define MS512_FLOAT16_TO_INT16(src) _mm512_cvttph_epi16(src) + +#define MS512_INT32_TO_FLOAT16(src) _mm512_cvtepi32_ph(src) +#define MS512_FLOAT16_TO_INT32(src) _mm512_cvttph_epi32(src) + +#define MS512_INT32_TO_FLOAT32(src) _mm512_cvtepi32_ps(src) +#define MS512_FLOAT32_TO_INT32(src) _mm512_cvttps_epi32(src) +#define MS512_FLOAT16_TO_FLOAT32(src) _mm512_cvtph_ps(src) +#define MS512_FLOAT32_TO_FLOAT16(src1, src2) _mm512_cvtps_ph(src1, src2) + +#define MS512_INT64_TO_FLOAT32(src) _mm512_cvtepi64_ps(src) +#define MS512_FLOAT32_TO_INT64(src) _mm512_cvttps_epi64(src) + +#define MS512_INT64_TO_FLOAT16(src) _mm512_cvtepi64_ph(src) +#define MS512_FLOAT16_TO_INT64(src) _mm512_cvttph_epi64(src) + +#define MS512_INT32_TO_FLOAT64(src) _mm512_cvtepi32_pd(src) +#define MS512_FLOAT64_TO_INT32(src) _mm512_cvttpd_epi32(src) + +#define MS512_INT64_TO_FLOAT64(src) _mm512_cvtepi64_pd(src) +#define MS512_FLOAT64_TO_INT64(src) _mm512_cvttpd_epi64(src) + +#define MS512_INT16_TO_INT32(src) _mm512_cvtepi16_epi32(src) +#define MS512_INT16_TO_INT64(src) _mm512_cvtepi16_epi64(src) +#define MS512_INT32_TO_INT16(src) _mm512_cvtepi32_epi16(src) +#define MS512_INT32_TO_INT64(src) _mm512_cvtepi32_epi64(src) +#define MS512_INT64_TO_INT16(src) _mm512_cvtepi64_epi16(src) +#define MS512_INT64_TO_INT32(src) _mm512_cvtepi64_epi32(src) + +static inline MS_FLOAT32X16 simd_exp512_f32(MS_FLOAT32X16 input) { + static MS_FLOAT32X16 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X16 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + static MS_FLOAT32X16 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, + 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, + 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, + 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, + 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, + {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, + 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, + 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, + 1.44269504088896341f}, + {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}}; + + input = MS_MAX512_F32(minv, MS_MIN512_F32(input, maxv)); + MS_INT32X16 integer = MS_CVT512PS_EPI32(MS_FLOOR512_F32(MS_FMADD512_F32(input, param[6], param[4]))); + MS_FLOAT32X16 decimal = MS_SUB512_F32(input, MS_MUL512_F32(MS_CVT512EPI32_PS(integer), param[0])); + MS_INT32X16 int_exp = MS_SLLI512_EPI32(MS_ADD512_EPI32(integer, MS_MOV512_EPI32(126)), 23); + MS_FLOAT32X16 tmp = MS_FMADD512_F32(decimal, MS_FMADD512_F32(decimal, param[1], param[2]), param[3]); + tmp = MS_FMADD512_F32(decimal, MS_FMADD512_F32(decimal, tmp, param[4]), param[5]); + MS_FLOAT32X16 decimal_exp = MS_FMADD512_F32(decimal, tmp, param[5]); + return MS_MUL512_F32(param[7], MS_MUL512_F32(decimal_exp, MS_CAST512_F32_S32(int_exp))); +} + +static inline MS_FLOAT32X16 simd_hexp512_f32(MS_FLOAT32X16 src) { + MS_FLOAT32X16 dst; + MS512_F32_GETI(dst, 0) = exp(MS512_F32_GETI(src, 0)); + MS512_F32_GETI(dst, 1) = exp(MS512_F32_GETI(src, 1)); + MS512_F32_GETI(dst, 2) = exp(MS512_F32_GETI(src, 2)); + MS512_F32_GETI(dst, 3) = exp(MS512_F32_GETI(src, 3)); + MS512_F32_GETI(dst, 4) = exp(MS512_F32_GETI(src, 4)); + MS512_F32_GETI(dst, 5) = exp(MS512_F32_GETI(src, 5)); + MS512_F32_GETI(dst, 6) = exp(MS512_F32_GETI(src, 6)); + MS512_F32_GETI(dst, 7) = exp(MS512_F32_GETI(src, 7)); + MS512_F32_GETI(dst, 8) = exp(MS512_F32_GETI(src, 8)); + MS512_F32_GETI(dst, 9) = exp(MS512_F32_GETI(src, 9)); + MS512_F32_GETI(dst, 10) = exp(MS512_F32_GETI(src, 10)); + MS512_F32_GETI(dst, 11) = exp(MS512_F32_GETI(src, 11)); + MS512_F32_GETI(dst, 12) = exp(MS512_F32_GETI(src, 12)); + MS512_F32_GETI(dst, 13) = exp(MS512_F32_GETI(src, 13)); + MS512_F32_GETI(dst, 14) = exp(MS512_F32_GETI(src, 14)); + MS512_F32_GETI(dst, 15) = exp(MS512_F32_GETI(src, 15)); + return dst; +} + +static inline void simd_exp512(MS_FLOAT32X16 input, float *dst) { + MS_FLOAT32X16 res = simd_exp512_f32(input); + MS_ST512_F32(dst, res); +} + +static inline MS_FLOAT32X16 MS_TANHX16_F32(MS_FLOAT32X16 src) { + static const MS_FLOAT32X16 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, + 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f}; + static const MS_FLOAT32X16 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, + 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f}; + static const MS_FLOAT32X16 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, + 135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, 135135.0f, + 135135.0f, 135135.0f, 135135.0f, 135135.0f}; + static const MS_FLOAT32X16 data3 = {28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, + 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f}; + static const MS_FLOAT32X16 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, + 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f}; + static const MS_FLOAT32X16 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, + 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f}; + static const MS_FLOAT32X16 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X16 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X16 square = MS_MUL512_F32(src, src); + MS_FLOAT32X16 a = + MS_MUL512_F32(MS_FMADD512_F32(MS_FMADD512_F32(MS_ADD512_F32(square, data0), square, data1), square, data2), src); + MS_FLOAT32X16 b = + MS_FMADD512_F32(MS_FMADD512_F32(MS_FMADD512_F32(data3, square, data4), square, data5), square, data2); + MS_FLOAT32X16 res = MS_DIV512_F32(a, b); + MS_FLOAT32X16 up_limit = MS_MOV512_F32(5.0f); + MS_FLOAT32X16 down_limit = MS_MOV512_F32(-5.0f); + MS_MASK512_TYPE up_mask = MS_CMPGT512_F32(src, up_limit); + MS_MASK512_TYPE down_mask = MS_CMPLT512_F32(src, down_limit); + res = MS_BLEND512_F32(res, pos, up_mask); + res = MS_BLEND512_F32(res, neg, down_mask); + return res; +} + +#define MS_TANH512_F32 MS_TANHX16_F32 + +static inline MS_FLOAT32X16 MS512_ERF_F32(MS_FLOAT32X16 src) { + MS_FLOAT32X16 dst; + MS_F32X16_GETI(dst, 0) = erff(MS_F32X16_GETI(src, 0)); + MS_F32X16_GETI(dst, 1) = erff(MS_F32X16_GETI(src, 1)); + MS_F32X16_GETI(dst, 2) = erff(MS_F32X16_GETI(src, 2)); + MS_F32X16_GETI(dst, 3) = erff(MS_F32X16_GETI(src, 3)); + MS_F32X16_GETI(dst, 4) = erff(MS_F32X16_GETI(src, 4)); + MS_F32X16_GETI(dst, 5) = erff(MS_F32X16_GETI(src, 5)); + MS_F32X16_GETI(dst, 6) = erff(MS_F32X16_GETI(src, 6)); + MS_F32X16_GETI(dst, 7) = erff(MS_F32X16_GETI(src, 7)); + MS_F32X16_GETI(dst, 8) = erff(MS_F32X16_GETI(src, 8)); + MS_F32X16_GETI(dst, 9) = erff(MS_F32X16_GETI(src, 9)); + MS_F32X16_GETI(dst, 10) = erff(MS_F32X16_GETI(src, 10)); + MS_F32X16_GETI(dst, 11) = erff(MS_F32X16_GETI(src, 11)); + MS_F32X16_GETI(dst, 12) = erff(MS_F32X16_GETI(src, 12)); + MS_F32X16_GETI(dst, 13) = erff(MS_F32X16_GETI(src, 13)); + MS_F32X16_GETI(dst, 14) = erff(MS_F32X16_GETI(src, 14)); + MS_F32X16_GETI(dst, 15) = erff(MS_F32X16_GETI(src, 15)); + return dst; +} + +#define MS_LOAD512X8_F32(src, input_ptr, num) \ + MS_FLOAT32X16 src##1 = MS_LD512_F32(input_ptr); \ + MS_FLOAT32X16 src##2 = MS_LD512_F32(input_ptr + 1 * num); \ + MS_FLOAT32X16 src##3 = MS_LD512_F32(input_ptr + 2 * num); \ + MS_FLOAT32X16 src##4 = MS_LD512_F32(input_ptr + 3 * num); \ + MS_FLOAT32X16 src##5 = MS_LD512_F32(input_ptr + 4 * num); \ + MS_FLOAT32X16 src##6 = MS_LD512_F32(input_ptr + 5 * num); \ + MS_FLOAT32X16 src##7 = MS_LD512_F32(input_ptr + 6 * num); \ + MS_FLOAT32X16 src##8 = MS_LD512_F32(input_ptr + 7 * num); + +#define MS_LOAD512X4_F32(src, input_ptr, num) \ + MS_FLOAT32X16 src##1 = MS_LD512_F32(input_ptr); \ + MS_FLOAT32X16 src##2 = MS_LD512_F32(input_ptr + 1 * num); \ + MS_FLOAT32X16 src##3 = MS_LD512_F32(input_ptr + 2 * num); \ + MS_FLOAT32X16 src##4 = MS_LD512_F32(input_ptr + 3 * num); + +#define MS_FMADD512X8_F32(src, weight, dst) \ + dst##1 = MS_MLA512_F32(dst##1, src##1, weight); \ + dst##2 = MS_MLA512_F32(dst##2, src##2, weight); \ + dst##3 = MS_MLA512_F32(dst##3, src##3, weight); \ + dst##4 = MS_MLA512_F32(dst##4, src##4, weight); \ + dst##5 = MS_MLA512_F32(dst##5, src##5, weight); \ + dst##6 = MS_MLA512_F32(dst##6, src##6, weight); \ + dst##7 = MS_MLA512_F32(dst##7, src##7, weight); \ + dst##8 = MS_MLA512_F32(dst##8, src##8, weight); + +#define MS_FMADD512X4_F32(src, weight, dst) \ + dst##1 = MS_MLA512_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLA512_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLA512_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLA512_F32(src##4, weight, dst##4); + +#define MS_SET_ZERO512X8_F32(dst) \ + MS_FLOAT32X16 dst##1 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##2 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##3 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##4 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##5 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##6 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##7 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##8 = _mm512_setzero_ps(); + +#define MS_SET_ZERO512X4_F32(dst) \ + MS_FLOAT32X16 dst##1 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##2 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##3 = _mm512_setzero_ps(); \ + MS_FLOAT32X16 dst##4 = _mm512_setzero_ps(); + +#pragma GCC pop_options + +#endif // NNACL_AVX512_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_avx_instructions.h b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_avx_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..2b3647d84423417b94a61351471a8f50639567a0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_avx_instructions.h @@ -0,0 +1,440 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include + +#ifdef _MSC_VER +#include +#define MS_F32X8_GETI(src, i) src.m256_f32[i] +#define MS256_F32_GETI(src, i) src.m256_f32[i] +#else +#include +#define MS_F32X8_GETI(src, i) src[i] +#define MS256_F32_GETI(src, i) src[i] +#endif + +#define PI 3.1415926f +#define LN2 0.693147f + +#define MS_FLOAT32X8 __m256 +#define MS_FLOAT256_F32 __m256 +#define MS_INT32X8 __m256i +#define MS_INT256_EPI32 __m256i +#define MS_MASK256_TYPE MS_FLOAT32X8 +#define MS_LD256_F32 _mm256_loadu_ps +#define MS_LD256_EPI32(src) _mm256_loadu_si256((__m256i const *)(src)) +#define MS_ADD256_F32 _mm256_add_ps +#define MS_ADD256_EPI32 _mm256_add_epi32 +#define MS_MOV256_F32 _mm256_set1_ps +#define MS_MOV256_EPI32 _mm256_set1_epi32 +#define MS_MOV256_VAL0_F32 _mm256_setzero_ps() +#define MS_MLA256_F32(src1, src2, src3) _mm256_fmadd_ps(src2, src3, src1) +#define MS_ST256_F32 _mm256_storeu_ps +#define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2) +#define MS_SUB256_F32 _mm256_sub_ps +#define MS_SUB256_EPI32 _mm256_sub_epi32 +#define MS_MAX256_F32 _mm256_max_ps +#define MS_MAX256_EPI32 _mm256_max_epi32 +#define MS_MIN256_F32 _mm256_min_ps +#define MS_MIN256_EPI32 _mm256_min_epi32 +#define MS_SQRT256_F32 _mm256_sqrt_ps +#define MS_RSQRT256_F32 _mm256_rsqrt_ps +#define MS_SIN256_F32 _mm256_sin_ps +#define MS_ERF256_F32 _mm256_erf_ps +#define MS_ROUND256_F32(src) _mm256_round_ps(src, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define MS_FLOOR256_F32 _mm256_floor_ps +#define MS_CEIL256_F32 _mm256_ceil_ps +#define MS_MUL256_F32(src1, src2) _mm256_mul_ps(src1, src2) +#define MS_MUL256_EPI32(src1, src2) _mm256_mullo_epi32(src1, src2) +#define MS_FMADD256_F32(src1, src2, src3) _mm256_fmadd_ps(src1, src2, src3) +#define MS_FMSUB256_F32(src1, src2, src3) _mm256_fmsub_ps(src1, src2, src3) +#define MS_FSMUL256_F32(src1, src2, src3) _mm256_fnmadd_ps(src3, src2, src1) // src1 - src2 * src3 +#define MS_DIV256_F32(src1, src2) _mm256_div_ps(src1, src2) +#define MS_MUL256_N_F32(src1, src2) _mm256_mul_ps(src1, _mm256_set1_ps(src2)) +#define MS_MUL256_N_EPI32(src1, src2) _mm256_mullo_epi32(src1, _mm256_set1_epi32(src2)) +#define MS_DIV256_N_F32(src1, src2) _mm256_div_ps(src1, _mm256_set1_ps(src2)) +#define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2) +#define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src) +#define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int +#define MS_CMP256_F32(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3) +#define MS_CMPGT256_F32(src1, src2) _mm256_cmp_ps(src1, src2, 30) +#define MS_CMPLE256_F32(src1, src2) _mm256_cmp_ps(src1, src2, 18) +#define MS_CMPLT256_F32(src1, src2) _mm256_cmp_ps(src1, src2, 17) +#define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2) +#define MS_BLEND256_F32(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3) +#define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3) +#define MS_CAST256_F32_S32(src) _mm256_castsi256_ps(src) +#define MS_AND256_MASK(src1, src2) _mm256_and_ps(src1, src2) +#define MS_OR256_F32(src1, src2) _mm256_or_ps(src1, src2) +#define MS_AND256_MASK_F32(src1, src2) _mm256_and_ps(src1, src2) +#define MS_AND256_F32(src1, src2) _mm256_and_ps(src1, src2) + +#define MS256_ANDNOT_F32(src1, src2) _mm256_andnot_ps(src1, src2) +#define MS256_SRLI_EPI32(src1, src2) _mm256_srli_epi32(src1, src2) +#define MS256_AND_EPI32(src1, src2) _mm256_and_si256(src1, src2) +#define MS256_CASTPS_EPI32(src) _mm256_castps_si256(src) + +static inline MS_FLOAT32X8 MS_POW256_F32(MS_FLOAT32X8 src1, MS_FLOAT32X8 src2) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = powf(MS_F32X8_GETI(src1, 0), MS_F32X8_GETI(src2, 0)); + MS_F32X8_GETI(dst, 1) = powf(MS_F32X8_GETI(src1, 1), MS_F32X8_GETI(src2, 1)); + MS_F32X8_GETI(dst, 2) = powf(MS_F32X8_GETI(src1, 2), MS_F32X8_GETI(src2, 2)); + MS_F32X8_GETI(dst, 3) = powf(MS_F32X8_GETI(src1, 3), MS_F32X8_GETI(src2, 3)); + MS_F32X8_GETI(dst, 4) = powf(MS_F32X8_GETI(src1, 4), MS_F32X8_GETI(src2, 4)); + MS_F32X8_GETI(dst, 5) = powf(MS_F32X8_GETI(src1, 5), MS_F32X8_GETI(src2, 5)); + MS_F32X8_GETI(dst, 6) = powf(MS_F32X8_GETI(src1, 6), MS_F32X8_GETI(src2, 6)); + MS_F32X8_GETI(dst, 7) = powf(MS_F32X8_GETI(src1, 7), MS_F32X8_GETI(src2, 7)); + return dst; +} + +static inline MS_FLOAT32X8 MS_ABS256_F32(MS_FLOAT32X8 src) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = fabsf(MS_F32X8_GETI(src, 0)); + MS_F32X8_GETI(dst, 1) = fabsf(MS_F32X8_GETI(src, 1)); + MS_F32X8_GETI(dst, 2) = fabsf(MS_F32X8_GETI(src, 2)); + MS_F32X8_GETI(dst, 3) = fabsf(MS_F32X8_GETI(src, 3)); + MS_F32X8_GETI(dst, 4) = fabsf(MS_F32X8_GETI(src, 4)); + MS_F32X8_GETI(dst, 5) = fabsf(MS_F32X8_GETI(src, 5)); + MS_F32X8_GETI(dst, 6) = fabsf(MS_F32X8_GETI(src, 6)); + MS_F32X8_GETI(dst, 7) = fabsf(MS_F32X8_GETI(src, 7)); + return dst; +} + +static inline MS_FLOAT256_F32 SIMD_SIGN256_F32(MS_FLOAT256_F32 src) { + MS_FLOAT256_F32 abs_src = MS_ABS256_F32(src); + MS_FLOAT256_F32 sign = MS_DIV256_F32(abs_src, src); + return sign; +} + +#define SIMD_SIGNABS256_F32(src, abs_src) MS_DIV256_F32(abs_src, src) + +static inline MS_FLOAT32X8 MS_COS256_F32(MS_FLOAT32X8 src) { + static const MS_FLOAT32X8 pi = {PI, PI, PI, PI, PI, PI, PI, PI}; + static const MS_FLOAT32X8 pi2_neg = { + -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, -2 * PI, + }; + static const MS_FLOAT32X8 div_pi2 = {1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), + 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI)}; + MS_FLOAT256_F32 src_abs = MS_ABS256_F32(src); + MS_FLOAT256_F32 src_cycle = + MS_ADD256_F32(MS_MUL256_F32(MS_FLOOR256_F32(MS_MUL256_F32(MS_ADD256_F32(src_abs, pi), div_pi2)), pi2_neg), src_abs); + + static const MS_FLOAT256_F32 data0 = {1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90, + 1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90}; + static const MS_FLOAT256_F32 data1 = {1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56, + 1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56}; + static const MS_FLOAT256_F32 data2 = {1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30, + 1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30}; + static const MS_FLOAT256_F32 data3 = {1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12, + 1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12}; + static const MS_FLOAT256_F32 data4 = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; + static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + + MS_FLOAT32X8 square = MS_MUL256_F32(src_cycle, src_cycle); + + MS_FLOAT32X8 tmp = + MS_MUL256_F32(MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_MUL256_F32(neg, square), data0), pos), square), data1); + MS_FLOAT32X8 tmp1 = MS_MUL256_F32(MS_MUL256_F32(MS_ADD256_F32(tmp, neg), square), data2); + MS_FLOAT256_F32 res = MS_ADD256_F32( + MS_MUL256_F32( + MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(MS_MUL256_F32(MS_ADD256_F32(tmp1, pos), square), data3), neg), square), + data4), + pos); + return res; +} + +static inline MS_FLOAT32X8 MS256_LOG_F32(MS_FLOAT32X8 src) { + const MS_INT256_EPI32 gFloatExpMask = MS_MOV256_EPI32(0xffULL << 23); + const MS_INT256_EPI32 gFloatExp0 = MS_MOV256_EPI32(127ULL << 23); + const MS_INT256_EPI32 gExpNormalizer = MS_MOV256_EPI32(127); + static const MS_FLOAT256_F32 data0 = {1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11, + 1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11}; + static const MS_FLOAT256_F32 data1 = {1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9}; + static const MS_FLOAT256_F32 data2 = {1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7}; + static const MS_FLOAT256_F32 data3 = {0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f}; + static const MS_FLOAT256_F32 data4 = {1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3}; + static const MS_FLOAT256_F32 data5 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT256_F32 data6 = {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}; + static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT32X8 ln2 = {LN2, LN2, LN2, LN2, LN2, LN2, LN2, LN2}; + + const MS_INT256_EPI32 exps32 = MS256_SRLI_EPI32(MS256_AND_EPI32(gFloatExpMask, MS256_CASTPS_EPI32(src)), 23); + const MS_INT256_EPI32 normExps = MS_SUB256_EPI32(exps32, gExpNormalizer); + const MS_FLOAT32X8 expsPD = MS_CVT256EPI32_PS(normExps); + const MS_FLOAT32X8 y = + MS_OR256_F32(MS_CAST256_F32_S32(gFloatExp0), MS256_ANDNOT_F32(MS_CAST256_F32_S32(gFloatExpMask), src)); + MS_FLOAT32X8 div = MS_DIV256_F32(MS_ADD256_F32(y, neg), MS_ADD256_F32(y, pos)); + MS_FLOAT32X8 square = MS_MUL256_F32(div, div); + + MS_FLOAT32X8 tmp = MS_ADD256_F32( + MS_MUL256_F32(MS_ADD256_F32(MS_MUL256_F32(square, MS_ADD256_F32(MS_MUL256_F32(square, data0), data1)), data2), + square), + data3); + MS_FLOAT32X8 tmp1 = MS_MUL256_F32(square, MS_ADD256_F32(MS_MUL256_F32(square, tmp), data4)); + MS_FLOAT32X8 res = + MS_ADD256_F32(MS_MUL256_F32(ln2, expsPD), MS_MUL256_F32(MS_MUL256_F32(div, MS_ADD256_F32(tmp1, data5)), data6)); + MS_FLOAT32X8 mask = MS_CMP256_F32(src, MS_MOV256_F32(0.0f), _CMP_EQ_OQ); + res = MS_BLEND256_F32(res, MS_MOV256_F32(-INFINITY), mask); + mask = MS_CMP256_F32(src, MS_MOV256_F32(INFINITY), _CMP_EQ_OQ); + res = MS_BLEND256_F32(res, MS_MOV256_F32(INFINITY), mask); + mask = MS_OR256_F32(MS_CMPLT256_F32(src, MS_MOV256_F32(0.0f)), MS_CMP256_F32(src, MS_MOV256_F32(0.0f), _CMP_UNORD_Q)); + res = MS_BLEND256_F32(res, MS_MOV256_F32(NAN), mask); + return res; +} + +static inline float MS_GET_MAX256_F32(__m256 src) { + float result = MS_F32X8_GETI(src, 0); + for (int i = 1; i < 8; i++) { // avx block num : 8 + result = fmaxf(result, MS_F32X8_GETI(src, i)); + } + return result; +} + +static inline float MS_GET_SUM256_F32(__m256 src) { + float result = MS_F32X8_GETI(src, 0); + for (int i = 1; i < 8; i++) { // avx block num : 8 + result = result + MS_F32X8_GETI(src, i); + } + return result; +} + +#define MS_DIV256_EPI32(src1, src2) \ + _mm256_cvttps_epi32(MS_DIV256_F32(_mm256_cvtepi32_ps(src1), _mm256_cvtepi32_ps(src2))) + +#define MS256_INT16_TO_FLOAT16(src) _mm256_cvtepi16_ph(src) +#define MS256_FLOAT16_TO_INT16(src) _mm256_cvttph_epi16(src) + +#define MS256_INT32_TO_FLOAT16(src) _mm256_cvtepi32_ph(src) +#define MS256_FLOAT16_TO_INT32(src) _mm256_cvttph_epi32(src) + +#define MS256_INT32_TO_FLOAT32(src) _mm256_cvtepi32_ps(src) +#define MS256_FLOAT32_TO_INT32(src) _mm256_cvttps_epi32(src) + +#define MS256_INT64_TO_FLOAT32(src) _mm256_cvtepi64_ps(src) +#define MS256_FLOAT32_TO_INT64(src) _mm256_cvttps_epi64(src) + +#define MS256_INT64_TO_FLOAT16(src) _mm256_cvtepi64_ph(src) +#define MS256_FLOAT16_TO_INT64(src) _mm256_cvttph_epi64(src) + +#define MS256_INT32_TO_FLOAT64(src) _mm256_cvtepi32_pd(src) +#define MS256_FLOAT64_TO_INT32(src) _mm256_cvttpd_epi32(src) + +#define MS256_INT64_TO_FLOAT64(src) _mm256_cvtepi64_pd(src) +#define MS256_FLOAT64_TO_INT64(src) _mm256_cvttpd_epi64(src) + +#define MS256_INT16_TO_INT32(src) _mm256_cvtepi16_epi32(src) +#define MS256_INT16_TO_INT64(src) _mm256_cvtepi16_epi64(src) +#define MS256_INT32_TO_INT16(src) _mm256_cvtepi32_epi16(src) +#define MS256_INT32_TO_INT64(src) _mm256_cvtepi32_epi64(src) +#define MS256_INT64_TO_INT16(src) _mm256_cvtepi64_epi16(src) +#define MS256_INT64_TO_INT32(src) _mm256_cvtepi64_epi32(src) + +static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = sqrtf(MS_F32X8_GETI(src, 0)); + MS_F32X8_GETI(dst, 1) = sqrtf(MS_F32X8_GETI(src, 1)); + MS_F32X8_GETI(dst, 2) = sqrtf(MS_F32X8_GETI(src, 2)); + MS_F32X8_GETI(dst, 3) = sqrtf(MS_F32X8_GETI(src, 3)); + MS_F32X8_GETI(dst, 4) = sqrtf(MS_F32X8_GETI(src, 4)); + MS_F32X8_GETI(dst, 5) = sqrtf(MS_F32X8_GETI(src, 5)); + MS_F32X8_GETI(dst, 6) = sqrtf(MS_F32X8_GETI(src, 6)); + MS_F32X8_GETI(dst, 7) = sqrtf(MS_F32X8_GETI(src, 7)); + return dst; +} + +#define MS_LOAD256X4_F32(src, input_ptr, num) \ + MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \ + MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \ + MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \ + MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); + +#define MS_LOAD256X8_F32(src, input_ptr, num) \ + MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \ + MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \ + MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \ + MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \ + MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \ + MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \ + MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \ + MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); + +#define MS_LOAD256X16_F32(src, input_ptr, num) \ + MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \ + MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \ + MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \ + MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \ + MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \ + MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \ + MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \ + MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); \ + MS_FLOAT32X8 src##9 = MS_LD256_F32(input_ptr + 8 * num); \ + MS_FLOAT32X8 src##10 = MS_LD256_F32(input_ptr + 9 * num); \ + MS_FLOAT32X8 src##11 = MS_LD256_F32(input_ptr + 10 * num); \ + MS_FLOAT32X8 src##12 = MS_LD256_F32(input_ptr + 11 * num); \ + MS_FLOAT32X8 src##13 = MS_LD256_F32(input_ptr + 12 * num); \ + MS_FLOAT32X8 src##14 = MS_LD256_F32(input_ptr + 13 * num); \ + MS_FLOAT32X8 src##15 = MS_LD256_F32(input_ptr + 14 * num); \ + MS_FLOAT32X8 src##16 = MS_LD256_F32(input_ptr + 15 * num); + +#define STORE256X8_F32(output_ptr, num, dst) \ + MS_ST256_F32(output_ptr + 0 * num, dst##1); \ + MS_ST256_F32(output_ptr + 1 * num, dst##2); \ + MS_ST256_F32(output_ptr + 2 * num, dst##3); \ + MS_ST256_F32(output_ptr + 3 * num, dst##4); \ + MS_ST256_F32(output_ptr + 4 * num, dst##5); \ + MS_ST256_F32(output_ptr + 5 * num, dst##6); \ + MS_ST256_F32(output_ptr + 6 * num, dst##7); \ + MS_ST256_F32(output_ptr + 7 * num, dst##8); + +#define STORE256X16_F32(output_ptr, num, dst) \ + MS_ST256_F32(output_ptr + 0 * num, dst##1); \ + MS_ST256_F32(output_ptr + 1 * num, dst##2); \ + MS_ST256_F32(output_ptr + 2 * num, dst##3); \ + MS_ST256_F32(output_ptr + 3 * num, dst##4); \ + MS_ST256_F32(output_ptr + 4 * num, dst##5); \ + MS_ST256_F32(output_ptr + 5 * num, dst##6); \ + MS_ST256_F32(output_ptr + 6 * num, dst##7); \ + MS_ST256_F32(output_ptr + 7 * num, dst##8); \ + MS_ST256_F32(output_ptr + 8 * num, dst##9); \ + MS_ST256_F32(output_ptr + 9 * num, dst##10); \ + MS_ST256_F32(output_ptr + 10 * num, dst##11); \ + MS_ST256_F32(output_ptr + 11 * num, dst##12); \ + MS_ST256_F32(output_ptr + 12 * num, dst##13); \ + MS_ST256_F32(output_ptr + 13 * num, dst##14); \ + MS_ST256_F32(output_ptr + 14 * num, dst##15); \ + MS_ST256_F32(output_ptr + 15 * num, dst##16); + +static inline MS_FLOAT32X8 simd_exp256_f32(MS_FLOAT32X8 input) { + static MS_FLOAT32X8 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, + 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X8 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, + -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + static MS_FLOAT32X8 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, + {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, + 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}, + {2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}}; + input = MS_MAX256_F32(minv, MS_MIN256_F32(input, maxv)); + MS_INT32X8 integer = MS_CVT256PS_EPI32(MS_FLOOR256_F32(MS_FMADD256_F32(input, param[6], param[4]))); + MS_FLOAT32X8 decimal = MS_SUB256_F32(input, MS_MUL256_F32(MS_CVT256EPI32_PS(integer), param[0])); + MS_INT32X8 int_exp = MS_SLLI256_EPI32(MS_ADD256_EPI32(integer, MS_MOV256_EPI32(126)), 23); + MS_FLOAT32X8 tmp = MS_FMADD256_F32(decimal, MS_FMADD256_F32(decimal, param[1], param[2]), param[3]); + tmp = MS_FMADD256_F32(decimal, MS_FMADD256_F32(decimal, tmp, param[4]), param[5]); + MS_FLOAT32X8 decimal_exp = MS_FMADD256_F32(decimal, tmp, param[5]); + return MS_MUL256_F32(param[7], MS_MUL256_F32(decimal_exp, MS_CAST256_F32_S32(int_exp))); +} + +static inline MS_FLOAT32X8 simd_hexp256_f32(MS_FLOAT32X8 src) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = exp(MS_F32X8_GETI(src, 0)); + MS_F32X8_GETI(dst, 1) = exp(MS_F32X8_GETI(src, 1)); + MS_F32X8_GETI(dst, 2) = exp(MS_F32X8_GETI(src, 2)); + MS_F32X8_GETI(dst, 3) = exp(MS_F32X8_GETI(src, 3)); + MS_F32X8_GETI(dst, 4) = exp(MS_F32X8_GETI(src, 4)); + MS_F32X8_GETI(dst, 5) = exp(MS_F32X8_GETI(src, 5)); + MS_F32X8_GETI(dst, 6) = exp(MS_F32X8_GETI(src, 6)); + MS_F32X8_GETI(dst, 7) = exp(MS_F32X8_GETI(src, 7)); + return dst; +} + +static inline void simd_exp256(MS_FLOAT32X8 input, float *dst) { + MS_FLOAT32X8 res = simd_exp256_f32(input); + MS_ST256_F32(dst, res); +} + +static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) { + static const MS_FLOAT32X8 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f}; + static const MS_FLOAT32X8 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f}; + static const MS_FLOAT32X8 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f, + 135135.0f, 135135.0f, 135135.0f, 135135.0f}; + static const MS_FLOAT32X8 data3 = {28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f, 28.0f}; + static const MS_FLOAT32X8 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f, 3150.0f}; + static const MS_FLOAT32X8 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f, 62370.0f}; + static const MS_FLOAT32X8 neg = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X8 pos = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X8 square = MS_MUL256_F32(src, src); + MS_FLOAT32X8 a = + MS_MUL256_F32(MS_FMADD256_F32(MS_FMADD256_F32(MS_ADD256_F32(square, data0), square, data1), square, data2), src); + MS_FLOAT32X8 b = + MS_FMADD256_F32(MS_FMADD256_F32(MS_FMADD256_F32(data3, square, data4), square, data5), square, data2); + MS_FLOAT32X8 res = MS_DIV256_F32(a, b); + MS_FLOAT32X8 up_limit = MS_MOV256_F32(5.0f); + MS_FLOAT32X8 down_limit = MS_MOV256_F32(-5.0f); + MS_FLOAT32X8 up_mask = MS_CMPGT256_F32(src, up_limit); + MS_FLOAT32X8 down_mask = MS_CMPLT256_F32(src, down_limit); + res = MS_BLEND256_F32(res, pos, up_mask); + res = MS_BLEND256_F32(res, neg, down_mask); + return res; +} + +#define MS_TANH256_F32 MS_TANHX8_F32 + +static inline MS_FLOAT32X8 MS256_ERF_F32(MS_FLOAT32X8 src) { + MS_FLOAT32X8 dst; + MS_F32X8_GETI(dst, 0) = erff(MS_F32X8_GETI(src, 0)); + MS_F32X8_GETI(dst, 1) = erff(MS_F32X8_GETI(src, 1)); + MS_F32X8_GETI(dst, 2) = erff(MS_F32X8_GETI(src, 2)); + MS_F32X8_GETI(dst, 3) = erff(MS_F32X8_GETI(src, 3)); + MS_F32X8_GETI(dst, 4) = erff(MS_F32X8_GETI(src, 4)); + MS_F32X8_GETI(dst, 5) = erff(MS_F32X8_GETI(src, 5)); + MS_F32X8_GETI(dst, 6) = erff(MS_F32X8_GETI(src, 6)); + MS_F32X8_GETI(dst, 7) = erff(MS_F32X8_GETI(src, 7)); + return dst; +} + +#define MS_FMADD256X8_F32(src, weight, dst) \ + dst##1 = MS_MLA256_F32(dst##1, src##1, weight); \ + dst##2 = MS_MLA256_F32(dst##2, src##2, weight); \ + dst##3 = MS_MLA256_F32(dst##3, src##3, weight); \ + dst##4 = MS_MLA256_F32(dst##4, src##4, weight); \ + dst##5 = MS_MLA256_F32(dst##5, src##5, weight); \ + dst##6 = MS_MLA256_F32(dst##6, src##6, weight); \ + dst##7 = MS_MLA256_F32(dst##7, src##7, weight); \ + dst##8 = MS_MLA256_F32(dst##8, src##8, weight); + +#define MS_SET_ZERO256X8_F32(dst) \ + MS_FLOAT32X8 dst##1 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##4 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##5 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##6 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##7 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##8 = _mm256_setzero_ps(); + +#define MS_FMADD256X4_F32(src, weight, dst) \ + dst##1 = MS_MLA256_F32(dst##1, src##1, weight); \ + dst##2 = MS_MLA256_F32(dst##2, src##2, weight); \ + dst##3 = MS_MLA256_F32(dst##3, src##3, weight); \ + dst##4 = MS_MLA256_F32(dst##4, src##4, weight); + +#define MS_SET_ZERO256X4_F32(dst) \ + MS_FLOAT32X8 dst##1 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##2 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##3 = _mm256_setzero_ps(); \ + MS_FLOAT32X8 dst##4 = _mm256_setzero_ps(); + +#define MS_REDUCE_ADD256_F32(src) (src = _mm256_hadd_ps(src, src), src = _mm256_hadd_ps(src, src), src[0] + src[4]); +#endif // NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_cpu_info.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_cpu_info.c new file mode 100644 index 0000000000000000000000000000000000000000..277c3d169545d7d1f0999771945f7d4be1ce42d6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_cpu_info.c @@ -0,0 +1,141 @@ + + +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include +#include +#include +#include "nnacl/errorcode.h" + +typedef unsigned int DWORD; +struct X86CpuInfoContext { + bool fma_flag_; + bool sse4_1_flag_; + bool avx2_flag_; + bool avx512_flag_; +}; + +static struct X86CpuInfoContext g_x86_cpu_info_context_; + +inline const bool X86_Fma_Support(void) { return g_x86_cpu_info_context_.fma_flag_; } + +inline const bool X86_Sse_Support(void) { +#ifdef ENABLE_SSE + return g_x86_cpu_info_context_.sse4_1_flag_; +#else + return false; +#endif +} + +inline const bool X86_Avx_Support(void) { +#ifdef ENABLE_AVX + return g_x86_cpu_info_context_.avx2_flag_; +#else + return false; +#endif +} + +inline const bool X86_Avx512_Support(void) { +#ifdef ENABLE_AVX512 + return g_x86_cpu_info_context_.avx512_flag_; +#else + return false; +#endif +} + +void ExecuteCpuIdCmd(DWORD cmd_code, DWORD *eax_data, DWORD *ebx_data, DWORD *ecx_data, DWORD *edx_data) { + DWORD deax, debx, decx, dedx; + asm volatile( + "movl %4, %%eax;\n" + "movl $0, %%ecx;\n" + "cpuid;\n" + "movl %%eax, %0;\n" + "movl %%ebx, %1;\n" + "movl %%ecx, %2;\n" + "movl %%edx, %3;\n" + : "=r"(deax), "=r"(debx), "=r"(decx), "=r"(dedx) + : "r"(cmd_code) + : "%eax", "%ebx", "%ecx", "%edx"); + + *eax_data = deax; + *ebx_data = debx; + *ecx_data = decx; + *edx_data = dedx; +} + +bool IsIntelX86Platform(void) { + DWORD eax_data, ebx_data, ecx_data, edx_data; + + const int vid_info_size = 13; + char *vid_info = malloc(sizeof(char) * vid_info_size); + if (vid_info == NULL) { + return false; + } + memset(vid_info, 0, vid_info_size); + + ExecuteCpuIdCmd(0, &eax_data, &ebx_data, &ecx_data, &edx_data); // eax = 0, execute cpuid to get vid info + + memcpy(vid_info, &ebx_data, 4); // Copy the first 4 characters to the array[0:3] + memcpy(vid_info + 4, &edx_data, 4); // Copy the middle 4 characters to the array[4:8] + memcpy(vid_info + 8, &ecx_data, 4); // Copy the last 4 characters to the array[8:12] + + int x86_intel_flag = (strcmp(vid_info, "GenuineIntel") == 0 || strcmp(vid_info, "AuthenticAMD") == 0) ? 1 : 0; + + free(vid_info); + return x86_intel_flag; +} + +int IntelX86CpuInfoInit(void) { + if (!IsIntelX86Platform()) { + return NNACL_ERR; + } + DWORD eax_data, ebx_data, ecx_data, edx_data; + ExecuteCpuIdCmd(1, &eax_data, &ebx_data, &ecx_data, &edx_data); // eax = 1, execute cpuid to get sse/fma flag + g_x86_cpu_info_context_.sse4_1_flag_ = (ecx_data & (1 << 19)) == 0 ? false : true; // sse flag is ecx 19 bit + g_x86_cpu_info_context_.fma_flag_ = (ecx_data & (1 << 12)) == 0 ? false : true; // fma flag is ecx 12 bit + + ExecuteCpuIdCmd(7, &eax_data, &ebx_data, &ecx_data, &edx_data); // eax = 7, execute cpuid to get avx2/avx512 flag + g_x86_cpu_info_context_.avx2_flag_ = (ebx_data & (1 << 5)) == 0 ? false : true; // avx2 flag is ecx 5 bit + g_x86_cpu_info_context_.avx512_flag_ = (ebx_data & (1 << 16)) == 0 ? false : true; // avx512 flag is ecx 16 bit + + return NNACL_OK; +} + +X86CpuInfoErrorCodeEnum IntelX86InstructionSetSupportCheck(void) { + if (IntelX86CpuInfoInit() != NNACL_OK) { + return X86CPUINFO_PLATFORM_ERR; + } +#if defined(ENABLE_AVX512) && !defined(AVX512_HARDWARE_SELF_AWARENESS) + if (!X86_Avx512_Support()) { + return X86CPUINFO_AVX512_ERR; + } +#endif + +#ifdef ENABLE_AVX + if (!X86_Avx_Support()) { + return X86CPUINFO_AVX_ERR; + } +#endif + +#ifdef ENABLE_SSE + if (!X86_Sse_Support()) { + return X86CPUINFO_SSE_ERR; + } +#endif + return X86CPUINFO_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_cpu_info.h b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_cpu_info.h new file mode 100644 index 0000000000000000000000000000000000000000..cec5ef130e3fd4391343bd8049493ae4df4400e7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_cpu_info.h @@ -0,0 +1,61 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_MS_SIMD_CPU_INFO_H_ +#define NNACL_MS_SIMD_CPU_INFO_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef ENABLE_AVX512 +#define AVX512_HARDWARE_SELF_AWARENESS +#endif + +#if defined(AVX512_HARDWARE_SELF_AWARENESS) +#define AVX512_HARDWARE_SELF_AWARENESS_BEGIN if (X86_Avx512_Support()) { +#define AVX512_HARDWARE_SELF_AWARENESS_END } +#else +#define AVX512_HARDWARE_SELF_AWARENESS_BEGIN +#define AVX512_HARDWARE_SELF_AWARENESS_END +#endif + +typedef enum X86CpuInfoErrorCodeEnum { + X86CPUINFO_OK = 0, + X86CPUINFO_PLATFORM_ERR = 1, + X86CPUINFO_AVX512_ERR, + X86CPUINFO_AVX_ERR, + X86CPUINFO_SSE_ERR, + X86CPUINFO_END = 9999 +} X86CpuInfoErrorCodeEnum; + +const bool X86_Fma_Support(void); +const bool X86_Sse_Support(void); +const bool X86_Avx_Support(void); +const bool X86_Avx512_Support(void); + +bool IsIntelX86Platform(void); +X86CpuInfoErrorCodeEnum IntelX86InstructionSetSupportCheck(void); + +int IntelX86CpuInfoInit(void); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_AVX_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_instructions.h b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..c8673ccd15d9e6d1a3ac5a6157ba9a46bdb15339 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_instructions.h @@ -0,0 +1,563 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include +#include "nnacl/intrinsics/ms_simd_cpu_info.h" + +#ifdef ENABLE_AVX512 +#include "nnacl/intrinsics/ms_simd_avx512_instructions.h" +#endif + +#ifdef ENABLE_AVX +#include "nnacl/intrinsics/ms_simd_avx_instructions.h" +#endif + +#ifdef ENABLE_SSE +#include "nnacl/intrinsics/ms_simd_sse_instructions.h" +#endif + +#ifdef ENABLE_ARM +#include "nnacl/intrinsics/ms_simd_neon_instructions.h" +#endif + +#define MS_SIMD_AVX512_INSTRUCTION(instruction, suffix) instruction##512##suffix +#define MS_SIMD_AVX_INSTRUCTION(instruction, suffix) instruction##256##suffix +#define MS_SIMD_SSE_INSTRUCTION(instruction, suffix) instruction##128##suffix +#define MS_SIMD_NEON_INSTRUCTION(instruction, suffix) instruction##128##suffix + +#define MS_SIMD_INSTRUCTION_F32(instruction) MS_SIMD_INSTRUCTION(instruction, _F32) +#define MS_SIMD_INSTRUCTION_EPI32(instruction) MS_SIMD_INSTRUCTION(instruction, _EPI32) +#define MS_SIMD_INSTRUCTION_MASK(instruction) MS_SIMD_INSTRUCTION(instruction, _MASK) + +// define (float/int) data +#define SIMD_F32 MS_SIMD_INSTRUCTION_F32(MS_FLOAT) +#define SIMD_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_INT) +#define SIMD_MASK MS_SIMD_INSTRUCTION(MS_MASK, _TYPE) + +// read scaler data +#define SIMD_F32_GETI MS_SIMD_INSTRUCTION(MS, _F32_GETI) + +// move (float/int) data +#define SIMD_MOV_F32 MS_SIMD_INSTRUCTION_F32(MS_MOV) +#define SIMD_MOV_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MOV) +#define SIMD_SET0_F32 MS_SIMD_INSTRUCTION(MS_MOV, _VAL0_F32) + +// load (float/int) data +#define SIMD_LD_F32 MS_SIMD_INSTRUCTION_F32(MS_LD) +#define SIMD_LD_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_LD) +#define SIMD_LD_HALF_EPI32 MS_SIMD_INSTRUCTION(MS_LD, _HALF_EPI32) + +// load 4 (float/int) data +#define SIMD_LDX4_F32 MS_SIMD_INSTRUCTION(MS_LOAD, X4_F32) +#define SIMD_LDX4_EPI32 MS_SIMD_INSTRUCTION(MS_LOAD, X4_EPI32) + +// stored (float/int) data +#define SIMD_ST_F32 MS_SIMD_INSTRUCTION_F32(MS_ST) +#define SIMD_ST_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_ST) +#define SIMD_ST_HALF_EPI32 MS_SIMD_INSTRUCTION(MS_ST, _HALF_EPI32) + +// sign +#define SIMD_SIGN_F32 MS_SIMD_INSTRUCTION_F32(SIMD_SIGN) +#define SIMD_SIGNABS_F32 MS_SIMD_INSTRUCTION_F32(SIMD_SIGNABS) + +// add (float/int) op +#define SIMD_ADD_F32 MS_SIMD_INSTRUCTION_F32(MS_ADD) +#define SIMD_ADD_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_ADD) +#define SIMD_ADD_N_F32(val1, val2) MS_EXPAND(SIMD_ADD_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_ADD_N_EPI32(val1, val2) MS_EXPAND(SIMD_ADD_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// sub (float/int) op +#define SIMD_SUB_F32 MS_SIMD_INSTRUCTION_F32(MS_SUB) +#define SIMD_SUB_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_SUB) +#define SIMD_SUB_N_F32(val1, val2) MS_EXPAND(SIMD_SUB_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_SUB_N_EPI32(val1, val2) MS_EXPAND(SIMD_SUB_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// div (float/int) op +#define SIMD_DIV_F32 MS_SIMD_INSTRUCTION_F32(MS_DIV) +#define SIMD_DIV_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_DIV) +#define SIMD_DIV_N_F32(val1, val2) MS_EXPAND(SIMD_DIV_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_DIV_N_EPI32(val1, val2) MS_EXPAND(SIMD_DIV_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// sqrt (float) op +#define SIMD_SQRT_F32 MS_SIMD_INSTRUCTION_F32(MS_SQRT) + +// rsqrt (float) op +#define SIMD_RSQRT_F32 MS_SIMD_INSTRUCTION_F32(MS_RSQRT) + +// log (float) op +#define SIMD_LOG_F32 MS_SIMD_INSTRUCTION(MS, _LOG_F32) + +// cos (float) op +#define SIMD_COS_F32 MS_SIMD_INSTRUCTION_F32(MS_COS) + +// sin (float) op +#define SIMD_SIN_F32 MS_SIMD_INSTRUCTION_F32(MS_SIN) + +// erf (float) op +#define SIMD_ERF_F32 MS_SIMD_INSTRUCTION(MS, _ERF_F32) + +// abs (float) op +#define SIMD_ABS_F32 MS_SIMD_INSTRUCTION_F32(MS_ABS) +#define SIMD_ABS_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_ABS) + +// round (float) op +#define SIMD_ROUND_F32 MS_SIMD_INSTRUCTION_F32(MS_ROUND) + +// ceil (float) op +#define SIMD_CEIL_F32 MS_SIMD_INSTRUCTION_F32(MS_CEIL) + +// floor (float) op +#define SIMD_FLOOR_F32 MS_SIMD_INSTRUCTION_F32(MS_FLOOR) + +// tanh (float) op +#define SIMD_TANH_F32 MS_SIMD_INSTRUCTION_F32(MS_TANH) + +// min (float/int) op +#define SIMD_MIN_F32 MS_SIMD_INSTRUCTION_F32(MS_MIN) +#define SIMD_MIN_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MIN) +#define SIMD_MIN_N_F32(val1, val2) MS_EXPAND(SIMD_MIN_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_MIN_N_EPI32(val1, val2) MS_EXPAND(SIMD_MIN_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// max (float/int) op +#define SIMD_MAX_F32 MS_SIMD_INSTRUCTION_F32(MS_MAX) +#define SIMD_MAX_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MAX) +#define SIMD_MAX_N_F32(val1, val2) MS_EXPAND(SIMD_MAX_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_MAX_N_EPI32(val1, val2) MS_EXPAND(SIMD_MAX_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// get max (float/int) op +#define SIMD_GET_MAX_F32 MS_SIMD_INSTRUCTION_F32(MS_GET_MAX) +#define SIMD_GET_MAX_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_GET_MAX) + +// get max (float/int) op +#define SIMD_GET_SUM_F32 MS_SIMD_INSTRUCTION_F32(MS_GET_SUM) +#define SIMD_REDUCE_ADD_F32 MS_SIMD_INSTRUCTION(MS_REDUCE_ADD, _F32) + +// clamp (float/int) op +#define SIMD_CLAMP_F32(val, min_val, max_val) SIMD_MIN_F32(SIMD_MAX_F32(val, min_val), max_val) +#define SIMD_CLAMP_EPI32(val, min_val, max_val) SIMD_MIN_EPI32(SIMD_MAX_EPI32(val, min_val), max_val) +#define SIMD_CLAMP_N_F32(val, min_val, max_val) \ + SIMD_MIN_F32(SIMD_MAX_F32(val, SIMD_MOV_F32(min_val)), SIMD_MOV_F32(max_val)) +#define SIMD_CLAMP_N_EPI32(val, min_val, max_val) \ + SIMD_MIN_EPI32(SIMD_MAX_EPI32(val, SIMD_MOV_EPI32(min_val)), SIMD_MOV_EPI32(max_val)) + +// mul (float/int) op +#define SIMD_MUL_F32 MS_SIMD_INSTRUCTION_F32(MS_MUL) +#define SIMD_MUL_EPI32 MS_SIMD_INSTRUCTION_EPI32(MS_MUL) +#define SIMD_MUL_N_F32(val1, val2) MS_EXPAND(SIMD_MUL_F32(val1, SIMD_MOV_F32(val2))) +#define SIMD_MUL_N_EPI32(val1, val2) MS_EXPAND(SIMD_MUL_EPI32(val1, SIMD_MOV_EPI32(val2))) + +// pow (float) op +#define SIMD_POW_F32 MS_SIMD_INSTRUCTION_F32(MS_POW) + +// fma (float/int) op +#define SIMD_FMADD_F32 MS_SIMD_INSTRUCTION_F32(MS_FMADD) + +// fms (float/int) op +#define SIMD_FMSUB_F32 MS_SIMD_INSTRUCTION_F32(MS_FMSUB) + +// fsm (float) op +#define MS_FSMUL_F32 MS_SIMD_INSTRUCTION_F32(MS_FSMUL) + +// square (float/int) op +#define SIMD_MUL_SQUARE_F32(val1) SIMD_MUL_F32(val1, val1) +#define SIMD_MUL_SQUARE_EPI32(val1) SIMD_MUL_EPI32(val1, val1) + +// exp (float) op +#define SIMD_EXP_ST_F32 MS_SIMD_INSTRUCTION(simd_exp, ) +#define SIMD_EXP_F32 MS_SIMD_INSTRUCTION(simd_exp, _f32) +// exp (float) high precision but a little slow op. +#define SIMD_HEXP_F32 MS_SIMD_INSTRUCTION(simd_hexp, _f32) + +// cmp (float/int) op +#define SIMD_CMPLT_F32 MS_SIMD_INSTRUCTION_F32(MS_CMPLT) +#define SIMD_CMPLE_F32 MS_SIMD_INSTRUCTION_F32(MS_CMPLE) +#define SIMD_CMPGT_F32 MS_SIMD_INSTRUCTION_F32(MS_CMPGT) +#define SIMD_BLEND_F32 MS_SIMD_INSTRUCTION_F32(MS_BLEND) + +// cast data +#define MS_CAST_F32_S32 MS_SIMD_INSTRUCTION(MS_CAST, _F32_S32) + +// logical op +#define SIMD_AND_MASK MS_SIMD_INSTRUCTION_MASK(MS_AND) +#define SIMD_OR_F32 MS_SIMD_INSTRUCTION_F32(MS_OR) +#define SIMD_AND_MASK_F32 MS_SIMD_INSTRUCTION(MS_AND, _MASK_F32) +#define SIMD_AND_F32 MS_SIMD_INSTRUCTION_F32(MS_AND) + +#define SIMD_GETSIGN_F32(src) \ + SIMD_OR_F32(SIMD_AND_F32(src, MS_CAST_F32_S32(SIMD_MOV_EPI32(0x80000000))), \ + MS_CAST_F32_S32(SIMD_MOV_EPI32(0x3F800000))) + +// int32/float mutual conversion +#define SIMD_EPI32_TO_F32 MS_SIMD_INSTRUCTION(MS, _INT32_TO_FLOAT32) +#define SIMD_F32_TO_EPI32 MS_SIMD_INSTRUCTION(MS, _FLOAT32_TO_INT32) +#define SIMD_F16_TO_F32 MS_SIMD_INSTRUCTION(MS, _FLOAT16_TO_FLOAT32) +#define SIMD_F32_TO_F16 MS_SIMD_INSTRUCTION(MS, _FLOAT32_TO_FLOAT16) + +// enable avx512 +#if defined(ENABLE_AVX512) +#define SIMD_RUN_AVX512(function, index, ...) \ + do { \ + AVX512_HARDWARE_SELF_AWARENESS_BEGIN \ + index = function##AVX512(index, __VA_ARGS__); \ + AVX512_HARDWARE_SELF_AWARENESS_END \ + } while (0) +#else +#define SIMD_RUN_AVX512(function, index, ...) +#endif + +// enable avx256 +#if defined(ENABLE_AVX) +#define SIMD_RUN_AVX(function, index, ...) index = function##AVX(index, __VA_ARGS__) +#else +#define SIMD_RUN_AVX(function, index, ...) +#endif + +// enable sse +#if defined(ENABLE_SSE) +#define SIMD_RUN_SSE(function, index, ...) index = function##SSE(index, __VA_ARGS__) +#else +#define SIMD_RUN_SSE(function, index, ...) +#endif + +// enable neon +#if defined(ENABLE_NEON) +#define SIMD_RUN_NEON(function, index, ...) index = function##NEON(index, __VA_ARGS__) +#else +#define SIMD_RUN_NEON(function, index, ...) +#endif + +#define SIMD_RUN_NO_SCALAR(function, index, ...) \ + do { \ + SIMD_RUN_AVX512(function, index, __VA_ARGS__); \ + SIMD_RUN_AVX(function, index, __VA_ARGS__); \ + SIMD_RUN_SSE(function, index, __VA_ARGS__); \ + SIMD_RUN_NEON(function, index, __VA_ARGS__); \ + } while (0) + +#define SIMD_RUN_X86_NO_SCALAR(function, index, ...) \ + do { \ + SIMD_RUN_AVX512(function, index, __VA_ARGS__); \ + SIMD_RUN_AVX(function, index, __VA_ARGS__); \ + SIMD_RUN_SSE(function, index, __VA_ARGS__); \ + } while (0) + +#define SIMD512_BLOCK16 32 // SIMD : 512 = 16 x 32 +#define SIMD256_BLOCK16 16 // SIMD : 256 = 16 x 16 +#define SIMD128_BLOCK16 8 // SIMD : 128 = 16 x 8 + +#define SIMD512_BLOCK32 16 // SIMD : 512 = 32 x 16 +#define SIMD256_BLOCK32 8 // SIMD : 256 = 32 x 8 +#define SIMD128_BLOCK32 4 // SIMD : 128 = 32 x 4 + +#define SIMD512_BLOCK64 8 // SIMD : 512 = 64 x 8 +#define SIMD256_BLOCK64 4 // SIMD : 256 = 64 x 4 +#define SIMD128_BLOCK64 2 // SIMD : 128 = 64 x 2 + +#define MS_EXPAND(...) __VA_ARGS__ + +// Scaler +#define MS_FLOAT32X1 float +#define MS_INT32X1 int +#define MS_MOV32_F32(value) (value) +#define MS_MOV32_EPI32(value) (value) +#define MS_LD32_F32(address) (*(address)) +#define MS_LD32_EPI32(address) (*(address)) +#define MS_ST32_F32(address, value) (*(address) = (value)) +#define MS_ST32_EPI32(address, value) (*(address) = (value)) +#define MS_ADD32_F32(value1, value2) ((value1) + (value2)) +#define MS_ADD32_EPI32(value1, value2) ((value1) + (value2)) +#define MS_SUB32_F32(value1, value2) ((value1) - (value2)) +#define MS_SUB32_EPI32(value1, value2) ((value1) - (value2)) +#define MS_MUL32_F32(value1, value2) ((value1) * (value2)) +#define MS_MUL32_EPI32(value1, value2) ((value1) * (value2)) +#define MS_DIV32_F32(value1, value2) ((value1) / (value2)) +#define MS_DIV32_EPI32(value1, value2) ((value1) / (value2)) +#define MS_MIN32_F32(value1, value2) (fmin((value1), (value2))) +#define MS_MIN32_EPI32(value1, value2) ((value1) < (value2) ? (value1) : (value2)) +#define MS_MAX32_F32(value1, value2) (fmax((value1), (value2))) +#define MS_MAX32_EPI32(value1, value2) ((value1) > (value2) ? (value1) : (value2)) +#define MS_SQRT32_F32(value) (sqrt(value)) + +static inline float simd_exp32_f32(float data) { + typedef union { + float f; + int i; + } fi; + static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; // Approximate calculation param +#ifdef _WIN32 + if (data < -88.0f) { + return 0.0f; + } else if (data > 88.0f) { + return 1.6516363e+38; // e^88 = 1.6516363e+38 + } +#else + data = + MS_MAX32_F32(-87.3365478515625f, MS_MIN32_F32(88.72283935546875f, data)); // clamp(logf(FLT_MIN), logf(FLT_MAX)) +#endif + int integer = floor(data * 1.44269504088896341f + 0.5f); + float decimal = data - integer * param[0]; + fi int_exp; + const int shift = 23; + const int bias = 126; + const float factor = 2; + // 2^n * exp(r) should be counted 2 * 2^(n - 1) * exp(r), + // because n may be 128, and it is not representable by fp32. + int_exp.i = (integer + bias) << shift; // integer num 2^(n - 1) approximate calculation : ((x - 1) + 127) << 23 + // Approximate calculation + const float decimal_exp = + 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); + return factor * int_exp.f * decimal_exp; +} + +// exp(x) = exp(n * ln(2) + r) = 2^n * exp(r) = 2 * 2^(n - 1) * exp(r) +static inline void simd_exp32(float src, float *dst) { + typedef union { + float f; + int i; + } fi; + static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; // log(2.0f) + src = MS_MAX32_F32(-87.3365478515625f, MS_MIN32_F32(88.72283935546875f, src)); // clamp(logf(FLT_MIN), logf(FLT_MAX)) + int integer = floor(src * 1.44269504088896341f + 0.5f); + float decimal = src - integer * param[0]; + fi int_exp; + const int shift = 23; + const int bias = 126; + const float factor = 2; + // 2^n * exp(r) should be counted 2 * 2^(n - 1) * exp(r), + // because n may be 128, and it is not representable by fp32. + int_exp.i = (integer + bias) << shift; // integer num 2^(n - 1) approximate calculation : ((x - 1) + 127) << 23 + const float decimal_exp = + 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); + *dst = factor * int_exp.f * decimal_exp; +} + +// define (float/int) data +#define MS_FLOAT_32xN(byte_num) MS_FLOAT32##X##byte_num +#define MS_INT_32xN(byte_num) MS_INT32##X##byte_num + +// move (float/int) data +#define MS_MOVN_F32(byte_num, ...) MS_EXPAND(MS_MOV##byte_num##_F32(__VA_ARGS__)) +#define MS_MOVN_EPI32(byte_num, ...) MS_EXPAND(MS_MOV##byte_num##_EPI32(__VA_ARGS__)) + +// load (float/int) data +#define MS_LD_F32(bit_num, ...) MS_EXPAND(MS_LD##bit_num##_F32(__VA_ARGS__)) +#define MS_LD_EPI32(bit_num, ...) MS_EXPAND(MS_LD##bit_num##_EPI32(__VA_ARGS__)) + +// load 4 (float/int) data +#define MS_LDX4_F32(bit_num, ...) MS_EXPAND(MS_LOAD##bit_num##X4_F32(__VA_ARGS__)) +#define MS_LDX4_EPI32(bit_num, ...) MS_EXPAND(MS_LOAD##bit_num##X4_EPI32(__VA_ARGS__)) + +// stored (float/int) data +#define MS_ST_F32(bit_num, ...) MS_EXPAND(MS_ST##bit_num##_F32(__VA_ARGS__)) +#define MS_ST_EPI32(bit_num, ...) MS_EXPAND(MS_ST##bit_num##_EPI32(__VA_ARGS__)) + +// add (float/int) op +#define MS_ADD_F32(bit_num, ...) MS_EXPAND(MS_ADD##bit_num##_F32(__VA_ARGS__)) +#define MS_ADD_EPI32(bit_num, ...) MS_EXPAND(MS_ADD##bit_num##_EPI32(__VA_ARGS__)) +#define MS_ADD_N_F32(bit_num, val1, val2) MS_EXPAND(MS_ADD##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) +#define MS_ADD_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_ADD##bit_num##_EPI32(val1, MS_MOV##bit_num##_F32(val2))) + +// sub (float/int) op +#define MS_SUB_F32(bit_num, ...) MS_EXPAND(MS_SUB##bit_num##_F32(__VA_ARGS__)) +#define MS_SUB_EPI32(bit_num, ...) MS_EXPAND(MS_SUB##bit_num##_EPI32(__VA_ARGS__)) +#define MS_SUB_N_F32(bit_num, val1, val2) MS_EXPAND(MS_SUB##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) +#define MS_SUB_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_SUB##bit_num##_EPI32(val1, MS_MOV##bit_num##_F32(val2))) + +// div (float/int) op +#define MS_DIV_F32(bit_num, ...) MS_EXPAND(MS_DIV##bit_num##_F32(__VA_ARGS__)) +#define MS_DIV_EPI32(bit_num, ...) MS_EXPAND(MS_DIV##bit_num##_EPI32(__VA_ARGS__)) +#define MS_DIV_N_F32(bit_num, val1, val2) MS_EXPAND(MS_DIV##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) +#define MS_DIV_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_DIV##bit_num##_EPI32(val1, MS_MOV##bit_num##_EPI32(val2))) + +// sqrt (float) op +#define MS_SQRT_F32(bit_num, ...) MS_EXPAND(MS_SQRT##bit_num##_F32(__VA_ARGS__)) + +// rsqrt (float) op +#define MS_RSQRT_F32(bit_num, ...) MS_EXPAND(MS_RSQRT##bit_num##_F32(__VA_ARGS__)) + +// log (float) op +#define MS_LOG_F32(bit_num, ...) MS_EXPAND(MS_LOG##bit_num##_F32(__VA_ARGS__)) + +// cos (float) op +#define MS_COS_F32(bit_num, ...) MS_EXPAND(MS_COS##bit_num##_F32(__VA_ARGS__)) + +// sin (float) op +#define MS_SIN_F32(bit_num, ...) MS_EXPAND(MS_SIN##bit_num##_F32(__VA_ARGS__)) + +// erf (float) op +#define MS_ERF_F32(bit_num, ...) MS_EXPAND(MS_ERF##bit_num##_F32(__VA_ARGS__)) + +// log (float) op +#define MS_ABS_F32(bit_num, ...) MS_EXPAND(MS_ABS##bit_num##_F32(__VA_ARGS__)) + +// round (float) op +#define MS_ROUND_F32(bit_num, ...) MS_EXPAND(MS_ROUND##bit_num##_F32(__VA_ARGS__)) + +// ceil (float) op +#define MS_CEIL_F32(bit_num, ...) MS_EXPAND(MS_CEIL##bit_num##_F32(__VA_ARGS__)) + +// floor (float) op +#define MS_FLOOR_F32(bit_num, ...) MS_EXPAND(MS_FLOOR##bit_num##_F32(__VA_ARGS__)) + +// min (float/int) op +#define MS_MIN_F32(bit_num, ...) MS_EXPAND(MS_MIN##bit_num##_F32(__VA_ARGS__)) +#define MS_MIN_EPI32(bit_num, ...) MS_EXPAND(MS_MIN##bit_num##_EPI32(__VA_ARGS__)) +#define MS_MIN_N_F32(bit_num, val, n) MS_MIN_F32(bit_num, val, MS_MOVN_F32(bit_num, n)) +#define MS_MIN_N_EPI32(bit_num, val, n) MS_MIN_EPI32(bit_num, val, MS_MOVN_EPI32(bit_num, n)) + +// max (float/int) op +#define MS_MAX_F32(bit_num, ...) MS_EXPAND(MS_MAX##bit_num##_F32(__VA_ARGS__)) +#define MS_MAX_EPI32(bit_num, ...) MS_EXPAND(MS_MAX##bit_num##_EPI32(__VA_ARGS__)) + +// get max (float/int) op +#define MS_GET_MAX_F32(bit_num, ...) MS_EXPAND(MS_GET_MAX##bit_num##_F32(__VA_ARGS__)) +#define MS_GET_MAX_EPI32(bit_num, ...) MS_EXPAND(MS_GET_MAX##bit_num##_EPI32(__VA_ARGS__)) + +// get max (float/int) op +#define MS_GET_SUM_F32(bit_num, ...) MS_EXPAND(MS_GET_SUM##bit_num##_F32(__VA_ARGS__)) + +// max n (float/int) op +#define MS_MAX_N_F32(bit_num, val, n) MS_MAX_F32(bit_num, val, MS_MOVN_F32(bit_num, n)) +#define MS_MAX_N_EPI32(bit_num, val, n) MS_MAX_EPI32(bit_num, val, MS_MOVN_EPI32(bit_num, n)) +#define MS_CLAMP_F32(bit_num, val, min_val, max_val) MS_MIN_F32(bit_num, MS_MAX_F32(bit_num, val, min_val), max_val) +#define MS_CLAMP_EPI32(bit_num, val, min_val, max_val) \ + MS_MIN_EPI32(bit_num, MS_MAX_EPI32(bit_num, val, min_val), max_val) + +// clamp n (float/int) op +#define MS_CLAMP_N_F32(bit_num, val, min_val, max_val) \ + MS_MIN_F32(bit_num, MS_MAX_F32(bit_num, val, MS_MOV##bit_num##_F32(min_val)), MS_MOV##bit_num##_F32(max_val)) +#define MS_CLAMP_N_EPI32(bit_num, val, min_val, max_val) \ + MS_MIN_EPI32(bit_num, MS_MAX_EPI32(bit_num, val, MS_MOV##bit_num##_EPI32(min_val)), MS_MOV##bit_num##_EPI32(max_val)) + +// mul (float/int) op +#define MS_MUL_F32(bit_num, ...) MS_EXPAND(MS_MUL##bit_num##_F32(__VA_ARGS__)) +#define MS_MUL_EPI32(bit_num, ...) MS_EXPAND(MS_MUL##bit_num##_EPI32(__VA_ARGS__)) +#define MS_MUL_N_F32(bit_num, val1, val2) MS_EXPAND(MS_MUL##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) +#define MS_MUL_N_EPI32(bit_num, val1, val2) MS_EXPAND(MS_MUL##bit_num##_EPI32(val1, MS_MOV##bit_num##_EPI32(val2))) + +// fma (float/int) op +#define MS_FMADD_F32(bit_num, ...) MS_EXPAND(MS_FMADD##bit_num##_F32(__VA_ARGS__)) +#define MS_FMADD_N_F32(bit_num, val1, val2) MS_EXPAND(MS_FMADD##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) + +// fms (float/int) op +#define MS_FMSUB_F32(bit_num, ...) MS_EXPAND(MS_FMSUB##bit_num##_F32(__VA_ARGS__)) +#define MS_FMSUB_N_F32(bit_num, val1, val2) MS_EXPAND(MS_FMSUB##bit_num##_F32(val1, MS_MOV##bit_num##_F32(val2))) + +// square (float/int) op +#define MS_MUL_SQUARE_F32(bit_num, val) MS_EXPAND((MS_MUL##bit_num##_F32(val, val))) +#define MS_MUL_SQUARE_EPI32(bit_num, val) MS_EXPAND((MS_MUL##bit_num##_EPI32(val, val))) + +// exp (float) op +#define MS_EXP_ST_F32(bit_num, ...) MS_EXPAND((simd_exp##bit_num(__VA_ARGS__))) +#define MS_EXP_F32(bit_num, ...) MS_EXPAND((simd_exp##bit_num##_f32(__VA_ARGS__))) + +#define MS_CMPLT_F32(bit_num, ...) MS_EXPAND((MS_CMPLT##bit_num##_F32(__VA_ARGS__))) +#define MS_CMPLE_F32(bit_num, ...) MS_EXPAND((MS_CMPLE##bit_num##_F32(__VA_ARGS__))) +#define MS_CMPGT_F32(bit_num, ...) MS_EXPAND((MS_CMPGT##bit_num##_F32(__VA_ARGS__))) +#define MS_BLEND_F32(bit_num, ...) MS_EXPAND((MS_BLEND##bit_num##_F32(__VA_ARGS__))) + +#define MS_INT16_TO_FLOAT16(bit_num, ...) MS_EXPAND((MS##bit_num##_INT16_TO_FLOAT16(__VA_ARGS__))) +#define MS_FLOAT16_TO_INT16(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT16_TO_INT16(__VA_ARGS__))) + +#define MS_INT32_TO_FLOAT16(bit_num, ...) MS_EXPAND((MS##bit_num##_INT32_TO_FLOAT16(__VA_ARGS__))) +#define MS_FLOAT16_TO_INT32(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT16_TO_INT32(__VA_ARGS__))) + +#define MS_INT32_TO_FLOAT32(bit_num, ...) MS_EXPAND((MS##bit_num##_INT32_TO_FLOAT32(__VA_ARGS__))) +#define MS_FLOAT32_TO_INT32(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT32_TO_INT32(__VA_ARGS__))) + +#define MS_INT64_TO_FLOAT32(bit_num, ...) MS_EXPAND((MS##bit_num##_INT64_TO_FLOAT32(__VA_ARGS__))) +#define MS_FLOAT32_TO_INT64(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT32_TO_INT64(__VA_ARGS__))) + +#define MS_INT64_TO_FLOAT16(bit_num, ...) MS_EXPAND((MS##bit_num##_INT64_TO_FLOAT16(__VA_ARGS__))) +#define MS_FLOAT16_TO_INT64(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT16_TO_INT64(__VA_ARGS__))) + +#define MS_INT32_TO_FLOAT64(bit_num, ...) MS_EXPAND((MS##bit_num##_INT32_TO_FLOAT64(__VA_ARGS__))) +#define MS_FLOAT64_TO_INT32(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT64_TO_INT32(__VA_ARGS__))) + +#define MS_INT64_TO_FLOAT64(bit_num, ...) MS_EXPAND((MS##bit_num##_INT64_TO_FLOAT64(__VA_ARGS__))) +#define MS_FLOAT64_TO_INT64(bit_num, ...) MS_EXPAND((MS##bit_num##_FLOAT64_TO_INT64(__VA_ARGS__))) + +// enable avx512 +#if defined(ENABLE_AVX512) +#define MS_SIMD_RUN_AVX512(function, ...) MS_EXPAND(function(512, 16, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_AVX512(function, ...) +#endif + +// enable avx256 +#if defined(ENABLE_AVX) +#define MS_SIMD_RUN_AVX(function, ...) MS_EXPAND(function(256, 8, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_AVX(function, ...) +#endif + +// enable sse +#if defined(ENABLE_SSE) +#define MS_SIMD_RUN_SSE(function, ...) MS_EXPAND(function(128, 4, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_SSE(function, ...) +#endif + +// enable neon +#if defined(ENABLE_NEON) +#define MS_SIMD_RUN_NEON(function, ...) MS_EXPAND(function(128, 4, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_NEON(function, ...) +#endif + +// enable neon/sse +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) +#define MS_SIMD_RUN_SSEORNEON128(function, ...) MS_EXPAND(function(128, 4, __VA_ARGS__)) +#else +#define MS_SIMD_RUN_SSEORNEON128(function, ...) +#endif + +// scalar (c style data) +#define MS_SIMD_RUN_SCALAR(function, ...) MS_EXPAND(function(32, 1, __VA_ARGS__)) + +#define MS_SIMD_RUN(function, ...) \ + do { \ + MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \ + MS_SIMD_RUN_AVX(function, __VA_ARGS__); \ + MS_SIMD_RUN_SSEORNEON128(function, __VA_ARGS__); \ + MS_SIMD_RUN_SCALAR(function, __VA_ARGS__); \ + } while (0) + +#define MS_SIMD_RUN_NO_SCALAR(function, ...) \ + do { \ + MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \ + MS_SIMD_RUN_AVX(function, __VA_ARGS__); \ + MS_SIMD_RUN_SSEORNEON128(function, __VA_ARGS__); \ + } while (0) + +#define MS_SIMD_RUN_X86(function, ...) \ + do { \ + MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \ + MS_SIMD_RUN_AVX(function, __VA_ARGS__); \ + MS_SIMD_RUN_SSE(function, __VA_ARGS__); \ + MS_SIMD_RUN_SCALAR(function, __VA_ARGS__); \ + } while (0) + +#define MS_SIMD_RUN_X86_NO_SCALAR(function, ...) \ + do { \ + MS_SIMD_RUN_AVX512(function, __VA_ARGS__); \ + MS_SIMD_RUN_AVX(function, __VA_ARGS__); \ + MS_SIMD_RUN_SSE(function, __VA_ARGS__); \ + } while (0) + +#endif // NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..94ed4b89b1339d4ecd00a4f35f93fc35a8e10ed2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h @@ -0,0 +1,162 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ +#define NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ +#include +#include "nnacl/intrinsics/ms_simd_instructions.h" + +#if defined(ENABLE_ARM82_A32) +static inline float16x8_t ms_vdivq_f16(float16x8_t in1, float16x8_t in2) { + float16x8_t dst; + asm volatile( + "vrecpe.f16 q14, %3\n" + "vrecps.f16 q15, %3, q14\n" + "vmul.f16 q14, q15, q14\n" + "vrecps.f16 q15, %3, q14\n" + "vmul.f16 q14, q15, q14\n" + "vmul.f16 %0, %2, q14\n" + : "=w"(dst) + : "0"(dst), "w"(in1), "w"(in2) + : "q14", "q15"); + return dst; +} + +static inline float16x4_t ms_vdiv_f16(float16x4_t in1, float16x4_t in2) { + float16x4_t dst; + asm volatile( + "vrecpe.f16 d14, %3\n" + "vrecps.f16 d16, %3, d14\n" + "vmul.f16 d14, d16, d14\n" + "vrecps.f16 d16, %3, d14\n" + "vmul.f16 d14, d16, d14\n" + "vmul.f16 %0, %2, d14\n" + : "=w"(dst) + : "0"(dst), "w"(in1), "w"(in2) + : "d14", "d16"); + return dst; +} + +static inline float ms_vaddvq_f32(float32x4_t in) { + // is not support in arm82 aarch32 and there is no assembly instruction to process all the data + return in[0] + in[1] + in[2] + in[3]; +} + +static inline float16_t ms_vmaxvq_f16(float16x8_t in) { + // is not support in arm82 aarch32 and there is no assembly instruction to process all the data + float16_t dst = in[0]; + for (int i = 1; i < 8; ++i) { + dst = dst > in[i] ? dst : in[i]; + } + return dst; +} + +static inline float32x4_t ms_vcvt_f32_f16(float16x4_t in) { + float32x4_t dst; + asm volatile("vcvt.f32.f16 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); + return dst; +} + +static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) { + float16x4_t dst; + asm volatile("vcvt.f16.f32 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); + return dst; +} + +#define MS_CVT_F32_F16(src) ms_vcvt_f32_f16(src) +#define MS_CVT_F16_F32(src) ms_vcvt_f16_f32(src) +#define MS_DIV_F16(src1, src2) ms_vdiv_f16(src1, src2) +#define MS_DIVQ_F16(src1, src2) ms_vdivq_f16(src1, src2) +#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_f16(src1, src2, vdupq_n_f16(src3)) +#define MS_MAXVQ_F16(src) ms_vmaxvq_f16(src) +#define MS_ADDVQ_F32(src) ms_vaddvq_f32(src) +#else +#define MS_CVT_F32_F16(src) vcvt_f32_f16(src) +#define MS_CVT_F16_F32(src) vcvt_f16_f32(src) +#define MS_DIV_F16(src1, src2) vdiv_f16(src1, src2) +#define MS_DIVQ_F16(src1, src2) vdivq_f16(src1, src2) +#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_n_f16(src1, src2, src3) +#define MS_MAXVQ_F16(src) vmaxvq_f16(src) +#define MS_ADDVQ_F32(src) vaddvq_f32(src) +#endif + +#define MS_FLOAT16X8 float16x8_t +#define MS_FLOAT16X4 float16x4_t +#define MS_FLOAT16X4X4 float16x4x4_t +#define MS_FLOAT16X4X2 float16x4x2_t +#define MS_MOVQ_F16 vmovq_n_f16 +#define MS_STQ_F16(ptr, val) vst1q_f16(ptr, val) +#define MS_ST_F16 vst1_f16 +#define MS_ST2_F16 vst2_f16 +#define MS_ST4_F16 vst4_f16 +#define MS_MINQ_F16 vminq_f16 +#define MS_MAXQ_F16 vmaxq_f16 +#define MS_LDQ_F16(ptr) vld1q_f16(ptr) +#define MS_LD_F16(ptr) vld1_f16(ptr) +#define MS_ADDQ_F16 vaddq_f16 +#define MS_SUBQ_F16 vsubq_f16 +#define MS_MULQ_F16 vmulq_f16 +#define MS_FMAQ_F16 vfmaq_f16 +#define MS_MULQ_N_F16(vector, scalar) vmulq_n_f16(vector, scalar) +#define MS_CMPGTQ_F16(src1, src2) vcgtq_f16(src1, src2) + +static inline float16x8_t MS_TANHX8_F16(float16x8_t src) { + float32x4_t src_low = MS_CVT_F32_F16(vget_low_f16(src)); + float32x4_t src_high = MS_CVT_F32_F16(vget_high_f16(src)); + return vcombine_f16(MS_CVT_F16_F32(MS_TANHX4_F32(src_low)), MS_CVT_F16_F32(MS_TANHX4_F32(src_high))); +} + +static inline float16x8_t MS_ERFX8_F16(float16x8_t src) { + float16x8_t dst; + dst[0] = erff(src[0]); + dst[1] = erff(src[1]); + dst[2] = erff(src[2]); + dst[3] = erff(src[3]); + dst[4] = erff(src[4]); + dst[5] = erff(src[5]); + dst[6] = erff(src[6]); + dst[7] = erff(src[7]); + return dst; +} + +static inline float16x8_t MS_SQRTFX8_F16(float16x8_t src) { + float16x8_t dst; + dst[0] = sqrtf(src[0]); + dst[1] = sqrtf(src[1]); + dst[2] = sqrtf(src[2]); + dst[3] = sqrtf(src[3]); + dst[4] = sqrtf(src[4]); + dst[5] = sqrtf(src[5]); + dst[6] = sqrtf(src[6]); + dst[7] = sqrtf(src[7]); + return dst; +} + +static inline float16x4_t MS_SQRTFX4_F16(float16x4_t src) { + float16x4_t dst; + dst[0] = sqrtf(src[0]); + dst[1] = sqrtf(src[1]); + dst[2] = sqrtf(src[2]); + dst[3] = sqrtf(src[3]); + return dst; +} + +static inline float32x4_t MS_VMLAL_F16(float16x4_t x, float16x4_t dy, float32x4_t sum) { + float32x4_t x_fp32 = MS_CVT_F32_F16(x); + float32x4_t dy_fp32 = MS_CVT_F32_F16(dy); + return vmlaq_f32(sum, x_fp32, dy_fp32); +} + +#endif // NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_neon_instructions.h b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_neon_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..53333c7f3338a541f3f4d4e6333e14b55d8a99c7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_neon_instructions.h @@ -0,0 +1,362 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NEON_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_NEON_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include +#include + +#include + +#define MS_F32X4_GETI(src, i) src[i] +#define MS128_F32_GETI(src, i) src[i] +#define MS_FLOAT32X4 float32x4_t +#define MS_FLOAT32X4X2 float32x4x2_t +#define MS_FLOAT32X4X4 float32x4x4_t +#define MS_FLOAT128_F32 float32x4_t +#define MS_INT32X4 int32x4_t +#define MS_INT128_EPI32 int32x4_t +#define MS_UINT32X4 uint32x4_t +#define MS_MASK128_TYPE MS_UINT32X4 +#define MS_LDQ_F32 vld1q_f32 +#define MS_LD128_F32 vld1q_f32 +#define MS_LDQ_EPI32 vld1q_s32 +#define MS_LD128_EPI32 vld1q_s32 +#define MS_ADDQ_F32 vaddq_f32 +#define MS_ADD128_F32 vaddq_f32 +#define MS_ADDQ_EPI32 vaddq_s32 +#define MS_ADD128_EPI32 vaddq_s32 +#define MS_MOVQ_F32 vmovq_n_f32 +#define MS_MOV128_F32 vmovq_n_f32 +#define MS_MOVQ_EPI32 vmovq_n_s32 +#define MS_MOV128_VAL0_F32 vmovq_n_f32(0.0f) +#define MS_MOV128_EPI32 vmovq_n_s32 +#define MS_SUBQ_F32 vsubq_f32 +#define MS_SUB128_F32 vsubq_f32 +#define MS_SUB128_EPI32 vsubq_s32 +#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3) +#define MS_STQ_F32 vst1q_f32 +#define MS_ST128_F32 vst1q_f32 +#define MS_STQ_EPI32 vst1q_s32 +#define MS_ST128_EPI32 vst1q_s32 +#define MS_MAXQ_F32 vmaxq_f32 +#define MS_MAXQ_EPI32 vmaxq_s32 +#define MS_MAX128_F32 vmaxq_f32 +#define MS_MAX128_EPI32 vmaxq_s32 +#define MS_MINQ_F32 vminq_f32 +#define MS_MINQ_EPI32 vminq_s32 +#define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2) +#define MS_MULQ_EPI32(src1, src2) vmulq_s32(src1, src2) +#define MS_MIN128_F32 vminq_f32 +#define MS_MIN128_EPI32 vminq_s32 +#define MS_MUL128_F32(src1, src2) vmulq_f32(src1, src2) +#define MS_MUL128_EPI32(src1, src2) vmulq_s32(src1, src2) +#define MS_FMADD128_F32(src1, src2, src3) vmlaq_f32(src3, src1, src2) +#define MS_FSMUL128_F32(src1, src2, src3) vmlsq_f32(src1, src2, src3) +#define MS_FMSUB128_EPI32(src1, src2, src3) vmlsq_s32(src3, src1, src2) +#ifdef ENABLE_ARM64 +#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2) +#define MS_DIV128_F32(src1, src2) vdivq_f32(src1, src2) +#else +static inline float32x4_t vrecp(float32x4_t v) { + float32x4_t r = vrecpeq_f32(v); + r = vmulq_f32(vrecpsq_f32(v, r), r); + r = vmulq_f32(vrecpsq_f32(v, r), r); + return r; +} +#define MS_DIVQ_F32(src1, src2) vmulq_f32(src1, vrecp(src2)) +#define MS_DIV128_F32(src1, src2) vmulq_f32(src1, vrecp(src2)) +#endif +#define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2) +#define MS_MULQ_N_EPI32(src1, src2) vmulq_n_s32(src1, src2) +#define MS_DIVQ_N_F32(src1, src2) vdivq_n_f32(src1, src2) +#define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2)) +#define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src) +#define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src) +#define MS_CMPLEQ_F32(src1, src2) vcleq_f32(src1, src2) +#define MS_CMPGTQ_F32(src1, src2) vcgtq_f32(src1, src2) +#define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2) +#define MS_CMPLE128_F32(src1, src2) vcleq_f32(src1, src2) +#define MS_CMPLT128_F32(src1, src2) vcltq_f32(src1, src2) +#define MS_CMPGT128_F32(src1, src2) vcgtq_f32(src1, src2) +#define MS_CMPGT128_EPI32(src1, src2) vcgtq_s32(src1, src2) +// Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_f32 +#define MS_BLENDQ_F32(src1, src2, src3) vbslq_f32(src3, src2, src1) +#define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1) +#define MS_BLEND128_F32(src1, src2, src3) vbslq_f32(src3, src2, src1) +#define MS_BLEND128_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1) +#define MS_CAST128_F32_S32(src) vreinterpretq_f32_s32(src) +#define MS_AND128_MASK(src1, src2) vandq_u32(src1, src2) +#define MS_AND128_F32(src1, src2) \ + vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(src1), vreinterpretq_u32_f32(src2))) +#define MS_OR128_F32(src1, src2) \ + vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(src1), vreinterpretq_u32_f32(src2))) +#define MS_CAST128_U32_F32(src) vreinterpretq_u32_f32(src) +#define MS_CAST128_F32_U32(src) vreinterpretq_f32_u32(src) +#define MS_OR128_MASK(src1, src2) vorrq_u32(src1, src2) + +#ifdef ENABLE_ARM64 +#define MS_GET_MAX128_F32 vmaxvq_f32 +static inline float MS_GET_SUM128_F32(MS_FLOAT32X4 src) { return vaddvq_f32(src); } +#else +static inline float MS_GET_MAX128_F32(MS_FLOAT32X4 src) { + float result = MS_F32X4_GETI(src, 0); + for (int i = 1; i < 4; i++) { // neon block num : 4 + result = fmaxf(result, MS_F32X4_GETI(src, i)); + } + return result; +} + +static inline float MS_GET_SUM128_F32(MS_FLOAT32X4 src) { + float result = MS_F32X4_GETI(src, 0); + for (int i = 1; i < 4; i++) { // neon block num : 4 + result = result + MS_F32X4_GETI(src, i); + } + return result; +} +#endif + +static inline MS_FLOAT32X4 MS_AND128_MASK_F32(MS_UINT32X4 src1, MS_FLOAT32X4 src2) { + MS_FLOAT32X4 result; + result[0] = (src1[0] == 0) ? 0.0f : src2[0]; + result[1] = (src1[1] == 0) ? 0.0f : src2[1]; + result[2] = (src1[2] == 0) ? 0.0f : src2[2]; + result[3] = (src1[3] == 0) ? 0.0f : src2[3]; + return result; +} + +static inline int32x4_t MS_DIV128_EPI32(int32x4_t src1, int32x4_t src2) { + int32x4_t result; + result[0] = src1[0] / src2[0]; // C0 : 0 + result[1] = src1[1] / src2[1]; // C1 : 1 + result[2] = src1[2] / src2[2]; // C2 : 2 + result[3] = src1[3] / src2[3]; // C3 : 3 + return result; +} + +#define MS128_INT32_TO_FLOAT32(src) vcvtq_f32_s32(src) +#define MS128_FLOAT32_TO_INT32(src) vcvtq_s32_f32(src) + +static inline MS_FLOAT32X4 MS_POW128_F32(MS_FLOAT32X4 src1, MS_FLOAT32X4 src2) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = powf(MS_F32X4_GETI(src1, 0), MS_F32X4_GETI(src2, 0)); + MS_F32X4_GETI(dst, 1) = powf(MS_F32X4_GETI(src1, 1), MS_F32X4_GETI(src2, 1)); + MS_F32X4_GETI(dst, 2) = powf(MS_F32X4_GETI(src1, 2), MS_F32X4_GETI(src2, 2)); + MS_F32X4_GETI(dst, 3) = powf(MS_F32X4_GETI(src1, 3), MS_F32X4_GETI(src2, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_ABS128_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = fabsf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = fabsf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = fabsf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = fabsf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS128_LOG_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = logf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = logf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = logf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = logf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_SQRT128_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3)); + return dst; +} +#define MS_RSQRT128_F32 vrsqrteq_f32 + +#define LOAD128X8_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ + MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ + MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ + MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ + MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); + +#define STORE128X8_F32(output_ptr, num, dst) \ + MS_STQ_F32(output_ptr + 0 * num, dst##1); \ + MS_STQ_F32(output_ptr + 1 * num, dst##2); \ + MS_STQ_F32(output_ptr + 2 * num, dst##3); \ + MS_STQ_F32(output_ptr + 3 * num, dst##4); \ + MS_STQ_F32(output_ptr + 4 * num, dst##5); \ + MS_STQ_F32(output_ptr + 5 * num, dst##6); \ + MS_STQ_F32(output_ptr + 6 * num, dst##7); \ + MS_STQ_F32(output_ptr + 7 * num, dst##8); + +static inline MS_FLOAT32X4 VexpFp32(MS_FLOAT32X4 input) { + static MS_FLOAT32X4 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}, + {2.0f, 2.0f, 2.0f, 2.0f}}; + static MS_FLOAT32X4 negative_flag = {-0.0f, -0.0f, -0.0f, -0.0f}; + + MS_INT32X4 integer = + MS_CVTQPS_EPI32(MS_FMADD128_F32(input, param[6], MS_OR128_F32(MS_AND128_F32(input, negative_flag), param[4]))); + MS_FLOAT32X4 decimal = MS_SUBQ_F32(input, MS_MULQ_F32(MS_CVTQEPI32_PS(integer), param[0])); + MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(126)), 23); + MS_FLOAT32X4 tmp = MS_MULQ_F32(decimal, (MS_ADDQ_F32(param[2], MS_MULQ_F32(decimal, param[1])))); + tmp = MS_MULQ_F32(decimal, MS_ADDQ_F32(param[4], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[3], tmp)))); + MS_FLOAT32X4 decimal_exp = MS_ADDQ_F32(param[5], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[5], tmp))); + return MS_MULQ_F32(param[7], MS_MULQ_F32(decimal_exp, MS_CAST128_F32_S32(int_exp))); +} + +static inline void simd_exp128(MS_FLOAT32X4 input, float *dst) { + static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + MS_STQ_F32(dst, VexpFp32(input)); +} + +static inline MS_FLOAT32X4 simd_exp128_f32(MS_FLOAT32X4 input) { + static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + return VexpFp32(input); +} + +static inline MS_FLOAT32X4 simd_hexp128_f32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = exp(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = exp(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = exp(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = exp(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) { + static const MS_FLOAT32X4 data0 = {378.0f, 378.0f, 378.0f, 378.0f}; + static const MS_FLOAT32X4 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f}; + static const MS_FLOAT32X4 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f}; + static const MS_FLOAT32X4 data3 = {28.0f, 28.0f, 28.0f, 28.0f}; + static const MS_FLOAT32X4 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f}; + static const MS_FLOAT32X4 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT32X4 up_limit = {5.0f, 5.0f, 5.0f, 5.0f}; + static const MS_FLOAT32X4 down_limit = {-5.0f, -5.0f, -5.0f, -5.0f}; + + MS_UINT32X4 up_mask = MS_CMPGTQ_F32(src, up_limit); + MS_UINT32X4 down_mask = MS_CMPGTQ_F32(down_limit, src); + + MS_FLOAT32X4 square = MS_MULQ_F32(src, src); + MS_FLOAT32X4 a = MS_MULQ_F32( + MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(square, data0), square), data1), square), data2), src); + MS_FLOAT32X4 b = MS_ADDQ_F32( + MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(MS_MULQ_F32(data3, square), data4), square), data5), square), + data2); + + MS_FLOAT32X4 tanh_value = MS_DIVQ_F32(a, b); + MS_FLOAT32X4 res = MS_BLENDQ_F32(tanh_value, pos, up_mask); + res = MS_BLENDQ_F32(res, neg, down_mask); + return res; +} + +static inline MS_FLOAT128_F32 SIMD_SIGN128_F32(MS_FLOAT128_F32 src) { + MS_FLOAT128_F32 abs_src = MS_ABS128_F32(src); + MS_FLOAT128_F32 src_tmp = MS_OR128_F32(src, MS_MOV128_F32(1.0f)); + MS_FLOAT128_F32 sign = MS_DIV128_F32(abs_src, src_tmp); + return sign; +} + +static inline MS_FLOAT128_F32 SIMD_SIGNABS128_F32(MS_FLOAT128_F32 src, MS_FLOAT128_F32 abs_src) { + MS_FLOAT128_F32 src_tmp = MS_OR128_F32(src, MS_MOV128_F32(1.0f)); + return MS_DIV128_F32(abs_src, src_tmp); +} + +#define MS_TANH128_F32 MS_TANHX4_F32 + +static inline MS_FLOAT32X4 MS128_ERF_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = erff(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = erff(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = erff(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = erff(MS_F32X4_GETI(src, 3)); + return dst; +} + +#define MS_FMADD128X8_F32(src, weight, dst) \ + dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); \ + dst##5 = MS_MLAQ_F32(src##5, weight, dst##5); \ + dst##6 = MS_MLAQ_F32(src##6, weight, dst##6); \ + dst##7 = MS_MLAQ_F32(src##7, weight, dst##7); \ + dst##8 = MS_MLAQ_F32(src##8, weight, dst##8); + +#define MS_LOAD128X4_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); + +#define MS_FMADD128X4_F32(src, weight, dst) \ + dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); + +#define MS_LOAD128X8_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ + MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ + MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ + MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ + MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); + +#define MS_SET_ZERO128X8_F32(dst) \ + MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##5 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##6 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##7 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##8 = MS_MOVQ_F32(0.0f); + +#define MS_SET_ZERO128X4_F32(dst) \ + MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); +#endif // NNACL_NEON_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_sse_instructions.h b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_sse_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..6eb07e256818385c221f939c10a4346ff6675ffc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/ms_simd_sse_instructions.h @@ -0,0 +1,403 @@ +/** + * Copyright 2021-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#define NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ +#include + +#ifdef _MSC_VER +#include +#define MS_F32X4_GETI(src, i) src.m128_f32[i] +#define MS128_F32_GETI(src, i) src.m128_f32[i] +#else +#include +#define MS_F32X4_GETI(src, i) src[i] +#define MS128_F32_GETI(src, i) src[i] +#endif + +#define PI 3.1415926f +#define LN2 0.693147f + +#define MS_FLOAT32X4 __m128 +#define MS_FLOAT128_F32 __m128 +#define MS_INT32X4 __m128i +#define MS_INT128_EPI32 __m128i +#define MS_MASK128_TYPE MS_FLOAT32X4 +#define MS_LDQ_F32 _mm_loadu_ps +#define MS_LD128_F32 _mm_loadu_ps +#define MS_LDQ_EPI32(src) _mm_loadu_si128((__m128i const *)(src)) +#define MS_LD128_EPI32(src) _mm_loadu_si128((__m128i const *)(src)) +#define MS_ADDQ_F32 _mm_add_ps +#define MS_ADD128_F32 _mm_add_ps +#define MS_ADDQ_EPI32 _mm_add_epi32 +#define MS_ADD128_EPI32 _mm_add_epi32 +#define MS_MOVQ_F32 _mm_set1_ps +#define MS_MOV128_F32 _mm_set1_ps +#define MS_MOVQ_EPI32 _mm_set1_epi32 +#define MS_MOV128_EPI32 _mm_set1_epi32 +#define MS_MOV128_VAL0_F32 _mm_setzero_ps() +#define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3)) +#define MS_STQ_F32 _mm_storeu_ps +#define MS_ST128_F32 _mm_storeu_ps +#define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2) +#define MS_ST128_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2) +#define MS_SUBQ_F32 _mm_sub_ps +#define MS_SUB128_F32 _mm_sub_ps +#define MS_SUB128_EPI32 _mm_sub_epi32 +#define MS_MAXQ_F32 _mm_max_ps +#define MS_MAXQ_EPI32 _mm_max_epi32 +#define MS_MAX128_F32 _mm_max_ps +#define MS_MAX128_EPI32 _mm_max_epi32 +#define MS_MINQ_F32 _mm_min_ps +#define MS_MINQ_EPI32 _mm_min_epi32 +#define MS_SQRT128_F32 _mm_sqrt_ps +#define MS_RSQRT128_F32 _mm_rsqrt_ps +#define MS_SIN128_F32 _mm_sin_ps +#define MS_ERF128_F32 _mm_erf_ps +#define MS_ROUND128_F32(src) _mm_round_ps(src, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) +#define MS_FLOOR128_F32 _mm_floor_ps +#define MS_CEIL128_F32 _mm_ceil_ps +#define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, src2) +#define MS_MULQ_EPI32(src1, src2) _mm_mullo_epi32(src1, src2) +#define MS_MIN128_F32 _mm_min_ps +#define MS_MIN128_EPI32 _mm_min_epi32 +#define MS_MUL128_F32(src1, src2) _mm_mul_ps(src1, src2) +#define MS_MUL128_EPI32(src1, src2) _mm_mullo_epi32(src1, src2) +#define MS_DIVQ_F32(src1, src2) _mm_div_ps(src1, src2) +#define MS_DIV128_F32(src1, src2) _mm_div_ps(src1, src2) +#define MS_MULQ_N_F32(src1, src2) _mm_mul_ps(src1, _mm_set1_ps(src2)) +#define MS_MULQ_N_EPI32(src1, src2) _mm_mullo_epi32(src1, _mm_set1_epi32(src2)) +#define MS_DIVQ_N_F32(src1, src2) _mm_div_ps(src1, _mm_set1_ps(src2)) +#define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2) +#define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int +#define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src) +#define MS_CMPLEQ_F32(src1, src2) _mm_cmple_ps(src1, src2) +#define MS_CMPGTQ_F32(src1, src2) _mm_cmpgt_ps(src1, src2) +#define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2) +#define MS_BLENDQ_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3) +#define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3) +#define MS_CMPLT128_F32(src1, src2) _mm_cmplt_ps(src1, src2) +#define MS_CMPLE128_F32(src1, src2) _mm_cmple_ps(src1, src2) +#define MS_CMPGT128_F32(src1, src2) _mm_cmpgt_ps(src1, src2) +#define MS_CMPEQ128_F32(src1, src2) _mm_cmpeq_ps(src1, src2) +#define MS_CMPUNORD128_F32(src1, src2) _mm_cmpunord_ps(src1, src2) +#define MS_CMPGT128_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2) +#define MS_BLEND128_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3) +#define MS_BLEND128_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3) +#define MS_CAST128_F32_S32(src) _mm_castsi128_ps(src) +#define MS_DIV128_EPI32(src1, src2) _mm_cvttps_epi32(MS_DIV128_F32(_mm_cvtepi32_ps(src1), _mm_cvtepi32_ps(src2))) +#define MS_AND128_MASK(src1, src2) _mm_and_ps(src1, src2) +#define MS_OR128_F32(src1, src2) _mm_or_ps(src1, src2) +#define MS_AND128_MASK_F32(src1, src2) _mm_and_ps(src1, src2) +#define MS_AND128_F32(src1, src2) _mm_and_ps(src1, src2) + +#define MS128_ANDNOT_F32(src1, src2) _mm_andnot_ps(src1, src2) +#define MS128_SRLI_EPI32(src1, src2) _mm_srli_epi32(src1, src2) +#define MS128_AND_EPI32(src1, src2) _mm_and_si128(src1, src2) +#define MS128_CASTPS_EPI32(src) _mm_castps_si128(src) +#define MS_CVT128EPI32_PS(src) _mm_cvtepi32_ps(src) +#define MS_CAST128_F32_S32(src) _mm_castsi128_ps(src) + +static inline MS_FLOAT32X4 MS_POW128_F32(MS_FLOAT32X4 src1, MS_FLOAT32X4 src2) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = powf(MS_F32X4_GETI(src1, 0), MS_F32X4_GETI(src2, 0)); + MS_F32X4_GETI(dst, 1) = powf(MS_F32X4_GETI(src1, 1), MS_F32X4_GETI(src2, 1)); + MS_F32X4_GETI(dst, 2) = powf(MS_F32X4_GETI(src1, 2), MS_F32X4_GETI(src2, 2)); + MS_F32X4_GETI(dst, 3) = powf(MS_F32X4_GETI(src1, 3), MS_F32X4_GETI(src2, 3)); + return dst; +} + +#ifdef ENABLE_AVX // only enable sse, dont support fma instruction. +#define MS_FMADD128_F32(src1, src2, src3) _mm_fmadd_ps(src1, src2, src3) +#define MS_FMSUB128_F32(src1, src2, src3) _mm_fmsub_ps(src1, src2, src3) +#define MS_FSMUL128_F32(src1, src2, src3) _mm_fnmadd_ps(src3, src2, src1) +#else +#define MS_FMADD128_F32(src1, src2, src3) _mm_add_ps(_mm_mul_ps(src1, src2), src3) +#define MS_FMSUB128_F32(src1, src2, src3) _mm_sub_ps(_mm_mul_ps(src1, src2), src3) +#define MS_FSMUL128_F32(src1, src2, src3) _mm_sub_ps(src1, _mm_mul_ps(src2, src3)) +#endif + +#define MS128_INT16_TO_FLOAT16(src) _mm_cvtepi16_ph(src) +#define MS128_FLOAT16_TO_INT16(src) _mm_cvttph_epi16(src) + +#define MS128_INT32_TO_FLOAT16(src) _mm_cvtepi32_ph(src) +#define MS128_FLOAT16_TO_INT32(src) _mm_cvttph_epi32(src) + +#define MS128_INT32_TO_FLOAT32(src) _mm_cvtepi32_ps(src) +#define MS128_FLOAT32_TO_INT32(src) _mm_cvttps_epi32(src) + +#define MS128_INT64_TO_FLOAT32(src) _mm_cvtepi64_ps(src) +#define MS128_FLOAT32_TO_INT64(src) _mm_cvttps_epi64(src) + +#define MS128_INT64_TO_FLOAT16(src) _mm_cvtepi64_ph(src) +#define MS128_FLOAT16_TO_INT64(src) _mm_cvttph_epi64(src) + +#define MS128_INT32_TO_FLOAT64(src) _mm_cvtepi32_pd(src) +#define MS128_FLOAT64_TO_INT32(src) _mm_cvttpd_epi32(src) + +#define MS128_INT64_TO_FLOAT64(src) _mm_cvtepi64_pd(src) +#define MS128_FLOAT64_TO_INT64(src) _mm_cvttpd_epi64(src) + +#define MS128_INT16_TO_INT32(src) _mm128_cvtepi16_epi32(src) +#define MS128_INT16_TO_INT64(src) _mm128_cvtepi16_epi64(src) +#define MS128_INT32_TO_INT16(src) _mm128_cvtepi32_epi16(src) +#define MS128_INT32_TO_INT64(src) _mm128_cvtepi32_epi64(src) +#define MS128_INT64_TO_INT16(src) _mm128_cvtepi64_epi16(src) +#define MS128_INT64_TO_INT32(src) _mm128_cvtepi64_epi32(src) + +static inline MS_FLOAT32X4 MS_ABS128_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = fabsf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = fabsf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = fabsf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = fabsf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT128_F32 SIMD_SIGN128_F32(MS_FLOAT128_F32 src) { + MS_FLOAT128_F32 abs_src = MS_ABS128_F32(src); + MS_FLOAT128_F32 sign = MS_DIV128_F32(abs_src, src); + return sign; +} + +#define SIMD_SIGNABS128_F32(src, abs_src) MS_DIV128_F32(abs_src, src) + +static inline MS_FLOAT32X4 MS_COS128_F32(MS_FLOAT32X4 src) { + static const MS_FLOAT32X4 pi = {PI, PI, PI, PI}; + static const MS_FLOAT32X4 pi2_neg = {-2 * PI, -2 * PI, -2 * PI, -2 * PI}; + static const MS_FLOAT32X4 div_pi2 = {1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI), 1.0f / (2 * PI)}; + MS_FLOAT128_F32 src_abs = MS_ABS128_F32(src); + MS_FLOAT128_F32 src_cycle = + MS_ADD128_F32(MS_MUL128_F32(MS_FLOOR128_F32(MS_MUL128_F32(MS_ADD128_F32(src_abs, pi), div_pi2)), pi2_neg), src_abs); + static const MS_FLOAT128_F32 data0 = {1.0f / 90, 1.0f / 90, 1.0f / 90, 1.0f / 90}; + static const MS_FLOAT128_F32 data1 = {1.0f / 56, 1.0f / 56, 1.0f / 56, 1.0f / 56}; + static const MS_FLOAT128_F32 data2 = {1.0f / 30, 1.0f / 30, 1.0f / 30, 1.0f / 30}; + static const MS_FLOAT128_F32 data3 = {1.0f / 12, 1.0f / 12, 1.0f / 12, 1.0f / 12}; + static const MS_FLOAT128_F32 data4 = {0.5f, 0.5f, 0.5f, 0.5f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X4 square = MS_MUL128_F32(src_cycle, src_cycle); + + MS_FLOAT32X4 tmp = + MS_MUL128_F32(MS_MUL128_F32(MS_ADD128_F32(MS_MUL128_F32(MS_MUL128_F32(neg, square), data0), pos), square), data1); + MS_FLOAT32X4 tmp1 = MS_MUL128_F32(MS_MUL128_F32(MS_ADD128_F32(tmp, neg), square), data2); + MS_FLOAT128_F32 res = MS_ADD128_F32( + MS_MUL128_F32( + MS_MUL128_F32(MS_ADD128_F32(MS_MUL128_F32(MS_MUL128_F32(MS_ADD128_F32(tmp1, pos), square), data3), neg), square), + data4), + pos); + return res; +} + +static inline MS_FLOAT32X4 MS128_LOG_F32(MS_FLOAT32X4 src) { + const MS_INT128_EPI32 gFloatExpMask = MS_MOV128_EPI32(0xffULL << 23); + const MS_INT128_EPI32 gFloatExp0 = MS_MOV128_EPI32(127ULL << 23); + const MS_INT128_EPI32 gExpNormalizer = MS_MOV128_EPI32(127); + static const MS_FLOAT128_F32 data0 = {1.0f / 11, 1.0f / 11, 1.0f / 11, 1.0f / 11}; + static const MS_FLOAT128_F32 data1 = {1.0f / 9, 1.0f / 9, 1.0f / 9, 1.0f / 9}; + static const MS_FLOAT128_F32 data2 = {1.0f / 7, 1.0f / 7, 1.0f / 7, 1.0f / 7}; + static const MS_FLOAT128_F32 data3 = {0.2f, 0.2f, 0.2f, 0.2f}; + static const MS_FLOAT128_F32 data4 = {1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3}; + static const MS_FLOAT128_F32 data5 = {1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT128_F32 data6 = {2.0f, 2.0f, 2.0f, 2.0f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + static const MS_FLOAT32X4 ln2 = {LN2, LN2, LN2, LN2}; + + const MS_INT128_EPI32 exps32 = MS128_SRLI_EPI32(MS128_AND_EPI32(gFloatExpMask, MS128_CASTPS_EPI32(src)), 23); + const MS_INT128_EPI32 normExps = MS_SUB128_EPI32(exps32, gExpNormalizer); + const MS_FLOAT32X4 expsPD = MS_CVT128EPI32_PS(normExps); + const MS_FLOAT32X4 y = + MS_OR128_F32(MS_CAST128_F32_S32(gFloatExp0), MS128_ANDNOT_F32(MS_CAST128_F32_S32(gFloatExpMask), src)); + MS_FLOAT32X4 div = MS_DIV128_F32(MS_ADD128_F32(y, neg), MS_ADD128_F32(y, pos)); + MS_FLOAT32X4 square = MS_MUL128_F32(div, div); + + MS_FLOAT32X4 tmp = MS_ADD128_F32( + MS_MUL128_F32(MS_ADD128_F32(MS_MUL128_F32(square, MS_ADD128_F32(MS_MUL128_F32(square, data0), data1)), data2), + square), + data3); + MS_FLOAT32X4 tmp1 = MS_MUL128_F32(square, MS_ADD128_F32(MS_MUL128_F32(square, tmp), data4)); + MS_FLOAT32X4 res = + MS_ADD128_F32(MS_MUL128_F32(ln2, expsPD), MS_MUL128_F32(MS_MUL128_F32(div, MS_ADD128_F32(tmp1, data5)), data6)); + MS_FLOAT32X4 mask = MS_CMPEQ128_F32(src, MS_MOV128_F32(0.0f)); + res = MS_BLEND128_F32(res, MS_MOV128_F32(-INFINITY), mask); + mask = MS_CMPEQ128_F32(src, MS_MOV128_F32(INFINITY)); + res = MS_BLEND128_F32(res, MS_MOV128_F32(INFINITY), mask); + mask = MS_OR128_F32(MS_CMPLT128_F32(src, MS_MOV128_F32(0.0f)), MS_CMPUNORD128_F32(src, MS_MOV128_F32(0.0f))); + res = MS_BLEND128_F32(res, MS_MOV128_F32(NAN), mask); + return res; +} + +static inline MS_FLOAT32X4 MS_SQRTFX4_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = sqrtf(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = sqrtf(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = sqrtf(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = sqrtf(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline float MS_GET_MAX128_F32(__m128 src) { + float result = MS_F32X4_GETI(src, 0); + for (int i = 1; i < 4; i++) { // sse block num : 4 + result = fmaxf(result, MS_F32X4_GETI(src, i)); + } + return result; +} + +static inline float MS_GET_SUM128_F32(__m128 src) { + float result = MS_F32X4_GETI(src, 0); + for (int i = 1; i < 4; i++) { // sse block num : 4 + result = result + MS_F32X4_GETI(src, i); + } + return result; +} + +#define STORE128X8_F32(output_ptr, num, dst) \ + MS_STQ_F32(output_ptr + 0 * num, dst##1); \ + MS_STQ_F32(output_ptr + 1 * num, dst##2); \ + MS_STQ_F32(output_ptr + 2 * num, dst##3); \ + MS_STQ_F32(output_ptr + 3 * num, dst##4); \ + MS_STQ_F32(output_ptr + 4 * num, dst##5); \ + MS_STQ_F32(output_ptr + 5 * num, dst##6); \ + MS_STQ_F32(output_ptr + 6 * num, dst##7); \ + MS_STQ_F32(output_ptr + 7 * num, dst##8); + +static inline MS_FLOAT32X4 VexpFp32(MS_FLOAT32X4 input) { + static MS_FLOAT32X4 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}, + {2.0f, 2.0f, 2.0f, 2.0f}}; + MS_INT32X4 integer = MS_CVTQPS_EPI32(MS_FLOOR128_F32(MS_FMADD128_F32(input, param[6], param[4]))); + MS_FLOAT32X4 decimal = MS_SUBQ_F32(input, MS_MULQ_F32(MS_CVTQEPI32_PS(integer), param[0])); + MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(126)), 23); + MS_FLOAT32X4 tmp = MS_MULQ_F32(decimal, (MS_ADDQ_F32(param[2], MS_MULQ_F32(decimal, param[1])))); + tmp = MS_MULQ_F32(decimal, MS_ADDQ_F32(param[4], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[3], tmp)))); + MS_FLOAT32X4 decimal_exp = MS_ADDQ_F32(param[5], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[5], tmp))); + return MS_MULQ_F32(param[7], MS_MULQ_F32(decimal_exp, MS_CAST128_F32_S32(int_exp))); +} + +static inline void simd_exp128(MS_FLOAT32X4 input, float *dst) { + static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + MS_STQ_F32(dst, VexpFp32(input)); +} + +static inline MS_FLOAT32X4 simd_exp128_f32(MS_FLOAT32X4 input) { + static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; + static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + return VexpFp32(input); +} + +static inline MS_FLOAT32X4 simd_hexp128_f32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = exp(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = exp(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = exp(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = exp(MS_F32X4_GETI(src, 3)); + return dst; +} + +static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) { + static const MS_FLOAT32X4 data0 = {378.0f, 378.0f, 378.0f, 378.0f}; + static const MS_FLOAT32X4 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f}; + static const MS_FLOAT32X4 data2 = {135135.0f, 135135.0f, 135135.0f, 135135.0f}; + static const MS_FLOAT32X4 data3 = {28.0f, 28.0f, 28.0f, 28.0f}; + static const MS_FLOAT32X4 data4 = {3150.0f, 3150.0f, 3150.0f, 3150.0f}; + static const MS_FLOAT32X4 data5 = {62370.0f, 62370.0f, 62370.0f, 62370.0f}; + static const MS_FLOAT32X4 neg = {-1.0f, -1.0f, -1.0f, -1.0f}; + static const MS_FLOAT32X4 pos = {1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X4 square = MS_MULQ_F32(src, src); + MS_FLOAT32X4 a = + MS_MUL128_F32(MS_FMADD128_F32(MS_FMADD128_F32(MS_ADD128_F32(square, data0), square, data1), square, data2), src); + MS_FLOAT32X4 b = + MS_FMADD128_F32(MS_FMADD128_F32(MS_FMADD128_F32(data3, square, data4), square, data5), square, data2); + MS_FLOAT32X4 res = MS_DIVQ_F32(a, b); + MS_FLOAT32X4 up_limit = MS_MOV128_F32(5.0f); + MS_FLOAT32X4 down_limit = MS_MOV128_F32(-5.0f); + MS_FLOAT32X4 up_mask = MS_CMPGT128_F32(src, up_limit); + MS_FLOAT32X4 down_mask = MS_CMPLT128_F32(src, down_limit); + res = MS_BLEND128_F32(res, pos, up_mask); + res = MS_BLEND128_F32(res, neg, down_mask); + return res; +} + +#define MS_TANH128_F32 MS_TANHX4_F32 + +static inline MS_FLOAT32X4 MS128_ERF_F32(MS_FLOAT32X4 src) { + MS_FLOAT32X4 dst; + MS_F32X4_GETI(dst, 0) = erff(MS_F32X4_GETI(src, 0)); + MS_F32X4_GETI(dst, 1) = erff(MS_F32X4_GETI(src, 1)); + MS_F32X4_GETI(dst, 2) = erff(MS_F32X4_GETI(src, 2)); + MS_F32X4_GETI(dst, 3) = erff(MS_F32X4_GETI(src, 3)); + return dst; +} + +#define MS_FMADD128X8_F32(src, weight, dst) \ + dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); \ + dst##5 = MS_MLAQ_F32(src##5, weight, dst##5); \ + dst##6 = MS_MLAQ_F32(src##6, weight, dst##6); \ + dst##7 = MS_MLAQ_F32(src##7, weight, dst##7); \ + dst##8 = MS_MLAQ_F32(src##8, weight, dst##8); + +#define MS_LOAD128X4_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); + +#define MS_FMADD128X4_F32(src, weight, dst) \ + dst##1 = MS_MLAQ_F32(src##1, weight, dst##1); \ + dst##2 = MS_MLAQ_F32(src##2, weight, dst##2); \ + dst##3 = MS_MLAQ_F32(src##3, weight, dst##3); \ + dst##4 = MS_MLAQ_F32(src##4, weight, dst##4); + +#define MS_LOAD128X8_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ + MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ + MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ + MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ + MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); + +#define MS_SET_ZERO128X8_F32(dst) \ + MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##5 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##6 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##7 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##8 = MS_MOVQ_F32(0.0f); + +#define MS_SET_ZERO128X4_F32(dst) \ + MS_FLOAT32X4 dst##1 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##2 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##3 = MS_MOVQ_F32(0.0f); \ + MS_FLOAT32X4 dst##4 = MS_MOVQ_F32(0.0f); +#endif // NNACL_SSE_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/ConvDwFp32IndirectRow.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/ConvDwFp32IndirectRow.c new file mode 100644 index 0000000000000000000000000000000000000000..d0a25ab1b7dbedc1ff2db17ffd64dc66e8e09070 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/ConvDwFp32IndirectRow.c @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#ifdef _MSC_VER +#include +#else +#include +#endif +#include "nnacl/fp32/conv_depthwise_fp32.h" + +#define INPUT_SIZE 25 + +void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const float *bias, size_t channels, + size_t output_width, size_t input_stride, size_t relu, size_t relu6) { + input_stride /= sizeof(float *); + size_t c8 = UP_DIV(channels, C8NUM) * C8NUM; + size_t c8_mod = channels % C8NUM; + float *in[INPUT_SIZE]; + for (int i = 0; i < output_width; ++i) { + for (int k = 0; k < INPUT_SIZE; k++) { + in[k] = input[k]; + } + input += input_stride; + size_t c = c8; + const float *w = weights; + const float *bias1 = bias; + for (; c >= C8NUM; c -= C8NUM) { + __m256 out1 = _mm256_loadu_ps(bias1); + bias1 += 8; + for (int k = 0; k < INPUT_SIZE; k += 5) { + __m256 in1 = _mm256_loadu_ps(in[k]); + __m256 w1 = _mm256_loadu_ps(w); + __m256 in2 = _mm256_loadu_ps(in[k + 1]); + __m256 w2 = _mm256_loadu_ps(w + 8); + out1 = _mm256_fmadd_ps(in1, w1, out1); + __m256 in3 = _mm256_loadu_ps(in[k + 2]); + __m256 w3 = _mm256_loadu_ps(w + 16); + out1 = _mm256_fmadd_ps(in2, w2, out1); + __m256 in4 = _mm256_loadu_ps(in[k + 3]); + __m256 w4 = _mm256_loadu_ps(w + 24); + out1 = _mm256_fmadd_ps(in3, w3, out1); + __m256 in5 = _mm256_loadu_ps(in[k + 4]); + __m256 w5 = _mm256_loadu_ps(w + 32); + out1 = _mm256_fmadd_ps(in4, w4, out1); + w += 40; + in[k] += C8NUM; + in[k + 1] += C8NUM; + in[k + 2] += C8NUM; + in[k + 3] += C8NUM; + in[k + 4] += C8NUM; + out1 = _mm256_fmadd_ps(in5, w5, out1); + } + if (relu6 != 0) { + __m256 relu6_data = _mm256_set1_ps(6.0); + out1 = _mm256_min_ps(out1, relu6_data); + } + if (relu != 0 || relu6 != 0) { + __m256 zero = _mm256_setzero_ps(); + out1 = _mm256_max_ps(out1, zero); + } + if (c > C8NUM || c8_mod == 0) { + _mm256_storeu_ps(output, out1); + output += C8NUM; + } else { + __m128 tmp; + switch (c8_mod) { + case 1: + _mm_store_ss(output, _mm256_castps256_ps128(out1)); + break; + case 2: + _mm_storel_pi((__m64 *)output, _mm256_castps256_ps128(out1)); + break; + case 3: + tmp = _mm256_castps256_ps128(out1); + _mm_storel_pi((__m64 *)output, tmp); + tmp = _mm_unpackhi_ps(tmp, tmp); + _mm_store_ss(output + 2, tmp); + break; + case 4: + _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); + break; + case 5: + _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); + _mm_store_ss(output + 4, _mm256_extractf128_ps(out1, 1)); + break; + case 6: + _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); + _mm_storel_pi((__m64 *)(output + 4), _mm256_extractf128_ps(out1, 1)); + break; + case 7: + _mm_storeu_ps(output, _mm256_castps256_ps128(out1)); + tmp = _mm256_extractf128_ps(out1, 1); + _mm_storel_pi((__m64 *)(output + 4), tmp); + tmp = _mm_unpackhi_ps(tmp, tmp); + _mm_store_ss(output + 6, tmp); + break; + default: + _mm256_storeu_ps(output, out1); + break; + } + output += c8_mod; + } + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/ConvDwFp32Row_sse.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/ConvDwFp32Row_sse.c new file mode 100644 index 0000000000000000000000000000000000000000..d5194d381f6d0282fd781b13fbf7e1b636374e1e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/ConvDwFp32Row_sse.c @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/fp32/common_func_fp32.h" + +void ConvDwFp32Row(float *output_ptr, const float *input_ptr, const float *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step) { + size_t out_c16 = DOWN_DIV(output_channel, C16NUM) * C16NUM; + size_t out_c8 = DOWN_DIV(output_channel, C8NUM) * C8NUM; + size_t out_c4 = DOWN_DIV(output_channel, C4NUM) * C4NUM; + for (int i = 0; i < num_pixels; i++) { + const float *weight_tmp = weight_ptr; + const float *input_tmp = input_ptr; + size_t out_c = 0; + for (; out_c < out_c16; out_c += C16NUM) { + __m128 dst1 = _mm_loadu_ps(output_ptr); + __m128 dst2 = _mm_loadu_ps(output_ptr + 4); + __m128 dst3 = _mm_loadu_ps(output_ptr + 8); + __m128 dst4 = _mm_loadu_ps(output_ptr + 12); + __m128 w1 = _mm_loadu_ps(weight_tmp); + __m128 w2 = _mm_loadu_ps(weight_tmp + 4); + __m128 w3 = _mm_loadu_ps(weight_tmp + 8); + __m128 w4 = _mm_loadu_ps(weight_tmp + 12); + __m128 in1 = _mm_loadu_ps(input_tmp); + __m128 in2 = _mm_loadu_ps(input_tmp + 4); + __m128 in3 = _mm_loadu_ps(input_tmp + 8); + __m128 in4 = _mm_loadu_ps(input_tmp + 12); + dst1 = MS_MLAQ_F32(dst1, w1, in1); + dst2 = MS_MLAQ_F32(dst2, w2, in2); + dst3 = MS_MLAQ_F32(dst3, w3, in3); + dst4 = MS_MLAQ_F32(dst4, w4, in4); + _mm_storeu_ps(output_ptr, dst1); + _mm_storeu_ps(output_ptr + 4, dst2); + _mm_storeu_ps(output_ptr + 8, dst3); + _mm_storeu_ps(output_ptr + 12, dst4); + output_ptr += 16; + input_tmp += 16; + weight_tmp += 16; + } + for (; out_c < out_c8; out_c += C8NUM) { + __m128 dst1 = _mm_loadu_ps(output_ptr); + __m128 dst2 = _mm_loadu_ps(output_ptr + 4); + __m128 w1 = _mm_loadu_ps(weight_tmp); + __m128 w2 = _mm_loadu_ps(weight_tmp + 4); + __m128 in1 = _mm_loadu_ps(input_tmp); + __m128 in2 = _mm_loadu_ps(input_tmp + 4); + dst1 = MS_MLAQ_F32(dst1, w1, in1); + dst2 = MS_MLAQ_F32(dst2, w2, in2); + _mm_storeu_ps(output_ptr, dst1); + _mm_storeu_ps(output_ptr + 4, dst2); + output_ptr += 8; + input_tmp += 8; + weight_tmp += 8; + } + for (; out_c < out_c4; out_c += C4NUM) { + __m128 dst1 = _mm_loadu_ps(output_ptr); + __m128 w1 = _mm_loadu_ps(weight_tmp); + __m128 in1 = _mm_loadu_ps(input_tmp); + dst1 = MS_MLAQ_F32(dst1, w1, in1); + _mm_storeu_ps(output_ptr, dst1); + output_ptr += 4; + input_tmp += 4; + weight_tmp += 4; + } + for (; out_c < output_channel; out_c++) { + *output_ptr++ += weight_ptr[out_c] * input_ptr[out_c]; + } + input_ptr += input_step; + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/DepthwiseFp32_Sse.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/DepthwiseFp32_Sse.c new file mode 100644 index 0000000000000000000000000000000000000000..777d626a6af9fad7afd8552ebf52eb2b51605e03 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/DepthwiseFp32_Sse.c @@ -0,0 +1,327 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl/intrinsics/sse/sse_common.h" + +#ifndef ENABLE_AVX +void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) { + in_kh_step /= sizeof(float); + in_kw_step /= sizeof(float); + kernel_w_step /= sizeof(float); + + const float *src_kh = src; + const float *weight_kh = weight; + __m128 dst_ma = _mm_setzero_ps(); + + for (int kh = 0; kh < height; kh++) { + const float *src_kw = src_kh; + const float *weight_kw = weight_kh; + + int c1 = 0; + int c4 = DOWN_DIV(width, C4NUM) * C4NUM; + int c2 = DOWN_DIV(width, C2NUM) * C2NUM; + // c4 loop + for (; c1 < c4; c1 += C4NUM) { + __m128 src_ma1 = _mm_loadu_ps(src_kw); + __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step); + __m128 src_ma3 = _mm_loadu_ps(src_kw + 2 * in_kw_step); + __m128 src_ma4 = _mm_loadu_ps(src_kw + 3 * in_kw_step); + + __m128 weight_ma1 = _mm_loadu_ps(weight_kw); + __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 weight_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM); + __m128 weight_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM); + + __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); + __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2); + __m128 mul_ma3 = _mm_mul_ps(src_ma3, weight_ma3); + __m128 mul_ma4 = _mm_mul_ps(src_ma4, weight_ma4); + + dst_ma = _mm_add_ps(dst_ma, mul_ma1); + dst_ma = _mm_add_ps(dst_ma, mul_ma2); + dst_ma = _mm_add_ps(dst_ma, mul_ma3); + dst_ma = _mm_add_ps(dst_ma, mul_ma4); + + src_kw += in_kw_step * 4; + weight_kw += C4NUM * 4; + } + + // c2 loop + for (; c1 < c2; c1 += C2NUM) { + __m128 src_ma1 = _mm_loadu_ps(src_kw); + __m128 src_ma2 = _mm_loadu_ps(src_kw + in_kw_step); + __m128 weight_ma1 = _mm_loadu_ps(weight_kw); + __m128 weight_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); + __m128 mul_ma2 = _mm_mul_ps(src_ma2, weight_ma2); + dst_ma = _mm_add_ps(dst_ma, mul_ma1); + dst_ma = _mm_add_ps(dst_ma, mul_ma2); + + src_kw += in_kw_step * 2; + weight_kw += C4NUM * 2; + } + + // remaining + for (; c1 < width; ++c1) { + __m128 src_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_ma1 = _mm_loadu_ps(weight_kw); + __m128 mul_ma1 = _mm_mul_ps(src_ma1, weight_ma1); + dst_ma = _mm_add_ps(dst_ma, mul_ma1); + + src_kw += in_kw_step; + weight_kw += C4NUM; + } + + src_kh += in_kh_step; + weight_kh += kernel_w_step; + } + + __m128 bias_ma = _mm_loadu_ps(bias); + dst_ma = _mm_add_ps(dst_ma, bias_ma); + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + dst_ma = _mm_max_ps(zero_ma, dst_ma); + if (relu6) { + __m128 const_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + dst_ma = _mm_min_ps(const_ma, dst_ma); + } + } + _mm_storeu_ps(dst, dst_ma); +} +#endif + +void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, + size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, + size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6) { + out_h_step /= sizeof(float); + block_channel /= sizeof(float); + in_sh_step /= sizeof(float); + in_sw_step /= sizeof(float); + in_kh_step /= sizeof(float); + in_kw_step /= sizeof(float); + + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + int c4 = DOWN_DIV(width, C4NUM) * C4NUM; + int c2 = DOWN_DIV(width, C2NUM) * C2NUM; + int c1 = 0; + // c4 loop + for (; c1 < c4; c1 += C4NUM, dst_w += C4NUM * block_channel, src_w += C4NUM * in_sw_step) { + const float *src_kh = src_w, *weight_kh = weight; + __m128 dst_w_ma1 = _mm_setzero_ps(); + __m128 dst_w_ma2 = _mm_setzero_ps(); + __m128 dst_w_ma3 = _mm_setzero_ps(); + __m128 dst_w_ma4 = _mm_setzero_ps(); + + for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { + const float *src_kw = src_kh, *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { + __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + + __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + + __m128 src_kw_ma3 = _mm_loadu_ps(src_kw + 2 * in_sw_step); + __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma3 = _mm_mul_ps(src_kw_ma3, weight_kw_ma3); + dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3); + + __m128 src_kw_ma4 = _mm_loadu_ps(src_kw + 3 * in_sw_step); + __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma4 = _mm_mul_ps(src_kw_ma4, weight_kw_ma4); + dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4); + } // kernel_w loop + } // kernel_h loop + + // add bias relu + __m128 bias_ma = _mm_loadu_ps(bias); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma); + dst_w_ma3 = _mm_add_ps(dst_w_ma3, bias_ma); + dst_w_ma4 = _mm_add_ps(dst_w_ma4, bias_ma); + + ActBlock4(&dst_w_ma1, &dst_w_ma2, &dst_w_ma3, &dst_w_ma4, relu, relu6); + + _mm_storeu_ps(dst_w, dst_w_ma1); + _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); + _mm_storeu_ps(dst_w + 2 * block_channel, dst_w_ma3); + _mm_storeu_ps(dst_w + 3 * block_channel, dst_w_ma4); + } // dst_width loop + + // c2 loop + for (; c1 < c2; c1 += C2NUM, dst_w += C2NUM * block_channel, src_w += C2NUM * in_sw_step) { + const float *src_kh = src_w, *weight_kh = weight; + __m128 dst_w_ma1 = _mm_setzero_ps(); + __m128 dst_w_ma2 = _mm_setzero_ps(); + + for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { + const float *src_kw = src_kh, *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { + __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + + __m128 src_kw_ma2 = _mm_loadu_ps(src_kw + in_sw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma2 = _mm_mul_ps(src_kw_ma2, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + } // kernel_w loop + } // kernel_h loop + // add bias relu + __m128 bias_ma = _mm_loadu_ps(bias); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, bias_ma); + + ActBlock2(&dst_w_ma1, &dst_w_ma2, relu, relu6); + + _mm_storeu_ps(dst_w, dst_w_ma1); + _mm_storeu_ps(dst_w + block_channel, dst_w_ma2); + } + + // remaining + for (; c1 < width; c1++, dst_w += block_channel, src_w += in_sw_step) { + const float *src_kh = src_w, *weight_kh = weight; + __m128 dst_w_ma1 = _mm_setzero_ps(); + for (int kh = 0; kh < kernel_h; kh++, src_kh += in_kh_step, weight_kh += kernel_w * C4NUM) { + const float *src_kw = src_kh, *weight_kw = weight_kh; + for (int kw = 0; kw < kernel_w; kw++, src_kw += in_kw_step, weight_kw += C4NUM) { + __m128 src_kw_ma1 = _mm_loadu_ps(src_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_kw_ma1, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + } // kernel_w loop + } // kernel_h loop + + // add bias relu + __m128 bias_ma = _mm_loadu_ps(bias); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, bias_ma); + ActBlock1(&dst_w_ma1, relu, relu6); + _mm_storeu_ps(dst_w, dst_w_ma1); + } + dst_h += out_h_step; + src_h += in_sh_step; + } // dst_height loop +} + +void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, + size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, + size_t in_kh_step, size_t in_kw_step) { + out_h_step /= sizeof(float); + block_channel /= sizeof(float); + in_sh_step /= sizeof(float); + in_sw_step /= sizeof(float); + in_kh_step /= sizeof(float); + in_kw_step /= sizeof(float); + + float *dst_h = dst; + const float *src_h = src; + for (int oh = 0; oh < height; oh++) { + float *dst_w = dst_h; + const float *src_w = src_h; + for (int ow = 0; ow < width; ow++) { + float *dst_kh = dst_w; + const float *weight_kh = weight; + __m128 src_w_ma = _mm_loadu_ps(src_w); + for (int kh = 0; kh < kernel_h; kh++) { + float *dst_kw = dst_kh; + const float *weight_kw = weight_kh; + + int c4 = DOWN_DIV(kernel_w, C4NUM) * C4NUM; + int c2 = DOWN_DIV(kernel_w, C2NUM) * C2NUM; + int c1 = 0; + // c4 loop + for (; c1 < c4; c1 += C4NUM) { + __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + _mm_storeu_ps(dst_kw, dst_w_ma1); + + __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2); + + __m128 dst_w_ma3 = _mm_loadu_ps(dst_kw + 2 * in_kw_step); + __m128 weight_kw_ma3 = _mm_loadu_ps(weight_kw + 2 * C4NUM); + __m128 tmp_ma3 = _mm_mul_ps(src_w_ma, weight_kw_ma3); + dst_w_ma3 = _mm_add_ps(dst_w_ma3, tmp_ma3); + _mm_storeu_ps(dst_kw + 2 * in_kw_step, dst_w_ma3); + + __m128 dst_w_ma4 = _mm_loadu_ps(dst_kw + 3 * in_kw_step); + __m128 weight_kw_ma4 = _mm_loadu_ps(weight_kw + 3 * C4NUM); + __m128 tmp_ma4 = _mm_mul_ps(src_w_ma, weight_kw_ma4); + dst_w_ma4 = _mm_add_ps(dst_w_ma4, tmp_ma4); + _mm_storeu_ps(dst_kw + 3 * in_kw_step, dst_w_ma4); + + dst_kw += 4 * in_kw_step; + weight_kw += 4 * C4NUM; + } + // c2 loop + for (; c1 < c2; c1 += C2NUM) { + __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + _mm_storeu_ps(dst_kw, dst_w_ma1); + + __m128 dst_w_ma2 = _mm_loadu_ps(dst_kw + in_kw_step); + __m128 weight_kw_ma2 = _mm_loadu_ps(weight_kw + C4NUM); + __m128 tmp_ma2 = _mm_mul_ps(src_w_ma, weight_kw_ma2); + dst_w_ma2 = _mm_add_ps(dst_w_ma2, tmp_ma2); + _mm_storeu_ps(dst_kw + in_kw_step, dst_w_ma2); + + dst_kw += 2 * in_kw_step; + weight_kw += 2 * C4NUM; + } + // remaining + for (; c1 < kernel_w; ++c1) { + __m128 dst_w_ma1 = _mm_loadu_ps(dst_kw); + __m128 weight_kw_ma1 = _mm_loadu_ps(weight_kw); + __m128 tmp_ma1 = _mm_mul_ps(src_w_ma, weight_kw_ma1); + dst_w_ma1 = _mm_add_ps(dst_w_ma1, tmp_ma1); + _mm_storeu_ps(dst_kw, dst_w_ma1); + + dst_kw += in_kw_step; + weight_kw += C4NUM; + } // kernel_w loop + + dst_kh += in_kh_step; + weight_kh += kernel_w * C4NUM; + } // kernel_h loop + dst_w += in_sw_step; + src_w += block_channel; + } // dst_width loop + dst_h += in_sh_step; + src_h += out_h_step; + } // dst_height loop +} + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/MatMul_Sse.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/MatMul_Sse.c new file mode 100644 index 0000000000000000000000000000000000000000..bc46eff2ce7108782087c4275557668134f5494d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/MatMul_Sse.c @@ -0,0 +1,243 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/op_base.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/intrinsics/sse/sse_common.h" +#include "nnacl/base/minimal_filtering_generator.h" + +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel) { + const float *src1 = matix_a; + int c16 = DOWN_DIV(in_channel, C16NUM) * C16NUM; + int c8 = DOWN_DIV(in_channel, C8NUM) * C8NUM; + for (int i = 0; i < m; ++i) { + const float *src1_n = src1; + const float *src2_n = matrix_b; + for (int j = 0; j < n; ++j) { + const float *src1_j = src1_n; + int y = 0; + // 16 channel + for (; y < c16; y += C16NUM) { + __m128 dst1 = _mm_setzero_ps(); + __m128 dst2 = _mm_setzero_ps(); + __m128 dst3 = _mm_setzero_ps(); + __m128 dst4 = _mm_setzero_ps(); + const float *src2_y = src2_n; + for (int z = 0; z < k; ++z) { + __m128 ma1 = _mm_loadu_ps(src1_j); + __m128 ma2 = _mm_loadu_ps(src1_j + 4); + __m128 ma3 = _mm_loadu_ps(src1_j + 8); + __m128 ma4 = _mm_loadu_ps(src1_j + 12); + + __m128 mb = _mm_load_ps1(src2_y); + __m128 tmp1 = _mm_mul_ps(ma1, mb); + __m128 tmp2 = _mm_mul_ps(ma2, mb); + __m128 tmp3 = _mm_mul_ps(ma3, mb); + __m128 tmp4 = _mm_mul_ps(ma4, mb); + dst1 = _mm_add_ps(dst1, tmp1); + dst2 = _mm_add_ps(dst2, tmp2); + dst3 = _mm_add_ps(dst3, tmp3); + dst4 = _mm_add_ps(dst4, tmp4); + src1_j += in_channel; + src2_y += n; + } + _mm_storeu_ps(matrix_c, dst1); + _mm_storeu_ps(matrix_c + 4, dst2); + _mm_storeu_ps(matrix_c + 8, dst3); + _mm_storeu_ps(matrix_c + 12, dst4); + src1_j -= in_channel * k; + src1_j += C16NUM; + matrix_c += C16NUM; + } + // 8 channel + for (; y < c8; y += C8NUM) { + __m128 dst1 = _mm_setzero_ps(); + __m128 dst2 = _mm_setzero_ps(); + const float *src2_y = src2_n; + for (int z = 0; z < k; ++z) { + __m128 ma1 = _mm_loadu_ps(src1_j); + __m128 ma2 = _mm_loadu_ps(src1_j + 4); + + __m128 mb = _mm_load_ps1(src2_y); + __m128 tmp1 = _mm_mul_ps(ma1, mb); + __m128 tmp2 = _mm_mul_ps(ma2, mb); + dst1 = _mm_add_ps(dst1, tmp1); + dst2 = _mm_add_ps(dst2, tmp2); + src1_j += in_channel; + src2_y += n; + } + _mm_storeu_ps(matrix_c, dst1); + _mm_storeu_ps(matrix_c + 4, dst2); + src1_j -= in_channel * k; + src1_j += C8NUM; + matrix_c += C8NUM; + } + // remain chann + for (; y < in_channel; ++y) { + float tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + *matrix_c++ = tmp; + } + src2_n += 1; + } + src1 += k * in_channel; + } +} + +void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, int write_mode) { + int C8Steps = row * C8NUM, WinoSteps1 = stride * col, WinoSteps2 = stride * C8NUM; + for (int r = row; r > 0; r -= C4NUM) { + const float *srcb_d = b, *bias_d = bias; + float *dst = NULL; + for (int cc = col; cc > 0; cc -= C8NUM) { + if (write_mode != 0) { // writec8 + dst = c; + } + const float *srca_d = a; + __m128 dst1 = _mm_setzero_ps(), dst2 = _mm_setzero_ps(), dst3 = _mm_setzero_ps(), dst4 = _mm_setzero_ps(); + __m128 dst5 = _mm_setzero_ps(), dst6 = _mm_setzero_ps(), dst7 = _mm_setzero_ps(), dst8 = _mm_setzero_ps(); + for (int d = depth; d > 0; --d) { + __m128 b1 = _mm_loadu_ps(srcb_d), b2 = _mm_loadu_ps(srcb_d + 4); + __m128 a1 = _mm_load_ps1(srca_d), a2 = _mm_load_ps1(srca_d + 1); + __m128 tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + __m128 tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + a1 = _mm_load_ps1(srca_d + 2); + dst1 = _mm_add_ps(dst1, tmp1), dst2 = _mm_add_ps(dst2, tmp2); + a2 = _mm_load_ps1(srca_d + 3); + dst3 = _mm_add_ps(dst3, tmp3), dst4 = _mm_add_ps(dst4, tmp4); + tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + dst5 = _mm_add_ps(dst5, tmp1), dst6 = _mm_add_ps(dst6, tmp2); + dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4); + srcb_d += C8NUM, srca_d += C4NUM; + } + + if (bias != NULL) { + DoBiasBlock8(bias_d, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8); + bias_d += C8NUM; + } + + ActBlock8(&dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, act_type); + + if (write_mode == OutType_TileC8) { // WriteWino + c = dst + WinoSteps2; + _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); + dst += WinoSteps1; + _mm_storeu_ps(dst, dst3), _mm_storeu_ps(dst + 4, dst4); + dst += WinoSteps1; + _mm_storeu_ps(dst, dst5), _mm_storeu_ps(dst + 4, dst6); + dst += WinoSteps1; + _mm_storeu_ps(dst, dst7), _mm_storeu_ps(dst + 4, dst8); + } else if (write_mode == OutType_C8) { // WriteC8 + _mm_storeu_ps(c, dst1), _mm_storeu_ps(c + 4, dst2); + _mm_storeu_ps(c + 8, dst3), _mm_storeu_ps(c + 12, dst4); + _mm_storeu_ps(c + 16, dst5), _mm_storeu_ps(c + 20, dst6); + _mm_storeu_ps(c + 24, dst7), _mm_storeu_ps(c + 28, dst8); + c += C8Steps; + } else { + switch (cc) { + case 1: // write1 + c = dst + 1; + WriteCol1(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 1, r); + break; + case 2: // write2 + c = dst + 2; + WriteCol2Opt(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, r); + break; + case 3: // write3 + c = dst + 3; + _mm_store_ss(dst, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 2, dst1); + WriteCol3(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 3, r); + break; + case 4: // write4 + c = dst + 4; + WriteCol4(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 4, r); + break; + case 5: // write5 + c = dst + 5; + WriteCol5(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 5, r); + break; + case 6: // write6 + c = dst + 6; + WriteCol6(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 6, r); + break; + case 7: // write7 + c = dst + 7; + WriteCol7(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 7, r); + break; + default: // write8 + c = dst + C8NUM; + WriteCol8(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 8, r); + break; + } + } + if (cc <= C8NUM) break; // write end + } + a += C4NUM * depth; + if (write_mode == OutType_C8) c += 32; + if (write_mode == OutType_TileC8) c = dst + WinoSteps2; + if (write_mode == OutType_Nhwc) c = dst - col; + if (r <= C4NUM) break; + } +} + +void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col) { + for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) { + const float *srca_d = a; + float *dst = c; + for (int r = row; r > 0; r -= C4NUM) { + const float *srcb_d = b; + __m128 dst1 = _mm_setzero_ps(), dst2 = _mm_setzero_ps(); + __m128 dst3 = _mm_setzero_ps(), dst4 = _mm_setzero_ps(); + __m128 dst5 = _mm_setzero_ps(), dst6 = _mm_setzero_ps(); + __m128 dst7 = _mm_setzero_ps(), dst8 = _mm_setzero_ps(); + for (int d = 0; d < depth; d++) { + __m128 b1 = _mm_loadu_ps(srcb_d), b2 = _mm_loadu_ps(srcb_d + 4); + __m128 a1 = _mm_load_ps1(srca_d), a2 = _mm_load_ps1(srca_d + 1); + __m128 tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + __m128 tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + a1 = _mm_load_ps1(srca_d + 2); + dst1 = _mm_add_ps(dst1, tmp1), dst2 = _mm_add_ps(dst2, tmp2); + a2 = _mm_load_ps1(srca_d + 3); + dst3 = _mm_add_ps(dst3, tmp3), dst4 = _mm_add_ps(dst4, tmp4); + tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1); + tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2); + dst5 = _mm_add_ps(dst5, tmp1), dst6 = _mm_add_ps(dst6, tmp2); + dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4); + srcb_d += C8NUM, srca_d += C4NUM; + } + _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2); + _mm_storeu_ps(dst + 8, dst3), _mm_storeu_ps(dst + 12, dst4); + _mm_storeu_ps(dst + 16, dst5), _mm_storeu_ps(dst + 20, dst6); + _mm_storeu_ps(dst + 24, dst7), _mm_storeu_ps(dst + 28, dst8); + dst += 32; + c = dst; + } + b += depth * C8NUM; + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/PostFuncBiasReluC8.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/PostFuncBiasReluC8.c new file mode 100644 index 0000000000000000000000000000000000000000..c18d6228a83e574994e33fadec30fc5608ebc593 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/PostFuncBiasReluC8.c @@ -0,0 +1,131 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl/intrinsics/sse/sse_common.h" + +void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type) { + stride /= sizeof(float); + for (int loop_c8 = 0; loop_c8 != oc8div; loop_c8 += C8NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c8 = dst + loop_c8; + __m128 bias1 = _mm_setzero_ps(), bias2 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + 4); + bias += 8; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM, src += 32) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + __m128 src3 = _mm_loadu_ps(src + 8); + __m128 src4 = _mm_loadu_ps(src + 12); + __m128 src5 = _mm_loadu_ps(src + 16); + __m128 src6 = _mm_loadu_ps(src + 20); + __m128 src7 = _mm_loadu_ps(src + 24); + __m128 src8 = _mm_loadu_ps(src + 28); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias2); + src5 = _mm_add_ps(src5, bias1); + src6 = _mm_add_ps(src6, bias2); + src7 = _mm_add_ps(src7, bias1); + src8 = _mm_add_ps(src8, bias2); + + ActBlock8(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + + _mm_storeu_ps(dst_c8, src1); + _mm_storeu_ps(dst_c8 + 4, src2); + dst_c8 += stride; + _mm_storeu_ps(dst_c8, src3); + _mm_storeu_ps(dst_c8 + 4, src4); + dst_c8 += stride; + _mm_storeu_ps(dst_c8, src5); + _mm_storeu_ps(dst_c8 + 4, src6); + dst_c8 += stride; + _mm_storeu_ps(dst_c8, src7); + _mm_storeu_ps(dst_c8 + 4, src8); + dst_c8 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1, src += 8, dst_c8 += stride) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + + ActBlock2(&src1, &src2, relu_type == 1, relu_type == 3); + + _mm_storeu_ps(dst_c8, src1); + _mm_storeu_ps(dst_c8 + 4, src2); + } + } + + if (oc8mod == 0) return; + + __m128 bias1 = _mm_setzero_ps(); + __m128 bias2 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + 4); + bias += 8; + } + float *dst_c1 = dst + oc8div; + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1, src += 8, dst_c1 += stride) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + + ActBlock2(&src1, &src2, relu_type == 1, relu_type == 3); + + switch (oc8mod) { + case 1: + _mm_store_ss(dst_c1, src1); + break; + case 2: + _mm_storel_pi((__m64 *)(dst_c1), src1); + break; + case 3: + _mm_storel_pi((__m64 *)(dst_c1), src1); + src1 = _mm_unpackhi_ps(src1, src1); + _mm_store_ss(dst_c1 + 2, src1); + break; + case 4: + _mm_storeu_ps(dst_c1, src1); + break; + case 5: + _mm_storeu_ps(dst_c1, src1); + _mm_store_ss(dst_c1 + 4, src2); + break; + case 6: + _mm_storeu_ps(dst_c1, src1); + _mm_storel_pi((__m64 *)(dst_c1 + 4), src2); + break; + case 7: + _mm_storeu_ps(dst_c1, src1); + _mm_storel_pi((__m64 *)(dst_c1 + 4), src2); + src2 = _mm_unpackhi_ps(src2, src2); + _mm_store_ss(dst_c1 + 6, src2); + break; + default: + break; + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/TiledC4MatMulFp32.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/TiledC4MatMulFp32.c new file mode 100644 index 0000000000000000000000000000000000000000..eb1c8efb277dad8fc087d54d84f199352bc5dec3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/TiledC4MatMulFp32.c @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/fp32/common_func_fp32.h" + +static inline void TiledC4MatmulFp32_Transfer(__m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, + const __m128 weight, const float v1, const float v2, const float v3, + const float v4) { + *dst1 = _mm_add_ps(*dst1, _mm_mul_ps(weight, _mm_set_ps1(v1))); + *dst2 = _mm_add_ps(*dst2, _mm_mul_ps(weight, _mm_set_ps1(v2))); + *dst3 = _mm_add_ps(*dst3, _mm_mul_ps(weight, _mm_set_ps1(v3))); + *dst4 = _mm_add_ps(*dst4, _mm_mul_ps(weight, _mm_set_ps1(v4))); +} + +static inline void TiledC4MatmulFp32_LoadData(__m128 *src1, __m128 *src2, __m128 *src3, __m128 *src4, + const float *src) { + *src1 = _mm_loadu_ps(src); + *src2 = _mm_loadu_ps(src + 4); + *src3 = _mm_loadu_ps(src + 8); + *src4 = _mm_loadu_ps(src + 12); +} + +void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { + const float *src_tmp = src; + for (int i = 0; i < oc4; ++i) { + float *dst_tmp = dst; + src = src_tmp; + size_t ic4_tmp = ic4 - 1; + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + 4); + __m128 src3 = _mm_loadu_ps(src + 8); + __m128 src4 = _mm_loadu_ps(src + 12); + src += 16; + __m128 weight_data[4]; + weight_data[0] = _mm_loadu_ps(weight); + weight_data[1] = _mm_loadu_ps(weight + 4); + weight_data[2] = _mm_loadu_ps(weight + 8); + weight_data[3] = _mm_loadu_ps(weight + 12); + weight += 16; + __m128 dst1 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0))); + __m128 dst2 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0))); + __m128 dst3 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0))); + __m128 dst4 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0))); + for (int j = 1; j < 4; ++j) { + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[j], MS_F32X4_GETI(src1, j), + MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j)); + } + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + __m128 dst5 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0))); + __m128 dst6 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0))); + __m128 dst7 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0))); + __m128 dst8 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0))); + for (int j = 1; j < 4; ++j) { + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], MS_F32X4_GETI(src1, j), + MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j)); + } + if (ic4_tmp != 0) { + ic4_tmp -= 1; + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + weight_data[0] = _mm_loadu_ps(weight); + weight_data[1] = _mm_loadu_ps(weight + 4); + weight += 8; + + dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0)))); + dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0)))); + for (; ic4_tmp != 0; ic4_tmp -= 1) { + dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0)))); + dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0)))); + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], MS_F32X4_GETI(src1, 1), + MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1)); + + weight_data[2] = _mm_loadu_ps(weight); + weight_data[3] = _mm_loadu_ps(weight + 4); + weight += 8; + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], MS_F32X4_GETI(src1, 2), + MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2)); + + dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src1, 3)))); + dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src2, 3)))); + src1 = _mm_loadu_ps(src); + src2 = _mm_loadu_ps(src + 4); + dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src3, 3)))); + dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src4, 3)))); + src3 = _mm_loadu_ps(src + 8); + src4 = _mm_loadu_ps(src + 12); + src += 16; + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[0], MS_F32X4_GETI(src1, 0), + MS_F32X4_GETI(src2, 0), MS_F32X4_GETI(src3, 0), MS_F32X4_GETI(src4, 0)); + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[1], MS_F32X4_GETI(src1, 1), + MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1)); + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[2], MS_F32X4_GETI(src1, 2), + MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2)); + + weight_data[0] = _mm_loadu_ps(weight); + weight_data[1] = _mm_loadu_ps(weight + 4); + weight += 8; + + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[3], MS_F32X4_GETI(src1, 3), + MS_F32X4_GETI(src2, 3), MS_F32X4_GETI(src3, 3), MS_F32X4_GETI(src4, 3)); + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + + dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0)))); + dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0)))); + } + dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0)))); + dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0)))); + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], MS_F32X4_GETI(src1, 1), + MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1)); + + weight_data[2] = _mm_loadu_ps(weight); + weight_data[3] = _mm_loadu_ps(weight + 4); + weight += 8; + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], MS_F32X4_GETI(src1, 2), + MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2)); + + TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[3], MS_F32X4_GETI(src1, 3), + MS_F32X4_GETI(src2, 3), MS_F32X4_GETI(src3, 3), MS_F32X4_GETI(src4, 3)); + + TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src); + src += 16; + for (int j = 0; j < 4; ++j) { + TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], MS_F32X4_GETI(src1, j), + MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j)); + } + } + _mm_storeu_ps(dst, dst1); + _mm_storeu_ps(dst + 4, dst2); + _mm_storeu_ps(dst + 8, dst3); + _mm_storeu_ps(dst + 12, dst4); + _mm_storeu_ps(dst + 16, dst5); + _mm_storeu_ps(dst + 20, dst6); + _mm_storeu_ps(dst + 24, dst7); + _mm_storeu_ps(dst + 28, dst8); + dst = dst_tmp + cal_num; + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/WinogradPostFuncBiasReluC4.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/WinogradPostFuncBiasReluC4.c new file mode 100644 index 0000000000000000000000000000000000000000..33a5715326df101f895ca920fc647df4a36d096c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/WinogradPostFuncBiasReluC4.c @@ -0,0 +1,349 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/fp32/common_func_fp32.h" +#include "nnacl/intrinsics/sse/sse_common.h" + +void WinogradPostFuncBiasReluC4(float *dst, const float *src, const float *bias, size_t oc4div, size_t oc4mod, + size_t plane_size, size_t plane_stride, size_t relu_type) { + size_t stride = oc4div + oc4mod; + plane_stride /= sizeof(float); + int loop_c4 = 0; + size_t src_stride = plane_size * C4NUM + plane_stride; + for (; loop_c4 <= (int)(oc4div)-C16NUM; loop_c4 += C16NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c4 = dst + loop_c4; + __m128 bias1 = _mm_setzero_ps(); + __m128 bias2 = _mm_setzero_ps(); + __m128 bias3 = _mm_setzero_ps(); + __m128 bias4 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + C4NUM); + bias3 = _mm_loadu_ps(bias + C8NUM); + bias4 = _mm_loadu_ps(bias + C12NUM); + bias += C16NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + C4NUM); + __m128 src5 = _mm_loadu_ps(src + src_stride); + __m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM); + __m128 src9 = _mm_loadu_ps(src + src_stride * C2NUM); + __m128 src10 = _mm_loadu_ps(src + src_stride * C2NUM + C4NUM); + __m128 src13 = _mm_loadu_ps(src + src_stride * C3NUM); + __m128 src14 = _mm_loadu_ps(src + src_stride * C3NUM + C4NUM); + + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias1); + src5 = _mm_add_ps(src5, bias2); + src6 = _mm_add_ps(src6, bias2); + src9 = _mm_add_ps(src9, bias3); + src10 = _mm_add_ps(src10, bias3); + src13 = _mm_add_ps(src13, bias4); + src14 = _mm_add_ps(src14, bias4); + + ActBlock8(&src1, &src2, &src5, &src6, &src9, &src10, &src13, &src14, relu_type); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src5); + _mm_storeu_ps(dst_c4 + C8NUM, src9); + _mm_storeu_ps(dst_c4 + C12NUM, src13); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src2); + _mm_storeu_ps(dst_c4 + C4NUM, src6); + _mm_storeu_ps(dst_c4 + C8NUM, src10); + _mm_storeu_ps(dst_c4 + C12NUM, src14); + dst_c4 += stride; + + __m128 src3 = _mm_loadu_ps(src + C8NUM); + __m128 src4 = _mm_loadu_ps(src + C12NUM); + __m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM); + __m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM); + __m128 src11 = _mm_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m128 src12 = _mm_loadu_ps(src + src_stride * C2NUM + C12NUM); + __m128 src15 = _mm_loadu_ps(src + src_stride * C3NUM + C8NUM); + __m128 src16 = _mm_loadu_ps(src + src_stride * C3NUM + C12NUM); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias1); + src7 = _mm_add_ps(src7, bias2); + src8 = _mm_add_ps(src8, bias2); + src11 = _mm_add_ps(src11, bias3); + src12 = _mm_add_ps(src12, bias3); + src15 = _mm_add_ps(src15, bias4); + src16 = _mm_add_ps(src16, bias4); + + ActBlock8(&src3, &src4, &src7, &src8, &src11, &src12, &src15, &src16, relu_type); + + _mm_storeu_ps(dst_c4, src3); + _mm_storeu_ps(dst_c4 + C4NUM, src7); + _mm_storeu_ps(dst_c4 + C8NUM, src11); + _mm_storeu_ps(dst_c4 + C12NUM, src15); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src4); + _mm_storeu_ps(dst_c4 + C4NUM, src8); + _mm_storeu_ps(dst_c4 + C8NUM, src12); + _mm_storeu_ps(dst_c4 + C12NUM, src16); + dst_c4 += stride; + src += C16NUM; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + src_stride); + __m128 src3 = _mm_loadu_ps(src + src_stride * C2NUM); + __m128 src4 = _mm_loadu_ps(src + src_stride * C3NUM); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + src3 = _mm_add_ps(src3, bias3); + src4 = _mm_add_ps(src4, bias4); + + ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src2); + _mm_storeu_ps(dst_c4 + C8NUM, src3); + _mm_storeu_ps(dst_c4 + C12NUM, src4); + dst_c4 += stride; + src += C4NUM; + } + src += plane_stride; + src += C3NUM * src_stride; + } + for (; loop_c4 <= (int)(oc4div)-C12NUM; loop_c4 += C12NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c4 = dst + loop_c4; + __m128 bias1 = _mm_setzero_ps(); + __m128 bias2 = _mm_setzero_ps(); + __m128 bias3 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + C4NUM); + bias3 = _mm_loadu_ps(bias + C8NUM); + bias += C12NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + C4NUM); + __m128 src3 = _mm_loadu_ps(src + C8NUM); + __m128 src4 = _mm_loadu_ps(src + C12NUM); + __m128 src5 = _mm_loadu_ps(src + src_stride); + __m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM); + __m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM); + __m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM); + __m128 src9 = _mm_loadu_ps(src + src_stride * C2NUM); + __m128 src10 = _mm_loadu_ps(src + src_stride * C2NUM + C4NUM); + __m128 src11 = _mm_loadu_ps(src + src_stride * C2NUM + C8NUM); + __m128 src12 = _mm_loadu_ps(src + src_stride * C2NUM + C12NUM); + src += C16NUM; + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias1); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias1); + src5 = _mm_add_ps(src5, bias2); + src6 = _mm_add_ps(src6, bias2); + src7 = _mm_add_ps(src7, bias2); + src8 = _mm_add_ps(src8, bias2); + src9 = _mm_add_ps(src9, bias3); + src10 = _mm_add_ps(src10, bias3); + src11 = _mm_add_ps(src11, bias3); + src12 = _mm_add_ps(src12, bias3); + + ActBlock12(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, &src9, &src10, &src11, &src12, relu_type == 1, + relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src5); + _mm_storeu_ps(dst_c4 + C8NUM, src9); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src2); + _mm_storeu_ps(dst_c4 + C4NUM, src6); + _mm_storeu_ps(dst_c4 + C8NUM, src10); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src3); + _mm_storeu_ps(dst_c4 + C4NUM, src7); + _mm_storeu_ps(dst_c4 + C8NUM, src11); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src4); + _mm_storeu_ps(dst_c4 + C4NUM, src8); + _mm_storeu_ps(dst_c4 + C8NUM, src12); + dst_c4 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + src_stride); + __m128 src3 = _mm_loadu_ps(src + src_stride * C2NUM); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + src3 = _mm_add_ps(src3, bias3); + + ActBlock1(&src1, relu_type == 1, relu_type == C3NUM); + ActBlock1(&src2, relu_type == 1, relu_type == C3NUM); + ActBlock1(&src3, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src2); + _mm_storeu_ps(dst_c4 + C8NUM, src3); + dst_c4 += stride; + src += C4NUM; + } + src += plane_stride; + src += C2NUM * src_stride; + } + + for (; loop_c4 <= (int)(oc4div)-C8NUM; loop_c4 += C8NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c4 = dst + loop_c4; + __m128 bias1 = _mm_setzero_ps(); + __m128 bias2 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias2 = _mm_loadu_ps(bias + C4NUM); + bias += C8NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + C4NUM); + __m128 src3 = _mm_loadu_ps(src + C8NUM); + __m128 src4 = _mm_loadu_ps(src + C12NUM); + __m128 src5 = _mm_loadu_ps(src + src_stride); + __m128 src6 = _mm_loadu_ps(src + src_stride + C4NUM); + __m128 src7 = _mm_loadu_ps(src + src_stride + C8NUM); + __m128 src8 = _mm_loadu_ps(src + src_stride + C12NUM); + src += C16NUM; + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias1); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias1); + src5 = _mm_add_ps(src5, bias2); + src6 = _mm_add_ps(src6, bias2); + src7 = _mm_add_ps(src7, bias2); + src8 = _mm_add_ps(src8, bias2); + + ActBlock8(&src1, &src2, &src3, &src4, &src5, &src6, &src7, &src8, relu_type); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src5); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src2); + _mm_storeu_ps(dst_c4 + C4NUM, src6); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src3); + _mm_storeu_ps(dst_c4 + C4NUM, src7); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src4); + _mm_storeu_ps(dst_c4 + C4NUM, src8); + dst_c4 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + src_stride); + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias2); + + ActBlock1(&src1, relu_type == 1, relu_type == C3NUM); + ActBlock1(&src2, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + _mm_storeu_ps(dst_c4 + C4NUM, src2); + dst_c4 += stride; + src += C4NUM; + } + src += plane_stride; + src += src_stride; + } + for (; loop_c4 < (int)(oc4div); loop_c4 += C4NUM) { + size_t plane_size_tmp = plane_size; + float *dst_c4 = dst + loop_c4; + __m128 bias1 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias += C4NUM; + } + for (; plane_size_tmp >= C4NUM; plane_size_tmp -= C4NUM) { + __m128 src1 = _mm_loadu_ps(src); + __m128 src2 = _mm_loadu_ps(src + C4NUM); + __m128 src3 = _mm_loadu_ps(src + 8); + __m128 src4 = _mm_loadu_ps(src + 12); + src += C16NUM; + src1 = _mm_add_ps(src1, bias1); + src2 = _mm_add_ps(src2, bias1); + src3 = _mm_add_ps(src3, bias1); + src4 = _mm_add_ps(src4, bias1); + + ActBlock4(&src1, &src2, &src3, &src4, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src2); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src3); + dst_c4 += stride; + _mm_storeu_ps(dst_c4, src4); + dst_c4 += stride; + } + for (; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + src1 = _mm_add_ps(src1, bias1); + + ActBlock1(&src1, relu_type == 1, relu_type == C3NUM); + + _mm_storeu_ps(dst_c4, src1); + dst_c4 += stride; + src += C4NUM; + } + src += plane_stride; + } + if (oc4mod == 0) { + return; + } + __m128 bias1 = _mm_setzero_ps(); + if (bias != NULL) { + bias1 = _mm_loadu_ps(bias); + bias += C4NUM; + } + float *dst_c1 = dst + oc4div; + for (size_t plane_size_tmp = plane_size; plane_size_tmp > 0; plane_size_tmp -= 1) { + __m128 src1 = _mm_loadu_ps(src); + src += C4NUM; + src1 = _mm_add_ps(src1, bias1); + + ActBlock1(&src1, relu_type == 1, relu_type == C3NUM); + + switch (oc4mod) { + case 1: + _mm_store_ss(dst_c1, src1); + dst_c1 += stride; + break; + case C2NUM: + _mm_storel_pi((__m64 *)(dst_c1), src1); + dst_c1 += stride; + break; + case C3NUM: + _mm_storel_pi((__m64 *)(dst_c1), src1); + src1 = _mm_unpackhi_ps(src1, src1); + _mm_store_ss(dst_c1 + C2NUM, src1); + dst_c1 += stride; + break; + case C4NUM: + _mm_storeu_ps(dst_c1, src1); + dst_c1 += stride; + break; + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/WinogradTrans.c b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/WinogradTrans.c new file mode 100644 index 0000000000000000000000000000000000000000..adba5508c74df5fa4f023cfd2bb38fde4420fc97 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/WinogradTrans.c @@ -0,0 +1,376 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) +#include "nnacl/intrinsics/ms_simd_instructions.h" +#include "nnacl/fp32/common_func_fp32.h" + +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + size_t len_c4 = length * 4; + size_t S_step = length * w * 4; + for (int h1 = 0; h1 < h; ++h1) { + const float *SW = S; + memset(M, 0, len_c4 * w * sizeof(float)); + for (int w_tmp = w; w_tmp > 0; --w_tmp) { + const float *SK = SW; + const float *BK = B; + int k_tmp = k; + for (; k_tmp >= 7; k_tmp -= 7) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + __m128 k4 = _mm_load_ps1(BK + 3 * h); + __m128 k5 = _mm_load_ps1(BK + 4 * h); + __m128 k6 = _mm_load_ps1(BK + 5 * h); + __m128 k7 = _mm_load_ps1(BK + 6 * h); + BK += 7 * h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(SK + S_step); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_fmadd_ps(s3, k3, M1); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + M2 = _mm_fmadd_ps(s4, k4, M2); + __m128 s5 = _mm_loadu_ps(SK + 4 * S_step); + M1 = _mm_fmadd_ps(s5, k5, M1); + __m128 s6 = _mm_loadu_ps(SK + 5 * S_step); + M2 = _mm_fmadd_ps(s6, k6, M2); + __m128 s7 = _mm_loadu_ps(SK + 6 * S_step); + M1 = _mm_fmadd_ps(s7, k7, M1); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(SK + S_step); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); + __m128 s5 = _mm_loadu_ps(SK + 4 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s5, k5)); + __m128 s6 = _mm_loadu_ps(SK + 5 * S_step); + s1 = _mm_add_ps(s1, _mm_mul_ps(s6, k6)); + __m128 s7 = _mm_loadu_ps(SK + 6 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s7, k7)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + M -= len_c4; + SK += 7 * S_step - len_c4; + } + for (; k_tmp >= 4; k_tmp -= 4) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + __m128 k4 = _mm_load_ps1(BK + 3 * h); + BK += 4 * h; + int len_tmp = length; +#ifdef ENABLE_AVX + for (; len_tmp >= C2NUM; len_tmp -= C2NUM, SK += C8NUM, M += C8NUM) { + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_loadu_ps(M + C4NUM); + __m128 s1 = _mm_loadu_ps(SK); + __m128 s11 = _mm_loadu_ps(SK + C4NUM); + M1 = _mm_fmadd_ps(s1, k1, M1); + M2 = _mm_fmadd_ps(s11, k1, M2); + __m128 s2 = _mm_loadu_ps(SK + S_step); + __m128 s22 = _mm_loadu_ps(SK + S_step + C4NUM); + M1 = _mm_fmadd_ps(s2, k2, M1); + M2 = _mm_fmadd_ps(s22, k2, M2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + __m128 s33 = _mm_loadu_ps(SK + 2 * S_step + C4NUM); + M1 = _mm_fmadd_ps(s3, k3, M1); + M2 = _mm_fmadd_ps(s33, k3, M2); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + __m128 s44 = _mm_loadu_ps(SK + 3 * S_step + C4NUM); + M1 = _mm_fmadd_ps(s4, k4, M1); + M2 = _mm_fmadd_ps(s44, k4, M2); + _mm_storeu_ps(M, M1); + _mm_storeu_ps(M + C4NUM, M2); + } +#endif + for (; len_tmp > 0; --len_tmp, SK += 4, M += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(SK + S_step); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_fmadd_ps(s3, k3, M1); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + M2 = _mm_fmadd_ps(s4, k4, M2); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(SK + S_step); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + __m128 s4 = _mm_loadu_ps(SK + 3 * S_step); + s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + M -= len_c4; + SK += 4 * S_step - len_c4; + } + for (; k_tmp >= 3; k_tmp -= 3) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + BK += 3 * h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(SK + S_step); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_fmadd_ps(s3, k3, M1); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(SK + S_step); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(SK + 2 * S_step); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + M -= len_c4; + SK += 3 * S_step - len_c4; + } + for (; k_tmp > 0; k_tmp -= 1) { + __m128 k1 = _mm_load_ps1(BK); + BK += h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, SK += 4, M += 4) { + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); +#ifdef ENABLE_AVX + M1 = _mm_fmadd_ps(s0, k1, M1); +#else + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); +#endif + _mm_storeu_ps(M, M1); + } + M -= len_c4; + SK += S_step - len_c4; + } + SW += len_c4; + M += len_c4; + } + B += 1; + } +} + +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { + size_t len_c4 = length * 4, k_step = len_c4 * k; + for (int h1 = 0; h1 < h; ++h1, S += k_step) { + const float *BW = B; + memset(M, 0, len_c4 * w * sizeof(float)); + for (int ww = 0; ww < w; ++ww, BW += 1, M += len_c4) { + const float *SK = S, *BK = BW; + int k_tmp = k; + for (; k_tmp >= 7; k_tmp -= 7, M -= len_c4) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + __m128 k4 = _mm_load_ps1(BK + 3 * h); + __m128 k5 = _mm_load_ps1(BK + 4 * h); + __m128 k6 = _mm_load_ps1(BK + 5 * h); + __m128 k7 = _mm_load_ps1(BK + 6 * h); + BK += 7 * h; + const float *S2 = SK + len_c4, *S3 = S2 + len_c4; + const float *S4 = S3 + len_c4, *S5 = S4 + len_c4; + const float *S6 = S5 + len_c4, *S7 = S6 + len_c4; + for (int len_tmp = length; len_tmp > 0; + --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4, S5 += 4, S6 += 4, S7 += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(S2); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_fmadd_ps(s3, k3, M1); + __m128 s4 = _mm_loadu_ps(S4); + M2 = _mm_fmadd_ps(s4, k4, M2); + __m128 s5 = _mm_loadu_ps(S5); + M1 = _mm_fmadd_ps(s5, k5, M1); + __m128 s6 = _mm_loadu_ps(S6); + M2 = _mm_fmadd_ps(s6, k6, M2); + __m128 s7 = _mm_loadu_ps(S7); + M1 = _mm_fmadd_ps(s7, k7, M1); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(S2); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + __m128 s4 = _mm_loadu_ps(S4); + s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); + __m128 s5 = _mm_loadu_ps(S5); + M1 = _mm_add_ps(M1, _mm_mul_ps(s5, k5)); + __m128 s6 = _mm_loadu_ps(S6); + s1 = _mm_add_ps(s1, _mm_mul_ps(s6, k6)); + __m128 s7 = _mm_loadu_ps(S7); + M1 = _mm_add_ps(M1, _mm_mul_ps(s7, k7)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + SK = S7; + } + for (; k_tmp >= 4; k_tmp -= 4, M -= len_c4) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + __m128 k4 = _mm_load_ps1(BK + 3 * h); + BK += 4 * h; + const float *S2 = SK + len_c4; + const float *S3 = S2 + len_c4; + const float *S4 = S3 + len_c4; + int len_tmp = length; +#ifdef ENABLE_AVX + for (; len_tmp >= C2NUM; len_tmp -= C2NUM, M += C8NUM, SK += C8NUM, S2 += C8NUM, S3 += C8NUM, S4 += C8NUM) { + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_loadu_ps(M + C4NUM); + __m128 s1 = _mm_loadu_ps(SK); + __m128 s11 = _mm_loadu_ps(SK + C4NUM); + M1 = _mm_fmadd_ps(s1, k1, M1); + M2 = _mm_fmadd_ps(s11, k1, M2); + __m128 s2 = _mm_loadu_ps(S2); + __m128 s22 = _mm_loadu_ps(S2 + C4NUM); + M1 = _mm_fmadd_ps(s2, k2, M1); + M2 = _mm_fmadd_ps(s22, k2, M2); + __m128 s3 = _mm_loadu_ps(S3); + __m128 s33 = _mm_loadu_ps(S3 + C4NUM); + M1 = _mm_fmadd_ps(s3, k3, M1); + M2 = _mm_fmadd_ps(s33, k3, M2); + __m128 s4 = _mm_loadu_ps(S4); + __m128 s44 = _mm_loadu_ps(S4 + C4NUM); + M1 = _mm_fmadd_ps(s4, k4, M1); + M2 = _mm_fmadd_ps(s44, k4, M2); + _mm_storeu_ps(M, M1); + _mm_storeu_ps(M + C4NUM, M2); + } +#endif + for (; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4, S4 += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s1 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s1, k1, M1); + __m128 s2 = _mm_loadu_ps(S2); + M2 = _mm_fmadd_ps(s2, k2, M2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_fmadd_ps(s3, k3, M1); + __m128 s4 = _mm_loadu_ps(S4); + M2 = _mm_fmadd_ps(s4, k4, M2); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(S2); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + __m128 s4 = _mm_loadu_ps(S4); + s1 = _mm_add_ps(s1, _mm_mul_ps(s4, k4)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + SK = S4; + } + for (; k_tmp >= 3; k_tmp -= 3, M -= len_c4) { + __m128 k1 = _mm_load_ps1(BK); + __m128 k2 = _mm_load_ps1(BK + h); + __m128 k3 = _mm_load_ps1(BK + 2 * h); + BK += 3 * h; + const float *S2 = SK + len_c4; + const float *S3 = S2 + len_c4; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4, S2 += 4, S3 += 4) { +#ifdef ENABLE_AVX + __m128 M1 = _mm_loadu_ps(M); + __m128 M2 = _mm_set1_ps(0.0f); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_fmadd_ps(s0, k1, M1); + __m128 s1 = _mm_loadu_ps(S2); + M2 = _mm_fmadd_ps(s1, k2, M2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_fmadd_ps(s3, k3, M1); + M1 = _mm_add_ps(M1, M2); + _mm_storeu_ps(M, M1); +#else + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); + __m128 s1 = _mm_loadu_ps(S2); + s1 = _mm_mul_ps(s1, k2); + __m128 s3 = _mm_loadu_ps(S3); + M1 = _mm_add_ps(M1, _mm_mul_ps(s3, k3)); + M1 = _mm_add_ps(M1, s1); + _mm_storeu_ps(M, M1); +#endif + } + SK = S3; + } + for (; k_tmp >= 1; k_tmp -= 1, M -= len_c4) { + __m128 k1 = _mm_load_ps1(BK); + BK += h; + for (int len_tmp = length; len_tmp > 0; --len_tmp, M += 4, SK += 4) { + __m128 M1 = _mm_loadu_ps(M); + __m128 s0 = _mm_loadu_ps(SK); +#ifdef ENABLE_AVX + M1 = _mm_fmadd_ps(s0, k1, M1); +#else + M1 = _mm_add_ps(M1, _mm_mul_ps(s0, k1)); +#endif + _mm_storeu_ps(M, M1); + } + } + } + } +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/sse_common.h b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/sse_common.h new file mode 100644 index 0000000000000000000000000000000000000000..5885954a77e715eec6396f009c4b87e0767fc933 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/intrinsics/sse/sse_common.h @@ -0,0 +1,390 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ +#define MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ + +#define SSE_ROW_NUM_1 1 +#define SSE_ROW_NUM_2 2 +#define SSE_ROW_NUM_3 3 + +#define SSE_INDEX_1 1 +#define SSE_INDEX_2 2 +#define SSE_INDEX_3 3 +#define SSE_INDEX_4 4 +#define SSE_INDEX_5 5 +#define SSE_INDEX_6 6 + +#define SSE_SHUFFLE_0321 (_MM_SHUFFLE(0, 3, 2, 1)) + +#define SSE_ACT_RELU 1 +#define SSE_ACT_RELU6 3 + +static inline void ActBlock1(__m128 *v1, size_t relu, size_t relu6) { + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + *v1 = _mm_max_ps(zero_ma, *v1); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + } +} + +static inline void ActBlock2(__m128 *v1, __m128 *v2, size_t relu, size_t relu6) { + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + *v1 = _mm_max_ps(zero_ma, *v1); + *v2 = _mm_max_ps(zero_ma, *v2); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + *v2 = _mm_min_ps(relu6_ma, *v2); + } +} + +static inline void ActBlock4(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, size_t relu, size_t relu6) { + __m128 zero_ma = _mm_setzero_ps(); + if (relu || relu6) { + *v1 = _mm_max_ps(zero_ma, *v1); + *v2 = _mm_max_ps(zero_ma, *v2); + *v3 = _mm_max_ps(zero_ma, *v3); + *v4 = _mm_max_ps(zero_ma, *v4); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + *v2 = _mm_min_ps(relu6_ma, *v2); + *v3 = _mm_min_ps(relu6_ma, *v3); + *v4 = _mm_min_ps(relu6_ma, *v4); + } +} + +static inline void ActBlock12(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7, + __m128 *v8, __m128 *v9, __m128 *v10, __m128 *v11, __m128 *v12, size_t relu, + size_t relu6) { + if (relu || relu6) { + __m128 zero_ma = _mm_setzero_ps(); + *v1 = _mm_max_ps(zero_ma, *v1); + *v2 = _mm_max_ps(zero_ma, *v2); + *v3 = _mm_max_ps(zero_ma, *v3); + *v4 = _mm_max_ps(zero_ma, *v4); + *v5 = _mm_max_ps(zero_ma, *v5); + *v6 = _mm_max_ps(zero_ma, *v6); + *v7 = _mm_max_ps(zero_ma, *v7); + *v8 = _mm_max_ps(zero_ma, *v8); + *v9 = _mm_max_ps(zero_ma, *v9); + *v10 = _mm_max_ps(zero_ma, *v10); + *v11 = _mm_max_ps(zero_ma, *v11); + *v12 = _mm_max_ps(zero_ma, *v12); + } + if (relu6) { + __m128 relu6_ma = _mm_set_ps(6.0f, 6.0f, 6.0f, 6.0f); + *v1 = _mm_min_ps(relu6_ma, *v1); + *v2 = _mm_min_ps(relu6_ma, *v2); + *v3 = _mm_min_ps(relu6_ma, *v3); + *v4 = _mm_min_ps(relu6_ma, *v4); + *v5 = _mm_min_ps(relu6_ma, *v5); + *v6 = _mm_min_ps(relu6_ma, *v6); + *v7 = _mm_min_ps(relu6_ma, *v7); + *v8 = _mm_min_ps(relu6_ma, *v8); + *v9 = _mm_min_ps(relu6_ma, *v9); + *v10 = _mm_min_ps(relu6_ma, *v10); + *v11 = _mm_min_ps(relu6_ma, *v11); + *v12 = _mm_min_ps(relu6_ma, *v12); + } +} + +static inline void ActBlock8(__m128 *v1, __m128 *v2, __m128 *v3, __m128 *v4, __m128 *v5, __m128 *v6, __m128 *v7, + __m128 *v8, size_t relu_type) { + __m128 relu6 = _mm_set_ps1(6.0); + __m128 zero = _mm_setzero_ps(); + switch (relu_type) { + case SSE_ACT_RELU6: + *v1 = _mm_min_ps(*v1, relu6); + *v2 = _mm_min_ps(*v2, relu6); + *v3 = _mm_min_ps(*v3, relu6); + *v4 = _mm_min_ps(*v4, relu6); + *v5 = _mm_min_ps(*v5, relu6); + *v6 = _mm_min_ps(*v6, relu6); + *v7 = _mm_min_ps(*v7, relu6); + *v8 = _mm_min_ps(*v8, relu6); + case SSE_ACT_RELU: + *v1 = _mm_max_ps(*v1, zero); + *v2 = _mm_max_ps(*v2, zero); + *v3 = _mm_max_ps(*v3, zero); + *v4 = _mm_max_ps(*v4, zero); + *v5 = _mm_max_ps(*v5, zero); + *v6 = _mm_max_ps(*v6, zero); + *v7 = _mm_max_ps(*v7, zero); + *v8 = _mm_max_ps(*v8, zero); + default: + break; + } +} + +static inline void WriteCol1(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_store_ss(*dst, *dst1); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol2(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int r) { + _mm_store_ss(*dst, *dst1); + *dst1 = _mm_shuffle_ps(*dst1, *dst1, SSE_SHUFFLE_0321); + _mm_store_ss(*dst, *dst1); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321); + _mm_store_ss(*dst, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321); + _mm_store_ss(*dst, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321); + _mm_store_ss(*dst, *dst7); + } +} + +static inline void WriteCol2Opt(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int r) { + _mm_store_ss(*dst, *dst1); + *dst1 = _mm_shuffle_ps(*dst1, *dst1, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst1); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst7); + *dst += stride; + *dst += SSE_INDEX_2; + } +} + +static inline void WriteCol3(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_store_ss(*dst, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst3); + *dst3 = _mm_shuffle_ps(*dst3, *dst3, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_2, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_store_ss(*dst, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst5); + *dst5 = _mm_shuffle_ps(*dst5, *dst5, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_2, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_store_ss(*dst, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_1, *dst7); + *dst7 = _mm_shuffle_ps(*dst7, *dst7, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_2, *dst7); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol4(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol5(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_store_ss(*dst + SSE_INDEX_4, *dst2); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_store_ss(*dst + SSE_INDEX_4, *dst4); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_store_ss(*dst + SSE_INDEX_4, *dst6); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_store_ss(*dst + SSE_INDEX_4, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol6(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_store_ss(*dst + SSE_INDEX_4, *dst2); + *dst2 = _mm_shuffle_ps(*dst2, *dst2, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst2); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_store_ss(*dst + SSE_INDEX_4, *dst4); + *dst4 = _mm_shuffle_ps(*dst4, *dst4, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst4); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_store_ss(*dst + SSE_INDEX_4, *dst6); + *dst6 = _mm_shuffle_ps(*dst6, *dst6, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst6); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_store_ss(*dst + SSE_INDEX_4, *dst8); + *dst8 = _mm_shuffle_ps(*dst8, *dst8, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol7(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_store_ss(*dst + SSE_INDEX_4, *dst2); + *dst2 = _mm_shuffle_ps(*dst2, *dst2, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst2); + *dst2 = _mm_shuffle_ps(*dst2, *dst2, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_6, *dst2); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_store_ss(*dst + SSE_INDEX_4, *dst4); + *dst4 = _mm_shuffle_ps(*dst4, *dst4, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst4); + *dst4 = _mm_shuffle_ps(*dst4, *dst4, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_6, *dst4); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_store_ss(*dst + SSE_INDEX_4, *dst6); + *dst6 = _mm_shuffle_ps(*dst6, *dst6, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst6); + *dst6 = _mm_shuffle_ps(*dst6, *dst6, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_6, *dst6); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_store_ss(*dst + SSE_INDEX_4, *dst8); + *dst8 = _mm_shuffle_ps(*dst8, *dst8, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_5, *dst8); + *dst8 = _mm_shuffle_ps(*dst8, *dst8, SSE_SHUFFLE_0321); + _mm_store_ss(*dst + SSE_INDEX_6, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void WriteCol8(float **dst, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, __m128 *dst5, + __m128 *dst6, __m128 *dst7, __m128 *dst8, int stride, int extra_stride, int r) { + _mm_storeu_ps(*dst, *dst1); + _mm_storeu_ps(*dst + SSE_INDEX_4, *dst2); + if (r > SSE_ROW_NUM_1) { + *dst += stride; + _mm_storeu_ps(*dst, *dst3); + _mm_storeu_ps(*dst + SSE_INDEX_4, *dst4); + } + if (r > SSE_ROW_NUM_2) { + *dst += stride; + _mm_storeu_ps(*dst, *dst5); + _mm_storeu_ps(*dst + SSE_INDEX_4, *dst6); + } + if (r > SSE_ROW_NUM_3) { + *dst += stride; + _mm_storeu_ps(*dst, *dst7); + _mm_storeu_ps(*dst + SSE_INDEX_4, *dst8); + *dst += stride; + *dst += extra_stride; + } +} + +static inline void DoBiasBlock8(const float *bias_ptr, __m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4, + __m128 *dst5, __m128 *dst6, __m128 *dst7, __m128 *dst8) { + __m128 bias1 = _mm_loadu_ps(bias_ptr); + __m128 bias2 = _mm_loadu_ps(bias_ptr + C4NUM); + *dst1 = _mm_add_ps(*dst1, bias1); + *dst2 = _mm_add_ps(*dst2, bias2); + *dst3 = _mm_add_ps(*dst3, bias1); + *dst4 = _mm_add_ps(*dst4, bias2); + *dst5 = _mm_add_ps(*dst5, bias1); + *dst6 = _mm_add_ps(*dst6, bias2); + *dst7 = _mm_add_ps(*dst7, bias1); + *dst8 = _mm_add_ps(*dst8, bias2); +} + +#endif // MINDSPORE_LITE_NNACL_INTRINSICS_SSE_SSE_COMMON_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel.c new file mode 100644 index 0000000000000000000000000000000000000000..8623557f0ab74b90b9ee9bc208c05eec988db25a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel.c @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel.h" +#include "nnacl/tensor_c.h" +#include "nnacl/op_base.h" +#include "nnacl/kernel/init_exec_env.h" + +static KernelCreator g_kernelCreatorRegistry[PrimType_MAX][16]; + +void RegKernelCreator(int opType, int dataType, KernelCreator creator) { + g_kernelCreatorRegistry[opType][REGIST_DT(dataType)] = creator; +} + +void Init_MSC_VER_kernels(void) { +#ifdef _MSC_VER + /* VS env do not support automatic register + * register here first time */ + static bool inited = false; + if (inited == false) { + init_vs_kernels(g_kernelCreatorRegistry); + inited = true; + } +#endif + return; +} + +bool checkOpValid(int opType) { + if (opType < PrimType_MIN || opType >= PrimType_MAX) { + return false; + } + return true; +} + +bool SupportKernelC(int opType, int dataType) { + Init_MSC_VER_kernels(); + const int length = 16; + if (REGIST_DT(dataType) < 0 || REGIST_DT(dataType) >= length) { + return false; + } + if (!checkOpValid(opType)) { + return false; + } + KernelCreator creator = g_kernelCreatorRegistry[opType][REGIST_DT(dataType)]; + return creator != NULL; +} + +int DefaultThreadUpdate(int32_t type, int64_t load, int64_t store, int64_t unit, int thread) { + return thread > 0 ? thread : 1; +} + +int NNACLKernelInferShape(struct KernelBase *self) { return NNACL_ERR; } + +int NNACLCheckKernelBase(KernelBase *kernel_base) { + CheckExecEnv(kernel_base); + + if (kernel_base->param_ == NULL) { + return NNACL_ERR; + } + + if (kernel_base->thread_nr_ <= 0 || kernel_base->thread_nr_ > MAX_THREAD_NUM) { + return NNACL_ERR; + } + + if (kernel_base->in_size_ == 0 || kernel_base->in_ == NULL) { + return NNACL_ERR; + } + if (kernel_base->out_size_ == 0 || kernel_base->out_ == NULL) { + return NNACL_ERR; + } + return NNACL_OK; +} + +KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size, + int data_type, ExecEnv *env) { + Init_MSC_VER_kernels(); + if (param == NULL) { + return NULL; + } + if (!checkOpValid(param->type_)) { + return NULL; + } + + KernelCreator creator = g_kernelCreatorRegistry[param->type_][REGIST_DT(data_type)]; + if (creator == NULL) { + return NULL; + } + + KernelBase *kernel_base = creator(param, data_type); + if (kernel_base == NULL) { + return NULL; + } + + kernel_base->InferShape = NNACLKernelInferShape; + kernel_base->UpdateThread = DefaultThreadUpdate; + kernel_base->env_ = env; + kernel_base->param_ = param; + kernel_base->thread_nr_ = param->thread_num_; + kernel_base->train_session_ = param->is_train_session_; + kernel_base->in_ = ins; + kernel_base->in_size_ = in_size; + kernel_base->out_ = outs; + kernel_base->out_size_ = out_size; + + int ret = NNACLCheckKernelBase(kernel_base); + if (ret != NNACL_OK) { + return NULL; + } + + return kernel_base; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d96c14634e42b25ba3da3c79c13c2ab97a5abe64 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_H_ +#define NNACL_KERNEL_H_ + +#include "nnacl/op_base.h" +#include "nnacl/infer/common_infer.h" + +typedef struct ExecEnv { + void *allocator_; + void *thread_pool_; + void *(*Alloc)(void *allocator, size_t sz); + void (*Free)(void *allocator, void *ptr); + int (*ParallelLaunch)(void *thread_pool, void *task, void *param, int task_num); +} ExecEnv; + +typedef struct KernelBase { + int (*Release)(struct KernelBase *self); + int (*Prepare)(struct KernelBase *self); + int (*Compute)(struct KernelBase *self); + int (*Resize)(struct KernelBase *self); + int (*InferShape)(struct KernelBase *self); + int (*UpdateThread)(int32_t type, int64_t load, int64_t store, int64_t unit, int thread); + OpParameter *param_; + int thread_nr_; + ExecEnv *env_; + TensorC **in_; + size_t in_size_; + TensorC **out_; + size_t out_size_; + bool train_session_; + void *workspace_; /* only used in train */ + int work_size_; /* only used in train */ +} KernelBase; + +#ifdef _MSC_VER +#define REG_KERNEL_CREATOR(op, data_type, func) +#else +#define REG_KERNEL_CREATOR(op, data_type, func) \ + __attribute__((constructor(102))) void Reg##op##data_type##Creator() { RegKernelCreator(op, data_type, func); } +#endif + +#define REGIST_DT(DataType) (DataType - kNumberTypeBegin - 1) +typedef KernelBase *(*KernelCreator)(OpParameter *param, int data_type); +void RegKernelCreator(int opType, int dataType, KernelCreator func); + +#ifdef __cplusplus +extern "C" { +#endif +KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size, + int data_type, ExecEnv *env); +bool SupportKernelC(int opType, int dataType); +#ifdef __cplusplus +} +#endif +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/activation.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/activation.c new file mode 100644 index 0000000000000000000000000000000000000000..f1433c12c57a261c4ec8d0e17b92cca64eb118e3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/activation.c @@ -0,0 +1,194 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/activation.h" +#include "nnacl/activation_parameter.h" +#include "nnacl/op_base.h" +#include "nnacl/fp32/activation_fp32.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/activation_fp16.h" +#endif + +typedef struct ActivationStruct { + KernelBase base; + int data_type_; + ActType act_type_; +} ActivationStruct; + +int ActivationResize(struct KernelBase *self) { + ActivationStruct *activation = (ActivationStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(activation); + self->thread_nr_ = self->UpdateThread(TC_TYPE(PrimType_Activation, activation->act_type_), 1, 1, + NNACLGetElementNum(self->out_[0]), self->thread_nr_); + return NNACL_OK; +} + +int activation_fp32_run(ActivationStruct *activation, int task_id, int count, int stride) { + ActivationParameter *param = (ActivationParameter *)activation->base.param_; + float *input = activation->base.in_[0]->data_; + float *output = activation->base.out_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + switch (activation->act_type_) { + case ActType_Relu: + return Fp32Relu(input + task_id * stride, count, output + task_id * stride); + case ActType_Relu6: + return Fp32Relu6(input + task_id * stride, count, output + task_id * stride); + case ActType_LeakyRelu: + return LRelu(input + task_id * stride, count, output + task_id * stride, param->alpha_); + case ActType_Sigmoid: + return Sigmoid(input + task_id * stride, count, output + task_id * stride); + case ActType_Tanh: + return Tanh(input + task_id * stride, count, output + task_id * stride); + case ActType_Swish: + return Swish(input + task_id * stride, count, output + task_id * stride); + case ActType_HSwish: + return HSwish(input + task_id * stride, count, output + task_id * stride); + case ActType_HSigmoid: + return HSigmoid(input + task_id * stride, count, output + task_id * stride); + case ActType_HardTanh: + return HardTanh(input + task_id * stride, count, output + task_id * stride, param->min_val_, param->max_val_); + case ActType_Gelu: + return Gelu(input + task_id * stride, count, output + task_id * stride, param->approximate_); + case ActType_Softplus: + return Softplus(input + task_id * stride, count, output + task_id * stride); + case ActType_Elu: + return Elu(input + task_id * stride, count, output + task_id * stride, param->alpha_); + default: + return NNACL_ACTIVATION_TYPE_INVALID; + } +} + +int activation_int32_run(ActivationStruct *activation, int task_id, int count, int stride) { + int32_t *input = activation->base.in_[0]->data_; + int32_t *output = activation->base.out_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + switch (activation->act_type_) { + case ActType_Relu: + return Int32Relu(input + task_id * stride, count, output + task_id * stride); + default: + return NNACL_ACTIVATION_TYPE_INVALID; + } +} + +int activation_fp16_run(ActivationStruct *activation, int task_id, int count, int stride) { +#ifdef ENABLE_FP16 + ActivationParameter *param = (ActivationParameter *)activation->base.param_; + float16_t *input = activation->base.in_[0]->data_; + float16_t *output = activation->base.out_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + switch (activation->act_type_) { + case ActType_Relu: + return ReluFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_Relu6: + return Relu6Fp16(input + stride * task_id, output + stride * task_id, count); + case ActType_LeakyRelu: + return LReluFp16(input + stride * task_id, output + stride * task_id, count, param->alpha_); + case ActType_Sigmoid: + return SigmoidFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_Tanh: + return TanhFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_HSwish: + return HSwishFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_Swish: + return SwishFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_HSigmoid: + return HSigmoidFp16(input + stride * task_id, output + stride * task_id, count); + case ActType_HardTanh: + return HardTanhFp16(input + stride * task_id, count, output + stride * task_id, param->min_val_, param->max_val_); + case ActType_Gelu: + return GeluFp16(input + stride * task_id, count, output + stride * task_id, true); + case ActType_Softplus: + return SoftplusFp16(input + stride * task_id, count, output + stride * task_id); + case ActType_Elu: + return EluFp16(input + stride * task_id, count, output + stride * task_id, param->alpha_); + default: + return NNACL_ACTIVATION_TYPE_INVALID; + } +#endif + return NNACL_DISABLE_FP16; +} + +int ActivationImpl(void *cdata, int task_id, float l, float r) { + ActivationStruct *activation = (ActivationStruct *)cdata; + + int ele_num = NNACLGetElementNum(activation->base.in_[0]); + NNACL_CHECK_ZERO_RETURN_ERR(activation->base.thread_nr_); + int stride = UP_DIV(ele_num, activation->base.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(stride, task_id, NNACL_ERR); + int count = MSMIN(stride, ele_num - stride * task_id); + if (count <= 0) { + return NNACL_OK; + } + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(stride, task_id), NNACL_ERR); + + switch (activation->data_type_) { + case kNumberTypeFloat32: + return activation_fp32_run(activation, task_id, count, stride); + case kNumberTypeFloat16: + return activation_fp16_run(activation, task_id, count, stride); + case kNumberTypeInt32: + return activation_int32_run(activation, task_id, count, stride); + default: + return NNACL_ACTIVATION_TYPE_INVALID; + } +} + +int ActivationCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ActivationImpl, self, self->thread_nr_); +} + +KernelBase *CreateActivation(OpParameter *param, int data_type) { + ActivationParameter *act = (ActivationParameter *)(param); + + int type = act->type_; + if (data_type == kNumberTypeInt32) { + if (type != ActType_Relu) { + return NULL; + } + } + + if (data_type == kNumberTypeFloat32 || data_type == kNumberTypeFloat16) { + if (type != ActType_Relu && type != ActType_Relu6 && type != ActType_LeakyRelu && type != ActType_Sigmoid && + type != ActType_Tanh && type != ActType_HSwish && type != ActType_Swish && type != ActType_HardTanh && + type != ActType_Gelu && type != ActType_HSigmoid && type != ActType_Softplus && type != ActType_Elu) { + return NULL; + } + } + + ActivationStruct *activation = (ActivationStruct *)malloc(sizeof(ActivationStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(activation); + memset(activation, 0, sizeof(ActivationStruct)); + + activation->data_type_ = data_type; + activation->act_type_ = act->type_; + activation->base.Prepare = DefaultPrepare1In1Out; + activation->base.Release = DefaultRelease; + activation->base.Resize = ActivationResize; + activation->base.Compute = ActivationCompute; + return (KernelBase *)activation; +} + +REG_KERNEL_CREATOR(PrimType_Activation, kNumberTypeFloat32, CreateActivation) +REG_KERNEL_CREATOR(PrimType_Activation, kNumberTypeFloat16, CreateActivation) +REG_KERNEL_CREATOR(PrimType_Activation, kNumberTypeUInt32, CreateActivation) diff --git a/mindspore-lite/src/extendrt/graph_compiler/type.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/activation.h similarity index 68% rename from mindspore-lite/src/extendrt/graph_compiler/type.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/activation.h index fdd6c1a491be3e8ba891ede5542d94de1dfcb521..3169869e06205721d379b32b2e326c6a808bed7e 100644 --- a/mindspore-lite/src/extendrt/graph_compiler/type.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/activation.h @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMIPLER_TYPE_H_ -#define MINDSPORE_LITE_EXTENDRT_GRAPH_COMIPLER_TYPE_H_ +#ifndef NNACL_KERNEL_ACTIVATION_H_ +#define NNACL_KERNEL_ACTIVATION_H_ -namespace mindspore { -enum GraphCompilerType { kDefaultCompiler = 0, kNoneCompiler }; -} // namespace mindspore -#endif // MINDSPORE_LITE_EXTENDRT_GRAPH_COMIPLER_TYPE_H_ +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateActivation(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ACTIVATION_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/addn.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/addn.c new file mode 100644 index 0000000000000000000000000000000000000000..d1f2a0040cd91049f1c4444afdad3e70623bba98 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/addn.c @@ -0,0 +1,144 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/addn.h" +#include "nnacl/fp32/add_fp32.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/kernel/default_kernel_base.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/arithmetic_fp16.h" +#endif + +int AddNLaunch(void *cdata, int task_id, float l, float r) { + AddNStruct *addn = (AddNStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(addn); + + int count_per_thread = UP_DIV(addn->elements_num_, addn->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, count_per_thread, NNACL_ERR); + int count = MSMIN(count_per_thread, addn->elements_num_ - task_id * count_per_thread); + int stride = count_per_thread * task_id; + +#ifdef ENABLE_FP16 + if (addn->data_type_ == kNumberTypeFloat16) { + return ElementAddFp16((float16_t *)addn->in1_addr_ + stride, (float16_t *)addn->in2_addr_ + stride, + (float16_t *)addn->out_addr_ + stride, count); + } +#endif + return ElementAdd((float *)addn->in1_addr_ + stride, (float *)addn->in2_addr_ + stride, + (float *)addn->out_addr_ + stride, count); +} + +void AddNCompute(AddNStruct *addn, bool same_shape, bool first_scalar) { +#ifdef ENABLE_FP16 + if (addn->data_type_ == kNumberTypeFloat16) { + if (same_shape) { + ElementAddFp16((float16_t *)addn->in1_addr_, (float16_t *)addn->in2_addr_, (float16_t *)addn->out_addr_, + addn->elements_num_); + } else { + ElementOptAddFp16((float16_t *)addn->in1_addr_, (float16_t *)addn->in2_addr_, (float16_t *)addn->out_addr_, + addn->elements_num_, first_scalar); + } + return; + } +#endif + + if (same_shape) { + ElementAdd((float *)addn->in1_addr_, (float *)addn->in2_addr_, (float *)addn->out_addr_, addn->elements_num_); + } else { + ElementOptAdd((float *)addn->in1_addr_, (float *)addn->in2_addr_, (float *)addn->out_addr_, addn->elements_num_, + first_scalar); + } + return; +} + +int AddNComputeNoParallel(AddNStruct *addn) { + TensorC *in0_tensor = addn->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in0_tensor); + TensorC *in1_tensor = addn->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in1_tensor); + AddNCompute(addn, NNACLIsShapeSame(in0_tensor, in1_tensor), NNACLGetElementNum(in0_tensor) == 1); + + for (size_t i = Index2; i < addn->base_.in_size_; i++) { + TensorC *in_tensor = addn->base_.in_[i]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + addn->in1_addr_ = in_tensor->data_; + addn->in2_addr_ = addn->out_addr_; + AddNCompute(addn, NNACLIsShapeSame(in_tensor, addn->base_.out_[OUTPUT_INDEX]), NNACLGetElementNum(in_tensor) == 1); + } + return NNACL_OK; +} + +int AddnResize(struct KernelBase *self) { + AddNStruct *addn = (AddNStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(addn); + + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + addn->elements_num_ = NNACLGetElementNum(out_tensor); + return NNACL_OK; +} + +int AddnCompute(struct KernelBase *self) { + AddNStruct *addn = (AddNStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(addn); + + addn->in1_addr_ = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(addn->in1_addr_); + addn->in2_addr_ = self->in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(addn->in2_addr_); + addn->out_addr_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(addn->out_addr_); + + if (addn->elements_num_ < self->thread_nr_) { + return AddNComputeNoParallel(addn); + } + + for (int i = 0; i < self->in_size_; i++) { + TensorC *in_tensor = self->in_[i]; + if (!NNACLIsShapeSame(in_tensor, self->out_[OUTPUT_INDEX])) { + return NNACL_ADDN_SHAPE_UNMATCH; + } + } + + int ret = self->env_->ParallelLaunch(self->env_->thread_pool_, AddNLaunch, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + + for (size_t i = Index2; i < self->in_size_; ++i) { + addn->in1_addr_ = self->in_[i]->data_; + NNACL_CHECK_NULL_RETURN_ERR(addn->in1_addr_); + addn->in2_addr_ = addn->out_addr_; + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, AddNLaunch, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +KernelBase *CreateAddN(OpParameter *param, int data_type) { + AddNStruct *addn = (AddNStruct *)malloc(sizeof(AddNStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(addn); + addn->data_type_ = data_type; + addn->base_.Prepare = DefaultPrepare1In1Out; + addn->base_.Resize = AddnResize; + addn->base_.Release = DefaultRelease; + addn->base_.Compute = AddnCompute; + return (KernelBase *)addn; +} + +REG_KERNEL_CREATOR(PrimType_AddN, kNumberTypeFloat16, CreateAddN) +REG_KERNEL_CREATOR(PrimType_AddN, kNumberTypeFloat32, CreateAddN) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/addn.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/addn.h new file mode 100644 index 0000000000000000000000000000000000000000..553a6eaa13fdfae27a26e09fdce939fd001e998c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/addn.h @@ -0,0 +1,35 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ADDN_H_ +#define NNACL_KERNEL_ADDN_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct AddNStruct { + KernelBase base_; + int data_type_; + int elements_num_; + void *in1_addr_; + void *in2_addr_; + void *out_addr_; +} AddNStruct; + +KernelBase *CreateAddN(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ADDN_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arg_min_max.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arg_min_max.c new file mode 100644 index 0000000000000000000000000000000000000000..3c12d85099730ff4215510855f32a67772aa6c9a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arg_min_max.c @@ -0,0 +1,127 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/arg_min_max.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/arg_min_max_parameter.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/fp32/arg_min_max_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/arg_min_max_fp16.h" +#endif + +int ArgMinMaxPrepare(KernelBase *self) { + ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arg_min_max); + ArgMinMaxParameter *param = (ArgMinMaxParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + arg_min_max->arg_elements_alloc_ = param->topk_ > Num1 || param->keep_dims_; + arg_min_max->compute_.topk_ = param->topk_; + arg_min_max->compute_.axis_ = param->axis_; + arg_min_max->compute_.keep_dims_ = param->keep_dims_; + arg_min_max->compute_.out_value_ = param->out_value_; + arg_min_max->compute_.get_max_ = self->param_->type_ == PrimType_ArgMinFusion ? false : true; + return NNACL_OK; +} + +int ArgMinMaxResize(KernelBase *self) { + ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arg_min_max); + ArgMinMaxComputeParam *compute = &arg_min_max->compute_; + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + ComputeStrides(input_tensor->shape_, compute->in_strides_, input_tensor->shape_size_); + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + ComputeStrides(output_tensor->shape_, compute->out_strides_, output_tensor->shape_size_); + + compute->dims_size_ = (int)input_tensor->shape_size_; + compute->axis_ = compute->axis_ < 0 ? compute->axis_ + compute->dims_size_ : compute->axis_; + NNACL_CHECK_FALSE(compute->topk_ <= 0, NNACL_ARG_MIN_MAX_AXIS_INVALID); + NNACL_CHECK_FALSE(compute->topk_ > input_tensor->shape_[compute->axis_], NNACL_ARG_MIN_MAX_AXIS_INVALID); + return NNACL_OK; +} + +int ArgMinMaxCompute(KernelBase *self) { + ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arg_min_max); + ArgMinMaxParameter *param = (ArgMinMaxParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + void *in_data = in_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(in_data); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + void *out_data = out_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + + void *out_value = NULL; + if (self->out_size_ == TWO_TENSOR) { + out_value = self->out_[Index1]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_value); + } + + if (arg_min_max->arg_elements_alloc_) { + int arg_size = in_tensor->shape_[arg_min_max->compute_.axis_] * (int)sizeof(ArgElement); + NNACL_CHECK_MALLOC_SIZE(arg_size); + arg_min_max->compute_.arg_elements_ = (ArgElement *)self->env_->Alloc(self->env_->allocator_, arg_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arg_min_max->compute_.arg_elements_); + } + + int ret = NNACL_OK; + int *in_shape = in_tensor->shape_; + if (in_tensor->data_type_ == kNumberTypeFloat32) { + ArgMinMaxFp32((float *)in_data, out_data, (float *)out_value, in_shape, &arg_min_max->compute_); +#ifdef ENABLE_FP16 + } else if (in_tensor->data_type_ == kNumberTypeFloat16) { + ArgMinMaxFp16((float16_t *)in_data, out_data, (float16_t *)out_value, in_shape, &arg_min_max->compute_); +#endif + } else if (in_tensor->data_type_ == kNumberTypeInt32) { + ArgMinMaxInt32((int32_t *)in_data, out_data, (int32_t *)out_value, in_shape, &arg_min_max->compute_); + } else { + ret = NNACL_UNSUPPORTED_DATA_TYPE; + } + + if (arg_min_max->arg_elements_alloc_) { + self->env_->Free(self->env_->allocator_, arg_min_max->compute_.arg_elements_); + arg_min_max->compute_.arg_elements_ = NULL; + } + return ret; +} + +KernelBase *CreateArgMinMax(OpParameter *param, int data_type) { + ArgMinMaxStruct *arg_min_max = (ArgMinMaxStruct *)malloc(sizeof(ArgMinMaxStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arg_min_max); + memset(arg_min_max, 0, sizeof(ArgMinMaxStruct)); + + arg_min_max->base_.Prepare = ArgMinMaxPrepare; + arg_min_max->base_.Resize = ArgMinMaxResize; + arg_min_max->base_.Release = DefaultRelease; + arg_min_max->base_.Compute = ArgMinMaxCompute; + return (KernelBase *)arg_min_max; +} + +REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeInt32, CreateArgMinMax) +REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeFloat16, CreateArgMinMax) +REG_KERNEL_CREATOR(PrimType_ArgMinFusion, kNumberTypeFloat32, CreateArgMinMax) + +REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeInt32, CreateArgMinMax) +REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeFloat16, CreateArgMinMax) +REG_KERNEL_CREATOR(PrimType_ArgMaxFusion, kNumberTypeFloat32, CreateArgMinMax) diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_less_kernel.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arg_min_max.h similarity index 37% rename from mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_less_kernel.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/arg_min_max.h index f166a043e87061251f405a10f020903dc4eb6294..755e3d2df10ee399602d8cff2dce0aa3cc321100 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_less_kernel.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arg_min_max.h @@ -14,31 +14,50 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_LESS_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_LESS_KERNEL_H_ - -#include -#include -#include -#include "extendrt/delegate/ascend_native/ascend_native_base_kernel.h" - -namespace mindspore::kernel { -class AscendNativeLessKernel : public AscendNativeBaseKernel { - public: - AscendNativeLessKernel(const std::vector &inputs, const std::vector &outputs, - InferPrimitive prim, std::shared_ptr ctx, const void *stream, - std::string name) - : AscendNativeBaseKernel(inputs, outputs, prim, ctx, stream, name) {} - - int InferShape() override; - - int Prepare() override; - - int Run() override; - - int ReSize() override; - - private: -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_LESS_KERNEL_H_ +#ifndef NNACL_KERNEL_ARG_MIN_MAX_H_ +#define NNACL_KERNEL_ARG_MIN_MAX_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#ifdef ENABLE_ARM64 +#include +#endif + +typedef struct ArgElement { + uint32_t index_; + union ArgData { + int8_t i8_data_; + int32_t i_data_; + float f_data_; +#ifdef ENABLE_ARM +#if (!SUPPORT_NNIE) || (defined SUPPORT_34XX) + float16_t f16_data_; +#endif +#endif + } data_; +} ArgElement; + +typedef int (*COMPARE_FUNCTION)(const void *a, const void *b); + +typedef struct ArgMinMaxComputeParam { + int32_t axis_; + int32_t dims_size_; + int32_t topk_; + bool get_max_; + bool keep_dims_; + bool out_value_; + int32_t in_strides_[COMM_SHAPE_SIZE]; + int32_t out_strides_[COMM_SHAPE_SIZE]; + ArgElement *arg_elements_; +} ArgMinMaxComputeParam; + +typedef struct ArgMinMaxStruct { + KernelBase base_; + ArgMinMaxComputeParam compute_; + bool arg_elements_alloc_; +} ArgMinMaxStruct; + +KernelBase *CreateArgMinMax(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ARG_MIN_MAX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic.c new file mode 100644 index 0000000000000000000000000000000000000000..d63d484ba299646b874f426eabf6f971adfd44f9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic.c @@ -0,0 +1,698 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either arithmeticress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/arithmetic.h" +#include "nnacl/op_base.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/fp32/mul_fp32.h" +#include "nnacl/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/arithmetic_fp16.h" +#endif + +void InitArithmeticRunFunction(KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + + ArithmeticFuncions fun_table[] = { + {PrimType_MulFusion, ActType_Relu, ElementMulRelu, ElementMulReluInt, NULL, ElementOptMulRelu, ElementOptMulReluInt, + NULL}, + {PrimType_MulFusion, ActType_Relu6, ElementMulRelu6, ElementMulRelu6Int, NULL, ElementOptMulRelu6, + ElementOptMulRelu6Int, NULL}, + {PrimType_MulFusion, ActType_No, ElementMul, ElementMulInt, NULL, ElementOptMul, ElementOptMulInt, NULL}, + {PrimType_AddFusion, ActType_Relu, ElementAddRelu, NULL, NULL, ElementOptAddRelu, NULL, NULL}, + {PrimType_AddFusion, ActType_Relu6, ElementAddRelu6, NULL, NULL, ElementOptAddRelu6, NULL, NULL}, + {PrimType_AddFusion, ActType_No, ElementAdd, ElementAddInt, NULL, ElementOptAdd, ElementOptAddInt, NULL}, + {PrimType_SubFusion, ActType_Relu, ElementSubRelu, NULL, NULL, ElementOptSubRelu, NULL, NULL}, + {PrimType_SubFusion, ActType_Relu6, ElementSubRelu6, NULL, NULL, ElementOptSubRelu6, NULL, NULL}, + {PrimType_SubFusion, ActType_No, ElementSub, ElementSubInt, NULL, ElementOptSub, ElementOptSubInt, NULL}, + {PrimType_DivFusion, ActType_Relu, ElementDivRelu, NULL, NULL, ElementOptDivRelu, NULL, NULL}, + {PrimType_DivFusion, ActType_Relu6, ElementDivRelu6, NULL, NULL, ElementOptDivRelu6, NULL, NULL}, + {PrimType_DivFusion, ActType_No, ElementDiv, NULL, NULL, ElementOptDiv, ElementOptDivInt, NULL}, + {PrimType_RealDiv, ActType_Relu, ElementDivRelu, NULL, NULL, ElementOptDivRelu, NULL, NULL}, + {PrimType_RealDiv, ActType_Relu6, ElementDivRelu6, NULL, NULL, ElementOptDivRelu6, NULL, NULL}, + {PrimType_RealDiv, ActType_No, ElementDiv, NULL, NULL, ElementOptDiv, ElementOptDivInt, NULL}, + {PrimType_LogicalAnd, ActType_No, ElementLogicalAnd, ElementLogicalAndInt, ElementLogicalAndBool, + ElementOptLogicalAnd, ElementOptLogicalAndInt, ElementOptLogicalAndBool}, + {PrimType_LogicalOr, ActType_No, ElementLogicalOr, NULL, ElementLogicalOrBool, NULL, NULL, ElementOptLogicalOrBool}, + {PrimType_Maximum, ActType_No, ElementMaximum, ElementMaximumInt, NULL, ElementOptMaximum, ElementOptMaximumInt, + NULL}, + {PrimType_Minimum, ActType_No, ElementMinimum, ElementMinimumInt, NULL, ElementOptMinimum, ElementOptMinimumInt, + NULL}, + {PrimType_FloorMod, ActType_No, ElementFloorMod, ElementFloorModInt, NULL, ElementOptFloorMod, + ElementOptFloorModInt, NULL}, + {PrimType_FloorDiv, ActType_No, ElementFloorDiv, ElementFloorDivInt, NULL, ElementOptFloorDiv, + ElementOptFloorDivInt, NULL}, + {PrimType_Mod, ActType_No, ElementMod, ElementModInt, NULL, ElementOptMod, ElementOptModInt, NULL}, + {PrimType_SquaredDifference, ActType_No, ElementSquaredDifference, NULL, NULL, ElementOptSquaredDifference, NULL, + NULL}}; + + size_t length = sizeof(fun_table) / sizeof(ArithmeticFuncions); + for (size_t i = 0; i < length; i++) { + if (fun_table[i].primitive_type_ == arithmetic->primitive_type_ && + fun_table[i].activation_type_ == ((ArithmeticParameter *)(arithmetic->base_.param_))->activation_type_) { + arithmetic->functions_ = fun_table[i]; + return; + } + } +} + +int ArithmeticRelease(struct KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + for (int i = 0; i < TWO_TENSOR; i++) { + if (arithmetic->broadcast_buffer_[i] != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->broadcast_buffer_[i]); + arithmetic->broadcast_buffer_[i] = NULL; + } + } + + for (int i = 0; i < arithmetic->block_boundary_infos_size_; i++) { + if (arithmetic->block_boundary_infos_[i].a_offset_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->block_boundary_infos_[i].a_offset_); + arithmetic->block_boundary_infos_[i].a_offset_ = NULL; + } + if (arithmetic->block_boundary_infos_[i].b_offset_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->block_boundary_infos_[i].b_offset_); + arithmetic->block_boundary_infos_[i].b_offset_ = NULL; + } + } + arithmetic->block_boundary_infos_size_ = 0; + + if (arithmetic->a_matrix_.batch_post_sum_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->a_matrix_.batch_post_sum_); + arithmetic->a_matrix_.batch_post_sum_ = NULL; + } + + if (arithmetic->b_matrix_.batch_post_sum_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->b_matrix_.batch_post_sum_); + arithmetic->b_matrix_.batch_post_sum_ = NULL; + } + + if (arithmetic->c_matrix_.batch_post_sum_ != NULL) { + self->env_->Free(self->env_->allocator_, arithmetic->c_matrix_.batch_post_sum_); + arithmetic->c_matrix_.batch_post_sum_ = NULL; + } + return NNACL_OK; +} + +void ArithmeticComputeOffset(ArithmeticStruct *arithmetic, int task_id) { + ArithmeticBlockBoundaryInfo *block_info = &arithmetic->block_boundary_infos_[task_id]; + block_info->init_offset_ = true; + + int64_t b_start = block_info->batch_begin_; + int64_t b_end = block_info->batch_end_; + int64_t s_end = block_info->size_end_; + if (s_end != 0) { + ++b_end; + } + int offset_index = 0; + for (; b_start < b_end; ++b_start) { + int64_t delta = b_start; + int64_t a_offset = 0; + int64_t b_offset = 0; + for (int j = 0; j <= arithmetic->batch_tail_dim_; ++j) { + if (j > 0) { + delta = delta % arithmetic->c_matrix_.batch_post_sum_[j]; + } + if (j < arithmetic->batch_tail_dim_) { + a_offset += (delta / arithmetic->c_matrix_.batch_post_sum_[j + 1] * arithmetic->a_matrix_.shape_[j] / + arithmetic->c_matrix_.shape_[j]) * + arithmetic->a_matrix_.batch_post_sum_[j + 1]; + b_offset += (delta / arithmetic->c_matrix_.batch_post_sum_[j + 1] * arithmetic->b_matrix_.shape_[j] / + arithmetic->c_matrix_.shape_[j]) * + arithmetic->b_matrix_.batch_post_sum_[j + 1]; + } else { + a_offset += (delta * arithmetic->a_matrix_.shape_[j] / arithmetic->c_matrix_.shape_[j]); + b_offset += (delta * arithmetic->b_matrix_.shape_[j] / arithmetic->c_matrix_.shape_[j]); + } + } + block_info->a_offset_[offset_index] = a_offset * arithmetic->a_matrix_.inner_size_ * arithmetic->in_data_size_; + block_info->b_offset_[offset_index] = b_offset * arithmetic->b_matrix_.inner_size_ * arithmetic->in_data_size_; + offset_index++; + } +} + +int ArithmeticDoExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)base; + int data_type = arithmetic->base_.in_[FIRST_INPUT]->data_type_; + NNACL_CHECK_NULL_RETURN_ERR(input0); + NNACL_CHECK_NULL_RETURN_ERR(input1); + + if (data_type == kNumberTypeFloat32) { + if (arithmetic->scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_f32_); + return arithmetic->functions_.optimzie_f32_((const float *)input0, (const float *)input1, (float *)output, size, + arithmetic->in_elements_num0_ == 1); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_f32_); + return arithmetic->functions_.compute_f32_((const float *)input0, (const float *)input1, (float *)output, size); + } + } + + if (data_type == kNumberTypeBool) { + if (arithmetic->scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_bool_); + return arithmetic->functions_.optimzie_bool_((const bool *)input0, (const bool *)input1, (bool *)output, size, + arithmetic->in_elements_num0_ == 1); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_bool_); + return arithmetic->functions_.compute_bool_((const bool *)input0, (const bool *)input1, (bool *)output, size); + } + } + + if (data_type == kNumberTypeInt32) { + if (arithmetic->scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_int_); + return arithmetic->functions_.optimzie_int_((const int *)input0, (const int *)input1, (int *)output, size, + arithmetic->in_elements_num0_ == 1); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_int_); + return arithmetic->functions_.compute_int_((const int *)input0, (const int *)input1, (int *)output, size); + } + } + + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int ArithmeticRun(void *cdata, int task_id, float l, float r) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)cdata; + NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR); + NNACL_CHECK_FALSE(task_id >= arithmetic->block_boundary_infos_size_, NNACL_ERR); + + if (arithmetic->block_boundary_infos_[task_id].init_offset_ == false) { + ArithmeticComputeOffset(arithmetic, task_id); + } + + ArithmeticBlockBoundaryInfo *block_info = &arithmetic->block_boundary_infos_[task_id]; + int64_t b_start = block_info->batch_begin_; + int64_t s_start = block_info->size_begin_; + int64_t s_end = block_info->size_end_; + int64_t index_start = 0; + int64_t index_end = block_info->batch_end_ - b_start; + uint8_t *a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start]; + uint8_t *b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start]; + uint8_t *c_ptr = (uint8_t *)(arithmetic->c_matrix_.data_) + + (b_start * arithmetic->c_matrix_.inner_size_ + s_start) * arithmetic->out_data_size_; + if (arithmetic->a_matrix_.inner_size_ > 1) { + a_ptr += s_start * arithmetic->in_data_size_; + } + if (arithmetic->b_matrix_.inner_size_ > 1) { + b_ptr += s_start * arithmetic->in_data_size_; + } + + if (index_start == index_end) { + return arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, s_end - s_start); + } + + int64_t size = arithmetic->c_matrix_.inner_size_ - s_start; + int ret = arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, size); + if (ret != NNACL_OK) { + return ret; + } + + ++index_start; + c_ptr += size * arithmetic->out_data_size_; + int64_t c_stride = arithmetic->c_matrix_.inner_size_ * arithmetic->out_data_size_; + for (; index_start < index_end; ++index_start) { + a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start]; + b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start]; + ret = arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, arithmetic->c_matrix_.inner_size_); + if (ret != NNACL_OK) { + return ret; + } + c_ptr += c_stride; + } + if (s_end == 0) { + return NNACL_OK; + } + a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start]; + b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start]; + return arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, s_end); +} + +void ResetArithmeticMatric(KernelBase *base, ArithmeticMatrixInfo *matrix) { + matrix->is_valid_ = false; + matrix->data_ = NULL; + matrix->inner_size_ = 1; + matrix->shape_size_ = 0; + + if (matrix->batch_post_sum_ != NULL) { + base->env_->Free(base->env_->allocator_, matrix->batch_post_sum_); + matrix->batch_post_sum_ = NULL; + } +} + +int UpdateArithmeticParameter(ArithmeticStruct *arithmetic) { + NNACL_CHECK_TRUE_RET(arithmetic->a_matrix_.shape_size_ == arithmetic->b_matrix_.shape_size_, + NNACL_ARITHMETIC_SHAPE_INVALID); + + arithmetic->ndim_ = arithmetic->a_matrix_.shape_size_; + ResetArithmeticMatric(&arithmetic->base_, &arithmetic->c_matrix_); + + for (size_t i = 0; i < arithmetic->ndim_; ++i) { + NNACL_CHECK_TRUE_RET(arithmetic->a_matrix_.shape_[i] <= INT_MAX, NNACL_ARITHMETIC_SHAPE_INVALID); + NNACL_CHECK_TRUE_RET(arithmetic->b_matrix_.shape_[i] <= INT_MAX, NNACL_ARITHMETIC_SHAPE_INVALID); + arithmetic->in_shape0_[i] = arithmetic->a_matrix_.shape_[i]; + arithmetic->in_shape1_[i] = arithmetic->b_matrix_.shape_[i]; + arithmetic->out_shape_[i] = MSMAX(arithmetic->in_shape0_[i], arithmetic->in_shape1_[i]); + arithmetic->c_matrix_.shape_[arithmetic->c_matrix_.shape_size_++] = + MSMAX(arithmetic->a_matrix_.shape_[i], arithmetic->b_matrix_.shape_[i]); + } + return NNACL_OK; +} + +int OptimizeArithmeticShape(ArithmeticStruct *arithmetic) { + ArithmeticMatrixInfo *a = &arithmetic->a_matrix_; + ArithmeticMatrixInfo *b = &arithmetic->b_matrix_; + arithmetic->ndim_ = a->shape_size_ >= b->shape_size_ ? a->shape_size_ : b->shape_size_; + + int shape0[MAX_LEN] = {0}; + int shape1[MAX_LEN] = {0}; + /* init a & b shape */ + int i = 0; + for (; i < arithmetic->ndim_; ++i) { + shape0[i] = 1; + shape1[i] = 1; + } + + /* init matrix shape dim */ + int a_matrix_size = arithmetic->ndim_ - a->shape_size_; + for (i = a_matrix_size; i < arithmetic->ndim_; i++) { + shape0[i] = a->shape_[i - a_matrix_size]; + } + + int b_matrix_size = arithmetic->ndim_ - b->shape_size_; + for (i = b_matrix_size; i < arithmetic->ndim_; i++) { + shape1[i] = b->shape_[i - b_matrix_size]; + } + + /* horizontal shape dims */ + int shape0_temp[MAX_LEN] = {0}; + int shape1_temp[MAX_LEN] = {0}; + int shape_temp_size = 0; + for (i = 0; i < arithmetic->ndim_;) { // horizontal comparison, merge the part of continuous 1. + shape0_temp[shape_temp_size] = shape0[i]; + shape1_temp[shape_temp_size] = shape1[i]; + shape_temp_size++; + if (shape0[i] != 1 && shape1[i] != 1) { + ++i; + continue; + } + + size_t j0 = i; + while (j0 < arithmetic->ndim_ && shape0[j0] == 1) { + ++j0; + } + size_t j1 = i; + while (j1 < arithmetic->ndim_ && shape1[j1] == 1) { + ++j1; + } + size_t j = MSMAX(j0, j1); + while ((++i) < j) { + shape0_temp[shape_temp_size - 1] *= shape0[i]; + shape1_temp[shape_temp_size - 1] *= shape1[i]; + } + } + + arithmetic->a_matrix_.shape_size_ = 0; + arithmetic->b_matrix_.shape_size_ = 0; + + for (i = 0; i < shape_temp_size;) { // vertical comparison, merge the part of continuous equation. + if (shape0_temp[i] == 1 && shape1_temp[i] == 1) { + ++i; + continue; + } + shape0[arithmetic->a_matrix_.shape_size_++] = shape0_temp[i]; + shape1[arithmetic->b_matrix_.shape_size_++] = shape1_temp[i]; + if (shape0_temp[i] != shape1_temp[i]) { + ++i; + continue; + } + while ((++i) < shape_temp_size) { + if (shape0_temp[i] != shape1_temp[i]) { + break; + } + shape0[arithmetic->a_matrix_.shape_size_ - 1] *= shape0_temp[i]; + shape1[arithmetic->b_matrix_.shape_size_ - 1] *= shape1_temp[i]; + } + } + + memcpy(arithmetic->a_matrix_.shape_, shape0, arithmetic->a_matrix_.shape_size_ * sizeof(int)); + memcpy(arithmetic->b_matrix_.shape_, shape1, arithmetic->b_matrix_.shape_size_ * sizeof(int)); + + return UpdateArithmeticParameter(arithmetic); +} + +int ResetArithmeticStatus(ArithmeticStruct *arithmetic) { + ResetArithmeticMatric(&arithmetic->base_, &arithmetic->a_matrix_); + ResetArithmeticMatric(&arithmetic->base_, &arithmetic->b_matrix_); + ResetArithmeticMatric(&arithmetic->base_, &arithmetic->c_matrix_); + + arithmetic->a_matrix_.shape_size_ = arithmetic->base_.in_[FIRST_INPUT]->shape_size_; + memcpy(arithmetic->a_matrix_.shape_, arithmetic->base_.in_[FIRST_INPUT]->shape_, + arithmetic->a_matrix_.shape_size_ * sizeof(int)); + arithmetic->b_matrix_.shape_size_ = arithmetic->base_.in_[SECOND_INPUT]->shape_size_; + memcpy(arithmetic->b_matrix_.shape_, arithmetic->base_.in_[SECOND_INPUT]->shape_, + arithmetic->b_matrix_.shape_size_ * sizeof(int)); + + return OptimizeArithmeticShape(arithmetic); +} + +void ArithmeticDoBroadcast(ArithmeticStruct *arithmetic, void *in_data, void *out_data, int input_index) { + int *in_shape = input_index == FIRST_INPUT ? arithmetic->in_shape0_ : arithmetic->in_shape1_; + int *in_stride = input_index == FIRST_INPUT ? arithmetic->in_strides0_ : arithmetic->in_strides1_; + int *multiples = input_index == FIRST_INPUT ? arithmetic->multiples0_ : arithmetic->multiples1_; + return arithmetic->tile_function_(in_data, out_data, 0, arithmetic->ndim_, in_shape, in_stride, + arithmetic->out_strides_, multiples); +} + +int ArithmeticBroadCastConstTensor(ArithmeticStruct *arithmetic) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + + CalcStructMultiplesAndStrides(arithmetic); + +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE + bool prefer_explicit_broadcast = false; +#else + bool prefer_explicit_broadcast = arithmetic->ndim_ != 1; +#endif + prefer_explicit_broadcast = + prefer_explicit_broadcast && (arithmetic->base_.in_[FIRST_INPUT]->data_type_ != kNumberTypeBool); + + bool exist_broadcast_ = false; + int buffer_size = NNACLGetElementNum(arithmetic->base_.out_[OUTPUT_INDEX]) * arithmetic->in_data_size_; + if (arithmetic->a_matrix_.is_const_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[FIRST_INPUT]->data_); + if (arithmetic->in_elements_num0_ != arithmetic->out_elements_num_ && prefer_explicit_broadcast) { + exist_broadcast_ = true; + + arithmetic->a_matrix_.data_ = arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, buffer_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.data_); + arithmetic->broadcast_buffer_[Index0] = arithmetic->a_matrix_.data_; + + ArithmeticDoBroadcast(arithmetic, arithmetic->base_.in_[FIRST_INPUT]->data_, arithmetic->a_matrix_.data_, Index0); + arithmetic->in_elements_num0_ = arithmetic->out_elements_num_; + + // shape must be equal to out + for (size_t i = 0; i < arithmetic->ndim_; ++i) { + arithmetic->in_shape0_[i] = arithmetic->out_shape_[i]; + arithmetic->in_strides0_[i] = arithmetic->out_strides_[i]; + } + memcpy(arithmetic->a_matrix_.shape_, arithmetic->c_matrix_.shape_, arithmetic->ndim_ * sizeof(int)); + arithmetic->a_matrix_.is_valid_ = true; + } + } + + if (arithmetic->b_matrix_.is_const_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[SECOND_INPUT]->data_); + if (arithmetic->in_elements_num1_ != arithmetic->out_elements_num_ && prefer_explicit_broadcast) { + exist_broadcast_ = true; + + arithmetic->b_matrix_.data_ = arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, buffer_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.data_); + arithmetic->broadcast_buffer_[Index1] = arithmetic->b_matrix_.data_; + + ArithmeticDoBroadcast(arithmetic, arithmetic->base_.in_[Index1]->data_, arithmetic->b_matrix_.data_, Index1); + arithmetic->in_elements_num1_ = arithmetic->out_elements_num_; + // shape must be equal to out + for (size_t i = 0; i < arithmetic->ndim_; ++i) { + arithmetic->in_shape1_[i] = arithmetic->out_shape_[i]; + arithmetic->in_strides1_[i] = arithmetic->out_strides_[i]; + } + + memcpy(arithmetic->b_matrix_.shape_, arithmetic->c_matrix_.shape_, arithmetic->ndim_ * sizeof(int)); + arithmetic->b_matrix_.is_valid_ = true; + } + } + if (!exist_broadcast_) { + return NNACL_OK; + } + return OptimizeArithmeticShape(arithmetic); +} + +int ArithmeticComputeOfflineInfo(ArithmeticStruct *arithmetic) { + int bread_pos = -1; + int last_dim = arithmetic->a_matrix_.shape_size_ - 1; + for (int i = last_dim; i >= 0; --i) { + if (arithmetic->a_matrix_.shape_[i] != arithmetic->b_matrix_.shape_[i]) { + bread_pos = i; + break; + } + } + arithmetic->batch_tail_dim_ = bread_pos; + if (bread_pos == last_dim && arithmetic->batch_tail_dim_ >= 0) { + --arithmetic->batch_tail_dim_; + } + + for (int i = last_dim; i > arithmetic->batch_tail_dim_; --i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->a_matrix_.inner_size_, arithmetic->a_matrix_.shape_[i], NNACL_ERR); + arithmetic->a_matrix_.inner_size_ *= arithmetic->a_matrix_.shape_[i]; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->b_matrix_.inner_size_, arithmetic->b_matrix_.shape_[i], NNACL_ERR); + arithmetic->b_matrix_.inner_size_ *= arithmetic->b_matrix_.shape_[i]; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->c_matrix_.inner_size_, arithmetic->c_matrix_.shape_[i], NNACL_ERR); + arithmetic->c_matrix_.inner_size_ *= arithmetic->c_matrix_.shape_[i]; + } + + arithmetic->a_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc( + arithmetic->base_.env_->allocator_, (arithmetic->a_matrix_.shape_size_ + 1) * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.batch_post_sum_); + for (int i = 0; i < arithmetic->a_matrix_.shape_size_ + 1; i++) { + arithmetic->a_matrix_.batch_post_sum_[i] = 1; + } + + arithmetic->b_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc( + arithmetic->base_.env_->allocator_, (arithmetic->b_matrix_.shape_size_ + 1) * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.batch_post_sum_); + for (int i = 0; i < arithmetic->b_matrix_.shape_size_ + 1; i++) { + arithmetic->b_matrix_.batch_post_sum_[i] = 1; + } + + arithmetic->c_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc( + arithmetic->base_.env_->allocator_, (arithmetic->c_matrix_.shape_size_ + 1) * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->c_matrix_.batch_post_sum_); + for (int i = 0; i < arithmetic->c_matrix_.shape_size_ + 1; i++) { + arithmetic->c_matrix_.batch_post_sum_[i] = 1; + } + + for (int i = arithmetic->batch_tail_dim_; i >= 0; --i) { + if (i == arithmetic->batch_tail_dim_) { + arithmetic->a_matrix_.batch_post_sum_[i] = arithmetic->a_matrix_.shape_[i]; + arithmetic->b_matrix_.batch_post_sum_[i] = arithmetic->b_matrix_.shape_[i]; + arithmetic->c_matrix_.batch_post_sum_[i] = arithmetic->c_matrix_.shape_[i]; + } else { + arithmetic->a_matrix_.batch_post_sum_[i] = + arithmetic->a_matrix_.shape_[i] * arithmetic->a_matrix_.batch_post_sum_[i + 1]; + arithmetic->b_matrix_.batch_post_sum_[i] = + arithmetic->b_matrix_.shape_[i] * arithmetic->b_matrix_.batch_post_sum_[i + 1]; + arithmetic->c_matrix_.batch_post_sum_[i] = + arithmetic->c_matrix_.shape_[i] * arithmetic->c_matrix_.batch_post_sum_[i + 1]; + } + } + + arithmetic->scalar_opt_ = false; + if (arithmetic->a_matrix_.inner_size_ == 1) { + arithmetic->in_elements_num0_ = 1; + arithmetic->scalar_opt_ = true; + } + if (arithmetic->b_matrix_.inner_size_ == 1) { + arithmetic->in_elements_num1_ = 1; + arithmetic->scalar_opt_ = true; + } + return NNACL_OK; +} + +int ArithmeticChooseThreadCuttingStrategy(ArithmeticStruct *arithmetic) { + int total_num = NNACLGetElementNum(arithmetic->base_.out_[OUTPUT_INDEX]); + arithmetic->base_.thread_nr_ = + arithmetic->base_.UpdateThread(TC_TYPE(arithmetic->primitive_type_, arithmetic->functions_.activation_type_), 1, 1, + total_num, arithmetic->base_.thread_nr_); + + int64_t block_size = UP_DIV(total_num, arithmetic->base_.thread_nr_); + int64_t split_point = 0; + while (split_point < total_num) { + int64_t start = split_point; + int64_t end = start + block_size; + if (end > total_num) { + end = total_num; + } + ArithmeticBlockBoundaryInfo block_boundary_info; + block_boundary_info.size_begin_ = start % arithmetic->c_matrix_.inner_size_; + block_boundary_info.size_end_ = end % arithmetic->c_matrix_.inner_size_; + block_boundary_info.batch_begin_ = start / arithmetic->c_matrix_.inner_size_; + block_boundary_info.batch_end_ = end / arithmetic->c_matrix_.inner_size_; + block_boundary_info.init_offset_ = false; + + int max_offset_size = block_boundary_info.batch_end_ - block_boundary_info.batch_begin_ + TWO_TENSOR; + block_boundary_info.a_offset_ = + (int *)arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, max_offset_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(block_boundary_info.a_offset_); + block_boundary_info.b_offset_ = + (int *)arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, max_offset_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(block_boundary_info.b_offset_); + + arithmetic->block_boundary_infos_[arithmetic->block_boundary_infos_size_++] = block_boundary_info; + split_point = end; + } + + arithmetic->base_.thread_nr_ = arithmetic->block_boundary_infos_size_; + return NNACL_OK; +} + +int ArithmeticResize(struct KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + + ArithmeticRelease(&arithmetic->base_); + + NNACL_CHECK_TRUE_RET(arithmetic->in_data_size_ != 0, NNACL_UNSUPPORTED_DATA_TYPE); + NNACL_CHECK_TRUE_RET(arithmetic->out_data_size_ != 0, NNACL_UNSUPPORTED_DATA_TYPE); + arithmetic->in_elements_num0_ = NNACLGetElementNum(self->in_[FIRST_INPUT]); + arithmetic->in_elements_num1_ = NNACLGetElementNum(self->in_[SECOND_INPUT]); + arithmetic->out_elements_num_ = NNACLGetElementNum(self->in_[OUTPUT_INDEX]); + + int ret = ResetArithmeticStatus(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + + ret = ArithmeticBroadCastConstTensor(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + + ret = ArithmeticComputeOfflineInfo(arithmetic); + if (ret != NNACL_OK) { + return ret; + } + + return ArithmeticChooseThreadCuttingStrategy(arithmetic); +} + +int ArithmeticPrepare(struct KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ < kNumberTypeBegin, NNACL_ERR); + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR); + NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR); + NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR); + + if (self->param_->quant_type_ != Quant_None) { + return NNACL_ERR; + } + + arithmetic->primitive_type_ = self->param_->type_; + if (self->param_->type_ == PrimType_Eltwise) { + switch (((ArithmeticParameter *)(self->param_))->eltwise_mode_) { + case Eltwise_PROD: + arithmetic->primitive_type_ = PrimType_MulFusion; + break; + case Eltwise_SUM: + arithmetic->primitive_type_ = PrimType_AddFusion; + break; + case Eltwise_MAXIMUM: + arithmetic->primitive_type_ = PrimType_Maximum; + break; + default: + return NNACL_ELTWISE_INVALID_MOD; + } + } + arithmetic->init_function_(self); + + arithmetic->a_matrix_.is_const_ = NNACLIsConst(self->in_[FIRST_INPUT]); + arithmetic->b_matrix_.is_const_ = NNACLIsConst(self->in_[SECOND_INPUT]); + return NNACL_OK; +} + +int ArithmeticCompute(struct KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ != self->in_[SECOND_INPUT]->data_type_, + NNACL_ARITHMETIC_DATA_TYPE_UNMATCH); + + if (self->train_session_) { + arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + } + + if (false == arithmetic->a_matrix_.is_valid_) { + arithmetic->a_matrix_.data_ = self->in_[FIRST_INPUT]->data_; + } + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.data_); + + if (false == arithmetic->b_matrix_.is_valid_) { + arithmetic->b_matrix_.data_ = self->in_[SECOND_INPUT]->data_; + } + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.data_); + + arithmetic->c_matrix_.data_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic->c_matrix_.data_); + + return self->env_->ParallelLaunch(self->env_->thread_pool_, ArithmeticRun, self, self->thread_nr_); +} + +KernelBase *CreateArithmetic(OpParameter *param, int data_type) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)malloc(sizeof(ArithmeticStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic); + memset(arithmetic, 0, sizeof(ArithmeticStruct)); + arithmetic->in_data_size_ = DataTypeCSize(data_type); + arithmetic->out_data_size_ = DataTypeCSize(data_type); + arithmetic->block_boundary_infos_size_ = 0; + arithmetic->a_matrix_.batch_post_sum_ = NULL; + arithmetic->b_matrix_.batch_post_sum_ = NULL; + arithmetic->c_matrix_.batch_post_sum_ = NULL; + arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL; + arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL; + arithmetic->tile_function_ = TileOneDimensionFp32; + arithmetic->init_function_ = InitArithmeticRunFunction; + arithmetic->execute_ = ArithmeticDoExecute; + arithmetic->base_.Prepare = ArithmeticPrepare; + arithmetic->base_.Resize = ArithmeticResize; + arithmetic->base_.Release = ArithmeticRelease; + arithmetic->base_.Compute = ArithmeticCompute; + return (KernelBase *)arithmetic; +} + +REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeBool, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_RealDiv, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Mod, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Mod, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeBool, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeBool, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeInt32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_SquaredDifference, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_Eltwise, kNumberTypeFloat32, CreateArithmetic) +REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeInt32, CreateArithmetic) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic.h new file mode 100644 index 0000000000000000000000000000000000000000..82e03d41497d52da47a54d7b798f2ee4e040b386 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic.h @@ -0,0 +1,97 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_ARITHMETIC_H_ +#define NNACL_KERNEL_ARITHMETIC_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/arithmetic_parameter.h" + +typedef struct ArithmeticFuncions { + int primitive_type_; + int activation_type_; + int (*compute_f32_)(const float *in1, const float *in2, float *out, int ele); + int (*compute_int_)(const int *in1, const int *in2, int *out, int ele); + int (*compute_bool_)(const bool *in1, const bool *in2, bool *out, int ele); + int (*optimzie_f32_)(const float *in1, const float *in2, float *out, int ele, bool scalar); + int (*optimzie_int_)(const int *in1, const int *in2, int *out, int ele, bool scalar); + int (*optimzie_bool_)(const bool *in1, const bool *in2, bool *out, int ele, bool scalar); +} ArithmeticFuncions; + +typedef struct ArithmeticMatrixInfo { + bool is_const_; + bool is_valid_; + void *data_; + int64_t inner_size_; + int shape_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int shape_size_; + int *batch_post_sum_; /* shape size + 1 */ +} ArithmeticMatrixInfo; + +typedef struct ArithmeticBlockBoundaryInfo { + int batch_begin_; + int batch_end_; + int size_begin_; // start-offset under the begin batch + int size_end_; // end-num under the ending batch + int *a_offset_; + int *b_offset_; + bool init_offset_; +} ArithmeticBlockBoundaryInfo; + +typedef struct ArithmeticStruct { + KernelBase base_; + bool scalar_opt_; + int primitive_type_; + int ndim_; + int in_data_size_; + int out_data_size_; + int batch_tail_dim_; + + ArithmeticMatrixInfo a_matrix_; + ArithmeticMatrixInfo b_matrix_; + ArithmeticMatrixInfo c_matrix_; + ArithmeticFuncions functions_; + + void *broadcast_buffer_[TWO_TENSOR]; + int block_boundary_infos_size_; + ArithmeticBlockBoundaryInfo block_boundary_infos_[MAX_THREAD_NUM]; + + int in_shape0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int in_elements_num0_; + int in_shape1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int in_elements_num1_; + int out_shape_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int out_elements_num_; + int in_strides0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int in_strides1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int out_strides_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int multiples0_[ARITHMETIC_SUPPORT_DIMS_NUM]; + int multiples1_[ARITHMETIC_SUPPORT_DIMS_NUM]; + + void (*tile_function_)(const void *inPtr, void *outPtr, int dim, size_t ndim, const int *inShape, + const int *inStrides, const int *outStrides, const int *multiple); + int (*execute_)(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size); + void (*init_function_)(KernelBase *base); +} ArithmeticStruct; + +KernelBase *CreateArithmetic(OpParameter *param, int data_type); +int ArithmeticPrepare(struct KernelBase *self); +int ArithmeticRelease(struct KernelBase *self); +int ArithmeticCompute(struct KernelBase *self); +int ArithmeticResize(struct KernelBase *self); + +#endif // NNACL_KERNEL_ARITHMETIC_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic_compare.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic_compare.c new file mode 100644 index 0000000000000000000000000000000000000000..4cb3daaa06cf1b045120e4fa1575f2bee2553724 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic_compare.c @@ -0,0 +1,166 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/arithmetic_compare.h" +#include "nnacl/kernel/arithmetic.h" +#include "nnacl/fp32/arithmetic_fp32.h" +#include "nnacl/fp32/arithmetic_compare_fp32.h" + +typedef struct ArithmeticCompareFuncions { + int primitive_type_; + int (*compute_f32_)(const float *input0, const float *input1, uint8_t *output, int element_size); + int (*compute_i32_)(const int *input0, const int *input1, uint8_t *output, int element_size); + int (*optimize_f32)(const float *input0, const float *input1, uint8_t *output, int element_size, bool first_scalar); + int (*optimize_i32)(const int *input0, const int *input1, uint8_t *output, int element_size, bool first_scalar); + int (*compute_i64)(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size); + int (*optimize_i64)(const int64_t *input0, const int64_t *input1, uint8_t *output, int element_size, + bool first_scalar); + int (*compute_bool)(const bool *input0, const bool *input1, uint8_t *output, int element_size); +} ArithmeticCompareFuncions; + +typedef struct ArithmeticCompareStruct { + ArithmeticStruct arithmetic_; + ArithmeticCompareFuncions functions_; +} ArithmeticCompareStruct; + +void InitArithmeticCompareRunFunction(KernelBase *self) { + ArithmeticCompareStruct *arithmetic_compare = (ArithmeticCompareStruct *)self; + NNACL_CHECK_NULL_RETURN_VOID(arithmetic_compare); + + ArithmeticCompareFuncions fun_table[] = { + {PrimType_Equal, ElementEqualFp32, ElementEqualInt32, ElementOptEqualFp32, ElementOptEqualInt32, NULL, NULL, + ElementEqualBool}, + {PrimType_NotEqual, ElementNotEqualFp32, ElementNotEqualInt32, ElementOptNotEqualFp32, ElementOptNotEqualInt32, + ElementNotEqualInt64, ElementOptNotEqualInt64, NULL}, + {PrimType_Less, ElementLessFp32, ElementLessInt32, ElementOptLessFp32, ElementOptLessInt32, NULL, NULL, NULL}, + {PrimType_LessEqual, ElementLessEqualFp32, ElementLessEqualInt32, ElementOptLessEqualFp32, ElementOptLessEqualInt32, + NULL, NULL, NULL}, + {PrimType_Greater, ElementGreaterFp32, ElementGreaterInt32, ElementOptGreaterFp32, ElementOptGreaterInt32, NULL, + NULL, NULL}, + {PrimType_GreaterEqual, ElementGreaterEqualFp32, ElementGreaterEqualInt32, ElementOptGreaterEqualFp32, + ElementOptGreaterEqualInt32, NULL, NULL, NULL}}; + + size_t length = sizeof(fun_table) / sizeof(ArithmeticCompareFuncions); + for (size_t i = 0; i < length; i++) { + if (fun_table[i].primitive_type_ == arithmetic_compare->arithmetic_.primitive_type_) { + arithmetic_compare->functions_ = fun_table[i]; + return; + } + } +} + +int ArithmeticCompareExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) { + ArithmeticCompareStruct *arithmetic_compare = (ArithmeticCompareStruct *)base; + NNACL_CHECK_NULL_RETURN_ERR(input0); + NNACL_CHECK_NULL_RETURN_ERR(input1); + + int data_type = base->in_[FIRST_INPUT]->data_type_; + bool first_scalar = arithmetic_compare->arithmetic_.in_elements_num0_ == 1; + + if (data_type == kNumberTypeFloat32) { + if (arithmetic_compare->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.optimize_f32); + return arithmetic_compare->functions_.optimize_f32((const float *)input0, (const float *)input1, + (uint8_t *)output, size, first_scalar); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_f32_); + return arithmetic_compare->functions_.compute_f32_((const float *)input0, (const float *)input1, + (uint8_t *)output, size); + } + } + + if (data_type == kNumberTypeInt || data_type == kNumberTypeInt32) { + if (arithmetic_compare->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.optimize_i32); + return arithmetic_compare->functions_.optimize_i32((const int *)input0, (const int *)input1, (uint8_t *)output, + size, first_scalar); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_i32_); + return arithmetic_compare->functions_.compute_i32_((const int *)input0, (const int *)input1, (uint8_t *)output, + size); + } + } + + if (data_type == kNumberTypeInt64) { + if (arithmetic_compare->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.optimize_i64); + return arithmetic_compare->functions_.optimize_i64((const int64_t *)input0, (const int64_t *)input1, + (uint8_t *)output, size, first_scalar); + } else { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_i64); + return arithmetic_compare->functions_.compute_i64((const int64_t *)input0, (const int64_t *)input1, + (uint8_t *)output, size); + } + } + if (data_type == kNumberTypeBool) { + if (!arithmetic_compare->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare->functions_.compute_bool); + return arithmetic_compare->functions_.compute_bool((const bool *)input0, (const bool *)input1, (uint8_t *)output, + size); + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + } + + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int ArithmeticCompareResize(KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + arithmetic->out_data_size_ = DataTypeCSize(self->out_[OUTPUT_INDEX]->data_type_); + return ArithmeticResize(self); +} + +KernelBase *CreateArithmeticCompare(OpParameter *param, int data_type) { + ArithmeticCompareStruct *arithmetic_compare = (ArithmeticCompareStruct *)malloc(sizeof(ArithmeticCompareStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic_compare); + memset(arithmetic_compare, 0, sizeof(ArithmeticCompareStruct)); + + ArithmeticStruct *arithmetic = (ArithmeticStruct *)arithmetic_compare; + arithmetic->in_data_size_ = DataTypeCSize(data_type); + arithmetic->out_data_size_ = DataTypeCSize(data_type); + arithmetic->block_boundary_infos_size_ = 0; + arithmetic->a_matrix_.batch_post_sum_ = NULL; + arithmetic->b_matrix_.batch_post_sum_ = NULL; + arithmetic->c_matrix_.batch_post_sum_ = NULL; + arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL; + arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL; + arithmetic->tile_function_ = TileOneDimensionFp32; + arithmetic->init_function_ = InitArithmeticCompareRunFunction; + arithmetic->execute_ = ArithmeticCompareExecute; + arithmetic->base_.Prepare = ArithmeticPrepare; + arithmetic->base_.Resize = ArithmeticCompareResize; + arithmetic->base_.Release = ArithmeticRelease; + arithmetic->base_.Compute = ArithmeticCompute; + return (KernelBase *)arithmetic_compare; +} + +REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeBool, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeInt64, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeInt32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeFloat32, CreateArithmeticCompare) +REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeInt32, CreateArithmeticCompare) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic_compare.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic_compare.h new file mode 100644 index 0000000000000000000000000000000000000000..8cf9dfba7682bd17f93cc155ad7bf0b87df2fd66 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic_compare.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ARITHMETIC_COMPARE_H_ +#define NNACL_KERNEL_ARITHMETIC_COMPARE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateArithmeticCompare(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ARITHMETIC_COMPARE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic_self.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic_self.c new file mode 100644 index 0000000000000000000000000000000000000000..a24f6a41ff7eb00f780a22d39d0b008e6d7b6a04 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic_self.c @@ -0,0 +1,199 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/arithmetic_self.h" +#include "nnacl/fp32/arithmetic_self_fp32.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/arithmetic_self_fp16.h" +#endif + +void ArithmeticSelfGetArithmeticSelfFunction(ArithmeticSelfStruct *arithmetic_self, int primitive_type) { + ArithmeticSelfFunction type_func_table[] = { + {PrimType_Abs, ElementAbs, NULL, ElementAbsInt, NULL}, + {PrimType_Cos, ElementCos, NULL, NULL, NULL}, + {PrimType_Log, ElementLog, NULL, NULL, NULL}, + {PrimType_Log1p, ElementLog1p, NULL, NULL, NULL}, + {PrimType_Square, ElementSquare, NULL, NULL, NULL}, + {PrimType_Sqrt, ElementSqrt, NULL, NULL, NULL}, + {PrimType_Rsqrt, ElementRsqrt, NULL, NULL, NULL}, + {PrimType_Sin, ElementSin, NULL, NULL, NULL}, + {PrimType_LogicalNot, ElementLogicalNot, ElementLogicalNotBool, NULL, NULL}, + {PrimType_Floor, ElementFloor, NULL, NULL, NULL}, + {PrimType_Ceil, ElementCeil, NULL, NULL, NULL}, + {PrimType_Round, ElementRound, NULL, NULL, NULL}, + {PrimType_Neg, ElementNegative, NULL, ElementNegativeInt, NULL}, + {PrimType_Reciprocal, ElementReciprocal, NULL, NULL, NULL}, + {PrimType_Erf, ElementErf, NULL, NULL, NULL}, + {PrimType_IsFinite, NULL, NULL, NULL, ElementIsFinite}}; + for (size_t i = 0; i < sizeof(type_func_table) / sizeof(ArithmeticSelfFunction); i++) { + if (type_func_table[i].primitive_type_ == primitive_type) { + arithmetic_self->function_ = type_func_table[i]; + return; + } + } +} + +void ArithmeticSelfGetArithmeticSelfF16Function(ArithmeticSelfStruct *arithmetic_self, int primitive_type) { +#ifdef ENABLE_FP16 + ArithmeticSelfF16Function type_func_table[] = {{PrimType_Abs, ElementAbsFp16}, + {PrimType_Cos, ElementCosFp16}, + {PrimType_Log, ElementLogFp16}, + {PrimType_Square, ElementSquareFp16}, + {PrimType_Sqrt, ElementSqrtFp16}, + {PrimType_Rsqrt, ElementRsqrtFp16}, + {PrimType_Sin, ElementSinFp16}, + {PrimType_LogicalNot, ElementLogicalNotFp16}, + {PrimType_Floor, ElementFloorFp16}, + {PrimType_Ceil, ElementCeilFp16}, + {PrimType_Round, ElementRoundFp16}, + {PrimType_Neg, ElementNegativeFp16}, + {PrimType_Reciprocal, ElementReciprocalFp16}, + {PrimType_Erf, ElementErfFp16}}; + for (size_t i = 0; i < sizeof(type_func_table) / sizeof(ArithmeticSelfF16Function); i++) { + if (type_func_table[i].primitive_type_ == primitive_type) { + arithmetic_self->f16_function_ = type_func_table[i]; + return; + } + } +#endif + arithmetic_self->f16_function_.primitive_type_ = primitive_type; + return; +} + +int ArithmeticSelfExecute(ArithmeticSelfStruct *arithmetic_self, int task_id) { + int elements_num = NNACLGetElementNum(arithmetic_self->base_.in_[FIRST_INPUT]); + NNACL_CHECK_TRUE_RET(arithmetic_self->base_.thread_nr_, NNACL_ERR); + int stride = UP_DIV(elements_num, arithmetic_self->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, stride, NNACL_ERR); + int offset = task_id * stride; + int count = NNACL_MIN(stride, elements_num - offset); + if (count <= 0) { + return NNACL_OK; + } + + void *in_data = arithmetic_self->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(in_data); + void *out_data = arithmetic_self->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + int in_data_type = arithmetic_self->base_.in_[FIRST_INPUT]->data_type_; + int out_data_type = arithmetic_self->base_.out_[OUTPUT_INDEX]->data_type_; + + if (in_data_type == kNumberTypeFloat32 && out_data_type == kNumberTypeBool) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_float_bool_); + return arithmetic_self->function_.func_float_bool_((float *)in_data + offset, (bool *)out_data + offset, count); + } + + if (in_data_type == kNumberTypeFloat32) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_); + return arithmetic_self->function_.func_((float *)in_data + offset, (float *)out_data + offset, count); + } + + if (in_data_type == kNumberTypeBool) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_bool_); + return arithmetic_self->function_.func_bool_((bool *)in_data + offset, (bool *)out_data + offset, count); + } + + if (in_data_type == kNumberTypeInt32) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->function_.func_int_); + return arithmetic_self->function_.func_int_((int32_t *)in_data + offset, (int32_t *)out_data + offset, count); + } + +#ifdef ENABLE_FP16 + if (in_data_type == kNumberTypeFloat16) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self->f16_function_.func_); + return arithmetic_self->f16_function_.func_((float16_t *)in_data + offset, (float16_t *)out_data + offset, count); + } +#endif + return NNACL_ARITHMETIC_SELF_DATA_TYPE_UNSUPPORT; +} + +int ArithmeticSelfRun(void *cdata, int task_id, float l, float r) { + ArithmeticSelfStruct *arithmetic_self = (ArithmeticSelfStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self); + return ArithmeticSelfExecute(arithmetic_self, task_id); +} + +int ArithmeticSelfResize(KernelBase *self) { + ArithmeticSelfStruct *arithmetic_self = (ArithmeticSelfStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_self); + self->thread_nr_ = arithmetic_self->base_.UpdateThread( + TC_PTYPE(arithmetic_self->op_type_), 1, 1, NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + return NNACL_OK; +} + +int ArithmeticSelfCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ArithmeticSelfRun, self, self->thread_nr_); +} + +int ArithmeticSelfPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ != ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ != ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_[OUTPUT_INDEX]->category_ == ConstTensor, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_[OUTPUT_INDEX]->category_ == ConstScalar, NNACL_OUTPUT_TENSOR_ERROR); + return NNACL_OK; +} + +KernelBase *CreateArithmeticSelf(OpParameter *param, int data_type) { + ArithmeticSelfStruct *arithmetic_self = (ArithmeticSelfStruct *)malloc(sizeof(ArithmeticSelfStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic_self); + ArithmeticSelfGetArithmeticSelfFunction(arithmetic_self, param->type_); + ArithmeticSelfGetArithmeticSelfF16Function(arithmetic_self, param->type_); + arithmetic_self->op_type_ = param->type_; + arithmetic_self->base_.Prepare = ArithmeticSelfPrepare; + arithmetic_self->base_.Resize = ArithmeticSelfResize; + arithmetic_self->base_.Release = DefaultRelease; + arithmetic_self->base_.Compute = ArithmeticSelfCompute; + return (KernelBase *)arithmetic_self; +} + +REG_KERNEL_CREATOR(PrimType_LogicalNot, kNumberTypeBool, CreateArithmeticSelf) + +REG_KERNEL_CREATOR(PrimType_Abs, kNumberTypeInt32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Neg, kNumberTypeInt32, CreateArithmeticSelf) + +REG_KERNEL_CREATOR(PrimType_Abs, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Ceil, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Cos, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Erf, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Floor, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Log, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Log1p, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_LogicalNot, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Neg, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Reciprocal, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Round, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Rsqrt, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Square, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Sqrt, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Sin, kNumberTypeFloat32, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_IsFinite, kNumberTypeFloat32, CreateArithmeticSelf) + +REG_KERNEL_CREATOR(PrimType_Abs, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Cos, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Log, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Square, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Sqrt, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Rsqrt, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Sin, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_LogicalNot, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Floor, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Ceil, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Round, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Neg, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Reciprocal, kNumberTypeFloat16, CreateArithmeticSelf) +REG_KERNEL_CREATOR(PrimType_Erf, kNumberTypeFloat16, CreateArithmeticSelf) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic_self.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic_self.h new file mode 100644 index 0000000000000000000000000000000000000000..be25dff5bd830ae5f20d0428c809bf6ef17f500f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/arithmetic_self.h @@ -0,0 +1,48 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ARITHMETIC_SELF_H_ +#define NNACL_KERNEL_ARITHMETIC_SELF_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct ArithmeticSelfFunction { + int primitive_type_; + int (*func_)(const float *input, float *output, const int element_size); + int (*func_bool_)(const bool *input, bool *output, const int element_size); + int (*func_int_)(const int *input, int *output, const int element_size); + int (*func_float_bool_)(const float *input, bool *output, const int element_size); +} ArithmeticSelfFunction; + +typedef struct ArithmeticSelfF16Function { + int primitive_type_; +#ifdef ENABLE_FP16 + int (*func_)(const float16_t *input, float16_t *output, int element_size); +#endif +} ArithmeticSelfF16Function; + +typedef struct ArithmeticSelfStruct { + KernelBase base_; + int op_type_; + ArithmeticSelfFunction function_; + ArithmeticSelfF16Function f16_function_; +} ArithmeticSelfStruct; + +KernelBase *CreateArithmeticSelf(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ARITHMETIC_SELF_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/batch_norm.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/batch_norm.c new file mode 100644 index 0000000000000000000000000000000000000000..48ba9f1e086cd839868a13a514440b6dc2be8d33 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/batch_norm.c @@ -0,0 +1,134 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/batch_norm.h" +#include "nnacl/tensor_c.h" +#include "nnacl/op_base.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/batchnorm_parameter.h" +#include "nnacl/fp32/batchnorm_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/batchnorm_fp16.h" +#endif + +int BatchNormFillParam(BatchNormStruct *batch_norm) { + TensorC *input_tensor = batch_norm->base_.in_[FIRST_INPUT]; + int in_channel = input_tensor->shape_[input_tensor->shape_size_ - 1]; + + TensorC *mean_tensor = batch_norm->base_.in_[SECOND_INPUT]; + int mean_channel = mean_tensor->shape_[mean_tensor->shape_size_ - 1]; + + TensorC *var_tensor = batch_norm->base_.in_[SECOND_INPUT]; + int var_channel = mean_tensor->shape_[var_tensor->shape_size_ - 1]; + + if (in_channel != mean_channel || in_channel != var_channel) { + return NNACL_BATCH_NORM_CHANNEL_SHAPE_INVALID; + } + + batch_norm->channel_ = in_channel; + batch_norm->unit_ = 1; + for (size_t i = 0; i < input_tensor->shape_size_ - 1; i++) { + batch_norm->unit_ *= input_tensor->shape_[i]; + } + if (batch_norm->momentum_ < 0.0f) { + batch_norm->momentum_ = 0.0f; + } + return NNACL_OK; +} + +int BatchNormRun(void *cdata, int task_id, float l, float r) { + BatchNormStruct *bn = (BatchNormStruct *)cdata; + void *in_data = bn->base_.in_[FIRST_INPUT]->data_; + void *out_data = bn->base_.out_[OUTPUT_INDEX]->data_; + if (bn->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + BatchNormFp16((float16_t *)in_data, (float16_t *)bn->mean_, (float16_t *)bn->variance_, bn, task_id, + bn->base_.thread_nr_, (float16_t *)out_data); +#endif + } else { + BatchNormFp32((float *)in_data, (float *)bn->mean_, (float *)bn->variance_, bn, task_id, bn->base_.thread_nr_, + (float *)out_data); + } + return NNACL_OK; +} + +int BatchNormReSize(KernelBase *self) { + BatchNormStruct *batch_norm = (BatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(batch_norm); + + int ret = BatchNormFillParam(batch_norm); + if (ret != NNACL_OK) { + return ret; + } + + (void)batch_norm->base_.Release(self); + + batch_norm->mean_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(self->in_[SECOND_INPUT])); + batch_norm->variance_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(self->in_[THIRD_INPUT])); + if (batch_norm->mean_ == NULL || batch_norm->variance_ == NULL) { + (void)batch_norm->base_.Release(self); + return NNACL_ERR; + } + + (void)memcpy(batch_norm->mean_, self->in_[SECOND_INPUT]->data_, NNACLGetSize(self->in_[SECOND_INPUT])); + (void)memcpy(batch_norm->variance_, self->in_[THIRD_INPUT]->data_, NNACLGetSize(self->in_[THIRD_INPUT])); + return NNACL_OK; +} + +int BatchNormRelease(KernelBase *self) { + BatchNormStruct *batch_norm = (BatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(batch_norm); + + if (batch_norm->mean_ != NULL) { + self->env_->Free(self->env_->allocator_, batch_norm->mean_); + batch_norm->mean_ = NULL; + } + if (batch_norm->variance_ != NULL) { + self->env_->Free(self->env_->allocator_, batch_norm->variance_); + batch_norm->variance_ = NULL; + } + return NNACL_OK; +} + +int BatchNormPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + BatchNormStruct *batch_norm = (BatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(batch_norm); + batch_norm->momentum_ = -1.0f; + batch_norm->epsilon_ = ((BatchNormParameter *)self->param_)->epsilon_; + return NNACL_OK; +} + +int BatchNormCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, BatchNormRun, self, self->thread_nr_); +} + +KernelBase *CreateBatchNorm(OpParameter *param, int data_type) { + BatchNormStruct *batch_norm = (BatchNormStruct *)malloc(sizeof(BatchNormStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(batch_norm); + memset(batch_norm, 0, sizeof(BatchNormStruct)); + batch_norm->data_type_ = data_type; + batch_norm->base_.Prepare = BatchNormPrepare; + batch_norm->base_.Resize = BatchNormReSize; + batch_norm->base_.Release = BatchNormRelease; + batch_norm->base_.Compute = BatchNormCompute; + return (KernelBase *)batch_norm; +} + +REG_KERNEL_CREATOR(PrimType_BatchNorm, kNumberTypeFloat16, CreateBatchNorm) +REG_KERNEL_CREATOR(PrimType_BatchNorm, kNumberTypeFloat32, CreateBatchNorm) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/batch_norm.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/batch_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..23b7d54467ecc53c48d69984eb0b249c447cc0da --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/batch_norm.h @@ -0,0 +1,38 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_BATCH_NORM_H_ +#define NNACL_KERNEL_BATCH_NORM_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct BatchNormStruct { + KernelBase base_; + int data_type_; + void *mean_; + void *variance_; + float momentum_; + int unit_; + int channel_; + float epsilon_; +} BatchNormStruct; + +KernelBase *CreateBatchNorm(OpParameter *param, int data_type); +int BatchNormRelease(KernelBase *self); +int BatchNormFillParam(BatchNormStruct *batch_norm); + +#endif // NNACL_KERNEL_BATCH_NORM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/batch_to_space.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/batch_to_space.c new file mode 100644 index 0000000000000000000000000000000000000000..0489d7181b5df3914eff4d8413ad628e36b5d9c4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/batch_to_space.c @@ -0,0 +1,114 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/batch_to_space.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/base/batch_to_space_base.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/batch_to_space_parameter.h" + +int BatchToSpaceProcessInput(BatchToSpaceStruct *batch_to_space) { + TensorC *block_shape = batch_to_space->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(block_shape); + NNACL_CHECK_NULL_RETURN_ERR(block_shape->data_); + TensorC *crop = batch_to_space->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(crop); + NNACL_CHECK_NULL_RETURN_ERR(crop->data_); + + if (NNACLGetElementNum(block_shape) < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { + return NNACL_BATCH_TO_SPACE_BLOCK_SHAPE_INVALID; + } + if (NNACLGetElementNum(crop) < COMM_SHAPE_SIZE) { + return NNACL_BATCH_TO_SPACE_CROP_INVALID; + } + + int32_t *block_shape_data = (int32_t *)block_shape->data_; + for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { + batch_to_space->block_shape_[i] = block_shape_data[i]; + } + + int32_t *crops_data = (int32_t *)crop->data_; + batch_to_space->no_crop_ = true; + for (int i = 0; i < COMM_SHAPE_SIZE; ++i) { + batch_to_space->crops_[i] = crops_data[i]; + if (batch_to_space->crops_[i] != 0) { + batch_to_space->no_crop_ = false; + } + } + return NNACL_OK; +} + +int BatchToSpaceCompute(KernelBase *self) { + BatchToSpaceStruct *batch_to_space = (BatchToSpaceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(batch_to_space); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + size_t data_size = DataTypeCSize(input->data_type_); + if (self->in_size_ == Num1) { + if (batch_to_space->no_crop_) { + BatchToSpaceNoCropForNHWC(input->data_, output->data_, input->shape_, output->shape_[Index0], + batch_to_space->block_shape_, data_size); + } else { + BatchToSpaceForNHWC(input->data_, output->data_, input->shape_, output->shape_[Index0], + batch_to_space->block_shape_, batch_to_space->crops_, data_size); + } + } + + if (self->in_size_ == Num3) { + int ret = BatchToSpaceProcessInput(batch_to_space); + if (ret != NNACL_OK) { + return ret; + } + if (batch_to_space->no_crop_) { + BatchToSpaceNoCropForNHWC(input->data_, output->data_, input->shape_, output->shape_[Index0], + batch_to_space->block_shape_, data_size); + } else { + BatchToSpaceForNHWC(input->data_, output->data_, input->shape_, output->shape_[Index0], + batch_to_space->block_shape_, batch_to_space->crops_, data_size); + } + } + return NNACL_OK; +} + +int BatchToSpaceResize(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NNACL_CHECK_TRUE_RET(self->in_[FIRST_INPUT]->shape_size_ == COMM_SHAPE_SIZE, NNACL_ERR); + return NNACL_OK; +} + +KernelBase *CreateBatchToSpace(OpParameter *param, int data_type) { + BatchToSpaceStruct *batch_to_space = (BatchToSpaceStruct *)malloc(sizeof(BatchToSpaceStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(batch_to_space); + memset(batch_to_space, 0, sizeof(BatchToSpaceStruct)); + BatchToSpaceParameter *bts_param = (BatchToSpaceParameter *)param; + memcpy(batch_to_space->crops_, bts_param->crops_, sizeof(int32_t) * COMM_SHAPE_SIZE); + memcpy(batch_to_space->block_shape_, bts_param->block_shape_, sizeof(int32_t) * BATCH_TO_SPACE_BLOCK_SHAPE_SIZE); + batch_to_space->base_.Prepare = DefaultPrepare1In1Out; + batch_to_space->base_.Resize = BatchToSpaceResize; + batch_to_space->base_.Release = DefaultRelease; + batch_to_space->base_.Compute = BatchToSpaceCompute; + return (KernelBase *)batch_to_space; +} + +REG_KERNEL_CREATOR(PrimType_BatchToSpace, kNumberTypeFloat16, CreateBatchToSpace) +REG_KERNEL_CREATOR(PrimType_BatchToSpace, kNumberTypeFloat32, CreateBatchToSpace) +REG_KERNEL_CREATOR(PrimType_BatchToSpaceND, kNumberTypeFloat16, CreateBatchToSpace) +REG_KERNEL_CREATOR(PrimType_BatchToSpaceND, kNumberTypeFloat32, CreateBatchToSpace) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/batch_to_space.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/batch_to_space.h new file mode 100644 index 0000000000000000000000000000000000000000..233aa14eb612d1fcf5b31acc541b20d5816ad781 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/batch_to_space.h @@ -0,0 +1,33 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_BATCH_TO_SPACE_H_ +#define NNACL_KERNEL_BATCH_TO_SPACE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/batch_to_space_parameter.h" + +typedef struct BatchToSpaceStruct { + KernelBase base_; + bool no_crop_; + int32_t crops_[COMM_SHAPE_SIZE]; + int32_t block_shape_[BATCH_TO_SPACE_BLOCK_SHAPE_SIZE]; +} BatchToSpaceStruct; + +KernelBase *CreateBatchToSpace(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_BATCH_TO_SPACE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/biasadd.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/biasadd.c new file mode 100644 index 0000000000000000000000000000000000000000..5076ecb3ea5bc73b6fc9ccedb28b5286f066384b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/biasadd.c @@ -0,0 +1,131 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/biasadd.h" +#include "nnacl/tensor_c.h" +#include "nnacl/op_base.h" +#include "nnacl/fp32/bias_add.h" +#include "nnacl/kernel/default_kernel_base.h" + +#define BIAS_ADD_PER_UNIT_LOAD_NUM 2 +#define BIAS_ADD_PER_UNIT_STORE_NUM 1 +#define SPLIT_POINTS_SIZE 32 + +typedef struct BiasAddStruct { + KernelBase base_; + int64_t inner_num_; + int64_t outer_num_; + int64_t total_num_; + bool batch_priority_; + int64_t split_points_[SPLIT_POINTS_SIZE]; + int split_pionts_size_; +} BiasAddStruct; + +int ChooseBiasThreadCuttingStrategy(KernelBase *self) { + BiasAddStruct *bias_add = (BiasAddStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(bias_add); + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_BiasAdd), BIAS_ADD_PER_UNIT_LOAD_NUM, + BIAS_ADD_PER_UNIT_STORE_NUM, bias_add->total_num_, self->thread_nr_); + if (self->thread_nr_ > SPLIT_POINTS_SIZE) { + self->thread_nr_ = SPLIT_POINTS_SIZE; + } + + bias_add->split_pionts_size_ = 0; + int64_t block_size = 1; + block_size = bias_add->total_num_ / self->thread_nr_; + int64_t remain_data = bias_add->total_num_ - block_size * self->thread_nr_; + int64_t split_point = 0; + while (split_point < bias_add->total_num_) { + bias_add->split_points_[bias_add->split_pionts_size_++] = split_point; + split_point += block_size; + if (remain_data > 0) { + ++split_point; + --remain_data; + } + } + self->thread_nr_ = bias_add->split_pionts_size_; + if (bias_add->inner_num_ >= C64NUM && block_size / bias_add->inner_num_ >= C6NUM) { + bias_add->batch_priority_ = true; + } else { + bias_add->batch_priority_ = false; + } + return NNACL_OK; +} + +int BiasRun(void *cdata, int task_id, float l, float r) { + BiasAddStruct *bias_add = (BiasAddStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(bias_add); + + float *input = (float *)(bias_add->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(input); + float *bias = (float *)(bias_add->base_.in_[SECOND_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(bias); + float *output = (float *)(bias_add->base_.out_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(output); + + int64_t block_start = bias_add->split_points_[task_id]; + int64_t block_end = bias_add->total_num_; + if ((task_id + 1) < bias_add->split_pionts_size_) { + block_end = bias_add->split_points_[task_id + 1]; + } + BiasAddOpt(input, bias, output, block_start, block_end, bias_add->inner_num_, bias_add->batch_priority_); + return NNACL_OK; +} + +int BiasAddResize(struct KernelBase *self) { + BiasAddStruct *bias_add = (BiasAddStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(bias_add); + + TensorC *in_tensor = self->in_[FIRST_INPUT]; + TensorC *add_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_FALSE(in_tensor->shape_size_ == 0, NNACL_ERR); + NNACL_CHECK_FALSE(add_tensor->shape_size_ == 0, NNACL_ERR); + NNACL_CHECK_FALSE(in_tensor->shape_size_ < add_tensor->shape_size_, NNACL_ERR); + + size_t dim_offset = in_tensor->shape_size_ - add_tensor->shape_size_; + bias_add->inner_num_ = 1; + for (size_t i = 0; i < add_tensor->shape_size_; ++i) { + NNACL_CHECK_FALSE(in_tensor->shape_[i + dim_offset] != add_tensor->shape_[i], NNACL_BIAS_ADD_SHAPE_NOT_MATCH); + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(in_tensor->shape_[i], bias_add->inner_num_), NNACL_BIAS_ADD_SHAPE_OVERFLOW); + bias_add->inner_num_ *= add_tensor->shape_[i]; + } + + bias_add->outer_num_ = 1; + for (size_t i = 0; i < dim_offset; ++i) { + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(in_tensor->shape_[i], bias_add->outer_num_), NNACL_BIAS_ADD_SHAPE_OVERFLOW); + bias_add->outer_num_ *= in_tensor->shape_[i]; + } + + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(bias_add->inner_num_, bias_add->outer_num_), NNACL_BIAS_ADD_SHAPE_OVERFLOW); + bias_add->total_num_ = bias_add->inner_num_ * bias_add->outer_num_; + return ChooseBiasThreadCuttingStrategy(self); +} + +int BiasAddCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, BiasRun, self, self->thread_nr_); +} + +KernelBase *CreateBiasAdd(OpParameter *param, int data_type) { + BiasAddStruct *bias_add = (BiasAddStruct *)malloc(sizeof(BiasAddStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(bias_add); + bias_add->base_.Prepare = DefaultPrepare2In1Out; + bias_add->base_.Resize = BiasAddResize; + bias_add->base_.Release = DefaultRelease; + bias_add->base_.Compute = BiasAddCompute; + return (KernelBase *)bias_add; +} + +REG_KERNEL_CREATOR(PrimType_BiasAdd, kNumberTypeFloat32, CreateBiasAdd) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/biasadd.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/biasadd.h new file mode 100644 index 0000000000000000000000000000000000000000..8b3253e8d2530c944ebc04243717f1201de9ea47 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/biasadd.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_BIASADD_H_ +#define NNACL_KERNEL_BIASADD_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateBiasAdd(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_BIASADD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/cast.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/cast.c new file mode 100644 index 0000000000000000000000000000000000000000..7348c992b59daa4398a2fc5fd73a4838541b25a2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/cast.c @@ -0,0 +1,209 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/cast.h" +#include "nnacl/op_base.h" +#include "nnacl/base/cast_base.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/tensor_c_utils.h" + +#ifdef ENABLE_FP16 +#include "nnacl/fp16/cast_fp16.h" +#endif + +int CastToFp32(const TensorC *input, TensorC *output, int offset, int data_num) { + int input_data_type = input->data_type_; + float *output_data = (float *)output->data_; + switch (input_data_type) { + case kNumberTypeBool: + BoolToFloat32((const bool *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeUInt8: + Uint8ToFloat32((const uint8_t *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeInt32: + Int32ToFloat32((const int32_t *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeFloat16: +#ifdef ENABLE_FP16 + Fp16ToFloat32((const float16_t *)(input->data_) + offset, output_data + offset, data_num); +#else + Fp16ToFloat32((const uint16_t *)(input->data_) + offset, output_data + offset, data_num); +#endif + break; + case kNumberTypeInt64: + Int64ToFloat32((const int64_t *)(input->data_) + offset, output_data + offset, data_num); + break; + default: + return NNACL_ERR; + } + return NNACL_OK; +} + +int CastToFp16(const TensorC *input, TensorC *output, int offset, int data_num) { + int input_data_type = input->data_type_; +#ifdef ENABLE_FP16 + float16_t *output_data = (float16_t *)output->data_; + switch (input_data_type) { + case kNumberTypeFloat32: + Float32ToFp16((const float *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeInt64: + Int64ToFp16((const int64_t *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeInt32: + Int32ToFp16((const int32_t *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeBool: + BoolToFp16((const bool *)(input->data_) + offset, output_data + offset, data_num); + break; + case kNumberTypeUInt8: + Uint8ToFp16((const uint8_t *)(input->data_) + offset, output_data + offset, data_num); + break; + default: + return NNACL_ERR; + } +#else + if (input_data_type == kNumberTypeFloat32) { + Float32ToFp16((const float *)(input->data_) + offset, (uint16_t *)(output->data_) + offset, data_num); + } else { + return NNACL_ERR; + } +#endif + return NNACL_OK; +} + +int CastToOthers(const TensorC *input, TensorC *output, int offset, int data_num) { + int input_data_type = input->data_type_; + int output_data_type = output->data_type_; + void *output_data = output->data_; + if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt64) { + Float32ToInt64((const float *)(input->data_) + offset, (int64_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) { + Float32ToInt32((const float *)(input->data_) + offset, (int32_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { + Int32ToInt64((const int32_t *)(input->data_) + offset, (int64_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeInt64 && output_data_type == kNumberTypeInt32) { + Int64ToInt32((const int64_t *)(input->data_) + offset, (int32_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt16) { + Float32ToInt16((const float *)(input->data_) + offset, (int16_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeBool && output_data_type == kNumberTypeInt32) { + BoolToInt32((const bool *)(input->data_) + offset, (int32_t *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeBool) { + Float32ToBool((const float *)(input->data_) + offset, (bool *)(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeUInt8) { + Float32ToUint8((const float *)(input->data_) + offset, (uint8_t *)(output_data) + offset, data_num); + } else { + return NNACL_ERR; + } + return NNACL_OK; +} + +int CastLaunch(void *cdata, int task_id, float l, float r) { + CastStruct *cast = (CastStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(cast); + + NNACL_CHECK_FALSE(cast->base_.in_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(cast->base_.out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + TensorC *in = cast->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + NNACL_CHECK_NULL_RETURN_ERR(in->data_); + TensorC *out = cast->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out); + NNACL_CHECK_NULL_RETURN_ERR(out->data_); + + int stride = cast->stride_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, stride, NNACL_ERR); + + int data_num = MSMIN(stride, cast->data_num_ - task_id * stride); + if (data_num <= 0) { + return NNACL_OK; + } + + int offset = task_id * stride; + int input_data_type = in->data_type_; + int output_data_type = out->data_type_; + if (input_data_type == output_data_type) { + size_t datalen = DataTypeCSize((TypeIdC)input_data_type); + memcpy((int8_t *)(out->data_) + offset * datalen, (int8_t *)(in->data_) + offset * datalen, data_num * datalen); + return NNACL_OK; + } + + if (output_data_type == kNumberTypeFloat32) { + return CastToFp32(in, out, offset, data_num); + } else if (output_data_type == kNumberTypeFloat16) { + return CastToFp16(in, out, offset, data_num); + } else { + return CastToOthers(in, out, offset, data_num); + } + return NNACL_OK; +} + +int cast_prepare(struct KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + return NNACL_OK; +} + +// Kernel resize input shape +int cast_resize(struct KernelBase *self) { + CastStruct *cast = (CastStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(cast); + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + int data_num = NNACLGetElementNum(in_tensor); + if (data_num == 0) { + return NNACL_OK; + } + + cast->data_num_ = data_num; + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + // update thread num + cast->base_.thread_nr_ = cast->base_.UpdateThread( + TC_PTYPE(PrimType_Cast), 1, 1, NNACLGetElementNum(cast->base_.out_[FIRST_INPUT]), cast->base_.thread_nr_); + cast->stride_ = UP_DIV(data_num, cast->base_.thread_nr_); + return NNACL_OK; +} + +int cast_release(struct KernelBase *self) { return NNACL_OK; } + +// Cast Op Compute +int cast_compute(struct KernelBase *self) { + CastStruct *cast = (CastStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(cast); + if (cast->data_num_ == 0) { + return NNACL_OK; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, CastLaunch, self, self->thread_nr_); +} + +KernelBase *CreateCast(OpParameter *param, int data_type) { + CastStruct *cast = (CastStruct *)malloc(sizeof(CastStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(cast); + memset(cast, 0, sizeof(CastStruct)); + cast->base_.Prepare = cast_prepare; + cast->base_.Resize = cast_resize; + cast->base_.Release = cast_release; + cast->base_.Compute = cast_compute; + cast->stride_ = 0; + cast->data_num_ = 0; + return (KernelBase *)cast; +} + +// todo register kernel diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/cast.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/cast.h new file mode 100644 index 0000000000000000000000000000000000000000..5d91e00b50d87c5a587f018cd106ab320515647d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/cast.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CAST_H_ +#define NNACL_KERNEL_CAST_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct CastStruct { + KernelBase base_; + int stride_; + int data_num_; +} CastStruct; + +KernelBase *CreateCast(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CAST_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/clip.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/clip.c new file mode 100644 index 0000000000000000000000000000000000000000..d2400ead96a6e947ec17bc432387e6d509112f43 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/clip.c @@ -0,0 +1,123 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/clip.h" +#include "nnacl/op_base.h" +#include "nnacl/clip_parameter.h" +#include "nnacl/fp32/activation_fp32.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" + +int GetClipMinMaxValue(TensorC *tensor, float *data) { + NNACL_CHECK_NULL_RETURN_ERR(tensor); + switch (tensor->data_type_) { + case kNumberTypeFloat: + case kNumberTypeFloat32: + *data = *((float *)tensor->data_); + break; + case kNumberTypeInt: + case kNumberTypeInt32: + *data = *((int *)tensor->data_); + break; + default: + return NNACL_CLIP_DATA_TYPE_INVALID; + } + return NNACL_OK; +} + +int ClipResize(struct KernelBase *self) { + ClipStruct *clip = (ClipStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(clip); + clip->base_.thread_nr_ = clip->base_.UpdateThread( + TC_PTYPE(PrimType_Clip), 1, 1, NNACLGetElementNum(clip->base_.out_[FIRST_INPUT]), clip->base_.thread_nr_); + + clip->length_ = NNACLGetElementNum(clip->base_.in_[FIRST_INPUT]); + clip->stride_ = UP_DIV(clip->length_, clip->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(clip->stride_, clip->base_.thread_nr_, NNACL_ERR); + return NNACL_OK; +} + +int ClipImpl(void *cdata, int task_id, float l, float r) { + ClipStruct *clip = (ClipStruct *)cdata; + void *in = clip->base_.in_[FIRST_INPUT]->data_; + void *out = clip->base_.out_[FIRST_INPUT]->data_; + + int stride = clip->stride_ * task_id; + int count = NNACL_MIN(clip->stride_, clip->length_ - stride); + if (count <= 0) { + return NNACL_OK; + } + + switch (clip->base_.in_[FIRST_INPUT]->data_type_) { + case kNumberTypeFloat: + case kNumberTypeFloat32: { + return Fp32Clip((float *)in + stride, count, (float *)out + stride, clip->min_val_, clip->max_val_); + } break; + case kNumberTypeInt: + case kNumberTypeInt32: { + return Int32Clip((int *)in + stride, count, (int *)out + stride, (int)clip->min_val_, (int)clip->max_val_); + } break; + default: + return NNACL_CLIP_DATA_TYPE_INVALID; + } + return NNACL_OK; +} + +int ClipCompute(struct KernelBase *self) { + ClipStruct *clip = (ClipStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(clip); + ClipParameter *param = (ClipParameter *)clip->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + clip->min_val_ = param->min_val_; + clip->max_val_ = param->max_val_; + + int ret = NNACL_OK; + if (clip->base_.in_size_ > ONE_TENSOR) { + TensorC *min_tensor = clip->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(min_tensor); + NNACL_CHECK_NULL_RETURN_ERR(min_tensor->data_); + ret = GetClipMinMaxValue(min_tensor, &(clip->min_val_)); + } + if (clip->base_.in_size_ > TWO_TENSOR) { + TensorC *max_tensor = clip->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(max_tensor); + NNACL_CHECK_NULL_RETURN_ERR(max_tensor->data_); + ret = GetClipMinMaxValue(max_tensor, &(clip->max_val_)); + } + if (ret != NNACL_OK) { + return ret; + } + if (clip->min_val_ >= clip->max_val_) { + return NNACL_CLIP_MINMAX_VALUE_INVALID; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, ClipImpl, clip, clip->base_.thread_nr_); +} + +KernelBase *CreateClip(OpParameter *param, int data_type) { + ClipStruct *clip = (ClipStruct *)malloc(sizeof(ClipStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(clip); + clip->base_.Prepare = DefaultPrepare1In1Out; + clip->base_.Resize = ClipResize; + clip->base_.Release = DefaultRelease; + clip->base_.Compute = ClipCompute; + return (KernelBase *)clip; +} + +REG_KERNEL_CREATOR(PrimType_Clip, kNumberTypeFloat, CreateClip) +REG_KERNEL_CREATOR(PrimType_Clip, kNumberTypeFloat32, CreateClip) +REG_KERNEL_CREATOR(PrimType_Clip, kNumberTypeInt, CreateClip) +REG_KERNEL_CREATOR(PrimType_Clip, kNumberTypeInt32, CreateClip) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/clip.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/clip.h new file mode 100644 index 0000000000000000000000000000000000000000..3718eb45de29f76bc23cb32c1ef9d9ee6c0cf036 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/clip.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CLIP_H_ +#define NNACL_KERNEL_CLIP_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct ClipStruct { + KernelBase base_; + float min_val_; + float max_val_; + int length_; + int stride_; +} ClipStruct; + +KernelBase *CreateClip(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CLIP_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/concat.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/concat.c new file mode 100644 index 0000000000000000000000000000000000000000..65d8e523f0a5d5d4b9e4be38ba4e5bd9ece5d318 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/concat.c @@ -0,0 +1,287 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/concat.h" +#include "nnacl/concat_parameter.h" +#include "nnacl/tensor_c.h" +#include "nnacl/op_base.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/tensor_c_utils.h" + +#define kConcatMinCostPerThread 16384 + +int DoConcat(ConcatStruct *concat, int task_id) { + NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR); + NNACL_CHECK_FALSE(task_id > concat->block_size_, NNACL_ERR); + + int all_bytes = NNACLGetSize(concat->base_.out_[FIRST_INPUT]); + int64_t start = concat->block_splits_[task_id]; + int64_t end = task_id < (concat->block_size_ - 1) ? concat->block_splits_[task_id + 1] : all_bytes; + int64_t start_row = start / concat->inner_sizes_[concat->base_.in_size_]; + int64_t end_row = end / concat->inner_sizes_[concat->base_.in_size_]; + + size_t src_buf_size = concat->base_.in_size_ * sizeof(uint8_t *); + NNACL_CHECK_MALLOC_SIZE(src_buf_size); + uint8_t **src = (uint8_t **)concat->base_.env_->Alloc(concat->base_.env_->allocator_, src_buf_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(src); + for (size_t i = 0; i < concat->base_.in_size_; ++i) { + if (concat->is_with_data_[i]) { + src[i] = concat->inputs_ptr_[i] + start_row * concat->inner_sizes_[i]; + } + } + uint8_t *out = concat->output_ + start; + + int input_index = concat->block_boundary_infos_[task_id].begin_input_; + int end_index = concat->block_boundary_infos_[task_id].end_input_; + if (start_row == end_row) { + if (input_index == end_index) { + memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, + concat->block_boundary_infos_[task_id].end_point_ - concat->block_boundary_infos_[task_id].begin_point_); + concat->base_.env_->Free(concat->base_.env_->allocator_, src); + return NNACL_OK; + } + int64_t size = concat->inner_sizes_[input_index] - concat->block_boundary_infos_[task_id].begin_point_; + memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, size); + out += size; + ++input_index; + for (; input_index < end_index; ++input_index) { + memcpy(out, src[input_index], concat->inner_sizes_[input_index]); + out += concat->inner_sizes_[input_index]; + } + memcpy(out, src[input_index], concat->block_boundary_infos_[task_id].end_point_); + concat->base_.env_->Free(concat->base_.env_->allocator_, src); + return NNACL_OK; + } + for (int i = 0; i < input_index; ++i) { + src[i] += concat->inner_sizes_[i]; + } + int64_t size = concat->inner_sizes_[input_index] - concat->block_boundary_infos_[task_id].begin_point_; + memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, size); + src[input_index] += concat->inner_sizes_[input_index]; + out += size; + ++input_index; + for (; input_index < concat->base_.in_size_; ++input_index) { + memcpy(out, src[input_index], concat->inner_sizes_[input_index]); + src[input_index] += concat->inner_sizes_[input_index]; + out += concat->inner_sizes_[input_index]; + } + ++start_row; + for (; start_row < end_row; ++start_row) { + for (input_index = 0; input_index < concat->base_.in_size_; ++input_index) { + memcpy(out, src[input_index], concat->inner_sizes_[input_index]); + src[input_index] += concat->inner_sizes_[input_index]; + out += concat->inner_sizes_[input_index]; + } + } + for (input_index = 0; input_index < end_index; ++input_index) { + memcpy(out, src[input_index], concat->inner_sizes_[input_index]); + out += concat->inner_sizes_[input_index]; + } + memcpy(out, src[end_index], concat->block_boundary_infos_[task_id].end_point_); + + concat->base_.env_->Free(concat->base_.env_->allocator_, src); + return NNACL_OK; +} + +int ConcatRun(void *cdata, int task_id, float l, float r) { + ConcatStruct *concat = (ConcatStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(concat); + return DoConcat(concat, task_id); +} + +int InitConcatDynamicStatus(ConcatStruct *concat) { + ConcatParameter *param = (ConcatParameter *)concat->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + size_t i = 0; + int64_t output_inner_size = 0; + for (; i < concat->base_.in_size_; i++) { + TensorC *t = concat->base_.in_[i]; + NNACL_CHECK_FALSE(param->axis_ >= t->shape_size_, NNACL_CONCAT_AXIS_INVALID); + int64_t outer_size = 1; + for (int j = 0; j < param->axis_; ++j) { + outer_size *= t->shape_[j]; + } + int inner_size = DataTypeCSize(concat->data_type_); + NNACL_CHECK_TRUE_RET(inner_size > 0, NNACL_UNSUPPORTED_DATA_TYPE); + + for (int j = param->axis_; j < t->shape_size_; ++j) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(inner_size, t->shape_[j], NNACL_CONCAT_SHAPE_INVALID); + inner_size *= t->shape_[j]; + } + if (i == 0) { + concat->outer_size_ = outer_size; + } else { + NNACL_CHECK_TRUE_RET(concat->outer_size_ == outer_size, NNACL_CONCAT_SHAPE_INVALID); + } + if (inner_size == 0) { + concat->is_with_data_[i] = false; + concat->inner_sizes_[i] = inner_size; + continue; + } + concat->is_with_data_[i] = true; + concat->inner_sizes_[i] = inner_size; + output_inner_size += inner_size; + } + concat->inner_sizes_[i] = output_inner_size; + return NNACL_OK; +} + +void ComputeConcatUnitBoundary(ConcatStruct *concat, int64_t *pre_sum, int offset, int *input, int64_t *point) { + size_t index = 0; + for (; index < concat->base_.in_size_; ++index) { + if (offset < pre_sum[index]) { + break; + } + } + *input = index; + *point = concat->inner_sizes_[index] - (pre_sum[index] - offset); +} + +int ChooseConcatThreadCuttingStrategy(ConcatStruct *concat) { + NNACL_CHECK_TRUE_RET(concat->base_.thread_nr_ > 0, NNACL_ERR); + + int all_bytes = NNACLGetSize(concat->base_.out_[FIRST_INPUT]); + int64_t thread_count = MSMAX(1, MSMIN(all_bytes / kConcatMinCostPerThread, concat->base_.thread_nr_)); + + NNACL_CHECK_ZERO_RETURN_ERR(thread_count); + int64_t block_size = all_bytes / thread_count; + int64_t remain_byte = all_bytes - block_size * thread_count; + int64_t *pre_sum = + (int64_t *)concat->base_.env_->Alloc(concat->base_.env_->allocator_, concat->base_.in_size_ * sizeof(int64_t)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(pre_sum); + int64_t init_sum = 0; + for (size_t i = 0; i < concat->base_.in_size_; ++i) { + init_sum += concat->inner_sizes_[i]; + pre_sum[i] = init_sum; + } + + concat->block_size_ = 0; + + int64_t block_spilt = 0; + while (block_spilt < all_bytes) { + concat->block_splits_[concat->block_size_] = block_spilt; + block_spilt += block_size; + if (remain_byte > 0) { + ++block_spilt; + --remain_byte; + } + int64_t start = concat->block_splits_[concat->block_size_]; + int64_t end = block_spilt > all_bytes ? all_bytes : block_spilt; + int64_t start_offset = start - DOWN_ROUND(start, concat->inner_sizes_[concat->base_.in_size_]); + int64_t end_offset = end - DOWN_ROUND(end, concat->inner_sizes_[concat->base_.in_size_]); + ConcatBlockBoundaryInfo block_boundary_info; + ComputeConcatUnitBoundary(concat, pre_sum, start_offset, &block_boundary_info.begin_input_, + &block_boundary_info.begin_point_); + ComputeConcatUnitBoundary(concat, pre_sum, end_offset, &block_boundary_info.end_input_, + &block_boundary_info.end_point_); + concat->block_boundary_infos_[concat->block_size_] = block_boundary_info; + concat->block_size_++; + } + + concat->base_.thread_nr_ = concat->block_size_; + concat->base_.env_->Free(concat->base_.env_->allocator_, pre_sum); + return NNACL_OK; +} + +int ConcatResize(KernelBase *self) { + ConcatStruct *concat = (ConcatStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat); + ConcatParameter *param = (ConcatParameter *)concat->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + param->axis_ = param->axis_ >= 0 ? param->axis_ : self->in_[FIRST_INPUT]->shape_size_ + param->axis_; + NNACL_CHECK_FALSE(param->axis_ < 0, NNACL_CONCAT_AXIS_INVALID); + NNACL_CHECK_FALSE(param->axis_ >= self->in_[FIRST_INPUT]->shape_size_, NNACL_CONCAT_AXIS_INVALID); + + int ret = InitConcatDynamicStatus(concat); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) { + return NNACL_OK; + } + + return ChooseConcatThreadCuttingStrategy(concat); +} + +int ConcatPepare(KernelBase *self) { + ConcatStruct *concat = (ConcatStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat); + + concat->inputs_ptr_ = self->env_->Alloc(self->env_->allocator_, self->in_size_ * sizeof(uint8_t *)); + NNACL_CHECK_NULL_RETURN_ERR(concat->inputs_ptr_); + concat->is_with_data_ = self->env_->Alloc(self->env_->allocator_, self->in_size_ * sizeof(bool)); + NNACL_CHECK_NULL_RETURN_ERR(concat->is_with_data_); + concat->inner_sizes_ = + self->env_->Alloc(self->env_->allocator_, (self->in_size_ + self->out_size_) * sizeof(int64_t)); + NNACL_CHECK_NULL_RETURN_ERR(concat->inner_sizes_); + + return NNACL_OK; +} + +int ConcatRelease(KernelBase *self) { + ConcatStruct *concat = (ConcatStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat); + if (concat->inputs_ptr_ != NULL) { + self->env_->Free(self->env_->allocator_, concat->inputs_ptr_); + } + if (concat->is_with_data_ != NULL) { + self->env_->Free(self->env_->allocator_, concat->is_with_data_); + } + if (concat->inner_sizes_ != NULL) { + self->env_->Free(self->env_->allocator_, concat->inner_sizes_); + } + return NNACL_OK; +} + +int ConcatCompute(KernelBase *self) { + ConcatStruct *concat = (ConcatStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat); + if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) { + return NNACL_OK; + } + + for (size_t i = 0; i < self->in_size_; ++i) { + if (!concat->is_with_data_[i]) { + continue; + } + NNACL_CHECK_NULL_RETURN_ERR(self->in_[i]->data_); + concat->inputs_ptr_[i] = self->in_[i]->data_; + } + + concat->output_ = self->out_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(concat->output_); + return self->env_->ParallelLaunch(self->env_->thread_pool_, ConcatRun, self, self->thread_nr_); +} + +KernelBase *CreateConcat(OpParameter *param, int data_type) { + ConcatStruct *concat = (ConcatStruct *)malloc(sizeof(ConcatStruct)); + NNACL_CHECK_NULL_RETURN_NULL(concat); + memset(concat, 0, sizeof(ConcatStruct)); + concat->data_type_ = kNumberTypeFloat32; + concat->inner_sizes_ = NULL; + concat->inputs_ptr_ = NULL; + concat->is_with_data_ = NULL; + concat->base_.Prepare = ConcatPepare; + concat->base_.Resize = ConcatResize; + concat->base_.Release = ConcatRelease; + concat->base_.Compute = ConcatCompute; + return (KernelBase *)concat; +} + +REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeBool, CreateConcat) +REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeInt32, CreateConcat) +REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeFloat32, CreateConcat) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/concat.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/concat.h new file mode 100644 index 0000000000000000000000000000000000000000..4886251479f86872502e5919a1767a052f0c48a9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/concat.h @@ -0,0 +1,52 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONCAT_H_ +#define NNACL_KERNEL_CONCAT_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct ConcatBlockBoundaryInfo { + int begin_input_; // input-index of upper boundary + int end_input_; // input-index of lower boundary. + int64_t begin_point_; // offset of begin-input. + int64_t end_point_; // required size of end-input. +} ConcatBlockBoundaryInfo; + +typedef struct ConcatStruct { + KernelBase base_; + int64_t outer_size_; + uint8_t *output_; + TypeIdC data_type_; + + bool *is_with_data_; /* size = in_tensor_size */ + uint8_t **inputs_ptr_; /* size = in_tensor_size */ + int64_t *inner_sizes_; // byte-inner-size (including axis) of each input and the last one is output's. + + ConcatBlockBoundaryInfo block_boundary_infos_[MAX_THREAD_NUM]; /* dynamic block size */ + int64_t block_splits_[MAX_THREAD_NUM]; /* dynamic block size */ + size_t block_size_; /* dynamic block size = actual thread number */ +} ConcatStruct; + +KernelBase *CreateConcat(OpParameter *param, int data_type); +int DoConcat(ConcatStruct *concat, int task_id); +int ConcatPepare(KernelBase *self); +int ConcatRelease(KernelBase *self); +int ConcatResize(KernelBase *self); + +#endif // NNACL_KERNEL_CONCAT_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_1x1.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_1x1.c new file mode 100644 index 0000000000000000000000000000000000000000..eda9b6d52c7576fad9fd909b4fbb28f235c6ef75 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_1x1.c @@ -0,0 +1,365 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/convolution_1x1.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/base/conv1x1_base.h" +#include "nnacl/fp32/matmul_fp32.h" + +int Conv1x1Run(void *cdata, int task_id, float l, float r) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + MatMulParameter *matmul = &conv_1x1->matmul_param_; + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, conv_1x1->thread_stride_, NNACL_ERR); + int total_thead_stride_ = task_id * conv_1x1->thread_stride_; + int res_stride = matmul->col_ - total_thead_stride_; + int cur_oc = MSMIN(conv_1x1->thread_stride_, res_stride); + if (cur_oc <= 0) { + return NNACL_OK; + } + + TensorC *out_tensor = conv_1x1->conv_.base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + float *bias = conv_1x1->conv_.bias_data_ == NULL + ? NULL + : (float *)conv_1x1->conv_.bias_data_ + conv_1x1->thread_stride_ * task_id; + float *weight = (float *)conv_1x1->conv_.packed_weight_ + total_thead_stride_ * matmul->deep_; + + if (out_tensor->format_ == Format_NC4HW4) { + MatMulOpt(conv_1x1->pack_input_, weight, conv_1x1->output_ptr_ + total_thead_stride_ * matmul->row_, bias, + matmul->act_type_, matmul->deep_, matmul->row_, cur_oc, matmul->row_, OutType_NC4HW4); + } else { + MatMulOpt(conv_1x1->pack_input_, weight, conv_1x1->output_ptr_ + total_thead_stride_, bias, matmul->act_type_, + matmul->deep_, matmul->row_, cur_oc, matmul->col_, OutType_Nhwc); + } + return NNACL_OK; +} + +void Conv1x1PackMatmulInput(const float *src_ptr, float *dst_ptr, int row, int col) { +#ifdef ENABLE_AVX + RowMajor2Col6Major(src_ptr, dst_ptr, row, col); +#elif defined(ENABLE_SSE) + RowMajor2Col4Major(src_ptr, dst_ptr, row, col); +#else + RowMajor2Col12Major(src_ptr, dst_ptr, row, col); +#endif +} + +int Conv1x1RunHw(void *cdata, int task_id, float l, float r) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + MatMulParameter *matmul = &conv_1x1->matmul_param_; + TensorC *output_tensor = conv_1x1->conv_.base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, conv_1x1->thread_stride_, NNACL_ERR); + int total_thead_stride_ = task_id * conv_1x1->thread_stride_; + int res_stride = matmul->row_ - total_thead_stride_; + int cur_hw_ = MSMIN(conv_1x1->thread_stride_, res_stride); + if (cur_hw_ <= 0) { + return NNACL_OK; + } + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_, matmul->deep_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, conv_1x1->row_tile_, NNACL_ERR); + int total_row_tile_ = task_id * conv_1x1->row_tile_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_row_tile_, matmul->deep_, NNACL_ERR); + float *thread_input_ptr = conv_1x1->input_ptr_ + total_thead_stride_ * matmul->deep_; + float *thread_pack_input = conv_1x1->pack_input_ + total_row_tile_ * matmul->deep_; + float *thread_output_ptr = NULL; + if (output_tensor->format_ != Format_NC4HW4) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_, matmul->col_, NNACL_ERR); + thread_output_ptr = conv_1x1->output_ptr_ + total_thead_stride_ * matmul->col_; + } else { + int col_min = MSMIN(matmul->col_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_, col_min, NNACL_ERR); + thread_output_ptr = conv_1x1->output_ptr_ + total_thead_stride_ * col_min; + } + float *cur_intput = thread_input_ptr; + float *cur_output = thread_output_ptr; + float *bias = (float *)conv_1x1->conv_.bias_data_; + for (int i = 0; i < cur_hw_; i += conv_1x1->row_tile_) { + int cur_rows = (cur_hw_ - i >= conv_1x1->row_tile_) ? conv_1x1->row_tile_ : (cur_hw_ - i); + Conv1x1PackMatmulInput(cur_intput, thread_pack_input, cur_rows, matmul->deep_); + if (output_tensor->format_ == Format_NC4HW4) { + MatMulOpt(thread_pack_input, (float *)conv_1x1->conv_.packed_weight_, cur_output, bias, matmul->act_type_, + matmul->deep_, cur_rows, matmul->col_, matmul->row_, OutType_NC4HW4); + cur_output += conv_1x1->row_tile_ * MSMIN(matmul->col_, C4NUM); + } else { + MatMulOpt(thread_pack_input, (float *)conv_1x1->conv_.packed_weight_, cur_output, bias, matmul->act_type_, + matmul->deep_, cur_rows, matmul->col_, matmul->col_, OutType_Nhwc); + cur_output += conv_1x1->row_tile_ * matmul->col_; + } + cur_intput += conv_1x1->row_tile_ * matmul->deep_; + } + + return NNACL_OK; +} + +void Conv1x1PackWeight(ConvolutionBaseStruct *conv) { + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(filter_tensor); + ConvComputeParam *compute = &conv->compute_; + NNACL_CHECK_NULL_RETURN_VOID(compute); + + if (compute->in_c_ <= 0 || compute->out_c_ <= 0) { + return; + } + + void *origin_weight = conv->base_.train_session_ ? filter_tensor->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + +#ifdef ENABLE_AVX + RowMajor2Col16Major((float *)origin_weight, (float *)conv->packed_weight_, compute->out_c_, compute->in_c_); +#elif defined(ENABLE_ARM32) + RowMajor2Col4Major((float *)origin_weight, (float *)conv->packed_weight_, compute->out_c_, compute->in_c_); +#else + RowMajor2Col8Major((float *)origin_weight, (float *)conv->packed_weight_, compute->out_c_, compute->in_c_); +#endif +} + +int Conv1x1MallocWeightBiasData(ConvolutionBaseStruct *conv) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + + int size = conv->compute_.in_c_ * UP_ROUND(conv->compute_.out_c_, conv_1x1->col_tile_) * sizeof(float); + if (!conv->base_.train_session_) { + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + if (conv->base_.in_size_ == THREE_TENSOR) { + size = UP_ROUND(conv->compute_.out_c_, conv_1x1->col_tile_) * sizeof(float); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + memset(conv->bias_data_, 0, size); + } + return NNACL_OK; +} + +void Conv1x1FreeTmpBuffer(Convolution1x1Struct *conv_1x1) { + if (conv_1x1->pre_trans_input_ && conv_1x1->input_ptr_ != NULL) { + conv_1x1->conv_.base_.env_->Free(conv_1x1->conv_.base_.env_->allocator_, conv_1x1->input_ptr_); + conv_1x1->input_ptr_ = NULL; + } + return; +} + +int InitConv1x1MatmulParam(Convolution1x1Struct *conv_1x1) { + ConvParameter *conv_param = (ConvParameter *)conv_1x1->conv_.base_.param_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->output_h_, conv_param->output_w_, NNACL_ERR); + conv_1x1->matmul_param_.row_ = conv_param->output_h_ * conv_param->output_w_; + conv_1x1->matmul_param_.col_ = conv_param->output_channel_; + conv_1x1->matmul_param_.deep_ = conv_param->input_channel_; + conv_1x1->matmul_param_.row_align_ = UP_ROUND(conv_1x1->matmul_param_.row_, conv_1x1->row_tile_); + conv_1x1->matmul_param_.col_align_ = UP_ROUND(conv_1x1->matmul_param_.col_, conv_1x1->col_tile_); + conv_1x1->matmul_param_.act_type_ = conv_param->act_type_; + return NNACL_OK; +} + +int InitConv1x1Param(Convolution1x1Struct *conv_1x1) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->row_tile_, conv_1x1->conv_.base_.thread_nr_, NNACL_ERR); + if ((conv_1x1->matmul_param_.row_ > (conv_1x1->row_tile_ * conv_1x1->conv_.base_.thread_nr_)) && + (conv_1x1->matmul_param_.row_ > conv_1x1->matmul_param_.col_)) { + conv_1x1->multi_thread_by_hw_ = true; + conv_1x1->conv_.base_.thread_nr_ = + MSMIN(conv_1x1->conv_.base_.thread_nr_, UP_DIV(conv_1x1->matmul_param_.row_, conv_1x1->row_tile_)); + if (conv_1x1->conv_.base_.thread_nr_ <= 0) { + return NNACL_ERR; + } + conv_1x1->thread_stride_ = + UP_DIV(UP_DIV(conv_1x1->matmul_param_.row_, conv_1x1->row_tile_), conv_1x1->conv_.base_.thread_nr_) * + conv_1x1->row_tile_; + } else { + conv_1x1->multi_thread_by_hw_ = false; + conv_1x1->conv_.base_.thread_nr_ = + MSMIN(conv_1x1->conv_.base_.thread_nr_, UP_DIV(conv_1x1->matmul_param_.col_, conv_1x1->col_tile_)); + if (conv_1x1->conv_.base_.thread_nr_ <= 0) { + return NNACL_ERR; + } + conv_1x1->thread_stride_ = + UP_DIV(UP_DIV(conv_1x1->matmul_param_.col_, conv_1x1->col_tile_), conv_1x1->conv_.base_.thread_nr_) * + conv_1x1->col_tile_; + } + + ConvParameter *conv_param = (ConvParameter *)conv_1x1->conv_.base_.param_; + conv_1x1->pre_trans_input_ = + (conv_param->pad_u_ != 0 || conv_param->pad_l_ != 0 || conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1); + if (conv_1x1->pre_trans_input_) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->matmul_param_.row_, conv_1x1->matmul_param_.deep_, NNACL_ERR); + conv_1x1->input_ptr_ = (float *)(conv_1x1->conv_.base_.env_->Alloc( + conv_1x1->conv_.base_.env_->allocator_, + conv_1x1->matmul_param_.row_ * conv_1x1->matmul_param_.deep_ * sizeof(float))); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_1x1->input_ptr_); + memset(conv_1x1->input_ptr_, 0, conv_1x1->matmul_param_.row_ * conv_1x1->matmul_param_.deep_ * sizeof(float)); + } + + return NNACL_OK; +} + +int Convolution1x1Resize(KernelBase *self) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + + Conv1x1FreeTmpBuffer(conv_1x1); + int error_code = ConvBasePrepare(&conv_1x1->conv_); + if (error_code != NNACL_OK) { + return error_code; + } + + error_code = InitConv1x1MatmulParam(conv_1x1); + if (error_code != NNACL_OK) { + return error_code; + } + + error_code = InitConv1x1Param(conv_1x1); + if (error_code != NNACL_OK) { + return error_code; + } + + return NNACL_OK; +} + +int Convolution1x1Prepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + +#ifdef ENABLE_AVX + conv_1x1->row_tile_ = C6NUM; + conv_1x1->col_tile_ = C16NUM; +#elif defined(ENABLE_SSE) + conv_1x1->row_tile_ = C4NUM; + conv_1x1->col_tile_ = C8NUM; +#elif defined(ENABLE_ARM32) + conv_1x1->row_tile_ = C12NUM; + conv_1x1->col_tile_ = C4NUM; +#else + conv_1x1->row_tile_ = C12NUM; + conv_1x1->col_tile_ = C8NUM; +#endif + + if (self->train_session_) { + int output_tile_size = UP_ROUND(conv_1x1->conv_.compute_.out_c_, conv_1x1->col_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->conv_.compute_.in_c_, output_tile_size, NNACL_ERR); + size_t size = conv_1x1->conv_.compute_.in_c_ * output_tile_size * sizeof(float); + conv_1x1->conv_.base_.work_size_ = size; + } + + int error_code = ConvBaseInitConvWeightBias(&conv_1x1->conv_); + if (error_code != NNACL_OK) { + return error_code; + } + return NNACL_OK; +} + +int Convolution1x1Release(KernelBase *self) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + Conv1x1FreeTmpBuffer(conv_1x1); + ConvBaseRelease(&conv_1x1->conv_); + return NNACL_OK; +} + +int Convolution1x1Compute(KernelBase *self) { + Convolution1x1Struct *conv_1x1 = (Convolution1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_1x1); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + float *src_in = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_in); + float *src_out = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_out); + + int pack_input_size = 0; + if (conv_1x1->multi_thread_by_hw_) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->conv_.base_.thread_nr_, conv_1x1->row_tile_, NNACL_ERR); + int total_row_tile_ = conv_1x1->conv_.base_.thread_nr_ * conv_1x1->row_tile_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_row_tile_, conv_1x1->matmul_param_.deep_, NNACL_ERR); + pack_input_size = total_row_tile_ * conv_1x1->matmul_param_.deep_; + } else { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->matmul_param_.row_align_, conv_1x1->matmul_param_.deep_, NNACL_ERR); + pack_input_size = conv_1x1->matmul_param_.row_align_ * conv_1x1->matmul_param_.deep_; + } + conv_1x1->pack_input_ = + (float *)conv_1x1->conv_.base_.env_->Alloc(conv_1x1->conv_.base_.env_->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_1x1->pack_input_); + + int ret = ConvBaseRepackWeight(&conv_1x1->conv_); + if (ret != NNACL_OK) { + return ret; + } + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_1x1->matmul_param_.row_, conv_1x1->matmul_param_.col_, NNACL_ERR); + int matmul_size = conv_1x1->matmul_param_.row_ * conv_1x1->matmul_param_.col_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->input_batch_ - 1, matmul_size, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->input_h_, conv_param->input_w_, NNACL_ERR); + int conv_input_hw = conv_param->input_h_ * conv_param->input_w_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_input_hw, conv_param->input_channel_, NNACL_ERR); + int conv_input_bhw = conv_input_hw * conv_param->input_channel_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_param->input_batch_ - 1, conv_input_bhw, NNACL_ERR); + for (int batch_index = 0; batch_index < conv_param->input_batch_; batch_index++) { + conv_1x1->output_ptr_ = src_out + batch_index * matmul_size; + float *tmp_in = src_in + batch_index * conv_input_bhw; + if (conv_1x1->pre_trans_input_) { + Conv1x1InputPack(tmp_in, conv_1x1->input_ptr_, conv_param, sizeof(float)); + } else { + conv_1x1->input_ptr_ = tmp_in; + } + if (conv_1x1->multi_thread_by_hw_) { + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, Conv1x1RunHw, self, self->thread_nr_); + } else { + Conv1x1PackMatmulInput(conv_1x1->input_ptr_, conv_1x1->pack_input_, conv_1x1->matmul_param_.row_, + conv_1x1->matmul_param_.deep_); + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, Conv1x1Run, self, self->thread_nr_); + } + if (ret != NNACL_OK) { + break; + } + } + + if (conv_1x1->pack_input_ != NULL) { + self->env_->Free(self->env_->allocator_, conv_1x1->pack_input_); + conv_1x1->pack_input_ = NULL; + } + return ret; +} + +ConvolutionBaseStruct *CreateConvolution1x1(ConvParameter *conv_param) { + Convolution1x1Struct *conv1x1 = (Convolution1x1Struct *)malloc(sizeof(Convolution1x1Struct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv1x1); + memset(conv1x1, 0, sizeof(Convolution1x1Struct)); + + conv1x1->conv_.is_sharing_pack_ = false; + conv1x1->conv_.malloc_weight_bias_ = Conv1x1MallocWeightBiasData; + conv1x1->conv_.pack_weight_ = Conv1x1PackWeight; + + conv1x1->conv_.base_.Resize = Convolution1x1Resize; + conv1x1->conv_.base_.Prepare = Convolution1x1Prepare; + conv1x1->conv_.base_.Release = Convolution1x1Release; + conv1x1->conv_.base_.Compute = Convolution1x1Compute; + + return (ConvolutionBaseStruct *)conv1x1; +} diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_plugin_impl.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_1x1.h similarity index 47% rename from mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_plugin_impl.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_1x1.h index ece05c05c59b2afd5cda101daa2ab244e656fd47..c7e7d31b9b3e702b3f4b0454fc50bea3bcdc6901 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_plugin_impl.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_1x1.h @@ -13,21 +13,30 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_NATIVE_PLUGIN_IMPL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_NATIVE_PLUGIN_IMPL_H_ -#include -#include "include/api/status.h" -#include "src/common/log_adapter.h" -#include "extendrt/delegate/plugin/ascend_native_executor_plugin.h" +#ifndef NNACL_KERNEL_CONVOLLUTION_1X1_H_ +#define NNACL_KERNEL_CONVOLLUTION_1X1_H_ -namespace mindspore { -class AscendNativeExecutorPluginImpl : public lite::AscendNativeExecutorPluginImplBase { - public: - AscendNativeExecutorPluginImpl() = default; - virtual ~AscendNativeExecutorPluginImpl() = default; -}; +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/kernel/convolution_base.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/matmul_parameter.h" -extern "C" MS_API AscendNativeExecutorPluginImpl *CreateAscendNativeExecutorPluginImpl(); -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_GE_GE_PLUGIN_IMPL_H_ +typedef struct Convolution1x1Struct { + ConvolutionBaseStruct conv_; + MatMulParameter matmul_param_; + int row_tile_; + int col_tile_; + bool pre_trans_input_; + float *input_ptr_; + float *output_ptr_; + float *pack_input_; + bool multi_thread_by_hw_; + int thread_stride_; +} Convolution1x1Struct; + +ConvolutionBaseStruct *CreateConvolution1x1(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_1X1_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_base.c new file mode 100644 index 0000000000000000000000000000000000000000..53fff731daa8d4ee4f3e98e59a1dea6852ace8b1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_base.c @@ -0,0 +1,209 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/convolution_base.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/tensor_c_utils.h" + +int ConvBaseUpdateParamInfo(ConvComputeParam *compute, ConvParameter *conv_param) { + compute->stride_h_ = conv_param->stride_h_; + compute->stride_w_ = conv_param->stride_w_; + compute->dilation_h_ = conv_param->dilation_h_; + compute->dilation_w_ = conv_param->dilation_w_; + compute->pad_u_ = conv_param->pad_u_; + compute->pad_d_ = conv_param->pad_d_; + compute->pad_l_ = conv_param->pad_l_; + compute->pad_r_ = conv_param->pad_r_; + + compute->in_c_ = conv_param->input_channel_; + compute->out_c_ = conv_param->output_channel_; + + compute->kernel_h_ = conv_param->kernel_h_; + compute->kernel_w_ = conv_param->kernel_w_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->kernel_h_, compute->kernel_w_, NNACL_ERR); + compute->kernel_hw_ = compute->kernel_h_ * compute->kernel_w_; + + return NNACL_OK; +} + +int ConvBaseUpdateComputeInfo(ConvolutionBaseStruct *conv) { + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + TensorC *input = conv->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = conv->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + conv_param->input_batch_ = NNACLGetBatch(input); + conv_param->input_h_ = NNACLGetHeight(input); + conv_param->input_w_ = NNACLGetWidth(input); + conv_param->input_channel_ = NNACLGetChannel(input); + conv_param->output_batch_ = NNACLGetBatch(output); + conv_param->output_h_ = NNACLGetHeight(output); + conv_param->output_w_ = NNACLGetWidth(output); + conv_param->output_channel_ = NNACLGetChannel(output); + + ConvComputeParam *compute = &conv->compute_; + compute->in_n_ = NNACLGetBatch(input); + compute->in_h_ = NNACLGetHeight(input); + compute->in_w_ = NNACLGetWidth(input); + compute->in_c_ = NNACLGetChannel(input); + NNACL_CHECK_FALSE(compute->in_c_ != conv_param->input_channel_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_h_, compute->in_w_, NNACL_ERR); + compute->in_hw_ = compute->in_h_ * compute->in_w_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_hw_, compute->in_n_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_hw_ * compute->in_n_, compute->in_c_, NNACL_ERR); + + compute->out_n_ = NNACLGetBatch(output); + compute->out_h_ = NNACLGetHeight(output); + compute->out_w_ = NNACLGetWidth(output); + compute->out_c_ = NNACLGetChannel(output); + NNACL_CHECK_FALSE(compute->out_c_ != conv_param->output_channel_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_h_, compute->out_w_, NNACL_ERR); + compute->out_hw_ = compute->out_h_ * compute->out_w_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_hw_, compute->out_n_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_hw_ * compute->out_n_, compute->out_c_, NNACL_ERR); + + return ConvBaseUpdateParamInfo(compute, conv_param); +} + +void ConvBaseRelease(ConvolutionBaseStruct *conv) { + if (!conv->base_.train_session_) { + if (!conv->is_sharing_pack_) { + conv->base_.env_->Free(conv->base_.env_->allocator_, conv->packed_weight_); + } else { + conv->free_sharing_weight_(conv->shaing_manager_, conv->packed_weight_); + } + conv->packed_weight_ = NULL; + } + + if (conv->bias_data_ != NULL) { + conv->base_.env_->Free(conv->base_.env_->allocator_, conv->bias_data_); + conv->bias_data_ = NULL; + } +} + +int ConvBasePrepare(ConvolutionBaseStruct *conv) { + NNACL_CHECK_FALSE(conv->base_.in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(conv->base_.out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + conv->out_format_ = conv->base_.out_[OUTPUT_INDEX]->format_; + return ConvBaseUpdateComputeInfo(conv); +} + +void ConvBaseUpdateOriginWeightAndBias(ConvolutionBaseStruct *conv) { + NNACL_CHECK_NULL_RETURN_VOID(conv); + + if (conv->base_.in_[SECOND_INPUT]->data_ != NULL) { + conv->origin_weight_ = conv->base_.in_[SECOND_INPUT]->data_; + } + + if (conv->base_.in_size_ == THREE_TENSOR && conv->base_.in_[THIRD_INPUT]->data_ != NULL) { + conv->origin_bias_ = conv->base_.in_[THIRD_INPUT]->data_; + } +} + +int ConvBaseInitConvWeightBias(ConvolutionBaseStruct *conv) { + if (conv->base_.train_session_) { + ConvBaseUpdateOriginWeightAndBias(conv); + } + + /* check weight shape done */ + if (!CheckInferShapeDone(&conv->base_.in_[SECOND_INPUT], ONE_TENSOR, NULL, 0)) { + return NNACL_OK; + } + + int ret = conv->malloc_weight_bias_(conv); + if (ret != NNACL_OK) { + return ret; + } + + if ((conv->base_.in_size_ == THREE_TENSOR) && (conv->origin_bias_ != NULL)) { + memcpy(conv->bias_data_, conv->origin_bias_, NNACLGetSize(conv->base_.in_[THIRD_INPUT])); + } + + if (!conv->base_.train_session_) { + if (conv->weight_is_packed_) { + return NNACL_OK; + } + if (conv->origin_weight_ != NULL) { + conv->pack_weight_(conv); + } else { + conv->is_repack_ = true; + } + } + return NNACL_OK; +} + +int ConvBaseCheckResizeValid(ConvolutionBaseStruct *conv) { + // ===============check in channel================= // + TensorC *input_tensor = conv->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + int resize_in_channel = NNACLGetChannel(input_tensor); + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(filter_tensor); + int filter_in_channel = NNACLGetChannel(filter_tensor); + if (filter_in_channel != resize_in_channel) { + return NNACL_CONVOLUTION_INPUT_CHANNEL_UNMATCH; + } + return NNACL_OK; +} + +void *ConvBaseGetConvPackWeightData(ConvolutionBaseStruct *conv, int data_size) { + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + bool const_fit = weight_tensor->category_ != ConstTensor && weight_tensor->category_ != ConstScalar; + bool group_fit = ((ConvParameter *)conv->base_.param_)->group_ > 1; + bool sharing_fit = conv->get_sharing_weight_ == NULL; + + void *data = NULL; + if (sharing_fit || const_fit || group_fit) { + if (data_size <= 0) { + return NULL; + } + data = conv->base_.env_->Alloc(conv->base_.env_->allocator_, data_size); + conv->weight_is_packed_ = false; + conv->is_sharing_pack_ = false; + } else { + data = conv->get_sharing_weight_(conv->shaing_manager_, weight_tensor->data_, data_size, &conv->weight_is_packed_); + } + return data; +} + +int ConvBaseRepackWeight(ConvolutionBaseStruct *conv) { + NNACL_CHECK_NULL_RETURN_ERR(conv); + + conv->origin_weight_ = conv->origin_weight_ != NULL ? conv->origin_weight_ : conv->base_.in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv->origin_weight_); + + if (conv->packed_weight_ == NULL) { + int ret = ConvBaseInitConvWeightBias(conv); + if (ret != NNACL_OK) { + return ret; + } + } + + if (conv->is_repack_ || conv->base_.train_session_) { + if (conv->base_.train_session_) { + conv->packed_weight_ = (float *)conv->base_.workspace_; + memset(conv->packed_weight_, 0, conv->base_.work_size_); + } else { + conv->is_repack_ = false; + } + conv->pack_weight_(conv); + } + return NNACL_OK; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_base.h new file mode 100644 index 0000000000000000000000000000000000000000..704b66091c44c2de4f2636bfadaa620f70b90fe1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_base.h @@ -0,0 +1,63 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_BASE_H_ +#define NNACL_KERNEL_CONVOLLUTION_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/tensor_c_utils.h" + +#define ConvMinBlock 1 + +typedef struct ConvolutionBaseStruct { + KernelBase base_; + ConvComputeParam compute_; + bool weight_is_packed_; + bool is_repack_; + bool infershape_done_; + bool use_batch_cut_flag_; + FormatC out_format_; + + void *packed_weight_; + void *bias_data_; + void *origin_weight_; // do not Free + void *origin_bias_; // do not Free + + void (*init_global_variable_)(struct ConvolutionBaseStruct *conv_im2col); + int (*malloc_weight_bias_)(struct ConvolutionBaseStruct *conv_base); + void (*pack_weight_)(struct ConvolutionBaseStruct *conv_base); + int (*run_impl_)(struct ConvolutionBaseStruct *conv, int task_id); + + bool is_sharing_pack_; + void *shaing_manager_; + void (*free_sharing_weight_)(void *manager, void *tensor_data); + void *(*get_sharing_weight_)(void *manager, const void *tensor_data, const size_t size, bool *is_packed); +} ConvolutionBaseStruct; + +int ConvBaseUpdateParamInfo(ConvComputeParam *compute, ConvParameter *conv_param); +int ConvBaseUpdateComputeInfo(ConvolutionBaseStruct *conv); +void ConvBaseRelease(ConvolutionBaseStruct *conv); +int ConvBaseCheckResizeValid(ConvolutionBaseStruct *conv); +int ConvBasePrepare(ConvolutionBaseStruct *conv); +int ConvBaseInitConvWeightBias(ConvolutionBaseStruct *conv); +int ConvBaseRepackWeight(ConvolutionBaseStruct *conv); +void ConvBaseUpdateOriginWeightAndBias(ConvolutionBaseStruct *conv); +void *ConvBaseGetConvPackWeightData(ConvolutionBaseStruct *conv, int data_size); + +#endif // NNACL_KERNEL_CONVOLLUTION_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_delegate.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_delegate.c new file mode 100644 index 0000000000000000000000000000000000000000..d2fd88995714e3bd8b9865ee4c5f66682bb1c885 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_delegate.c @@ -0,0 +1,357 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/convolution_delegate.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/base/conv_common_base.h" +#include "nnacl/kernel/group_convolution.h" +#include "nnacl/kernel/convolution_depthwise.h" +#include "nnacl/kernel/convolution_1x1.h" +#include "nnacl/kernel/convolution_im2col.h" +#include "nnacl/kernel/convolution_winograd.h" +#include "nnacl/fp32/conv_winograd_fp32.h" +#include "nnacl/kernel/convolution_depthwise_sw.h" +#ifdef ENABLE_AVX +#include "nnacl/kernel/convolution_sw_1x1.h" +#include "nnacl/kernel/convolution_sw_avx.h" +#include "nnacl/kernel/convolution_depthwise_sw_avx.h" +#endif +#ifdef ENABLE_ARM64 +#include "nnacl/kernel/convolution_depthwise_indirect.h" +#include "nnacl/kernel/convolution_sw_arm64.h" +#include "nnacl/fp32/conv_sw_arm64_fp32.h" +#endif +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +#include "nnacl/kernel/convolution_depthwise_3x3.h" +#include "nnacl/fp32/conv_depthwise_fp32.h" +#endif + +#define MaxDwConvSWSize 32 + +float *ConvolutionDelegateCopyData(const TensorC *tensor) { + NNACL_CHECK_NULL_RETURN_NULL(tensor); + NNACL_CHECK_NULL_RETURN_NULL(tensor->data_); + + float *data = (float *)malloc(NNACLGetSize(tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(data); + + (void)memcpy(data, tensor->data_, NNACLGetSize(tensor)); + return data; +} + +int ConvolutionDelegateGetWeightData(ConvolutionDelegateStruct *convolution_delegate) { + if (convolution_delegate->conv_.base_.in_[SECOND_INPUT]->data_ == NULL) { + return NNACL_OK; + } + if (convolution_delegate->conv_.infershape_done_) { + convolution_delegate->origin_weight_ = convolution_delegate->conv_.base_.in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate->origin_weight_); + convolution_delegate->need_free_weight_ = false; + return NNACL_OK; + } + convolution_delegate->origin_weight_ = + ConvolutionDelegateCopyData(convolution_delegate->conv_.base_.in_[SECOND_INPUT]); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(convolution_delegate->origin_weight_); + convolution_delegate->need_free_weight_ = true; + return NNACL_OK; +} + +int ConvolutionDelegateGetBiasData(ConvolutionDelegateStruct *convolution_delegate) { + if (convolution_delegate->conv_.base_.in_size_ != THREE_TENSOR) { + convolution_delegate->origin_bias_ = NULL; + convolution_delegate->need_free_bias_ = false; + return NNACL_OK; + } + + if (convolution_delegate->conv_.infershape_done_) { + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate->conv_.base_.in_[THIRD_INPUT]); + convolution_delegate->origin_bias_ = convolution_delegate->conv_.base_.in_[THIRD_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate->origin_bias_); + convolution_delegate->need_free_bias_ = false; + return NNACL_OK; + } + + convolution_delegate->origin_bias_ = ConvolutionDelegateCopyData(convolution_delegate->conv_.base_.in_[THIRD_INPUT]); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(convolution_delegate->origin_bias_); + convolution_delegate->need_free_bias_ = true; + return NNACL_OK; +} + +int ConvolutionDelegateGetWeightAndBias(ConvolutionDelegateStruct *convolution_delegate) { + int ret = ConvolutionDelegateGetWeightData(convolution_delegate); + if (ret != NNACL_OK) { + return ret; + } + + return ConvolutionDelegateGetBiasData(convolution_delegate); +} + +ConvolutionBaseStruct *ConvolutionDelegateConvNC4KernelSelect(ConvolutionDelegateStruct *convolution_delegate) { + /* runtime nc4hw4 pass + * arm64: conv1x1 conv_Im2col support nc4 + * Avx: conv_Im2col support nc4 + * */ + ConvParameter *conv_param = (ConvParameter *)convolution_delegate->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_NULL(conv_param); + +#ifdef ENABLE_ARM64 + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + ConvolutionBaseStruct *conv1x1 = CreateConvolution1x1(conv_param); + return conv1x1; + } +#endif + +#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) + ConvolutionBaseStruct *conv_im2col = CreateConvolutionIm2Col(&convolution_delegate->conv_.base_, conv_param); + return conv_im2col; +#endif + + return NULL; +} + +ConvolutionBaseStruct *ConvolutionDelegateConvNHWCKernelSelect(ConvolutionDelegateStruct *convolution_delegate) { + ConvParameter *conv_param = (ConvParameter *)convolution_delegate->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_NULL(conv_param); + + ConvolutionBaseStruct *conv = NULL; + + int out_unit; + if (CheckIfUseWinograd(&out_unit, conv_param)) { + conv = CreateConvolutionWinograd(conv_param, out_unit); + } + +#ifdef ENABLE_AVX + if (conv == NULL && CheckAvxUseSW1x1Conv(conv_param)) { + conv = CreateConvolutionSW1x1(conv_param, convolution_delegate->input_const_, convolution_delegate->weight_const_); + } + + if (conv == NULL && CheckAvxUseSWConv(conv_param, convolution_delegate->conv_.base_.thread_nr_)) { + conv = CreateConvolutionSWAVX(conv_param); + } +#endif + +#ifdef ENABLE_ARM64 + if (conv == NULL && CheckArm64UseSWConv(conv_param)) { + conv = CreateConvolutionSWARM64(conv_param); + } +#endif + + if (conv == NULL) { + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + conv = CreateConvolution1x1(conv_param); + } else { + conv = CreateConvolutionIm2Col(&convolution_delegate->conv_.base_, conv_param); + } + } + return conv; +} + +ConvolutionBaseStruct *ConvolutionDelegateConvolutionSelect(ConvolutionDelegateStruct *convolution_delegate) { + ConvolutionBaseStruct *conv; + if (convolution_delegate->conv_.base_.out_[OUTPUT_INDEX]->format_ == Format_NC4HW4) { + conv = ConvolutionDelegateConvNC4KernelSelect(convolution_delegate); + } else { + conv = ConvolutionDelegateConvNHWCKernelSelect(convolution_delegate); + } + if (conv == NULL) { + return NULL; + } + + conv->base_.InferShape = convolution_delegate->conv_.base_.InferShape; + conv->base_.UpdateThread = convolution_delegate->conv_.base_.UpdateThread; + conv->base_.env_ = convolution_delegate->conv_.base_.env_; + conv->base_.param_ = convolution_delegate->conv_.base_.param_; + conv->base_.thread_nr_ = convolution_delegate->conv_.base_.thread_nr_; + conv->base_.train_session_ = convolution_delegate->conv_.base_.train_session_; + conv->base_.in_ = convolution_delegate->conv_.base_.in_; + conv->base_.in_size_ = convolution_delegate->conv_.base_.in_size_; + conv->base_.out_ = convolution_delegate->conv_.base_.out_; + conv->base_.out_size_ = convolution_delegate->conv_.base_.out_size_; + + conv->infershape_done_ = convolution_delegate->conv_.infershape_done_; + conv->shaing_manager_ = convolution_delegate->conv_.shaing_manager_; + conv->get_sharing_weight_ = convolution_delegate->conv_.get_sharing_weight_; + conv->free_sharing_weight_ = convolution_delegate->conv_.free_sharing_weight_; + conv->is_sharing_pack_ = convolution_delegate->conv_.is_sharing_pack_; + + conv->origin_weight_ = convolution_delegate->origin_weight_; + conv->origin_bias_ = convolution_delegate->origin_bias_; + return conv; +} + +void ConvolutionDelegateFreeCopiedData(ConvolutionDelegateStruct *convolution_delegate) { + if (convolution_delegate->origin_weight_ != NULL && convolution_delegate->need_free_weight_) { + free(convolution_delegate->origin_weight_); + } + convolution_delegate->origin_weight_ = NULL; + convolution_delegate->conv_.origin_weight_ = NULL; + convolution_delegate->need_free_weight_ = false; + + if (convolution_delegate->origin_bias_ != NULL && convolution_delegate->need_free_bias_) { + free(convolution_delegate->origin_bias_); + } + convolution_delegate->origin_bias_ = NULL; + convolution_delegate->conv_.origin_bias_ = NULL; + convolution_delegate->need_free_bias_ = false; +} + +int ConvolutionDelegateResize(struct KernelBase *self) { + ConvolutionDelegateStruct *convolution_delegate = (ConvolutionDelegateStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate); + + if (convolution_delegate->convolution_ == NULL) { + convolution_delegate->convolution_ = ConvolutionDelegateConvolutionSelect(convolution_delegate); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(convolution_delegate->convolution_); + (void)ConvBaseUpdateComputeInfo(convolution_delegate->convolution_); + int ret = convolution_delegate->convolution_->base_.Prepare(&convolution_delegate->convolution_->base_); + if (ret != NNACL_OK) { + return ret; + } + } + + (void)ConvBaseUpdateComputeInfo(convolution_delegate->convolution_); + int ret = convolution_delegate->convolution_->base_.Resize(&convolution_delegate->convolution_->base_); + if (ret != NNACL_OK) { + return ret; + } + + ConvolutionDelegateFreeCopiedData(convolution_delegate); + return NNACL_OK; +} + +int ConvolutionDelegatePrepare(struct KernelBase *self) { + ConvolutionDelegateStruct *convolution_delegate = (ConvolutionDelegateStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate); + + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[SECOND_INPUT]); + + NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ != kNumberTypeFloat32 && + self->in_[SECOND_INPUT]->data_type_ != kNumberTypeFloat16, + NNACL_CONVOLUTION_WEIGHT_DATATYPE_INVALID); + NNACL_CHECK_FALSE(self->in_size_ == THREE_TENSOR && self->in_[THIRD_INPUT] != NULL && + self->in_[THIRD_INPUT]->data_type_ != kNumberTypeFloat32, + NNACL_CONVOLUTION_BIAS_DATATYPE_INVALID); + + convolution_delegate->input_const_ = NNACLIsConst(self->in_[FIRST_INPUT]) && !self->train_session_; + convolution_delegate->weight_const_ = NNACLIsConst(self->in_[SECOND_INPUT]) && !self->train_session_; + + return ConvolutionDelegateGetWeightAndBias(convolution_delegate); +} + +int ConvolutionDelegateRelease(struct KernelBase *self) { + ConvolutionDelegateStruct *convolution_delegate = (ConvolutionDelegateStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate); + if (convolution_delegate->convolution_ != NULL) { + (void)convolution_delegate->convolution_->base_.Release(&convolution_delegate->convolution_->base_); + free(convolution_delegate->convolution_); + convolution_delegate->convolution_ = NULL; + } + return NNACL_OK; +} + +int ConvolutionDelegateCompute(struct KernelBase *self) { + ConvolutionDelegateStruct *convolution_delegate = (ConvolutionDelegateStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate); + NNACL_CHECK_NULL_RETURN_ERR(convolution_delegate->convolution_); + + convolution_delegate->convolution_->base_.workspace_ = convolution_delegate->conv_.base_.workspace_; + return convolution_delegate->convolution_->base_.Compute(&convolution_delegate->convolution_->base_); +} + +KernelBase *CreateConvlutionDelegate(ConvParameter *conv_param) { + ConvolutionDelegateStruct *delegate = (ConvolutionDelegateStruct *)malloc(sizeof(ConvolutionDelegateStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(delegate); + memset(delegate, 0, sizeof(ConvolutionDelegateStruct)); + delegate->conv_.base_.Prepare = ConvolutionDelegatePrepare; + delegate->conv_.base_.Resize = ConvolutionDelegateResize; + delegate->conv_.base_.Release = ConvolutionDelegateRelease; + delegate->conv_.base_.Compute = ConvolutionDelegateCompute; + return (KernelBase *)delegate; +} + +KernelBase *CreateConvolutionDepthwise(ConvParameter *conv_param) { + NNACL_CHECK_NULL_RETURN_NULL(conv_param); + KernelBase *kernel = NULL; + + if (conv_param->dynamic_shape_) { + kernel = CreateConvDw(conv_param); + if (kernel != NULL) { + return kernel; + } + } + +#ifdef ENABLE_AVX + kernel = CreateConvDwSWAVX(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) + if (CheckConvDw1DWinograd(conv_param, conv_param->thread_num_)) { + kernel = CreateConvDw3x3(conv_param); + if (kernel != NULL) { + return kernel; + } + } +#endif + +#ifdef ENABLE_ARM64 + if (CheckConvDwUseIndirectBuffer(conv_param)) { + kernel = CreateConvDwIndirect(conv_param); + if (kernel != NULL) { + return kernel; + } + } +#endif + + if (conv_param->input_channel_ < MaxDwConvSWSize) { + kernel = CreateConvDwSW(conv_param); + if (kernel != NULL) { + return kernel; + } + } + + kernel = CreateConvDw(conv_param); + return kernel; +} + +KernelBase *CreateConv2DFusion(OpParameter *param, int data_type) { + ConvParameter *conv_param = (ConvParameter *)param; + conv_param->thread_num_ = param->thread_num_; + KernelBase *kernel; + if (conv_param->group_ == 1) { + kernel = CreateConvlutionDelegate(conv_param); + } else if (conv_param->group_ == conv_param->input_channel_ && conv_param->group_ == conv_param->output_channel_) { + kernel = CreateConvolutionDepthwise(conv_param); + } else { + kernel = CreateGroupConvolution(conv_param, data_type); + } + + if (kernel == NULL) { + return NULL; + } + + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)kernel; + (void)ConvBaseUpdateParamInfo(&conv->compute_, conv_param); + + return kernel; +} + +REG_KERNEL_CREATOR(PrimType_Conv2DFusion, kNumberTypeFloat32, CreateConv2DFusion) diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ops/ascend_native_stub.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_delegate.h similarity index 43% rename from mindspore-lite/src/extendrt/delegate/ascend_native/ops/ascend_native_stub.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_delegate.h index f676b160d070c939333bfd550a3e2863aabe7435..9f84f99635754f91b45346d81f69b23a85da9796 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ops/ascend_native_stub.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_delegate.h @@ -13,27 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef NNACL_KERNEL_CONVOLUTION_DELEGATE_H_ +#define NNACL_KERNEL_CONVOLUTION_DELEGATE_H_ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_OPS_ASCEND_NATIVE_STUB_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_OPS_ASCEND_NATIVE_STUB_H_ -#include -#include -#include -#include +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/kernel/convolution_base.h" +#include "nnacl/conv_parameter.h" -#include "ops/base_operator.h" +typedef struct ConvolutionDelegateStruct { + ConvolutionBaseStruct conv_; /* used for Conv2dFusion basic operator */ + ConvolutionBaseStruct *convolution_; /* real running conv */ + float *origin_weight_; + float *origin_bias_; + bool need_free_weight_; + bool need_free_bias_; + bool input_const_; + bool weight_const_; +} ConvolutionDelegateStruct; -namespace mindspore { -namespace ops { -constexpr auto kNameAscendNativeStub = "AscendNativeStub"; -/// \brief Custom defined user-defined operator prototype. -class MIND_API AscendNativeStub : public BaseOperator { - public: - MIND_API_BASE_MEMBER(AscendNativeStub); - /// \brief Constructor. - AscendNativeStub() : BaseOperator(kNameAscendNativeStub) {} -}; -} // namespace ops -} // namespace mindspore +KernelBase *CreateConvlutionDelegate(ConvParameter *conv_param); +KernelBase *CreateConv2DFusion(OpParameter *param, int data_type); -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_OPS_ASCEND_NATIVE_STUB_H_ +#endif // NNACL_KERNEL_CONVOLUTION_DELEGATE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise.c new file mode 100644 index 0000000000000000000000000000000000000000..3ee70f2db759413328a9ae6adf1a5d7d48e13aed --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise.c @@ -0,0 +1,229 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/convolution_depthwise.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/base/conv_common_base.h" +#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl/fp32/pack_fp32.h" +#ifdef ENABLE_AVX512 +#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#endif +#include "nnacl/fp32/conv_depthwise_avx_fp32.h" + +int ConvDwRun(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + +#ifdef ENABLE_AVX512 + if (X86_Avx512_Support()) { + return ConvDwAVX512(conv_dw->output_ptr_, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, task_id, &conv_dw->dw_param_); + } else { + return ConvDwAVX(conv_dw->output_ptr_, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, task_id, &conv_dw->dw_param_); + } +#endif + +#ifdef ENABLE_AVX + return ConvDwAVX(conv_dw->output_ptr_, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, task_id, &conv_dw->dw_param_); +#endif + + return ConvDw(conv_dw->output_ptr_, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, task_id); +} + +void ConvDwReleaseParam(ConvolutionDepthwiseStruct *conv_dw) { + ExecEnv *env = conv_dw->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + + if (conv_dw->dw_param_.num_pixels_ != NULL) { + env->Free(env->allocator_, conv_dw->dw_param_.num_pixels_); + conv_dw->dw_param_.num_pixels_ = NULL; + } + if (conv_dw->dw_param_.out_w_start_ != NULL) { + env->Free(env->allocator_, conv_dw->dw_param_.out_w_start_); + conv_dw->dw_param_.out_w_start_ = NULL; + } + if (conv_dw->dw_param_.out_w_end_ != NULL) { + env->Free(env->allocator_, conv_dw->dw_param_.out_w_end_); + conv_dw->dw_param_.out_w_end_ = NULL; + } +} + +void ConvDwPackWeight(ConvolutionBaseStruct *conv) { + void *origin_data = conv->base_.in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_VOID(origin_data); + PackWeightKHWToHWKFp32(origin_data, conv->packed_weight_, conv->compute_.kernel_hw_, conv->compute_.out_c_); +} + +int ConvDwMallocWeightBiasData(ConvolutionBaseStruct *conv) { + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(weight_tensor); + + int pack_weight_size = conv->compute_.kernel_hw_ * conv->compute_.out_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(pack_weight_size, sizeof(float), NNACL_ERR); + + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + NNACL_CHECK_MALLOC_SIZE(conv->compute_.out_c_ * sizeof(float)); + if (conv->bias_data_ == NULL) { + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, conv->compute_.out_c_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, conv->compute_.out_c_ * sizeof(float)); + return NNACL_OK; +} + +int ConvDwInitConvDwCalcInfo(ConvolutionDepthwiseStruct *conv_dw) { + ExecEnv *env = conv_dw->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + ConvComputeParam *compute = &conv_dw->conv_.compute_; + NNACL_CHECK_NULL_RETURN_ERR(compute); + + ConvDwReleaseParam(conv_dw); + + conv_dw->dw_param_.num_pixels_ = env->Alloc(env->allocator_, compute->kernel_w_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.num_pixels_); + + conv_dw->dw_param_.out_w_start_ = env->Alloc(env->allocator_, compute->kernel_w_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.out_w_start_); + + conv_dw->dw_param_.out_w_end_ = env->Alloc(env->allocator_, compute->kernel_w_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.out_w_end_); + + int *num_pixels = (int *)(conv_dw->dw_param_.num_pixels_); + int *out_w_start = (int *)(conv_dw->dw_param_.out_w_start_); + int *out_w_end = (int *)(conv_dw->dw_param_.out_w_end_); + conv_dw->dw_param_.first_calc_kw_ = -1; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->dilation_w_, (compute->kernel_w_ - 1), NNACL_ERR); + for (int kw = 0; kw < compute->kernel_w_; kw++) { + out_w_start[kw] = + NNACL_MAX(0, (compute->pad_l_ - compute->dilation_w_ * kw + compute->stride_w_ - 1) / compute->stride_w_); + + out_w_end[kw] = NNACL_MIN( + (compute->in_w_ + compute->pad_l_ - compute->dilation_w_ * kw + compute->stride_w_ - 1) / compute->stride_w_, + compute->out_w_); + + num_pixels[kw] = out_w_end[kw] - out_w_start[kw]; + if (conv_dw->dw_param_.first_calc_kw_ == -1 && out_w_start[kw] == 0 && num_pixels[kw] == compute->out_w_) { + conv_dw->dw_param_.first_calc_kw_ = kw; + } + } + return NNACL_OK; +} + +int ConvolutionDepthwisePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + TensorC *weight_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(weight_tensor); + NNACL_CHECK_TRUE_RET(weight_tensor->shape_size_ == DIMENSION_4D, NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + + int weight_size_hw = NNACLGetHeight(weight_tensor) * NNACLGetWidth(weight_tensor); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(NNACLGetBatch(weight_tensor), weight_size_hw, NNACL_ERR); + int pack_weight_size = NNACLGetBatch(weight_tensor) * weight_size_hw; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(pack_weight_size, sizeof(float), NNACL_ERR); + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwiseCompute(KernelBase *self) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + conv_dw->input_ptr_ = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->input_ptr_); + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + conv_dw->output_ptr_ = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->output_ptr_); + + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.num_pixels_); + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.out_w_start_); + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->dw_param_.out_w_end_); + + return self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDwRun, self, self->thread_nr_); +} + +int ConvolutionDepthwiseResize(KernelBase *self) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int ret = ConvBasePrepare(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + self->thread_nr_ = NNACL_MIN(self->thread_nr_, conv_dw->conv_.compute_.out_h_); + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + + ret = ConvDwInitConvDwCalcInfo(conv_dw); + if (ret != NNACL_OK) { + return ret; + } + + return NNACL_OK; +} + +int ConvolutionDepthwiseRelease(KernelBase *self) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvDwReleaseParam(conv_dw); + + ConvBaseRelease(&conv_dw->conv_); + return NNACL_OK; +} + +KernelBase *CreateConvDw(ConvParameter *conv) { + ConvolutionDepthwiseStruct *conv_dw = (ConvolutionDepthwiseStruct *)malloc(sizeof(ConvolutionDepthwiseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwiseStruct)); + + conv_dw->conv_.pack_weight_ = ConvDwPackWeight; + conv_dw->conv_.malloc_weight_bias_ = ConvDwMallocWeightBiasData; + conv_dw->conv_.base_.Prepare = ConvolutionDepthwisePrepare; + conv_dw->conv_.base_.Compute = ConvolutionDepthwiseCompute; + conv_dw->conv_.base_.Resize = ConvolutionDepthwiseResize; + conv_dw->conv_.base_.Release = ConvolutionDepthwiseRelease; + return (KernelBase *)conv_dw; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise.h new file mode 100644 index 0000000000000000000000000000000000000000..5a1c0b5206d24e7dc3e359f8abcb165bd7d4c48f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwiseStruct { + ConvolutionBaseStruct conv_; + ConvDwCalcParam dw_param_; + float *input_ptr_; + float *output_ptr_; +} ConvolutionDepthwiseStruct; + +int ConvolutionDepthwiseRelease(KernelBase *self); +KernelBase *CreateConvDw(ConvParameter *conv); + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_3x3.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_3x3.c new file mode 100644 index 0000000000000000000000000000000000000000..e3fe38f143d50e931c5b8382033e215d57eb3ada --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_3x3.c @@ -0,0 +1,154 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +#include "nnacl/kernel/convolution_depthwise_3x3.h" +#include "nnacl/kernel/convolution_base.h" +#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl/fp32/pack_fp32.h" + +int ConvDw3x3Run(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwise3x3Struct *conv_dw = (ConvolutionDepthwise3x3Struct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int units = UP_DIV(conv_dw->conv_.compute_.out_w_, C2NUM); // F(2, 3) contains 2 conv units + int c4 = UP_ROUND(conv_dw->conv_.compute_.in_c_, C4NUM); + int c12c4_units = C12NUM * c4 * units; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(c12c4_units, task_id, NNACL_ERR); + float *buffer = conv_dw->buffer_ + c12c4_units * task_id; + NNACL_CHECK_ZERO_RETURN_ERR(conv_dw->conv_.base_.thread_nr_); + + int step_oh = UP_DIV(conv_dw->conv_.compute_.out_h_, conv_dw->conv_.base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(step_oh, task_id, NNACL_ERR); + int start_oh = step_oh * task_id; + int end_oh = MSMIN(start_oh + step_oh, conv_dw->conv_.compute_.out_h_); + + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + ConvDw3x3(conv_dw->output_ptr_, buffer, conv_dw->input_ptr_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, start_oh, end_oh); + return NNACL_OK; +} + +void ConvDw3x3PackWeight(ConvolutionBaseStruct *conv) { + void *origin_weight = (conv->base_.train_session_) ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + PackWeightConvDw3x3Fp32((float *)origin_weight, (float *)conv->packed_weight_, conv->compute_.out_c_); +} + +int ConvDw3x3MallocWeightBiasData(ConvolutionBaseStruct *conv) { + int c4 = UP_ROUND(conv->compute_.out_c_, C4NUM); + if (!conv->base_.train_session_) { + if (conv->packed_weight_ == NULL) { + int pack_weight_size = c4 * C12NUM; + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + } + + if (conv->bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(c4 * sizeof(float)); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, c4 * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, c4 * sizeof(float)); + return NNACL_OK; +} + +int ConvolutionDepthwise3x3Resize(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + int ret = ConvBasePrepare(conv); + if (ret != NNACL_OK) { + return ret; + } + self->thread_nr_ = NNACL_MIN(self->thread_nr_, conv->compute_.out_h_); + return NNACL_OK; +} + +int ConvolutionDepthwise3x3Prepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionDepthwise3x3Struct *conv_dw = (ConvolutionDepthwise3x3Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + int c4 = UP_ROUND(conv_dw->conv_.compute_.out_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(c4, C12NUM, NNACL_ERR); + int pack_weight_size = c4 * C12NUM; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwise3x3Compute(KernelBase *self) { + ConvolutionDepthwise3x3Struct *conv_dw = (ConvolutionDepthwise3x3Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int units = UP_DIV(conv_dw->conv_.compute_.out_w_, C2NUM); // F(2, 3) contains 2 conv units + int c4 = UP_ROUND(conv_dw->conv_.compute_.in_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(C12NUM, c4, NNACL_ERR); + int c12c4 = C12NUM * c4; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(c12c4, units, NNACL_ERR); + int c12c4_units = c12c4 * units; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(c12c4_units, self->thread_nr_, NNACL_ERR); + int buffer_size = c12c4_units * self->thread_nr_; + + conv_dw->buffer_ = self->env_->Alloc(self->env_->allocator_, buffer_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->buffer_); + + int ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + self->env_->Free(self->env_->allocator_, conv_dw->buffer_); + return ret; + } + + conv_dw->input_ptr_ = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->input_ptr_); + conv_dw->output_ptr_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->output_ptr_); + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDw3x3Run, self, self->thread_nr_); + self->env_->Free(self->env_->allocator_, conv_dw->buffer_); + return ret; +} + +int ConvolutionDepthwise3x3Release(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + ConvBaseRelease(conv); + return NNACL_OK; +} + +KernelBase *CreateConvDw3x3(ConvParameter *conv_param) { + ConvolutionDepthwise3x3Struct *conv_dw = + (ConvolutionDepthwise3x3Struct *)malloc(sizeof(ConvolutionDepthwise3x3Struct)); + NNACL_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwise3x3Struct)); + conv_dw->conv_.pack_weight_ = ConvDw3x3PackWeight; + conv_dw->conv_.malloc_weight_bias_ = ConvDw3x3MallocWeightBiasData; + conv_dw->conv_.base_.Resize = ConvolutionDepthwise3x3Resize; + conv_dw->conv_.base_.Prepare = ConvolutionDepthwise3x3Prepare; + conv_dw->conv_.base_.Compute = ConvolutionDepthwise3x3Compute; + conv_dw->conv_.base_.Release = ConvolutionDepthwise3x3Release; + + return (KernelBase *)conv_dw; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_3x3.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_3x3.h new file mode 100644 index 0000000000000000000000000000000000000000..1b9b26c229852c2ae112e037285d3dd885173df1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_3x3.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_3X3_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_3X3_H_ + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwise3x3Struct { + ConvolutionBaseStruct conv_; + float *buffer_; + float *input_ptr_; + float *output_ptr_; +} ConvolutionDepthwise3x3Struct; + +KernelBase *CreateConvDw3x3(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_3X3_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_indirect.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_indirect.c new file mode 100644 index 0000000000000000000000000000000000000000..346cde052988672f252751833ec495ddc97380ed --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_indirect.c @@ -0,0 +1,227 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/convolution_depthwise_indirect.h" +#include "nnacl/kernel/convolution_base.h" +#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl/fp32/pack_fp32.h" + +int ConvDwIndirectMallocIndirectBuffer(ConvolutionDepthwiseIndirectStruct *conv_dw) { + ConvComputeParam *compute = &conv_dw->conv_.compute_; + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(compute); + + conv_dw->step_w_ = compute->dilation_w_ == 1 ? compute->stride_w_ : compute->kernel_w_; + int step_w_2d = conv_dw->step_w_ * compute->kernel_h_; + conv_dw->step_h_ = (compute->kernel_h_ * compute->kernel_w_) + (compute->out_w_ - 1) * step_w_2d; + int step_h_2d = compute->out_h_ * conv_dw->step_h_; + int buffer_size = compute->out_n_ * step_h_2d; + + ExecEnv *env = conv_dw->conv_.base_.env_; + conv_dw->indirect_buffer_ = (float **)(env->Alloc(env->allocator_, buffer_size * sizeof(float *))); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->indirect_buffer_); + return NNACL_OK; +} + +int ConvDwIndirectRun(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvDwIndirection(conv_dw->output_ptr_, conv_dw->indirect_buffer_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_dw->zero_ptr_, conv_param, task_id); + return NNACL_OK; +} + +int ConvDwIndirectMallocPackedInput(ConvolutionDepthwiseIndirectStruct *conv_dw) { + int IC_DIV = UP_DIV(conv_dw->conv_.compute_.in_c_, conv_dw->div_flag_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.in_n_, conv_dw->conv_.compute_.in_hw_, NNACL_ERR); + int conv_input_bhw = conv_dw->conv_.compute_.in_n_ * conv_dw->conv_.compute_.in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_input_bhw, conv_dw->div_flag_ * IC_DIV, NNACL_ERR); + int pack_input_size = conv_input_bhw * conv_dw->div_flag_ * IC_DIV; + conv_dw->packed_input_ = + conv_dw->conv_.base_.env_->Alloc(conv_dw->conv_.base_.env_->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_input_); + return NNACL_OK; +} + +void ConvDwIndirectPackWeight(ConvolutionBaseStruct *conv) { + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(weight_tensor); + void *origin_weight = (conv->base_.train_session_) ? weight_tensor->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + +#ifdef ENABLE_AVX + PackDepthwiseIndirectWeightC8Fp32(origin_weight, conv->packed_weight_, conv->compute_.kernel_h_, + conv->compute_.kernel_w_, conv->compute_.out_c_); +#else + PackDepthwiseIndirectWeightC4Fp32(origin_weight, conv->packed_weight_, conv->compute_.kernel_h_, + conv->compute_.kernel_w_, conv->compute_.out_c_); +#endif +} + +int ConvDwIndirectMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)conv; + ExecEnv *env = conv->base_.env_; + + int batch_flag = UP_DIV(conv->compute_.out_c_, conv_dw->div_flag_); + int pack_weight_size = conv_dw->div_flag_ * batch_flag * conv->compute_.kernel_hw_; + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + // malloc zero ptr + NNACL_CHECK_MALLOC_SIZE(batch_flag * conv_dw->div_flag_ * sizeof(float)); + conv_dw->zero_ptr_ = (float *)env->Alloc(env->allocator_, batch_flag * conv_dw->div_flag_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->zero_ptr_); + memset(conv_dw->zero_ptr_, 0, batch_flag * conv_dw->div_flag_ * sizeof(float)); + + // malloc bias ptr + if (conv->bias_data_ == NULL) { + conv->bias_data_ = env->Alloc(env->allocator_, batch_flag * conv_dw->div_flag_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, batch_flag * conv_dw->div_flag_ * sizeof(float)); + return NNACL_OK; +} + +int ConvolutionDepthwiseIndirectCompute(KernelBase *self) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + void *input_ptr = input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + + if (conv_dw->conv_.compute_.in_c_ % conv_dw->div_flag_ != 0) { + int ret = ConvDwIndirectMallocPackedInput(conv_dw); + if (ret != NNACL_OK) { + return ret; + } +#ifdef ENABLE_AVX + PackNHWCToNHWC8Fp32(input_ptr, conv_dw->packed_input_, conv_dw->conv_.compute_.in_n_, + conv_dw->conv_.compute_.in_hw_, conv_dw->conv_.compute_.in_c_); +#else + PackNHWCToNHWC4Fp32(input_ptr, conv_dw->packed_input_, conv_dw->conv_.compute_.in_n_, + conv_dw->conv_.compute_.in_hw_, conv_dw->conv_.compute_.in_c_); +#endif + } else { + conv_dw->packed_input_ = input_ptr; + } + + int ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + conv_dw->output_ptr_ = output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw->output_ptr_); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvDwInitIndirection(conv_dw->indirect_buffer_, conv_dw->packed_input_, conv_dw->zero_ptr_, conv_param, + conv_dw->step_h_, conv_dw->step_w_); + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDwIndirectRun, self, self->thread_nr_); + + if (conv_dw->conv_.compute_.in_c_ % conv_dw->div_flag_ != 0) { + self->env_->Free(self->env_->allocator_, conv_dw->packed_input_); + } + return ret; +} +int ConvolutionDepthwiseIndirectResize(KernelBase *self) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + if (conv_dw->indirect_buffer_ != NULL) { + self->env_->Free(self->env_->allocator_, conv_dw->indirect_buffer_); + conv_dw->indirect_buffer_ = NULL; + } + + int ret = ConvBasePrepare(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvDwIndirectMallocIndirectBuffer(conv_dw); + if (ret != NNACL_OK) { + return ret; + } + + self->thread_nr_ = NNACL_MIN(self->thread_nr_, conv_dw->conv_.compute_.out_h_); + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + return NNACL_OK; +} + +int ConvolutionDepthwiseIndirectPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + int batch_flag = UP_DIV(conv_dw->conv_.compute_.out_c_, conv_dw->div_flag_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->div_flag_ * batch_flag, conv_dw->conv_.compute_.kernel_hw_, NNACL_ERR); + int pack_weight_size = conv_dw->div_flag_ * batch_flag * conv_dw->conv_.compute_.kernel_hw_; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwiseIndirectRelease(KernelBase *self) { + ConvolutionDepthwiseIndirectStruct *conv_dw = (ConvolutionDepthwiseIndirectStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + if (conv_dw->zero_ptr_ != NULL) { + self->env_->Free(self->env_->allocator_, conv_dw->zero_ptr_); + conv_dw->zero_ptr_ = NULL; + } + if (conv_dw->indirect_buffer_ != NULL) { + self->env_->Free(self->env_->allocator_, conv_dw->indirect_buffer_); + conv_dw->indirect_buffer_ = NULL; + } + ConvBaseRelease(&conv_dw->conv_); + return NNACL_OK; +} + +KernelBase *CreateConvDwIndirect(ConvParameter *conv_param) { + ConvolutionDepthwiseIndirectStruct *conv_dw = + (ConvolutionDepthwiseIndirectStruct *)malloc(sizeof(ConvolutionDepthwiseIndirectStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwiseIndirectStruct)); + +#ifdef ENABLE_AVX + conv_dw->div_flag_ = C8NUM; +#else + conv_dw->div_flag_ = C4NUM; +#endif + conv_dw->conv_.pack_weight_ = ConvDwIndirectPackWeight; + conv_dw->conv_.malloc_weight_bias_ = ConvDwIndirectMallocWeightBiasData; + + conv_dw->conv_.base_.Compute = ConvolutionDepthwiseIndirectCompute; + conv_dw->conv_.base_.Resize = ConvolutionDepthwiseIndirectResize; + conv_dw->conv_.base_.Prepare = ConvolutionDepthwiseIndirectPrepare; + conv_dw->conv_.base_.Release = ConvolutionDepthwiseIndirectRelease; + + return (KernelBase *)conv_dw; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_indirect.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_indirect.h new file mode 100644 index 0000000000000000000000000000000000000000..d896a6c62f54abf1b7102e14eb6b763b04d116fe --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_indirect.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_INDIRECT_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_INDIRECT_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwiseIndirectStruct { + ConvolutionBaseStruct conv_; + int div_flag_; + int step_w_; + int step_h_; + float *zero_ptr_; + float *output_ptr_; + float *packed_input_; + float **indirect_buffer_; +} ConvolutionDepthwiseIndirectStruct; + +KernelBase *CreateConvDwIndirect(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_INDIRECT_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_sw.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_sw.c new file mode 100644 index 0000000000000000000000000000000000000000..91f72808850c758d58d1324b1f717aa4236941b0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_sw.c @@ -0,0 +1,200 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/convolution_depthwise_sw.h" +#include "nnacl/kernel/convolution_base.h" +#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl/fp32/pack_fp32.h" + +int ConvDwSWMallocWeightBiasData(ConvolutionBaseStruct *conv) { + int OC4 = UP_DIV(conv->compute_.out_c_, C4NUM); + int pack_weight_size = C4NUM * OC4 * conv->compute_.kernel_hw_; + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + int malloc_size = NNACL_MAX(conv->compute_.out_c_, C4NUM * OC4); + if (conv->bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(malloc_size * sizeof(float)); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, malloc_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, malloc_size * sizeof(float)); + conv->base_.thread_nr_ = NNACL_MIN(conv->base_.thread_nr_, OC4); + return NNACL_OK; +} + +int ConvDwSWInitPackedInputOutput(ConvolutionDepthwiseSWStruct *conv_dw) { + if (conv_dw->conv_.compute_.in_c_ % C4NUM == 0) { + conv_dw->need_align_ = false; + return NNACL_OK; + } + + conv_dw->need_align_ = true; + int IC4 = UP_DIV(conv_dw->conv_.compute_.in_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.in_n_, conv_dw->conv_.compute_.in_hw_, NNACL_ERR); + int conv_input_bhw = conv_dw->conv_.compute_.in_n_ * conv_dw->conv_.compute_.in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_input_bhw, C4NUM * IC4, NNACL_ERR); + int pack_input_size = conv_input_bhw * C4NUM * IC4; + NNACL_CHECK_MALLOC_SIZE(pack_input_size * sizeof(float)); + conv_dw->packed_input_ = + (float *)conv_dw->conv_.base_.env_->Alloc(conv_dw->conv_.base_.env_->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_input_); + + int OC4 = UP_DIV(conv_dw->conv_.compute_.out_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.out_n_, conv_dw->conv_.compute_.out_hw_, NNACL_ERR); + int output_bhw = conv_dw->conv_.compute_.out_n_ * conv_dw->conv_.compute_.out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, C4NUM * OC4, NNACL_ERR); + int pack_output_size = output_bhw * C4NUM * OC4; + NNACL_CHECK_MALLOC_SIZE(pack_output_size * sizeof(float)); + conv_dw->packed_output_ = + (float *)conv_dw->conv_.base_.env_->Alloc(conv_dw->conv_.base_.env_->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_output_); + return NNACL_OK; +} + +int ConvDwSWRun(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvDwSWFp32(conv_dw->packed_output_, conv_dw->packed_input_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, &conv_dw->sliding_, task_id); + return NNACL_OK; +} + +void ConvDwSWFreePackedInputOutput(ConvolutionDepthwiseSWStruct *conv_dw) { + if (conv_dw->need_align_) { + conv_dw->conv_.base_.env_->Free(conv_dw->conv_.base_.env_->allocator_, conv_dw->packed_input_); + conv_dw->packed_input_ = NULL; + conv_dw->conv_.base_.env_->Free(conv_dw->conv_.base_.env_->allocator_, conv_dw->packed_output_); + conv_dw->packed_output_ = NULL; + } +} + +void ConvDwSWPackWeight(ConvolutionBaseStruct *conv) { + void *origin_weight = (conv->base_.train_session_) ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + PackNCHWToNC4HW4Fp32(origin_weight, conv->packed_weight_, 1, conv->compute_.kernel_hw_, conv->compute_.out_c_); +} + +int ConvolutionDepthwiseSWResize(KernelBase *self) { + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + int ret = ConvBasePrepare(&conv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + InitSlidingParamConvDw(&conv_dw->sliding_, conv_param, C4NUM); + + self->thread_nr_ = NNACL_MIN(self->thread_nr_, conv_dw->conv_.compute_.out_h_); + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + return NNACL_OK; +} + +int ConvolutionDepthwiseSWCompute(KernelBase *self) { + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int ret = ConvDwSWInitPackedInputOutput(conv_dw); + if (ret != NNACL_OK) { + ConvDwSWFreePackedInputOutput(conv_dw); + return ret; + } + + ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + ConvDwSWFreePackedInputOutput(conv_dw); + return ret; + } + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_ptr = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + if (conv_dw->need_align_) { + PackNHWCToNHWC4Fp32(input_ptr, conv_dw->packed_input_, conv_dw->conv_.compute_.in_n_, + conv_dw->conv_.compute_.in_hw_, conv_dw->conv_.compute_.in_c_); + } else { + conv_dw->packed_input_ = input_ptr; + } + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *output_ptr = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + if (!conv_dw->need_align_) { + conv_dw->packed_output_ = output_ptr; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDwSWRun, self, self->thread_nr_); + + if (conv_dw->need_align_) { + PackNHWCXToNHWCFp32(conv_dw->packed_output_, output_ptr, conv_dw->conv_.compute_.out_n_, + conv_dw->conv_.compute_.out_hw_, conv_dw->conv_.compute_.out_c_, C4NUM); + } + + ConvDwSWFreePackedInputOutput(conv_dw); + return ret; +} + +int ConvolutionDdepthwiseSWPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + int OC4 = UP_DIV(conv_dw->conv_.compute_.out_c_, C4NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(C4NUM * OC4, conv_dw->conv_.compute_.kernel_hw_, NNACL_ERR); + int pack_weight_size = C4NUM * OC4 * conv_dw->conv_.compute_.kernel_hw_; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwiseSWRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +KernelBase *CreateConvDwSW(ConvParameter *conv_param) { + ConvolutionDepthwiseSWStruct *conv_dw = (ConvolutionDepthwiseSWStruct *)malloc(sizeof(ConvolutionDepthwiseSWStruct)); + NNACL_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwiseSWStruct)); + + conv_dw->conv_.malloc_weight_bias_ = ConvDwSWMallocWeightBiasData; + conv_dw->conv_.pack_weight_ = ConvDwSWPackWeight; + conv_dw->conv_.base_.Resize = ConvolutionDepthwiseSWResize; + conv_dw->conv_.base_.Compute = ConvolutionDepthwiseSWCompute; + conv_dw->conv_.base_.Prepare = ConvolutionDdepthwiseSWPrepare; + conv_dw->conv_.base_.Release = ConvolutionDepthwiseSWRelease; + return (KernelBase *)conv_dw; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_sw.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_sw.h new file mode 100644 index 0000000000000000000000000000000000000000..99babb6a8a1783c5571b753288abcb7efe0bc001 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_sw.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" + +typedef struct ConvolutionDepthwiseSWStruct { + ConvolutionBaseStruct conv_; + SlidingWindowParam sliding_; + float *packed_input_; + float *packed_output_; + bool need_align_; +} ConvolutionDepthwiseSWStruct; + +KernelBase *CreateConvDwSW(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_sw_avx.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_sw_avx.c new file mode 100644 index 0000000000000000000000000000000000000000..31ed850625a751c077f875b2af0f700d51def6ee --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_sw_avx.c @@ -0,0 +1,216 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_AVX +#include "nnacl/kernel/convolution_depthwise_sw_avx.h" +#include "nnacl/kernel/convolution_base.h" +#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/tensor_c.h" + +int ConvDwSWAVXInitPackedInputOutput(ConvolutionDepthwiseSWAVXStruct *conv_dw) { + conv_dw->input_need_align_ = (conv_dw->conv_.compute_.in_c_ % conv_dw->oc_tile_ != 0); + conv_dw->output_need_align_ = (conv_dw->conv_.compute_.out_c_ % conv_dw->oc_tile_ != 0); + + ExecEnv *env = conv_dw->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + if (conv_dw->input_need_align_) { + int ic_algin = UP_DIV(conv_dw->conv_.compute_.in_c_, conv_dw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.in_n_, conv_dw->conv_.compute_.in_hw_, NNACL_ERR); + int input_bhw = conv_dw->conv_.compute_.in_n_ * conv_dw->conv_.compute_.in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_bhw, conv_dw->oc_tile_ * ic_algin, NNACL_ERR); + int pack_input_size = input_bhw * conv_dw->oc_tile_ * ic_algin; + conv_dw->packed_input_ = (float *)env->Alloc(env->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_input_); + } + + if (conv_dw->output_need_align_) { + int oc_algin = UP_DIV(conv_dw->conv_.compute_.out_c_, conv_dw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv_dw->conv_.compute_.out_n_, conv_dw->conv_.compute_.out_hw_, NNACL_ERR); + int output_bhw = conv_dw->conv_.compute_.out_n_ * conv_dw->conv_.compute_.out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, conv_dw->oc_tile_ * oc_algin, NNACL_ERR); + int pack_output_size = output_bhw * conv_dw->oc_tile_ * oc_algin; + conv_dw->packed_output_ = (float *)env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_dw->packed_output_); + } + + return NNACL_OK; +} + +void ConvDwSWAVXPackWeight(ConvolutionBaseStruct *conv) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_dw); + + int oc_algin = UP_DIV(conv->compute_.out_c_, conv_dw->oc_tile_); + void *origin_weight = conv->base_.train_session_ ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + + PackNHWCToNXHWCXFp32(conv->compute_.kernel_h_, conv->compute_.kernel_w_, conv->compute_.out_c_, oc_algin, 1, + (float *)conv->packed_weight_, (float *)conv->origin_weight_); +} + +int ConvDwSWAVXMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int oc_algin = UP_DIV(conv->compute_.out_c_, conv_dw->oc_tile_); + int pack_weight_size = oc_algin * conv_dw->oc_tile_ * conv->compute_.kernel_hw_; + + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + if (conv->base_.in_size_ == THREE_TENSOR) { + int bias_size = oc_algin * conv_dw->oc_tile_; + NNACL_CHECK_MALLOC_SIZE(bias_size * sizeof(float)); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, bias_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + memset(conv->bias_data_, 0, bias_size * sizeof(float)); + } + return NNACL_OK; +} + +int ConvDwSWAvxRun(void *cdata, int task_id, float l, float r) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)conv_dw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + DepthwiseSWAvxFp32(conv_dw->packed_output_, conv_dw->packed_input_, (float *)conv_dw->conv_.packed_weight_, + (float *)conv_dw->conv_.bias_data_, conv_param, &conv_dw->sliding_param_, task_id); + return NNACL_OK; +} + +void ConvDwSWAVXFreePackedInputOutput(ConvolutionDepthwiseSWAVXStruct *conv_dw) { + if (conv_dw->input_need_align_) { + conv_dw->conv_.base_.env_->Free(conv_dw->conv_.base_.env_->allocator_, conv_dw->packed_input_); + conv_dw->packed_input_ = NULL; + conv_dw->input_need_align_ = false; + } + if (conv_dw->output_need_align_) { + conv_dw->conv_.base_.env_->Free(conv_dw->conv_.base_.env_->allocator_, conv_dw->packed_output_); + conv_dw->packed_output_ = NULL; + conv_dw->output_need_align_ = false; + } +} + +int ConvolutionDepthwiseSWAVXCompute(KernelBase *self) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + + int ret = ConvDwSWAVXInitPackedInputOutput(conv_dw); + if (ret != NNACL_OK) { + ConvDwSWAVXFreePackedInputOutput(conv_dw); + return ret; + } + + ret = ConvBaseRepackWeight(&conv_dw->conv_); + if (ret != NNACL_OK) { + ConvDwSWAVXFreePackedInputOutput(conv_dw); + return ret; + } + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_ptr = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + + if (conv_dw->input_need_align_) { + PackNHWCToNHWCXFp32(input_ptr, conv_dw->packed_input_, conv_dw->conv_.compute_.in_n_, + conv_dw->conv_.compute_.in_hw_, conv_dw->conv_.compute_.in_c_, conv_dw->oc_tile_); + } else { + conv_dw->packed_input_ = input_ptr; + } + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *output_ptr = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + + if (!conv_dw->output_need_align_) { + conv_dw->packed_output_ = output_ptr; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvDwSWAvxRun, self, self->thread_nr_); + + if (conv_dw->output_need_align_) { + PackNHWCXToNHWCFp32(conv_dw->packed_output_, output_ptr, conv_dw->conv_.compute_.out_n_, + conv_dw->conv_.compute_.out_hw_, conv_dw->conv_.compute_.out_c_, conv_dw->oc_tile_); + } + + ConvDwSWAVXFreePackedInputOutput(conv_dw); + return ret; +} + +int ConvolutionDepthwiseSWAVXPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + conv_dw->oc_tile_ = C8NUM; + ConvBaseUpdateOriginWeightAndBias(&conv_dw->conv_); + + if (self->train_session_) { + int oc_algin = UP_DIV(conv_dw->conv_.compute_.out_c_, conv_dw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(oc_algin * conv_dw->oc_tile_, conv_dw->conv_.compute_.kernel_hw_, NNACL_ERR); + int pack_weight_size = oc_algin * conv_dw->oc_tile_ * conv_dw->conv_.compute_.kernel_hw_; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_dw->conv_); +} + +int ConvolutionDepthwiseSWAVXResize(KernelBase *self) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = (ConvolutionDepthwiseSWAVXStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_dw); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + ConvBasePrepare(&conv_dw->conv_); + + InitSlidingParamConvDw(&conv_dw->sliding_param_, conv_param, conv_dw->oc_tile_); + return NNACL_OK; +} + +int ConvolutionDepthwiseSWAVXRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +KernelBase *CreateConvDwSWAVX(ConvParameter *conv_param) { + ConvolutionDepthwiseSWAVXStruct *conv_dw = + (ConvolutionDepthwiseSWAVXStruct *)malloc(sizeof(ConvolutionDepthwiseSWAVXStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_dw); + memset(conv_dw, 0, sizeof(ConvolutionDepthwiseSWAVXStruct)); + + conv_dw->conv_.pack_weight_ = ConvDwSWAVXPackWeight; + conv_dw->conv_.malloc_weight_bias_ = ConvDwSWAVXMallocWeightBiasData; + + conv_dw->conv_.base_.Prepare = ConvolutionDepthwiseSWAVXPrepare; + conv_dw->conv_.base_.Compute = ConvolutionDepthwiseSWAVXCompute; + conv_dw->conv_.base_.Resize = ConvolutionDepthwiseSWAVXResize; + conv_dw->conv_.base_.Release = ConvolutionDepthwiseSWAVXRelease; + return (KernelBase *)conv_dw; +} +#endif diff --git a/mindspore-lite/src/extendrt/graph_compiler/compile_option.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_sw_avx.h similarity index 47% rename from mindspore-lite/src/extendrt/graph_compiler/compile_option.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_sw_avx.h index d0f09657c4520ceb93fd3d21196ad0dbf61b8210..deb49f57d6ee4440505908d4f3749a4621580abf 100644 --- a/mindspore-lite/src/extendrt/graph_compiler/compile_option.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_depthwise_sw_avx.h @@ -14,23 +14,27 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_COMPILER_COMPILE_OPTION_H -#define MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_COMPILER_COMPILE_OPTION_H +#ifndef NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_AVX_H_ +#define NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_AVX_H_ -#include -#include -#include "mindapi/base/format.h" -#include "mindapi/base/type_id.h" -#include "src/extendrt/kernel/kernel_spec_infos.h" +#ifdef ENABLE_AVX +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" -namespace mindspore::lite { -struct CompileOption { - Format graph_format{Format::NCHW}; - Format graph_input_format{Format::NCHW}; - std::string backend{kernel::kBackendCPU}; - TypeId datatype{kNumberTypeFloat32}; -}; +typedef struct ConvolutionDepthwiseSWAVXStruct { + ConvolutionBaseStruct conv_; + SlidingWindowParam sliding_param_; + int oc_tile_; + float *packed_input_; + float *packed_output_; + bool input_need_align_; + bool output_need_align_; +} ConvolutionDepthwiseSWAVXStruct; -using CompileOptionPtr = std::shared_ptr; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_COMPILER_COMPILE_OPTION_H +KernelBase *CreateConvDwSWAVX(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_DEPTHWISE_SW_AVX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col.c new file mode 100644 index 0000000000000000000000000000000000000000..b19807f088a03d6b8e438278ff467c9ec20e5f88 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col.c @@ -0,0 +1,81 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/convolution_im2col.h" +#include "nnacl/kernel/convolution_im2col_base.h" +#ifdef ENABLE_ARM32 +#include "nnacl/kernel/convolution_im2col_arm32.h" +#endif +#ifdef ENABLE_ARM64 +#include "nnacl/kernel/convolution_im2col_arm64.h" +#endif +#ifdef ENABLE_SSE +#include "nnacl/kernel/convolution_im2col_sse.h" +#endif +#ifdef ENABLE_AVX +#include "nnacl/kernel/convolution_im2col_avx.h" +#endif +#ifdef ENABLE_AVX512 +#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#include "nnacl/kernel/convolution_im2col_avx512.h" +#endif + +ConvolutionBaseStruct *CreateConvolutionIm2Col(KernelBase *base, ConvParameter *conv_param) { + ConvolutionBaseStruct *kernel = NULL; + +#ifdef ENABLE_AVX512 + FormatC out_format = base->out_[OUTPUT_INDEX]->format_; + if (out_format != Format_NC4HW4) { + AVX512_HARDWARE_SELF_AWARENESS_BEGIN; + kernel = CreateConvIm2ColAVX512(conv_param); + if (kernel != NULL) { + return kernel; + } + AVX512_HARDWARE_SELF_AWARENESS_END; + } +#endif + +#ifdef ENABLE_AVX + kernel = CreateConvIm2ColAVX(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_SSE + kernel = CreateConvIm2ColSSE(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_ARM64 + kernel = CreateConvIm2ColARM64(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_ARM32 + kernel = CreateConvIm2ColARM32(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + + kernel = CreateConvIm2ColBase(conv_param); + return kernel; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col.h new file mode 100644 index 0000000000000000000000000000000000000000..3c09230bb987fa066c5d7bd36f77c7418a356afd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" + +ConvolutionBaseStruct *CreateConvolutionIm2Col(KernelBase *base, ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/akg/cpu_kernel_builder.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_arm32.c similarity index 30% rename from mindspore-lite/tools/graph_kernel/converter/akg/cpu_kernel_builder.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_arm32.c index c8e1d6410e73763c87b956feb9ed48aa02b40de3..c9b6b98e472527c194ab3f20f76a32ee41d53550 100644 --- a/mindspore-lite/tools/graph_kernel/converter/akg/cpu_kernel_builder.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_arm32.c @@ -9,25 +9,37 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. * See the License for the specific language governing permissions and * limitations under the License. */ +#ifdef ENABLE_ARM32 +#include "nnacl/kernel/convolution_im2col_arm32.h" +#include "nnacl/fp32/pack_fp32.h" -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_CPU_KERNEL_BUILDER_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_CPU_KERNEL_BUILDER_H_ +void ConvIm2ColARM32InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C4NUM; + conv_im2col->row_tile_ = C12NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col4Major; +} -#include -#include -#include "tools/graph_kernel/converter/akg/akg_kernel_builder.h" +ConvolutionBaseStruct *CreateConvIm2ColARM32(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); -namespace mindspore::graphkernel { -class CpuKernelBuilder : public AkgKernelBuilder { - public: - bool CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) override; - AnfNodePtr CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) override; - bool GenerateAkgKernelNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &custom_node, - const CNodePtr &old_cnode) override; -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_CPU_KERNEL_BUILDER_H_ + conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer; + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColARM32InitGlobalVariable; + conv_im2col->conv_.run_impl_ = ConvIm2ColBaseRunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_arm32.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_arm32.h new file mode 100644 index 0000000000000000000000000000000000000000..26bcb00e6c0b4308e7aeb1488f3e028f12c113bf --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_arm32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM32_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM32_H_ + +#ifdef ENABLE_ARM32 +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColARM32(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_arm64.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_arm64.c new file mode 100644 index 0000000000000000000000000000000000000000..d4d3445e70897d4e6763c97247fe81ddf071e80e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_arm64.c @@ -0,0 +1,72 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM64 +#include "nnacl/kernel/convolution_im2col_arm64.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/fp32/conv_common_fp32.h" + +void ConvIm2ColARM64InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C8NUM; + conv_im2col->row_tile_ = C12NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col8Major; +} + +int ConvIm2ColARM64RunImpl(struct ConvolutionBaseStruct *conv, int task_id) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(ori_input_data); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + if (conv->out_format_ != Format_NC4HW4) { + if (conv->use_batch_cut_flag_) { + ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } else { + ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_, + conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param); + } + } else { + ConvFp32OutNC4HW4(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateConvIm2ColARM64(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer; + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColARM64InitGlobalVariable; + conv_im2col->conv_.run_impl_ = ConvIm2ColARM64RunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_arm64.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_arm64.h new file mode 100644 index 0000000000000000000000000000000000000000..0a0035c673714fda2a20cade95acbf885dca879f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_arm64.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM64_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM64_H_ + +#ifdef ENABLE_ARM64 +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColARM64(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_ARM64_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_avx.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_avx.c new file mode 100644 index 0000000000000000000000000000000000000000..3d12bb6dea647658d98fa1af5b043a45f7542bf5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_avx.c @@ -0,0 +1,151 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl/kernel/convolution_im2col_avx.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/fp32/conv_common_fp32.h" + +void ConvIm2ColAVXInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C16NUM; + conv_im2col->row_tile_ = C6NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col16Major; +} + +int ConvIm2ColAVXInitTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col) { + int kernel_chw = conv_im2col->conv_.compute_.kernel_hw_ * conv_im2col->conv_.compute_.in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv_im2col->conv_.base_.thread_nr_, NNACL_ERR); + int total_kernel_chw = kernel_chw * conv_im2col->conv_.base_.thread_nr_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR); + int unit_size = total_kernel_chw * conv_im2col->row_tile_; + + ExecEnv *env = conv_im2col->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + if (conv_im2col->packed_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->packed_input_); + conv_im2col->packed_input_ = NULL; + } + conv_im2col->packed_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_); + + if (conv_im2col->col_major_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->col_major_input_); + conv_im2col->col_major_input_ = NULL; + } + conv_im2col->col_major_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->col_major_input_); + + conv_im2col->output_need_align_ = + conv_im2col->conv_.compute_.out_c_ % conv_im2col->oc_tile_ != 0 && conv_im2col->conv_.out_format_ == Format_NC4HW4; + if (conv_im2col->output_need_align_) { + int oc_algin = UP_DIV(conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_); + int output_bhw = conv_im2col->conv_.compute_.out_n_ * conv_im2col->conv_.compute_.out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, conv_im2col->oc_tile_ * oc_algin, NNACL_ERR); + int pack_output_size = output_bhw * conv_im2col->oc_tile_ * oc_algin; + + if (conv_im2col->tmp_output_ != NULL) { + env->Free(env->allocator_, conv_im2col->tmp_output_); + conv_im2col->tmp_output_ = NULL; + } + conv_im2col->tmp_output_ = env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->tmp_output_); + } + return NNACL_OK; +} + +int ConvIm2ColAVXRunImpl(struct ConvolutionBaseStruct *conv, int task_id) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + float *ori_input_data = conv->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(ori_input_data); + + if (conv->out_format_ != Format_NC4HW4) { + if (conv->use_batch_cut_flag_) { + ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } else { + ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_, + conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param); + } + } else { + ConvFp32OutNC4HW4(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } + return NNACL_OK; +} + +int ConvolutionIm2colAvxCompute(KernelBase *self) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + int ret = conv_im2col->init_tmp_buffer_(conv_im2col); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + float *output_addr = (float *)self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_addr); + if (!conv_im2col->output_need_align_) { + conv_im2col->tmp_output_ = output_addr; + } + + ret = ConvBaseRepackWeight(&conv_im2col->conv_); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_); + + if (conv_im2col->output_need_align_) { + PackNC8HW8AlignedToNC8HW8NotAlignedFp32(conv_im2col->tmp_output_, output_addr, conv_im2col->conv_.compute_.out_n_, + conv_im2col->conv_.compute_.out_w_ * conv_im2col->conv_.compute_.out_h_, + conv_im2col->conv_.compute_.out_c_); + } else { + conv_im2col->tmp_output_ = NULL; + } + + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; +} + +ConvolutionBaseStruct *CreateConvIm2ColAVX(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColAVXInitTmpBuffer; + + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColAVXInitGlobalVariable; + conv_im2col->conv_.run_impl_ = ConvIm2ColAVXRunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colAvxCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_avx.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_avx.h new file mode 100644 index 0000000000000000000000000000000000000000..cbaeed45028e0b90ef9849408da96fe69fdf0098 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_avx.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX_H_ + +#ifdef ENABLE_AVX +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColAVX(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_avx512.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_avx512.c new file mode 100644 index 0000000000000000000000000000000000000000..ecf241252eb6766fd0715eea742de9a50a3e4f7c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_avx512.c @@ -0,0 +1,146 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX512 +#include "nnacl/kernel/convolution_im2col_avx512.h" +#include "nnacl/fp32/conv_im2col_avx512_fp32.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/tensor_c.h" + +void ConvIm2ColAVX512InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C16NUM; + conv_im2col->row_tile_ = + MSMIN(UP_DIV(conv_im2col->conv_.compute_.out_hw_, conv_im2col->conv_.base_.thread_nr_), C150NUM); + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col64Major; +} + +int ConvIm2ColAVX512InitTmpBuffer(struct ConvolutionIm2ColBaseStruct *conv_im2col) { + ExecEnv *env = conv_im2col->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + ConvComputeParam *compute = &conv_im2col->conv_.compute_; + NNACL_CHECK_NULL_RETURN_ERR(compute); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->kernel_hw_, compute->in_c_, NNACL_ERR); + int kernel_chw = compute->kernel_hw_ * compute->in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv_im2col->conv_.base_.thread_nr_, NNACL_ERR); + int total_kernel_chw = kernel_chw * conv_im2col->conv_.base_.thread_nr_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR); + size_t unit_size = total_kernel_chw * conv_im2col->row_tile_; + + if (conv_im2col->packed_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->packed_input_); + conv_im2col->packed_input_ = NULL; + } + conv_im2col->packed_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_); + + conv_im2col->output_need_align_ = compute->out_c_ % conv_im2col->oc_tile_ != 0; + if (conv_im2col->output_need_align_) { + if (conv_im2col->tmp_output_ != NULL) { + env->Free(env->allocator_, conv_im2col->tmp_output_); + conv_im2col->tmp_output_ = NULL; + } + + // avx512 need to malloc dst aligned to C16NUM + int oc_algin = UP_ROUND(compute->out_c_, conv_im2col->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_n_, compute->out_hw_, NNACL_ERR); + int output_bhw = compute->out_n_ * compute->out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, oc_algin, NNACL_ERR); + size_t pack_output_size = output_bhw * compute->out_w_ * oc_algin; + + conv_im2col->tmp_output_ = env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->tmp_output_); + } + + return NNACL_OK; +} + +int ConvIm2ColAVX512RunImpl(struct ConvolutionBaseStruct *conv, int task_id) { + if (conv->out_format_ == Format_NC4HW4) { + return NNACL_CONVOLUTION_AVX512_UNSUPPORT_FORMAT; + } + + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(ori_input_data); + + if (conv->use_batch_cut_flag_) { + ConvIm2ColAVX512Fp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->tmp_output_, task_id, conv_param, + conv_im2col->row_tile_); + } else { + ConvIm2ColAVX512Fp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->tmp_output_, task_id, conv_param, + conv_im2col->row_tile_); + } + return NNACL_OK; +} + +int ConvolutionIm2colAvx512Compute(KernelBase *self) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + int ret = conv_im2col->init_tmp_buffer_(conv_im2col); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + float *output_addr = (float *)self->out_[OUTPUT_INDEX]->data_; + if (!conv_im2col->output_need_align_) { + conv_im2col->tmp_output_ = output_addr; + } + + ret = ConvBaseRepackWeight(&conv_im2col->conv_); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_); + + if (conv_im2col->output_need_align_) { + PackNHWCXToNHWCFp32(conv_im2col->tmp_output_, output_addr, conv_im2col->conv_.compute_.out_n_, + conv_im2col->conv_.compute_.out_hw_, conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_); + } else { + conv_im2col->tmp_output_ = NULL; + } + + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; +} + +ConvolutionBaseStruct *CreateConvIm2ColAVX512(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColAVX512InitTmpBuffer; + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColAVX512InitGlobalVariable; + conv_im2col->conv_.run_impl_ = ConvIm2ColAVX512RunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colAvx512Compute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_avx512.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_avx512.h new file mode 100644 index 0000000000000000000000000000000000000000..f111f6b72ec1de4e67469cee8584593385fc5954 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_avx512.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX512_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX512_H_ + +#ifdef ENABLE_AVX512 +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColAVX512(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_AVX512_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_base.c new file mode 100644 index 0000000000000000000000000000000000000000..3a319fbc04a3fa751d9107ff6553a88118f1be98 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_base.c @@ -0,0 +1,246 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/convolution_im2col_base.h" +#include "nnacl/kernel/convolution_base.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/fp32/conv_common_fp32.h" + +int ConvIm2ColBaseImpl(void *cdata, int task_id, float l, float r) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(cdata); + return conv->run_impl_(conv, task_id); +} + +int ConvIm2ColBaseRunImpl(ConvolutionBaseStruct *conv, int task_id) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + + float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(ori_input_data); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + if (conv->use_batch_cut_flag_) { + ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, + (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, + conv_param); + } else { + ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_, + conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param); + } + return NNACL_OK; +} + +int ConvIm2ColBaseMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + size_t oc_block_num = UP_ROUND(conv->compute_.out_c_, conv_im2col->oc_tile_); + size_t pack_weight_size = oc_block_num * conv->compute_.in_c_ * conv->compute_.kernel_hw_; + if (!conv->base_.train_session_) { + NNACL_CHECK_MALLOC_SIZE(pack_weight_size * sizeof(float)); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + + if (conv->bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(oc_block_num * sizeof(float)); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, oc_block_num * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, oc_block_num * sizeof(float)); + return NNACL_OK; +} + +int ConvIm2ColBaseUpdateThreadNumProcess(KernelBase *self, int32_t kernel_type, int64_t per_unit_load_num, + int64_t per_unit_store_num, int64_t unit_num) { +#ifdef DYNAMIC_THREAD_DISTRIBUTE + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + if (conv_im2col->conv_.compute_.in_n_ % self->thread_nr_ == 0) { + conv_im2col->conv_.use_batch_cut_flag_ = true; + return NNACL_OK; + } else { + conv_im2col->conv_.use_batch_cut_flag_ = false; + } + + int update_thread = UP_DIV(UP_DIV(conv_im2col->conv_.compute_.out_hw_, conv_im2col->row_tile_), ConvMinBlock); + self->thread_nr_ = NNACL_MIN(self->thread_nr_, update_thread); +#else + self->thread_nr_ = self->thread_nr_ > 0 ? self->thread_nr_ : 1; +#endif + return NNACL_OK; +} + +void ConvIm2ColBaseFreeTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col) { + ExecEnv *env = conv_im2col->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + + if (conv_im2col->packed_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->packed_input_); + conv_im2col->packed_input_ = NULL; + } + if (conv_im2col->col_major_input_ != NULL) { + env->Free(env->allocator_, conv_im2col->col_major_input_); + conv_im2col->col_major_input_ = NULL; + } + if (conv_im2col->output_need_align_ && conv_im2col->tmp_output_ != NULL) { + env->Free(env->allocator_, conv_im2col->tmp_output_); + conv_im2col->tmp_output_ = NULL; + conv_im2col->output_need_align_ = false; + } +} + +int ConvIm2ColBaseInitTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)conv_im2col; + TensorC *out_tensor = conv_im2col->conv_.base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_NULL_RETURN_ERR(out_tensor->data_); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(conv->compute_.kernel_hw_, conv->compute_.in_c_, NNACL_ERR); + int kernel_chw = conv->compute_.kernel_hw_ * conv->compute_.in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv->base_.thread_nr_, NNACL_ERR); + int total_kernel_chw = kernel_chw * conv->base_.thread_nr_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR); + int unit_size = total_kernel_chw * conv_im2col->row_tile_; + + if (conv_im2col->packed_input_ != NULL) { + conv->base_.env_->Free(conv->base_.env_->allocator_, conv_im2col->packed_input_); + conv_im2col->packed_input_ = NULL; + } + conv_im2col->packed_input_ = + (float *)conv->base_.env_->Alloc(conv->base_.env_->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_); + + if (conv_im2col->col_major_input_ != NULL) { + conv->base_.env_->Free(conv->base_.env_->allocator_, conv_im2col->col_major_input_); + conv_im2col->col_major_input_ = NULL; + } + conv_im2col->col_major_input_ = + (float *)conv->base_.env_->Alloc(conv->base_.env_->allocator_, unit_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->col_major_input_); + + return NNACL_OK; +} + +void ConvIm2ColBasePackWeight(ConvolutionBaseStruct *conv) { + void *origin_weight = (conv->base_.train_session_) ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_im2col->row_major_to_col_nmajor_); + conv_im2col->row_major_to_col_nmajor_((float *)origin_weight, (float *)conv->packed_weight_, conv->compute_.out_c_, + conv->compute_.in_c_ * conv->compute_.kernel_hw_); +} + +void ConvIm2ColBaseInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C8NUM; + conv_im2col->row_tile_ = C12NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col8Major; +} + +int ConvolutionIm2colBaseRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +int ConvolutionIm2colBaseCompute(KernelBase *self) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + int ret = conv_im2col->init_tmp_buffer_(conv_im2col); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + float *output_addr = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_addr); + if (!conv_im2col->output_need_align_) { + conv_im2col->tmp_output_ = output_addr; + } + + ret = ConvBaseRepackWeight(&conv_im2col->conv_); + if (ret != NNACL_OK) { + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_); + ConvIm2ColBaseFreeTmpBuffer(conv_im2col); + return ret; +} + +int ConvolutionIm2colBaseResize(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + + int ret = ConvBaseCheckResizeValid(conv); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvBasePrepare(conv); + if (ret != NNACL_OK) { + return ret; + } + + return ConvIm2ColBaseUpdateThreadNumProcess(self, TC_PTYPE(PrimType_Conv2DFusion), 0, 0, 0); +} + +int ConvolutionIm2colBasePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_im2col); + + conv_im2col->conv_.init_global_variable_(&conv_im2col->conv_); + + if (self->train_session_) { + int oc_block_num = UP_ROUND(conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_); + int kernel_chw = conv_im2col->conv_.compute_.in_c_ * conv_im2col->conv_.compute_.kernel_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(oc_block_num, kernel_chw, NNACL_ERR); + int pack_weight_size = oc_block_num * kernel_chw; + self->work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_im2col->conv_); +} + +ConvolutionBaseStruct *CreateConvIm2ColBase(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); + + conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer; + + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.run_impl_ = ConvIm2ColBaseRunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColBaseInitGlobalVariable; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_base.h new file mode 100644 index 0000000000000000000000000000000000000000..251f0bd243519272de11c3684bd1f0f45e227182 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_base.h @@ -0,0 +1,52 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_BASE_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" + +typedef struct ConvolutionIm2ColBaseStruct { + ConvolutionBaseStruct conv_; + int oc_tile_; + int row_tile_; + + float *tmp_output_; + float *packed_input_; + float *col_major_input_; + bool output_need_align_; + + void (*row_major_to_col_nmajor_)(const float *src_ptr, float *dst_ptr, int row, int col); + int (*init_tmp_buffer_)(struct ConvolutionIm2ColBaseStruct *conv_im2col); +} ConvolutionIm2ColBaseStruct; + +int ConvIm2ColBaseMallocWeightBiasData(ConvolutionBaseStruct *conv); +int ConvIm2ColBaseInitTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col); +int ConvIm2ColBaseImpl(void *cdata, int task_id, float l, float r); +void ConvIm2ColBaseFreeTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col); +void ConvIm2ColBasePackWeight(ConvolutionBaseStruct *conv); +int ConvIm2ColBaseRunImpl(ConvolutionBaseStruct *conv, int task_id); +int ConvolutionIm2colBaseCompute(KernelBase *self); +int ConvolutionIm2colBasePrepare(KernelBase *self); +int ConvolutionIm2colBaseResize(KernelBase *self); +int ConvolutionIm2colBaseRelease(KernelBase *self); +ConvolutionBaseStruct *CreateConvIm2ColBase(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_BASE_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/akg/gpu_kernel_builder.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_sse.c similarity index 30% rename from mindspore-lite/tools/graph_kernel/converter/akg/gpu_kernel_builder.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_sse.c index e01dceac82578b6fbf05c0db0487a0d7ee1fcb3c..6926585a84fa5157cc2b46813fafbea56dbc06d6 100644 --- a/mindspore-lite/tools/graph_kernel/converter/akg/gpu_kernel_builder.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_sse.c @@ -9,27 +9,39 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. * See the License for the specific language governing permissions and * limitations under the License. */ +#ifdef ENABLE_SSE +#include "nnacl/kernel/convolution_im2col_sse.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/tensor_c.h" -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_GPU_KERNEL_BUILDER_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_GPU_KERNEL_BUILDER_H_ -#include -#include -#include -#include "tools/graph_kernel/converter/akg/akg_kernel_builder.h" +void ConvIm2ColSSEInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv; + conv_im2col->oc_tile_ = C8NUM; + conv_im2col->row_tile_ = C4NUM; + conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col8Major; +} -namespace mindspore::graphkernel { -class GpuKernelBuilder : public AkgKernelBuilder { - public: - bool CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) override; - AnfNodePtr CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) override; - std::vector ReadThreadBlockFromJson(const std::string &dir_name); +ConvolutionBaseStruct *CreateConvIm2ColSSE(ConvParameter *conv_param) { + ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col); + memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct)); - private: - std::map node_info_map_; -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_GPU_KERNEL_BUILDER_H_ + conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer; + + conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData; + conv_im2col->conv_.run_impl_ = ConvIm2ColBaseRunImpl; + conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight; + conv_im2col->conv_.init_global_variable_ = ConvIm2ColSSEInitGlobalVariable; + + conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute; + conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare; + conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize; + conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease; + + return (ConvolutionBaseStruct *)conv_im2col; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_sse.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_sse.h new file mode 100644 index 0000000000000000000000000000000000000000..08948bcb1f978eb334e39f464d3ac59924ab2491 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_im2col_sse.h @@ -0,0 +1,29 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_IM2COL_SSE_H_ +#define NNACL_KERNEL_CONVOLLUTION_IM2COL_SSE_H_ + +#ifdef ENABLE_SSE +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_im2col_base.h" + +ConvolutionBaseStruct *CreateConvIm2ColSSE(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_IM2COL_SSE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_slidewindow.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_slidewindow.c new file mode 100644 index 0000000000000000000000000000000000000000..bde8068af07effd6bb7ead1e04f15ccdfa0c8732 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_slidewindow.c @@ -0,0 +1,227 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +#include "nnacl/kernel/convolution_slidewindow.h" +#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/tensor_c.h" +#include "nnacl/tensor_c_utils.h" + +int ConvSWInitTmpBuffer(ConvolutionSWStruct *conv_sw) { + TensorC *input_tensor = conv_sw->conv_.base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_data = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + ConvComputeParam *compute = &conv_sw->conv_.compute_; + NNACL_CHECK_NULL_RETURN_ERR(compute); + + if (conv_sw->ic_res_ != 0 && compute->kernel_h_ == 1 && compute->kernel_w_ == 1) { + int ic_block_num = UP_DIV(compute->in_c_, conv_sw->in_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_n_, compute->in_hw_, NNACL_ERR); + int input_bhw = compute->in_n_ * conv_sw->conv_.compute_.in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_bhw, ic_block_num * conv_sw->in_tile_, NNACL_ERR); + + conv_sw->input_data_ = (float *)conv_sw->conv_.base_.env_->Alloc( + conv_sw->conv_.base_.env_->allocator_, input_bhw * ic_block_num * conv_sw->in_tile_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_sw->input_data_); + + PackNHWCToNHWCXFp32(input_data, conv_sw->input_data_, compute->in_n_, compute->in_hw_, compute->in_c_, + conv_sw->oc_tile_); + } else { + conv_sw->input_data_ = input_data; + } + + float *out_data = (float *)conv_sw->conv_.base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + if (conv_sw->oc_res_ == 0) { // not need to malloc dst + conv_sw->output_data_ = out_data; + } else { // need to malloc dst to align block + int oc_block_num = UP_DIV(compute->out_c_, conv_sw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_n_, compute->out_hw_, NNACL_ERR); + int output_bhw = compute->out_n_ * compute->out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, oc_block_num * conv_sw->oc_tile_, NNACL_ERR); + conv_sw->output_data_ = (float *)conv_sw->conv_.base_.env_->Alloc( + conv_sw->conv_.base_.env_->allocator_, output_bhw * oc_block_num * conv_sw->oc_tile_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_sw->output_data_); + } + + return NNACL_OK; +} + +void ConvSWFreeTmpBuffer(ConvolutionSWStruct *conv_sw) { + ConvParameter *conv_param = (ConvParameter *)conv_sw->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_VOID(conv_param); + + if (conv_sw->output_data_ != NULL && conv_sw->oc_res_ != 0) { + conv_sw->conv_.base_.env_->Free(conv_sw->conv_.base_.env_->allocator_, conv_sw->output_data_); + conv_sw->output_data_ = NULL; + } + if (conv_sw->input_data_ != NULL && conv_sw->ic_res_ != 0 && conv_param->kernel_w_ == 1 && + conv_param->kernel_h_ == 1) { + conv_sw->conv_.base_.env_->Free(conv_sw->conv_.base_.env_->allocator_, conv_sw->input_data_); + conv_sw->input_data_ = NULL; + } +} + +void ConvSWPackWeight(ConvolutionBaseStruct *conv) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_sw); + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(filter_tensor); + + int input_channel = NNACLGetChannel(filter_tensor); + int output_channel = NNACLGetBatch(filter_tensor); + int kernel_h = NNACLGetHeight(filter_tensor); + int kernel_w = NNACLGetWidth(filter_tensor); + + int oc_block_num = UP_DIV(output_channel, conv_sw->oc_tile_); + void *origin_weight = (conv->base_.train_session_) ? filter_tensor->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + PackNHWCToNXHWCXFp32(kernel_h, kernel_w, output_channel, oc_block_num, input_channel, (float *)conv->packed_weight_, + (float *)origin_weight); +} + +int ConvSWMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(filter_tensor); + + int input_channel = NNACLGetChannel(filter_tensor); + int output_channel = NNACLGetBatch(filter_tensor); + int kernel_h = NNACLGetHeight(filter_tensor); + int kernel_w = NNACLGetWidth(filter_tensor); + + conv_param->input_channel_ = input_channel; + conv_param->output_channel_ = output_channel; + int kernel_plane = kernel_h * kernel_w; + int oc_block_num = UP_DIV(output_channel, conv_sw->oc_tile_); + int pack_weight_size = oc_block_num * conv_sw->oc_tile_ * input_channel * kernel_plane; + if (!conv_sw->conv_.base_.train_session_) { + conv_sw->conv_.packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_sw->conv_.packed_weight_); + } + + if (conv_sw->conv_.base_.in_size_ == THREE_TENSOR) { + int malloc_size = oc_block_num * conv_sw->oc_tile_ * sizeof(float); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, malloc_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + memset(conv->bias_data_, 0, oc_block_num * conv_sw->oc_tile_ * sizeof(float)); + } + return NNACL_OK; +} + +int ConvSWImpl(void *cdata, int task_id, float l, float r) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + return conv_sw->conv_.run_impl_(&conv_sw->conv_, task_id); +} + +int ConvolutionSWCompute(KernelBase *self) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + + int ret = ConvSWInitTmpBuffer(conv_sw); + if (ret != NNACL_OK) { + ConvSWFreeTmpBuffer(conv_sw); + return ret; + } + + ret = ConvBaseRepackWeight(&conv_sw->conv_); + if (ret != NNACL_OK) { + ConvSWFreeTmpBuffer(conv_sw); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvSWImpl, self, self->thread_nr_); + if (ret != NNACL_OK) { + ConvSWFreeTmpBuffer(conv_sw); + return ret; + } + + if (conv_sw->oc_res_ != 0) { + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + float *out_data = (float *)self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + PackNHWCXToNHWCFp32(conv_sw->output_data_, out_data, conv_param->output_batch_, + conv_param->output_h_ * conv_param->output_w_, conv_param->output_channel_, conv_sw->oc_tile_); + } + + ConvSWFreeTmpBuffer(conv_sw); + return NNACL_OK; +} + +int ConvolutionSWRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +int ConvolutionSWResize(KernelBase *self) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + int ret = ConvBaseCheckResizeValid(&conv_sw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvBasePrepare(&conv_sw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + InitSlidingParamConv(&conv_sw->sw_param_, conv_param, conv_sw->in_tile_, conv_sw->oc_tile_); + return NNACL_OK; +} + +int ConvolutionSWPrepare(KernelBase *self) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + + conv_sw->conv_.init_global_variable_(&conv_sw->conv_); + + if (self->train_session_) { + TensorC *filter_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(filter_tensor); + NNACL_CHECK_FALSE(filter_tensor->shape_size_ != DIMENSION_4D, NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + + int input_channel = NNACLGetChannel(filter_tensor); + int output_channel = NNACLGetBatch(filter_tensor); + int kernel_h = NNACLGetHeight(filter_tensor); + int kernel_w = NNACLGetWidth(filter_tensor); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_h, kernel_w, NNACL_ERR); + int kernel_hw = kernel_h * kernel_w; + int oc_block_num = UP_DIV(output_channel, conv_sw->oc_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_channel, kernel_hw, NNACL_ERR); + int kernel_chw = input_channel * kernel_hw; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(oc_block_num * conv_sw->oc_tile_, kernel_chw, NNACL_ERR); + int pack_weight_size = oc_block_num * conv_sw->oc_tile_ * kernel_chw; + + conv_sw->conv_.base_.work_size_ = pack_weight_size * sizeof(float); + } + + return ConvBaseInitConvWeightBias(&conv_sw->conv_); +} +#endif diff --git a/mindspore-lite/tools/graph_kernel/converter/akg/ascend_kernel_builder.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_slidewindow.h similarity index 42% rename from mindspore-lite/tools/graph_kernel/converter/akg/ascend_kernel_builder.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_slidewindow.h index 36931e6809e020e17564041c1f2e830cceb31e4c..cc1b5f03ec306107b56633167009ee4c9c2dba60 100644 --- a/mindspore-lite/tools/graph_kernel/converter/akg/ascend_kernel_builder.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_slidewindow.h @@ -14,21 +14,33 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_ASCEND_KERNEL_BUILDER_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_ASCEND_KERNEL_BUILDER_H_ -#include -#include -#include "tools/graph_kernel/converter/akg/akg_kernel_builder.h" +#ifndef NNACL_KERNEL_CONVOLLUTION_SLIDEWINDOW_H_ +#define NNACL_KERNEL_CONVOLLUTION_SLIDEWINDOW_H_ -namespace mindspore::graphkernel { -class AscendKernelBuilder : public AkgKernelBuilder { - public: - bool CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) override; - AnfNodePtr CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) override; +#if defined(ENABLE_AVX) || defined(ENABLE_ARM64) +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" +#include "nnacl/matmul_parameter.h" - private: - std::string dir_path_; - std::map node_info_map_; -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_ASCEND_KERNEL_BUILDER_H_ +typedef struct ConvolutionSWStruct { + ConvolutionBaseStruct conv_; + SlidingWindowParam sw_param_; + int oc_tile_; + int in_tile_; + int oc_res_; + int ic_res_; + float *output_data_; + float *input_data_; +} ConvolutionSWStruct; + +int ConvolutionSWPrepare(KernelBase *self); +int ConvolutionSWCompute(KernelBase *self); +int ConvolutionSWResize(KernelBase *self); +int ConvolutionSWRelease(KernelBase *self); +void ConvSWPackWeight(ConvolutionBaseStruct *conv); +int ConvSWMallocWeightBiasData(ConvolutionBaseStruct *conv); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_SLIDEWINDOW_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_1x1.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_1x1.c new file mode 100644 index 0000000000000000000000000000000000000000..3461fb6336b851821ee8afe7c9183908609ea96f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_1x1.c @@ -0,0 +1,152 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl/kernel/convolution_sw_1x1.h" +#include "nnacl/kernel/matmul_base.h" +#include "nnacl/kernel/matmul_create.h" + +int MatmulConv1x1Prelare(ConvolutionSW1x1Struct *sw_1x1) { + sw_1x1->matmul_->batch_ = 1; + sw_1x1->matmul_->a_batch_ = 1; + sw_1x1->matmul_->b_batch_ = 1; + + sw_1x1->matmul_->compute_.deep_ = sw_1x1->conv_.compute_.in_c_; + sw_1x1->matmul_->compute_.col_ = sw_1x1->conv_.compute_.out_c_; + sw_1x1->matmul_->compute_.row_ = sw_1x1->conv_.compute_.in_hw_ * sw_1x1->conv_.compute_.in_n_; + + return sw_1x1->matmul_->base_.Prepare(&sw_1x1->matmul_->base_); +} + +int MatmulConv1x1Resize(ConvolutionSW1x1Struct *sw_1x1) { + sw_1x1->matmul_->compute_.deep_ = sw_1x1->conv_.compute_.in_c_; + sw_1x1->matmul_->compute_.col_ = sw_1x1->conv_.compute_.out_c_; + sw_1x1->matmul_->compute_.row_ = sw_1x1->conv_.compute_.in_hw_ * sw_1x1->conv_.compute_.in_n_; + + MatmulBaseFreeBatchOffset(sw_1x1->matmul_); + int ret = MatmulBaseMallocBatchOffset(sw_1x1->matmul_); + if (ret != NNACL_OK) { + return ret; + } + + return sw_1x1->matmul_->base_.Resize(&sw_1x1->matmul_->base_); +} + +void UpdateTensorInfo(KernelBase *self, ConvolutionSW1x1Struct *sw_1x1) { + sw_1x1->matmul_->base_.in_ = self->in_; + sw_1x1->matmul_->base_.in_size_ = self->in_size_; + sw_1x1->matmul_->base_.out_ = self->out_; + sw_1x1->matmul_->base_.out_size_ = self->out_size_; + sw_1x1->matmul_->base_.workspace_ = self->workspace_; +} + +int ConvolutionSW1x1Compute(KernelBase *self) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1); + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_); + + UpdateTensorInfo(self, sw_1x1); + return sw_1x1->matmul_->base_.Compute(&sw_1x1->matmul_->base_); +} + +int ConvolutionSW1x1Resize(KernelBase *self) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1); + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_); + + UpdateTensorInfo(self, sw_1x1); + return MatmulConv1x1Resize(sw_1x1); +} + +int ConvolutionSW1x1Prepare(KernelBase *self) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1); + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1->matmul_); + + sw_1x1->matmul_->matrix_b_.origin_ptr_ = sw_1x1->conv_.origin_weight_; + sw_1x1->matmul_->matrix_b_.origin_need_free_ = false; + sw_1x1->matmul_->matrix_c_.origin_ptr_ = sw_1x1->conv_.origin_bias_; + sw_1x1->matmul_->matrix_c_.origin_need_free_ = false; + + sw_1x1->matmul_->infer_shape_ = sw_1x1->conv_.infershape_done_; + sw_1x1->matmul_->base_.train_session_ = self->train_session_; + sw_1x1->matmul_->base_.thread_nr_ = self->thread_nr_; + sw_1x1->matmul_->base_.env_ = self->env_; + + UpdateTensorInfo(self, sw_1x1); + return MatmulConv1x1Prelare(sw_1x1); +} + +int ConvolutionSW1x1Release(KernelBase *self) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(sw_1x1); + + if (sw_1x1->matmul_ != NULL) { + sw_1x1->matmul_->matrix_b_.origin_ptr_ = NULL; + sw_1x1->matmul_->matrix_c_.origin_ptr_ = NULL; + + (void)sw_1x1->matmul_->base_.Release(&sw_1x1->matmul_->base_); + + if (sw_1x1->matmul_->base_.param_ != NULL) { + free(sw_1x1->matmul_->base_.param_); + sw_1x1->matmul_->base_.param_ = NULL; + } + + free(sw_1x1->matmul_); + sw_1x1->matmul_ = NULL; + } + + ConvBaseRelease(&sw_1x1->conv_); + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateConvolutionSW1x1(ConvParameter *conv_param, bool input_const, bool weight_const) { + ConvolutionSW1x1Struct *sw_1x1 = (ConvolutionSW1x1Struct *)malloc(sizeof(ConvolutionSW1x1Struct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(sw_1x1); + memset(sw_1x1, 0, sizeof(ConvolutionSW1x1Struct)); + + sw_1x1->conv_.is_sharing_pack_ = false; + sw_1x1->conv_.base_.Compute = ConvolutionSW1x1Compute; + sw_1x1->conv_.base_.Resize = ConvolutionSW1x1Resize; + sw_1x1->conv_.base_.Prepare = ConvolutionSW1x1Prepare; + sw_1x1->conv_.base_.Release = ConvolutionSW1x1Release; + + OpParameter *param = (OpParameter *)malloc(sizeof(MatMulParameter)); + if (param == NULL) { + free(sw_1x1); + return NULL; + } + MatMulParameter *matmul_param = (MatMulParameter *)param; + matmul_param->op_parameter_ = conv_param->op_parameter_; + matmul_param->act_type_ = conv_param->act_type_; + matmul_param->a_transpose_ = false; + matmul_param->b_transpose_ = true; + + KernelBase *matmul = CreateMatmulKernel(); + if (matmul == NULL) { + free(sw_1x1); + free(param); + return NULL; + } + + ((MatmulStruct *)matmul)->is_sharing_pack_ = false; + ((MatmulStruct *)matmul)->a_const_ = input_const; + ((MatmulStruct *)matmul)->b_const_ = weight_const; + ((MatmulStruct *)matmul)->base_.param_ = param; + sw_1x1->matmul_ = (MatmulStruct *)matmul; + return (ConvolutionBaseStruct *)sw_1x1; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_1x1.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_1x1.h new file mode 100644 index 0000000000000000000000000000000000000000..90e9d3e2d03f0d9f93450804c12e64377ed1bc7f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_1x1.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_SW_1X1_H_ +#define NNACL_KERNEL_CONVOLLUTION_SW_1X1_H_ + +#ifdef ENABLE_AVX +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/kernel/matmul_struct.h" + +typedef struct ConvolutionSW1x1Struct { + ConvolutionBaseStruct conv_; + MatmulStruct *matmul_; +} ConvolutionSW1x1Struct; + +ConvolutionBaseStruct *CreateConvolutionSW1x1(ConvParameter *conv_param, bool input_const, bool weight_const); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_SW_1X1_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_arm64.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_arm64.c new file mode 100644 index 0000000000000000000000000000000000000000..283945b29dd3a21a7d0e795958b12efd91d73193 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_arm64.c @@ -0,0 +1,59 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM64 +#include "nnacl/kernel/convolution_sw_arm64.h" +#include "nnacl/kernel/convolution_slidewindow.h" +#include "nnacl/fp32/conv_sw_arm64_fp32.h" + +void ConvSWARM64InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_VOID(conv_param); + + conv_sw->oc_tile_ = C8NUM; + conv_sw->oc_res_ = conv_param->output_channel_ % conv_sw->oc_tile_; +} + +int ConvSWARM64RunImpl(ConvolutionBaseStruct *conv, int task_id) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + ConvSWArm64Fp32(conv_sw->input_data_, (float *)conv->packed_weight_, (float *)conv->bias_data_, conv_sw->output_data_, + task_id, conv_param, &conv_sw->sw_param_); + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateConvolutionSWARM64(ConvParameter *conv_param) { + ConvolutionSWStruct *sw = (ConvolutionSWStruct *)malloc(sizeof(ConvolutionSWStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(sw); + memset(sw, 0, sizeof(ConvolutionSWStruct)); + + sw->conv_.run_impl_ = ConvSWARM64RunImpl; + sw->conv_.init_global_variable_ = ConvSWARM64InitGlobalVariable; + sw->conv_.pack_weight_ = ConvSWPackWeight; + sw->conv_.malloc_weight_bias_ = ConvSWMallocWeightBiasData; + + sw->conv_.base_.Compute = ConvolutionSWCompute; + sw->conv_.base_.Prepare = ConvolutionSWPrepare; + sw->conv_.base_.Release = ConvolutionSWRelease; + sw->conv_.base_.Resize = ConvolutionSWResize; + + return (ConvolutionBaseStruct *)sw; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_arm64.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_arm64.h new file mode 100644 index 0000000000000000000000000000000000000000..31af4bcaa5413ac4bd5fb373c89dda795d698f2b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_arm64.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_SW_ARM64_H_ +#define NNACL_KERNEL_CONVOLLUTION_SW_ARM64_H_ +#ifdef ENABLE_ARM64 +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" + +ConvolutionBaseStruct *CreateConvolutionSWARM64(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_SW_ARM64_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_avx.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_avx.c new file mode 100644 index 0000000000000000000000000000000000000000..b9d66d1abfc42109933b4d5c2b15b07852e1a3f0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_avx.c @@ -0,0 +1,71 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl/kernel/convolution_sw_avx.h" +#include "nnacl/kernel/convolution_slidewindow.h" +#include "nnacl/fp32/conv_1x1_avx_fp32.h" +#include "nnacl/fp32/conv_sw_avx_fp32.h" + +void ConvSWAVXInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_VOID(conv_param); + + conv_sw->oc_tile_ = C8NUM; + conv_sw->oc_res_ = conv_param->output_channel_ % conv_sw->oc_tile_; + if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) { + // 1x1 conv is aligned to C8NUM + conv_sw->in_tile_ = C8NUM; + conv_sw->ic_res_ = conv_param->input_channel_ % conv_sw->in_tile_; + } +} + +int ConvSWAVXRunImpl(ConvolutionBaseStruct *conv, int task_id) { + ConvolutionSWStruct *conv_sw = (ConvolutionSWStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(conv_sw); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + if (conv_param->kernel_w_ == 1 && conv_param->kernel_h_ == 1) { + Conv1x1SWAVXFp32(conv_sw->input_data_, (float *)conv->packed_weight_, (float *)conv->bias_data_, + conv_sw->output_data_, task_id, conv_param, &conv_sw->sw_param_); + } else { + ConvSWAVXFp32(conv_sw->input_data_, (float *)conv->packed_weight_, (float *)conv->bias_data_, conv_sw->output_data_, + task_id, conv_param, &conv_sw->sw_param_); + } + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateConvolutionSWAVX(ConvParameter *conv_param) { + ConvolutionSWStruct *sw = (ConvolutionSWStruct *)malloc(sizeof(ConvolutionSWStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(sw); + memset(sw, 0, sizeof(ConvolutionSWStruct)); + + sw->conv_.run_impl_ = ConvSWAVXRunImpl; + sw->conv_.init_global_variable_ = ConvSWAVXInitGlobalVariable; + sw->conv_.pack_weight_ = ConvSWPackWeight; + sw->conv_.malloc_weight_bias_ = ConvSWMallocWeightBiasData; + + sw->conv_.base_.Compute = ConvolutionSWCompute; + sw->conv_.base_.Prepare = ConvolutionSWPrepare; + sw->conv_.base_.Release = ConvolutionSWRelease; + sw->conv_.base_.Resize = ConvolutionSWResize; + + return (ConvolutionBaseStruct *)sw; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_avx.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_avx.h new file mode 100644 index 0000000000000000000000000000000000000000..22dc3f9b7aa6b33c97c93a08eea0c2dd328a45aa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_sw_avx.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_SW_AVX_H_ +#define NNACL_KERNEL_CONVOLLUTION_SW_AVX_H_ +#ifdef ENABLE_AVX +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" + +ConvolutionBaseStruct *CreateConvolutionSWAVX(ConvParameter *conv_param); +#endif +#endif // NNACL_KERNEL_CONVOLLUTION_SW_AVX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd.c new file mode 100644 index 0000000000000000000000000000000000000000..c0d226559b9f355aae0f34b10ff9ee410992fa69 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd.c @@ -0,0 +1,76 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/convolution_winograd.h" +#include "nnacl/kernel/convolution_winograd_base.h" +#ifdef ENABLE_AVX +#include "nnacl/kernel/convolution_winograd_avx.h" +#endif +#ifdef ENABLE_SSE +#include "nnacl/kernel/convolution_winograd_sse.h" +#endif +#ifdef ENABLE_ARM64 +#include "nnacl/kernel/convolution_winograd_arm64.h" +#endif +#ifdef ENABLE_ARM32 +#include "nnacl/kernel/convolution_winograd_arm32.h" +#endif + +ConvolutionWinogradBaseStruct *SelectConvolutionWinograd(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *kernel = NULL; + +#ifdef ENABLE_AVX + kernel = CreateConvWinogradAVX(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_SSE + kernel = CreateConvWinogradSSE(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_ARM64 + kernel = CreateConvWinogradARM64(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + +#ifdef ENABLE_ARM32 + kernel = CreateConvWinogradARM32(conv_param); + if (kernel != NULL) { + return kernel; + } +#endif + + kernel = CreateConvWinogradBase(conv_param); + return kernel; +} + +ConvolutionBaseStruct *CreateConvolutionWinograd(ConvParameter *conv_param, int out_unit) { + ConvolutionWinogradBaseStruct *kernel = SelectConvolutionWinograd(conv_param); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(kernel); + + kernel->output_unit_ = out_unit; + kernel->conv_.malloc_weight_bias_ = ConvWinoBaseMallocWeightBiasData; + kernel->conv_.run_impl_ = ConvWinoBaseRunImpl; + kernel->conv_.pack_weight_ = ConvWinoBasePackWeight; + return (ConvolutionBaseStruct *)kernel; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd.h new file mode 100644 index 0000000000000000000000000000000000000000..9d2356b29b707af204adcb707a91ff0fc729adfb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" + +typedef struct ConvolutionWinogradStruct { + ConvolutionBaseStruct conv_; +} ConvolutionWinogradStruct; + +ConvolutionBaseStruct *CreateConvolutionWinograd(ConvParameter *conv_param, int out_uint); + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/query.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_arm32.c similarity index 36% rename from mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/query.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_arm32.c index a208adc01990dd8ec24387db62451a7238ff719a..f806e58fa8a39240e6a6efbf487dfeefca9aeeca 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/query.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_arm32.c @@ -14,24 +14,29 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_QUERY_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_QUERY_H_ +#ifdef ENABLE_ARM32 +#include "nnacl/kernel/convolution_winograd_arm32.h" -#include -#include "extendrt/delegate/ascend_native/ascend_native_impl/encoder.h" +void ConvWinoARM32InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C8NUM; + winograd->tmp_data_tile_ = C4NUM; + winograd->tile_num_ = C12NUM; +} -namespace mindspore::ascend_native { +ConvolutionWinogradBaseStruct *CreateConvWinogradARM32(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); -class AscendNativeQuery : public AscendNativeEncoder { - public: - AscendNativeQuery() : AscendNativeEncoder(false) {} - virtual ~AscendNativeQuery() {} - void QEmbedding(std::vector *ins, std::vector *outs, void *ws, EncoderParams *p, void *q); - void Attn(std::vector *ins, std::vector *outs, void *ws, EncoderParams *p, void *q) override; - void HeadPangu(std::vector *ins, std::vector *outs, void *ws, EncoderParams *p, void *q); - size_t GetWorkspaceSize(const EncoderParams &p) override; - void Forward(std::vector *ins, std::vector *outs, void *ws, EncoderParams *p, void *q) override; -}; + winograd->config_input_output_ = ConvWinoBaseConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoARM32InitGlobalVariable; -} // namespace mindspore::ascend_native -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_QUERY_H_ + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + return winograd; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_arm32.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_arm32.h new file mode 100644 index 0000000000000000000000000000000000000000..9cfa4fe26a489e645d8eb531cd8c426a2275e2ac --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_arm32.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM32_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM32_H_ + +#ifdef ENABLE_ARM32 +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_winograd_base.h" + +ConvolutionWinogradBaseStruct *CreateConvWinogradARM32(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_arm64.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_arm64.c new file mode 100644 index 0000000000000000000000000000000000000000..7c0b323c8697c4fd2fab2afe6697e10e093e67a5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_arm64.c @@ -0,0 +1,60 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM64 +#include "nnacl/kernel/convolution_winograd_arm64.h" +#include "nnacl/kernel/convolution_winograd_base.h" + +void ConvWinoARM64InitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C8NUM; + winograd->tmp_data_tile_ = C4NUM; + winograd->tile_num_ = C12NUM; +} + +int ConvWinoARM64ConfigInputOutput(ConvolutionWinogradBaseStruct *winograd) { + winograd->transfer_functions_.in_func_ = GetInputTransFunc(winograd->input_unit_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + winograd->transfer_functions_.in_step_func_ = GetInputTransStepFunc(winograd->input_unit_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + winograd->transfer_functions_.in_pack_func_ = GetInputTransPackFunc(winograd->input_unit_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + ActType act_type = ((ConvParameter *)winograd->conv_.base_.param_)->act_type_; + winograd->transfer_functions_.out_func_ = GetOutputTransFunc(winograd->input_unit_, winograd->output_unit_, act_type); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + return NNACL_OK; +} + +ConvolutionWinogradBaseStruct *CreateConvWinogradARM64(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); + + winograd->config_input_output_ = ConvWinoARM64ConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoARM64InitGlobalVariable; + + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + return winograd; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_arm64.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_arm64.h new file mode 100644 index 0000000000000000000000000000000000000000..3bc6da866ed5634db042671951b29ea8b4e53494 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_arm64.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM64_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM64_H_ + +#ifdef ENABLE_ARM64 +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_winograd_base.h" + +ConvolutionWinogradBaseStruct *CreateConvWinogradARM64(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_ARM64_H_ diff --git a/mindspore-lite/src/extendrt/kernel/nnacl/nnacl_lib.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_avx.c similarity index 31% rename from mindspore-lite/src/extendrt/kernel/nnacl/nnacl_lib.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_avx.c index d94ff993ed6e0498027ca5bcc24d6695cfe58f6f..ce54664beb0839437dd777a0f9e1bdd460c9c105 100644 --- a/mindspore-lite/src/extendrt/kernel/nnacl/nnacl_lib.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_avx.c @@ -9,35 +9,35 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_NNACL_NNACL_LIB_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_NNACL_NNACL_LIB_H_ +#ifdef ENABLE_AVX +#include "nnacl/kernel/convolution_winograd_avx.h" +#include "nnacl/kernel/convolution_winograd_base.h" -#include -#include -#include "src/extendrt/kernel/kernel_lib.h" -#include "src/litert/kernel/cpu/nnacl/nnacl_kernel.h" -#include "src/common/tensor_util.h" -#include "src/extendrt/kernel/kernel_spec_infos.h" +void ConvWinoAVXInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C16NUM; + winograd->tmp_data_tile_ = C8NUM; + winograd->tile_num_ = C12NUM; +} -namespace mindspore { -namespace kernel { -class NNACLLib : public KernelLib { - public: - NNACLLib() : KernelLib(kNNACLLibName, kBackendCPU) {} +ConvolutionWinogradBaseStruct *CreateConvWinogradAVX(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); - bool Support(const PrimitiveType &op_type, const KernelAttr &dt, const std::string &backend, - const Format &format = DEFAULT_FORMAT) const override; - BaseKernel *CreateKernel(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) const override; + winograd->config_input_output_ = ConvWinoBaseConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoAVXInitGlobalVariable; - InferKernel *CreateKernelExec(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) const override; -}; -} // namespace kernel -} // namespace mindspore + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + return winograd; +} #endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_avx.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_avx.h new file mode 100644 index 0000000000000000000000000000000000000000..f44376ceb0bbc5eaa5882e8425d6c5570b8dddd9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_avx.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_AVX_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_AVX_H_ + +#ifdef ENABLE_AVX +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_winograd_base.h" + +ConvolutionWinogradBaseStruct *CreateConvWinogradAVX(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_AVX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_base.c new file mode 100644 index 0000000000000000000000000000000000000000..343fb2b8a7a7667bf64f116a799e36e9c981a4cf --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_base.c @@ -0,0 +1,320 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/convolution_winograd_base.h" +#include "nnacl/base/minimal_filtering_generator.h" +#include "nnacl/fp32/winograd_transform.h" +#include "nnacl/fp32/conv_winograd_fp32.h" + +int ConvWinoBaseMallocWeightBiasData(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd); + + // set data + size_t trans_matrix_data_size = winograd->input_unit_ * winograd->input_unit_ * conv->compute_.in_c_ * + UP_ROUND(conv->compute_.out_c_, winograd->oc_block_) * sizeof(float); + if (!conv->base_.train_session_) { + if (conv->packed_weight_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(trans_matrix_data_size); + conv->packed_weight_ = ConvBaseGetConvPackWeightData(conv, trans_matrix_data_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + } + + float matrix_a[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float matrix_at[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float matrix_b[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float matrix_bt[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float coef = 1.0f; + if (winograd->input_unit_ == CONVOLUTION_WINOGRAD_INPUT_UNIT_SIZE) { + coef = 0.5f; + } + int ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, winograd->matrix_g_, winograd->matrix_gt_, coef, + winograd->output_unit_, winograd->kernel_unit_); + if (ret != NNACL_OK) { + return ret; + } + + // init bias + size_t new_bias_size = UP_ROUND(conv->compute_.out_c_, C4NUM) * sizeof(float); + if (conv->bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(new_bias_size); + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, new_bias_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, new_bias_size); + return NNACL_OK; +} + +void ConvWinoBaseFreeTmpBuffer(ConvolutionWinogradBaseStruct *winograd) { + ExecEnv *env = winograd->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + + if (winograd->trans_input_ != NULL) { + env->Free(env->allocator_, winograd->trans_input_); + winograd->trans_input_ = NULL; + } + if (winograd->tmp_data_ != NULL) { + env->Free(env->allocator_, winograd->tmp_data_); + winograd->tmp_data_ = NULL; + } + if (winograd->gemm_out_ != NULL) { + env->Free(env->allocator_, winograd->gemm_out_); + winograd->gemm_out_ = NULL; + } + if (winograd->col_buffer_ != NULL) { + env->Free(env->allocator_, winograd->col_buffer_); + winograd->col_buffer_ = NULL; + } + if (winograd->opt_input_trans_ != NULL) { + env->Free(env->allocator_, winograd->opt_input_trans_); + winograd->opt_input_trans_ = NULL; + } +} + +void ConvWinoBaseInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C8NUM; + winograd->tmp_data_tile_ = C4NUM; + winograd->tile_num_ = C12NUM; +} + +int ConvWinoBaseWinogradFilterTransform(ConvolutionWinogradBaseStruct *winograd, const float *weight_data) { + NNACL_CHECK_ZERO_RETURN_ERR(winograd->oc_block_); + return WinogradWeightTransform(weight_data, (float *)winograd->conv_.packed_weight_, winograd->matrix_g_, + winograd->matrix_gt_, winograd->oc_block_, winograd->input_unit_, + winograd->kernel_unit_, winograd->conv_.compute_.in_c_, + winograd->conv_.compute_.out_c_, true); +} + +void ConvWinoBasePackWeight(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_VOID(winograd); + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(weight_tensor); + void *origin_weight = (conv->base_.train_session_) ? weight_tensor->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + ConvWinoBaseWinogradFilterTransform(winograd, (float *)origin_weight); +} + +int ConvolutionWinogradBasePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(winograd); + + winograd->conv_.init_global_variable_(&winograd->conv_); + + winograd->kernel_unit_ = winograd->conv_.compute_.kernel_h_; + winograd->input_unit_ = winograd->output_unit_ + winograd->kernel_unit_ - 1; + + if (self->train_session_) { + TensorC *filter_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(filter_tensor); + NNACL_CHECK_FALSE(filter_tensor->shape_size_ != DIMENSION_4D, NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + + int input_plane = winograd->input_unit_ * winograd->input_unit_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_plane, winograd->conv_.compute_.in_c_, NNACL_ERR); + int in_chw = input_plane * winograd->conv_.compute_.in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(in_chw, UP_ROUND(winograd->conv_.compute_.out_c_, winograd->oc_block_), NNACL_ERR); + int trans_matrix_data_size = + in_chw * UP_ROUND(winograd->conv_.compute_.out_c_, winograd->oc_block_) * sizeof(float); + self->work_size_ = trans_matrix_data_size; + } + + return ConvBaseInitConvWeightBias(&winograd->conv_); +} + +int ConvoWinoBaseUpdateThreadNumProcess(ConvolutionWinogradBaseStruct *winograd) { + if (winograd->conv_.compute_.in_n_ % winograd->conv_.base_.thread_nr_ == 0) { + winograd->conv_.use_batch_cut_flag_ = true; + return NNACL_OK; + } else { + winograd->conv_.use_batch_cut_flag_ = false; + } + + int update_thread = UP_DIV(UP_DIV(winograd->conv_.compute_.out_hw_, C12NUM), ConvMinBlock); + winograd->conv_.base_.thread_nr_ = NNACL_MIN(update_thread, winograd->conv_.base_.thread_nr_); + return NNACL_OK; +} + +int ConvoWinoBaseUpdateThread(ConvolutionWinogradBaseStruct *winograd) { +#ifdef DYNAMIC_THREAD_DISTRIBUTE + ConvoWinoBaseUpdateThreadNumProcess(winograd); +#else + KernelBase *base = &winograd->conv_.base_; + base->thread_nr_ = base->UpdateThread(TC_PTYPE(PrimType_Conv2DFusion), 0, 0, 0, base->thread_nr_); +#endif + return NNACL_OK; +} + +int ConvWinoBaseConfigInputOutput(ConvolutionWinogradBaseStruct *winograd) { + winograd->transfer_functions_.in_func_ = GetInputTransFunc(winograd->input_unit_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.in_func_); + + ConvParameter *conv_param = (ConvParameter *)winograd->conv_.base_.param_; + winograd->transfer_functions_.out_func_ = + GetOutputTransFunc(winograd->input_unit_, winograd->output_unit_, conv_param->act_type_); + NNACL_CHECK_NULL_RETURN_ERR(winograd->transfer_functions_.out_func_); + + return NNACL_OK; +} + +int ConvoWinoBaseInitTmpBuffer(ConvolutionWinogradBaseStruct *winograd) { + ExecEnv *env = winograd->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + int thread_input_plane = winograd->conv_.base_.thread_nr_ * winograd->input_unit_ * winograd->input_unit_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(winograd->tile_num_, thread_input_plane, NNACL_ERR); + int total_thread_input_plane = winograd->tile_num_ * thread_input_plane; + size_t tile_buffer_size = total_thread_input_plane * winograd->conv_.compute_.in_c_ * sizeof(float); + winograd->trans_input_ = (float *)env->Alloc(env->allocator_, tile_buffer_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->trans_input_); + + int oc8 = UP_ROUND(winograd->conv_.compute_.out_c_, C8NUM); + winograd->gemm_out_ = env->Alloc(env->allocator_, total_thread_input_plane * oc8 * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->gemm_out_); + + winograd->tmp_data_ = env->Alloc(env->allocator_, winograd->tmp_data_tile_ * thread_input_plane * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->tmp_data_); + + winograd->col_buffer_ = env->Alloc(env->allocator_, winograd->conv_.base_.thread_nr_ * winograd->tile_num_ * + winograd->conv_.compute_.in_c_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->col_buffer_); + + int tile = UP_ROUND(winograd->conv_.compute_.in_c_, winograd->tmp_data_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thread_input_plane, tile, NNACL_ERR); + winograd->opt_input_trans_ = env->Alloc(env->allocator_, total_thread_input_plane * tile * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(winograd->opt_input_trans_); + + winograd->tmp_buffer_address_list_[Index0] = winograd->trans_input_; + winograd->tmp_buffer_address_list_[Index1] = winograd->gemm_out_; + winograd->tmp_buffer_address_list_[Index2] = winograd->tmp_data_; + winograd->tmp_buffer_address_list_[Index3] = winograd->col_buffer_; + winograd->tmp_buffer_address_list_[Index4] = winograd->opt_input_trans_; + return NNACL_OK; +} + +int ConvWinoBaseRunImpl(ConvolutionBaseStruct *conv, int task_id) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(winograd); + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + TensorC *input_tensor = conv->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_data = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + + TensorC *output_tensor = conv->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *output_data = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + if (conv->use_batch_cut_flag_) { + ConvWinogardFp32CutByBatch(input_data, (float *)conv->packed_weight_, (float *)conv->bias_data_, output_data, + winograd->tmp_buffer_address_list_, task_id, conv_param, winograd->transfer_functions_); + } else { + ConvWinogardFp32(input_data, (float *)conv->packed_weight_, (float *)conv->bias_data_, output_data, + winograd->tmp_buffer_address_list_, task_id, conv_param, winograd->transfer_functions_); + } + + return NNACL_OK; +} + +int ConvWinoImpl(void *cdata, int task_id, float l, float r) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(conv); + return conv->run_impl_(conv, task_id); +} + +void ConvWinoBaseUpdateParam(ConvParameter *param, ConvolutionWinogradBaseStruct *winograd) { + param->input_unit_ = winograd->input_unit_; + param->output_unit_ = winograd->output_unit_; +} + +int ConvolutionWinogradBaseResize(KernelBase *self) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(winograd); + + int ret = ConvBaseCheckResizeValid(&winograd->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvBasePrepare(&winograd->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvoWinoBaseUpdateThread(winograd); + if (ret != NNACL_OK) { + return ret; + } + + ret = winograd->config_input_output_(winograd); + if (ret != NNACL_OK) { + return ret; + } + + ConvWinoBaseUpdateParam((ConvParameter *)self->param_, winograd); + return NNACL_OK; +} + +int ConvolutionWinogradBaseCompute(KernelBase *self) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(winograd); + + int ret = ConvoWinoBaseInitTmpBuffer(winograd); + if (ret != NNACL_OK) { + ConvWinoBaseFreeTmpBuffer(winograd); + return ret; + } + + ret = ConvBaseRepackWeight(&winograd->conv_); + if (ret != NNACL_OK) { + ConvWinoBaseFreeTmpBuffer(winograd); + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvWinoImpl, self, self->thread_nr_); + ConvWinoBaseFreeTmpBuffer(winograd); + return ret; +} + +int ConvolutionWinogradBaseRelease(KernelBase *self) { + ConvolutionBaseStruct *conv = (ConvolutionBaseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(conv); + ConvBaseRelease(conv); + return NNACL_OK; +} + +ConvolutionWinogradBaseStruct *CreateConvWinogradBase(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); + + winograd->config_input_output_ = ConvWinoBaseConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoBaseInitGlobalVariable; + + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; + return (ConvolutionWinogradBaseStruct *)winograd; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_base.h new file mode 100644 index 0000000000000000000000000000000000000000..d714b001c783003d2a5cd280b0be3f45076f3568 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_base.h @@ -0,0 +1,65 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_BASE_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" +#include "nnacl/fp32/winograd_utils.h" + +#define CONVOLUTION_WINOGRAD_MATRIX_SIZE 64 +#define CONVOLUTION_WINOGRAD_TMP_BUFFER_SIZE 5 +#define CONVOLUTION_WINOGRAD_INPUT_UNIT_SIZE 8 + +typedef float *TmpBufferAddress; + +typedef struct ConvolutionWinogradBaseStruct { + ConvolutionBaseStruct conv_; + + int kernel_unit_; + int input_unit_; + int output_unit_; + int oc_block_; + int tile_num_; + int tmp_data_tile_; + float *tmp_data_; + float *trans_input_; + float *gemm_out_; + float *col_buffer_; + float *opt_input_trans_; + float matrix_g_[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + float matrix_gt_[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; + TmpBufferAddress tmp_buffer_address_list_[CONVOLUTION_WINOGRAD_TMP_BUFFER_SIZE]; + TransFuncList transfer_functions_; + + int (*config_input_output_)(struct ConvolutionWinogradBaseStruct *winograd); +} ConvolutionWinogradBaseStruct; + +void ConvWinoBasePackWeight(ConvolutionBaseStruct *conv); +int ConvWinoBaseConfigInputOutput(ConvolutionWinogradBaseStruct *winograd); +int ConvWinoBaseRunImpl(ConvolutionBaseStruct *conv, int task_id); +int ConvWinoBaseMallocWeightBiasData(ConvolutionBaseStruct *conv); +int ConvolutionWinogradBasePrepare(KernelBase *self); +int ConvolutionWinogradBaseResize(KernelBase *self); +int ConvolutionWinogradBaseRelease(KernelBase *self); +int ConvolutionWinogradBaseCompute(KernelBase *self); +ConvolutionWinogradBaseStruct *CreateConvWinogradBase(ConvParameter *conv_param); + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_BASE_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/softmax_tensorrt.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_sse.c similarity index 31% rename from mindspore-lite/src/extendrt/delegate/tensorrt/op/softmax_tensorrt.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_sse.c index 0d6c475c7d077df3a9be28c4a4c7903395ae6dce..db7bd463e42a5c16fb497ffc9c5d0bafaf2e09b2 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/softmax_tensorrt.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_sse.c @@ -1,5 +1,5 @@ /** - * Copyright 2021-2022 Huawei Technologies Co., Ltd + * Copyright 2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -9,36 +9,36 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SOFTMAX_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SOFTMAX_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -namespace mindspore::lite { -class SoftMaxTensorRT : public TensorRTOp { - public: - SoftMaxTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} +#ifdef ENABLE_SSE +#include "nnacl/kernel/convolution_winograd_sse.h" +#include "nnacl/kernel/convolution_winograd_base.h" - ~SoftMaxTensorRT() override = default; +void ConvWinoSSEInitGlobalVariable(ConvolutionBaseStruct *conv) { + ConvolutionWinogradBaseStruct *winograd = (ConvolutionWinogradBaseStruct *)conv; + winograd->oc_block_ = C8NUM; + winograd->tmp_data_tile_ = C4NUM; + winograd->tile_num_ = C12NUM; +} - int AddInnerOp(TensorRTContext *ctx) override; +ConvolutionWinogradBaseStruct *CreateConvWinogradSSE(ConvParameter *conv_param) { + ConvolutionWinogradBaseStruct *winograd = + (ConvolutionWinogradBaseStruct *)malloc(sizeof(ConvolutionWinogradBaseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(winograd); + memset(winograd, 0, sizeof(ConvolutionWinogradBaseStruct)); - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; + winograd->config_input_output_ = ConvWinoBaseConfigInputOutput; + winograd->conv_.init_global_variable_ = ConvWinoSSEInitGlobalVariable; - private: - nvinfer1::ISoftMaxLayer *AddSoftMaxOp(TensorRTContext *ctx); + winograd->conv_.base_.Prepare = ConvolutionWinogradBasePrepare; + winograd->conv_.base_.Resize = ConvolutionWinogradBaseResize; + winograd->conv_.base_.Release = ConvolutionWinogradBaseRelease; + winograd->conv_.base_.Compute = ConvolutionWinogradBaseCompute; - std::vector axis_val_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SOFTMAX_TENSORRT_H_ + return winograd; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_sse.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_sse.h new file mode 100644 index 0000000000000000000000000000000000000000..2c30de9a59a2fbce805b781f89c4d73d0941a997 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/convolution_winograd_sse.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_SSE_H_ +#define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_SSE_H_ + +#ifdef ENABLE_SSE +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_winograd_base.h" + +ConvolutionWinogradBaseStruct *CreateConvWinogradSSE(ConvParameter *conv_param); +#endif + +#endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_SSE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/crop.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/crop.c new file mode 100644 index 0000000000000000000000000000000000000000..8f107e82132af224cb8a51a73b6ba7e8a99dde18 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/crop.c @@ -0,0 +1,96 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/crop.h" +#include "nnacl/base/crop_base.h" +#include "nnacl/fp32/crop_fp32.h" +#include "nnacl/kernel/default_kernel_base.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/crop_fp16.h" +#endif + +int CropLaunch(void *cdata, int task_id, float l, float r) { + CropStruct *crop = (CropStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(crop); + + TensorC *in = crop->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + TensorC *out = crop->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out); + +#ifdef ENABLE_FP16 + if (in->data_type_ == kNumberTypeFloat16) { + Fp16Crop((float16_t *)in->data_, (float16_t *)out->data_, in->shape_, out->shape_, crop->in_offset_, + in->shape_size_, task_id, crop->base_.thread_nr_); + return NNACL_OK; + } +#endif + + CropParameter *crop_param = (CropParameter *)crop->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(crop_param); + Crop4D((float *)in->data_, (float *)out->data_, in->shape_, out->shape_, crop_param, task_id, crop->base_.thread_nr_); + return NNACL_OK; +} + +int CropResize(struct KernelBase *self) { + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_FALSE(out_tensor->shape_size_ <= Num1, NNACL_OUTPUT_TENSOR_ERROR); + + CropStruct *crop = (CropStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(crop); + CropParameter *crop_param = (CropParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(crop_param); + + return CropPadOffset(in_tensor->shape_size_, crop_param, crop->in_offset_); +} + +int CropCompute(struct KernelBase *self) { + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + CropParameter *crop_param = (CropParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(crop_param); + + if (in_tensor->data_type_ != kNumberTypeFloat16 && out_tensor->shape_[Index1] < self->thread_nr_) { + float *input_data = (float *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + float *output_data = (float *)out_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + Crop4DNoParallel(input_data, output_data, in_tensor->shape_, out_tensor->shape_, crop_param); + return NNACL_OK; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, CropLaunch, self, self->thread_nr_); +} + +KernelBase *CreateCrop(OpParameter *param, int data_type) { + CropStruct *crop = (CropStruct *)malloc(sizeof(CropStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(crop); + memset(crop, 0, sizeof(CropStruct)); + crop->base_.Prepare = DefaultPrepare1In1Out; + crop->base_.Resize = CropResize; + crop->base_.Release = DefaultRelease; + crop->base_.Compute = CropCompute; + return (KernelBase *)crop; +} + +REG_KERNEL_CREATOR(PrimType_Crop, kNumberTypeInt32, CreateCrop) +REG_KERNEL_CREATOR(PrimType_Crop, kNumberTypeFloat32, CreateCrop) +REG_KERNEL_CREATOR(PrimType_Crop, kNumberTypeFloat16, CreateCrop) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/crop.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/crop.h new file mode 100644 index 0000000000000000000000000000000000000000..1427b0c5f4ed52d670b632a70f3613fe2b0bf19e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/crop.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CROP_H_ +#define NNACL_KERNEL_CROP_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct CropStruct { + KernelBase base_; + int64_t in_offset_[COMM_SHAPE_SIZE]; +} CropStruct; + +KernelBase *CreateCrop(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CROP_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/crop_and_resize.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/crop_and_resize.c new file mode 100644 index 0000000000000000000000000000000000000000..fceaa3fe6ee01e37c64ce01572b1e5938ee6f583 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/crop_and_resize.c @@ -0,0 +1,190 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/crop_and_resize.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/fp32/resize_fp32.h" +#include "nnacl/tensor_c_utils.h" + +int CropAndResizeMallocTmpBuffer(CropAndResizeStruct *crop_and_resize) { + TensorC *input_tensor = crop_and_resize->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *output_tensor = crop_and_resize->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + ExecEnv *env = crop_and_resize->base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + // Malloc buffer to save coordinate. + // For mode CROP_AND_RESIZE, different output batches require different cache coordinates. + crop_and_resize->batch_ = NNACLGetBatch(output_tensor); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(crop_and_resize->new_height_, crop_and_resize->batch_, NNACL_ERR); + int height_size = crop_and_resize->new_height_ * crop_and_resize->batch_; + NNACL_CHECK_MALLOC_SIZE(height_size); + crop_and_resize->y_bottoms_ = (int *)env->Alloc(env->allocator_, height_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->y_bottoms_); + crop_and_resize->y_tops_ = (int *)env->Alloc(env->allocator_, height_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->y_tops_); + crop_and_resize->y_bottom_weights_ = (float *)env->Alloc(env->allocator_, height_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->y_bottom_weights_); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(crop_and_resize->new_width_, crop_and_resize->batch_, NNACL_ERR); + int width_size = crop_and_resize->new_width_ * crop_and_resize->batch_; + NNACL_CHECK_MALLOC_SIZE(width_size); + crop_and_resize->x_lefts_ = (int *)env->Alloc(env->allocator_, width_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->x_lefts_); + crop_and_resize->x_rights_ = (int *)env->Alloc(env->allocator_, width_size * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->x_rights_); + crop_and_resize->x_left_weights_ = (float *)env->Alloc(env->allocator_, width_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->x_left_weights_); + + int c = NNACLGetChannel(input_tensor); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(crop_and_resize->new_width_, c, NNACL_ERR); + int new_wc = crop_and_resize->new_width_ * c; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(new_wc, crop_and_resize->mapped_point_num_, NNACL_ERR); + int total_point_num = new_wc * crop_and_resize->mapped_point_num_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_point_num, crop_and_resize->base_.thread_nr_, NNACL_ERR); + int line_buffer_size = total_point_num * crop_and_resize->base_.thread_nr_ * sizeof(float); + crop_and_resize->line_buffer_ = (float *)env->Alloc(env->allocator_, line_buffer_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(crop_and_resize->line_buffer_); + return NNACL_OK; +} + +void CropAndResizeFreeTmpBuffer(CropAndResizeStruct *crop_and_resize) { + ExecEnv *env = crop_and_resize->base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + env->Free(env->allocator_, crop_and_resize->y_bottoms_); + env->Free(env->allocator_, crop_and_resize->y_tops_); + env->Free(env->allocator_, crop_and_resize->y_bottom_weights_); + env->Free(env->allocator_, crop_and_resize->x_lefts_); + env->Free(env->allocator_, crop_and_resize->x_rights_); + env->Free(env->allocator_, crop_and_resize->x_left_weights_); + env->Free(env->allocator_, crop_and_resize->line_buffer_); + crop_and_resize->y_bottoms_ = NULL; + crop_and_resize->y_tops_ = NULL; + crop_and_resize->y_bottom_weights_ = NULL; + crop_and_resize->x_lefts_ = NULL; + crop_and_resize->x_rights_ = NULL; + crop_and_resize->x_left_weights_ = NULL; + crop_and_resize->line_buffer_ = NULL; +} + +int CropAndResizeImpl(void *cdata, int task_id, float l, float r) { + CropAndResizeStruct *crop_and_resize = (CropAndResizeStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(crop_and_resize); + + TensorC *input = crop_and_resize->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *boxes = crop_and_resize->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(boxes); + TensorC *box_idx = crop_and_resize->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(box_idx); + TensorC *output = crop_and_resize->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + int unit = UP_DIV(crop_and_resize->new_height_, crop_and_resize->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(unit, task_id, NNACL_ERR); + int h_begin = unit * task_id; + int h_end = MSMIN(h_begin + unit, crop_and_resize->new_height_); + if (h_end <= h_begin) { + return NNACL_OK; + } + + float extrapolation_value = ((CropAndResizeParameter *)crop_and_resize->base_.param_)->extrapolation_value_; + int c = input->shape_[kNHWC_C]; + float *line0 = crop_and_resize->line_buffer_ + crop_and_resize->new_width_ * c * 2 * task_id; + float *line1 = line0 + crop_and_resize->new_width_ * c; + + return CropAndResizeBilinear((float *)input->data_, (float *)output->data_, (int32_t *)box_idx->data_, + (float *)boxes->data_, extrapolation_value, input->shape_, output->shape_, + crop_and_resize->y_bottoms_, crop_and_resize->y_tops_, crop_and_resize->x_lefts_, + crop_and_resize->x_rights_, crop_and_resize->y_bottom_weights_, + crop_and_resize->x_left_weights_, line0, line1, h_begin, h_end); +} + +int CropAndResizeCompute(struct KernelBase *self) { + CropAndResizeStruct *crop_and_resize = (CropAndResizeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(crop_and_resize); + + // In Prepare() stage, in_tensor[0] may be of fp16 data type in fp16 mode, so move type checks here. + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *boxes_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(boxes_tensor); + TensorC *boxidx_tensor = self->in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(boxidx_tensor); + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + int ret = CropAndResizeMallocTmpBuffer(crop_and_resize); + if (ret != NNACL_OK) { + CropAndResizeFreeTmpBuffer(crop_and_resize); + return ret; + } + + float *boxes = (float *)boxes_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(boxes); + int32_t *box_idx = (int32_t *)boxidx_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(box_idx); + + if (CheckCropAndResizeBoxIdx(box_idx, boxes_tensor->shape_[Index0], NNACLGetBatch(input_tensor)) != NNACL_OK) { + return NNACL_CROP_AND_RESIZE_BOX_IDX_INVALID; + } + + ret = PrepareCropAndResizeBilinear(input_tensor->shape_, boxes, box_idx, output_tensor->shape_, + crop_and_resize->y_bottoms_, crop_and_resize->y_tops_, crop_and_resize->x_lefts_, + crop_and_resize->x_rights_, crop_and_resize->y_bottom_weights_, + crop_and_resize->x_left_weights_); + if (ret != NNACL_OK) { + CropAndResizeFreeTmpBuffer(crop_and_resize); + return ret; + } + + int error_code = self->env_->ParallelLaunch(self->env_->thread_pool_, CropAndResizeImpl, self, self->thread_nr_); + CropAndResizeFreeTmpBuffer(crop_and_resize); + return error_code; +} + +int CropAndResizeResize(KernelBase *self) { + CropAndResizeStruct *crop_and_resize = (CropAndResizeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(crop_and_resize); + TensorC *output = self->out_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(output); + crop_and_resize->new_height_ = output->shape_[Index1]; + crop_and_resize->new_width_ = output->shape_[Index2]; + return NNACL_OK; +} + +int CropAndResizePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + return NNACL_OK; +} + +KernelBase *CreateCropAndResize(OpParameter *param, int data_type) { + CropAndResizeStruct *crop_and_resize = (CropAndResizeStruct *)malloc(sizeof(CropAndResizeStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(crop_and_resize); + memset(crop_and_resize, 0, sizeof(CropAndResizeStruct)); + crop_and_resize->mapped_point_num_ = Num2; + crop_and_resize->base_.Prepare = CropAndResizePrepare; + crop_and_resize->base_.Resize = CropAndResizeResize; + crop_and_resize->base_.Compute = CropAndResizeCompute; + crop_and_resize->base_.Release = DefaultRelease; + return (KernelBase *)crop_and_resize; +} + +REG_KERNEL_CREATOR(PrimType_CropAndResize, kNumberTypeFloat32, CreateCropAndResize) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/crop_and_resize.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/crop_and_resize.h new file mode 100644 index 0000000000000000000000000000000000000000..50856ad9e78693abf0dfaa165cafbbd7e2c16212 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/crop_and_resize.h @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_CROP_AND_RESIZE_H_ +#define NNACL_KERNEL_CROP_AND_RESIZE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct { + KernelBase base_; + int mapped_point_num_; + int batch_; + int new_height_; + int new_width_; + int *y_tops_; + int *y_bottoms_; + int *x_lefts_; + int *x_rights_; + float *y_bottom_weights_; + float *x_left_weights_; + float *line_buffer_; +} CropAndResizeStruct; + +KernelBase *CreateCropAndResize(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_CROP_AND_RESIZE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution.c new file mode 100644 index 0000000000000000000000000000000000000000..eed33fe36188bbc504aa804b8be392b0c2a0fa05 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution.c @@ -0,0 +1,337 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/deconvolution.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/kernel/deconvolution_winograd.h" +#include "nnacl/kernel/deconvolution_depthwise.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/fp32/deconv_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32/matmul_avx_fp32.h" +#include "nnacl/kernel/default_kernel_base.h" + +int DeConvMallocWeightBiasData(ConvolutionBaseStruct *conv) { + int output_aligned_size = UP_ROUND(conv->compute_.out_c_, C8NUM) * sizeof(float); + size_t pack_weight_size = conv->compute_.in_c_ * conv->compute_.kernel_hw_ * output_aligned_size; + if (!conv->base_.train_session_) { + conv->packed_weight_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, pack_weight_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->packed_weight_); + } + if (conv->bias_data_ == NULL) { + conv->bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, output_aligned_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv->bias_data_); + } + memset(conv->bias_data_, 0, output_aligned_size); + return NNACL_OK; +} + +void DeConvPackWeight(ConvolutionBaseStruct *conv) { + TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(weight_tensor); + void *weight_data = weight_tensor->data_ == NULL ? conv->origin_weight_ : weight_tensor->data_; + NNACL_CHECK_NULL_RETURN_VOID(weight_data); + +#ifdef ENABLE_AVX + PackNHWCToCXHWNXFp32((float *)weight_data, (float *)conv->packed_weight_, conv->compute_.in_c_, + conv->compute_.kernel_hw_, conv->compute_.out_c_); +#else + PackNHWCToC8HWN8Fp32((float *)weight_data, (float *)conv->packed_weight_, conv->compute_.in_c_, + conv->compute_.kernel_hw_, conv->compute_.out_c_); +#endif +} + +int DeConvInitParam(DeConvStruct *deconv) { + ConvComputeParam *compute = &deconv->conv_.compute_; + deconv->matmul_.row_ = compute->in_hw_; + deconv->matmul_.deep_ = compute->in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_c_, compute->kernel_hw_, NNACL_ERR); + deconv->matmul_.col_ = compute->out_c_ * compute->kernel_hw_; + deconv->matmul_.row_align_ = UP_ROUND(deconv->matmul_.row_, deconv->matmul_.row_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(UP_ROUND(compute->out_c_, C8NUM), compute->kernel_hw_, NNACL_ERR); + deconv->matmul_.col_align_ = UP_ROUND(compute->out_c_, C8NUM) * compute->kernel_hw_; + + deconv->conv_.base_.thread_nr_ = NNACL_MIN(deconv->conv_.base_.thread_nr_, UP_DIV(compute->out_c_, C8NUM)); + NNACL_CHECK_ZERO_RETURN_ERR(deconv->conv_.base_.thread_nr_); +#ifdef ENABLE_AVX + deconv->thread_stride_ = UP_DIV(UP_DIV(compute->out_c_, C8NUM * C3NUM), deconv->conv_.base_.thread_nr_) * C3NUM; +#else + deconv->thread_stride_ = UP_DIV(UP_DIV(compute->out_c_, C8NUM), deconv->conv_.base_.thread_nr_); +#endif + return NNACL_OK; +} + +int DeConvRun(void *cdata, int task_id, float l, float r) { + DeConvStruct *deconv = (DeConvStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + + int total_thead_stride_ = task_id * deconv->thread_stride_; + int res_stride = UP_DIV(deconv->conv_.compute_.out_c_, C8NUM) - total_thead_stride_; + int oc = NNACL_MIN(deconv->thread_stride_, res_stride); + int cur_stride = deconv->thread_stride_ * C8NUM; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_, C8NUM, NNACL_ERR); + int total_thead_stride_c8 = total_thead_stride_ * C8NUM; + res_stride = deconv->conv_.compute_.out_c_ - total_thead_stride_c8; + int oc_res = NNACL_MIN(cur_stride, res_stride); + if (oc <= 0 || oc_res <= 0) { + return NNACL_OK; + } + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_c8, deconv->conv_.compute_.kernel_hw_, NNACL_ERR); + int plane_thead_stride_c8 = total_thead_stride_c8 * deconv->conv_.compute_.kernel_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(plane_thead_stride_c8, deconv->matmul_.row_align_, NNACL_ERR); + int row_c8 = plane_thead_stride_c8 * deconv->matmul_.row_align_; + float *tmp_buffer = deconv->tmp_buffer_ + row_c8; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(plane_thead_stride_c8, deconv->matmul_.deep_, NNACL_ERR); + int deep_c8 = plane_thead_stride_c8 * deconv->matmul_.deep_; + +#ifdef ENABLE_AVX + DeconvMatmulAvx(deconv->pack_input_, (float *)deconv->conv_.packed_weight_ + deep_c8, tmp_buffer, + deconv->matmul_.deep_, deconv->matmul_.row_align_, oc * C8NUM * deconv->conv_.compute_.kernel_hw_, + deconv->conv_.compute_.kernel_hw_); +#elif ENABLE_SSE + DeconvMatmulFloatSse(deconv->pack_input_, (float *)deconv->conv_.packed_weight_ + deep_c8, tmp_buffer, + deconv->matmul_.deep_, deconv->matmul_.row_align_, + oc * C8NUM * deconv->conv_.compute_.kernel_hw_); +#else + MatMulOpt(deconv->pack_input_, (float *)deconv->conv_.packed_weight_ + deep_c8, tmp_buffer, NULL, ActType_No, + deconv->matmul_.deep_, deconv->matmul_.row_align_, oc * C8NUM * deconv->conv_.compute_.kernel_hw_, + deconv->matmul_.col_, OutType_C8); +#endif + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_thead_stride_c8, deconv->conv_.compute_.out_hw_, NNACL_OK); + DeConvPostFp32C8(tmp_buffer, deconv->pack_output_ + total_thead_stride_c8 * deconv->conv_.compute_.out_hw_, + (float *)deconv->conv_.bias_data_ + total_thead_stride_c8, + deconv->output_ptr_ + total_thead_stride_c8, oc_res, (ConvParameter *)deconv->conv_.base_.param_); + return NNACL_OK; +} + +void DeConvFreeRunBuf(DeConvStruct *deconv) { + ExecEnv *env = deconv->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_VOID(env); + + if (deconv->pack_output_ != NULL) { + env->Free(env->allocator_, deconv->pack_output_); + deconv->pack_output_ = NULL; + } + if (deconv->tmp_buffer_ != NULL) { + env->Free(env->allocator_, deconv->tmp_buffer_); + deconv->tmp_buffer_ = NULL; + } + if (deconv->pack_input_ != NULL) { + env->Free(env->allocator_, deconv->pack_input_); + deconv->pack_input_ = NULL; + } +} + +int DeConvInitRunBuf(DeConvStruct *deconv) { + ExecEnv *env = deconv->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + int pack_output_size = UP_ROUND(deconv->conv_.compute_.out_c_, C8NUM) * deconv->conv_.compute_.out_hw_; + deconv->pack_output_ = (float *)env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->pack_output_); + + int tmp_buffer_size = deconv->matmul_.row_align_ * deconv->matmul_.col_align_; + deconv->tmp_buffer_ = (float *)env->Alloc(env->allocator_, tmp_buffer_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->tmp_buffer_); + + int pack_input_size = deconv->matmul_.row_align_ * deconv->matmul_.deep_; + deconv->pack_input_ = (float *)env->Alloc(env->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->pack_input_); + + return NNACL_OK; +} + +int DeConvCheckvResizeValid(ConvolutionBaseStruct *conv) { + // ===============check in channel================= // + TensorC *input_tensor = conv->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT]; + + int resize_out_channel = NNACLGetChannel(input_tensor); + int filter_out_channel = NNACLGetBatch(filter_tensor); + if (filter_out_channel != resize_out_channel) { + return NNACL_DECONV_RESIZE_OC_INVALID; + } + return NNACL_OK; +} + +int DeConvResize(KernelBase *self) { + DeConvStruct *deconv = (DeConvStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + + (void)ConvBaseUpdateComputeInfo(&deconv->conv_); + + int ret = DeConvCheckvResizeValid(&deconv->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = ConvBasePrepare(&deconv->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = DeConvInitParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + + return NNACL_OK; +} + +int DeConvCompute(KernelBase *self) { + DeConvStruct *deconv = (DeConvStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + + int error_code = ConvBaseRepackWeight(&deconv->conv_); + if (error_code != NNACL_OK) { + return error_code; + } + + error_code = DeConvInitRunBuf(deconv); + if (error_code != NNACL_OK) { + DeConvFreeRunBuf(deconv); + return error_code; + } + + float *src_in = (float *)self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_in); + float *src_out = (float *)self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_out); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.in_n_ - 1, deconv->conv_.compute_.in_c_, NNACL_ERR); + int input_bc = (deconv->conv_.compute_.in_n_ - 1) * deconv->conv_.compute_.in_c_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.in_hw_, input_bc, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.out_hw_, input_bc, NNACL_ERR); + for (int batch_index = 0; batch_index < deconv->conv_.compute_.in_n_; batch_index++) { + deconv->input_ptr_ = src_in + batch_index * deconv->conv_.compute_.in_hw_ * deconv->conv_.compute_.in_c_; + deconv->output_ptr_ = src_out + batch_index * deconv->conv_.compute_.out_hw_ * deconv->conv_.compute_.out_c_; + +#if defined(ENABLE_ARM32) || defined(ENABLE_SSE) + RowMajor2Col4Major(deconv->input_ptr_, deconv->pack_input_, deconv->matmul_.row_, deconv->matmul_.deep_); +#else + RowMajor2Col12Major(deconv->input_ptr_, deconv->pack_input_, deconv->matmul_.row_, deconv->matmul_.deep_); +#endif + + error_code = self->env_->ParallelLaunch(self->env_->thread_pool_, DeConvRun, self, self->thread_nr_); + if (error_code != NNACL_OK) { + DeConvFreeRunBuf(deconv); + return error_code; + } + } + + DeConvFreeRunBuf(deconv); + return NNACL_OK; +} + +int DeConvPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + DeConvStruct *deconv = (DeConvStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + ConvParameter *param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + // There could be weight dataType casting before Prepare, thus weight update is required. + ConvBaseUpdateOriginWeightAndBias(&deconv->conv_); + +#if defined(ENABLE_ARM32) || defined(ENABLE_AVX) || defined(ENABLE_SSE) + deconv->matmul_.row_tile_ = C4NUM; +#else + deconv->matmul_.row_tile_ = C12NUM; +#endif + + if (self->train_session_) { + int output_aligned_size = UP_ROUND(deconv->conv_.compute_.out_c_, C8NUM); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.in_c_, deconv->conv_.compute_.kernel_hw_, NNACL_ERR); + int kernel_chw = deconv->conv_.compute_.in_c_ * deconv->conv_.compute_.kernel_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, output_aligned_size, NNACL_ERR); + size_t pack_weight_size = kernel_chw * output_aligned_size * sizeof(float); + self->work_size_ = pack_weight_size; + } + + if (self->in_[SECOND_INPUT]->data_ != NULL) { + int error_code = ConvBaseInitConvWeightBias(&deconv->conv_); + if (error_code != NNACL_OK) { + return error_code; + } + } else { + deconv->conv_.is_repack_ = true; + } + + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateDeConv(ConvParameter *param) { + DeConvStruct *deconv = (DeConvStruct *)malloc(sizeof(DeConvStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(deconv); + memset(deconv, 0, sizeof(DeConvStruct)); + deconv->conv_.malloc_weight_bias_ = DeConvMallocWeightBiasData; + deconv->conv_.pack_weight_ = DeConvPackWeight; + deconv->conv_.base_.Prepare = DeConvPrepare; + deconv->conv_.base_.Resize = DeConvResize; + deconv->conv_.base_.Release = DefaultRelease; + deconv->conv_.base_.Compute = DeConvCompute; + return &deconv->conv_; +} + +ConvolutionBaseStruct *SelectDeConv(ConvParameter *conv_param) { +#ifndef _WIN32 +#ifndef ENABLE_MCU + bool param_winograd_fit = (conv_param->stride_h_ > 1 || conv_param->stride_w_ > 1) && + (conv_param->dilation_w_ == 1 && conv_param->dilation_h_ == 1); + +#ifdef ENABLE_AVX + bool in_size_winograd_fit = conv_param->input_w_ * conv_param->input_h_ >= NNACL_DECONV_WINOGRAD_HW_MAX; + bool size_winograd_fit = (conv_param->kernel_w_ / conv_param->stride_w_ >= C2NUM || + conv_param->kernel_h_ / conv_param->stride_h_ >= C2NUM || conv_param->output_channel_ == 1); +#else + bool in_size_winograd_fit = true; + bool size_winograd_fit = + (conv_param->kernel_w_ / conv_param->stride_w_ > C2NUM || conv_param->kernel_h_ / conv_param->stride_h_ > C2NUM); +#endif + + if (param_winograd_fit && size_winograd_fit && in_size_winograd_fit) { + ConvolutionBaseStruct *kernel = CreateDeConvWinograd(conv_param); + if (kernel != NULL) { + return kernel; + } + } +#endif +#endif + + return CreateDeConv(conv_param); +} + +KernelBase *CreateConvolutionTranspose(OpParameter *param, int data_type) { + ConvParameter *conv_param = (ConvParameter *)param; + NNACL_CHECK_NULL_RETURN_NULL(conv_param); + + ConvolutionBaseStruct *conv = NULL; + if (conv_param->group_ == 1 && conv_param->input_channel_ == 1 && conv_param->output_channel_ == 1) { + conv = CreateDeConvDw(conv_param); + } else if (conv_param->group_ == 1) { + conv = SelectDeConv(conv_param); + } else if (conv_param->group_ == conv_param->input_channel_ && conv_param->group_ == conv_param->output_channel_) { + conv = CreateDeConvDw(conv_param); + } + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv); + ConvBaseUpdateParamInfo(&conv->compute_, conv_param); + return &conv->base_; +} + +REG_KERNEL_CREATOR(PrimType_Conv2dTransposeFusion, kNumberTypeFloat32, CreateConvolutionTranspose) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution.h new file mode 100644 index 0000000000000000000000000000000000000000..c5c7871d3bf9a409b9768eef51d0ee262dd1ba58 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_DECONVOLUTION_H_ +#define NNACL_KERNEL_DECONVOLUTION_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/kernel/convolution_base.h" +#include "nnacl/kernel/matmul_struct.h" + +typedef struct DeConvStruct { + ConvolutionBaseStruct conv_; + MatmulComputeParam matmul_; + int thread_stride_; + float *pack_input_; + float *pack_output_; + float *tmp_buffer_; + float *input_ptr_; + float *output_ptr_; +} DeConvStruct; + +int DeConvCheckvResizeValid(ConvolutionBaseStruct *conv); +KernelBase *CreateConvolutionTranspose(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_DECONVOLUTION_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution_depthwise.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution_depthwise.c new file mode 100644 index 0000000000000000000000000000000000000000..f5ad38d5476296b79785809fa7811432019a84df --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution_depthwise.c @@ -0,0 +1,233 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/deconvolution_depthwise.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/fp32/conv_depthwise_fp32.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/kernel/deconvolution.h" + +int DeConvDwInitPackedInputOutput(DeConvDwStruct *deconv_dw) { + if (!deconv_dw->need_align_) { + return NNACL_OK; + } + ExecEnv *env = deconv_dw->conv_.base_.env_; + ConvComputeParam *compute = &deconv_dw->conv_.compute_; + + int ic4 = UP_ROUND(compute->in_c_, compute->tile_num_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_n_, compute->in_hw_, NNACL_ERR); + int input_bhw = compute->in_n_ * compute->in_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(input_bhw, ic4, NNACL_ERR); + int pack_input_size = input_bhw * ic4; + deconv_dw->packed_input_ = (float *)env->Alloc(env->allocator_, pack_input_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv_dw->packed_input_); + + int oc4 = UP_ROUND(compute->out_c_, compute->tile_num_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_n_, compute->out_hw_, NNACL_ERR); + int output_bhw = compute->out_n_ * compute->out_hw_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, oc4, NNACL_ERR); + int pack_output_size = output_bhw * oc4; + deconv_dw->packed_output_ = (float *)env->Alloc(env->allocator_, pack_output_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv_dw->packed_output_); + memset(deconv_dw->packed_output_, 0, pack_output_size * sizeof(float)); + + return NNACL_OK; +} + +int DeconvDwRun(void *cdata, int task_id, float l, float r) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + DeconvDwSWFp32(deconv_dw->packed_output_, deconv_dw->packed_input_, (float *)deconv_dw->conv_.packed_weight_, + (float *)deconv_dw->conv_.bias_data_, (ConvParameter *)deconv_dw->conv_.base_.param_, + &deconv_dw->sliding_, task_id); + return NNACL_OK; +} + +int DeConvDwMallocWeightBiasData(ConvolutionBaseStruct *conv) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)conv; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + + int oc4 = UP_ROUND(conv->compute_.out_c_, conv->compute_.tile_num_); + if (!conv->base_.train_session_) { + int pack_weight_size = oc4 * conv->compute_.kernel_hw_; + NNACL_CHECK_MALLOC_SIZE(pack_weight_size); + deconv_dw->conv_.packed_weight_ = ConvBaseGetConvPackWeightData(conv, pack_weight_size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv_dw->conv_.packed_weight_); + } + + if (deconv_dw->conv_.bias_data_ == NULL) { + NNACL_CHECK_MALLOC_SIZE(oc4 * sizeof(float)); + deconv_dw->conv_.bias_data_ = conv->base_.env_->Alloc(conv->base_.env_->allocator_, oc4 * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv_dw->conv_.bias_data_); + } + memset(deconv_dw->conv_.bias_data_, 0, oc4 * sizeof(float)); + return NNACL_OK; +} + +void DeConvDwPackWeight(ConvolutionBaseStruct *conv) { + void *origin_weight = conv->base_.train_session_ ? conv->base_.in_[SECOND_INPUT]->data_ : conv->origin_weight_; + NNACL_CHECK_NULL_RETURN_VOID(origin_weight); + PackNCHWToNC4HW4Fp32(origin_weight, conv->packed_weight_, 1, conv->compute_.kernel_hw_, conv->compute_.out_c_); +} + +void DeConvDwFreePackedInputOutput(DeConvDwStruct *deconv_dw) { + if (deconv_dw->need_align_) { + ExecEnv *env = deconv_dw->conv_.base_.env_; + + env->Free(env->allocator_, deconv_dw->packed_input_); + deconv_dw->packed_input_ = NULL; + env->Free(env->allocator_, deconv_dw->packed_output_); + deconv_dw->packed_output_ = NULL; + } +} + +int DeConvDwPrepare(KernelBase *self) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + ConvComputeParam *compute = &deconv_dw->conv_.compute_; + deconv_dw->conv_.compute_.tile_num_ = C4NUM; + + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + NNACL_CHECK_FALSE(compute->in_c_ != compute->out_c_, NNACL_DECONVOLUTION_DEPTHWISE_CHANNEL_INVALID); + NNACL_CHECK_FALSE(compute->dilation_h_ != Num1, NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID); + NNACL_CHECK_FALSE(compute->dilation_w_ != Num1, NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID); + + ConvBaseUpdateOriginWeightAndBias(&deconv_dw->conv_); + + if (self->train_session_) { + int oc4 = UP_ROUND(compute->out_c_, compute->tile_num_); + int pack_weight_size = oc4 * compute->kernel_hw_; + self->work_size_ = pack_weight_size; + } + + int ret = ConvBaseInitConvWeightBias(&deconv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw->conv_.packed_weight_); + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw->conv_.bias_data_); + return NNACL_OK; +} + +void DeConvDwUpdateParam(ConvolutionBaseStruct *conv) { + TensorC *input = conv->base_.in_[FIRST_INPUT]; + TensorC *output = conv->base_.out_[OUTPUT_INDEX]; + + ConvParameter *conv_param = (ConvParameter *)conv->base_.param_; + conv_param->thread_num_ = conv->base_.thread_nr_; + conv_param->input_batch_ = NNACLGetBatch(output); + conv_param->input_h_ = NNACLGetHeight(output); + conv_param->input_w_ = NNACLGetWidth(output); + conv_param->input_channel_ = NNACLGetChannel(output); + conv_param->output_batch_ = NNACLGetBatch(input); + conv_param->output_h_ = NNACLGetHeight(input); + conv_param->output_w_ = NNACLGetWidth(input); + conv_param->output_channel_ = NNACLGetChannel(input); + + ConvComputeParam *compute = &conv->compute_; + compute->in_n_ = NNACLGetBatch(output); + compute->in_h_ = NNACLGetHeight(output); + compute->in_w_ = NNACLGetWidth(output); + compute->in_c_ = NNACLGetChannel(output); + compute->out_n_ = NNACLGetBatch(input); + compute->out_h_ = NNACLGetHeight(input); + compute->out_w_ = NNACLGetWidth(input); + compute->out_c_ = NNACLGetChannel(input); +} + +int DeConvDwResize(KernelBase *self) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + + (void)ConvBaseUpdateComputeInfo(&deconv_dw->conv_); + + int ret = DeConvCheckvResizeValid(&deconv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + int tile_num = deconv_dw->conv_.compute_.tile_num_; + DeConvDwUpdateParam(&deconv_dw->conv_); + (void)InitSlidingParamConvDw(&deconv_dw->sliding_, (ConvParameter *)self->param_, tile_num); + self->thread_nr_ = NNACL_MIN(self->thread_nr_, UP_DIV(deconv_dw->conv_.compute_.out_c_, tile_num)); + deconv_dw->need_align_ = deconv_dw->conv_.compute_.in_c_ % tile_num != 0; + + ret = ConvBasePrepare(&deconv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +int DeConvDwCompute(KernelBase *self) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv_dw); + ConvComputeParam *compute = &deconv_dw->conv_.compute_; + + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + float *in_data = (float *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(in_data); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + float *out_data = (float *)out_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + + int ret = ConvBaseRepackWeight(&deconv_dw->conv_); + if (ret != NNACL_OK) { + return ret; + } + + ret = DeConvDwInitPackedInputOutput(deconv_dw); + if (ret != NNACL_OK) { + DeConvDwFreePackedInputOutput(deconv_dw); + return ret; + } + + if (deconv_dw->need_align_) { + PackNHWCToNHWC4Fp32(in_data, deconv_dw->packed_input_, compute->in_n_, compute->in_hw_, compute->in_c_); + } else { + deconv_dw->packed_input_ = in_data; + deconv_dw->packed_output_ = out_data; + memset(deconv_dw->packed_output_, 0, NNACLGetSize(out_tensor)); + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, DeconvDwRun, self, self->thread_nr_); + + if (deconv_dw->need_align_) { + PackNHWCXToNHWCFp32(deconv_dw->packed_output_, out_data, compute->out_n_, compute->out_hw_, compute->out_c_, + compute->tile_num_); + } + DeConvDwFreePackedInputOutput(deconv_dw); + return ret; +} + +ConvolutionBaseStruct *CreateDeConvDw(ConvParameter *param) { + DeConvDwStruct *deconv_dw = (DeConvDwStruct *)malloc(sizeof(DeConvDwStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(deconv_dw); + memset(deconv_dw, 0, sizeof(DeConvDwStruct)); + + deconv_dw->conv_.pack_weight_ = DeConvDwPackWeight; + deconv_dw->conv_.malloc_weight_bias_ = DeConvDwMallocWeightBiasData; + deconv_dw->conv_.base_.Prepare = DeConvDwPrepare; + deconv_dw->conv_.base_.Resize = DeConvDwResize; + deconv_dw->conv_.base_.Release = DefaultRelease; + deconv_dw->conv_.base_.Compute = DeConvDwCompute; + return &deconv_dw->conv_; +} diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ops/ascend_native_stub.cc b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution_depthwise.h similarity index 54% rename from mindspore-lite/src/extendrt/delegate/ascend_native/ops/ascend_native_stub.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution_depthwise.h index 74ea18710ec3bfda840fbd0d7c4e372542ad266f..ec4c568cda52ced61075ac789a7085eeb09069b7 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ops/ascend_native_stub.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution_depthwise.h @@ -13,20 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef NNACL_KERNEL_DECONVOLUTION_DEPTHWISE_H_ +#define NNACL_KERNEL_DECONVOLUTION_DEPTHWISE_H_ -#include "extendrt/delegate/ascend_native/ops/ascend_native_stub.h" -#include "mindapi/base/shared_ptr.h" -#include "mindapi/ir/common.h" -#include "mindapi/ir/value.h" -#include "mindspore/ops/op_def/op_name.h" -#include "ops/primitive_c.h" -#include "src/common/log_adapter.h" -#include "mindapi/helper.h" +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/kernel/convolution_base.h" -namespace mindspore { -namespace ops { -MIND_API_OPERATOR_IMPL(AscendNativeStub, BaseOperator); +typedef struct DeConvDwStruct { + ConvolutionBaseStruct conv_; + SlidingWindowParam sliding_; + bool need_align_; + float *packed_input_; + float *packed_output_; +} DeConvDwStruct; -REGISTER_PRIMITIVE_C(kNameAscendNativeStub, AscendNativeStub); -} // namespace ops -} // namespace mindspore +ConvolutionBaseStruct *CreateDeConvDw(ConvParameter *param); + +#endif // NNACL_KERNEL_DECONVOLUTION_DEPTHWISE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution_winograd.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution_winograd.c new file mode 100644 index 0000000000000000000000000000000000000000..745826b9d09caeeaa7550ddfc1dafb80c87c1964 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution_winograd.c @@ -0,0 +1,551 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _WIN32 +#ifndef ENABLE_MCU +#include "nnacl/kernel/deconvolution_winograd.h" +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32/deconv_winograd_fp32.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/kernel/deconvolution.h" + +void DeConvWinogradFreeResizeBuf(DeConvWinogradStruct *deconv) { + DeConvParam *param = &deconv->param_; + + for (int i = 0; i < param->compute_size_; i++) { + DeConvComputeUnit *unit = ¶m->compute_units_[i]; + if (unit->tmp_buffer_ != NULL) { + free(unit->tmp_buffer_); + unit->tmp_buffer_ = NULL; + } + + if (unit->use_winograd_) { + if (unit->winograd_.b_buffer_ != NULL) { + free(unit->winograd_.b_buffer_); + unit->winograd_.b_buffer_ = NULL; + } + } + } + + for (int i = 0; i < DECONV_WINOGRAD_BUFFER_COUNT; i++) { + DeConvWgABuffer *wg = ¶m->a_buffer_[i]; + if (wg->buf_init_) { + if (wg->dest_buffer_ != NULL) { + free(wg->dest_buffer_); + wg->dest_buffer_ = NULL; + } + if (wg->middle_buffer_ != NULL) { + free(wg->middle_buffer_); + wg->middle_buffer_ = NULL; + } + } + wg->buf_init_ = false; + } + + if (deconv->tile_input_ != NULL) { + free(deconv->tile_input_); + deconv->tile_input_ = NULL; + } +} + +void DeConvWinogradFreeDeconvParam(DeConvWinogradStruct *deconv) { + DeConvParam *param = &deconv->param_; + + for (int i = 0; i < param->compute_size_; i++) { + DeConvComputeUnit *unit = ¶m->compute_units_[i]; + + if (unit->weight_ != NULL) { + free(unit->weight_); + unit->weight_ = NULL; + } + + if (unit->use_winograd_) { + if (unit->winograd_.AT_ != NULL) { + free(unit->winograd_.AT_); + unit->winograd_.AT_ = NULL; + } + if (unit->winograd_.BT_ != NULL) { + free(unit->winograd_.BT_); + unit->winograd_.BT_ = NULL; + } + } + } + + if (param->compute_units_ != NULL) { + free(param->compute_units_); + param->compute_units_ = NULL; + } +} + +int DeConvWinogradInitParameter(DeConvWinogradStruct *deconv) { + DeConvParam *param = &deconv->param_; + ConvComputeParam *compute = &deconv->conv_.compute_; + + int thread_num = deconv->conv_.base_.thread_nr_; + NNACL_CHECK_ZERO_RETURN_ERR(thread_num); + + param->input_plane_ = compute->in_hw_; + param->output_plane_ = compute->out_hw_; + + param->in_tile_w_count_ = UP_DIV(compute->in_w_, WINOGRAD_DEFAULT_UNIT); + NNACL_CHECK_ZERO_RETURN_ERR(param->in_tile_w_count_); + param->in_tile_h_count_ = UP_DIV(compute->in_h_, WINOGRAD_DEFAULT_UNIT); + NNACL_CHECK_ZERO_RETURN_ERR(param->in_tile_h_count_); + param->in_tile_count_ = UP_DIV(param->in_tile_w_count_ * param->in_tile_h_count_, WINOGRAD_DEFAULT_TILE); + + deconv->conv_.base_.thread_nr_ = NNACL_MAX(1, deconv->conv_.base_.thread_nr_); + deconv->conv_.base_.thread_nr_ = NNACL_MIN(deconv->conv_.base_.thread_nr_, param->in_tile_count_); + + deconv->thread_num_hw_ = NNACL_MIN(deconv->conv_.base_.thread_nr_, compute->out_hw_); + NNACL_CHECK_ZERO_RETURN_ERR(deconv->thread_num_hw_); + deconv->thread_stride_hw_ = UP_DIV(compute->out_hw_, deconv->thread_num_hw_); + + int total_ic_up = WINOGRAD_DEFAULT_UNIT * WINOGRAD_DEFAULT_UNIT * WINOGRAD_DEFAULT_TILE * param->ic_up_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.base_.thread_nr_, total_ic_up, NNACL_ERR); + int size = deconv->conv_.base_.thread_nr_ * total_ic_up; + NNACL_CHECK_MALLOC_SIZE(size * sizeof(float)); + deconv->tile_input_ = (float *)malloc(size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->tile_input_); + (void)memset(deconv->tile_input_, 0, size * sizeof(float)); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW((WINOGRAD_DEFAULT_UNIT - 1), compute->stride_w_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW((WINOGRAD_DEFAULT_UNIT - 1), compute->stride_h_, NNACL_ERR); + param->out_tile_w_ = (WINOGRAD_DEFAULT_UNIT - 1) * compute->stride_w_ + compute->kernel_w_; + param->out_tile_h_ = (WINOGRAD_DEFAULT_UNIT - 1) * compute->stride_h_ + compute->kernel_h_; + + for (int i = 0; i < param->compute_size_; i++) { + DeConvComputeUnit *unit = ¶m->compute_units_[i]; + if (unit->use_winograd_) { + if (!param->a_buffer_[unit->winograd_.kh_].buf_init_) { + param->a_buffer_[unit->winograd_.kh_].buf_init_ = true; + size = unit->winograd_.kh_ * unit->winograd_.kw_ * WINOGRAD_DEFAULT_TILE * param->ic_up_; + + param->a_buffer_[unit->winograd_.kh_].middle_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(param->a_buffer_[unit->winograd_.kh_].middle_buffer_); + + param->a_buffer_[unit->winograd_.kh_].dest_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(param->a_buffer_[unit->winograd_.kh_].dest_buffer_); + } + + size = unit->winograd_.kh_ * unit->winograd_.kw_ * param->oc_up_ * WINOGRAD_DEFAULT_TILE; + unit->winograd_.b_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit->winograd_.b_buffer_); + + size = unit->winograd_.kh_ * unit->winograd_.kw_ * param->oc_div_ * WINOGRAD_DEFAULT_TILE * compute->tile_num_; + unit->tmp_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit->tmp_buffer_); + } else { + size = param->oc_div_ * unit->w_size_ * unit->h_size_ * WINOGRAD_DEFAULT_TILE * compute->tile_num_; + unit->tmp_buffer_ = malloc(thread_num * size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit->tmp_buffer_); + } + } + + return NNACL_OK; +} + +int DeConvWgFp32Run(void *cdata, int task_id, float l, float r) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + ConvParameter *conv_param = (ConvParameter *)deconv->conv_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + DeConvParam *param = &deconv->param_; + ConvComputeParam *compute = &deconv->conv_.compute_; + + for (int tile_index = task_id; tile_index < param->in_tile_count_; tile_index += deconv->conv_.base_.thread_nr_) { + int size = WINOGRAD_DEFAULT_UNIT * WINOGRAD_DEFAULT_UNIT * WINOGRAD_DEFAULT_TILE * param->ic_up_; + float *tile_in = deconv->tile_input_ + task_id * size; + size = param->out_tile_w_ * param->out_tile_h_ * WINOGRAD_DEFAULT_TILE * param->oc_div_ * compute->tile_num_; + float *tile_out = deconv->tile_output_ + task_id * size; + (void)memset(tile_out, 0, size * sizeof(float)); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(tile_index, WINOGRAD_DEFAULT_TILE, NNACL_ERR); + int start_index = tile_index * WINOGRAD_DEFAULT_TILE; + int cal_count = NNACL_MIN(WINOGRAD_DEFAULT_TILE, param->in_tile_w_count_ * param->in_tile_h_count_ - start_index); + + int ret = DeconvWg(deconv->nhwc_input_, tile_in, tile_out, start_index, cal_count, conv_param, param, task_id); + if (ret != NNACL_OK) { + return ret; + } + + (void)pthread_mutex_lock(&deconv->lock_); + (void)DeconvWgPost(tile_out, deconv->nc4hw4_output_, conv_param, param, cal_count, tile_index); + (void)pthread_mutex_unlock(&deconv->lock_); + } + return NNACL_OK; +} + +int DeConvWgPostFp32Run(void *cdata, int task_id, float l, float r) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + ConvComputeParam *compute = &deconv->conv_.compute_; + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, deconv->thread_stride_hw_, NNACL_ERR); + int output_stride_plane = task_id * deconv->thread_stride_hw_; + int rest_plane = compute->out_hw_ - output_stride_plane; + int current_plane = MSMIN(rest_plane, deconv->thread_stride_hw_); + if (current_plane <= 0) { + return NNACL_OK; + } + + ActType act = ((ConvParameter *)deconv->conv_.base_.param_)->act_type_; + float *bias = (float *)deconv->conv_.bias_data_; + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_stride_plane, deconv->conv_.compute_.tile_num_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_stride_plane, deconv->conv_.compute_.out_c_, NNACL_ERR); + WinogradPostConvFuncFp32CX(deconv->nc4hw4_output_ + output_stride_plane * compute->tile_num_, + deconv->nhwc_output_ + output_stride_plane * compute->out_c_, bias, compute->out_c_, + current_plane, compute->out_hw_, act); + return NNACL_OK; +} + +int DeConvWinogradInitComputeParam(DeConvWinogradStruct *deconv) { + deconv->valid_weight_shape_ = CheckShaleValid(&deconv->conv_.base_.in_[SECOND_INPUT], Num1); + if (deconv->valid_weight_shape_ == false) { + return NNACL_OK; + } + + ConvComputeParam *compute = &deconv->conv_.compute_; + DeConvParam *param = &deconv->param_; + + param->kernel_plane_ = compute->kernel_hw_; + param->ic_div_ = UP_DIV(compute->in_c_, compute->tile_num_); + param->oc_div_ = UP_DIV(compute->out_c_, compute->tile_num_); + param->ic_up_ = param->ic_div_ * compute->tile_num_; + param->oc_up_ = param->oc_div_ * compute->tile_num_; + + param->compute_size_ = 0; + for (int si_h = 0; si_h < compute->stride_h_; si_h++) { + for (int si_w = 0; si_w < compute->stride_w_; si_w++) { + if (si_h < compute->kernel_h_ && si_w < compute->kernel_w_) { + param->compute_size_++; + } + } + } + + size_t size = (size_t)param->compute_size_ * sizeof(DeConvComputeUnit); + param->compute_units_ = (DeConvComputeUnit *)(malloc(size)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(param->compute_units_); + + int cur_count = 0; + for (int si_h = 0; si_h < compute->stride_h_; si_h++) { + if (si_h >= compute->kernel_h_) { + continue; + } + for (int si_w = 0; si_w < compute->stride_w_; si_w++) { + if (si_w >= compute->kernel_w_) { + continue; + } + + int h_size = 1 + (compute->kernel_h_ - si_h - 1) / compute->stride_h_; + int w_size = 1 + (compute->kernel_w_ - si_w - 1) / compute->stride_w_; + + DeConvComputeUnit unit; + unit.winograd_.AT_ = NULL; + unit.winograd_.BT_ = NULL; + + unit.h_start_ = si_h; + unit.w_start_ = si_w; + unit.h_size_ = h_size; + unit.w_size_ = w_size; + + unit.use_winograd_ = false; + if (h_size == w_size) { + unit.winograd_.k_ = unit.h_size_; + unit.winograd_.i_ = WINOGRAD_DEFAULT_UNIT; + unit.winograd_.o_ = WINOGRAD_DEFAULT_UNIT + unit.h_size_ - 1; + unit.winograd_.kh_ = unit.h_size_ + WINOGRAD_DEFAULT_UNIT - 1; + unit.winograd_.kw_ = unit.w_size_ + WINOGRAD_DEFAULT_UNIT - 1; + unit.use_winograd_ = unit.winograd_.kh_ < WINOGRAD_MAX_COUNT && unit.winograd_.kw_ < WINOGRAD_MAX_COUNT; + } + if (unit.use_winograd_) { + unit.winograd_.b_buffer_ = NULL; + unit.weight_ = malloc(unit.winograd_.kh_ * unit.winograd_.kw_ * param->oc_up_ * param->ic_up_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit.weight_); + } else { + unit.weight_ = malloc(h_size * w_size * param->ic_up_ * param->oc_up_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(unit.weight_); + } + unit.tmp_buffer_ = NULL; + param->compute_units_[cur_count] = unit; + cur_count++; + } + } + return NNACL_OK; +} + +int DeConvWinogradInitDataParam(DeConvWinogradStruct *deconv) { + TensorC *weight_tensor = deconv->conv_.base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(weight_tensor); + float *nhwc_weight = weight_tensor->data_; + if (nhwc_weight == NULL) { + deconv->conv_.is_repack_ = true; + return NNACL_OK; + } + + DeConvParam *param = &deconv->param_; + + /* unit data : weight & winograd data */ + for (int i = 0; i < param->compute_size_; i++) { + DeConvComputeUnit *unit = ¶m->compute_units_[i]; + int ret = PackDeConvWgDataFp32(nhwc_weight, unit, (ConvParameter *)deconv->conv_.base_.param_, param); + if (ret != NNACL_OK) { + return ret; + } + } + + /* bias */ + ExecEnv *env = deconv->conv_.base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + if (deconv->conv_.bias_data_ != NULL) { + env->Free(env->allocator_, deconv->conv_.bias_data_); + deconv->conv_.bias_data_ = NULL; + } + deconv->conv_.bias_data_ = env->Alloc(env->allocator_, param->oc_up_ * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->conv_.bias_data_); + (void)memset(deconv->conv_.bias_data_, 0, param->oc_up_ * sizeof(float)); + + if (deconv->conv_.base_.in_size_ == THREE_TENSOR) { + TensorC *bias_tensor = deconv->conv_.base_.in_[THIRD_INPUT]; + if (bias_tensor->shape_size_ == Num1 && NNACLGetElementNum(bias_tensor) == deconv->conv_.compute_.out_c_) { + (void)memcpy(deconv->conv_.bias_data_, bias_tensor->data_, deconv->conv_.compute_.out_c_ * sizeof(float)); + } + } + return NNACL_OK; +} + +int DeConvWinogradInitRunBuf(DeConvWinogradStruct *deconv) { + ExecEnv *env = deconv->conv_.base_.env_; + + int size = deconv->param_.oc_up_ * deconv->conv_.compute_.out_hw_; + deconv->nc4hw4_output_ = (float *)env->Alloc(env->allocator_, size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->nc4hw4_output_); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->param_.out_tile_w_, deconv->param_.out_tile_h_, NNACL_ERR); + int out_tile_hw = deconv->param_.out_tile_w_ * deconv->param_.out_tile_h_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.base_.thread_nr_, out_tile_hw, NNACL_ERR); + int total_out_tile_hw = deconv->conv_.base_.thread_nr_ * out_tile_hw; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(WINOGRAD_DEFAULT_TILE, deconv->param_.oc_up_, NNACL_ERR); + int tile_oc_up = WINOGRAD_DEFAULT_TILE * deconv->param_.oc_up_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_out_tile_hw, tile_oc_up, NNACL_ERR); + size = total_out_tile_hw * tile_oc_up; + deconv->tile_output_ = (float *)env->Alloc(env->allocator_, size * sizeof(float)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->tile_output_); + + return NNACL_OK; +} + +void DeConvWinogradFreeRunBuf(DeConvWinogradStruct *deconv) { + ExecEnv *env = deconv->conv_.base_.env_; + + if (deconv->nc4hw4_output_ != NULL) { + env->Free(env->allocator_, deconv->nc4hw4_output_); + deconv->nc4hw4_output_ = NULL; + } + + if (deconv->tile_output_ != NULL) { + env->Free(env->allocator_, deconv->tile_output_); + deconv->tile_output_ = NULL; + } +} + +int InitTrainComputeInit(DeConvWinogradStruct *deconv) { + if (!deconv->valid_weight_shape_) { + int ret = DeConvWinogradInitComputeParam(deconv); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + if (!deconv->valid_weight_shape_ || DeConvWinogradInitParameter(deconv) != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return NNACL_DECONVOLUTION_DEPTHWISE_INVALID_WEIGHT_SHAPE; + } + } + + if (deconv->conv_.is_repack_ && DeConvWinogradInitDataParam(deconv) != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return NNACL_DECONVOLUTION_DEPTHWISE_INVALID_WEIGHT_REPACK; + } + + return NNACL_OK; +} + +int DeConvWinogradPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + ConvComputeParam *compute = &deconv->conv_.compute_; + NNACL_CHECK_FALSE(compute->dilation_h_ != Num1, NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID); + NNACL_CHECK_FALSE(compute->dilation_w_ != Num1, NNACL_DECONVOLUTION_DEPTHWISE_DILATION_INVALID); + NNACL_CHECK_FALSE(compute->stride_h_ == Num0, NNACL_DECONVOLUTION_DEPTHWISE_STRIDE_INVALID); + NNACL_CHECK_FALSE(compute->stride_w_ == Num0, NNACL_DECONVOLUTION_DEPTHWISE_STRIDE_INVALID); + +#ifdef ENABLE_AVX + compute->tile_num_ = C8NUM; +#else + compute->tile_num_ = C4NUM; +#endif + + ConvBaseUpdateOriginWeightAndBias(&deconv->conv_); + + int ret = DeConvWinogradInitComputeParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + + if (deconv->valid_weight_shape_) { + ret = DeConvWinogradInitDataParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + } + + // when input data is const tensor, save data in kernel + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + if (NNACLIsConst(input_tensor)) { + deconv->origin_input_ = (float *)malloc(NNACLGetSize(input_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(deconv->origin_input_); + (void)memcpy(deconv->origin_input_, input_tensor->data_, NNACLGetSize(input_tensor)); + } + return NNACL_OK; +} + +int DeConvWinogradResize(KernelBase *self) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + (void)ConvBaseUpdateComputeInfo(&deconv->conv_); + + int ret = DeConvCheckvResizeValid(&deconv->conv_); + if (ret != NNACL_OK) { + return ret; + } + + DeConvWinogradFreeResizeBuf(deconv); + + ret = ConvBasePrepare(&deconv->conv_); + if (ret != NNACL_OK) { + return ret; + } + + if (!deconv->valid_weight_shape_) { + ret = DeConvWinogradInitComputeParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + if (!deconv->valid_weight_shape_) { + return NNACL_OK; + } + ret = DeConvWinogradInitDataParam(deconv); + if (ret != NNACL_OK) { + return ret; + } + } + + ret = DeConvWinogradInitParameter(deconv); + if (ret != NNACL_OK) { + return ret; + } + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(deconv->conv_.compute_.out_hw_, deconv->conv_.compute_.out_c_, NNACL_ERR); + int output_chw = deconv->conv_.compute_.out_hw_ * deconv->conv_.compute_.out_c_; + if (output_chw <= kDeconvWinogradMaxPixel) { + self->thread_nr_ = NNACL_MIN(self->thread_nr_, Num3); + } + return NNACL_OK; +} + +int DeConvWinogradRelease(KernelBase *self) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + + DeConvWinogradFreeResizeBuf(deconv); + DeConvWinogradFreeDeconvParam(deconv); + + if (deconv->origin_input_ != NULL) { + free(deconv->origin_input_); + deconv->origin_input_ = NULL; + } + return NNACL_OK; +} + +int DeConvWinogradCompute(KernelBase *self) { + DeConvWinogradStruct *deconv = (DeConvWinogradStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(deconv); + DeConvParam *param = &deconv->param_; + ConvComputeParam *compute_ = &deconv->conv_.compute_; + + int ret = DeConvWinogradInitRunBuf(deconv); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + + ret = InitTrainComputeInit(deconv); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + float *src_in = deconv->origin_input_ != NULL ? deconv->origin_input_ : (float *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_in); + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *src_out = (float *)output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_out); + + int input_chw = compute_->in_hw_ * compute_->in_c_; + int output_chw = compute_->out_hw_ * compute_->out_c_; + for (int batch_index = 0; batch_index < compute_->in_n_; batch_index++) { + deconv->nhwc_input_ = src_in + batch_index * input_chw; + deconv->nhwc_output_ = src_out + batch_index * output_chw; + + (void)memset(deconv->nc4hw4_output_, 0, compute_->out_hw_ * param->oc_div_ * compute_->tile_num_ * sizeof(float)); + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, DeConvWgFp32Run, self, self->thread_nr_); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + + /* post bias activate and nhwc */ + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, DeConvWgPostFp32Run, self, self->thread_nr_); + if (ret != NNACL_OK) { + DeConvWinogradFreeRunBuf(deconv); + return ret; + } + } + + DeConvWinogradFreeRunBuf(deconv); + return NNACL_OK; +} + +ConvolutionBaseStruct *CreateDeConvWinograd(ConvParameter *param) { + DeConvWinogradStruct *deconv_winograd = (DeConvWinogradStruct *)malloc(sizeof(DeConvWinogradStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(deconv_winograd); + memset(deconv_winograd, 0, sizeof(DeConvWinogradStruct)); + + deconv_winograd->conv_.base_.Prepare = DeConvWinogradPrepare; + deconv_winograd->conv_.base_.Resize = DeConvWinogradResize; + deconv_winograd->conv_.base_.Release = DeConvWinogradRelease; + deconv_winograd->conv_.base_.Compute = DeConvWinogradCompute; + return &deconv_winograd->conv_; +} +#endif +#endif diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_layernorm_kernel.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution_winograd.h similarity index 39% rename from mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_layernorm_kernel.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution_winograd.h index 6608f355fa176a48ce14fbce9798ccce0d8f9660..1262e1fb461e864c268d27054a87931890a0e476 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_layernorm_kernel.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/deconvolution_winograd.h @@ -13,29 +13,40 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef NNACL_KERNEL_DECONVOLUTION_WINOGRAD_H_ +#define NNACL_KERNEL_DECONVOLUTION_WINOGRAD_H_ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_LAYERNORM_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_LAYERNORM_KERNEL_H_ +#ifndef _WIN32 +#ifndef ENABLE_MCU +#include +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/kernel/convolution_base.h" -#include -#include -#include -#include "extendrt/delegate/ascend_native/ascend_native_base_kernel.h" +#define kDeconvWinogradMaxPixel 3145728 +#define WINOGRAD_DEFAULT_UNIT 3 +#define WINOGRAD_DEFAULT_TILE 8 +#define WINOGRAD_MAX_COUNT 8 -namespace mindspore::kernel { -class AscendNativeLayernormKernel : public AscendNativeBaseKernel { - public: - AscendNativeLayernormKernel(const std::vector &inputs, const std::vector &outputs, - InferPrimitive prim, const InferContext *ctx, const void *stream, std::string name) - : AscendNativeBaseKernel(inputs, outputs, prim, ctx, stream, name) {} +typedef struct DeConvWinogradStruct { + ConvolutionBaseStruct conv_; + DeConvParam param_; + pthread_mutex_t lock_; + int thread_num_hw_; + int thread_stride_hw_; + float *nhwc_input_; + float *nhwc_output_; + float *tile_input_; + float *tile_output_; + float *origin_input_; + float *nc4hw4_output_; + bool valid_weight_shape_; +} DeConvWinogradStruct; - int InferShape() override; +#define NNACL_DECONV_WINOGRAD_HW_MAX 2000 - int Prepare() override; - - int Run() override; - - int ReSize() override; -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_LAYERNORM_KERNEL_H_ +ConvolutionBaseStruct *CreateDeConvWinograd(ConvParameter *param); +#endif +#endif +#endif // NNACL_KERNEL_DECONVOLUTION_WINOGRAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/default_kernel_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/default_kernel_base.c new file mode 100644 index 0000000000000000000000000000000000000000..6c32c4c4d95dcdf794301c94caa4b832b1fd89d7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/default_kernel_base.c @@ -0,0 +1,55 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/default_kernel_base.h" + +int DefaultPrepare3In1Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultPrepare3In2Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < TWO_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultPrepare1In2Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < TWO_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultPrepare1In1Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultPrepare2In1Out(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int DefaultResize(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + return NNACL_OK; +} + +int DefaultRelease(KernelBase *self) { return NNACL_OK; } diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/default_kernel_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/default_kernel_base.h new file mode 100644 index 0000000000000000000000000000000000000000..abc339321ea1be371defc515cee01eb34005cf57 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/default_kernel_base.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_DEFAULT_KERNEL_BASE_H_ +#define NNACL_KERNEL_DEFAULT_KERNEL_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +int DefaultPrepare3In2Out(KernelBase *self); +int DefaultPrepare1In1Out(KernelBase *self); +int DefaultPrepare2In1Out(KernelBase *self); +int DefaultPrepare1In2Out(KernelBase *self); +int DefaultPrepare3In1Out(KernelBase *self); +int DefaultResize(KernelBase *self); +int DefaultRelease(KernelBase *self); + +#endif // NNACL_KERNEL_DEFAULT_KERNEL_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/depth_to_space.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/depth_to_space.c new file mode 100644 index 0000000000000000000000000000000000000000..d3e4ac8ef47f1428c66040b7e5c1cf6cfa316ddf --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/depth_to_space.c @@ -0,0 +1,80 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/depth_to_space.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/depth_to_space_parameter.h" +#include "nnacl/base/depth_to_space_base.h" + +int DepthToSpaceResize(KernelBase *self) { + DepthToSpaceStruct *depth_to_space = (DepthToSpaceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(depth_to_space); + DepthToSpaceArgs *args = &depth_to_space->args_; + + TensorC *input = self->in_[FIRST_INPUT]; + int32_t in_strides[DIMENSION_4D] = {0}; + ComputeStrides(input->shape_, in_strides, input->shape_size_); + args->in_stride_dim0_ = in_strides[Index0]; + args->in_stride_dim1_ = in_strides[Index1]; + args->in_stride_dim2_ = in_strides[Index2]; + + TensorC *output = self->out_[OUTPUT_INDEX]; + int32_t out_strides[DIMENSION_4D] = {0}; + ComputeStrides(output->shape_, out_strides, output->shape_size_); + args->out_stride_dim0_ = out_strides[Index0]; + args->out_stride_dim1_ = out_strides[Index1]; + args->out_stride_dim2_ = out_strides[Index2]; + return NNACL_OK; +} + +int DepthToSpaceCompute(KernelBase *self) { + DepthToSpaceStruct *depth_to_space = (DepthToSpaceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(depth_to_space); + int mode = ((DepthToSpaceParameter *)self->param_)->mode_; + + TensorC *input = self->in_[FIRST_INPUT]; + TensorC *output = self->out_[OUTPUT_INDEX]; + + if (mode == 0) { + // RCD + DepthToSpaceForNHWC(input->data_, output->data_, input->shape_, &depth_to_space->args_); + } else if (mode == 1) { + // CRD + DepthToSpaceCRDForNHWC(input->data_, output->data_, input->shape_, &depth_to_space->args_); + } else { + return NNACL_DEPTH_TO_SPACE_INVALID_MODE; + } + return NNACL_OK; +} + +KernelBase *CreateDepthToSpace(OpParameter *param, int data_type) { + DepthToSpaceStruct *depth_to_space = (DepthToSpaceStruct *)malloc(sizeof(DepthToSpaceStruct)); + NNACL_CHECK_NULL_RETURN_NULL(depth_to_space); + memset(depth_to_space, 0, sizeof(DepthToSpaceStruct)); + + depth_to_space->args_.data_type_size_ = DataTypeCSize(data_type); + depth_to_space->args_.block_size_ = ((DepthToSpaceParameter *)param)->block_size_; + depth_to_space->base_.Release = DefaultRelease; + depth_to_space->base_.Prepare = DefaultPrepare1In1Out; + depth_to_space->base_.Resize = DepthToSpaceResize; + depth_to_space->base_.Compute = DepthToSpaceCompute; + return (KernelBase *)depth_to_space; +} + +REG_KERNEL_CREATOR(PrimType_DepthToSpace, kNumberTypeFloat32, CreateDepthToSpace) +REG_KERNEL_CREATOR(PrimType_DepthToSpace, kNumberTypeFloat16, CreateDepthToSpace) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/depth_to_space.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/depth_to_space.h new file mode 100644 index 0000000000000000000000000000000000000000..a98a3930a5fda93c026806afc5e83d15b5ba7af7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/depth_to_space.h @@ -0,0 +1,42 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_DEPTH_TO_SPACE_H_ +#define NNACL_KERNEL_DEPTH_TO_SPACE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct DepthToSpaceArgs { + int32_t in_stride_dim0_; + int32_t in_stride_dim1_; + int32_t in_stride_dim2_; + int32_t out_stride_dim0_; + int32_t out_stride_dim1_; + int32_t out_stride_dim2_; + uint8_t data_type_size_; + int32_t block_size_; +} DepthToSpaceArgs; + +typedef struct DepthToSpaceStruct { + KernelBase base_; + DepthToSpaceArgs args_; +} DepthToSpaceStruct; + +KernelBase *CreateDepthToSpace(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_DEPTH_TO_SPACE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/exp.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/exp.c new file mode 100644 index 0000000000000000000000000000000000000000..895f3fcb9a6bc0c1105ae3683678838b73f576fb --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/exp.c @@ -0,0 +1,86 @@ +/** + * Copyright 2022-2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/exp.h" +#include +#include "nnacl/exp_parameter.h" +#include "nnacl/op_base.h" +#include "nnacl/fp32/exp_fp32.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/exp_fp16.h" +#endif + +int ExpRunImpl(void *cdata, int task_id, float l, float r) { + ExpStruct *exp = (ExpStruct *)cdata; + return exp->Exp(exp->base_.in_[FIRST_INPUT]->data_, exp->base_.out_[OUTPUT_INDEX]->data_, exp, task_id); +} + +int ExpResize(struct KernelBase *self) { + ExpStruct *exp = (ExpStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(exp); + ExpParameter *param = (ExpParameter *)exp->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + exp->element_num_ = NNACLGetElementNum(exp->base_.in_[FIRST_INPUT]); + return NNACL_OK; +} + +int ExpPrepare(struct KernelBase *self) { + ExpStruct *exp = (ExpStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(exp); + ExpParameter *param = (ExpParameter *)exp->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_FALSE(self->in_size_ < 1 || self->out_size_ < 1, NNACL_TENSOR_SIZE_INVALID); + + float log_base = (param->base_ == -1) ? 1 : logf(param->base_); + float epsilon = 0.000001; + exp->in_scale_ = param->scale_ * log_base; + if (param->shift_ == 0) { + exp->out_scale_ = 1; + } else { + if (fabs(log_base - 1) < epsilon) { + exp->out_scale_ = expf(param->shift_); + } else { + exp->out_scale_ = powf(param->base_, param->shift_); + } + } + + return NNACL_OK; +} + +int ExpCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ExpRunImpl, self, self->thread_nr_); +} + +KernelBase *CreateExp(OpParameter *param, int data_type) { + ExpStruct *exp = (ExpStruct *)malloc(sizeof(ExpStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(exp); + exp->base_.Prepare = ExpPrepare; + exp->base_.Resize = ExpResize; + exp->base_.Release = DefaultRelease; + exp->base_.Compute = ExpCompute; + exp->Exp = ExpFusionFp32; +#ifdef ENABLE_FP16 + if (data_type == kNumberTypeFloat16) { + exp->Exp = ExpFusionFp16; + } +#endif + return (KernelBase *)exp; +} + +REG_KERNEL_CREATOR(PrimType_ExpFusion, kNumberTypeFloat32, CreateExp) +REG_KERNEL_CREATOR(PrimType_ExpFusion, kNumberTypeFloat16, CreateExp) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/exp.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/exp.h new file mode 100644 index 0000000000000000000000000000000000000000..41ecfc46eae05b05f8ec89ea019e3d83d40175aa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/exp.h @@ -0,0 +1,33 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_EXP_H_ +#define NNACL_KERNEL_EXP_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct ExpStruct { + KernelBase base_; + float in_scale_; + float out_scale_; + int element_num_; + int (*Exp)(const void *in, void *out, const struct ExpStruct *exp, int task_id); +} ExpStruct; + +KernelBase *CreateExp(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_EXP_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/arithmetic_compare_f16.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/arithmetic_compare_f16.c new file mode 100644 index 0000000000000000000000000000000000000000..694a063b0d5f231f2cd2b92c6d5cddeaceec03ed --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/arithmetic_compare_f16.c @@ -0,0 +1,110 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/f16/arithmetic_compare_f16.h" +#include "nnacl/kernel/f16/arithmetic_f16.h" +#include "nnacl/fp16/arithmetic_fp16.h" + +typedef struct ArithmeticCompareF16Funcions { + int primitive_type_; + int activation_type_; + int (*compute_)(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size); + int (*optimzie_)(const float16_t *input0, const float16_t *input1, uint8_t *output, int element_size, + bool first_scalar); +} ArithmeticCompareF16Funcions; + +typedef struct ArithmeticCompareF16Struct { + ArithmeticF16Struct arithmetic_f16_; + ArithmeticCompareF16Funcions functions_; +} ArithmeticCompareF16Struct; + +void InitArithmeticCompareF16RunFunction(KernelBase *base) { + ArithmeticCompareF16Struct *arithmetic_compare_f16 = (ArithmeticCompareF16Struct *)base; + ArithmeticParameter *arithmetic_param = (ArithmeticParameter *)base->param_; + + ArithmeticCompareF16Funcions arithmetic_cp_fun_table_fp16[] = { + {PrimType_NotEqual, ActType_No, ElementNotEqualFp16, ElementOptNotEqualFp16}, + {PrimType_Equal, ActType_No, ElementEqualFp16, ElementOptEqualFp16}, + {PrimType_Less, ActType_No, ElementLessFp16, ElementOptLessFp16}, + {PrimType_LessEqual, ActType_No, ElementLessEqualFp16, ElementOptLessEqualFp16}, + {PrimType_Greater, ActType_No, ElementGreaterFp16, ElementOptGreaterFp16}, + {PrimType_GreaterEqual, ActType_No, ElementGreaterEqualFp16, ElementOptGreaterEqualFp16}}; + + size_t length = sizeof(arithmetic_cp_fun_table_fp16) / sizeof(ArithmeticCompareF16Funcions); + for (size_t i = 0; i < length; i++) { + if (arithmetic_cp_fun_table_fp16[i].primitive_type_ == + arithmetic_compare_f16->arithmetic_f16_.arithmetic_.primitive_type_ && + arithmetic_cp_fun_table_fp16[i].activation_type_ == arithmetic_param->activation_type_) { + arithmetic_compare_f16->functions_ = arithmetic_cp_fun_table_fp16[i]; + return; + } + } +} + +int ArithmeticCompareF16DoExecute(KernelBase *base, const void *input0, const void *input1, void *output, + int64_t size) { + ArithmeticCompareF16Struct *arithmetic_compare_f16 = (ArithmeticCompareF16Struct *)base; + + if (arithmetic_compare_f16->arithmetic_f16_.arithmetic_.scalar_opt_) { + bool first_scalar = arithmetic_compare_f16->arithmetic_f16_.arithmetic_.in_elements_num0_ == 1; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare_f16->functions_.optimzie_); + return arithmetic_compare_f16->functions_.optimzie_((const float16_t *)input0, (const float16_t *)input1, + (uint8_t *)output, size, first_scalar); + } + + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_compare_f16->functions_.compute_); + return arithmetic_compare_f16->functions_.compute_((const float16_t *)input0, (const float16_t *)input1, + (uint8_t *)output, size); +} +int ArithmeticCompareF16Compute(KernelBase *self) { + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic); + arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + arithmetic->out_data_size_ = DataTypeCSize(self->out_[OUTPUT_INDEX]->data_type_); + return ArithmeticF16Compute(self); +} + +KernelBase *CreateArithmeticCompareF16(OpParameter *param, int data_type) { + ArithmeticCompareF16Struct *arithmetic_compare_f16 = + (ArithmeticCompareF16Struct *)malloc(sizeof(ArithmeticCompareF16Struct)); + NNACL_CHECK_NULL_RETURN_NULL(arithmetic_compare_f16); + memset(arithmetic_compare_f16, 0, sizeof(ArithmeticF16Struct)); + + ArithmeticStruct *arithmetic = &arithmetic_compare_f16->arithmetic_f16_.arithmetic_; + arithmetic->block_boundary_infos_size_ = 0; + arithmetic->a_matrix_.batch_post_sum_ = NULL; + arithmetic->b_matrix_.batch_post_sum_ = NULL; + arithmetic->c_matrix_.batch_post_sum_ = NULL; + arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL; + arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL; + arithmetic->base_.Prepare = ArithmeticPrepare; + arithmetic->base_.Resize = ArithmeticF16Resize; + arithmetic->base_.Release = ArithmeticRelease; + arithmetic->base_.Compute = ArithmeticCompareF16Compute; + + arithmetic->execute_ = ArithmeticCompareF16DoExecute; + arithmetic->tile_function_ = TileOneDimensionFp16; + arithmetic->init_function_ = InitArithmeticCompareF16RunFunction; + + return (KernelBase *)arithmetic_compare_f16; +} + +REG_KERNEL_CREATOR(PrimType_NotEqual, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_Equal, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_Less, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_LessEqual, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_Greater, kNumberTypeFloat16, CreateArithmeticCompareF16) +REG_KERNEL_CREATOR(PrimType_GreaterEqual, kNumberTypeFloat16, CreateArithmeticCompareF16) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/arithmetic_compare_f16.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/arithmetic_compare_f16.h new file mode 100644 index 0000000000000000000000000000000000000000..57f05a92a75b302e3cdceda821aba233c0f43d40 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/arithmetic_compare_f16.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_F16_ARITHMETIC_COMPARE_F16_H_ +#define NNACL_KERNEL_F16_ARITHMETIC_COMPARE_F16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateArithmeticCompareF16(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_F16_ARITHMETIC_COMPARE_F16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/arithmetic_f16.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/arithmetic_f16.c new file mode 100644 index 0000000000000000000000000000000000000000..5788fe04626425a8fab099e3b0cf939afbb3186d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/arithmetic_f16.c @@ -0,0 +1,195 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/f16/arithmetic_f16.h" +#include "nnacl/fp16/cast_fp16.h" +#include "nnacl/fp16/arithmetic_fp16.h" +#include "nnacl/fp16/utils_fp16.h" +#include "nnacl/tensor_c_utils.h" + +void InitArithmeticF16RunFunction(KernelBase *base) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)base; + + ArithmeticF16Funcions f16_fun_table[] = { + {PrimType_MulFusion, ActType_Relu, ElementMulReluFp16, ElementOptMulReluFp16}, + {PrimType_MulFusion, ActType_Relu6, ElementMulRelu6Fp16, ElementOptMulRelu6Fp16}, + {PrimType_MulFusion, ActType_No, ElementMulFp16, ElementOptMulFp16}, + {PrimType_AddFusion, ActType_Relu, ElementAddReluFp16, ElementOptAddReluFp16}, + {PrimType_AddFusion, ActType_Relu6, ElementAddRelu6Fp16, ElementOptAddRelu6Fp16}, + {PrimType_AddFusion, ActType_No, ElementAddFp16, ElementOptAddFp16}, + {PrimType_SubFusion, ActType_Relu, ElementSubReluFp16, ElementOptSubReluFp16}, + {PrimType_SubFusion, ActType_Relu6, ElementSubRelu6Fp16, ElementOptSubRelu6Fp16}, + {PrimType_SubFusion, ActType_No, ElementSubFp16, ElementOptSubFp16}, + {PrimType_DivFusion, ActType_Relu, ElementDivReluFp16, ElementOptDivReluFp16}, + {PrimType_DivFusion, ActType_Relu6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16}, + {PrimType_DivFusion, ActType_No, ElementDivFp16, ElementOptDivFp16}, + {PrimType_RealDiv, ActType_Relu, ElementDivReluFp16, ElementOptDivReluFp16}, + {PrimType_RealDiv, ActType_Relu6, ElementDivRelu6Fp16, ElementOptDivRelu6Fp16}, + {PrimType_RealDiv, ActType_No, ElementDivFp16, ElementOptDivFp16}, + {PrimType_FloorMod, ActType_No, ElementFloorModFp16, ElementOptFloorModFp16}, + {PrimType_FloorDiv, ActType_No, ElementFloorDivFp16, ElementOptFloorDivFp16}, + {PrimType_LogicalAnd, ActType_No, ElementLogicalAndFp16, ElementOptLogicalAndFp16}, + {PrimType_LogicalOr, ActType_No, ElementLogicalOrFp16, ElementOptLogicalOrFp16}, + {PrimType_SquaredDifference, ActType_No, ElementSquaredDifferenceFp16, ElementOptSquaredDifferenceFp16}, + {PrimType_Maximum, ActType_No, ElementMaximumFp16, ElementOptMaximumFp16}, + {PrimType_Minimum, ActType_No, ElementMinimumFp16, ElementOptMinimumFp16}}; + + size_t length = sizeof(f16_fun_table) / sizeof(ArithmeticF16Funcions); + for (size_t i = 0; i < length; i++) { + if (f16_fun_table[i].primitive_type_ == arithmetic_f16->arithmetic_.primitive_type_ && + f16_fun_table[i].activation_type_ == + ((ArithmeticParameter *)(arithmetic_f16->arithmetic_.base_.param_))->activation_type_) { + arithmetic_f16->functions_ = f16_fun_table[i]; + return; + } + } +} + +int ArithmeticF16DoExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)base; + + if (arithmetic_f16->arithmetic_.scalar_opt_) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->functions_.optimzie_); + return arithmetic_f16->functions_.optimzie_((const float16_t *)input0, (const float16_t *)input1, + (float16_t *)output, size, + arithmetic_f16->arithmetic_.in_elements_num0_ == 1); + } + + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->functions_.compute_); + return arithmetic_f16->functions_.compute_((const float16_t *)input0, (const float16_t *)input1, (float16_t *)output, + size); +} + +int ArithmeticF16Resize(KernelBase *self) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16); + ArithmeticStruct *arithmetic = (ArithmeticStruct *)self; + + arithmetic->in_data_size_ = sizeof(float16_t); + arithmetic->out_data_size_ = sizeof(float16_t); + if (arithmetic->in_elements_num1_ != 1 && arithmetic->in_elements_num0_ != 1) { + if (arithmetic->a_matrix_.is_const_ && self->in_[FIRST_INPUT]->data_type_ == kNumberTypeFloat32) { + TensorC *t = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(t->data_); + void *f32_data = t->data_; + t->data_type_ = kNumberTypeFloat16; + t->data_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(t)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_); + Float32ToFloat16((float *)(f32_data), (float16_t *)(t->data_), NNACLGetElementNum(t)); + self->env_->Free(self->env_->allocator_, f32_data); + } + if (arithmetic->b_matrix_.is_const_ && self->in_[SECOND_INPUT]->data_type_ == kNumberTypeFloat32) { + TensorC *t = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(t->data_); + void *f32_data = t->data_; + t->data_type_ = kNumberTypeFloat16; + t->data_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(t)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_); + Float32ToFloat16((float *)(f32_data), (float16_t *)(t->data_), NNACLGetElementNum(t)); + self->env_->Free(self->env_->allocator_, f32_data); + } + } + return ArithmeticResize(self); +} + +void FreeArithmeticF16Buffers(ArithmeticF16Struct *arithmetic_f16) { + for (int i = 0; i < THREE_TENSOR; i++) { + if (arithmetic_f16->tmp_buffer_[i] != NULL) { + arithmetic_f16->arithmetic_.base_.env_->Free(arithmetic_f16->arithmetic_.base_.env_->allocator_, + arithmetic_f16->tmp_buffer_[i]); + arithmetic_f16->tmp_buffer_[i] = NULL; + } + } +} + +int ArithmeticF16Compute(KernelBase *self) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16); + + int in0_data_type = self->in_[FIRST_INPUT]->data_type_; + int in1_data_type = self->in_[SECOND_INPUT]->data_type_; + int out_data_type = self->out_[OUTPUT_INDEX]->data_type_; + + NNACL_CHECK_FALSE(in0_data_type != kNumberTypeFloat32 && in0_data_type != kNumberTypeFloat16, + NNACL_UNSUPPORTED_DATA_TYPE); + NNACL_CHECK_FALSE(in1_data_type != kNumberTypeFloat16 && in1_data_type != kNumberTypeFloat32, + NNACL_UNSUPPORTED_DATA_TYPE); + + if (!arithmetic_f16->arithmetic_.a_matrix_.is_valid_) { + arithmetic_f16->arithmetic_.a_matrix_.data_ = GetOrAllocFp16Data(self->in_[FIRST_INPUT], self->env_, true); + arithmetic_f16->tmp_buffer_[FIRST_INPUT] = + in0_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.a_matrix_.data_; + } + + if (!arithmetic_f16->arithmetic_.b_matrix_.is_valid_) { + arithmetic_f16->arithmetic_.b_matrix_.data_ = GetOrAllocFp16Data(self->in_[SECOND_INPUT], self->env_, true); + arithmetic_f16->tmp_buffer_[SECOND_INPUT] = + in1_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.b_matrix_.data_; + } + + arithmetic_f16->arithmetic_.c_matrix_.data_ = GetOrAllocFp16Data(self->out_[OUTPUT_INDEX], self->env_, false); + arithmetic_f16->tmp_buffer_[THIRD_INPUT] = + out_data_type == kNumberTypeFloat16 ? NULL : arithmetic_f16->arithmetic_.c_matrix_.data_; + + int ret = ArithmeticCompute(self); + if (ret == NNACL_OK && out_data_type == kNumberTypeFloat32) { + NNACL_CHECK_NULL_RETURN_ERR(arithmetic_f16->arithmetic_.c_matrix_.data_); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]->data_); + Float16ToFloat32((float16_t *)(arithmetic_f16->arithmetic_.c_matrix_.data_), + (float *)(self->out_[OUTPUT_INDEX]->data_), NNACLGetElementNum(self->out_[OUTPUT_INDEX])); + } + + FreeArithmeticF16Buffers(arithmetic_f16); + return NNACL_OK; +} + +KernelBase *CreateArithmeticF16(OpParameter *param, int data_type) { + ArithmeticF16Struct *arithmetic_f16 = (ArithmeticF16Struct *)malloc(sizeof(ArithmeticF16Struct)); + NNACL_CHECK_NULL_RETURN_NULL(arithmetic_f16); + memset(arithmetic_f16, 0, sizeof(ArithmeticF16Struct)); + + ArithmeticStruct *arithmetic = &arithmetic_f16->arithmetic_; + arithmetic->block_boundary_infos_size_ = 0; + arithmetic->a_matrix_.batch_post_sum_ = NULL; + arithmetic->b_matrix_.batch_post_sum_ = NULL; + arithmetic->c_matrix_.batch_post_sum_ = NULL; + arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL; + arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL; + arithmetic->base_.Prepare = ArithmeticPrepare; + arithmetic->base_.Resize = ArithmeticF16Resize; + arithmetic->base_.Release = ArithmeticRelease; + arithmetic->base_.Compute = ArithmeticF16Compute; + + arithmetic->execute_ = ArithmeticF16DoExecute; + arithmetic->tile_function_ = TileOneDimensionFp16; + arithmetic->init_function_ = InitArithmeticF16RunFunction; + + return (KernelBase *)arithmetic_f16; +} + +REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_Eltwise, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_RealDiv, kNumberTypeFloat16, CreateArithmeticF16) +REG_KERNEL_CREATOR(PrimType_SquaredDifference, kNumberTypeFloat16, CreateArithmeticF16) diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_add_kernel.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/arithmetic_f16.h similarity index 40% rename from mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_add_kernel.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/arithmetic_f16.h index 725caa2f95d67cae3f47d00c90e34619ea2113e7..928fe6034237dc21b2c7a153ce8767dcff9fc766 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_add_kernel.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/arithmetic_f16.h @@ -14,30 +14,29 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_ADD_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_ADD_KERNEL_H_ - -#include -#include -#include -#include "extendrt/delegate/ascend_native/ascend_native_base_kernel.h" - -namespace mindspore::kernel { -class AscendNativeAddKernel : public AscendNativeBaseKernel { - public: - AscendNativeAddKernel(const std::vector &inputs, const std::vector &outputs, - InferPrimitive prim, const InferContext *ctx, const void *stream, std::string name) - : AscendNativeBaseKernel(inputs, outputs, prim, ctx, stream, name) {} - - int InferShape() override; - - int Prepare() override; - - int Run() override; - - int ReSize() override; - - private: -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_ADD_KERNEL_H_ +#ifndef NNACL_KERNEL_F16_ARITHMETIC_F16_H_ +#define NNACL_KERNEL_F16_ARITHMETIC_F16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/kernel/arithmetic.h" + +typedef struct ArithmeticF16Funcions { + int primitive_type_; + int activation_type_; + int (*compute_)(const float16_t *in1, const float16_t *in2, float16_t *out, int ele); + int (*optimzie_)(const float16_t *in1, const float16_t *in2, float16_t *out, int ele, bool first_scalar); +} ArithmeticF16Funcions; + +typedef struct ArithmeticF16Struct { + ArithmeticStruct arithmetic_; + ArithmeticF16Funcions functions_; + void *tmp_buffer_[THREE_TENSOR]; /* in_size + out_size */ +} ArithmeticF16Struct; + +KernelBase *CreateArithmeticF16(OpParameter *param, int data_type); +int ArithmeticF16Resize(KernelBase *self); +int ArithmeticF16Compute(KernelBase *self); + +#endif // MINDSPORE_NNACL_KERNEL_F16_ARITHMETIC_F16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/concat_f16.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/concat_f16.c new file mode 100644 index 0000000000000000000000000000000000000000..34df7d5dabc5587c6666ecacfd8bbaf6feb3cf9c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/concat_f16.c @@ -0,0 +1,132 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/f16/concat_f16.h" +#include "nnacl/kernel/concat.h" +#include "nnacl/fp16/cast_fp16.h" +#include "nnacl/fp16/utils_fp16.h" +#include "nnacl/tensor_c_utils.h" + +typedef struct ConcatF16Struct { + ConcatStruct concat_; + void **tmp_buffer_; /* in_size + out_size */ +} ConcatF16Struct; + +int ConcatEnsureFp16InputsAndOutput(ConcatF16Struct *concat_f16) { + ConcatStruct *concat = &concat_f16->concat_; + + int tmp_buffer_size = (concat->base_.in_size_ + concat->base_.out_size_) * sizeof(float16_t *); + concat_f16->tmp_buffer_ = concat->base_.env_->Alloc(concat->base_.env_->allocator_, tmp_buffer_size); + NNACL_CHECK_NULL_RETURN_ERR(concat_f16->tmp_buffer_); + memset(concat_f16->tmp_buffer_, 0, tmp_buffer_size); + + for (size_t i = 0; i < concat->base_.in_size_; ++i) { + if (!concat->is_with_data_[i]) { + continue; + } + + concat->inputs_ptr_[i] = GetOrAllocFp16Data(concat->base_.in_[i], concat->base_.env_, true); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(concat->inputs_ptr_[i]); + if (concat->base_.in_[i]->data_type_ == kNumberTypeFloat32 || + concat->base_.in_[i]->data_type_ == kNumberTypeFloat) { + concat_f16->tmp_buffer_[i] = concat->inputs_ptr_[i]; + } + } + + concat->output_ = GetOrAllocFp16Data(concat->base_.out_[OUTPUT_INDEX], concat->base_.env_, false); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(concat->output_); + if (concat->base_.out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat32 || + concat->base_.out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat) { + concat_f16->tmp_buffer_[concat->base_.in_size_] = concat->output_; + } + return NNACL_OK; +} + +int ConcatFp16Run(void *cdata, int task_id, float l, float r) { + ConcatF16Struct *concat_f16 = (ConcatF16Struct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(concat_f16); + ConcatStruct *concat = &concat_f16->concat_; + return DoConcat(concat, task_id); +} + +void ConcatF16FreeTmpBuffer(ConcatF16Struct *concat_f16) { + if (concat_f16->tmp_buffer_ != NULL) { + /* free tmp_buffer_[i] */ + for (int i = 0; i < (concat_f16->concat_.base_.in_size_ + concat_f16->concat_.base_.out_size_); i++) { + if (concat_f16->tmp_buffer_[i] != NULL) { + concat_f16->concat_.base_.env_->Free(concat_f16->concat_.base_.env_->allocator_, concat_f16->tmp_buffer_[i]); + } + concat_f16->tmp_buffer_[i] = NULL; + } + + /* free tmp_buffer_ */ + concat_f16->concat_.base_.env_->Free(concat_f16->concat_.base_.env_->allocator_, concat_f16->tmp_buffer_); + concat_f16->tmp_buffer_ = NULL; + } +} + +int ConcatF16Compute(KernelBase *self) { + ConcatF16Struct *concat_f16 = (ConcatF16Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(concat_f16); + ConcatStruct *concat = &concat_f16->concat_; + + if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) { + return NNACL_OK; + } + + int ret = ConcatEnsureFp16InputsAndOutput(concat_f16); + if (ret != NNACL_OK) { + ConcatF16FreeTmpBuffer(concat_f16); + return ret; + } + + NNACL_CHECK_NULL_RETURN_ERR(concat->output_); + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConcatFp16Run, self, self->thread_nr_); + if (ret == NNACL_OK) { + TensorC *output_tensor = concat->base_.out_[FIRST_INPUT]; + if (output_tensor->data_type_ == kNumberTypeFloat32 || output_tensor->data_type_ == kNumberTypeFloat) { + float *output = concat->base_.out_[FIRST_INPUT]->data_; + if (output == NULL) { + ret = NNACL_CONCAT_F16_OUTPUT_DATA_INVALID; + } else { + Float16ToFloat32((float16_t *)concat->output_, output, NNACLGetElementNum(output_tensor)); + } + } + } + + ConcatF16FreeTmpBuffer(concat_f16); + return ret; +} + +KernelBase *CreateConcatF16(OpParameter *param, int data_type) { + ConcatF16Struct *concat_f16 = (ConcatF16Struct *)malloc(sizeof(ConcatF16Struct)); + NNACL_CHECK_NULL_RETURN_NULL(concat_f16); + memset(concat_f16, 0, sizeof(ConcatF16Struct)); + + ConcatStruct *concat = &concat_f16->concat_; + concat->data_type_ = kNumberTypeFloat16; + concat->inner_sizes_ = NULL; + concat->inputs_ptr_ = NULL; + concat->is_with_data_ = NULL; + concat->base_.Prepare = ConcatPepare; + concat->base_.Resize = ConcatResize; + concat->base_.Release = ConcatRelease; + concat->base_.Compute = ConcatF16Compute; + concat_f16->tmp_buffer_ = NULL; + return (KernelBase *)concat; +} + +REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeFloat16, CreateConcatF16) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/concat_f16.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/concat_f16.h new file mode 100644 index 0000000000000000000000000000000000000000..11b96ad5448b63d4bbc5c9765c48f11e2974c0a6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/concat_f16.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_KERNEL_F16_CONCAT_F16_H_ +#define MINDSPORE_NNACL_KERNEL_F16_CONCAT_F16_H_ +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateConcatF16(OpParameter *param, int data_type); + +#endif // MINDSPORE_NNACL_KERNEL_F16_CONCAT_F16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/reduce_f16.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/reduce_f16.c new file mode 100644 index 0000000000000000000000000000000000000000..836f8d6d5f1b87f8765f95a0c806fba1ad621fee --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/reduce_f16.c @@ -0,0 +1,118 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/f16/reduce_f16.h" +#include "nnacl/fp16/reduce_fp16.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/kernel/default_kernel_base.h" + +typedef struct ReduceF16Compute { + int type_; + int (*f16_reducer_)(const int outer_size, const int inner_size, const int axis_size, const float16_t *src_data, + float16_t *dst_data, const int tid, const int thread_num); +} ReduceF16Compute; + +typedef struct ReduceF16Struct { + ReduceStruct reduce_; + ReduceF16Compute compute_; +} ReduceF16Struct; + +int CallReduceF16Unit(KernelBase *base, int task_id) { + ReduceF16Struct *reduce_f16 = (ReduceF16Struct *)base; + NNACL_CHECK_NULL_RETURN_ERR(reduce_f16->reduce_.src_data_); + NNACL_CHECK_NULL_RETURN_ERR(reduce_f16->reduce_.src_data_); + NNACL_CHECK_NULL_RETURN_ERR(reduce_f16->compute_.f16_reducer_); + + return reduce_f16->compute_.f16_reducer_(reduce_f16->reduce_.outer_size_, reduce_f16->reduce_.inner_size_, + reduce_f16->reduce_.axis_size_, + (const float16_t *)reduce_f16->reduce_.src_data_, + (float16_t *)reduce_f16->reduce_.dst_data_, task_id, base->thread_nr_); +} + +void InitialReduceF16KernelList(KernelBase *base) { + ReduceF16Struct *reduce_f16 = (ReduceF16Struct *)base; + ReduceParameter *param = (ReduceParameter *)(base->param_); + + ReduceF16Compute func_list[] = {{Reduce_Sum, ReduceSumFp16}, {Reduce_Mean, ReduceMeanFp16}, + {Reduce_Max, ReduceMaxFp16}, {Reduce_Min, ReduceMinFp16}, + {Reduce_Prod, ReduceProdFp16}, {Reduce_SumSquare, ReduceSumFp16}, + {Reduce_ASum, ReduceSumFp16}, {Reduce_L2, ReduceL2NormFp16}}; + + size_t list_len = sizeof(func_list) / sizeof(ReduceF16Compute); + for (size_t i = 0; i < list_len; ++i) { + if (param->mode_ == func_list[i].type_) { + reduce_f16->compute_ = func_list[i]; + return; + } + } +} + +void HandleReduceF16ASumAndSumSquare(KernelBase *base) { + TensorC *in_tensor = base->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(in_tensor); + float16_t *data = (float16_t *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_VOID(data); + + int num = NNACLGetElementNum(in_tensor); + + if (((ReduceParameter *)base->param_)->mode_ == Reduce_ASum) { + for (int i = 0; i < num; ++i) { + if (data[i] < 0.0f) { + data[i] = 0.0f - data[i]; + } + } + } + + if (((ReduceParameter *)base->param_)->mode_ == Reduce_SumSquare) { + for (int i = 0; i < num; ++i) { + data[i] = data[i] * data[i]; + } + return; + } +} + +int CalculateReduceF16CoeffOutput(KernelBase *base) { + TensorC *out_tensor = base->out_[OUTPUT_INDEX]; + int num = NNACLGetElementNum(out_tensor); + + float16_t *out_data = (float16_t *)out_tensor->data_; + for (int i = 0; i < num; ++i) { + out_data[i] *= ((ReduceParameter *)base->param_)->coeff; + } + return NNACL_OK; +} + +KernelBase *CreateReduceF16(OpParameter *param, int data_type) { + ReduceF16Struct *reduce_f16 = (ReduceF16Struct *)malloc(sizeof(ReduceF16Struct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reduce_f16); + memset(reduce_f16, 0, sizeof(ReduceF16Struct)); + + ReduceStruct *reduce = &reduce_f16->reduce_; + reduce->data_type_ = data_type; + reduce->base_.Release = DefaultRelease; + reduce->base_.Prepare = ReducePrepare; + reduce->base_.Resize = ReduceResize; + reduce->base_.Compute = ReduceCompute; + + reduce->handle_sum_square_ = HandleReduceF16ASumAndSumSquare; + reduce->calculate_coeff_ = CalculateReduceF16CoeffOutput; + reduce->init_kernel_list_ = InitialReduceF16KernelList; + reduce->call_uint_ = CallReduceF16Unit; + + return (KernelBase *)reduce_f16; +} + +REG_KERNEL_CREATOR(PrimType_ReduceFusion, kNumberTypeFloat16, CreateReduceF16) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/reduce_f16.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/reduce_f16.h new file mode 100644 index 0000000000000000000000000000000000000000..2cf844c61fd4342cb1adb25b042237ccaf5baa3c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/reduce_f16.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_F16_REDUCE_F16_H_ +#define NNACL_KERNEL_F16_REDUCE_F16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/kernel/reduce.h" + +KernelBase *CreateReduceF16(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_F16_REDUCE_F16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/stack_f16.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/stack_f16.c new file mode 100644 index 0000000000000000000000000000000000000000..ce6fbb3e4dc19ad15297ae9bbb818abba37200b3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/stack_f16.c @@ -0,0 +1,96 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/f16/stack_f16.h" +#include "nnacl/fp16/cast_fp16.h" +#include "nnacl/tensor_c_utils.h" + +void *StackF16InitBuffer(KernelBase *base, TensorC *t, bool init) { + if (init == false) { + return t->data_; + } + + int ele_num = NNACLGetElementNum(t); + void *f16_buffer = base->env_->Alloc(base->env_->allocator_, ele_num * sizeof(float16_t)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(f16_buffer); + Float32ToFloat16(t->data_, f16_buffer, ele_num); + return f16_buffer; +} + +int StackF16InitMallocFlags(StackF16Struct *stack_f16) { + KernelBase *base = (KernelBase *)stack_f16; + stack_f16->init_ = base->env_->Alloc(base->env_->allocator_, (base->in_size_ + base->out_size_) * sizeof(bool)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack_f16->init_); + + for (size_t i = 0; i < base->in_size_; ++i) { + stack_f16->init_[i] = base->in_[i]->data_type_ == kNumberTypeFloat32; + stack_f16->stack_.buffers_[i] = StackF16InitBuffer(base, base->in_[i], stack_f16->init_[i]); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack_f16->stack_.buffers_[i]); + } + stack_f16->init_[base->in_size_] = base->out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat32; + stack_f16->stack_.buffers_[base->in_size_] = + StackF16InitBuffer(base, base->out_[OUTPUT_INDEX], stack_f16->init_[base->in_size_]); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack_f16->stack_.buffers_[base->in_size_]); + return NNACL_OK; +} + +void StackF16FreeBuffer(StackF16Struct *stack_f16) { + if (stack_f16->init_[stack_f16->stack_.base_.in_size_]) { + /* output transfer */ + Float16ToFloat32((float16_t *)stack_f16->stack_.buffers_[stack_f16->stack_.base_.in_size_], + (float *)stack_f16->stack_.base_.out_[OUTPUT_INDEX]->data_, + NNACLGetElementNum(stack_f16->stack_.base_.out_[OUTPUT_INDEX])); + } + + for (size_t i = 0; i < (stack_f16->stack_.base_.in_size_ + stack_f16->stack_.base_.out_size_); ++i) { + if (stack_f16->init_[i]) { + stack_f16->stack_.base_.env_->Free(stack_f16->stack_.base_.env_->allocator_, stack_f16->stack_.buffers_[i]); + } + stack_f16->stack_.buffers_[i] = NULL; + } + + stack_f16->stack_.base_.env_->Free(stack_f16->stack_.base_.env_->allocator_, stack_f16->init_); + stack_f16->init_ = NULL; +} + +int StackF16Compute(KernelBase *self) { + StackF16Struct *stack_f16 = (StackF16Struct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack_f16); + + int ret = StackF16InitMallocFlags(stack_f16); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, StackRun, self, self->thread_nr_); + StackF16FreeBuffer(stack_f16); + return ret; +} + +KernelBase *CreateStackF16(OpParameter *param, int data_type) { + StackF16Struct *stack_f16 = (StackF16Struct *)malloc(sizeof(StackF16Struct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(stack_f16); + StackStruct *stack = &stack_f16->stack_; + stack->buffers_ = NULL; + stack->data_type_ = data_type; + stack->base_.Release = StackRelease; + stack->base_.Prepare = StackPrepare; + stack->base_.Resize = StackResize; + stack->base_.Compute = StackF16Compute; + return (KernelBase *)stack; +} + +REG_KERNEL_CREATOR(PrimType_Stack, kNumberTypeFloat16, CreateStackF16) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/stack_f16.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/stack_f16.h new file mode 100644 index 0000000000000000000000000000000000000000..53a438b4c22b02caed514005c996d3bc4b99a08b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/f16/stack_f16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_F16_STACK_F16_H_ +#define NNACL_KERNEL_F16_STACK_F16_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/kernel/stack.h" + +typedef struct StackF16Struct { + StackStruct stack_; + bool *init_; +} StackF16Struct; + +KernelBase *CreateStackF16(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_F16_STACK_F16_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fill.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fill.c new file mode 100644 index 0000000000000000000000000000000000000000..86acc90582ea259bca3224ff4ba6c2ac45ac25be --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fill.c @@ -0,0 +1,102 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/fill.h" +#include "nnacl/fill_parameter.h" +#include "nnacl/op_base.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/base/fill_base.h" +#include "nnacl/kernel/default_kernel_base.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/fill_fp16.h" +#endif + +int FillResize(struct KernelBase *self) { + FillStruct *fill = (FillStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fill); + fill->base_.thread_nr_ = fill->base_.UpdateThread( + TC_PTYPE(PrimType_Fill), 0, 1, NNACLGetSize(fill->base_.out_[OUTPUT_INDEX]), fill->base_.thread_nr_); + + NNACL_CHECK_NULL_RETURN_ERR(fill->base_.out_[OUTPUT_INDEX]); + fill->data_size_ = (int)NNACLGetElementNum(fill->base_.out_[OUTPUT_INDEX]); + fill->thread_sz_count_ = MSMIN(fill->base_.thread_nr_, fill->data_size_); + if (fill->thread_sz_count_ != 0) { + fill->thread_sz_stride_ = UP_DIV(fill->data_size_, fill->thread_sz_count_); + } + return NNACL_OK; +} + +int FillImpl(void *cdata, int task_id, float l, float r) { + FillStruct *fill = (FillStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(fill); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, fill->thread_sz_stride_, NNACL_ERR); + int size = MSMIN(fill->thread_sz_stride_, fill->data_size_ - task_id * fill->thread_sz_stride_); + NNACL_CHECK_FALSE(size <= 0, NNACL_OK); + int offset = task_id * fill->thread_sz_stride_; + int ret = NNACL_OK; + switch (fill->base_.in_[FIRST_INPUT]->data_type_) { +#ifdef ENABLE_FP16 + case kNumberTypeFloat16: + ret = FillFp16((float16_t *)fill->out_ptr_ + offset, size, ((float16_t *)fill->src_data_)[FIRST_INPUT]); + break; +#endif + case kNumberTypeFloat32: + ret = FillFp32((float *)fill->out_ptr_ + offset, size, ((float *)fill->src_data_)[FIRST_INPUT]); + break; + case kNumberTypeInt32: + ret = FillInt32((int *)fill->out_ptr_ + offset, size, ((int *)fill->src_data_)[FIRST_INPUT]); + break; + case kNumberTypeBool: + ret = FillBool((bool *)fill->out_ptr_ + offset, size, ((bool *)fill->src_data_)[FIRST_INPUT]); + break; + default: + return NNACL_FILL_DATA_TYPE_INVALID; + } + return ret; +} + +int FillCompute(struct KernelBase *self) { + FillStruct *fill = (FillStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fill); + + fill->src_data_ = (void *)fill->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(fill->src_data_); + fill->out_ptr_ = (void *)fill->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(fill->out_ptr_); + + return self->env_->ParallelLaunch(self->env_->thread_pool_, FillImpl, fill, fill->base_.thread_nr_); +} + +KernelBase *CreateFill(OpParameter *param, int data_type) { + FillStruct *fill = (FillStruct *)malloc(sizeof(FillStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fill); + fill->base_.Prepare = DefaultPrepare2In1Out; + fill->base_.Resize = FillResize; + fill->base_.Release = DefaultRelease; + fill->base_.Compute = FillCompute; + return (KernelBase *)fill; +} + +REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeBool, CreateFill); +REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeInt32, CreateFill); +REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeFloat32, CreateFill); +REG_KERNEL_CREATOR(PrimType_Fill, kNumberTypeFloat16, CreateFill); + +REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeBool, CreateFill); +REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeInt32, CreateFill); +REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeFloat32, CreateFill); +REG_KERNEL_CREATOR(PrimType_FillV2, kNumberTypeFloat16, CreateFill); diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fill.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fill.h new file mode 100644 index 0000000000000000000000000000000000000000..15f5cb790e98860fecbc12035252ed63de6683f1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fill.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_FILL_H_ +#define NNACL_KERNEL_FILL_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct FillStruct { + KernelBase base_; + int thread_sz_count_; + int thread_sz_stride_; + int data_size_; + void *src_data_; + void *out_ptr_; + int thread_count_; +} FillStruct; + +KernelBase *CreateFill(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_FILL_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fullconnection.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fullconnection.c new file mode 100644 index 0000000000000000000000000000000000000000..a358e91cf058ba1f95d313dc7d268d4ff9519f9b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fullconnection.c @@ -0,0 +1,81 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/fullconnection.h" +#include "nnacl/kernel/matmul_base.h" +#include "nnacl/kernel/matmul_create.h" + +int FullConnectionPrepare(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + + NNACL_CHECK_FALSE(self->in_size_ < C2NUM, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < C1NUM, NNACL_ERR); + + if (matmul->a_const_ || matmul->infer_shape_) { + int *a_shape = self->in_[FIRST_INPUT]->shape_; + matmul->compute_.row_ = a_shape[0]; + matmul->compute_.deep_ = a_shape[1]; + } + + if (matmul->b_const_ || matmul->infer_shape_) { + int *b_shape = self->in_[SECOND_INPUT]->shape_; + matmul->compute_.col_ = b_shape[0]; + matmul->compute_.deep_ = b_shape[1]; + } + + matmul->batch_ = 1; + matmul->a_batch_ = 1; + matmul->b_batch_ = 1; + + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + param->a_transpose_ = false; + param->b_transpose_ = true; + + int ret = MatmulBaseMallocBatchOffset(matmul); + if (ret != NNACL_OK) { + return ret; + } + + return MatmulBasePrepare(self); +} + +int FullConnectionResize(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + NNACL_CHECK_TRUE_RET(self->out_[0]->shape_size_ > 0, NNACL_ERR); + + int row = 1; + for (size_t i = 0; i < self->out_[0]->shape_size_ - 1; ++i) { + row *= (self->out_[OUTPUT_INDEX]->shape_)[i]; + } + matmul->compute_.row_ = row; + matmul->compute_.col_ = (self->out_[OUTPUT_INDEX]->shape_)[self->out_[0]->shape_size_ - 1]; + matmul->compute_.deep_ = self->in_[SECOND_INPUT]->shape_[SECOND_INPUT]; + + return MatmulBaseResize(self); +} + +KernelBase *CreateFullconnection(OpParameter *param, int data_type) { + KernelBase *kernel = NULL; + if (data_type == kNumberTypeFloat32) { + kernel = CreateMatmulKernel(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(kernel); + kernel->Prepare = FullConnectionPrepare; + kernel->Resize = FullConnectionResize; + } + return kernel; +} + +REG_KERNEL_CREATOR(PrimType_FullConnection, kNumberTypeFloat32, CreateFullconnection); diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fullconnection.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fullconnection.h new file mode 100644 index 0000000000000000000000000000000000000000..678c047df4c3a3a3e912d87d0a7760714bf1e924 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fullconnection.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_FULLCONNECTION_H_ +#define NNACL_KERNEL_FULLCONNECTION_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateFullconnection(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_FULLCONNECTION_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fused_batch_norm.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fused_batch_norm.c new file mode 100644 index 0000000000000000000000000000000000000000..ef19be66048a29cbcaab5e4837ff4e88c4756472 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fused_batch_norm.c @@ -0,0 +1,327 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/fused_batch_norm.h" +#include +#include "nnacl/op_base.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/batchnorm_parameter.h" +#include "nnacl/fp32/batchnorm_fp32.h" +#include "nnacl/fp32/scale_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/scale_fp16.h" +#include "nnacl/fp16/batchnorm_fp16.h" +#endif + +int FusedBatchNormInitScaleParam(FusedBatchNormStruct *fused_batch_norm) { + ScaleStruct *scale = &fused_batch_norm->scale_param_; + scale->base_.thread_nr_ = fused_batch_norm->bn_.base_.thread_nr_; + + scale->axis_ = kNHWC_C; + TensorC *in_tensor = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]; + if (in_tensor->shape_size_ != DIMENSION_4D) { + return NNACL_FUSED_BATCH_NORM_NO_CHANGE; + } + + scale->outer_size_ = 1; + for (int i = 0; i < scale->axis_; i++) { + scale->outer_size_ *= in_tensor->shape_[i]; + } + scale->axis_size_ = in_tensor->shape_[Index3]; + scale->inner_size_ = 1; + return NNACL_OK; +} + +void FusedBatchNormCalculateScaleF32(FusedBatchNormStruct *fbn, const void *scale_data, const void *bias_data, + const void *mean_data, const void *var_data, float eps, int kernel_num) { + float *fp32_scale_origin = (float *)scale_data; + float *fp32_var_origin = (float *)var_data; + float *fp32_bias_origin = (float *)bias_data; + float *fp32_mean_origin = (float *)mean_data; + + float *fp32_scale = (float *)fbn->scale_; + for (int i = 0; i < kernel_num; i++) { + fp32_scale[i] = fp32_scale_origin[i] / sqrtf(fp32_var_origin[i] + eps); + } + + float *fp32_offset = (float *)fbn->offset_; + for (int i = 0; i < kernel_num; i++) { + fp32_offset[i] = fp32_bias_origin[i] - fp32_mean_origin[i] * fp32_scale[i]; + } +} + +void FusedBatchNormCalculateScaleF16(FusedBatchNormStruct *fbn, const void *scale_data, const void *bias_data, + const void *mean_data, const void *var_data, float eps, int kernel_num) { +#ifdef ENABLE_FP16 + float16_t *fp16_scale_origin = (float16_t *)scale_data; + float16_t *fp16_var_origin = (float16_t *)var_data; + float16_t *fp16_bias_origin = (float16_t *)bias_data; + float16_t *fp16_mean_origin = (float16_t *)mean_data; + + float16_t *fp16_scale = (float16_t *)fbn->scale_; + for (int i = 0; i < kernel_num; i++) { + fp16_scale[i] = fp16_scale_origin[i] / sqrtf(fp16_var_origin[i] + eps); + } + + float16_t *fp16_offset = (float16_t *)fbn->offset_; + for (int i = 0; i < kernel_num; i++) { + fp16_offset[i] = fp16_bias_origin[i] - fp16_mean_origin[i] * fp16_scale[i]; + } +#endif +} + +void FusedBatchNormRunFp16(FusedBatchNormStruct *fused_batch_norm, int task_id) { +#ifdef ENABLE_FP16 + void *in_data = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]->data_; + void *out_data = fused_batch_norm->bn_.base_.out_[OUTPUT_INDEX]->data_; + + if (fused_batch_norm->is_scale_) { + DoScaleFp16((float16_t *)in_data, (float16_t *)out_data, (float16_t *)fused_batch_norm->scale_, + (float16_t *)fused_batch_norm->offset_, task_id, &fused_batch_norm->scale_param_); + } else { + FusedBatchNormFp16((float16_t *)in_data, (float16_t *)fused_batch_norm->scale_, + (float16_t *)fused_batch_norm->offset_, (float16_t *)fused_batch_norm->bn_.mean_, + (float16_t *)fused_batch_norm->bn_.variance_, &fused_batch_norm->bn_, task_id, + fused_batch_norm->bn_.base_.thread_nr_, (float16_t *)out_data); + } +#endif +} + +int FusedBatchNormBatchnorm2Scale(FusedBatchNormStruct *fused_batch_norm, const void *scale_data, const void *bias_data, + const void *mean_data, const void *var_data, float eps, int kernel_num) { + int ret = FusedBatchNormInitScaleParam(fused_batch_norm); + if (ret != NNACL_OK) { + return ret; + } + + ExecEnv *env = fused_batch_norm->bn_.base_.env_; + TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT]; + fused_batch_norm->scale_ = env->Alloc(env->allocator_, NNACLGetSize(scale_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->scale_); + TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT]; + fused_batch_norm->offset_ = env->Alloc(env->allocator_, NNACLGetSize(offset_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->offset_); + + // new scale: -scale / sqrt(variance + eps) + // new bias: -scale * mean / sqrt(variance + eps) + bias + if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat16) { + FusedBatchNormCalculateScaleF16(fused_batch_norm, scale_data, bias_data, mean_data, var_data, eps, kernel_num); + } else { + FusedBatchNormCalculateScaleF32(fused_batch_norm, scale_data, bias_data, mean_data, var_data, eps, kernel_num); + } + + fused_batch_norm->is_scale_ = true; + return NNACL_OK; +} + +int FusedBatchNormInitConstTensor(FusedBatchNormStruct *fused_batch_norm) { + TensorC *scale_tensor = fused_batch_norm->bn_.base_.in_[SECOND_INPUT]; + TensorC *offset_tensor = fused_batch_norm->bn_.base_.in_[THIRD_INPUT]; + TensorC *mean_tensor = fused_batch_norm->bn_.base_.in_[FOURTH_INPUT]; + TensorC *variance_tensor = fused_batch_norm->bn_.base_.in_[FIFTH_INPUT]; + + if (!fused_batch_norm->bn_.base_.train_session_) { + int ret = FusedBatchNormBatchnorm2Scale( + fused_batch_norm, (float *)scale_tensor->data_, (float *)offset_tensor->data_, (float *)mean_tensor->data_, + (float *)variance_tensor->data_, fused_batch_norm->bn_.epsilon_, NNACLGetElementNum(scale_tensor)); + if (ret == NNACL_OK) { + return NNACL_OK; + } else { + fused_batch_norm->bn_.base_.Release(&fused_batch_norm->bn_.base_); + if (ret != NNACL_FUSED_BATCH_NORM_NO_CHANGE) { + return NNACL_FUSED_BATCH_NORM_TO_SCALE_FAILED; + } + } + } + + ExecEnv *env = fused_batch_norm->bn_.base_.env_; + fused_batch_norm->scale_ = env->Alloc(env->allocator_, NNACLGetSize(scale_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->scale_); + (void)memcpy(fused_batch_norm->scale_, scale_tensor->data_, NNACLGetSize(scale_tensor)); + fused_batch_norm->offset_ = env->Alloc(env->allocator_, NNACLGetSize(offset_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->offset_); + (void)memcpy(fused_batch_norm->offset_, offset_tensor->data_, NNACLGetSize(offset_tensor)); + fused_batch_norm->bn_.mean_ = env->Alloc(env->allocator_, NNACLGetSize(mean_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->bn_.mean_); + (void)memcpy(fused_batch_norm->bn_.mean_, mean_tensor->data_, NNACLGetSize(mean_tensor)); + fused_batch_norm->bn_.variance_ = env->Alloc(env->allocator_, NNACLGetSize(variance_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(fused_batch_norm->bn_.variance_); + (void)memcpy(fused_batch_norm->bn_.variance_, variance_tensor->data_, NNACLGetSize(variance_tensor)); + return NNACL_OK; +} + +void FusedBatchNormRunFp32(FusedBatchNormStruct *fused_batch_norm, int task_id) { + void *in_data = fused_batch_norm->bn_.base_.in_[FIRST_INPUT]->data_; + void *out_data = fused_batch_norm->bn_.base_.out_[OUTPUT_INDEX]->data_; + + if (fused_batch_norm->is_scale_) { + DoScale((float *)in_data, (float *)out_data, (float *)fused_batch_norm->scale_, (float *)fused_batch_norm->offset_, + task_id, &fused_batch_norm->scale_param_); + } else { + FusedBatchNormFp32((float *)in_data, (float *)fused_batch_norm->scale_, (float *)fused_batch_norm->offset_, + (float *)fused_batch_norm->bn_.mean_, (float *)fused_batch_norm->bn_.variance_, + &fused_batch_norm->bn_, task_id, fused_batch_norm->bn_.base_.thread_nr_, (float *)out_data); + } +} + +int FusedBatchNormRun(void *cdata, int task_id, float l, float r) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat16) { + FusedBatchNormRunFp16(fused_batch_norm, task_id); + } else if (fused_batch_norm->bn_.data_type_ == kNumberTypeFloat32) { + FusedBatchNormRunFp32(fused_batch_norm, task_id); + } + return NNACL_OK; +} + +int FusedBatchNormTrainComputeInit(FusedBatchNormStruct *fbn) { + if (fbn->bn_.base_.out_size_ < Num5) { + return NNACL_OK; + } + + TensorC *out_scale = fbn->bn_.base_.out_[SECOND_INPUT]; + TensorC *out_offset = fbn->bn_.base_.out_[THIRD_INPUT]; + TensorC *out_mean = fbn->bn_.base_.out_[FOURTH_INPUT]; + TensorC *out_var = fbn->bn_.base_.out_[FIFTH_INPUT]; + + void *current_mean = fbn->bn_.mean_; + void *current_var = fbn->bn_.variance_; + + bool schema_trained = ((BatchNormParameter *)fbn->bn_.base_.param_)->is_training_; + if (fbn->train_mode_ && schema_trained && fbn->bn_.base_.in_size_ >= Num5) { + TensorC *in_tensor = fbn->bn_.base_.in_[FIRST_INPUT]; + TensorC *scale_tensor = fbn->bn_.base_.in_[SECOND_INPUT]; + TensorC *offset_tensor = fbn->bn_.base_.in_[THIRD_INPUT]; + TensorC *mean_tensor = fbn->bn_.base_.in_[FOURTH_INPUT]; + TensorC *var_tensor = fbn->bn_.base_.in_[FIFTH_INPUT]; + if (in_tensor->data_ == NULL || scale_tensor->data_ == NULL || offset_tensor->data_ == NULL || + mean_tensor->data_ == NULL || var_tensor->data_ == NULL) { + return NNACL_FUSED_BATCH_TRAIN_DATA_INVALID; + } + + memset(current_mean, 0, NNACLGetSize(mean_tensor)); + memset(current_var, 0, NNACLGetSize(var_tensor)); + + bool isBatch2d = true; + if (fbn->bn_.base_.in_[FIRST_INPUT]->shape_size_ == Num2) isBatch2d = false; + + if (fbn->bn_.data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + FusedBatchNormFp16MeanVar((float16_t *)in_tensor->data_, (float16_t *)current_mean, current_var, &fbn->bn_, + (float16_t *)mean_tensor->data_, (float16_t *)var_tensor->data_); +#endif + } else { + FusedBatchNormFp32MeanVar((float *)in_tensor->data_, (float *)current_mean, current_var, &fbn->bn_, + (float *)mean_tensor->data_, (float *)var_tensor->data_, isBatch2d); + } + + (void)memcpy(out_scale->data_, scale_tensor->data_, NNACLGetSize(out_scale)); + (void)memcpy(out_offset->data_, offset_tensor->data_, NNACLGetSize(out_offset)); + (void)memcpy(out_mean->data_, current_mean, NNACLGetSize(out_mean)); + (void)memcpy(out_var->data_, current_var, NNACLGetSize(out_var)); + + // Copy to local variables + (void)memcpy(fbn->scale_, scale_tensor->data_, NNACLGetSize(scale_tensor)); + (void)memcpy(fbn->offset_, offset_tensor->data_, NNACLGetSize(offset_tensor)); + + fbn->trained_ = true; // trained at least once + return NNACL_OK; + } + + if (fbn->bn_.base_.train_session_) { + (void)memcpy(out_scale->data_, fbn->scale_, NNACLGetSize(out_scale)); + (void)memcpy(out_offset->data_, fbn->offset_, NNACLGetSize(out_offset)); + (void)memcpy(out_mean->data_, current_mean, NNACLGetSize(out_mean)); + (void)memcpy(out_var->data_, current_var, NNACLGetSize(out_var)); + } + + return NNACL_OK; +} + +int FusedBatchNormCompute(KernelBase *self) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + + int ret = FusedBatchNormTrainComputeInit(fused_batch_norm); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, FusedBatchNormRun, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +int FusedBatchNormReSize(KernelBase *self) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + + int ret = BatchNormFillParam(&fused_batch_norm->bn_); + if (ret != NNACL_OK) { + return ret; + } + + (void)self->Release(self); + + return FusedBatchNormInitConstTensor(fused_batch_norm); +} + +int FusedBatchNormPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < FIVE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + fused_batch_norm->bn_.momentum_ = ((BatchNormParameter *)self->param_)->momentum_; + fused_batch_norm->bn_.epsilon_ = ((BatchNormParameter *)self->param_)->epsilon_; + return NNACL_OK; +} + +int FusedBatchNormRelease(KernelBase *self) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(fused_batch_norm); + + (void)BatchNormRelease(&fused_batch_norm->bn_.base_); + + if (fused_batch_norm->scale_ != NULL) { + self->env_->Free(self->env_->allocator_, fused_batch_norm->scale_); + fused_batch_norm->scale_ = NULL; + } + if (fused_batch_norm->offset_ != NULL) { + self->env_->Free(self->env_->allocator_, fused_batch_norm->offset_); + fused_batch_norm->offset_ = NULL; + } + return NNACL_OK; +} + +KernelBase *CreateFusedBatchNorm(OpParameter *param, int data_type) { + FusedBatchNormStruct *fused_batch_norm = (FusedBatchNormStruct *)malloc(sizeof(FusedBatchNormStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(fused_batch_norm); + memset(fused_batch_norm, 0, sizeof(FusedBatchNormStruct)); + fused_batch_norm->bn_.data_type_ = data_type; + fused_batch_norm->bn_.base_.Prepare = FusedBatchNormPrepare; + fused_batch_norm->bn_.base_.Resize = FusedBatchNormReSize; + fused_batch_norm->bn_.base_.Release = FusedBatchNormRelease; + fused_batch_norm->bn_.base_.Compute = FusedBatchNormCompute; + return (KernelBase *)fused_batch_norm; +} + +REG_KERNEL_CREATOR(PrimType_FusedBatchNorm, kNumberTypeFloat16, CreateFusedBatchNorm) +REG_KERNEL_CREATOR(PrimType_FusedBatchNorm, kNumberTypeFloat32, CreateFusedBatchNorm) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fused_batch_norm.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fused_batch_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..e703c554ea00f4ddf8f05c7191f06949c6bc643e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/fused_batch_norm.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_FUSED_BATCH_NORM_H_ +#define NNACL_KERNEL_FUSED_BATCH_NORM_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/kernel/batch_norm.h" +#include "nnacl/kernel/scale.h" + +typedef struct FusedBatchNormStruct { + BatchNormStruct bn_; + ScaleStruct scale_param_; + void *scale_; + void *offset_; + bool is_scale_; + bool trained_; + bool train_mode_; +} FusedBatchNormStruct; + +KernelBase *CreateFusedBatchNorm(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_FUSED_BATCH_NORM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather.c new file mode 100644 index 0000000000000000000000000000000000000000..b72058dd72d609025a24c4f3073a1f366c6e9bb4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather.c @@ -0,0 +1,241 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/gather.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/op_base.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" + +#define kGatherMinCostPerThread 16384 + +void GatherHandleCopy(GatherStruct *gather, int8_t **int8_in, int8_t **int8_out, int begin, int end, + int byte_in_stride) { + for (; begin < end; ++begin) { + int index = gather->indices_data_[begin]; + index = (index < 0 ? index + gather->limit_ : index); + if (index < 0 || index >= gather->limit_) { + memset(*int8_out, 0, gather->byte_inner_size_); + } else { + memcpy(*int8_out, *int8_in + index * gather->byte_inner_size_, gather->byte_inner_size_); + } + *int8_out += gather->byte_inner_size_; + } + *int8_in += byte_in_stride; +} + +int GatherRun(void *cdata, int task_id, float l, float r) { + GatherStruct *gather = (GatherStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(gather); + NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR); + NNACL_CHECK_FALSE(task_id >= gather->block_infos_size_, NNACL_ERR); + + int8_t *int8_in = (int8_t *)(gather->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(int8_in); + int8_t *int8_out = (int8_t *)(gather->base_.out_[OUTPUT_INDEX]->data_); + NNACL_CHECK_NULL_RETURN_ERR(int8_out); + int begin_batch = gather->block_infos_[task_id].begin_batch_; + int begin_index = gather->block_infos_[task_id].begin_index_; + int end_batch = gather->block_infos_[task_id].end_batch_; + int end_index = gather->block_infos_[task_id].end_index_; + int64_t byte_in_stride = gather->limit_ * gather->byte_inner_size_; + int8_in += begin_batch * byte_in_stride; + int8_out += begin_batch * gather->indices_size_ * gather->byte_inner_size_ + begin_index * gather->byte_inner_size_; + if (begin_batch == end_batch) { + GatherHandleCopy(gather, &int8_in, &int8_out, begin_index, end_index, byte_in_stride); + return NNACL_OK; + } + GatherHandleCopy(gather, &int8_in, &int8_out, begin_index, gather->indices_size_, byte_in_stride); + ++begin_batch; + for (; begin_batch < end_batch; ++begin_batch) { + GatherHandleCopy(gather, &int8_in, &int8_out, 0, gather->indices_size_, byte_in_stride); + } + GatherHandleCopy(gather, &int8_in, &int8_out, 0, end_index, byte_in_stride); + return NNACL_OK; +} + +int AssignGatherIndicesData(GatherStruct *gather, bool is_indices_int32) { + TensorC *indices_tensor = gather->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(indices_tensor->data_); + + if (is_indices_int32) { + gather->indices_data_ = (int *)(indices_tensor->data_); + return NNACL_OK; + } + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather->indices_size_, (int)(sizeof(int)), NNACL_ERR); + gather->indices_data_ = + (int *)(gather->base_.env_->Alloc(gather->base_.env_->allocator_, gather->indices_size_ * sizeof(int))); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(gather->indices_data_); + + switch (indices_tensor->data_type_) { + case kNumberTypeInt64: + for (int i = 0; i < gather->indices_size_; i++) { + gather->indices_data_[i] = (int)((int64_t *)indices_tensor->data_)[i]; + } + break; + case kNumberTypeFloat: + case kNumberTypeFloat32: + for (int i = 0; i < gather->indices_size_; i++) { + gather->indices_data_[i] = (int)((float *)indices_tensor->data_)[i]; + } + break; + case kNumberTypeBool: + for (int i = 0; i < gather->indices_size_; i++) { + gather->indices_data_[i] = (int)((bool *)indices_tensor->data_)[i]; + } + break; + default: + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +int InitGatherDynamicStatus(GatherStruct *gather) { + int *in_shape = gather->base_.in_[FIRST_INPUT]->shape_; + int in_rank = (int)gather->base_.in_[FIRST_INPUT]->shape_size_; + NNACL_CHECK_TRUE_RET(gather->axis_ >= 0 && gather->axis_ < in_rank, NNACL_GATHER_AXIS_INVALID); + gather->limit_ = in_shape[gather->axis_]; + gather->outer_size_ = 1; + for (int i = 0; i < gather->axis_; ++i) { + gather->outer_size_ *= in_shape[i]; + } + gather->byte_inner_size_ = (int)DataTypeCSize(gather->base_.out_[OUTPUT_INDEX]->data_type_); + for (int i = gather->axis_ + 1; i < in_rank; ++i) { + gather->byte_inner_size_ *= in_shape[i]; + } + gather->indices_size_ = NNACLGetElementNum(gather->base_.in_[SECOND_INPUT]); + return NNACL_OK; +} + +void GatherUpdateThreadNumProcess(GatherStruct *gather) { + int all_bytes = NNACLGetSize(gather->base_.out_[OUTPUT_INDEX]); + if (all_bytes <= kGatherMinCostPerThread) { + gather->base_.thread_nr_ = 1; + return; + } + + gather->base_.thread_nr_ = + gather->base_.UpdateThread(TC_PTYPE(PrimType_Gather), 0, gather->byte_inner_size_, + NNACLGetSize(gather->base_.out_[OUTPUT_INDEX]), gather->base_.thread_nr_); + return; +} + +int ChooseGatherThreadCuttingStrategy(GatherStruct *gather) { + gather->block_infos_size_ = 0; + if (gather->outer_size_ == 0 || gather->indices_size_ == 0 || gather->byte_inner_size_ == 0) { + return NNACL_OK; + } + GatherUpdateThreadNumProcess(gather); + if (gather->base_.thread_nr_ > GATHER_BLOCK_INFOS_SIZE) { + gather->base_.thread_nr_ = GATHER_BLOCK_INFOS_SIZE; + } + + if (gather->base_.thread_nr_ == 1) { + gather->block_infos_[gather->block_infos_size_].begin_batch_ = 0; + gather->block_infos_[gather->block_infos_size_].begin_index_ = 0; + gather->block_infos_[gather->block_infos_size_].end_batch_ = gather->outer_size_; + gather->block_infos_[gather->block_infos_size_].end_index_ = 0; + gather->block_infos_size_++; + } else { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather->outer_size_, gather->indices_size_, NNACL_ERR); + int total_block = gather->outer_size_ * gather->indices_size_; + int block_size = total_block / gather->base_.thread_nr_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(block_size, gather->base_.thread_nr_, NNACL_ERR); + int remain_block = total_block - block_size * gather->base_.thread_nr_; + int start = 0; + while (start < total_block) { + GatherBlockBoundaryInfo block_boundary_info; + block_boundary_info.begin_batch_ = start / gather->indices_size_; + block_boundary_info.begin_index_ = start % gather->indices_size_; + start += block_size; + if (remain_block > 0) { + ++start; + --remain_block; + } + if (start >= total_block) { + start = total_block; + } + block_boundary_info.end_batch_ = start / gather->indices_size_; + block_boundary_info.end_index_ = start % gather->indices_size_; + gather->block_infos_[gather->block_infos_size_++] = block_boundary_info; + } + gather->base_.thread_nr_ = gather->block_infos_size_; + } + + return NNACL_OK; +} + +int GatherResize(KernelBase *self) { + GatherStruct *gather = (GatherStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather); + + int status = InitGatherDynamicStatus(gather); + NNACL_CHECK_FALSE(status != NNACL_OK, status); + + return ChooseGatherThreadCuttingStrategy(gather); +} + +int GatherPrepare(struct KernelBase *self) { + GatherStruct *gather = (GatherStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather); + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_GATHER_INPUT_TENSOR_INVALID); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_GATHER_OUTPUT_TENSOR_INVALID); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[THIRD_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[THIRD_INPUT]->data_); + gather->axis_ = *((int *)self->in_[THIRD_INPUT]->data_); + return NNACL_OK; +} + +int GatherCompute(struct KernelBase *self) { + GatherStruct *gather = (GatherStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather); + + if (gather->outer_size_ == 0 || gather->indices_size_ == 0 || gather->byte_inner_size_ == 0) { + return NNACL_OK; + } + + bool is_indices_int32 = self->in_[SECOND_INPUT]->data_type_ == kNumberTypeInt32; + int ret = AssignGatherIndicesData(gather, is_indices_int32); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, GatherRun, gather, gather->base_.thread_nr_); + + if (!is_indices_int32) { + self->env_->Free(self->env_->allocator_, gather->indices_data_); + gather->indices_data_ = NULL; + } + return ret; +} + +KernelBase *CreateGather(OpParameter *param, int data_type) { + GatherStruct *gather = (GatherStruct *)malloc(sizeof(GatherStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather); + gather->indices_data_ = NULL; + gather->block_infos_size_ = 0; + gather->base_.Prepare = GatherPrepare; + gather->base_.Resize = GatherResize; + gather->base_.Release = DefaultRelease; + gather->base_.Compute = GatherCompute; + return (KernelBase *)gather; +} + +REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeFloat16, CreateGather) +REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeFloat32, CreateGather) +REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeInt32, CreateGather) +REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeBool, CreateGather) diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cuh b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather.h similarity index 46% rename from mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cuh rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather.h index 0d95eeb3c3d81ff629fe723bebeb3a4f1c7113b6..e58a0750affcc91610eb9f0c3ee10a0e8c1902be 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cuh +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather.h @@ -13,17 +13,34 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef NNACL_KERNEL_GATHER_H_ +#define NNACL_KERNEL_GATHER_H_ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_FSE_DECODE_IMPL_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_FSE_DECODE_IMPL_H_ +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" -#include -#include "tools/converter/quantizer/fse_chunk_end.h" +#define GATHER_BLOCK_INFOS_SIZE 32 -template -void FSE_Decode(const uint64_t *chunks, const uint16_t *states_table, const uint8_t *bit_count_table, - const uint16_t *symbol_table, const uint64_t *ptable, int ptable_size, const T *centroids, - uint64_t output_size, T *output, const uint32_t &device_id, uint64_t current_chunk_input, - bool use_curr_chunk, cudaStream_t cuda_stream); +typedef struct GatherBlockBoundaryInfo { + int64_t begin_batch_; + int64_t begin_index_; + int64_t end_batch_; + int64_t end_index_; +} GatherBlockBoundaryInfo; -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_FSE_DECODE_IMPL_H_ +typedef struct GatherStruct { + KernelBase base_; + int axis_; + int limit_; + int outer_size_; + int indices_size_; + int byte_inner_size_; + int block_infos_size_; + int *indices_data_; + GatherBlockBoundaryInfo block_infos_[GATHER_BLOCK_INFOS_SIZE]; +} GatherStruct; + +KernelBase *CreateGather(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_GATHER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather_d.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather_d.c new file mode 100644 index 0000000000000000000000000000000000000000..62887c8b6252cd92e7a4b33ee1079f299ce72bda --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather_d.c @@ -0,0 +1,124 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either gather_dress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/gather_d.h" +#include "nnacl/gather_parameter.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/op_base.h" +#include "nnacl/base/gather_d_base.h" +#include "nnacl/kernel/default_kernel_base.h" + +typedef struct GatherDStru { + KernelBase base; +} GatherDStru; + +int GatherDPrepare(struct KernelBase *self) { + GatherDStru *gather_d = (GatherDStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_d); + GatherParameter *param = (GatherParameter *)gather_d->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_FALSE(self->in_size_ < kInputSize2 || self->out_size_ < 1, NNACL_TENSOR_SIZE_INVALID); + + param->axis_ = ((int *)(gather_d->base.in_[1]->data_))[0]; + return NNACL_OK; +} + +int GatherDResize(struct KernelBase *self) { + GatherDStru *gather_d = (GatherDStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_d); + GatherParameter *param = (GatherParameter *)gather_d->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + int input_rank = (int)gather_d->base.in_[0]->shape_size_; + NNACL_CHECK_FALSE(param->axis_ >= input_rank || param->axis_ < -input_rank, NNACL_GATHER_D_AXIS_INVALID); + + if (param->axis_ < 0) { + param->axis_ = param->axis_ + input_rank; + } + return NNACL_OK; +} + +int GatherDCompute(struct KernelBase *self) { + GatherDStru *gather_d_stru = (GatherDStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_d_stru); + GatherParameter *param = (GatherParameter *)gather_d_stru->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + TensorC *input = gather_d_stru->base.in_[0]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = gather_d_stru->base.out_[0]; + NNACL_CHECK_NULL_RETURN_ERR(output); + const void *input_data = input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + const void *index_data = gather_d_stru->base.in_[2]->data_; + NNACL_CHECK_NULL_RETURN_ERR(index_data); + void *output_data = output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + size_t input_shape[MAX_SHAPE_SIZE]; + for (size_t i = 0; i < input->shape_size_; i++) { + input_shape[i] = input->shape_[i]; + } + size_t output_shape[MAX_SHAPE_SIZE]; + for (size_t i = 0; i < output->shape_size_; i++) { + output_shape[i] = output->shape_[i]; + } + + int input_dtype = input->data_type_; + int index_dtype = gather_d_stru->base.in_[THIRD_INPUT]->data_type_; + int status = NNACL_ERR; + if (index_dtype == kNumberTypeInt32) { + if (input_dtype == kNumberTypeFloat32) { + status = GATHER_D(float, int32_t, (float *)output_data, (float *)input_data, (int32_t *)index_data, input_shape, + input->shape_size_, output_shape, output->shape_size_, param->axis_); + } else if (input_dtype == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + status = GATHER_D(float16_t, int32_t, (float16_t *)output_data, (float16_t *)input_data, (int32_t *)index_data, + input_shape, input->shape_size_, output_shape, output->shape_size_, param->axis_); +#endif + } else if (input_dtype == kNumberTypeInt32) { + status = GATHER_D(int32_t, int32_t, (int32_t *)output_data, (int32_t *)input_data, (int32_t *)index_data, + input_shape, input->shape_size_, output_shape, output->shape_size_, param->axis_); + } + } else if (index_dtype == kNumberTypeInt64) { + if (input_dtype == kNumberTypeFloat32) { + status = GATHER_D(float, int64_t, (float *)output_data, (float *)input_data, (int64_t *)index_data, input_shape, + input->shape_size_, output_shape, output->shape_size_, param->axis_); + } else if (input_dtype == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + status = GATHER_D(float16_t, int64_t, (float16_t *)output_data, (float16_t *)input_data, (int64_t *)index_data, + input_shape, input->shape_size_, output_shape, output->shape_size_, param->axis_); +#endif + } else if (input_dtype == kNumberTypeInt32) { + status = GATHER_D(int32_t, int64_t, (int32_t *)output_data, (int32_t *)input_data, (int64_t *)index_data, + input_shape, input->shape_size_, output_shape, output->shape_size_, param->axis_); + } + } + return status; +} + +KernelBase *CreateGatherD(OpParameter *param, int data_type) { + GatherDStru *gather_d = (GatherDStru *)malloc(sizeof(GatherDStru)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather_d); + gather_d->base.Prepare = GatherDPrepare; + gather_d->base.Resize = GatherDResize; + gather_d->base.Release = DefaultRelease; + gather_d->base.Compute = GatherDCompute; + return (KernelBase *)gather_d; +} + +REG_KERNEL_CREATOR(PrimType_GatherD, kNumberTypeFloat32, CreateGatherD); +REG_KERNEL_CREATOR(PrimType_GatherD, kNumberTypeInt32, CreateGatherD); +REG_KERNEL_CREATOR(PrimType_GatherD, kNumberTypeFloat16, CreateGatherD); diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather_d.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather_d.h new file mode 100644 index 0000000000000000000000000000000000000000..2dfbf3ba346c6a2102397dc98665679ab748dae1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather_d.h @@ -0,0 +1,25 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_GATHER_D_H_ +#define NNACL_KERNEL_GATHER_D_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateGatherD(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_GATHER_D_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather_nd.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather_nd.c new file mode 100644 index 0000000000000000000000000000000000000000..0c78bb3fe3b5d7d6dadb0833733cda3404aab987 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather_nd.c @@ -0,0 +1,168 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/gather_nd.h" +#include "nnacl/fp32/gatherNd_fp32.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/nnacl_common.h" + +int GatherNdInitOffset(GatherNdStruct *gather_nd) { + TensorC *input_tensor = gather_nd->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *indices_tensor = gather_nd->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(indices_tensor); + + if (indices_tensor->shape_size_ < 1) { + return NNACL_GATHER_ND_INDICES_RANK_INVALID; + } + + int in_rank = input_tensor->shape_size_; + int idx_lastshape = indices_tensor->shape_[indices_tensor->shape_size_ - 1]; + if (idx_lastshape > in_rank) { + return NNACL_GATHER_ND_INDICES_SHAPE_INVALID; + } + + gather_nd->area_ = 1; + for (int i = idx_lastshape; i < input_tensor->shape_size_; ++i) { + gather_nd->area_ *= input_tensor->shape_[i]; + } + + int in_stride[MAX_SHAPE_SIZE] = {0}; + in_stride[in_rank - 1] = 1; + for (int i = in_rank - 2; i >= 0; --i) { + in_stride[i] = input_tensor->shape_[i + 1] * in_stride[i + 1]; + } + + int idx_stride = idx_lastshape; + (void)memset(gather_nd->in_offset_, 0, gather_nd->count_ * sizeof(int)); + + if (indices_tensor->data_type_ == kNumberTypeInt || indices_tensor->data_type_ == kNumberTypeInt32) { + int32_t *indices_ptr = (int32_t *)indices_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(indices_ptr); + for (int j = 0; j < gather_nd->count_; ++j) { + for (int k = 0; k < idx_lastshape; ++k) { + gather_nd->in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride[k]; + } + } + } else if (indices_tensor->data_type_ == kNumberTypeInt64) { + int64_t *indices_ptr = (int64_t *)indices_tensor->data_; + for (int j = 0; j < gather_nd->count_; ++j) { + for (int k = 0; k < idx_lastshape; ++k) { + gather_nd->in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride[k]; + } + } + } else { + return NNACL_GATHER_ND_INDICES_DATA_TYPE_INVALID; + } + + return NNACL_OK; +} + +int GatherNdRun(void *cdata, int task_id, float l, float r) { + GatherNdStruct *gather_nd = (GatherNdStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd); + TensorC *input = gather_nd->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, gather_nd->thread_stride_, NNACL_ERR); + int count = NNACL_MIN(gather_nd->thread_stride_, gather_nd->count_ - task_id * gather_nd->thread_stride_); + if (count <= 0) { + return NNACL_OK; + } + + int offset = task_id * gather_nd->thread_stride_; + int dtype_len = DataTypeCSize(input->data_type_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(offset, gather_nd->area_, NNACL_ERR); + int8_t *out_ptr = (int8_t *)gather_nd->out_ptr_ + offset * gather_nd->area_ * dtype_len; + return GatherNd(gather_nd->in_ptr_, out_ptr, gather_nd->in_offset_ + offset, gather_nd->area_, count, dtype_len); +} + +int GatherNdCompute(KernelBase *self) { + GatherNdStruct *gather_nd = (GatherNdStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd); + + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + gather_nd->in_ptr_ = input->data_; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd->in_ptr_); + + TensorC *output = self->out_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(output); + gather_nd->out_ptr_ = output->data_; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd->out_ptr_); + + int ret = GatherNdInitOffset(gather_nd); + if (ret != NNACL_OK) { + return ret; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, GatherNdRun, self, self->thread_nr_); +} + +int GatherNdRelease(KernelBase *self) { + GatherNdStruct *gather_nd = (GatherNdStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd); + if (gather_nd->in_offset_ != NULL) { + self->env_->Free(self->env_->allocator_, gather_nd->in_offset_); + gather_nd->in_offset_ = NULL; + } + return NNACL_OK; +} + +int GatherNdResize(KernelBase *self) { + (void)self->Release; + GatherNdStruct *gather_nd = (GatherNdStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(gather_nd); + TensorC *indices_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(indices_tensor); + + gather_nd->count_ = 1; + for (int i = 0; i < indices_tensor->shape_size_ - 1; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather_nd->count_, indices_tensor->shape_[i], NNACL_ERR); + gather_nd->count_ *= indices_tensor->shape_[i]; + } + + int min_count = INT32_MAX / sizeof(int); + if (gather_nd->count_ >= min_count) { + return NNACL_GATHER_ND_COUNT_INVALID; + } + + gather_nd->in_offset_ = self->env_->Alloc(self->env_->allocator_, gather_nd->count_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(gather_nd->in_offset_); + + gather_nd->base_.thread_nr_ = NNACL_MIN(gather_nd->base_.thread_nr_, gather_nd->count_); + if (gather_nd->base_.thread_nr_ != 0) { + gather_nd->thread_stride_ = UP_DIV(gather_nd->count_, gather_nd->base_.thread_nr_); + } + return NNACL_OK; +} + +KernelBase *CreateGatherNd(OpParameter *param, int data_type) { + GatherNdStruct *gather_nd = (GatherNdStruct *)malloc(sizeof(GatherNdStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather_nd); + memset(gather_nd, 0, sizeof(GatherNdStruct)); + + gather_nd->base_.Prepare = DefaultPrepare2In1Out; + gather_nd->base_.Resize = GatherNdResize; + gather_nd->base_.Compute = GatherNdCompute; + gather_nd->base_.Release = GatherNdRelease; + return (KernelBase *)gather_nd; +} + +REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeBool, CreateGatherNd); +REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeInt32, CreateGatherNd); +REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeFloat32, CreateGatherNd); +REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeFloat16, CreateGatherNd); diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather_nd.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather_nd.h new file mode 100644 index 0000000000000000000000000000000000000000..0398f9cae9a61723fac043c3930754033fc0d383 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/gather_nd.h @@ -0,0 +1,35 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_GATHER_ND_H_ +#define NNACL_KERNEL_GATHER_ND_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct { + KernelBase base_; + int *in_offset_; + int count_; + int area_; + int thread_stride_; + void *in_ptr_; + void *out_ptr_; +} GatherNdStruct; + +KernelBase *CreateGatherNd(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_GATHER_ND_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/group_convolution.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/group_convolution.c new file mode 100644 index 0000000000000000000000000000000000000000..b85cc66d6a09c52e57d7d805fcf506e1b9b00944 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/group_convolution.c @@ -0,0 +1,419 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/group_convolution.h" +#include "nnacl/kernel/convolution_delegate.h" +#include "nnacl/base/conv_common_base.h" +#include "nnacl/tensor_c_utils.h" + +int GroupConvBasePrepare(GroupConvolutionStruct *group_conv) { + for (int i = 0; i < group_conv->group_; ++i) { + KernelBase *sub_conv = group_conv->group_convs_[i]; + NNACL_CHECK_NULL_RETURN_ERR(sub_conv); + int ret = sub_conv->Prepare(sub_conv); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int GroupConvCreatorNewInputTensor(GroupConvolutionStruct *group_conv, KernelBase *new_conv) { + TensorC *in_tensor = (TensorC *)malloc(sizeof(TensorC)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(in_tensor); + in_tensor->format_ = Format_NHWC; + in_tensor->category_ = VarTensor; + in_tensor->data_type_ = group_conv->data_type_; + in_tensor->shape_size_ = DIMENSION_4D; + in_tensor->shape_[Index0] = INVALID_SHAPE; + new_conv->in_[FIRST_INPUT] = in_tensor; + return NNACL_OK; +} + +int GroupConvCreatorNewOutputTensor(GroupConvolutionStruct *group_conv, KernelBase *new_conv) { + TensorC *out_tensor = (TensorC *)malloc(sizeof(TensorC)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(out_tensor); + out_tensor->format_ = Format_NHWC; + out_tensor->category_ = VarTensor; + out_tensor->data_type_ = group_conv->data_type_; + out_tensor->shape_size_ = DIMENSION_4D; + out_tensor->shape_[Index0] = INVALID_SHAPE; + new_conv->out_[OUTPUT_INDEX] = out_tensor; + return NNACL_OK; +} + +TensorC *CreateConstTensor(const TensorC *tensor, const int *shape, const int shape_size, const int index) { + NNACL_CHECK_NULL_RETURN_NULL(tensor->data_); + + TensorC *new_tensor = (TensorC *)malloc(sizeof(TensorC)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(new_tensor); + new_tensor->data_type_ = tensor->data_type_; + new_tensor->format_ = Format_NHWC; + new_tensor->category_ = ConstTensor; + new_tensor->shape_size_ = shape_size; + memcpy(new_tensor->shape_, shape, shape_size * sizeof(int)); + + int size = NNACLGetSize(new_tensor); + if (size <= 0) { + free(new_tensor); + return NULL; + } + + void *data = malloc(size); + if (data == NULL) { + free(new_tensor); + return NULL; + } + new_tensor->data_ = data; + + uint8_t *new_tensor_data = (uint8_t *)tensor->data_ + index * size; + memcpy(new_tensor->data_, new_tensor_data, size); + return new_tensor; +} + +int GroupConvCreatorNewConstTensor(GroupConvolutionStruct *group_conv, KernelBase *new_conv, int group_id) { + TensorC *origin_weight = group_conv->conv_base_.base_.in_[SECOND_INPUT]; + int shape[] = {group_conv->sub_out_c_, NNACLGetHeight(origin_weight), NNACLGetWidth(origin_weight), + group_conv->sub_in_c_}; + TensorC *weight_tensor = CreateConstTensor(origin_weight, shape, DIMENSION_4D, group_id); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(weight_tensor); + new_conv->in_[SECOND_INPUT] = weight_tensor; + + if (group_conv->conv_base_.base_.in_size_ == THREE_TENSOR) { + TensorC *bias_weight = group_conv->conv_base_.base_.in_[THIRD_INPUT]; + TensorC *bias_tensor = CreateConstTensor(bias_weight, &group_conv->sub_out_c_, DIMENSION_1D, group_id); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(bias_tensor); + new_conv->in_[THIRD_INPUT] = bias_tensor; + } + return NNACL_OK; +} + +int GroupConvCreatorSetShapeOfTensors(GroupConvolutionStruct *group_conv) { + ConvParameter *origin_conv_param = (ConvParameter *)group_conv->conv_base_.base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(origin_conv_param); + ConvParameter *new_conv_param = &group_conv->new_conv_param_; + NNACL_CHECK_NULL_RETURN_ERR(new_conv_param); + memcpy(new_conv_param, origin_conv_param, sizeof(ConvParameter)); + + TensorC *weight_tensor = group_conv->conv_base_.base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(weight_tensor); + NNACL_CHECK_FALSE(origin_conv_param->group_ == 0, NNACL_GROUP_CONVOLUTION_GROUP_INVALID); + NNACL_CHECK_FALSE(weight_tensor->shape_size_ != DIMENSION_4D, NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + NNACL_CHECK_FALSE(origin_conv_param->kernel_h_ != NNACLGetHeight(weight_tensor), + NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + NNACL_CHECK_FALSE(origin_conv_param->kernel_w_ != NNACLGetWidth(weight_tensor), + NNACL_CONVOLUTION_WEIGHT_SHAPE_INVALID); + + ConvComputeParam *compute = &group_conv->conv_base_.compute_; + group_conv->ori_in_c_ = compute->in_c_; + group_conv->ori_out_c_ = compute->out_c_; + group_conv->sub_in_c_ = compute->in_c_ / group_conv->group_; + group_conv->sub_out_c_ = compute->out_c_ / group_conv->group_; + + new_conv_param->input_channel_ = group_conv->sub_in_c_; + new_conv_param->output_channel_ = group_conv->sub_out_c_; + new_conv_param->group_ = origin_conv_param->group_; + + return NNACL_OK; +} + +int GroupConvSetSubConvInfo(GroupConvolutionStruct *group_conv, KernelBase *new_conv, int group_id) { + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + NNACL_CHECK_NULL_RETURN_ERR(new_conv); + + ConvolutionBaseStruct *sub_conv = (ConvolutionBaseStruct *)new_conv; + (void)ConvBaseUpdateParamInfo(&sub_conv->compute_, &group_conv->new_conv_param_); + + sub_conv->infershape_done_ = group_conv->conv_base_.infershape_done_; + sub_conv->shaing_manager_ = group_conv->conv_base_.shaing_manager_; + sub_conv->get_sharing_weight_ = group_conv->conv_base_.get_sharing_weight_; + sub_conv->free_sharing_weight_ = group_conv->conv_base_.free_sharing_weight_; + sub_conv->is_sharing_pack_ = group_conv->conv_base_.is_sharing_pack_; + + new_conv->env_ = group_conv->conv_base_.base_.env_; + new_conv->param_ = &group_conv->new_conv_param_.op_parameter_; + new_conv->thread_nr_ = group_conv->conv_base_.base_.thread_nr_; + new_conv->train_session_ = group_conv->conv_base_.base_.train_session_; + new_conv->UpdateThread = group_conv->conv_base_.base_.UpdateThread; + new_conv->in_size_ = group_conv->conv_base_.base_.in_size_; + new_conv->out_size_ = group_conv->conv_base_.base_.out_size_; + + new_conv->in_ = (TensorC **)malloc(new_conv->in_size_ * sizeof(TensorC *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(new_conv->in_); + memset(new_conv->in_, 0, new_conv->in_size_ * sizeof(TensorC *)); + new_conv->out_ = (TensorC **)malloc(new_conv->out_size_ * sizeof(TensorC *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(new_conv->out_); + memset(new_conv->out_, 0, new_conv->out_size_ * sizeof(TensorC *)); + + // create new input for each group + int ret = GroupConvCreatorNewInputTensor(group_conv, new_conv); + if (ret != NNACL_OK) { + group_conv->conv_base_.base_.Release((KernelBase *)group_conv); + return ret; + } + + // const tensor + ret = GroupConvCreatorNewConstTensor(group_conv, new_conv, group_id); + if (ret != NNACL_OK) { + group_conv->conv_base_.base_.Release((KernelBase *)group_conv); + return ret; + } + + // create new output tensor + ret = GroupConvCreatorNewOutputTensor(group_conv, new_conv); + if (ret != NNACL_OK) { + group_conv->conv_base_.base_.Release((KernelBase *)group_conv); + return ret; + } + return NNACL_OK; +} + +int GroupConvConcatOutputRun(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)cdata; + + int plane_step = UP_DIV(group_conv->conv_base_.compute_.out_hw_, group_conv->conv_base_.base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(plane_step, task_id, NNACL_ERR); + int begin_plane = plane_step * task_id; + int end_plane = NNACL_MIN(group_conv->conv_base_.compute_.out_hw_, plane_step * (task_id + 1)); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(begin_plane, group_conv->sub_out_c_, NNACL_ERR); + float *src_ptr = group_conv->sub_out_src_ + begin_plane * group_conv->sub_out_c_; + float *dst_ptr = group_conv->sub_out_dst_ + begin_plane * group_conv->ori_out_c_; + for (int i = begin_plane; i < end_plane; ++i) { + (void)memcpy(dst_ptr, src_ptr, group_conv->sub_out_c_ * sizeof(float)); + src_ptr += group_conv->sub_out_c_; + dst_ptr += group_conv->ori_out_c_; + } + return NNACL_OK; +} + +int GroupConvPostConcat(GroupConvolutionStruct *group_conv, int group_id) { + group_conv->sub_out_src_ = (float *)group_conv->group_convs_[group_id]->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->sub_out_src_); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(group_id, group_conv->sub_out_c_, NNACL_ERR); + group_conv->sub_out_dst_ = (float *)(group_conv->origin_output_data_) + group_id * group_conv->sub_out_c_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->sub_out_dst_); + + return group_conv->conv_base_.base_.env_->ParallelLaunch(group_conv->conv_base_.base_.env_->thread_pool_, + GroupConvConcatOutputRun, group_conv, + group_conv->conv_base_.base_.thread_nr_); +} + +int GroupConvSeparateInputRun(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)cdata; + + int plane_step = UP_DIV(group_conv->conv_base_.compute_.in_hw_, group_conv->conv_base_.base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(plane_step, task_id, NNACL_ERR); + int begin_plane = plane_step * task_id; + int end_plane = NNACL_MIN(group_conv->conv_base_.compute_.in_hw_, plane_step * (task_id + 1)); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(begin_plane, group_conv->ori_in_c_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(begin_plane, group_conv->sub_in_c_, NNACL_ERR); + float *src_ptr = group_conv->sub_in_src_ + begin_plane * group_conv->ori_in_c_; + float *dst_ptr = group_conv->sub_in_dst_ + begin_plane * group_conv->sub_in_c_; + for (int i = begin_plane; i < end_plane; ++i) { + (void)memcpy(dst_ptr, src_ptr, group_conv->sub_in_c_ * sizeof(float)); + src_ptr += group_conv->ori_in_c_; + dst_ptr += group_conv->sub_in_c_; + } + + return NNACL_OK; +} + +int GroupConvSeparateInput(GroupConvolutionStruct *group_conv, int group_id) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(group_id, group_conv->sub_in_c_, NNACL_ERR); + + group_conv->sub_in_src_ = (float *)(group_conv->origin_input_data_) + group_id * group_conv->sub_in_c_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->sub_in_src_); + group_conv->sub_in_dst_ = (float *)(group_conv->group_convs_[group_id]->in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(group_conv->sub_in_dst_); + + return group_conv->conv_base_.base_.env_->ParallelLaunch(group_conv->conv_base_.base_.env_->thread_pool_, + GroupConvSeparateInputRun, group_conv, + group_conv->conv_base_.base_.thread_nr_); +} + +void GroupConvUpdateShape(GroupConvolutionStruct *group_conv) { + for (int i = 0; i < group_conv->group_; i++) { + TensorC *in_tensor = group_conv->conv_base_.base_.in_[FIRST_INPUT]; + int in_shape[] = {NNACLGetBatch(in_tensor), NNACLGetHeight(in_tensor), NNACLGetWidth(in_tensor), + group_conv->sub_in_c_}; + memcpy(group_conv->group_convs_[i]->in_[FIRST_INPUT]->shape_, in_shape, DIMENSION_4D * sizeof(float)); + + TensorC *out_tensor = group_conv->conv_base_.base_.out_[OUTPUT_INDEX]; + int out_shape[] = {NNACLGetBatch(out_tensor), NNACLGetHeight(out_tensor), NNACLGetWidth(out_tensor), + group_conv->sub_out_c_}; + memcpy(group_conv->group_convs_[i]->out_[OUTPUT_INDEX]->shape_, out_shape, DIMENSION_4D * sizeof(float)); + } + return; +} + +int GroupConvolutionResize(KernelBase *self) { + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + + (void)ConvBaseUpdateComputeInfo(&group_conv->conv_base_); + self->thread_nr_ = NNACL_MIN(NNACL_MAX(1, self->thread_nr_), group_conv->conv_base_.compute_.in_hw_); + self->thread_nr_ = NNACL_MIN(NNACL_MAX(1, self->thread_nr_), group_conv->conv_base_.compute_.in_hw_); + + GroupConvUpdateShape(group_conv); + + for (int i = 0; i < group_conv->group_; ++i) { + group_conv->group_convs_[i]->thread_nr_ = self->thread_nr_; + int ret = group_conv->group_convs_[i]->Resize(group_conv->group_convs_[i]); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int GroupConvolutionCompute(KernelBase *self) { + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + + group_conv->origin_input_data_ = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->origin_input_data_); + group_conv->origin_output_data_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(group_conv->origin_output_data_); + + for (int i = 0; i < group_conv->group_; ++i) { + // first, malloc data for sub_kernel's tensors. + TensorC *sub_kernel_in_tensor = group_conv->group_convs_[i]->in_[FIRST_INPUT]; + sub_kernel_in_tensor->data_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(sub_kernel_in_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(sub_kernel_in_tensor->data_); + + TensorC *sub_kernel_out_tensor = group_conv->group_convs_[i]->out_[OUTPUT_INDEX]; + sub_kernel_out_tensor->data_ = self->env_->Alloc(self->env_->allocator_, NNACLGetSize(sub_kernel_out_tensor)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(sub_kernel_out_tensor->data_); + + // second, separate group conv input into several parts. This step must be in runtime stage. + int ret = GroupConvSeparateInput(group_conv, i); + if (ret != NNACL_OK) { + return ret; + } + + // sun kernels run + ret = group_conv->group_convs_[i]->Compute(group_conv->group_convs_[i]); + if (ret != NNACL_OK) { + return ret; + } + + // post process, concat all outputs of sub-kernels into one output + ret = GroupConvPostConcat(group_conv, i); + if (ret != NNACL_OK) { + return ret; + } + + // Free data + self->env_->Free(self->env_->allocator_, sub_kernel_in_tensor->data_); + sub_kernel_in_tensor->data_ = NULL; + self->env_->Free(self->env_->allocator_, sub_kernel_out_tensor->data_); + sub_kernel_out_tensor->data_ = NULL; + } + return NNACL_OK; +} + +int GroupConvolutionPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ != ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + NNACL_CHECK_FALSE(group_conv->group_ == 0, NNACL_GROUP_CONVOLUTION_GROUP_INVALID); + + GroupConvCreatorSetShapeOfTensors(group_conv); + + group_conv->group_convs_ = (KernelBase **)malloc(group_conv->group_ * sizeof(KernelBase *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(group_conv->group_convs_); + memset(group_conv->group_convs_, 0, group_conv->group_ * sizeof(KernelBase *)); + + for (int i = 0; i < group_conv->group_; ++i) { + KernelBase *new_conv = CreateConvlutionDelegate(&group_conv->new_conv_param_); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(new_conv); + group_conv->group_convs_[i] = new_conv; + + int ret = GroupConvSetSubConvInfo(group_conv, new_conv, i); + if (ret != NNACL_OK) { + return ret; + } + } + return GroupConvBasePrepare(group_conv); +} + +void GroupConvReleaseSubConv(KernelBase *current_conv) { + (void)current_conv->Release(current_conv); + + if (current_conv->in_ != NULL) { + for (int j = 0; j < current_conv->in_size_; j++) { + if (NNACLIsConst(current_conv->in_[j])) { + free(current_conv->in_[j]->data_); + current_conv->in_[j]->data_ = NULL; + } + if (current_conv->in_[j] != NULL) { + free(current_conv->in_[j]); + current_conv->in_[j] = NULL; + } + } + free(current_conv->in_); + current_conv->in_ = NULL; + } + + if (current_conv->out_ != NULL) { + for (int j = 0; j < current_conv->out_size_; j++) { + if (current_conv->out_[j] != NULL) { + free(current_conv->out_[j]); + current_conv->out_[j] = NULL; + } + } + free(current_conv->out_); + current_conv->out_ = NULL; + } +} + +int GroupConvolutionRelease(KernelBase *self) { + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(group_conv); + ConvParameter *conv_param = (ConvParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(conv_param); + + if (group_conv->group_convs_ != NULL) { + for (int i = 0; i < conv_param->group_; i++) { + if (group_conv->group_convs_[i] != NULL) { + GroupConvReleaseSubConv(group_conv->group_convs_[i]); + free(group_conv->group_convs_[i]); + group_conv->group_convs_[i] = NULL; + } + } + free(group_conv->group_convs_); + group_conv->group_convs_ = NULL; + } + return NNACL_OK; +} + +KernelBase *CreateGroupConvolution(ConvParameter *conv_param, TypeIdC data_type) { + GroupConvolutionStruct *group_conv = (GroupConvolutionStruct *)malloc(sizeof(GroupConvolutionStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(group_conv); + memset(group_conv, 0, sizeof(GroupConvolutionStruct)); + + group_conv->data_type_ = data_type; + group_conv->group_ = conv_param->group_; + group_conv->conv_base_.base_.Compute = GroupConvolutionCompute; + group_conv->conv_base_.base_.Resize = GroupConvolutionResize; + group_conv->conv_base_.base_.Prepare = GroupConvolutionPrepare; + group_conv->conv_base_.base_.Release = GroupConvolutionRelease; + return (KernelBase *)group_conv; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/group_convolution.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/group_convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..de4789dd1d1e6314dc877dc10c8bc662f85d542a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/group_convolution.h @@ -0,0 +1,49 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_GROUP_CONVOLUTION_H_ +#define NNACL_KERNEL_GROUP_CONVOLUTION_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/conv_parameter.h" +#include "nnacl/kernel/convolution_base.h" + +typedef struct GroupConvolutionStruct { + ConvolutionBaseStruct conv_base_; + KernelBase **group_convs_; + ConvParameter new_conv_param_; + TypeIdC data_type_; + int group_; + + void *origin_input_data_; + void *origin_output_data_; + + float *sub_in_src_; + float *sub_in_dst_; + float *sub_out_src_; + float *sub_out_dst_; + + int sub_in_c_; + int ori_in_c_; + int sub_out_c_; + int ori_out_c_; +} GroupConvolutionStruct; + +KernelBase *CreateGroupConvolution(ConvParameter *conv_param, TypeIdC data_type); + +#endif // NNACL_KERNEL_GROUP_CONVOLUTION_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/group_norm.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/group_norm.c new file mode 100644 index 0000000000000000000000000000000000000000..2f6f148caca4ac1a2cef586789cc0387099daf36 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/group_norm.c @@ -0,0 +1,122 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/group_norm.h" +#include "nnacl/fp32/group_norm_fp32.h" +#include "nnacl/group_norm_parameter.h" +#include "nnacl/op_base.h" +#include "nnacl/errorcode.h" +#include "nnacl/tensor_c.h" +#include "nnacl/tensor_c_utils.h" + +int GroupNormResize(struct KernelBase *self) { + GroupNormStru *groupnorm = (GroupNormStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(groupnorm); + GroupNormParameter *param = (GroupNormParameter *)groupnorm->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_FALSE(self->in_size_ < kInputSize2, NNACL_TENSOR_SIZE_INVALID); + NNACL_CHECK_FALSE(self->out_size_ < 1, NNACL_TENSOR_SIZE_INVALID); + + self->Release(self); + + TensorC *in0 = self->in_[0]; + NNACL_CHECK_FALSE(in0->shape_size_ < C1NUM, NNACL_GROUP_NORM_SHAPE_SIZE_INVALID); + NNACL_CHECK_FALSE(in0->format_ != Format_NCHW, NNACL_GROUP_NORM_FORMAT_INVALID); + + param->unit_ = NNACLGetHeight(in0) * NNACLGetWidth(in0); + param->batch_ = NNACLGetBatch(in0); + param->channel_ = NNACLGetChannel(in0); + return self->Prepare(self); +} + +int GroupNormPrepare(struct KernelBase *self) { + GroupNormStru *groupnorm = (GroupNormStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(groupnorm); + GroupNormParameter *param = (GroupNormParameter *)groupnorm->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + NNACL_CHECK_FALSE(param->num_groups_ < 0, NNACL_GROUP_NORM_NUM_GROUPS_INVALID); + NNACL_CHECK_FALSE(param->channel_ % param->num_groups_, NNACL_GROUP_NORM_NUM_GROUPS_INVALID); + NNACL_CHECK_FALSE(param->num_groups_ == 0, NNACL_GROUP_NORM_NUM_GROUPS_INVALID); + + size_t mean_var_elem_num = param->num_groups_; + param->mean_ = malloc(mean_var_elem_num * sizeof(float)); + param->variance_ = malloc(mean_var_elem_num * sizeof(float)); + if (param->mean_ == NULL || param->variance_ == NULL) { + self->Release(self); + return NNACL_MALLOC_BUFFER_FAILED; + } + return NNACL_OK; +} + +int GroupNormRelease(struct KernelBase *self) { + GroupNormStru *groupnorm = (GroupNormStru *)self; + NNACL_CHECK_NULL_RETURN_ERR(groupnorm); + GroupNormParameter *param = (GroupNormParameter *)groupnorm->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + if (param->mean_ != NULL) { + free(param->mean_); + param->mean_ = NULL; + } + if (param->variance_ != NULL) { + free(param->variance_); + param->variance_ = NULL; + } + + return NNACL_OK; +} + +int GroupNormImpl(void *param, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(param); + GroupNormStru *groupnorm_stru = (GroupNormStru *)param; + GroupNormParameter *groupnorm_param = (GroupNormParameter *)groupnorm_stru->base.param_; + NNACL_CHECK_NULL_RETURN_ERR(groupnorm_param); + + const void *input_data = groupnorm_stru->base.in_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + const void *scale_data = groupnorm_stru->base.in_[C1NUM]->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale_data); + const void *offset_data = groupnorm_stru->base.in_[C2NUM]->data_; + NNACL_CHECK_NULL_RETURN_ERR(offset_data); + void *output_data = groupnorm_stru->base.out_[0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + NNACL_CHECK_NULL_RETURN_ERR(groupnorm_param->mean_); + NNACL_CHECK_NULL_RETURN_ERR(groupnorm_param->variance_); + + int ret = GroupNormFp32(input_data, scale_data, offset_data, groupnorm_param->mean_, groupnorm_param->variance_, + groupnorm_param, task_id, output_data); + + return ret; +} + +int GroupNormCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, GroupNormImpl, self, self->param_->thread_num_); +} + +KernelBase *CreateGroupNorm(OpParameter *param, int data_type) { + GroupNormStru *groupnorm = (GroupNormStru *)malloc(sizeof(GroupNormStru)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(groupnorm); + + groupnorm->base.Prepare = GroupNormPrepare; + groupnorm->base.Resize = GroupNormResize; + groupnorm->base.Release = GroupNormRelease; + groupnorm->base.Compute = GroupNormCompute; + + return (void *)groupnorm; +} + +REG_KERNEL_CREATOR(PrimType_GroupNormFusion, kNumberTypeFloat32, CreateGroupNorm); diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/group_norm.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/group_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..37b861f9c2421d03027a8f35307031ebbc1d3a90 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/group_norm.h @@ -0,0 +1,31 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_GROUP_NORM_H_ +#define NNACL_KERNEL_GROUP_NORM_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/group_norm_parameter.h" +#include "nnacl/kernel.h" + +typedef struct GroupNormStru { + KernelBase base; +} GroupNormStru; + +KernelBase *CreateGroupNorm(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_GROUP_NORM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/init_exec_env.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/init_exec_env.c new file mode 100644 index 0000000000000000000000000000000000000000..053992a2fb3fc340393579d2bafbe3b1b746a265 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/init_exec_env.c @@ -0,0 +1,51 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/init_exec_env.h" + +#define NNACLMaxAllocSize (2000 * 1024 * 1024) +ExecEnv nnacl_default_env; + +void *NNACLDefaultAlloc(void *allocator, size_t sz) { + if (sz == 0 || sz > NNACLMaxAllocSize) { + return NULL; + } + return malloc(sz); +} + +void NNACLDefaultFree(void *allocator, void *ptr) { return free(ptr); } + +int NNACLDefaultParallelLunch(void *threadPool, void *task, void *param, int taskNr) { + int (*function)(void *cdata, int task_id, float l, float r) = task; + int ret = 0; + for (int i = 0; i < taskNr; i++) { + ret += function(param, i, 0, 1); + } + return ret == NNACL_OK ? NNACL_OK : NNACL_ERR; +} + +void InitDefaultExecEnv(void) { + nnacl_default_env.Free = NNACLDefaultFree; + nnacl_default_env.Alloc = NNACLDefaultAlloc; + nnacl_default_env.ParallelLaunch = NNACLDefaultParallelLunch; +} + +void CheckExecEnv(KernelBase *base) { + if (base->env_ == NULL) { + base->env_ = &nnacl_default_env; + } + return; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/init_exec_env.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/init_exec_env.h new file mode 100644 index 0000000000000000000000000000000000000000..1f33d3f48a7ac9d139c8a2c12eb901697f61c9ac --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/init_exec_env.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_INIT_EXEC_ENV_H_ +#define NNACL_KERNEL_INIT_EXEC_ENV_H_ + +#include "nnacl/kernel.h" + +#ifndef _MSC_VER +__attribute__((constructor(103))) void InitDefaultExecEnv(void); +#endif + +void CheckExecEnv(KernelBase *base); + +#endif // NNACL_KERNEL_INIT_EXEC_ENV_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/init_vs_kernels.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/init_vs_kernels.c new file mode 100644 index 0000000000000000000000000000000000000000..e400154567b560a74551d203a8ef1831d5e51837 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/init_vs_kernels.c @@ -0,0 +1,357 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/init_vs_kernels.h" +#include "nnacl/kernel/activation.h" +#include "nnacl/kernel/arithmetic.h" +#include "nnacl/kernel/arithmetic_compare.h" +#include "nnacl/kernel/arithmetic_self.h" +#include "nnacl/kernel/arg_min_max.h" +#include "nnacl/kernel/addn.h" +#include "nnacl/kernel/biasadd.h" +#include "nnacl/kernel/batch_norm.h" +#include "nnacl/kernel/clip.h" +#include "nnacl/kernel/concat.h" +#include "nnacl/kernel/crop.h" +#include "nnacl/kernel/crop_and_resize.h" +#include "nnacl/kernel/exp.h" +#include "nnacl/kernel/depth_to_space.h" +#include "nnacl/kernel/fill.h" +#include "nnacl/kernel/fused_batch_norm.h" +#include "nnacl/kernel/fullconnection.h" +#include "nnacl/kernel/gather.h" +#include "nnacl/kernel/gather_d.h" +#include "nnacl/kernel/gather_nd.h" +#include "nnacl/kernel/group_norm.h" +#include "nnacl/kernel/log_softmax.h" +#include "nnacl/kernel/local_response_norm.h" +#include "nnacl/kernel/layer_norm.h" +#include "nnacl/kernel/matmul.h" +#include "nnacl/kernel/non_max_suppression.h" +#include "nnacl/kernel/non_zero.h" +#include "nnacl/kernel/nllloss.h" +#include "nnacl/kernel/prior_box.h" +#include "nnacl/kernel/prelu.h" +#include "nnacl/kernel/pad.h" +#include "nnacl/kernel/pow.h" +#include "nnacl/kernel/reshape.h" +#include "nnacl/kernel/reverse.h" +#include "nnacl/kernel/range.h" +#include "nnacl/kernel/rank.h" +#include "nnacl/kernel/scale.h" +#include "nnacl/kernel/shape.h" +#include "nnacl/kernel/reduce.h" +#include "nnacl/kernel/ragged_range.h" +#include "nnacl/kernel/stack.h" +#include "nnacl/kernel/strided_slice.h" +#include "nnacl/kernel/softmax.h" +#include "nnacl/kernel/size.h" +#include "nnacl/kernel/splice.h" +#include "nnacl/kernel/tile.h" +#include "nnacl/kernel/tril.h" +#include "nnacl/kernel/triu.h" +#include "nnacl/kernel/transpose.h" +#include "nnacl/kernel/slice.h" +#include "nnacl/kernel/unique.h" +#ifdef ENABLE_FP16 +#include "nnacl/kernel/f16/arithmetic_f16.h" +#include "nnacl/kernel/f16/arithmetic_compare_f16.h" +#include "nnacl/kernel/f16/concat_f16.h" +#include "nnacl/kernel/f16/reduce_f16.h" +#include "nnacl/kernel/f16/stack_f16.h" +#endif + +void InitVSKernelF16(KernelCreator **creators) { +#ifdef ENABLE_FP16 + creators[PrimType_Abs][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Activation][REGIST_DT(kNumberTypeFloat16)] = CreateActivation; + creators[PrimType_AddFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_AddN][REGIST_DT(kNumberTypeFloat16)] = CreateAddN; + creators[PrimType_ArgMinFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArgMinMax; + creators[PrimType_ArgMaxFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArgMinMax; + creators[PrimType_BatchNorm][REGIST_DT(kNumberTypeFloat16)] = CreateBatchNorm; + creators[PrimType_Ceil][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Concat][REGIST_DT(kNumberTypeFloat16)] = CreateConcatF16; + creators[PrimType_Cos][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Crop][REGIST_DT(kNumberTypeFloat16)] = CreateCrop; + creators[PrimType_DivFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_DepthToSpace][REGIST_DT(kNumberTypeFloat16)] = CreateDepthToSpace; + creators[PrimType_Eltwise][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Erf][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Equal][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_Fill][REGIST_DT(kNumberTypeFloat16)] = CreateFill; + creators[PrimType_Flatten][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_FlattenGrad][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_Floor][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_FloorMod][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_FloorDiv][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_FusedBatchNorm][REGIST_DT(kNumberTypeFloat16)] = CreateFusedBatchNorm; + creators[PrimType_Gather][REGIST_DT(kNumberTypeFloat16)] = CreateGather; + creators[PrimType_GatherD][REGIST_DT(kNumberTypeFloat16)] = CreateGatherD; + creators[PrimType_GatherNd][REGIST_DT(kNumberTypeFloat16)] = CreateGatherNd; + creators[PrimType_Greater][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_GreaterEqual][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_Less][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_LessEqual][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_LayerNormFusion][REGIST_DT(kNumberTypeFloat16)] = CreateLayerNorm; + creators[PrimType_LogicalAnd][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_LogicalOr][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Log][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_LogSoftmax][REGIST_DT(kNumberTypeFloat16)] = CreateLogSoftmax; + creators[PrimType_LogicalNot][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Maximum][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Minimum][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_MulFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Neg][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_NotEqual][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticCompareF16; + creators[PrimType_PadFusion][REGIST_DT(kNumberTypeFloat16)] = CreatePad; + creators[PrimType_PReLUFusion][REGIST_DT(kNumberTypeFloat16)] = CreatePRelu; + creators[PrimType_PowFusion][REGIST_DT(kNumberTypeFloat16)] = CreatePow; + creators[PrimType_Reshape][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_RealDiv][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_ReduceFusion][REGIST_DT(kNumberTypeFloat16)] = CreateReduceF16; + creators[PrimType_Rsqrt][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Round][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Reciprocal][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_ScaleFusion][REGIST_DT(kNumberTypeFloat16)] = CreateScale; + creators[PrimType_Shape][REGIST_DT(kNumberTypeFloat16)] = CreateShape; + creators[PrimType_Softmax][REGIST_DT(kNumberTypeFloat16)] = CreateSoftmax; + creators[PrimType_Stack][REGIST_DT(kNumberTypeFloat16)] = CreateStackF16; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeFloat16)] = CreateStridedSlice; + creators[PrimType_Squeeze][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_SubFusion][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_SquaredDifference][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticF16; + creators[PrimType_Splice][REGIST_DT(kNumberTypeFloat16)] = CreateSplice; + creators[PrimType_Sin][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Size][REGIST_DT(kNumberTypeFloat16)] = CreateSize; + creators[PrimType_SliceFusion][REGIST_DT(kNumberTypeFloat16)] = CreateSlice; + creators[PrimType_Square][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_Sqrt][REGIST_DT(kNumberTypeFloat16)] = CreateArithmeticSelf; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeFloat16)] = CreateTile; + creators[PrimType_Triu][REGIST_DT(kNumberTypeFloat16)] = CreateTriu; + creators[PrimType_Tril][REGIST_DT(kNumberTypeFloat16)] = CreateTril; + creators[PrimType_Transpose][REGIST_DT(kNumberTypeFloat16)] = CreateTranspose; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeFloat16)] = CreateReshape; + creators[PrimType_Unique][REGIST_DT(kNumberTypeFloat16)] = CreateUnique; +#endif +} + +void InitVSKernelA(KernelCreator **creators) { + creators[PrimType_Abs][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Abs][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticSelf; + creators[PrimType_Activation][REGIST_DT(kNumberTypeFloat32)] = CreateActivation; + creators[PrimType_Activation][REGIST_DT(kNumberTypeUInt32)] = CreateActivation; + creators[PrimType_AddFusion][REGIST_DT(kNumberTypeBool)] = CreateArithmetic; + creators[PrimType_AddFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_AddFusion][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_AddN][REGIST_DT(kNumberTypeFloat32)] = CreateAddN; + creators[PrimType_ArgMinFusion][REGIST_DT(kNumberTypeInt32)] = CreateArgMinMax; + creators[PrimType_ArgMinFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArgMinMax; + creators[PrimType_ArgMaxFusion][REGIST_DT(kNumberTypeInt32)] = CreateArgMinMax; + creators[PrimType_ArgMaxFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArgMinMax; + creators[PrimType_BiasAdd][REGIST_DT(kNumberTypeFloat32)] = CreateBiasAdd; + creators[PrimType_BatchNorm][REGIST_DT(kNumberTypeFloat32)] = CreateBatchNorm; + creators[PrimType_Ceil][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Cos][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Clip][REGIST_DT(kNumberTypeFloat)] = CreateClip; + creators[PrimType_Clip][REGIST_DT(kNumberTypeFloat32)] = CreateClip; + creators[PrimType_Clip][REGIST_DT(kNumberTypeInt)] = CreateClip; + creators[PrimType_Clip][REGIST_DT(kNumberTypeInt32)] = CreateClip; + creators[PrimType_Concat][REGIST_DT(kNumberTypeBool)] = CreateConcat; + creators[PrimType_Concat][REGIST_DT(kNumberTypeInt32)] = CreateConcat; + creators[PrimType_Concat][REGIST_DT(kNumberTypeFloat32)] = CreateConcat; + creators[PrimType_Crop][REGIST_DT(kNumberTypeInt32)] = CreateCrop; + creators[PrimType_Crop][REGIST_DT(kNumberTypeFloat32)] = CreateCrop; + creators[PrimType_CropAndResize][REGIST_DT(kNumberTypeFloat32)] = CreateCropAndResize; + creators[PrimType_DepthToSpace][REGIST_DT(kNumberTypeFloat32)] = CreateDepthToSpace; + creators[PrimType_DivFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_DivFusion][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_Eltwise][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Equal][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_Equal][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_Erf][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_ExpFusion][REGIST_DT(kNumberTypeFloat32)] = CreateExp; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeBool)] = CreateReshape; + creators[PrimType_ExpandDims][REGIST_DT(kNumberTypeInt8)] = CreateReshape; + creators[PrimType_Fill][REGIST_DT(kNumberTypeBool)] = CreateFill; + creators[PrimType_Fill][REGIST_DT(kNumberTypeInt32)] = CreateFill; + creators[PrimType_Fill][REGIST_DT(kNumberTypeFloat32)] = CreateFill; + creators[PrimType_Flatten][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_Flatten][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_FlattenGrad][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_Floor][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_FloorDiv][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_FloorDiv][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_FloorMod][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_FloorMod][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_FullConnection][REGIST_DT(kNumberTypeFloat32)] = CreateFullconnection; + creators[PrimType_FusedBatchNorm][REGIST_DT(kNumberTypeFloat32)] = CreateFusedBatchNorm; + creators[PrimType_Gather][REGIST_DT(kNumberTypeFloat32)] = CreateGather; + creators[PrimType_Gather][REGIST_DT(kNumberTypeInt32)] = CreateGather; + creators[PrimType_Gather][REGIST_DT(kNumberTypeBool)] = CreateGather; + creators[PrimType_GatherD][REGIST_DT(kNumberTypeFloat32)] = CreateGatherD; + creators[PrimType_GatherD][REGIST_DT(kNumberTypeInt32)] = CreateGatherD; + creators[PrimType_GatherNd][REGIST_DT(kNumberTypeBool)] = CreateGatherNd; + creators[PrimType_GatherNd][REGIST_DT(kNumberTypeInt32)] = CreateGatherNd; + creators[PrimType_GatherNd][REGIST_DT(kNumberTypeFloat32)] = CreateGatherNd; + creators[PrimType_Greater][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_Greater][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_GreaterEqual][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_GreaterEqual][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_GroupNormFusion][REGIST_DT(kNumberTypeFloat32)] = CreateGroupNorm; +} + +void InitVSKernelI(KernelCreator **creators) { + creators[PrimType_IsFinite][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_LayerNormFusion][REGIST_DT(kNumberTypeFloat32)] = CreateLayerNorm; + creators[PrimType_Less][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_Less][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_LessEqual][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_LessEqual][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_LogicalAnd][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_LogicalAnd][REGIST_DT(kNumberTypeBool)] = CreateArithmetic; + creators[PrimType_LogicalAnd][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_LogicalOr][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_LogicalOr][REGIST_DT(kNumberTypeBool)] = CreateArithmetic; + creators[PrimType_Log][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_LogSoftmax][REGIST_DT(kNumberTypeFloat32)] = CreateLogSoftmax; + creators[PrimType_Log1p][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_LogicalNot][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_LogicalNot][REGIST_DT(kNumberTypeBool)] = CreateArithmeticSelf; + creators[PrimType_LRN][REGIST_DT(kNumberTypeFloat32)] = CreateLocalResponseNorm; + creators[PrimType_Maximum][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Maximum][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_MatMulFusion][REGIST_DT(kNumberTypeFloat32)] = CreateMatmul; + creators[PrimType_Mod][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Mod][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_MulFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_MulFusion][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_Minimum][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Minimum][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_NLLLoss][REGIST_DT(kNumberTypeFloat32)] = CreateNLLLoss; + creators[PrimType_Neg][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Neg][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticSelf; + creators[PrimType_NotEqual][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticCompare; + creators[PrimType_NotEqual][REGIST_DT(kNumberTypeInt32)] = CreateArithmeticCompare; + creators[PrimType_NotEqual][REGIST_DT(kNumberTypeInt64)] = CreateArithmeticCompare; + creators[PrimType_NonZero][REGIST_DT(kNumberTypeBool)] = CreateNonZero; + creators[PrimType_NonMaxSuppression][REGIST_DT(kNumberTypeFloat32)] = CreateNonMaxSuppression; + creators[PrimType_PadFusion][REGIST_DT(kNumberTypeFloat32)] = CreatePad; + creators[PrimType_PriorBox][REGIST_DT(kNumberTypeFloat32)] = CreatePriorBox; + creators[PrimType_PriorBox][REGIST_DT(kNumberTypeInt8)] = CreatePriorBox; + creators[PrimType_PowFusion][REGIST_DT(kNumberTypeFloat32)] = CreatePow; + creators[PrimType_PReLUFusion][REGIST_DT(kNumberTypeFloat32)] = CreatePRelu; +} + +void InitVSKernelR(KernelCreator **creators) { + creators[PrimType_RaggedRange][REGIST_DT(kNumberTypeInt32)] = CreateRaggedRange; + creators[PrimType_RaggedRange][REGIST_DT(kNumberTypeFloat32)] = CreateRaggedRange; + creators[PrimType_Range][REGIST_DT(kNumberTypeFloat32)] = CreateRange; + creators[PrimType_Range][REGIST_DT(kNumberTypeInt32)] = CreateRange; + creators[PrimType_Range][REGIST_DT(kNumberTypeFloat16)] = CreateRange; + creators[PrimType_Rank][REGIST_DT(kNumberTypeFloat32)] = CreateRank; + creators[PrimType_Rank][REGIST_DT(kNumberTypeFloat32)] = CreateRank; + creators[PrimType_Reshape][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_Reshape][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_Reshape][REGIST_DT(kNumberTypeBool)] = CreateReshape; + creators[PrimType_RealDiv][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_ReduceFusion][REGIST_DT(kNumberTypeBool)] = CreateReduce; + creators[PrimType_ReduceFusion][REGIST_DT(kNumberTypeInt32)] = CreateReduce; + creators[PrimType_ReduceFusion][REGIST_DT(kNumberTypeFloat32)] = CreateReduce; + creators[PrimType_Reciprocal][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_ReverseV2][REGIST_DT(kNumberTypeInt32)] = CreateReverse; + creators[PrimType_ReverseV2][REGIST_DT(kNumberTypeFloat32)] = CreateReverse; + creators[PrimType_Round][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Rsqrt][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_ScaleFusion][REGIST_DT(kNumberTypeFloat32)] = CreateScale; + creators[PrimType_Shape][REGIST_DT(kNumberTypeInt32)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeBool)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeFloat32)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeInt8)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeUInt8)] = CreateShape; + creators[PrimType_Shape][REGIST_DT(kNumberTypeInt64)] = CreateShape; + creators[PrimType_Softmax][REGIST_DT(kNumberTypeFloat32)] = CreateSoftmax; + creators[PrimType_SquaredDifference][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_Stack][REGIST_DT(kNumberTypeFloat32)] = CreateStack; + creators[PrimType_Stack][REGIST_DT(kNumberTypeInt32)] = CreateStack; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeFloat32)] = CreateStridedSlice; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeInt64)] = CreateStridedSlice; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeInt32)] = CreateStridedSlice; + creators[PrimType_StridedSlice][REGIST_DT(kNumberTypeInt8)] = CreateStridedSlice; + creators[PrimType_Square][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Sqrt][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Sin][REGIST_DT(kNumberTypeFloat32)] = CreateArithmeticSelf; + creators[PrimType_Size][REGIST_DT(kNumberTypeInt32)] = CreateSize; + creators[PrimType_Size][REGIST_DT(kNumberTypeFloat32)] = CreateSize; + creators[PrimType_SliceFusion][REGIST_DT(kNumberTypeInt32)] = CreateSlice; + creators[PrimType_SliceFusion][REGIST_DT(kNumberTypeFloat32)] = CreateSlice; + creators[PrimType_SubFusion][REGIST_DT(kNumberTypeFloat32)] = CreateArithmetic; + creators[PrimType_SubFusion][REGIST_DT(kNumberTypeInt32)] = CreateArithmetic; + creators[PrimType_Squeeze][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_Squeeze][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_Squeeze][REGIST_DT(kNumberTypeBool)] = CreateReshape; + creators[PrimType_Splice][REGIST_DT(kNumberTypeFloat32)] = CreateSplice; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeInt32)] = CreateTile; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeFloat32)] = CreateTile; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeBool)] = CreateTile; + creators[PrimType_TileFusion][REGIST_DT(kNumberTypeUInt8)] = CreateTile; + creators[PrimType_Triu][REGIST_DT(kNumberTypeDouble)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeFloat)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeFloat64)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeFloat32)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt64)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt32)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt16)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeInt8)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeUInt64)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeUInt32)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeUInt16)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeUInt8)] = CreateTriu; + creators[PrimType_Triu][REGIST_DT(kNumberTypeBool)] = CreateTriu; + creators[PrimType_Tril][REGIST_DT(kNumberTypeDouble)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeFloat)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeFloat64)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeFloat32)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt64)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt32)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt16)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeInt8)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeUInt64)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeUInt32)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeUInt16)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeUInt8)] = CreateTril; + creators[PrimType_Tril][REGIST_DT(kNumberTypeBool)] = CreateTril; + creators[PrimType_Transpose][REGIST_DT(kNumberTypeFloat32)] = CreateTranspose; + creators[PrimType_Transpose][REGIST_DT(kNumberTypeInt32)] = CreateTranspose; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeFloat32)] = CreateReshape; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeInt32)] = CreateReshape; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeInt64)] = CreateReshape; + creators[PrimType_Unsqueeze][REGIST_DT(kNumberTypeBool)] = CreateReshape; + creators[PrimType_Unique][REGIST_DT(kNumberTypeInt32)] = CreateUnique; + creators[PrimType_Unique][REGIST_DT(kNumberTypeFloat32)] = CreateUnique; +} + +void init_vs_kernels(KernelCreator **creators) { + InitVSKernelA(creators); + InitVSKernelI(creators); + InitVSKernelR(creators); + InitVSKernelF16(creators); +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/init_vs_kernels.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/init_vs_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..676ff1ca0be1ee09304166fcde7a02e63589378e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/init_vs_kernels.h @@ -0,0 +1,20 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_INIT_VS_KERNELS_H_ +#define NNACL_KERNEL_INIT_VS_KERNELS_H_ +#include "nnacl/kernel.h" +void init_vs_kernels(KernelCreator **creators); +#endif // NNACL_KERNEL_INIT_VS_KERNELS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/layer_norm.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/layer_norm.c new file mode 100644 index 0000000000000000000000000000000000000000..1dac590f49e83d11165e4e2e6852ebe6e896da8c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/layer_norm.c @@ -0,0 +1,130 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either log_softmaxress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/layer_norm.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/fp32/layer_norm_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/layer_norm_fp16.h" +#endif + +int LayerNormRun(void *cdata, int task_id, float l, float r) { + LayerNormStruct *ln = (LayerNormStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(ln); + if (ln->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + return LayerNormFp16(ln->src_data_, ln->gamma_data_, ln->beta_data_, ln->dst_data_, ln->mean_data_, ln->var_data_, + &ln->compute_, task_id, ln->base_.thread_nr_); +#endif + } + return LayerNorm(ln->src_data_, ln->gamma_data_, ln->beta_data_, ln->dst_data_, ln->mean_data_, ln->var_data_, + &ln->compute_, task_id, ln->base_.thread_nr_); +} + +int LayerNormResize(KernelBase *self) { + LayerNormStruct *layer_norm = (LayerNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm); + LayerNormComputeParam *compute = &layer_norm->compute_; + + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + + if (compute->begin_norm_axis_ < 0) { + compute->begin_norm_axis_ = compute->begin_norm_axis_ + (int)input->shape_size_; + } + + if (compute->begin_params_axis_ < 0) { + compute->begin_params_axis_ = compute->begin_params_axis_ + (int)input->shape_size_; + } + + compute->norm_outer_size_ = 1; + for (int i = 0; i < compute->begin_norm_axis_; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->norm_outer_size_, input->shape_[i], NNACL_ERR); + compute->norm_outer_size_ *= input->shape_[i]; + } + + compute->norm_inner_size_ = 1; + for (size_t i = compute->begin_norm_axis_; i < input->shape_size_; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->norm_inner_size_, input->shape_[i], NNACL_ERR); + compute->norm_inner_size_ *= input->shape_[i]; + } + + compute->params_outer_size_ = 1; + for (int i = 0; i < compute->begin_params_axis_; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->params_outer_size_, input->shape_[i], NNACL_ERR); + compute->params_outer_size_ *= input->shape_[i]; + } + + compute->params_inner_size_ = 1; + for (size_t i = compute->begin_params_axis_; i < input->shape_size_; ++i) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->params_inner_size_, input->shape_[i], NNACL_ERR); + compute->params_inner_size_ *= input->shape_[i]; + } + + int out_num = NNACLGetElementNum(self->out_[OUTPUT_INDEX]); + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_LayerNormFusion), compute->norm_inner_size_, + compute->norm_inner_size_, out_num, self->thread_nr_); + self->thread_nr_ = NNACL_MIN(compute->norm_outer_size_, self->thread_nr_); + return NNACL_OK; +} + +int LayerNormCompute(KernelBase *self) { + LayerNormStruct *layer_norm = (LayerNormStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm); + + layer_norm->src_data_ = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->src_data_); + layer_norm->gamma_data_ = self->in_[SECOND_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->gamma_data_); + layer_norm->beta_data_ = self->in_[THIRD_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->beta_data_); + layer_norm->dst_data_ = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->dst_data_); + + if (layer_norm->base_.out_size_ == THREE_TENSOR) { + layer_norm->mean_data_ = self->out_[Index1]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->mean_data_); + layer_norm->var_data_ = self->out_[Index2]->data_; + NNACL_CHECK_NULL_RETURN_ERR(layer_norm->var_data_); + } else if (layer_norm->base_.out_size_ != ONE_TENSOR) { + return NNACL_LAYER_NORM_OUTPUT_NUM_INVALID; + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, LayerNormRun, self, self->thread_nr_); +} + +KernelBase *CreateLayerNorm(OpParameter *param, int data_type) { + LayerNormStruct *layer_norm = (LayerNormStruct *)malloc(sizeof(LayerNormStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(layer_norm); + memset(layer_norm, 0, sizeof(LayerNormStruct)); + layer_norm->data_type_ = data_type; + + LayerNormParameter *layer_norm_param = (LayerNormParameter *)param; + layer_norm->compute_.epsilon_ = layer_norm_param->epsilon_; + layer_norm->compute_.elementwise_affine_ = layer_norm_param->elementwise_affine_; + layer_norm->compute_.begin_norm_axis_ = layer_norm_param->begin_norm_axis_; + layer_norm->compute_.begin_params_axis_ = layer_norm_param->begin_params_axis_; + + layer_norm->base_.Prepare = DefaultPrepare3In1Out; + layer_norm->base_.Release = DefaultRelease; + layer_norm->base_.Resize = LayerNormResize; + layer_norm->base_.Compute = LayerNormCompute; + return (KernelBase *)layer_norm; +} + +REG_KERNEL_CREATOR(PrimType_LayerNormFusion, kNumberTypeFloat16, CreateLayerNorm) +REG_KERNEL_CREATOR(PrimType_LayerNormFusion, kNumberTypeFloat32, CreateLayerNorm) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/layer_norm.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/layer_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..3437ffeb165565ea9511cff6250f3a37d4628ac8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/layer_norm.h @@ -0,0 +1,49 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_LAYER_NORM_H_ +#define NNACL_KERNEL_LAYER_NORM_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct LayerNormComputeParam { + float epsilon_; + bool elementwise_affine_; + int begin_norm_axis_; + int begin_params_axis_; + int norm_inner_size_; + int norm_outer_size_; + int params_inner_size_; + int params_outer_size_; +} LayerNormComputeParam; + +typedef struct LayerNormStruct { + KernelBase base_; + LayerNormComputeParam compute_; + int data_type_; + void *src_data_; + void *dst_data_; + void *gamma_data_; + void *beta_data_; + void *mean_data_; + void *var_data_; +} LayerNormStruct; + +KernelBase *CreateLayerNorm(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_LAYER_NORM_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/local_response_norm.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/local_response_norm.c new file mode 100644 index 0000000000000000000000000000000000000000..5b2a1109798c93e91ef1f24156f074c2c4b94bf9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/local_response_norm.c @@ -0,0 +1,77 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either log_softmaxress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/local_response_norm.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/fp32/local_response_norm_fp32.h" +#include "nnacl/tensor_c_utils.h" + +int LocalResponseNormRun(void *cdata, int task_id, float l, float r) { + LocalResponseNormStruct *lrn = (LocalResponseNormStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(lrn); + LocalResponseNormParameter *param = (LocalResponseNormParameter *)lrn->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + TensorC *input = lrn->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = lrn->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + NNACL_CHECK_FALSE(input->shape_size_ != DIMENSION_4D, NNACL_LOCAL_RESPONSE_NORM_SHAPE_INVALID); + NNACL_CHECK_FALSE(param->depth_radius_ <= 0, NNACL_LOCAL_RESPONSE_NORM_DEPTH_RADIUS_INVALID); + + float *input_ptr = (float *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + float *output_ptr = (float *)output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + + int batch = NNACLGetBatch(input); + int height = NNACLGetHeight(input); + int width = NNACLGetWidth(input); + int channel = NNACLGetChannel(input); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(batch, width, NNACL_ERR); + int size_bw = batch * width; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(size_bw, height, NNACL_ERR); + int outer_size = size_bw * height; + int stride = UP_DIV(outer_size, lrn->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(stride, task_id, NNACL_ERR); + int start = stride * task_id; + int count = MSMIN(stride, outer_size - start); + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(start, channel, NNACL_ERR); + input_ptr += start * channel; + output_ptr += start * channel; + + return LocalResponseNorm(input_ptr, count, channel, output_ptr, param); +} + +int LrnCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, LocalResponseNormRun, self, self->thread_nr_); +} + +KernelBase *CreateLocalResponseNorm(OpParameter *param, int data_type) { + LocalResponseNormStruct *lrn = (LocalResponseNormStruct *)malloc(sizeof(LocalResponseNormStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(lrn); + memset(lrn, 0, sizeof(LocalResponseNormStruct)); + + lrn->base_.Prepare = DefaultPrepare1In1Out; + lrn->base_.Release = DefaultRelease; + lrn->base_.Resize = DefaultResize; + lrn->base_.Compute = LrnCompute; + return (KernelBase *)lrn; +} + +REG_KERNEL_CREATOR(PrimType_LRN, kNumberTypeFloat32, CreateLocalResponseNorm) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/local_response_norm.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/local_response_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..015d0bdb83fcf1e8836c32370f45262c81086e1d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/local_response_norm.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_LOCAL_RESPONSE_NORM_H_ +#define NNACL_KERNEL_LOCAL_RESPONSE_NORM_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct LocalResponseNormStruct { + KernelBase base_; +} LocalResponseNormStruct; + +KernelBase *CreateLocalResponseNorm(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_LOG_SOFTMAX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/log_softmax.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/log_softmax.c new file mode 100644 index 0000000000000000000000000000000000000000..d48b6830df922abe46a092a07f67f12fdb4aa2ba --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/log_softmax.c @@ -0,0 +1,120 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either log_softmaxress or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/log_softmax.h" +#include "nnacl/common_func.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/fp32/log_softmax_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/log_softmax_fp16.h" +#endif + +int LogSoftmaxLastAxisRun(void *cdata, int task_id, float l, float r) { + LogSoftmaxStruct *log_softmax = (LogSoftmaxStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(log_softmax); + + TensorC *in = log_softmax->softmax_.base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + void *input_ptr = in->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + void *output_ptr = log_softmax->softmax_.base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + void *tmp_ptr = log_softmax->softmax_.sum_data_; + NNACL_CHECK_NULL_RETURN_ERR(tmp_ptr); + + int unit = UP_DIV(log_softmax->softmax_.out_plane_size_, log_softmax->softmax_.base_.thread_nr_); + int begin = task_id * unit; + int end = MSMIN(begin + unit, log_softmax->softmax_.out_plane_size_); + int channel = in->shape_[log_softmax->softmax_.axis_]; + int offset = begin * channel; + +#ifdef ENABLE_FP16 + if (log_softmax->softmax_.data_type_ == kNumberTypeFloat16) { + LogSoftmaxLastAxisFp16((const float16_t *)input_ptr + offset, (float16_t *)output_ptr + offset, + (float16_t *)tmp_ptr + offset, end - begin, channel); + return NNACL_OK; + } +#endif + LogSoftmaxLastAxis((const float *)input_ptr + offset, (float *)output_ptr + offset, (float *)tmp_ptr + offset, + end - begin, channel); + return NNACL_OK; +} + +int LogSoftmaxResize(struct KernelBase *self) { + LogSoftmaxStruct *log_softmax = (LogSoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(log_softmax); + + int ret = InitSoftmaxParam(&log_softmax->softmax_); + if (ret != NNACL_OK) { + return ret; + } + + if (log_softmax->softmax_.in_plane_size_ == 1 && log_softmax->softmax_.sum_data_ == NULL) { + TensorC *in = log_softmax->softmax_.base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + SoftmaxStruct *softmax = &log_softmax->softmax_; + + int sum_data_size = softmax->in_plane_size_ * softmax->out_plane_size_ * in->shape_[softmax->axis_]; + softmax->sum_data_ = self->env_->Alloc(self->env_->allocator_, sum_data_size * DataTypeCSize(softmax->data_type_)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(softmax->sum_data_); + } + return NNACL_OK; +} + +int LogSoftmaxCompute(struct KernelBase *self) { + LogSoftmaxStruct *log_softmax = (LogSoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(log_softmax); + + if (log_softmax->softmax_.in_plane_size_ == 1) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, LogSoftmaxLastAxisRun, self, self->thread_nr_); + } + + TensorC *in = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in); + void *input_ptr = in->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + void *output_ptr = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + NNACL_CHECK_NULL_RETURN_ERR(log_softmax->softmax_.sum_data_); + +#ifdef ENABLE_FP16 + if (log_softmax->softmax_.data_type_ == kNumberTypeFloat16) { + LogSoftmaxFp16((const float16_t *)input_ptr, (float16_t *)output_ptr, (float16_t *)log_softmax->softmax_.sum_data_, + in->shape_, in->shape_size_, log_softmax->softmax_.axis_); + return NNACL_OK; + } +#endif + LogSoftmax((const float *)input_ptr, (float *)output_ptr, (float *)log_softmax->softmax_.sum_data_, in->shape_, + in->shape_size_, log_softmax->softmax_.axis_); + return NNACL_OK; +} + +KernelBase *CreateLogSoftmax(OpParameter *param, int data_type) { + LogSoftmaxStruct *log_softmax = (LogSoftmaxStruct *)malloc(sizeof(LogSoftmaxStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(log_softmax); + memset(log_softmax, 0, sizeof(LogSoftmaxStruct)); + + log_softmax->softmax_.sum_data_ = NULL; + log_softmax->softmax_.data_type_ = data_type; + log_softmax->softmax_.base_.Prepare = DefaultPrepare1In1Out; + log_softmax->softmax_.base_.Release = SoftmaxRelease; + log_softmax->softmax_.base_.Resize = LogSoftmaxResize; + log_softmax->softmax_.base_.Compute = LogSoftmaxCompute; + return (KernelBase *)log_softmax; +} + +REG_KERNEL_CREATOR(PrimType_LogSoftmax, kNumberTypeFloat32, CreateLogSoftmax) +REG_KERNEL_CREATOR(PrimType_LogSoftmax, kNumberTypeFloat16, CreateLogSoftmax) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/log_softmax.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/log_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..b4fc79321baed6aa64ef042bbe5d541951b50f7b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/log_softmax.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_LOG_SOFTMAX_H_ +#define NNACL_KERNEL_LOG_SOFTMAX_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/kernel/softmax.h" + +typedef struct LogSoftmaxStruct { + SoftmaxStruct softmax_; +} LogSoftmaxStruct; + +KernelBase *CreateLogSoftmax(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_LOG_SOFTMAX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul.c new file mode 100644 index 0000000000000000000000000000000000000000..76ca1f535c4a8b08805fd3e61c8bcca2e9c27ff9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul.c @@ -0,0 +1,176 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/matmul.h" +#include "nnacl/kernel/matmul_base.h" +#include "nnacl/kernel/matmul_create.h" + +void MatmulInitShapeA(MatmulStruct *matmul) { + int *a_shape = matmul->base_.in_[kInputIndex]->shape_; + size_t a_shape_size = matmul->base_.in_[kInputIndex]->shape_size_; + int batch = 1; + NNACL_CHECK_TRUE_RET_VOID(a_shape_size >= C2NUM); + for (size_t i = 0; i < a_shape_size - C2NUM; ++i) { + batch *= a_shape[i]; + } + matmul->a_batch_ = batch; + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->compute_.row_ = param->a_transpose_ ? a_shape[a_shape_size - 1] : a_shape[a_shape_size - C2NUM]; + matmul->compute_.deep_ = param->a_transpose_ ? a_shape[a_shape_size - C2NUM] : a_shape[a_shape_size - 1]; +} + +void MatmulInitShapeB(MatmulStruct *matmul) { + int *b_shape = matmul->base_.in_[kWeightIndex]->shape_; + size_t b_shape_size = matmul->base_.in_[kWeightIndex]->shape_size_; + int batch = 1; + NNACL_CHECK_TRUE_RET_VOID(b_shape_size >= C2NUM); + for (size_t i = 0; i < b_shape_size - C2NUM; ++i) { + batch *= b_shape[i]; + } + matmul->b_batch_ = batch; + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->compute_.col_ = param->b_transpose_ ? b_shape[b_shape_size - C2NUM] : b_shape[b_shape_size - 1]; + matmul->compute_.deep_ = param->b_transpose_ ? b_shape[b_shape_size - 1] : b_shape[b_shape_size - C2NUM]; +} + +int MatmulInitBroadcastParams(MatmulStruct *matmul) { + TensorC *a = matmul->base_.in_[FIRST_INPUT]; + TensorC *b = matmul->base_.in_[SECOND_INPUT]; + + int max_dim_size = (int)NNACL_MAX(a->shape_size_, b->shape_size_); + max_dim_size = NNACL_MAX(max_dim_size, COMM_SHAPE_SIZE); + + int a_shape[MAX_SHAPE_SIZE] = {0}; + int index = max_dim_size - 1; + for (int i = (int)a->shape_size_ - 1; i >= 0; i--) { + a_shape[index--] = a->shape_[i]; + } + for (; index >= 0;) { + a_shape[index--] = 1; + } + + int b_shape[MAX_SHAPE_SIZE] = {0}; + index = max_dim_size - 1; + for (int i = (int)b->shape_size_ - 1; i >= 0; i--) { + b_shape[index--] = b->shape_[i]; + } + for (; index >= 0;) { + b_shape[index--] = 1; + } + + int batch_sizes[MAX_SHAPE_SIZE] = {0}; + int a_batch_sizes[MAX_SHAPE_SIZE] = {0}; + int b_batch_sizes[MAX_SHAPE_SIZE] = {0}; + for (int i = max_dim_size - Num3; i >= 0; --i) { + if (max_dim_size - Num3 == i) { + batch_sizes[i] = NNACL_MAX(a_shape[i], b_shape[i]); + a_batch_sizes[i] = a_shape[i]; + b_batch_sizes[i] = b_shape[i]; + } else { + batch_sizes[i] = batch_sizes[i + 1] * NNACL_MAX(a_shape[i], b_shape[i]); + a_batch_sizes[i] = a_batch_sizes[i + 1] * a_shape[i]; + b_batch_sizes[i] = b_batch_sizes[i + 1] * b_shape[i]; + } + } + + int out_batch = 1; + for (int i = 0; i < max_dim_size - Num2; ++i) { + int max_v = NNACL_MAX(a_shape[i], b_shape[i]); + int min_v = NNACL_MIN(a_shape[i], b_shape[i]) > 0 ? NNACL_MIN(a_shape[i], b_shape[i]) : 1; + out_batch *= max_v; + if ((max_v != min_v) && ((max_v % min_v) != 0)) { + return NNACL_ERR; + } + } + matmul->batch_ = out_batch; + + MatmulBaseFreeBatchOffset(matmul); + int ret = MatmulBaseMallocBatchOffset(matmul); + if (ret != NNACL_OK) { + return ret; + } + + for (int i = 0; i < matmul->batch_; ++i) { + int delta = i; + int a_offset = 0; + int b_offset = 0; + for (int j = 0; j < max_dim_size - Num2; ++j) { + if (j > 0) { + delta = delta % batch_sizes[j]; + } + if (j >= (MAX_SHAPE_SIZE - 1)) { + return NNACL_ERR; + } + if (j < (max_dim_size - Num3)) { + a_offset += + (delta / batch_sizes[j + 1] * a_shape[j] / NNACL_MAX(a_shape[j], b_shape[j])) * a_batch_sizes[j + 1]; + b_offset += + (delta / batch_sizes[j + 1] * b_shape[j] / NNACL_MAX(a_shape[j], b_shape[j])) * b_batch_sizes[j + 1]; + } else { + a_offset += (delta * a_shape[j] / NNACL_MAX(a_shape[j], b_shape[j])); + b_offset += (delta * b_shape[j] / NNACL_MAX(a_shape[j], b_shape[j])); + } + } + matmul->a_offset_[i] = a_offset; + matmul->b_offset_[i] = b_offset; + } + return NNACL_OK; +} + +int MatmulPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < C2NUM, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < C1NUM, NNACL_ERR); + + MatmulStruct *matmul = (MatmulStruct *)self; + if (matmul->a_const_ || matmul->infer_shape_) { + MatmulInitShapeA(matmul); + } + + if (matmul->b_const_ || matmul->infer_shape_) { + MatmulInitShapeB(matmul); + } + + return MatmulBasePrepare(self); +} + +int MatmulResize(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + MatmulInitShapeA(matmul); + MatmulInitShapeB(matmul); + + int ret = MatmulInitBroadcastParams(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + return MatmulBaseResize(self); +} + +int MatmulRelease(KernelBase *self) { + MatmulBaseFreeBatchOffset((MatmulStruct *)self); + return MatmulBaseRelease(self); +} + +KernelBase *CreateMatmul(OpParameter *param, int data_type) { + KernelBase *kernel = NULL; + if (data_type == kNumberTypeFloat32) { + kernel = CreateMatmulKernel(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(kernel); + kernel->Prepare = MatmulPrepare; + kernel->Resize = MatmulResize; + kernel->Release = MatmulRelease; + } + return kernel; +} + +REG_KERNEL_CREATOR(PrimType_MatMulFusion, kNumberTypeFloat32, CreateMatmul); diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul.h new file mode 100644 index 0000000000000000000000000000000000000000..2565fa91ef116772675fb6643c180ef45ff74d61 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul.h @@ -0,0 +1,25 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_MATMUL_H_ +#define NNACL_KERNEL_MATMUL_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateMatmul(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_MATMUL_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_arm32.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_arm32.c new file mode 100644 index 0000000000000000000000000000000000000000..b707fffdd6b066c2748d8558310ba8a07d78df77 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_arm32.c @@ -0,0 +1,110 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM32 +#include "nnacl/kernel/matmul_arm32.h" +#include "nnacl/kernel/matmul_base.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" + +void MatmulARM32InitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->matrix_a_.need_pack_ = true; + matmul->matrix_b_.need_pack_ = true; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col4MajorParallel : RowMajor2Row4MajorParallel; + matmul->compute_.row_tile_ = C12NUM; + matmul->compute_.col_tile_ = C4NUM; + matmul->compute_.col_min_unit_ = C4NUM; +} + +int MatmulARM32ParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block4(a, b, c, bias, act, compute->deep_, compute->col_step_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulARM32ParallelRunByOC(MatmulStruct *matmul, int task_id) { + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block4(a, b, c, bias, act, compute->deep_, compute_oc); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +KernelBase *CreateMatmulARM32() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kNotImplemented; + matmul->init_global_varibale_ = MatmulARM32InitGlobalVariable; + matmul->parallel_run_by_batch_ = MatmulARM32ParallelRunByBatch; + matmul->parallel_run_by_oc_ = MatmulARM32ParallelRunByOC; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_arm32.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_arm32.h new file mode 100644 index 0000000000000000000000000000000000000000..61d617c6efa97c37df834f5ab408ef2ed753c0e6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_arm32.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_ARM32_H_ +#define NNACL_KERNEL_MATMUL_ARM32_H_ + +#ifdef ENABLE_ARM32 +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateMatmulARM32(); + +#endif +#endif // NNACL_KERNEL_MATMUL_ARM32_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_arm64.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_arm64.c new file mode 100644 index 0000000000000000000000000000000000000000..67beb48fe42eb0f59730ed478aa086db5d211ece --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_arm64.c @@ -0,0 +1,214 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_ARM64 +#include "nnacl/kernel/matmul_arm64.h" +#include "nnacl/kernel/matmul_base.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/fp32/pack_fp32_opt.h" + +typedef struct MatrixAPack { + int64_t points_[MAX_THREAD_NUM]; + int64_t unit_num_; + int thread_; + int deep_; + int row_; + int col_; + MatrixInfo *matrix_a_; + float *src_ptr_; + bool a_transpose_; +} MatrixAPack; + +int MatmulARM64PackMatrixAImplOptPack(void *cdata, int task_id, float l, float r) { + MatrixAPack *pack = (MatrixAPack *)cdata; + int64_t start = pack->points_[task_id]; + int64_t end = pack->unit_num_; + if (task_id < pack->thread_ - 1) { + end = pack->points_[task_id + 1]; + } + + if (pack->a_transpose_) { + RowMajor2Row12MajorOpt(pack->src_ptr_, pack->matrix_a_->pack_ptr_, pack->deep_, pack->row_, start, end); + } else { + RowMajor2Col12MajorOpt(pack->src_ptr_, pack->matrix_a_->pack_ptr_, pack->row_, pack->deep_, start, end); + } + return NNACL_OK; +} + +int MatmulARM64PackMatrixAImplOpt(MatmulStruct *matmul) { + int64_t kPackAMinUnitNum = 1 << 13; + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + float *src_ptr = matmul->matrix_a_.origin_ptr_ != NULL ? matmul->matrix_a_.origin_ptr_ + : (float *)(matmul->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_a_.pack_ptr_ != NULL, NNACL_ERR); + + MatrixAPack pack; + pack.src_ptr_ = src_ptr; + pack.matrix_a_ = &matmul->matrix_a_; + pack.deep_ = matmul->compute_.deep_; + pack.col_ = matmul->compute_.col_; + pack.row_ = matmul->compute_.row_; + pack.a_transpose_ = param->a_transpose_; + pack.unit_num_ = 0; + pack.unit_num_ = matmul->a_batch_ * UP_DIV(matmul->compute_.row_, C12NUM) * matmul->compute_.deep_; + pack.thread_ = MSMIN(matmul->base_.thread_nr_, UP_DIV(pack.unit_num_, kPackAMinUnitNum)); + if (pack.thread_ < 1) { + pack.thread_ = 1; + } + int64_t block_size = pack.unit_num_ / pack.thread_; + int64_t remain_size = pack.unit_num_ - block_size * pack.thread_; + int64_t start = 0; + size_t count = 0; + while (start < pack.unit_num_) { + pack.points_[count++] = start; + start += block_size; + if (remain_size > 0) { + ++start; + --remain_size; + } + } + pack.thread_ = count; + + if (pack.thread_ == 1) { + return MatmulARM64PackMatrixAImplOptPack(&pack, 0, 0, 1); + } + return matmul->base_.env_->ParallelLaunch(matmul->base_.env_->thread_pool_, MatmulARM64PackMatrixAImplOptPack, &pack, + pack.thread_); +} + +bool MatmulARM64CheckThreadCuttingByRow(MatmulStruct *matmul) { + if (matmul->b_batch_ != C1NUM) { + return false; + } + if (matmul->batch_ >= matmul->base_.thread_nr_ || matmul->compute_.col_ == 1) { + matmul->compute_.row_min_unit_ = C4NUM; + return true; + } + return false; +} +void MatmulARM64InitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->pack_opt_ = true; + matmul->compute_.row_tile_ = C12NUM; + matmul->compute_.col_tile_ = C8NUM; + matmul->compute_.col_min_unit_ = C8NUM; + matmul->matrix_a_.need_pack_ = true; + matmul->matrix_b_.need_pack_ = !matmul->weight_is_packed_; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel; +} + +int MatmulARM64ParallelRunByBatch(MatmulStruct *matmul, int task_id) { + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulARM64ParallelRunByRow(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int start_row = matmul->split_points_[task_id]; + int end_row = matmul->compute_.row_num_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_row = matmul->split_points_[task_id + 1]; + } + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + GemmIsNotPackByRow(matmul->matrix_a_.pack_ptr_, matmul->matrix_b_.pack_ptr_, matmul->output_data_, + matmul->matrix_c_.pack_ptr_, start_row, end_row, matmul->compute_.deep_, param->act_type_); + return NNACL_OK; +} + +int MatmulARM64ParallelRunByOC(MatmulStruct *matmul, int task_id) { + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulPackFp32(a, b, c, bias, act, compute->deep_, compute_oc); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +KernelBase *CreateMatmulARM64() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kMatmulFp32Arm64Cpu; + matmul->check_thread_cutting_by_row_ = MatmulARM64CheckThreadCuttingByRow; + matmul->init_global_varibale_ = MatmulARM64InitGlobalVariable; + matmul->parallel_run_by_oc_ = MatmulARM64ParallelRunByOC; + matmul->parallel_run_by_row_ = MatmulARM64ParallelRunByRow; + matmul->parallel_run_by_batch_ = MatmulARM64ParallelRunByBatch; + matmul->pack_matrix_a_impl_opt_ = MatmulARM64PackMatrixAImplOpt; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_arm64.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_arm64.h new file mode 100644 index 0000000000000000000000000000000000000000..9b7b547e7b70993bce7668ee7ce22995ee593606 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_arm64.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_ARM64_H_ +#define NNACL_KERNEL_MATMUL_ARM64_H_ + +#ifdef ENABLE_ARM64 +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateMatmulARM64(); + +#endif +#endif // NNACL_KERNEL_MATMUL_ARM64_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_avx.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_avx.c new file mode 100644 index 0000000000000000000000000000000000000000..98cf7fed7558b218b5a5210402a1a16d4e7e4d03 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_avx.c @@ -0,0 +1,169 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX +#include "nnacl/kernel/matmul_avx.h" +#include "nnacl/kernel/matmul_base.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32/pack_fp32.h" + +void MatmulAVXInitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->compute_.row_tile_ = C1NUM; + matmul->compute_.col_tile_ = C8NUM; + matmul->compute_.col_min_unit_ = C32NUM; + matmul->out_need_aligned_ = true; + matmul->matrix_b_.need_pack_ = true; + matmul->matrix_a_.need_pack_ = param->a_transpose_; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col32MajorParallel : RowMajor2Row32MajorParallel; +} + +int MatmulAVXParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (matmul->compute_.row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + ActType act = param->act_type_; + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + if (func_flag == 0) { + MatMulAvxFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_, compute->row_); + } else if (func_flag == C1NUM) { + MatVecMulAvxFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulAVXParallelRunByRow(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_; + + int start_row = matmul->split_points_[task_id]; + int end_row = compute->row_num_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_row = matmul->split_points_[task_id + 1]; + } + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + const float *input = matmul->matrix_a_.pack_ptr_ + start_row * compute->deep_; + float *output = matmul->output_data_ + start_row * compute->col_align_; + if (compute->col_ == 1) { + float bias = 0; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + bias = matmul->matrix_c_.pack_ptr_[0]; + } + matmul->gemm_not_pack_fun_(input, matmul->matrix_b_.pack_ptr_, output, &bias, row_num, compute->deep_, + param->act_type_); + } else { + MatMulAvxFp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_, + compute->deep_, compute->col_align_, compute->col_align_, row_num); + } + return NNACL_OK; +} + +int MatmulAVXParallelRunByOC(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulAvxFp32(a, b, c, bias, param->act_type_, compute->deep_, compute_oc, compute->col_align_, compute->row_); + } else if (func_flag == C1NUM) { + MatVecMulAvxFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +bool MatmulAVXCheckThreadCuttingByRow(MatmulStruct *matmul) { + if (matmul->b_batch_ != C1NUM) { + return false; + } + if (matmul->compute_.row_num_ < matmul->base_.thread_nr_) { + return false; + } + if (matmul->compute_.col_ == 1) { + matmul->compute_.row_min_unit_ = C4NUM; + return true; + } + if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) { + return false; + } + matmul->compute_.row_min_unit_ = C3NUM; + if (matmul->compute_.col_step_ < C16NUM) { + matmul->compute_.row_min_unit_ = C8NUM; + } else if (matmul->compute_.col_step_ < C24NUM) { + matmul->compute_.row_min_unit_ = C6NUM; + } else if (matmul->compute_.col_step_ < C32NUM) { + matmul->compute_.row_min_unit_ = C4NUM; + } + return MSMIN(matmul->compute_.row_num_ / matmul->compute_.row_min_unit_, matmul->base_.thread_nr_) > + MSMIN(matmul->compute_.col_step_ / matmul->compute_.col_min_unit_, matmul->base_.thread_nr_); +} + +KernelBase *CreateMatmulAVX() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kNotImplemented; + matmul->init_global_varibale_ = MatmulAVXInitGlobalVariable; + matmul->parallel_run_by_batch_ = MatmulAVXParallelRunByBatch; + matmul->parallel_run_by_row_ = MatmulAVXParallelRunByRow; + matmul->parallel_run_by_oc_ = MatmulAVXParallelRunByOC; + matmul->check_thread_cutting_by_row_ = MatmulAVXCheckThreadCuttingByRow; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_avx.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_avx.h new file mode 100644 index 0000000000000000000000000000000000000000..7ccdc01a3c69980e7c4c25a5069a27a732787ffa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_avx.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_AVX_H_ +#define NNACL_KERNEL_MATMUL_AVX_H_ + +#ifdef ENABLE_AVX +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateMatmulAVX(); + +#endif +#endif // NNACL_KERNEL_MATMUL_AVX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_avx512.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_avx512.c new file mode 100644 index 0000000000000000000000000000000000000000..60660932bc87048ae57809f17320cb71e1f01fef --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_avx512.c @@ -0,0 +1,708 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_AVX512 +#include "nnacl/kernel/matmul_avx512.h" +#include "nnacl/kernel/matmul_base.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32/matmul_avx512_fp32.h" +#include "nnacl/fp32/matmul_avx512_mask_fp32.h" + +#define MIN_CALC_COST 24576 /* 1 x 6 x 64x 64 */ + +void MatmulAVX512BatchRowThreadCut(MatmulStruct *matmul) { + // BatchCut + matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_); + + // RowCut + int row_step = MSMAX(matmul->compute_.row_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_); + int row_remaining = matmul->compute_.row_ - row_step * matmul->base_.thread_nr_; + + matmul->row_split_points_size_ = 0; + int row_split_point = 0; + while (row_split_point < matmul->compute_.row_) { + matmul->row_split_points_[matmul->row_split_points_size_++] = row_split_point; + row_split_point += row_step; + if (row_remaining > 0) { + ++row_split_point; + --row_remaining; + } + } + matmul->row_split_points_[matmul->row_split_points_size_] = matmul->compute_.row_; + if (matmul->compute_.batch_stride_ == 0) { + matmul->base_.thread_nr_ = matmul->row_split_points_size_; + } +} + +void MatmulAVX512BatchColThreadCut(MatmulStruct *matmul) { + // BatchCut + matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_); + + // ColCut + int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_); + int thread_num_tmp = NNACL_MIN(matmul->base_.thread_nr_, total_col_unit); + int block_col_unit = UP_DIV(total_col_unit, thread_num_tmp); + int split_point = 0; + matmul->col_split_points_size_ = 0; + while (split_point < total_col_unit) { + matmul->col_split_points_[matmul->col_split_points_size_++] = split_point * matmul->compute_.col_min_unit_; + split_point += block_col_unit; + } + if (matmul->compute_.batch_stride_ == 0) { + matmul->base_.thread_nr_ = matmul->col_split_points_size_; + } +} + +void MatmulAVX512BatchColRowSliceThreadCut(MatmulStruct *matmul) { + // BatchCut + matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_); + + int row_s = 0; + int row_e = matmul->compute_.row_; + int col_s = 0; + int col_e = matmul->compute_.col_; + + // ColCut + int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_); + matmul->compute_.block_col_unit_ = DOWN_DIV(total_col_unit, matmul->base_.thread_nr_); + matmul->col_split_points_size_ = 0; + matmul->col_split_points_[matmul->col_split_points_size_++] = 0; + if (matmul->compute_.block_col_unit_ > 0) { + int col_split_point = 0; + for (int i = 0; i < matmul->base_.thread_nr_; i++) { + MatmulSlice matmul_slice; + matmul_slice.row_s_ = row_s; + matmul_slice.row_e_ = row_e; + matmul_slice.col_s_ = col_split_point * matmul->compute_.col_min_unit_; + col_split_point += matmul->compute_.block_col_unit_; + col_s = NNACL_MIN(col_split_point * matmul->compute_.col_min_unit_, matmul->compute_.col_step_); + matmul_slice.col_e_ = col_s; + matmul->matmul_slice_set_[i][matmul->matmul_slice_count_[i]++] = matmul_slice; + } + } + if (col_e - col_s <= 0) { + return; + } + + // RowColCut + int row_thread = 0; + int less_col_align = UP_ROUND(col_e - col_s, C16NUM); + bool use_colrowcut_flag = ((less_col_align / C64NUM) * C64NUM) == less_col_align; + bool use_rowcut_flag = matmul->compute_.row_ >= C6NUM * matmul->base_.thread_nr_ || col_e - col_s <= C64NUM; + if (use_rowcut_flag && !use_colrowcut_flag) { + int row_step = MSMAX(matmul->compute_.row_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_); + int row_remaining = matmul->compute_.row_ - row_step * matmul->base_.thread_nr_; + int row_split_point = 0; + + for (row_thread = 0; row_thread < matmul->base_.thread_nr_ && row_split_point < matmul->compute_.row_; + row_thread++) { + MatmulSlice matmul_slice; + matmul_slice.row_s_ = row_split_point; + + row_split_point += row_step; + if (row_remaining > 0) { + ++row_split_point; + --row_remaining; + } + + matmul_slice.row_e_ = row_split_point; + matmul_slice.col_s_ = col_s; + matmul_slice.col_e_ = col_e; + matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice; + } + } else { + int col_num = UP_DIV(col_e - col_s, C64NUM); + int row_num = NNACL_MIN(UP_DIV(matmul->base_.thread_nr_, col_num), (row_e - row_s)); + int tile_remaining = MSMAX(col_num * row_num - matmul->base_.thread_nr_, 0); + + NNACL_CHECK_ZERO_RETURN(row_num); + int row_step = (row_e - row_s) / row_num; + int row_remaining_tmp = (row_e - row_s) - row_step * row_num; + + int row_step_cut2 = (row_num == 1) ? row_step : (row_e - row_s) / (row_num - 1); + int row_remaining_cut2_tmp = (row_e - row_s) - row_step_cut2 * (row_num - 1); + + MatmulSlice matmul_slice; + for (int c = 0; c < col_num; c++) { + matmul_slice.col_s_ = col_s + c * C64NUM; + matmul_slice.col_e_ = NNACL_MIN(col_s + (c + 1) * C64NUM, matmul->compute_.col_); + int row_split_point = 0; + int row_remaining = row_remaining_tmp; + int row_remaining_cut2 = row_remaining_cut2_tmp; + if (c < col_num - tile_remaining) { + for (int r = 0; r < row_num; r++) { + matmul_slice.row_s_ = row_split_point; + row_split_point += row_step; + if (row_remaining > 0) { + ++row_split_point; + --row_remaining; + } + matmul_slice.row_e_ = NNACL_MIN(row_split_point, matmul->compute_.row_); + matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice; + row_thread++; + } + } else { + for (int r = 0; r < row_num - 1; r++) { + matmul_slice.row_s_ = row_split_point; + row_split_point += row_step_cut2; + if (row_remaining_cut2 > 0) { + ++row_split_point; + --row_remaining_cut2; + } + matmul_slice.row_e_ = NNACL_MIN(row_split_point, matmul->compute_.row_); + matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice; + row_thread++; + } + } + } + } + if ((matmul->compute_.batch_stride_ == 0) && (matmul->compute_.block_col_unit_ == 0)) { + matmul->base_.thread_nr_ = row_thread; + } +} + +void MatmulAVX512GetThreadCuttingPolicy(MatmulStruct *matmul) { + size_t total_cost = (size_t)(matmul->batch_) * (size_t)(matmul->compute_.row_) * (size_t)(matmul->compute_.col_) * + (size_t)(matmul->compute_.deep_); + + // Thread Update + matmul->base_.thread_nr_ = MSMAX(NNACL_MIN((int)(total_cost / MIN_CALC_COST), matmul->base_.thread_nr_), C1NUM); + + if (matmul->compute_.deep_ < C128NUM) { + return MatmulBaseGetThreadCuttingPolicy(matmul); + } + + for (int i = 0; i < SPLIT_COUNT; i++) { + matmul->matmul_slice_count_[i] = 0; + } + if (matmul->compute_.col_ == 1 && !matmul->a_const_) { + MatmulAVX512BatchRowThreadCut(matmul); + if (matmul->compute_.deep_ == 1) { + matmul->gemm_not_pack_fun_ = GemmIsNotPack; + } else { + matmul->gemm_not_pack_fun_ = GemmIsNotPackOptimize; + } + matmul->parallel_run_ = matmul->parallel_run_by_gepdot_; + } else if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) { + MatmulAVX512BatchColThreadCut(matmul); + if (matmul->compute_.deep_ == 1) { + matmul->parallel_run_ = matmul->parallel_run_by_row1_deep1_gepdot_; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + matmul->gemm_not_pack_fun_ = Row1Deep1GemmIsNotPack; + } else { + matmul->gemm_not_pack_fun_ = Row1Deep1NoBiasGemmIsNotPack; + } + return; + } + matmul->parallel_run_ = matmul->parallel_run_by_gepm_; + } else { + MatmulAVX512BatchColRowSliceThreadCut(matmul); + matmul->parallel_run_ = matmul->parallel_run_by_batch_col_row_gemm_; + } + return; +} + +bool MatmulAVX512CheckThreadCuttingByRow(MatmulStruct *matmul) { + if (matmul->b_batch_ != C1NUM) { + return false; + } + if (matmul->compute_.row_num_ < matmul->base_.thread_nr_) { + return false; + } + if (matmul->compute_.col_ == 1) { + matmul->compute_.row_min_unit_ = C8NUM; + return true; + } + if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) { + return false; + } + matmul->compute_.row_min_unit_ = C6NUM; + if (matmul->compute_.col_step_ < C48NUM) { + matmul->compute_.row_min_unit_ = C12NUM; + } else if (matmul->compute_.col_step_ < C64NUM) { + matmul->compute_.row_min_unit_ = C8NUM; + } + return NNACL_MIN(matmul->compute_.row_num_ / matmul->compute_.row_min_unit_, matmul->base_.thread_nr_) > + NNACL_MIN(matmul->compute_.col_step_ / matmul->compute_.col_min_unit_, matmul->base_.thread_nr_); +} +void MatmulAVX512InitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col64MajorParallel : RowMajor2Row64MajorParallel; + matmul->matrix_a_.need_pack_ = param->a_transpose_; + matmul->matrix_b_.need_pack_ = true; + matmul->compute_.row_tile_ = C1NUM; + matmul->compute_.col_tile_ = C16NUM; + matmul->compute_.col_min_unit_ = C64NUM; + + if (matmul->compute_.row_ == 1) { + if (!matmul->b_const_ && matmul->compute_.col_ <= C128NUM) { + matmul->out_need_aligned_ = true; + } + } else if (matmul->compute_.col_ == 1) { + matmul->out_need_aligned_ = true; + } else { + matmul->out_need_aligned_ = false; + } + + if (matmul->compute_.deep_ >= C128NUM) { + matmul->out_need_aligned_ = false; + } +} +int MatmulAVX512InitParameter(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + + if (compute->deep_ < C128NUM) { + return MatmulBaseInitParameter(matmul); + } + + matmul->init_global_varibale_(matmul); + if (compute->col_ == 1 && !matmul->a_const_) { + matmul->out_need_aligned_ = false; + compute->row_tile_ = 1; + compute->col_tile_ = 1; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_a_.need_pack_ = param->a_transpose_ && compute->row_ != 1; + matmul->matrix_b_.need_pack_ = false; + matmul->pack_opt_ = false; + } else if (compute->row_ == 1 && !matmul->b_const_ && compute->col_ <= C128NUM) { + matmul->out_need_aligned_ = false; + compute->row_tile_ = 1; + compute->col_tile_ = 1; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_a_.need_pack_ = false; + matmul->matrix_b_.need_pack_ = param->b_transpose_; + matmul->pack_opt_ = false; + } + compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_); + compute->col_align_ = UP_ROUND(compute->col_, compute->col_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->row_align_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->row_align_, compute->deep_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->col_align_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->col_align_, compute->deep_, NNACL_ERR); + int a_pack_size = matmul->a_batch_ * compute->row_align_ * compute->deep_; + int b_pack_size = matmul->b_batch_ * compute->col_align_ * compute->deep_; + if ((matmul->matrix_a_.has_packed_ && matmul->matrix_a_.pack_size_ != a_pack_size) || + (matmul->matrix_b_.has_packed_ && matmul->matrix_b_.pack_size_ != b_pack_size)) { + return NNACL_ERR; + } + matmul->matrix_a_.pack_size_ = a_pack_size; + matmul->matrix_b_.pack_size_ = b_pack_size; + compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_); + matmul->out_need_aligned_ = (matmul->out_need_aligned_ && ((compute->col_ % compute->col_tile_) != 0)); + compute->col_step_ = matmul->out_need_aligned_ ? compute->col_align_ : compute->col_; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(matmul->a_batch_, compute->row_), NNACL_ERR); + compute->row_num_ = matmul->a_batch_ * compute->row_; + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByRow(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int start_row = matmul->split_points_[task_id]; + int end_row = matmul->compute_.row_num_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_row = matmul->split_points_[task_id + 1]; + } + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + const float *input = matmul->matrix_a_.pack_ptr_ + start_row * matmul->compute_.deep_; + float *output = matmul->output_data_ + start_row * matmul->compute_.col_step_; + if (matmul->compute_.col_ == 1) { + float bias = 0; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + bias = matmul->matrix_c_.pack_ptr_[0]; + } + matmul->gemm_not_pack_fun_(input, matmul->matrix_b_.pack_ptr_, output, &bias, row_num, matmul->compute_.deep_, + param->act_type_); + } else { + if (matmul->out_need_aligned_) { + MatMulAvx512Fp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_, + matmul->compute_.deep_, matmul->compute_.col_align_, matmul->compute_.col_align_, row_num); + } else { + MatMulMaskAvx512Fp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_, + matmul->compute_.deep_, matmul->compute_.col_, matmul->compute_.col_, row_num); + } + } + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByOC(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (matmul->compute_.row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + if (matmul->out_need_aligned_) { + MatMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_, compute->row_); + } else { + MatMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_, compute->row_); + } + } else if (func_flag == C1NUM) { + if (matmul->out_need_aligned_) { + MatVecMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_); + } else { + MatVecMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_); + } + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = NNACL_MIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + + if (func_flag == 0) { + if (matmul->out_need_aligned_) { + MatMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_, compute->row_); + } else { + MatMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_, compute->row_); + } + } else if (func_flag == C1NUM) { + if (matmul->out_need_aligned_) { + MatVecMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_); + } else { + MatVecMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_); + } + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByGEPM(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_; + int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_; + int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_; + int matrix_col = matmul->compute_.col_step_; + int matrix_deep = matmul->compute_.deep_; + + // by BatchCut + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = NNACL_MIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_); + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size; + float *c = matmul->output_data_ + index * c_plane_size; + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col); + } + + // by ColCut + int col_split_points_size = matmul->col_split_points_size_; + if (task_id < col_split_points_size) { + int start_oc = matmul->col_split_points_[task_id]; + int end_oc = matrix_col; + if (task_id < (col_split_points_size - 1)) { + end_oc = matmul->col_split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc; + float *c = matmul->output_data_ + i * c_plane_size + start_oc; + MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col); + } + } + return NNACL_OK; +} +int MatmulAVX512ParallelRunByGEMM(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_; + int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_; + int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_; + int matrix_row = matmul->compute_.row_; + int matrix_col = matmul->compute_.col_step_; + int matrix_deep = matmul->compute_.deep_; + + // by BatchCut + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = start_batch + matmul->compute_.batch_stride_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size; + float *c = matmul->output_data_ + index * c_plane_size; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col, matrix_row); + } + + // by ColCut + int col_split_points_size = matmul->col_split_points_size_; + if (task_id < col_split_points_size) { + int start_oc = matmul->col_split_points_[task_id]; + int end_oc = matmul->col_split_points_[task_id + 1]; + int compute_oc = end_oc - start_oc; + + bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + if (compute_oc > 0) { + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep; + float *c = matmul->output_data_ + i * c_plane_size + start_oc; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, matrix_row); + } + } + } + + // by RowCut + int start_oc = matmul->col_split_points_[col_split_points_size]; + int end_oc = matrix_col; + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + int row_split_points_size = matmul->row_split_points_size_; + if (task_id >= row_split_points_size) { + return NNACL_OK; + } + int start_row = matmul->row_split_points_[task_id]; + int end_row = matmul->row_split_points_[task_id + 1]; + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + + bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size + start_row * matrix_deep; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep; + float *c = matmul->output_data_ + i * c_plane_size + start_row * matrix_col + start_oc; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, row_num); + } + + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByGEPDOT(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = &matmul->compute_; + + // by BatchCut + int start_batch = task_id * compute->batch_stride_; + int end_batch = start_batch + compute->batch_stride_; + float bias = 0; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + bias = matmul->matrix_c_.pack_ptr_[0]; + } + int a_stride = compute->row_ * compute->deep_; + int b_stride = compute->deep_ * compute->col_; + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_stride; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_stride; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_; + matmul->gemm_not_pack_fun_(a, b, c, &bias, compute->row_, compute->deep_, param->act_type_); + } + + // by RowCut + int split_points_size = matmul->row_split_points_size_; + if (task_id >= split_points_size) { + return NNACL_OK; + } + for (int index = matmul->base_.thread_nr_ * compute->batch_stride_; index < matmul->batch_; ++index) { + int start_row = matmul->row_split_points_[task_id]; + int end_row = matmul->row_split_points_[task_id + 1]; + int row_num = end_row - start_row; + if (row_num <= 0) { + continue; + } + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_stride + start_row * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_stride; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_ + start_row * compute->col_step_; + matmul->gemm_not_pack_fun_(a, b, c, &bias, row_num, compute->deep_, param->act_type_); + } + + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByRow1Deep1GEPDOT(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_; + int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_; + int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_; + int matrix_col = matmul->compute_.col_step_; + int matrix_deep = matmul->compute_.deep_; + + // by BatchCut + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = NNACL_MIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_); + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size; + float *c = matmul->output_data_ + index * c_plane_size; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + matmul->gemm_not_pack_fun_(a, b, c, bias, matrix_col, matrix_deep, param->act_type_); + } + + // by ColCut + int col_split_points_size = matmul->col_split_points_size_; + if (task_id < col_split_points_size) { + int start_oc = matmul->col_split_points_[task_id]; + int end_oc = matrix_col; + if (task_id < (col_split_points_size - 1)) { + end_oc = matmul->col_split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc; + float *c = matmul->output_data_ + i * c_plane_size + start_oc; + matmul->gemm_not_pack_fun_(a, b, c, bias, compute_oc, matrix_deep, param->act_type_); + } + } + return NNACL_OK; +} + +int MatmulAVX512ParallelRunByBatchColRowGEMM(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + + int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_; + int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_; + int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_; + int matrix_row = matmul->compute_.row_; + int matrix_col = matmul->compute_.col_step_; + int matrix_deep = matmul->compute_.deep_; + + // by BatchCut + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = start_batch + matmul->compute_.batch_stride_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size; + float *c = matmul->output_data_ + index * c_plane_size; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col, matrix_row); + } + + MatmulSlice *matmul_slices = matmul->matmul_slice_set_[task_id]; + int slice_count = matmul->matmul_slice_count_[task_id]; + for (int s = 0; s < slice_count; s++) { + MatmulSlice matmul_slice = matmul_slices[s]; + + int start_oc = matmul_slice.col_s_; + int end_oc = matmul_slice.col_e_; + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + int start_row = matmul_slice.row_s_; + int end_row = matmul_slice.row_e_; + int row_num = end_row - start_row; + if (row_num <= 0) { + return NNACL_OK; + } + + bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size + start_row * matrix_deep; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep; + float *c = matmul->output_data_ + i * c_plane_size + start_row * matrix_col + start_oc; + MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, row_num); + } + } + return NNACL_OK; +} + +KernelBase *CreateMatmulAVX512() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kNotImplemented; + matmul->check_thread_cutting_by_row_ = MatmulAVX512CheckThreadCuttingByRow; + matmul->get_thread_cutting_policy_ = MatmulAVX512GetThreadCuttingPolicy; + matmul->init_parameter_ = MatmulAVX512InitParameter; + matmul->init_global_varibale_ = MatmulAVX512InitGlobalVariable; + matmul->parallel_run_by_oc_ = MatmulAVX512ParallelRunByOC; + matmul->parallel_run_by_row_ = MatmulAVX512ParallelRunByRow; + matmul->parallel_run_by_batch_ = MatmulAVX512ParallelRunByBatch; + matmul->parallel_run_by_gemm_ = MatmulAVX512ParallelRunByGEMM; + matmul->parallel_run_by_gepm_ = MatmulAVX512ParallelRunByGEPM; + matmul->parallel_run_by_gepdot_ = MatmulAVX512ParallelRunByGEPDOT; + matmul->parallel_run_by_batch_col_row_gemm_ = MatmulAVX512ParallelRunByBatchColRowGEMM; + matmul->parallel_run_by_row1_deep1_gepdot_ = MatmulAVX512ParallelRunByRow1Deep1GEPDOT; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_avx512.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_avx512.h new file mode 100644 index 0000000000000000000000000000000000000000..4ce212b4f88bb09af25afde17b911b1af70aef15 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_avx512.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_AVX512_H_ +#define NNACL_KERNEL_MATMUL_AVX512_H_ +#ifdef ENABLE_AVX512 +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateMatmulAVX512(); + +#endif +#endif // NNACL_KERNEL_MATMUL_AVX512_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_base.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_base.c new file mode 100644 index 0000000000000000000000000000000000000000..672d6d5b3fceb133d22c21e2ddeeee323c57a8fc --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_base.c @@ -0,0 +1,676 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/matmul_base.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/op_base.h" + +#define kNumDeepThreshold 512 + +int MatmulFp32Run(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + MatmulStruct *matmul = (MatmulStruct *)cdata; + return matmul->parallel_run_(matmul, task_id); +} + +void MatmulBaseFreeBatchOffset(MatmulStruct *matmul) { + if (matmul->a_offset_ != NULL) { + free(matmul->a_offset_); + matmul->a_offset_ = NULL; + } + if (matmul->b_offset_ != NULL) { + free(matmul->b_offset_); + matmul->b_offset_ = NULL; + } +} + +int MatmulBaseMallocBatchOffset(MatmulStruct *matmul) { + matmul->a_offset_ = malloc(matmul->batch_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->a_offset_); + memset(matmul->a_offset_, 0, matmul->batch_ * sizeof(int)); + + matmul->b_offset_ = malloc(matmul->batch_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->b_offset_); + memset(matmul->b_offset_, 0, matmul->batch_ * sizeof(int)); + return NNACL_OK; +} + +int MatmulBasePackMatrixBParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + + int start = task_id * compute->pack_b_stride_; + if (param->b_transpose_) { + int end = NNACL_MIN(matmul->compute_.col_, start + compute->pack_b_stride_); + matmul->matrix_b_pack_fun_(matmul->pack_b_src_, matmul->pack_b_dst_, compute->col_, compute->deep_, start, end); + } else { + int end = NNACL_MIN(matmul->compute_.deep_, start + compute->pack_b_stride_); + matmul->matrix_b_pack_fun_(matmul->pack_b_src_, matmul->pack_b_dst_, compute->deep_, compute->col_, start, end); + } + return NNACL_OK; +} + +int MatmulFp32PackMatrixBRun(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + MatmulStruct *matmul = (MatmulStruct *)cdata; + return MatmulBasePackMatrixBParallelRunByBatch(matmul, task_id); +} + +bool MatmulBaseCheckRowOptimalConditions(MatmulStruct *matmul) { + return matmul->compute_.row_ == 1 && + !(matmul->support_mul_batch_cut_by_row_ && (matmul->a_batch_ > 1 && matmul->b_batch_ == 1)); +} + +int MatmulBaseInitParameter(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + + matmul->init_global_varibale_(matmul); + if (MatmulBaseCheckRowOptimalConditions(matmul)) { + compute->row_tile_ = 1; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_a_.need_pack_ = false; + matmul->pack_opt_ = false; + if (!matmul->b_const_ && compute->col_ <= C128NUM) { + compute->col_tile_ = 1; + matmul->out_need_aligned_ = false; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_.need_pack_ = param->b_transpose_; + } + } + if (compute->col_ == 1 && !matmul->a_const_) { + matmul->out_need_aligned_ = false; + compute->row_tile_ = 1; + compute->col_tile_ = 1; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel; + matmul->matrix_a_.need_pack_ = param->a_transpose_ && compute->row_ != 1; + matmul->matrix_b_.need_pack_ = false; + matmul->pack_opt_ = false; + } + compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_); + compute->col_align_ = UP_ROUND(compute->col_, compute->col_tile_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->row_align_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->row_align_, compute->deep_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->col_align_, NNACL_ERR); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->col_align_, compute->deep_, NNACL_ERR); + int a_pack_size = matmul->a_batch_ * compute->row_align_ * compute->deep_; + int b_pack_size = matmul->b_batch_ * compute->col_align_ * compute->deep_; + if ((matmul->matrix_a_.has_packed_ && matmul->matrix_a_.pack_size_ != a_pack_size) || + (matmul->matrix_b_.has_packed_ && matmul->matrix_b_.pack_size_ != b_pack_size)) { + return NNACL_ERR; + } + matmul->matrix_a_.pack_size_ = a_pack_size; + matmul->matrix_b_.pack_size_ = b_pack_size; + compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_); + matmul->out_need_aligned_ = (matmul->out_need_aligned_ && ((compute->col_ % compute->col_tile_) != 0)); + compute->col_step_ = matmul->out_need_aligned_ ? compute->col_align_ : compute->col_; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(matmul->a_batch_, compute->row_), NNACL_ERR); + compute->row_num_ = matmul->a_batch_ * compute->row_; + return NNACL_OK; +} + +int MatmulBasePackMatrixAImplOpt(MatmulStruct *matmul) { return NNACL_ERR; } + +int MatmulBasePackMatrixAImpl(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + float *src_ptr = (matmul->matrix_a_.origin_ptr_ != NULL) ? (matmul->matrix_a_.origin_ptr_) + : (float *)(matmul->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_a_.pack_ptr_ != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_a_pack_fun_ != NULL, NNACL_ERR); + for (int i = 0; i < matmul->a_batch_; i++) { + const float *src = src_ptr + i * matmul->compute_.deep_ * matmul->compute_.row_; + float *dst = matmul->matrix_a_.pack_ptr_ + i * matmul->compute_.deep_ * matmul->compute_.row_align_; + if (param->a_transpose_) { + matmul->matrix_a_pack_fun_(src, dst, matmul->compute_.deep_, matmul->compute_.row_, 0, matmul->compute_.deep_); + } else { + matmul->matrix_a_pack_fun_(src, dst, matmul->compute_.row_, matmul->compute_.deep_, 0, matmul->compute_.row_); + } + } + return NNACL_OK; +} + +int MatmulBasePackMatrixBImpl(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + + float *src_ptr = matmul->matrix_b_.origin_ptr_ != NULL ? matmul->matrix_b_.origin_ptr_ + : (float *)matmul->base_.in_[SECOND_INPUT]->data_; + NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_b_.pack_ptr_ != NULL, NNACL_ERR); + NNACL_CHECK_TRUE_RET(matmul->matrix_b_pack_fun_ != NULL, NNACL_ERR); + + for (int i = 0; i < matmul->b_batch_; i++) { + if (param->b_transpose_) { + matmul->compute_.pack_b_stride_ = UP_DIV(matmul->compute_.col_, matmul->base_.thread_nr_); + } else { + matmul->compute_.pack_b_stride_ = UP_DIV(matmul->compute_.deep_, matmul->base_.thread_nr_); + } + matmul->pack_b_src_ = src_ptr + i * matmul->compute_.deep_ * matmul->compute_.col_; + matmul->pack_b_dst_ = matmul->matrix_b_.pack_ptr_ + i * matmul->compute_.deep_ * matmul->compute_.col_align_; + int ret = matmul->base_.env_->ParallelLaunch(matmul->base_.env_->thread_pool_, MatmulFp32PackMatrixBRun, matmul, + matmul->base_.thread_nr_); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + return NNACL_OK; +} + +int MatmulBasePackMatrixA(MatmulStruct *matmul) { + if (!matmul->a_const_) { + if (!matmul->matrix_a_.need_pack_) { + matmul->matrix_a_.pack_ptr_ = (float *)matmul->base_.in_[0]->data_; + return NNACL_OK; + } + if (matmul->base_.train_session_) { + matmul->matrix_a_.pack_ptr_ = (float *)(matmul->base_.workspace_); + } else { + matmul->matrix_a_.pack_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, + matmul->matrix_a_.pack_size_ * sizeof(float))); + } + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_); + } else { + bool is_packed = false; + void *data = NULL; + size_t data_size = (size_t)(matmul->matrix_a_.pack_size_) * sizeof(float); + if (matmul->is_sharing_pack_) { + TensorC *a_matrix = matmul->base_.in_[FIRST_INPUT]; + data = matmul->get_sharing_weight_(matmul->shaing_manager_, a_matrix->data_, data_size, &is_packed); + } else { + data = malloc(data_size); + } + matmul->matrix_a_.pack_ptr_ = (float *)data; + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_); + if (is_packed) { + return NNACL_OK; + } + } + if (matmul->pack_opt_) { + /* valid in arm64 */ + return matmul->pack_matrix_a_impl_opt_(matmul); + } + return matmul->pack_matrix_a_impl_(matmul); +} + +int MatmulBasePackMatrixB(MatmulStruct *matmul) { + if (!matmul->b_const_) { + if (!matmul->matrix_b_.need_pack_) { + matmul->matrix_b_.pack_ptr_ = (float *)matmul->base_.in_[SECOND_INPUT]->data_; + return NNACL_OK; + } + if (matmul->base_.train_session_) { + matmul->matrix_b_.pack_ptr_ = (float *)(matmul->base_.workspace_) + matmul->matrix_a_.pack_size_; + } else { + matmul->matrix_b_.pack_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, + matmul->matrix_b_.pack_size_ * sizeof(float))); + } + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_b_.pack_ptr_); + } else { + if (!matmul->matrix_b_.need_pack_ && matmul->weight_is_packed_) { + matmul->matrix_b_.pack_ptr_ = (float *)matmul->base_.in_[SECOND_INPUT]->data_; + return NNACL_OK; + } + bool is_packed = false; + void *data = NULL; + size_t data_size = (size_t)(matmul->matrix_b_.pack_size_) * sizeof(float); + if (matmul->is_sharing_pack_) { + TensorC *b_matrix = matmul->base_.in_[SECOND_INPUT]; + data = matmul->get_sharing_weight_(matmul->shaing_manager_, b_matrix->data_, data_size, &is_packed); + } else { + data = malloc(data_size); + } + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(data); + matmul->matrix_b_.pack_ptr_ = (float *)data; + if (is_packed) { + return NNACL_OK; + } + } + return matmul->pack_matrix_b_impl_(matmul); +} + +int MatmulBaseBackupConstMatrix(MatmulStruct *matmul, MatrixInfo *matrix_info, int index) { + NNACL_CHECK_TRUE_RET(index < (int)matmul->base_.in_size_, NNACL_ERR); + size_t backup_size = (size_t)NNACLGetElementNum(matmul->base_.in_[index]) * sizeof(float); + NNACL_CHECK_TRUE_RET(backup_size > 0, NNACL_ERR); + matrix_info->origin_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, backup_size)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matrix_info->origin_ptr_); + void *src_ptr = matmul->base_.in_[index]->data_; + NNACL_CHECK_NULL_RETURN_ERR(src_ptr); + (void)memcpy(matrix_info->origin_ptr_, src_ptr, backup_size); + matrix_info->origin_need_free_ = true; + return NNACL_OK; +} + +int MatmulBaseParallelRunByRow(MatmulStruct *matmul, int task_id) { return NNACL_ERR; } + +int MatmulBaseParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + MatmulComputeParam *compute = &matmul->compute_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, param->act_type_, compute->deep_, compute->row_, compute->col_step_, compute->col_, + OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block8(a, b, c, bias, param->act_type_, compute->deep_, compute->col_step_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulBaseParallelRunIsNotPackByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + int start_batch = task_id * matmul->compute_.batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_); + float bias = 0; + if (matmul->matrix_c_.pack_ptr_ != NULL) { + bias = matmul->matrix_c_.pack_ptr_[0]; + } + for (int index = start_batch; index < end_batch; ++index) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * matmul->compute_.row_ * matmul->compute_.deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * matmul->compute_.deep_ * matmul->compute_.col_; + float *c = matmul->output_data_ + index * matmul->compute_.row_ * matmul->compute_.col_; + matmul->gemm_not_pack_fun_(a, b, c, &bias, matmul->compute_.row_, matmul->compute_.deep_, param->act_type_); + } + return NNACL_OK; +} + +void MatmulBaseGetThreadCuttingInfoByRow(MatmulStruct *matmul) { + int row_step = NNACL_MAX(matmul->compute_.row_num_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_); + int row_remaining = matmul->compute_.row_num_ - row_step * matmul->base_.thread_nr_; + + int split_point = 0; + int count = 0; + while (split_point < matmul->compute_.row_num_) { + matmul->split_points_[count++] = split_point; + split_point += row_step; + if (row_remaining > 0) { + ++split_point; + --row_remaining; + } + } + matmul->base_.thread_nr_ = count; +} + +int MatmulBaseParallelRunByOC(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block8(a, b, c, bias, act, compute->deep_, compute_oc); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +void MatmulBaseGetThreadCuttingPolicy(MatmulStruct *matmul) { + if (matmul->compute_.deep_ < kNumDeepThreshold) { + if (matmul->model_thread_nr_ != -1) { + matmul->base_.thread_nr_ = matmul->model_thread_nr_; + } + } + + if ((matmul->a_batch_ >= matmul->base_.thread_nr_ && + (matmul->b_batch_ == matmul->a_batch_ || !matmul->support_mul_batch_cut_by_row_)) || + matmul->compute_.col_ == 1) { + matmul->compute_.batch_stride_ = UP_DIV(matmul->batch_, matmul->base_.thread_nr_); + matmul->parallel_run_ = matmul->parallel_run_by_batch_; + if (matmul->compute_.col_ != 1 || matmul->a_const_) { + return; + } + + matmul->parallel_run_ = matmul->parallel_run_not_pack_by_batch_; + if (matmul->compute_.deep_ == 1) { + matmul->gemm_not_pack_fun_ = GemmIsNotPack; + } else { + matmul->gemm_not_pack_fun_ = GemmIsNotPackOptimize; + if (matmul->check_thread_cutting_by_row_(matmul)) { + matmul->parallel_run_ = matmul->parallel_run_by_row_; + matmul->get_thread_cutting_info_by_row_(matmul); + } + } + return; + } else if ((matmul->a_batch_ >= matmul->base_.thread_nr_ && matmul->b_batch_ == 1) || + matmul->check_thread_cutting_by_row_(matmul)) { + matmul->parallel_run_ = matmul->parallel_run_by_row_; + matmul->get_thread_cutting_info_by_row_(matmul); + } else { + int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_); + matmul->base_.thread_nr_ = MSMIN(matmul->base_.thread_nr_, total_col_unit); + int block_col_unit = UP_DIV(total_col_unit, matmul->base_.thread_nr_); + + int count = 0; + int split_point = 0; + while (split_point < total_col_unit) { + matmul->split_points_[count++] = (split_point * matmul->compute_.col_min_unit_); + split_point += block_col_unit; + } + matmul->base_.thread_nr_ = count; + matmul->parallel_run_ = matmul->parallel_run_by_oc_; + } + return; +} + +int MatmulBasePackBiasMatrix(MatmulStruct *matmul) { + if (matmul->base_.in_size_ != FOURTH_INPUT) { + return NNACL_OK; + } + if (matmul->matrix_c_.has_packed_) { + NNACL_CHECK_FALSE(matmul->matrix_c_.pack_size_ < matmul->compute_.col_align_, NNACL_ERR); + return NNACL_OK; + } + TensorC *bias_tensor = matmul->base_.in_[THIRD_INPUT]; + float *bias_src = matmul->matrix_c_.origin_ptr_ != NULL ? matmul->matrix_c_.origin_ptr_ : (float *)bias_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(bias_src); + + int bias_num = NNACLGetElementNum(bias_tensor); + NNACL_CHECK_TRUE_RET(bias_num > 0 && matmul->compute_.col_align_ >= bias_num, NNACL_ERR); + + matmul->matrix_c_.pack_size_ = matmul->compute_.col_align_; + if (matmul->matrix_c_.pack_ptr_ == NULL) { + matmul->matrix_c_.pack_ptr_ = (float *)(malloc(matmul->matrix_c_.pack_size_ * sizeof(float))); + } + NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_c_.pack_ptr_); + + if (bias_num == 1) { + for (int i = 0; i < matmul->matrix_c_.pack_size_; ++i) { + matmul->matrix_c_.pack_ptr_[i] = bias_src[0]; + } + } else { + (void)memcpy(matmul->matrix_c_.pack_ptr_, bias_src, bias_num * sizeof(float)); + (void)memset(matmul->matrix_c_.pack_ptr_ + bias_num, 0, (matmul->matrix_c_.pack_size_ - bias_num) * sizeof(float)); + } + if (matmul->matrix_c_.origin_need_free_) { + matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->matrix_c_.origin_ptr_); + matmul->matrix_c_.origin_ptr_ = NULL; + matmul->matrix_c_.origin_need_free_ = false; + } + return NNACL_OK; +} + +int MatmulBaseInitTmpOutBuffer(MatmulStruct *matmul) { + if (matmul->out_need_aligned_) { + if (matmul->output_data_ != NULL) { + matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->output_data_); + } + // avx need to malloc dst aligned to C8NUM + // avx512 need to malloc dst aligned to C16NUM + int out_channel = matmul->compute_.col_; + NNACL_CHECK_ZERO_RETURN_ERR(matmul->compute_.col_tile_); + int oc_block_num = UP_DIV(out_channel, matmul->compute_.col_tile_); + int ele_num = matmul->batch_ * matmul->compute_.row_ * oc_block_num * matmul->compute_.col_tile_; + int data_size = ele_num * (int)sizeof(float); + matmul->output_data_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, data_size)); + NNACL_CHECK_NULL_RETURN_ERR(matmul->output_data_); + } + return NNACL_OK; +} + +void MatmulBaseInitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + matmul->matrix_a_.need_pack_ = true; + matmul->matrix_b_.need_pack_ = !matmul->weight_is_packed_; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel; + matmul->compute_.row_tile_ = C12NUM; + matmul->compute_.col_tile_ = C8NUM; + matmul->compute_.col_min_unit_ = C8NUM; + return; +} + +bool MatmulBaseCheckThreadCuttingByRow() { return false; } + +void MatmulBaseFreePackedMatrixA(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + if (matmul->matrix_a_.need_pack_ && !matmul->base_.train_session_ && matmul->matrix_a_.pack_ptr_ != NULL) { + self->env_->Free(self->env_->allocator_, matmul->matrix_a_.pack_ptr_); + } + matmul->matrix_a_.pack_ptr_ = NULL; +} + +void MatmulBaseFreePackedMatrixB(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + if (matmul->matrix_b_.need_pack_ && !matmul->base_.train_session_ && matmul->matrix_b_.pack_ptr_ != NULL) { + matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->matrix_b_.pack_ptr_); + } + matmul->matrix_b_.pack_ptr_ = NULL; +} + +int MatmulBaseResize(KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + + int ret = matmul->init_parameter_(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + if (self->train_session_) { + self->work_size_ = (matmul->matrix_a_.pack_size_ + matmul->matrix_b_.pack_size_) * (int)sizeof(float); + } + + matmul->get_thread_cutting_policy_(matmul); + if (!matmul->matrix_c_.has_packed_) { + ret = MatmulBasePackBiasMatrix(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + if (!matmul->bias_need_repack_) { + matmul->matrix_c_.has_packed_ = true; + } + } + ret = MatmulBaseInitTmpOutBuffer(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + return NNACL_OK; +} + +int MatmulBaseRelease(struct KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + MatmulBaseFreeBatchOffset(matmul); + + if (matmul->out_need_aligned_ && matmul->output_data_ != NULL) { + matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->output_data_); + matmul->output_data_ = NULL; + } + if (matmul->matrix_c_.pack_ptr_ != NULL) { + free(matmul->matrix_c_.pack_ptr_); + matmul->matrix_c_.pack_ptr_ = NULL; + } + if (matmul->a_const_) { + if (matmul->is_sharing_pack_) { + matmul->free_sharing_weight_(matmul->shaing_manager_, matmul->matrix_a_.pack_ptr_); + } else { + free(matmul->matrix_a_.pack_ptr_); + } + } + if (matmul->b_const_) { + if (!matmul->matrix_b_.need_pack_ && matmul->weight_is_packed_) { + return NNACL_OK; + } + if (matmul->is_sharing_pack_) { + matmul->free_sharing_weight_(matmul->shaing_manager_, matmul->matrix_b_.pack_ptr_); + } else { + free(matmul->matrix_b_.pack_ptr_); + } + } + return NNACL_OK; +} + +int MatmulBasePrepare(struct KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + + NNACL_CHECK_FALSE(matmul->base_.in_size_ < C2NUM, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(matmul->base_.out_size_ < 1, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(matmul->base_.in_[FIRST_INPUT]->data_type_ != kNumberTypeFloat32, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(matmul->base_.in_[SECOND_INPUT]->data_type_ != kNumberTypeFloat32, NNACL_INPUT_TENSOR_ERROR); + + if (matmul->base_.in_size_ == THREE_TENSOR) { + NNACL_CHECK_TRUE_RET(matmul->base_.in_[THIRD_INPUT]->data_type_ == kNumberTypeFloat32, NNACL_MATMUL_BIAS_INVALID); + } + + MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_); + NNACL_CHECK_FALSE( + param->act_type_ != ActType_No && param->act_type_ != ActType_Relu && param->act_type_ != ActType_Relu6, + NNACL_MATMUL_ACT_TYPE_INVALID); + + int ret = matmul->init_parameter_(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + if (matmul->a_const_) { + ret = MatmulBasePackMatrixA(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + matmul->matrix_a_.has_packed_ = true; + } + if (matmul->b_const_) { + ret = MatmulBasePackMatrixB(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + matmul->matrix_b_.has_packed_ = true; + } + + if (matmul->base_.in_size_ == THREE_TENSOR) { + /* deal with const bias */ + bool bias_const = NNACLIsConst(self->in_[THIRD_INPUT]); + if (!matmul->infer_shape_ && bias_const && !matmul->base_.train_session_ && matmul->matrix_c_.origin_ptr_ == NULL) { + ret = MatmulBaseBackupConstMatrix(matmul, &matmul->matrix_c_, THIRD_INPUT); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + } + return NNACL_OK; +} + +int MatmulBaseCompute(struct KernelBase *self) { + MatmulStruct *matmul = (MatmulStruct *)self; + + float *out_data = (float *)(matmul->base_.out_[FIRST_INPUT]->data_); + NNACL_CHECK_FALSE(out_data == NULL, NNACL_ERR); + if (!matmul->out_need_aligned_) { + matmul->output_data_ = out_data; + } + + if (!matmul->a_const_) { + int ret = MatmulBasePackMatrixA(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + if (!matmul->b_const_) { + int ret = MatmulBasePackMatrixB(matmul); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + + NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_); + NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_b_.pack_ptr_); + + int ret = self->env_->ParallelLaunch(self->env_->thread_pool_, MatmulFp32Run, self, self->thread_nr_); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + if (matmul->out_need_aligned_) { + PackNHWCXToNHWCFp32(matmul->output_data_, out_data, matmul->batch_, matmul->compute_.row_, matmul->compute_.col_, + matmul->compute_.col_tile_); + } else { + matmul->output_data_ = NULL; + } + if (!matmul->a_const_) { + MatmulBaseFreePackedMatrixA(self); + } + + if (!matmul->b_const_) { + MatmulBaseFreePackedMatrixB(self); + } + return NNACL_OK; +} + +void InitMatrixInfo(MatrixInfo *info) { + info->need_pack_ = false; + info->has_packed_ = false; + info->origin_need_free_ = false; + info->pack_size_ = -1; + info->origin_ptr_ = NULL; + info->pack_ptr_ = NULL; +} + +KernelBase *CreateMatmulBase() { + MatmulStruct *matmul = (MatmulStruct *)malloc(sizeof(MatmulStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + memset(matmul, 0, sizeof(MatmulStruct)); + matmul->base_.Prepare = MatmulBasePrepare; + matmul->base_.Resize = MatmulBaseResize; + matmul->base_.Release = MatmulBaseRelease; + matmul->base_.Compute = MatmulBaseCompute; + InitMatrixInfo(&(matmul->matrix_a_)); + InitMatrixInfo(&(matmul->matrix_b_)); + InitMatrixInfo(&(matmul->matrix_c_)); + matmul->is_sharing_pack_ = false; + matmul->pack_opt_ = false; + matmul->a_const_ = false; + matmul->b_const_ = false; + matmul->bias_need_repack_ = false; + matmul->out_need_aligned_ = false; + matmul->a_offset_ = NULL; + matmul->b_offset_ = NULL; + matmul->model_thread_nr_ = -1; + matmul->support_mul_batch_cut_by_row_ = false; + matmul->matmul_type_ = kMatmulFp32BaseCpu; + matmul->get_thread_cutting_policy_ = MatmulBaseGetThreadCuttingPolicy; + matmul->check_thread_cutting_by_row_ = MatmulBaseCheckThreadCuttingByRow; + matmul->get_thread_cutting_info_by_row_ = MatmulBaseGetThreadCuttingInfoByRow; + matmul->init_parameter_ = MatmulBaseInitParameter; + matmul->init_global_varibale_ = MatmulBaseInitGlobalVariable; + matmul->pack_matrix_a_impl_opt_ = MatmulBasePackMatrixAImplOpt; + matmul->pack_matrix_a_impl_ = MatmulBasePackMatrixAImpl; + matmul->pack_matrix_b_impl_ = MatmulBasePackMatrixBImpl; + matmul->parallel_run_by_batch_ = MatmulBaseParallelRunByBatch; + matmul->parallel_run_not_pack_by_batch_ = MatmulBaseParallelRunIsNotPackByBatch; + matmul->parallel_run_by_oc_ = MatmulBaseParallelRunByOC; + matmul->parallel_run_by_row_ = MatmulBaseParallelRunByRow; + return (KernelBase *)matmul; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_base.h new file mode 100644 index 0000000000000000000000000000000000000000..7eb3559ceed557f897a15dda4c613f116f9d24e4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_base.h @@ -0,0 +1,35 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_BASE_H_ +#define NNACL_KERNEL_MATMUL_BASE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/matmul_parameter.h" +#include "nnacl/kernel/matmul_struct.h" + +void MatmulBaseGetThreadCuttingPolicy(MatmulStruct *matmul); +void MatmulBaseFreeBatchOffset(MatmulStruct *matmul); +int MatmulBaseMallocBatchOffset(MatmulStruct *matmul); +int MatmulBaseInitParameter(MatmulStruct *matmul); +int MatmulBasePrepare(KernelBase *self); +int MatmulBaseResize(KernelBase *self); +int MatmulBaseRelease(KernelBase *self); + +KernelBase *CreateMatmulBase(); + +#endif // NNACL_KERNEL_MATMUL_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_create.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_create.c new file mode 100644 index 0000000000000000000000000000000000000000..9d3d3f1820ff45780cf34a6089ba5876959f94a9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_create.c @@ -0,0 +1,82 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/matmul_create.h" +#include "nnacl/kernel/matmul_base.h" +#if defined(ENABLE_AVX512) +#include "nnacl/kernel/matmul_avx512.h" +#include "nnacl/intrinsics/ms_simd_cpu_info.h" +#endif + +#if defined(ENABLE_AVX) +#include "nnacl/kernel/matmul_avx.h" +#endif + +#if defined(ENABLE_SSE) +#include "nnacl/kernel/matmul_sse.h" +#endif + +#if defined(ENABLE_ARM32) +#include "nnacl/kernel/matmul_arm32.h" +#endif + +#if defined(ENABLE_ARM64) +#include "nnacl/kernel/matmul_arm64.h" +#endif + +KernelBase *CreateMatmulKernel() { + KernelBase *matmul = NULL; + +#if defined(ENABLE_AVX512) + AVX512_HARDWARE_SELF_AWARENESS_BEGIN + matmul = CreateMatmulAVX512(); + if (matmul != NULL) { + return matmul; + } + AVX512_HARDWARE_SELF_AWARENESS_END +#endif + +#if defined(ENABLE_AVX) + matmul = CreateMatmulAVX(); + if (matmul != NULL) { + return matmul; + } +#endif + +#if defined(ENABLE_SSE) + matmul = CreateMatmulSSE(); + if (matmul != NULL) { + return matmul; + } +#endif + +#if defined(ENABLE_ARM64) + matmul = CreateMatmulARM64(); + if (matmul != NULL) { + return matmul; + } +#endif + +#if defined(ENABLE_ARM32) + matmul = CreateMatmulARM32(); + if (matmul != NULL) { + return matmul; + } +#endif + + matmul = CreateMatmulBase(); + return matmul; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_create.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_create.h new file mode 100644 index 0000000000000000000000000000000000000000..6ccdf08f8797c0242af8897ad45a4379a9b0d2b5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_create.h @@ -0,0 +1,24 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_MATMUL_CREATE_H_ +#define NNACL_KERNEL_MATMUL_CREATE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/kernel.h" + +KernelBase *CreateMatmulKernel(); + +#endif // NNACL_KERNEL_MATMUL_CREATE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_sse.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_sse.c new file mode 100644 index 0000000000000000000000000000000000000000..d023ebf5c681ddd35b9ad934fe231d04aaa43c35 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_sse.c @@ -0,0 +1,110 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_SSE +#include "nnacl/kernel/matmul_sse.h" +#include "nnacl/kernel/matmul_base.h" +#include "nnacl/fp32/matmul_fp32.h" +#include "nnacl/fp32/pack_fp32.h" + +void MatmulSSEInitGlobalVariable(MatmulStruct *matmul) { + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + matmul->matrix_a_.need_pack_ = true; + matmul->matrix_b_.need_pack_ = true; + matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row4MajorParallel : RowMajor2Col4MajorParallel; + matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel; + matmul->compute_.row_tile_ = C4NUM; + matmul->compute_.col_tile_ = C8NUM; + matmul->compute_.col_min_unit_ = C8NUM; +} + +int MatmulSSEParallelRunByBatch(MatmulStruct *matmul, int task_id) { + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_batch = task_id * compute->batch_stride_; + int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_); + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + + for (int index = start_batch; index < end_batch; ++index) { + const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_; + const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_; + float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block8(a, b, c, bias, act, compute->deep_, compute->col_step_); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_); + } + } + return NNACL_OK; +} + +int MatmulSSEParallelRunByOC(MatmulStruct *matmul, int task_id) { + NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR); + MatMulParameter *param = (MatMulParameter *)matmul->base_.param_; + MatmulComputeParam *compute = &matmul->compute_; + ActType act = param->act_type_; + + int start_oc = matmul->split_points_[task_id]; + int end_oc = compute->col_step_; + if (task_id < (matmul->base_.thread_nr_ - 1)) { + end_oc = matmul->split_points_[task_id + 1]; + } + int compute_oc = end_oc - start_oc; + if (compute_oc <= 0) { + return NNACL_OK; + } + int func_flag = 0; + if (compute->row_ == 1) { + func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM; + } + int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_; + + for (int i = 0; i < matmul->batch_; ++i) { + float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_; + float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride; + float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc; + float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc; + + if (func_flag == 0) { + MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc); + } else if (func_flag == C1NUM) { + MatVecMulFp32Block8(a, b, c, bias, act, compute->deep_, compute_oc); + } else { + MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_); + } + } + return NNACL_OK; +} + +KernelBase *CreateMatmulSSE() { + MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase(); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul); + matmul->matmul_type_ = kNotImplemented; + matmul->init_global_varibale_ = MatmulSSEInitGlobalVariable; + matmul->parallel_run_by_oc_ = MatmulSSEParallelRunByOC; + matmul->parallel_run_by_batch_ = MatmulSSEParallelRunByBatch; + return (KernelBase *)matmul; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_sse.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_sse.h new file mode 100644 index 0000000000000000000000000000000000000000..7b1b3332508044edef44f84390282441262579b2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_sse.h @@ -0,0 +1,27 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_MATMUL_SSE_H_ +#define NNACL_KERNEL_MATMUL_SSE_H_ +#ifdef ENABLE_SSE +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +KernelBase *CreateMatmulSSE(); + +#endif +#endif // NNACL_KERNEL_MATMUL_SSE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_struct.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_struct.h new file mode 100644 index 0000000000000000000000000000000000000000..120166add378d4fec26ce8e586061a74379ce031 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/matmul_struct.h @@ -0,0 +1,133 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_KERNEL_MATMUL_STRUCT_H_ +#define NNACL_KERNEL_MATMUL_STRUCT_H_ + +#include "nnacl/kernel.h" +#include "nnacl/matmul_parameter.h" + +#define SPLIT_COUNT MAX_THREAD_NUM + +typedef struct MatrixInfo { + bool need_pack_; + bool has_packed_; // only valid for constant, only do once throughout the process. + bool origin_need_free_; // true when failing to infer shape, false in conv1x1 free in convolution delegate + int pack_size_; + float *origin_ptr_; // only valid for constant, which is synchronized with the 'has_origin'. + float *pack_ptr_; +} MatrixInfo; + +typedef struct MatmulSlice { + int row_s_; + int row_e_; + int col_s_; + int col_e_; +} MatmulSlice; + +typedef struct MatmulComputeParam { + int row_; + int col_; + int deep_; + int row_align_; + int col_align_; + int deep_align_; + int row_num_; + int col_tile_; + int row_tile_; + int col_step_; + int row_min_unit_; + int col_min_unit_; + int batch_stride_; + int pack_b_stride_; + int block_col_unit_; +} MatmulComputeParam; + +typedef struct MatmulStruct { + KernelBase base_; + MatmulComputeParam compute_; + MatmulType matmul_type_; + + /* model pool optimize */ + int model_thread_nr_; + + /* batch-matmul broadcast */ + int batch_; + int a_batch_; + int b_batch_; + int *a_offset_; /* batch_ size */ + int *b_offset_; /* batch_ size */ + + int split_points_[SPLIT_COUNT]; + + float *output_data_; + float *pack_b_src_; + float *pack_b_dst_; + + bool a_const_; + bool b_const_; + bool bias_need_repack_; + bool infer_shape_; + bool pack_opt_; + bool is_sharing_pack_; + bool out_need_aligned_; + bool weight_is_packed_; + bool support_mul_batch_cut_by_row_; + + MatrixInfo matrix_a_; + MatrixInfo matrix_b_; + MatrixInfo matrix_c_; + + void (*matrix_a_pack_fun_)(const float *src_ptr, float *dst_ptr, int row, int col, int start_row, int end_row); + void (*matrix_b_pack_fun_)(const float *src_ptr, float *dst_ptr, int row, int col, int start_row, int end_row); + + int (*pack_matrix_a_impl_opt_)(struct MatmulStruct *matmul); + int (*pack_matrix_a_impl_)(struct MatmulStruct *matmul); + int (*pack_matrix_b_impl_)(struct MatmulStruct *matmul); + + int (*init_parameter_)(struct MatmulStruct *matmul); + void (*init_global_varibale_)(struct MatmulStruct *matmul); + + bool (*check_thread_cutting_by_row_)(struct MatmulStruct *matmul); + void (*get_thread_cutting_policy_)(struct MatmulStruct *matmul); + void (*get_thread_cutting_info_by_row_)(struct MatmulStruct *matmul); + + void *shaing_manager_; + void *(*get_sharing_weight_)(void *manager, const void *tensor_data, const size_t size, bool *is_packed); + void (*free_sharing_weight_)(void *manager, void *tensor_data); + + void (*gemm_not_pack_fun_)(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type); + + int (*parallel_run_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_row_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_oc_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_batch_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_not_pack_by_batch_)(struct MatmulStruct *matmul, int task_id); + + /* optimize for avx512 */ + int col_split_points_size_; + int row_split_points_size_; + int col_split_points_[SPLIT_COUNT]; + int row_split_points_[SPLIT_COUNT]; + int matmul_slice_count_[SPLIT_COUNT]; + MatmulSlice matmul_slice_set_[SPLIT_COUNT][SPLIT_COUNT]; + int (*parallel_run_by_gemm_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_gepm_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_gepdot_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_batch_col_row_gemm_)(struct MatmulStruct *matmul, int task_id); + int (*parallel_run_by_row1_deep1_gepdot_)(struct MatmulStruct *matmul, int task_id); +} MatmulStruct; + +#endif // NNACL_KERNEL_MATMUL_STRUCT_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/nllloss.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/nllloss.c new file mode 100644 index 0000000000000000000000000000000000000000..60d9101c6cd2cbec0824254087f74c14c730d899 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/nllloss.c @@ -0,0 +1,63 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/nllloss.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/fp32/nllloss_fp32.h" +#include "nnacl/nllloss_parameter.h" + +int NlllossCompute(KernelBase *self) { + NLLLossStruct *nllloss = (NLLLossStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(nllloss); + float *logits = self->in_[Index0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(logits); + int *labels = self->in_[Index1]->data_; + NNACL_CHECK_NULL_RETURN_ERR(labels); + float *weight = self->in_[Index2]->data_; + NNACL_CHECK_NULL_RETURN_ERR(weight); + + float *loss = self->out_[Index0]->data_; + NNACL_CHECK_NULL_RETURN_ERR(loss); + float *total_weight = self->out_[Index1]->data_; + NNACL_CHECK_NULL_RETURN_ERR(total_weight); + + ReductionType reduction_type = ((NLLLossParameter *)self->param_)->reduction_type_; + return NLLLoss(logits, labels, weight, loss, total_weight, nllloss, reduction_type); +} + +int NlllossPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < TWO_TENSOR, NNACL_ERR); + NLLLossStruct *nllloss = (NLLLossStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(nllloss); + TensorC *logits_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(logits_tensor); + nllloss->batch_ = logits_tensor->shape_[Index0]; + nllloss->class_num_ = logits_tensor->shape_[Index1]; + return NNACL_OK; +} + +KernelBase *CreateNLLLoss(OpParameter *param, int data_type) { + NLLLossStruct *nllloss = (NLLLossStruct *)malloc(sizeof(NLLLossStruct)); + NNACL_CHECK_NULL_RETURN_NULL(nllloss); + nllloss->base_.Release = DefaultRelease; + nllloss->base_.Prepare = NlllossPrepare; + nllloss->base_.Resize = DefaultResize; + nllloss->base_.Compute = NlllossCompute; + return (KernelBase *)nllloss; +} + +REG_KERNEL_CREATOR(PrimType_NLLLoss, kNumberTypeFloat32, CreateNLLLoss) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/nllloss.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/nllloss.h new file mode 100644 index 0000000000000000000000000000000000000000..5ddfb486c533e6a8b2464524bf1089af0bfde531 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/nllloss.h @@ -0,0 +1,32 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_NLLLOSS_H_ +#define NNACL_KERNEL_NLLLOSS_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct { + KernelBase base_; + int batch_; + int class_num_; +} NLLLossStruct; + +KernelBase *CreateNLLLoss(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_NLLLOSS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/non_max_suppression.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/non_max_suppression.c new file mode 100644 index 0000000000000000000000000000000000000000..1dc0736e5fa3160471b38d4d5f009f1f1777000b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/non_max_suppression.c @@ -0,0 +1,126 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/non_max_suppression.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/non_max_suppression_parameter.h" +#include "nnacl/fp32/non_max_suppression_fp32.h" + +void NonMaxSuppressioExpandDims(int *dst_shape, int *origin_shape, size_t size) { + int i = 0; + for (; i < size; i++) { + dst_shape[i] = 1; + } + for (; i < Num3; i++) { + dst_shape[i] = origin_shape[i - size]; + } +} + +void NonMaxSuppressionGetParams(NonMaxSuppressionStruct *nm_suppression) { + // optional input order: max_output_per_class, iou_threshold, score_threshold + nm_suppression->max_output_per_class_ = 0; + if (nm_suppression->base_.in_size_ >= Num3) { + TensorC *max_output_tensor = nm_suppression->base_.in_[Index3]; + if (max_output_tensor != NULL && max_output_tensor->data_ != NULL) { + nm_suppression->max_output_per_class_ = *(int *)(max_output_tensor->data_); + } + } + + nm_suppression->iou_threshold_ = 0.0f; + if (nm_suppression->base_.in_size_ >= Num4) { + TensorC *iou_threshold_tensor = nm_suppression->base_.in_[Index4]; + if (iou_threshold_tensor != NULL && iou_threshold_tensor->data_ != NULL) { + nm_suppression->iou_threshold_ = *(float *)(iou_threshold_tensor->data_); + } + } + + nm_suppression->score_threshold_ = 0.0f; + if (nm_suppression->base_.in_size_ >= Num5) { + TensorC *score_threshold_tensor = nm_suppression->base_.in_[Index5]; + if (score_threshold_tensor != NULL && score_threshold_tensor->data_ != NULL) { + nm_suppression->score_threshold_ = *(float *)(score_threshold_tensor->data_); + } + } +} + +int NonMaxSuppressionCompute(KernelBase *self) { + NonMaxSuppressionStruct *nm_suppression = (NonMaxSuppressionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(nm_suppression); + + NonMaxSuppressionGetParams(nm_suppression); + + TensorC *box_tensor = self->in_[Index0]; + NNACL_CHECK_NULL_RETURN_ERR(box_tensor); + int box_dims[Num3] = {0}; // batch, box_num, 4 + bool simple_out = false; + if (box_tensor->shape_size_ != Num3) { + NonMaxSuppressioExpandDims(box_dims, box_tensor->shape_, Num3 - box_tensor->shape_size_); + simple_out = true; + } + if (box_dims[Index2] != Num4) { + return NNACL_NON_MAX_SUPPRESSION_BOX_DIMS_INVALID; + } + + TensorC *score_tensor = self->in_[Index1]; + NNACL_CHECK_NULL_RETURN_ERR(score_tensor); + int score_dims[Num3] = {0}; // batch, class, box_num + if (score_tensor->shape_size_ != Num3) { + NonMaxSuppressioExpandDims(score_dims, score_tensor->shape_, Num3 - score_tensor->shape_size_); + } + if (score_dims[Index0] != box_dims[Index0]) { + return NNACL_NON_MAX_SUPPRESSION_BOX_DIMS_SCORE_UNMATCH; + } + if (score_dims[Index2] != box_dims[Index1]) { + return NNACL_NON_MAX_SUPPRESSION_DIMENSION_SPATIAL_UNMATCH; + } + if (nm_suppression->base_.out_[OUTPUT_INDEX]->data_ != NULL) { + /* output shape and data set in compute */ + return NNACL_NON_MAX_SUPPRESSION_UNSUPPORT_DEFINE_DATA; + } + return NonMaxSuppressionSelecte(nm_suppression, simple_out, score_dims); +} + +int NonMaxSuppressionPrepare(KernelBase *self) { + NonMaxSuppressionStruct *nm_suppression = (NonMaxSuppressionStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(nm_suppression); + + // boxes, scores, max_output_boxes, iou_threshold, score_threshold + if (self->in_size_ < Num2 || self->in_size_ > Num5 || self->out_size_ != Num1) { + return NNACL_NON_MAX_SUPPRESSION_TENSOR_SIZE_INVALID; + } + + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NMSParameter *nmparam = (NMSParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(nmparam); + if (nmparam->center_point_box_ != 0 && nmparam->center_point_box_ != 1) { + return NNACL_NON_MAX_SUPPRESSION_PARAM_INVALID; + } + nm_suppression->center_point_box_ = nmparam->center_point_box_; + return NNACL_OK; +} + +KernelBase *CreateNonMaxSuppression(OpParameter *param, int data_type) { + NonMaxSuppressionStruct *non_max_suppression = (NonMaxSuppressionStruct *)malloc(sizeof(NonMaxSuppressionStruct)); + NNACL_CHECK_NULL_RETURN_NULL(non_max_suppression); + non_max_suppression->base_.Release = DefaultRelease; + non_max_suppression->base_.Resize = DefaultResize; + non_max_suppression->base_.Prepare = NonMaxSuppressionPrepare; + non_max_suppression->base_.Compute = NonMaxSuppressionCompute; + return (KernelBase *)non_max_suppression; +} + +REG_KERNEL_CREATOR(PrimType_NonMaxSuppression, kNumberTypeFloat32, CreateNonMaxSuppression) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/non_max_suppression.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/non_max_suppression.h new file mode 100644 index 0000000000000000000000000000000000000000..5c9988ec56e2ed61f13c772ec96a79cd3fe6f6ed --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/non_max_suppression.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_NON_MAX_SUPPRESSION_H_ +#define NNACL_KERNEL_NON_MAX_SUPPRESSION_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct { + KernelBase base_; + int center_point_box_; + int max_output_per_class_; + float iou_threshold_; + float score_threshold_; +} NonMaxSuppressionStruct; + +KernelBase *CreateNonMaxSuppression(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_NON_MAX_SUPPRESSION_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/non_zero.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/non_zero.c new file mode 100644 index 0000000000000000000000000000000000000000..693ac2f4957d0e8b24aeed96d2ce4c11fa4dfa52 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/non_zero.c @@ -0,0 +1,69 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/non_zero.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" + +int NonZeroCompute(KernelBase *self) { + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + NNACL_CHECK_FALSE(input->shape_size_ != DIMENSION_2D, NNACL_NON_ZERO_SHAPE_INVALID); + + bool *input_data = (bool *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + int *output_data = (int *)output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + int non_zero_nums = output->shape_[Index1]; + int non_zero_count = 0; + + int *coordiate_values = (int *)self->env_->Alloc(self->env_->allocator_, input->shape_size_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(coordiate_values); + + for (int i = 0; i < NNACLGetElementNum(input); i += 1) { + if (input_data[i]) { + for (size_t j = 0; j < input->shape_size_; j++) { + output_data[non_zero_count + (int)j * non_zero_nums] = coordiate_values[j]; + } + non_zero_count++; + } + for (size_t idx = input->shape_size_; idx >= 1; --idx) { + if (coordiate_values[idx - 1] != input->shape_[idx - 1] - 1) { + coordiate_values[idx - 1] = coordiate_values[idx - 1] + 1; + break; + } + coordiate_values[idx - 1] = 0; + } + } + + return NNACL_OK; +} + +KernelBase *CreateNonZero(OpParameter *param, int data_type) { + NonZeroStruct *non_zero = (NonZeroStruct *)malloc(sizeof(NonZeroStruct)); + NNACL_CHECK_NULL_RETURN_NULL(non_zero); + non_zero->base_.Release = DefaultRelease; + non_zero->base_.Prepare = DefaultPrepare2In1Out; + non_zero->base_.Resize = DefaultResize; + non_zero->base_.Compute = NonZeroCompute; + return (KernelBase *)non_zero; +} + +REG_KERNEL_CREATOR(PrimType_NonZero, kNumberTypeBool, CreateNonZero) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/non_zero.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/non_zero.h new file mode 100644 index 0000000000000000000000000000000000000000..a7b97ab498bb9a1a41a8aeb0bdf46076daf272a0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/non_zero.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_NON_ZERO_H_ +#define NNACL_KERNEL_NON_ZERO_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct { + KernelBase base_; +} NonZeroStruct; + +KernelBase *CreateNonZero(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_NON_ZERO_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/one_hot.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/one_hot.c new file mode 100644 index 0000000000000000000000000000000000000000..b51dd0bbb903190defa08aa2eb9257812ebafcc1 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/one_hot.c @@ -0,0 +1,193 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/one_hot.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/one_hot_parameter.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/fp32/one_hot_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/one_hot_fp16.h" +#endif + +int OneHotRun(void *cdata, int task_id, float l, float r) { + OneHotStruct *one_hot = (OneHotStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(one_hot); + + int *indices_data = (int *)one_hot->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(indices_data); + + TensorC *output_tensor = one_hot->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + void *output_data = one_hot->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + if (output_tensor->data_type_ == kNumberTypeFloat32) { + return OneHotToFp32(indices_data, one_hot->on_value_, one_hot->off_value_, (float *)output_data, one_hot, task_id, + one_hot->base_.thread_nr_); +#ifdef ENABLE_FP16 + } else if (output_tensor->data_type_ == kNumberTypeFloat16) { + return OneHotToFp16(indices_data, (float16_t)one_hot->on_value_, (float16_t)one_hot->off_value_, + (float16_t *)output_data, one_hot, task_id, one_hot->base_.thread_nr_); +#endif + } + + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int OneHotInitOnOffValueForFourInputs(OneHotStruct *one_hot) { + TensorC *on_value_tensor = one_hot->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(on_value_tensor); + void *on_value_data = on_value_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(on_value_data); + if (on_value_tensor->data_type_ == kNumberTypeFloat32) { + one_hot->on_value_ = *((float *)on_value_data); +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) + } else if (on_value_tensor->data_type_ == kNumberTypeFloat16) { + one_hot->on_value_ = *((float16_t *)on_value_data); +#endif + } else { + return NNACL_ONE_HOR_ON_VALUE_TENSOR_DATA_TYPE_INVALID; + } + + TensorC *off_value_tensor = one_hot->base_.in_[FOURTH_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(off_value_tensor); + void *off_value_data = off_value_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(off_value_data); + if (on_value_tensor->data_type_ == kNumberTypeFloat32) { + one_hot->off_value_ = *((float *)off_value_data); +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) + } else if (on_value_tensor->data_type_ == kNumberTypeFloat16) { + one_hot->off_value_ = *((float16_t *)off_value_data); +#endif + } else { + return NNACL_ONE_HOR_OFF_VALUE_TENSOR_DATA_TYPE_INVALID; + } + + return NNACL_OK; +} + +int OneHotInitOnOffValueForThreeInputs(OneHotStruct *one_hot) { + TensorC *value_tensor = one_hot->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(value_tensor); + void *value_data = value_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(value_data); + + if (value_tensor->data_type_ == kNumberTypeFloat32) { + one_hot->off_value_ = ((float *)value_data)[Index0]; + one_hot->on_value_ = ((float *)value_data)[Index1]; +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) + } else if (value_tensor->data_type_ == kNumberTypeFloat16) { + one_hot->off_value_ = ((float16_t *)value_data)[Index0]; + one_hot->on_value_ = ((float16_t *)value_data)[Index1]; +#endif + } else { + return NNACL_ONE_HOR_ON_OFF_VALUE_TENSOR_DATA_TYPE_INVALID; + } + return NNACL_OK; +} + +int OneHotInitParamsAndOnOffValue(OneHotStruct *one_hot) { + TensorC *depth_tensor = one_hot->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(depth_tensor); + + if (depth_tensor->data_type_ == kNumberTypeInt32) { + const int *depth = (int *)depth_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(depth); + one_hot->depth_ = *depth; + } else { + return NNACL_ONE_HOR_DEPTH_TENSOR_DATA_TYPE_INVALID; + } + + if (one_hot->base_.in_size_ == FOUR_TENSOR) { + // 4 inputs: indices, depth, on_value, off_value + one_hot->support_neg_index_ = false; + int ret = OneHotInitOnOffValueForFourInputs(one_hot); + if (ret != NNACL_OK) { + return ret; + } + } else { + // 3 inputs: indices, depth, off_on_value + one_hot->support_neg_index_ = true; + int ret = OneHotInitOnOffValueForThreeInputs(one_hot); + if (ret != NNACL_OK) { + return ret; + } + } + return NNACL_OK; +} + +int OneHotCompute(KernelBase *self) { + OneHotStruct *one_hot = (OneHotStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(one_hot); + int ret = OneHotInitParamsAndOnOffValue(one_hot); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, OneHotRun, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + + return NNACL_OK; +} + +int OneHotPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ != FOUR_TENSOR && self->in_size_ != THREE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ != ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + TypeIdC data_type = self->in_[FIRST_INPUT]->data_type_; + NNACL_CHECK_FALSE(data_type != kNumberTypeInt32 && data_type != kNumberTypeInt64, NNACL_OUTPUT_TENSOR_ERROR); + return NNACL_OK; +} + +int OneHotResize(KernelBase *self) { + OneHotStruct *one_hot = (OneHotStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(one_hot); + + TensorC *indices = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(indices); + + int origin_axis = ((OneHotParameter *)self->param_)->axis_; + one_hot->axis_ = origin_axis < 0 ? origin_axis + (int)indices->shape_size_ + 1 : origin_axis; + NNACL_CHECK_FALSE(one_hot->axis_ < 0 && one_hot->axis_ > (int)indices->shape_size_, NNACL_ONE_HOT_AXIS_INVALID); + + one_hot->outer_size_ = 1; + for (int i = 0; i < one_hot->axis_; i++) { + one_hot->outer_size_ *= indices->shape_[i]; + } + if (one_hot->outer_size_ == 0) { + return NNACL_ONE_HOT_OUTER_SIZE_INVALID; + } + one_hot->inner_size_ = NNACLGetElementNum(indices) / one_hot->outer_size_; + NNACL_CHECK_FALSE(one_hot->inner_size_ <= 0, NNACL_ONE_HOT_INNER_SIZE_INVALID); + + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_OneHot), one_hot->inner_size_, one_hot->outer_size_, + NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + return NNACL_OK; +} + +KernelBase *CreateOneHot(OpParameter *param, int data_type) { + OneHotStruct *one_hot = (OneHotStruct *)malloc(sizeof(OneHotStruct)); + NNACL_CHECK_NULL_RETURN_NULL(one_hot); + one_hot->base_.Release = DefaultRelease; + one_hot->base_.Prepare = OneHotPrepare; + one_hot->base_.Resize = OneHotResize; + one_hot->base_.Compute = OneHotCompute; + return (KernelBase *)one_hot; +} + +REG_KERNEL_CREATOR(PrimType_OneHot, kNumberTypeInt32, CreateOneHot) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/one_hot.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/one_hot.h new file mode 100644 index 0000000000000000000000000000000000000000..d9defb49111cd13d2e5a006b076abc3453a0759e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/one_hot.h @@ -0,0 +1,37 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ONE_HOT_H_ +#define NNACL_KERNEL_ONE_HOT_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct { + KernelBase base_; + int axis_; + int depth_; + int outer_size_; + int inner_size_; + bool support_neg_index_; + float on_value_; + float off_value_; +} OneHotStruct; + +KernelBase *CreateOneHot(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ONE_HOT_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/ones_like.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/ones_like.c new file mode 100644 index 0000000000000000000000000000000000000000..19840fb2f488e8d9eb358164e4b4dd8b91228c3e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/ones_like.c @@ -0,0 +1,67 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/ones_like.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" + +#define ApproximateOnesLike(output, data_size) \ + for (size_t i = 0; i < data_size; ++i) { \ + output[i] = 1; \ + } + +int OnesLikeCompute(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + void *output_ptr = output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + size_t num = (size_t)NNACLGetElementNum(output_tensor); + + if (output_tensor->data_type_ == kNumberTypeFloat32) { + float *output = (float *)output_ptr; + ApproximateOnesLike(output, num); + return NNACL_OK; + } +#ifdef ENABLE_FP16 + if (output_tensor->data_type_ == kNumberTypeFloat16) { + float16_t *output = (float16_t *)output_ptr; + ApproximateOnesLike(output, num); + return NNACL_OK; + } +#endif + if (output_tensor->data_type_ == kNumberTypeInt32) { + int *output = (int *)output_ptr; + ApproximateOnesLike(output, num); + return NNACL_OK; + } + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +KernelBase *CreateOnesLike(OpParameter *param, int data_type) { + OnesLikeStruct *ones_like = (OnesLikeStruct *)malloc(sizeof(OnesLikeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(ones_like); + ones_like->data_type_ = data_type; + ones_like->base_.Release = DefaultRelease; + ones_like->base_.Prepare = DefaultPrepare1In1Out; + ones_like->base_.Resize = DefaultResize; + ones_like->base_.Compute = OnesLikeCompute; + return (KernelBase *)ones_like; +} + +REG_KERNEL_CREATOR(PrimType_OnesLike, kNumberTypeInt32, CreateOnesLike) +REG_KERNEL_CREATOR(PrimType_OnesLike, kNumberTypeFloat32, CreateOnesLike) +REG_KERNEL_CREATOR(PrimType_OnesLike, kNumberTypeFloat16, CreateOnesLike) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/ones_like.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/ones_like.h new file mode 100644 index 0000000000000000000000000000000000000000..d332d4ffcb59c2ef71d36b3188ab68c01069a6c3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/ones_like.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ONES_LIKE_H_ +#define NNACL_KERNEL_ONES_LIKE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct OnesLikeStruct { + KernelBase base_; + int data_type_; +} OnesLikeStruct; + +KernelBase *CreateOnesLike(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ONES_LIKE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pad.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pad.c new file mode 100644 index 0000000000000000000000000000000000000000..3657eb233516d5cf11dfdeb7b921349a1bfb9118 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pad.c @@ -0,0 +1,406 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/pad.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/common_func.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/pad_fp16.h" +#endif +#include "nnacl/fp32/pad_fp32.h" + +int PadInitMirrorPadBlock(PadStruct *pad) { + int left_pads[DEFAULT_PAD_NDIMS] = {0}; + for (size_t i = 0; i < DEFAULT_PAD_NDIMS; ++i) { + left_pads[i] = pad->paddings_[Num2 * i]; + } + + int input_separate_dims[DEFAULT_PAD_NDIMS] = {0}; + int output_separate_dims[DEFAULT_PAD_NDIMS] = {0}; + int separate_offset[DEFAULT_PAD_NDIMS] = {0}; + int separate_size = 0; + + /* init separate dims */ + for (size_t i = 0; i < DEFAULT_PAD_NDIMS; ++i) { + input_separate_dims[separate_size] = pad->in_[i]; + output_separate_dims[separate_size] = pad->out_[i]; + separate_offset[separate_size] = left_pads[i]; + separate_size++; + } + + /* init separate stride */ + int output_separate_stride[DEFAULT_PAD_NDIMS] = {0}; + (void)GetStride(output_separate_stride, output_separate_dims, separate_size); + int remain_stride_size = 0; + int remain_size = 1; + int right_pads[DEFAULT_PAD_NDIMS] = {0}; + for (size_t i = 0; i < DEFAULT_PAD_NDIMS; i++) { + right_pads[i] = output_separate_dims[i] - input_separate_dims[i] - separate_offset[i]; + } + + /* init pad region */ + int pad_region[DEFAULT_PAD_NDIMS] = {0}; + int pad_region_size = 0; + for (int i = remain_stride_size; i < separate_size; ++i) { + int r = 1; + r = (separate_offset[i] > 0) ? (r + 1) : r; + r = (right_pads[i] > 0) ? (r + 1) : r; + pad_region[pad_region_size++] = r; + } + int pad_region_stride[DEFAULT_PAD_NDIMS] = {0}; + int region_size = GetStride(pad_region_stride, pad_region, pad_region_size); + + /* init mirror block info */ + int max_block_size = remain_size * region_size * sizeof(MirrorPadBlock); + pad->mirror_pad_block_ = (MirrorPadBlock *)pad->base_.env_->Alloc(pad->base_.env_->allocator_, max_block_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(pad->mirror_pad_block_); + + // 0: center, 1: left, 2: right + int pad_cord[DEFAULT_PAD_NDIMS] = {0}; + + for (int pos = 0; pos < remain_size; ++pos) { + const int dst_basic_offset = 0; + for (int index = 1; index < region_size; ++index) { + int dst_offset = dst_basic_offset; + int value = index; + for (size_t i = 0; i < pad_region_size && pad_region_stride[i] != 0; ++i) { + NNACL_CHECK_ZERO_RETURN_ERR(pad_region_stride[i]); + pad_cord[i] = value / pad_region_stride[i]; + value = value % pad_region_stride[i]; + } + MirrorPadBlock block; + const int size_offset = DEFAULT_PAD_NDIMS - pad_region_size; + for (size_t i = 0; i < pad_region_size; ++i) { + int di = size_offset + i; + int si = remain_stride_size + i; + if (di >= DEFAULT_PAD_NDIMS) { + continue; + } + switch (pad_cord[i]) { + case Num0: + dst_offset += separate_offset[si] * output_separate_stride[si]; + block.size_[di] = input_separate_dims[si]; + block.out_stride_[di] = output_separate_stride[si]; + break; + case Num2: + dst_offset += (separate_offset[si] + input_separate_dims[si]) * output_separate_stride[si]; + block.size_[di] = right_pads[si]; + block.out_stride_[di] = output_separate_stride[si]; + break; + case Num1: + if (separate_offset[si] > 0) { + block.size_[di] = separate_offset[si]; + block.out_stride_[di] = output_separate_stride[si]; + } else { + dst_offset += (separate_offset[si] + input_separate_dims[si]) * output_separate_stride[si]; + block.size_[di] = right_pads[si]; + block.out_stride_[di] = output_separate_stride[si]; + } + break; + default: + break; + } + } + block.out_offset_ = dst_offset; + pad->mirror_pad_block_[pad->mirror_pad_block_size_++] = block; + } + } + return NNACL_OK; +} + +int PadExtendDims(int *dims, const int *origin_dims, int max_dim, int origin_dim, int init_value) { + NNACL_CHECK_NULL_RETURN_ERR(dims); + NNACL_CHECK_NULL_RETURN_ERR(origin_dims); + for (int i = 0; i < max_dim - origin_dim; ++i) { + dims[i] = init_value; + } + for (int i = max_dim - origin_dim; i < max_dim; ++i) { + dims[i] = origin_dims[i - (max_dim - origin_dim)]; + } + return NNACL_OK; +} + +int PadImpl(void *cdata, int task_id, float l, float r) { + PadStruct *pad = (PadStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(pad); + void *input = pad->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input); + void *output = pad->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output); + + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + PadFp16(input, output, pad->in_, pad->out_, pad->paddings_, task_id, pad->base_.thread_nr_); +#endif + } else { + Pad((float *)input, (float *)output, pad->in_, pad->out_, pad->paddings_, task_id, pad->base_.thread_nr_); + } + return NNACL_OK; +} + +int PadFastMirrorRunImpl(PadStruct *pad, int task_id) { + void *in = pad->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(in); + void *out = pad->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out); + + /* copy center part */ + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + PadFp16((float16_t *)in, (float16_t *)out, pad->in_, pad->out_, pad->paddings_, task_id, pad->base_.thread_nr_); +#endif + } else { + Pad((float *)in, (float *)out, pad->in_, pad->out_, pad->paddings_, task_id, pad->base_.thread_nr_); + } + + /* calculate region part */ + for (int i = task_id; i < pad->mirror_pad_block_size_; i += pad->base_.thread_nr_) { + MirrorPadBlock *block = &pad->mirror_pad_block_[i]; + for (int a = 0; a < block->size_[FIRST_INPUT]; a++) { + int out_a_index = block->out_offset_ + a * block->out_stride_[FIRST_INPUT]; + for (int b = 0; b < block->size_[SECOND_INPUT]; b++) { + int out_b_index = out_a_index + b * block->out_stride_[SECOND_INPUT]; + for (int c = 0; c < block->size_[THIRD_INPUT]; ++c) { + int out_c_index = out_b_index + c * block->out_stride_[THIRD_INPUT]; + for (int d = 0; d < block->size_[FOURTH_INPUT]; ++d) { + int out_d_index = out_c_index + d * block->out_stride_[FOURTH_INPUT]; + for (int e = 0; e < block->size_[FIFTH_INPUT]; ++e) { + int start_index = out_d_index + e * block->out_stride_[FIFTH_INPUT]; + int end_index = start_index + block->size_[SIXTH_INPUT]; + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + MirrorPadFp16(in, out, pad->in_, pad->in_strides_, pad->out_strides_, pad->paddings_, + pad->mirror_offset_, start_index, end_index); +#endif + } else { + MirrorPad(in, out, pad->in_, pad->in_strides_, pad->out_strides_, pad->paddings_, pad->mirror_offset_, + start_index, end_index); + } + } + } + } + } + } + } + return NNACL_OK; +} + +int MirrorPadImpl(void *cdata, int task_id, float l, float r) { + PadStruct *pad = (PadStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(pad); + + /* Fast Mirror pad */ + if (pad->mirror_pad_block_size_ != 0) { + return PadFastMirrorRunImpl(pad, task_id); + } + + TensorC *input = pad->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + void *input_data = input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + TensorC *output = pad->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + void *output_data = output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + + /* Common Mirror pad */ + int unit = UP_DIV(NNACLGetElementNum(output), pad->base_.thread_nr_); + int begin = unit * task_id; + int end = NNACL_MIN(begin + unit, NNACLGetElementNum(output)); + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + MirrorPadFp16((float16_t *)input_data, (float16_t *)output_data, pad->in_, pad->in_strides_, pad->out_strides_, + pad->paddings_, pad->mirror_offset_, begin, end); +#endif + } else { + MirrorPad((float *)input_data, (float *)output_data, pad->in_, pad->in_strides_, pad->out_strides_, pad->paddings_, + pad->mirror_offset_, begin, end); + } + return NNACL_OK; +} + +int PadCheckPaddings(const int *paddings, int length, const int *input_shape, int mode) { + NNACL_CHECK_NULL_RETURN_ERR(paddings); + NNACL_CHECK_NULL_RETURN_ERR(input_shape); + int offset = mode == PaddingMode_Symmetric ? 0 : 1; + for (int i = 0; i < length; ++i) { + int max_valid = input_shape[i] - offset; + if (paddings[i * Num2] > max_valid) { + return NNACL_PAD_MIRROR_PAD_SIZE_INVALID; + } + if (paddings[i * Num2 + 1] > max_valid) { + return NNACL_PAD_MIRROR_PAD_SIZE_INVALID; + } + } + return NNACL_OK; +} + +int PadCopyPaddingFromInput(PadStruct *pad) { + TensorC *input_tensor = pad->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *padding_tensor = pad->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(padding_tensor); + int *padding_data = padding_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(padding_data); + + (void)PadExtendDims(pad->in_, input_tensor->shape_, DEFAULT_PAD_NDIMS, input_tensor->shape_size_, 1); + (void)PadExtendDims(pad->paddings_, padding_data, MAX_PAD_SIZE, NNACLGetElementNum(padding_tensor), 0); + pad->paddings_size_ = MAX_PAD_SIZE; + + return NNACL_OK; +} + +void PadCalculateStrides(PadStruct *pad) { + pad->in_strides_[DEFAULT_PAD_NDIMS - 1] = 1; + for (int i = DEFAULT_PAD_NDIMS - Num2; i >= 0; --i) { + pad->in_strides_[i] = pad->in_[i + 1] * pad->in_strides_[i + 1]; + } + for (int i = 0; i < DEFAULT_PAD_NDIMS; ++i) { + pad->out_[i] = pad->in_[i] + pad->paddings_[i * Num2] + pad->paddings_[i * Num2 + 1]; + } + pad->out_strides_[DEFAULT_PAD_NDIMS - 1] = 1; + for (int i = DEFAULT_PAD_NDIMS - Num2; i >= 0; --i) { + pad->out_strides_[i] = pad->out_[i + 1] * pad->out_strides_[i + 1]; + } +} + +int PadHandleMirrorPad(PadStruct *pad) { + pad->mirror_offset_ = pad->pad_mode_ == PaddingMode_Reflect ? 1 : 0; + (void)PadCheckPaddings(pad->paddings_, DEFAULT_PAD_NDIMS, pad->in_, pad->pad_mode_); + PadCalculateStrides(pad); + return PadInitMirrorPadBlock(pad); +} + +int PadCompute(KernelBase *self) { + PadStruct *pad = (PadStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(pad); + + if (self->in_size_ == THREE_TENSOR) { + TensorC *pad_value_tensor = self->in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(pad_value_tensor); + NNACL_CHECK_FALSE(NNACLGetElementNum(pad_value_tensor) != 1, NNACL_PAD_PADDING_VALID_INVALID); + void *pad_valud = pad_value_tensor->data_; + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + pad->constant_value_ = ((float16_t *)pad_valud)[Index0]; +#endif + } else { + pad->constant_value_ = ((float *)pad_valud)[Index0]; + } + } + + int ret = PadCopyPaddingFromInput(pad); + if (ret != NNACL_OK) { + return ret; + } + + if (pad->pad_mode_ == PaddingMode_Constant) { + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + size_t output_size = NNACLGetElementNum(output); + void *output_data = output->data_; + if (fabsf(pad->constant_value_ - 0.0f) < 1e-5) { + memset(output_data, 0, output_size * (int)DataTypeCSize(pad->data_type_)); + } else { + for (size_t i = 0; i < output_size; ++i) { + if (pad->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + ((float16_t *)output_data)[i] = pad->constant_value_; +#endif + } else { + ((float *)output_data)[i] = pad->constant_value_; + } + } + } + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, PadImpl, self, self->thread_nr_); + return ret; + } + + /* not constant pad mod using mirror pad algorithm */ + ret = PadHandleMirrorPad(pad); + if (ret != NNACL_OK) { + return ret; + } + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, MirrorPadImpl, self, self->thread_nr_); + + self->env_->Free(self->env_->allocator_, pad->mirror_pad_block_); + pad->mirror_pad_block_ = NULL; + pad->mirror_pad_block_size_ = 0; + return ret; +} + +int PadResize(KernelBase *self) { + PadStruct *pad = (PadStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(pad); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *padding = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(padding); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + int rank = input->shape_size_; + NNACL_CHECK_FALSE(input->shape_size_ > DEFAULT_PAD_NDIMS, NNACL_PAD_SHAPE_INVALID); + NNACL_CHECK_FALSE(NNACLGetElementNum(padding) != rank + rank, NNACL_PAD_SHAPE_INVALID); + + if (pad->pad_mode_ == PaddingMode_Constant) { + (void)PadExtendDims(pad->in_, input->shape_, DEFAULT_PAD_NDIMS, rank, 1); + (void)PadExtendDims(pad->out_, output->shape_, DEFAULT_PAD_NDIMS, rank, 1); + + if (pad->paddings_size_ < MAX_PAD_SIZE) { + int ori_paddings[MAX_PAD_SIZE]; + memcpy(ori_paddings, pad->paddings_, MAX_PAD_SIZE * sizeof(int)); + (void)PadExtendDims(pad->paddings_, ori_paddings, MAX_PAD_SIZE, pad->paddings_size_, 0); + pad->paddings_size_ = MAX_PAD_SIZE; + } + } + return NNACL_OK; +} + +int PadPrepare(KernelBase *self) { + NNACL_CHECK_TRUE_RET(self->in_size_ == TWO_TENSOR || self->in_size_ == THREE_TENSOR, NNACL_ERR); + NNACL_CHECK_TRUE_RET(self->out_size_ == ONE_TENSOR, NNACL_ERR); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_FALSE(input->data_type_ != kNumberTypeFloat32 && input->data_type_ != kNumberTypeFloat16, NNACL_ERR); + return NNACL_OK; +} + +KernelBase *CreatePad(OpParameter *param, int data_type) { + PadStruct *pad = (PadStruct *)malloc(sizeof(PadStruct)); + NNACL_CHECK_NULL_RETURN_NULL(pad); + memset(pad, 0, sizeof(PadStruct)); + + pad->data_type_ = data_type; + + PadParameter *pad_param = (PadParameter *)param; + pad->pad_mode_ = pad_param->pad_mode_; + pad->constant_value_ = pad_param->constant_value_; + pad->paddings_size_ = pad_param->padding_length; + memcpy(pad->paddings_, pad_param->paddings_, MAX_PAD_SIZE * sizeof(int)); + + pad->base_.Release = DefaultRelease; + pad->base_.Prepare = PadPrepare; + pad->base_.Resize = PadResize; + pad->base_.Compute = PadCompute; + return (KernelBase *)pad; +} + +REG_KERNEL_CREATOR(PrimType_PadFusion, kNumberTypeFloat32, CreatePad) +REG_KERNEL_CREATOR(PrimType_PadFusion, kNumberTypeFloat16, CreatePad) diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/delegate_allocator.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pad.h similarity index 42% rename from mindspore-lite/src/extendrt/delegate/ascend_native/delegate_allocator.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/pad.h index d560665bc9d519790872d6f761bff60558d0c68a..dc6f476baac640dcbb913abc2cf9ca12bdeb7490 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/delegate_allocator.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pad.h @@ -1,3 +1,4 @@ + /** * Copyright 2023 Huawei Technologies Co., Ltd * @@ -13,26 +14,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_DELEGATE_ALLOCATOR_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_DELEGATE_ALLOCATOR_H_ -#include "include/api/allocator.h" +#ifndef NNACL_KERNEL_PAD_H_ +#define NNACL_KERNEL_PAD_H_ + +#include +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" +#include "nnacl/pad_parameter.h" -namespace mindspore { -class DelegateAllocator : public Allocator { - public: - explicit DelegateAllocator(void *stream) : stream_(stream) {} - void *Malloc(size_t size) override; - void Free(void *ptr) override; - int RefCount(void *ptr) override { return 0; }; - int SetRefCount(void *ptr, int ref_count) override { return 0; } - int IncRefCount(void *ptr, int ref_count) override { return 0; } - int DecRefCount(void *ptr, int ref_count) override { return 0; } +typedef struct MirrorPadBlock { + int out_offset_; + int out_stride_[DEFAULT_PAD_NDIMS]; + int size_[DEFAULT_PAD_NDIMS]; +} MirrorPadBlock; - private: - void *stream_{nullptr}; -}; +typedef struct PadStruct { + KernelBase base_; + int data_type_; + int mirror_offset_; + float constant_value_; + int pad_mode_; + int paddings_[MAX_PAD_SIZE]; + int paddings_size_; + int in_[DEFAULT_PAD_NDIMS]; + int out_[DEFAULT_PAD_NDIMS]; + int in_strides_[DEFAULT_PAD_NDIMS]; + int out_strides_[DEFAULT_PAD_NDIMS]; + MirrorPadBlock *mirror_pad_block_; + int mirror_pad_block_size_; +} PadStruct; -} // namespace mindspore +KernelBase *CreatePad(OpParameter *param, int data_type); -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_DELEGATE_ALLOCATOR_H_ +#endif // NNACL_KERNEL_PAD_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pooling.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pooling.c new file mode 100644 index 0000000000000000000000000000000000000000..17d5f2edb046f73399b59966e442ee730d34e228 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pooling.c @@ -0,0 +1,159 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/pooling.h" +#include +#include "nnacl/pooling_parameter.h" +#include "nnacl/fp32/pooling_fp32.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/kernel/default_kernel_base.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/pooling_fp16.h" +#endif + +int PoolingF16RunImpl(PoolingStruct *pooling, int task_id) { +#ifdef ENABLE_FP16 + PoolingParameter *param = (PoolingParameter *)pooling->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + float16_t *input_ptr = (float16_t *)pooling->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + float16_t *output_ptr = (float16_t *)pooling->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + + if (param->pool_mode_ == PoolMode_MaxPool) { + MaxPoolingFp16(input_ptr, output_ptr, param, &pooling->compute_, task_id, pooling->base_.thread_nr_); + return NNACL_OK; + } else { + return AvgPoolingFp16(input_ptr, output_ptr, param, &pooling->compute_, task_id, pooling->base_.thread_nr_); + } +#endif + return NNACL_DISABLE_FP16; +} + +int PoolingRunImpl(PoolingStruct *pooling, int task_id) { + PoolingParameter *param = (PoolingParameter *)pooling->base_.param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + TensorC *input_tensor = pooling->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + float *input_ptr = (float *)input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + float *output_ptr = (float *)pooling->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + + if (input_tensor->format_ == Format_NC4HW4) { + if (param->pool_mode_ == PoolMode_MaxPool) { + return MaxPoolingFromNC4HW4ToNHWC(input_ptr, output_ptr, param, &pooling->compute_, task_id, + pooling->base_.thread_nr_); + } else { + return AvgPoolingFromNC4HW4ToNHWC(input_ptr, output_ptr, param, &pooling->compute_, task_id, + pooling->base_.thread_nr_); + } + } else if (input_tensor->format_ == Format_NHWC) { + if (param->pool_mode_ == PoolMode_MaxPool) { + return MaxPooling(input_ptr, output_ptr, param, &pooling->compute_, task_id, pooling->base_.thread_nr_); + } else { + return AvgPooling(input_ptr, output_ptr, param, &pooling->compute_, task_id, pooling->base_.thread_nr_); + } + } + + return NNACL_UNSUPPORTED_FORMAT; +} + +int PoolingImpl(void *cdata, int task_id, float l, float r) { + PoolingStruct *pooling = (PoolingStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(cdata); + if (pooling->data_type_ == kNumberTypeFloat16) { + return PoolingF16RunImpl(pooling, task_id); + } else if (pooling->data_type_ == kNumberTypeFloat32) { + return PoolingRunImpl(pooling, task_id); + } + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int PoolingCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, PoolingImpl, self, self->thread_nr_); +} + +int PoolingResize(KernelBase *self) { + PoolingStruct *pooling = (PoolingStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(pooling); + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + + PoolingComputeParam *compute = &pooling->compute_; + PoolingParameter *param = (PoolingParameter *)self->param_; + + compute->input_batch_ = NNACLGetBatch(in_tensor); + compute->input_channel_ = NNACLGetChannel(in_tensor); + compute->input_h_ = NNACLGetHeight(in_tensor); + compute->input_w_ = NNACLGetWidth(in_tensor); + compute->output_batch_ = NNACLGetBatch(out_tensor); + compute->output_channel_ = NNACLGetChannel(out_tensor); + compute->output_h_ = NNACLGetHeight(out_tensor); + compute->output_w_ = NNACLGetWidth(out_tensor); + compute->window_h_ = param->window_h_; + compute->window_w_ = param->window_w_; + if (param->global_) { + compute->window_h_ = compute->input_h_; + compute->window_w_ = compute->input_w_; + } + return NNACL_OK; +} + +int PoolingPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < 1, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < 1, NNACL_ERR); + + PoolingStruct *pooling = (PoolingStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(pooling); + PoolingParameter *param = (PoolingParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + float minf = pooling->data_type_ == kNumberTypeFloat32 ? -FLT_MAX : -FLT16_MAX; + float maxf = pooling->data_type_ == kNumberTypeFloat32 ? FLT_MAX : FLT16_MAX; + + if (param->act_type_ == ActType_Relu) { + minf = 0.f; + } else if (param->act_type_ == ActType_Relu6) { + minf = 0.f; + maxf = 6.f; + } + pooling->compute_.minf = minf; + pooling->compute_.maxf = maxf; + + return NNACL_OK; +} + +KernelBase *CreatePooling(OpParameter *param, int data_type) { + PoolingStruct *pooling = (PoolingStruct *)malloc(sizeof(PoolingStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(pooling); + memset(pooling, 0, sizeof(PoolingStruct)); + pooling->data_type_ = data_type; + pooling->base_.Release = DefaultRelease; + pooling->base_.Prepare = PoolingPrepare; + pooling->base_.Resize = PoolingResize; + pooling->base_.Compute = PoolingCompute; + return (KernelBase *)pooling; +} + +REG_KERNEL_CREATOR(PrimType_AvgPoolFusion, kNumberTypeFloat16, CreatePooling) +REG_KERNEL_CREATOR(PrimType_MaxPoolFusion, kNumberTypeFloat16, CreatePooling) +REG_KERNEL_CREATOR(PrimType_AvgPoolFusion, kNumberTypeFloat32, CreatePooling) +REG_KERNEL_CREATOR(PrimType_MaxPoolFusion, kNumberTypeFloat32, CreatePooling) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pooling.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..356d0f95d7b355d8ee7c2e917d70fbde565fa33d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pooling.h @@ -0,0 +1,54 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_POOLING_H_ +#define NNACL_KERNEL_POOLING_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct PoolingComputeParam { + int input_w_; + int input_h_; + int input_batch_; + int input_channel_; + int output_w_; + int output_h_; + int output_batch_; + int output_channel_; + int window_w_; + int window_h_; + float minf; + float maxf; +} PoolingComputeParam; + +typedef struct Pooling3DComputeParam { + PoolingComputeParam pooling_compute_param_; + int input_d_; + int output_d_; + int window_d_; +} Pooling3DComputeParam; + +typedef struct PoolingStruct { + KernelBase base_; + PoolingComputeParam compute_; + int data_type_; +} PoolingStruct; + +KernelBase *CreatePooling(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_POOLING_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pow.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pow.c new file mode 100644 index 0000000000000000000000000000000000000000..3a081855332e7b4ea8c280000348947834a14f75 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pow.c @@ -0,0 +1,79 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/pow.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/fp32/power_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/power_fp16.h" +#endif + +int PowImpl(void *cdata, int task_id, float l, float r) { + PowStruct *pow = (PowStruct *)cdata; + TensorC *input0 = pow->base_.in_[FIRST_INPUT]; + TensorC *input1 = pow->base_.in_[SECOND_INPUT]; + TensorC *output = pow->base_.out_[OUTPUT_INDEX]; + + int size = NNACLGetElementNum(input0); + int stride = UP_DIV(size, pow->base_.thread_nr_); + int len = MSMIN(stride, size - stride * task_id); + if (len <= 0) { + return NNACL_OK; + } + bool broadcast = !ShapeEqual(input0->shape_, input0->shape_size_, input1->shape_, input1->shape_size_); + float scale = ((PowParameter *)pow->base_.param_)->scale_; + float shift = ((PowParameter *)pow->base_.param_)->shift_; + int task_stride = stride * task_id; + + uint8_t *exp_addr = (uint8_t *)input1->data_; + void *cur_exp = NULL; + if (broadcast) { + cur_exp = exp_addr; + } else { + cur_exp = exp_addr + task_stride * DataTypeCSize(pow->data_type_); + } + + if (pow->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + return PowerFp16((float16_t *)input0->data_ + task_stride, (float16_t *)cur_exp, + (float16_t *)output->data_ + task_stride, len, scale, shift, broadcast); +#endif + } else if (pow->data_type_ == kNumberTypeFloat32) { + return Power((float *)input0->data_ + task_stride, (float *)cur_exp, (float *)output->data_ + task_stride, len, + scale, shift, broadcast); + } + return NNACL_POW_INVALID_DATA_TYPE; +} + +int PowCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, PowImpl, self, self->thread_nr_); +} + +KernelBase *CreatePow(OpParameter *param, int data_type) { + PowStruct *pow = (PowStruct *)malloc(sizeof(PowStruct)); + NNACL_CHECK_NULL_RETURN_NULL(pow); + pow->data_type_ = data_type; + pow->base_.Release = DefaultRelease; + pow->base_.Prepare = DefaultPrepare2In1Out; + pow->base_.Resize = DefaultResize; + pow->base_.Compute = PowCompute; + return (KernelBase *)pow; +} + +REG_KERNEL_CREATOR(PrimType_PowFusion, kNumberTypeFloat32, CreatePow) +REG_KERNEL_CREATOR(PrimType_PowFusion, kNumberTypeFloat16, CreatePow) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pow.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pow.h new file mode 100644 index 0000000000000000000000000000000000000000..e87f6d694576fe3f87fb81a109085c5fc67b63f5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/pow.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_POW_H_ +#define NNACL_KERNEL_POW_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct PowStruct { + KernelBase base_; + int data_type_; +} PowStruct; + +KernelBase *CreatePow(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_POW_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/prelu.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/prelu.c new file mode 100644 index 0000000000000000000000000000000000000000..3fffd33fbe38e50495b9d85e561900cd49fa656b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/prelu.c @@ -0,0 +1,111 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/prelu.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/fp32/prelu_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/prelu_fp16.h" +#endif + +int PReluRun(void *cdata, int task_id, float l, float r) { + PReluStruct *prelu = (PReluStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(prelu); + + int thread_num = prelu->base_.thread_nr_; + int num = prelu->channel_shared_ ? prelu->input_num_ : prelu->input_num_ / prelu->channel_num_; + int step = UP_DIV(num, thread_num); + int start = task_id * step; + int end = MSMIN(start + step, num); + + void *in_data = prelu->base_.in_[FIRST_INPUT]->data_; + void *out_data = prelu->base_.out_[OUTPUT_INDEX]->data_; + void *slope_data = prelu->base_.in_[SECOND_INPUT]->data_; + + if (prelu->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + if (prelu->channel_shared_) { + PReluShareChannelFp16((float16_t *)in_data, (float16_t *)out_data, ((float16_t *)slope_data)[0], start, end); + } else { + PReluFp16((float16_t *)in_data, (float16_t *)out_data, (float16_t *)slope_data, start, end, prelu->channel_num_); + } +#endif + } else { + if (prelu->channel_shared_) { + PReluShareChannel((float *)in_data, (float *)out_data, ((float *)slope_data)[0], start, end); + } else { + PRelu((float *)in_data, (float *)out_data, (float *)slope_data, start, end, prelu->channel_num_); + } + } + return NNACL_OK; +} + +int PReluPrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + return NNACL_OK; +} + +int PReluResize(KernelBase *self) { + PReluStruct *prelu = (PReluStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(prelu); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + prelu->input_num_ = NNACLGetElementNum(input); + prelu->channel_num_ = NNACLGetChannel(input); + return NNACL_OK; +} + +int PReluCompute(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[SECOND_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[SECOND_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]->data_); + PReluStruct *prelu = (PReluStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(prelu); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *slope = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(slope); + + int slope_num = NNACLGetElementNum(slope); + if (slope_num == Num1) { + prelu->channel_shared_ = true; + } else if (slope_num == NNACLGetChannel(input)) { + prelu->channel_shared_ = false; + } else { + return NNACL_PRELU_SLOPE_NUM_INVALID; + } + return self->env_->ParallelLaunch(self->env_->thread_pool_, PReluRun, self, self->thread_nr_); +} + +KernelBase *CreatePRelu(OpParameter *param, int data_type) { + PReluStruct *prelu = (PReluStruct *)malloc(sizeof(PReluStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(prelu); + memset(prelu, 0, sizeof(PReluStruct)); + prelu->data_type_ = data_type; + prelu->base_.Prepare = PReluPrepare; + prelu->base_.Resize = PReluResize; + prelu->base_.Compute = PReluCompute; + prelu->base_.Release = DefaultRelease; + return (KernelBase *)prelu; +} + +REG_KERNEL_CREATOR(PrimType_PReLUFusion, kNumberTypeFloat16, CreatePRelu) +REG_KERNEL_CREATOR(PrimType_PReLUFusion, kNumberTypeFloat32, CreatePRelu) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/prelu.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/prelu.h new file mode 100644 index 0000000000000000000000000000000000000000..f622f52adb4aa99c7ea3e2f3320e9c9e7e113554 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/prelu.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_PRELU_H_ +#define NNACL_KERNEL_PRELU_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct PReluStruct { + KernelBase base_; + int data_type_; + int input_num_; + int channel_num_; + bool channel_shared_; +} PReluStruct; + +KernelBase *CreatePRelu(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_PRELU_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/prior_box.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/prior_box.c new file mode 100644 index 0000000000000000000000000000000000000000..8380882ae240f05e040557921dd789d24aff9e3b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/prior_box.c @@ -0,0 +1,190 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/prior_box.h" +#include +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/fp32/prior_box_fp32.h" +#include "nnacl/tensor_c_utils.h" + +int PriorBoxInitOutput(PriorBoxStruct *prior_box, const PriorBoxParameter *param, const float *different_aspect_ratios, + int different_aspect_ratios_size) { + for (int i = 0; i < prior_box->fmap_h_; i++) { + float cy = i + param->offset; + for (int j = 0; j < prior_box->fmap_w_; j++) { + float cx = j + param->offset; + for (int32_t k = 0; k < param->min_sizes_size; k++) { + float min = param->min_sizes[k]; + prior_box->output_[prior_box->output_size_++] = (cx - min / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy - min / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + prior_box->output_[prior_box->output_size_++] = (cx + min / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy + min / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + + if (param->max_sizes_size > 0) { + float max = param->max_sizes[k]; + NNACL_CHECK_FALSE(min * max <= 0, NNACL_PRIOR_BOX_VALUE_INVALID); + float prime = sqrt(min * max); + prior_box->output_[prior_box->output_size_++] = (cx - prime / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy - prime / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + prior_box->output_[prior_box->output_size_++] = (cx + prime / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy + prime / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + } + + for (int m = 0; m < different_aspect_ratios_size; m++) { + float v = different_aspect_ratios[m]; + if (fabs(v - 1.0f) < 1e-6) { + continue; + } + NNACL_CHECK_FALSE(v <= 0, NNACL_PRIOR_BOX_VALUE_INVALID); + float as_square_root = sqrt(v); + NNACL_CHECK_FALSE(as_square_root <= 0, NNACL_PRIOR_BOX_VALUE_INVALID); + float box_w = min * as_square_root; + float box_h = min / as_square_root; + prior_box->output_[prior_box->output_size_++] = (cx - box_w / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy - box_h / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + prior_box->output_[prior_box->output_size_++] = (cx + box_w / prior_box->step_w_ * 0.5f) / prior_box->fmap_w_; + prior_box->output_[prior_box->output_size_++] = (cy + box_h / prior_box->step_h_ * 0.5f) / prior_box->fmap_h_; + } + } + } + } + return NNACL_OK; +} + +int RunPriorBox(void *cdata, int task_id, float l, float r) { + PriorBoxStruct *prior_box = (PriorBoxStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(prior_box); + TensorC *output_tensor = prior_box->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + float *output_data = output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + return PriorBox(prior_box->output_, output_data, NNACLGetSize(output_tensor), task_id, prior_box->base_.thread_nr_); +} + +int PriorBoxRelease(KernelBase *self) { + PriorBoxStruct *prior_box = (PriorBoxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(prior_box); + if (prior_box->output_ != NULL) { + self->env_->Free(self->env_->allocator_, prior_box->output_); + prior_box->output_ = NULL; + prior_box->output_size_ = 0; + } + return NNACL_OK; +} + +int PriorBoxResize(KernelBase *self) { + PriorBoxStruct *prior_box = (PriorBoxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(prior_box); + PriorBoxParameter *param = (PriorBoxParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + TensorC *input0_tensor = prior_box->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input0_tensor); + TensorC *input1_tensor = prior_box->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input1_tensor); + TensorC *output_tensor = prior_box->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + prior_box->fmap_w_ = NNACLGetWidth(input0_tensor); + NNACL_CHECK_ZERO_RETURN_ERR(prior_box->fmap_w_); + prior_box->fmap_h_ = NNACLGetHeight(input1_tensor); + NNACL_CHECK_ZERO_RETURN_ERR(prior_box->fmap_h_); + const int image_w = param->image_size_w > 0 ? param->image_size_w : NNACLGetWidth(input1_tensor); + const int image_h = param->image_size_h > 0 ? param->image_size_h : NNACLGetHeight(input1_tensor); + + prior_box->step_w_ = param->step_w > 0.0f ? param->step_w : (float)(image_w) / prior_box->fmap_w_; + prior_box->step_h_ = param->step_h > 0.0f ? param->step_h : (float)(image_h) / prior_box->fmap_h_; + + float *different_aspect_ratios = + (float *)self->env_->Alloc(self->env_->allocator_, param->aspect_ratios_size * sizeof(float) * Num2); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(different_aspect_ratios); + different_aspect_ratios[Index0] = 1.0f; + int different_aspect_ratios_size = 1; + + float *aspect_ratios = param->aspect_ratios; + for (int32_t i = 0; i < param->aspect_ratios_size; i++) { + float ratio = aspect_ratios[i]; + + bool exist = false; + for (int k = 0; k < different_aspect_ratios_size; k++) { + if (fabs(ratio - different_aspect_ratios[k]) < 1e-6) { + exist = true; + } + } + + if (!exist) { + different_aspect_ratios[different_aspect_ratios_size++] = ratio; + if (param->flip) { + NNACL_CHECK_FALSE(fabs(ratio) <= 1e-5, NNACL_PRIOR_BOX_RATIO_INVALID); + different_aspect_ratios[different_aspect_ratios_size++] = 1.0f / ratio; + } + } + } + + PriorBoxRelease(self); + int size = Num4 + Num4 + different_aspect_ratios_size; + size = size * prior_box->fmap_h_ * prior_box->fmap_w_ * param->min_sizes_size; + size = size + UP_ROUND(NNACLGetHeight(output_tensor), COMM_SHAPE_SIZE); + size = size * sizeof(float); + NNACL_CHECK_MALLOC_SIZE(size); + prior_box->output_ = (float *)self->env_->Alloc(self->env_->allocator_, size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(prior_box->output_); + prior_box->output_size_ = 0; + + int ret = PriorBoxInitOutput(prior_box, param, different_aspect_ratios, different_aspect_ratios_size); + if (ret != NNACL_OK) { + return ret; + } + + // do clip + if (param->clip) { + for (int i = 0; i < prior_box->output_size_; i++) { + float item = prior_box->output_[i]; + if (item > 1.0f) { + item = 1.0f; + } + if (item < 0.0f) { + item = 0.0f; + } + } + } + + // variance + for (int i = 0; i < NNACLGetHeight(output_tensor) / COMM_SHAPE_SIZE; i++) { + for (int j = 0; j < COMM_SHAPE_SIZE; j++) { + prior_box->output_[prior_box->output_size_++] = param->variances[j]; + } + } + return NNACL_OK; +} + +int PriorBoxCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, RunPriorBox, self, self->thread_nr_); +} + +KernelBase *CreatePriorBox(OpParameter *param, int data_type) { + PriorBoxStruct *prior_box = (PriorBoxStruct *)malloc(sizeof(PriorBoxStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(prior_box); + memset(prior_box, 0, sizeof(PriorBoxStruct)); + + prior_box->base_.Prepare = DefaultPrepare2In1Out; + prior_box->base_.Resize = PriorBoxResize; + prior_box->base_.Release = PriorBoxRelease; + prior_box->base_.Compute = PriorBoxCompute; + return (KernelBase *)prior_box; +} + +REG_KERNEL_CREATOR(PrimType_PriorBox, kNumberTypeFloat32, CreatePriorBox) +REG_KERNEL_CREATOR(PrimType_PriorBox, kNumberTypeInt8, CreatePriorBox) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/prior_box.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/prior_box.h new file mode 100644 index 0000000000000000000000000000000000000000..06a6bd34bbb44344db488f34663c985af91e9d77 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/prior_box.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_PRIOR_BOX_H_ +#define NNACL_KERNEL_PRIOR_BOX_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct PriorBoxStruct { + KernelBase base_; + float *output_; + int output_size_; + int fmap_h_; + int fmap_w_; + float step_h_; + float step_w_; +} PriorBoxStruct; + +KernelBase *CreatePriorBox(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_PRIOR_BOX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/ragged_range.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/ragged_range.c new file mode 100644 index 0000000000000000000000000000000000000000..25c66f0bcfecc4eda08435df2cf4761786d5648a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/ragged_range.c @@ -0,0 +1,74 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/ragged_range.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/fp32/ragged_range_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/ragged_range_fp16.h" +#endif + +int RaggedRangeCompute(KernelBase *self) { + RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(ragged_range); + + TensorC *input0 = self->in_[Index0]; + TensorC *input1 = self->in_[Index1]; + TensorC *input2 = self->in_[Index2]; + TensorC *output0 = self->out_[Index0]; + TensorC *output1 = self->out_[Index1]; + + if (input0->data_type_ == kNumberTypeFloat32) { + RaggedRangeFp32((float *)input0->data_, (float *)input1->data_, (float *)input2->data_, (int *)output0->data_, + (float *)output1->data_, ragged_range); + } else if (input0->data_type_ == kNumberTypeInt32) { + RaggedRangeInt((int *)input0->data_, (int *)input1->data_, (int *)input2->data_, (int *)output0->data_, + (int *)output1->data_, ragged_range); + } else if (input0->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + RaggedRangeFp16((float16_t *)input0->data_, (float16_t *)input1->data_, (float16_t *)input2->data_, + (int *)output0->data_, (float16_t *)output1->data_, ragged_range); +#endif + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +int RaggedRangeResize(KernelBase *self) { + RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(ragged_range); + + ragged_range->rows_ = self->out_[OUTPUT_INDEX]->shape_[Index0] - 1; + ragged_range->starts_is_scalar_ = self->in_[FIRST_INPUT]->shape_size_ == 0; + ragged_range->limits_is_scalar_ = self->in_[SECOND_INPUT]->shape_size_ == 0; + ragged_range->deltas_is_scalar_ = self->in_[THIRD_INPUT]->shape_size_ == 0; + return NNACL_OK; +} + +KernelBase *CreateRaggedRange(OpParameter *param, int data_type) { + RaggedRangeStruct *ragged_range = (RaggedRangeStruct *)malloc(sizeof(RaggedRangeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(ragged_range); + ragged_range->base_.Release = DefaultRelease; + ragged_range->base_.Prepare = DefaultPrepare3In2Out; + ragged_range->base_.Resize = RaggedRangeResize; + ragged_range->base_.Compute = RaggedRangeCompute; + return (KernelBase *)ragged_range; +} + +REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeInt32, CreateRaggedRange) +REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeFloat16, CreateRaggedRange) +REG_KERNEL_CREATOR(PrimType_RaggedRange, kNumberTypeFloat32, CreateRaggedRange) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/ragged_range.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/ragged_range.h new file mode 100644 index 0000000000000000000000000000000000000000..881cc7966fef0b57ebe272b367baa34688738e89 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/ragged_range.h @@ -0,0 +1,35 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_RAGGED_RANGE_H_ +#define NNACL_KERNEL_RAGGED_RANGE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct RaggedRangeStruct { + KernelBase base_; + int rows_; + bool starts_is_scalar_; + bool limits_is_scalar_; + bool deltas_is_scalar_; +} RaggedRangeStruct; + +KernelBase *CreateRaggedRange(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_RAGGED_RANGE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/range.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/range.c new file mode 100644 index 0000000000000000000000000000000000000000..99579d1c6f7fd8993cbaa64c844cd02284d314aa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/range.c @@ -0,0 +1,74 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/range.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/range_parameter.h" +#include "nnacl/fp32/range_fp32.h" +#include "nnacl/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/range_fp16.h" +#endif + +int RangeCompute(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + int output_num = NNACLGetElementNum(output); + + if (self->in_size_ == THREE_TENSOR) { + TensorC *delta = self->in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(delta); + + if (input->data_type_ == kNumberTypeFloat32) { + Range((float *)output->data_, *(float *)input->data_, *(float *)delta->data_, output_num); + } else if (input->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + RangeFp16((float16_t *)output->data_, *(float16_t *)input->data_, *(float16_t *)delta->data_, output_num); +#endif + } else if (input->data_type_ == kNumberTypeInt32) { + RangeInt((int *)output->data_, *(int *)input->data_, *(int *)delta->data_, output_num); + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + } else { + if (input->data_type_ == kNumberTypeInt32) { + RangeParameter *param = (RangeParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + RangeInt((int *)output->data_, param->start_, param->delta_, output_num); + } else { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + } + return NNACL_OK; +} + +KernelBase *CreateRange(OpParameter *param, int data_type) { + RangeStruct *range = (RangeStruct *)malloc(sizeof(RangeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(range); + range->base_.Release = DefaultRelease; + range->base_.Prepare = DefaultPrepare1In1Out; + range->base_.Resize = DefaultResize; + range->base_.Compute = RangeCompute; + return (KernelBase *)range; +} + +REG_KERNEL_CREATOR(PrimType_Range, kNumberTypeInt32, CreateRange) +REG_KERNEL_CREATOR(PrimType_Range, kNumberTypeFloat32, CreateRange) +REG_KERNEL_CREATOR(PrimType_Range, kNumberTypeFloat16, CreateRange) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/range.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/range.h new file mode 100644 index 0000000000000000000000000000000000000000..d919c8c98ac594c4a5774848c95fa7762c4f7e3c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/range.h @@ -0,0 +1,31 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_RANGE_H_ +#define NNACL_KERNEL_RANGE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct RangeStruct { + KernelBase base_; +} RangeStruct; + +KernelBase *CreateRange(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_RANGE_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_less_kernel.cc b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/rank.c similarity index 40% rename from mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_less_kernel.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/rank.c index 620123fff3df787b95030e330d2039ae9f982813..aabfb228239e727138938fb1d42afe2dee8de4de 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_less_kernel.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/rank.c @@ -14,31 +14,31 @@ * limitations under the License. */ -#include "extendrt/delegate/ascend_native/ascend_native_less_kernel.h" -#include "extendrt/delegate/ascend_native/ascend_native_kernel_registry.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "abstract/ops/primitive_infer_map.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" - -namespace mindspore::kernel { -using mindspore::ops::kNameLess; - -int AscendNativeLessKernel::InferShape() { - if (out_tensors_[0]->shape().size() == 0) { - if (in_tensors_[0] != nullptr) out_tensors_[0]->set_shape(in_tensors_[0]->shape()); +#include "nnacl/kernel/rank.h" +#include "nnacl/kernel/default_kernel_base.h" + +int RankCompute(KernelBase *self) { + size_t rank = self->in_[FIRST_INPUT]->shape_size_; + void *output_data = self->out_[OUTPUT_INDEX]->data_; + if (self->in_[FIRST_INPUT]->data_type_ == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + *(float16_t *)output_data = (float16_t)rank; +#endif + } else { + *(float *)output_data = (float)rank; } - return kSuccess; + return NNACL_OK; } -int AscendNativeLessKernel::Prepare() { return kSuccess; } - -int AscendNativeLessKernel::Run() { - MS_LOG(INFO) << "AscendNativeLessKernel::Execute"; - - return kSuccess; +KernelBase *CreateRank(OpParameter *param, int data_type) { + RankStruct *rank = (RankStruct *)malloc(sizeof(RankStruct)); + NNACL_CHECK_NULL_RETURN_NULL(rank); + rank->base_.Release = DefaultRelease; + rank->base_.Prepare = DefaultPrepare1In1Out; + rank->base_.Resize = DefaultResize; + rank->base_.Compute = RankCompute; + return (KernelBase *)rank; } -int AscendNativeLessKernel::ReSize() { return kSuccess; } - -REGISTER_ASCEND_NATIVE_CREATOR(kNameLess, AscendNativeLessKernel) -} // namespace mindspore::kernel +REG_KERNEL_CREATOR(PrimType_Rank, kNumberTypeFloat32, CreateRank) +REG_KERNEL_CREATOR(PrimType_Rank, kNumberTypeFloat16, CreateRank) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/rank.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/rank.h new file mode 100644 index 0000000000000000000000000000000000000000..f2717edd99200e79bb8de45f2f3c7f92832ce161 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/rank.h @@ -0,0 +1,31 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_RANK_H_ +#define NNACL_KERNEL_RANK_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct RankStruct { + KernelBase base_; +} RankStruct; + +KernelBase *CreateRank(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_RANK_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reduce.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reduce.c new file mode 100644 index 0000000000000000000000000000000000000000..42b2199ea4d1e4f64b292f307490c810a177a6da --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reduce.c @@ -0,0 +1,434 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/reduce.h" +#include +#include "nnacl/fp32/reduce_fp32.h" +#include "nnacl/kernel/reshape.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/kernel/default_kernel_base.h" + +void InitialReduceKernelList(KernelBase *base) { + ReduceStruct *reduce = (ReduceStruct *)base; + ReduceParameter *param = (ReduceParameter *)(base->param_); + + ReduceKernelList func_list[] = {{Reduce_Sum, ReduceSum, IntReduceSum, NULL, ReduceSumByLastAxis}, + {Reduce_Mean, ReduceMean, IntReduceMean, NULL, NULL}, + {Reduce_Max, ReduceMax, IntReduceMax, NULL, ReduceMaxByLastAxis}, + {Reduce_Min, ReduceMin, IntReduceMin, NULL, NULL}, + {Reduce_Prod, ReduceProd, IntReduceProd, NULL, NULL}, + {Reduce_SumSquare, ReduceSum, IntReduceSum, NULL, NULL}, + {Reduce_ASum, ReduceSum, IntReduceSum, NULL, NULL}, + {Reduce_All, NULL, NULL, ReduceAll, NULL}, + {Reduce_L2, ReduceL2Norm, NULL, NULL, NULL}}; + + size_t list_len = sizeof(func_list) / sizeof(ReduceKernelList); + for (size_t i = 0; i < list_len; ++i) { + if (param->mode_ == func_list[i].type_) { + reduce->compute_ = func_list[i]; + return; + } + } +} + +int CallReduceUnit(KernelBase *base, int task_id) { + ReduceStruct *reduce = (ReduceStruct *)base; + NNACL_CHECK_NULL_RETURN_ERR(reduce->src_data_); + NNACL_CHECK_NULL_RETURN_ERR(reduce->dst_data_); + + if (reduce->data_type_ == kNumberTypeFloat32) { + if (reduce->inner_size_ == 1 && reduce->compute_.float_last_axis_func_ != NULL) { + return reduce->compute_.float_last_axis_func_(reduce->outer_size_, reduce->inner_size_, reduce->axis_size_, + (float *)(reduce->src_data_), (float *)(reduce->dst_data_), task_id, + reduce->base_.thread_nr_); + } else { + NNACL_CHECK_NULL_RETURN_ERR(reduce->compute_.float_function_); + return reduce->compute_.float_function_(reduce->outer_size_, reduce->inner_size_, reduce->axis_size_, + (float *)(reduce->src_data_), (float *)(reduce->dst_data_), task_id, + reduce->base_.thread_nr_); + } + } + + if (reduce->data_type_ == kNumberTypeBool) { + NNACL_CHECK_NULL_RETURN_ERR(reduce->compute_.bool_function_); + return reduce->compute_.bool_function_(reduce->outer_size_, reduce->inner_size_, reduce->axis_size_, + (bool *)(reduce->src_data_), (bool *)(reduce->dst_data_), task_id, + reduce->base_.thread_nr_); + } + + if (reduce->data_type_ == kNumberTypeInt32) { + NNACL_CHECK_NULL_RETURN_ERR(reduce->compute_.int_function_); + return reduce->compute_.int_function_(reduce->outer_size_, reduce->inner_size_, reduce->axis_size_, + (int *)(reduce->src_data_), (int *)(reduce->dst_data_), task_id, + reduce->base_.thread_nr_); + } + + return NNACL_REDUCE_UNSUPPORTED_DATA_TYPE; +} + +int ReduceImpl(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + ReduceStruct *reduce = (ReduceStruct *)cdata; + return reduce->call_uint_((KernelBase *)reduce, task_id); +} + +int CopyReduceyInputToOutput(ReduceStruct *reduce) { + int total_num = NNACLGetElementNum(reduce->base_.in_[FIRST_INPUT]); + NNACL_CHECK_FALSE(total_num == 0, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + int block_num = UP_DIV(total_num, reduce->base_.thread_nr_); + int tmp_thread_num = UP_DIV(total_num, block_num); + NNACL_CHECK_FALSE(tmp_thread_num == 0, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + + ReshapeStruct reshape_struct; + reshape_struct.base_.in_ = reduce->base_.in_; + reshape_struct.base_.out_ = reduce->base_.out_; + reshape_struct.block_num_ = block_num; + reshape_struct.total_num_ = total_num; + reshape_struct.base_.thread_nr_ = tmp_thread_num; + return reduce->base_.env_->ParallelLaunch(reduce->base_.env_->thread_pool_, ParallelReshape, &reshape_struct, + tmp_thread_num); +} + +int MallocReduceTmpBuffer(ReduceStruct *reduce) { + // Clean pointers in data_buffer for free condition checking in FreeReduceTmpBuffer. + memset(reduce->data_buffers_, 0, reduce->data_buffers_size_ * sizeof(void *)); + + for (int i = 0; i < reduce->data_buffers_size_; i++) { + reduce->data_buffers_[i] = reduce->base_.env_->Alloc( + reduce->base_.env_->allocator_, reduce->data_buffer_sizes_[i] * DataTypeCSize(reduce->data_type_)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(reduce->data_buffers_[i]); + } + return NNACL_OK; +} + +void FreeReduceTmpBuffer(ReduceStruct *reduce) { + for (int i = 0; i < reduce->data_buffers_size_; i++) { + if (reduce->data_buffers_[i] != NULL) { + reduce->base_.env_->Free(reduce->base_.env_->allocator_, reduce->data_buffers_[i]); + } + reduce->data_buffers_[i] = NULL; + } +} + +int CalculateReduceCoeffOutput(KernelBase *base) { + ReduceStruct *reduce = (ReduceStruct *)base; + + if (reduce->data_type_ != kNumberTypeFloat32) { + return NNACL_REDUCE_COEFF_DATA_TYPE_INVALID; + } + TensorC *out_tensor = reduce->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_NULL_RETURN_ERR(out_tensor->data_); + int num = NNACLGetElementNum(out_tensor); + + float *out_data = (float *)out_tensor->data_; + for (int i = 0; i < num; ++i) { + out_data[i] *= ((ReduceParameter *)reduce->base_.param_)->coeff; + } + return NNACL_OK; +} + +void HandleReduceASumAndSumSquare(KernelBase *base) { + ReduceStruct *reduce = (ReduceStruct *)base; + if (reduce->data_type_ == kNumberTypeInt32 || reduce->data_type_ == kNumberTypeBool) { + return; + } + + TensorC *in_tensor = base->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_VOID(in_tensor); + float *data = (float *)in_tensor->data_; + NNACL_CHECK_NULL_RETURN_VOID(data); + + int num = NNACLGetElementNum(in_tensor); + + if (((ReduceParameter *)base->param_)->mode_ == Reduce_ASum) { + for (int i = 0; i < num; ++i) { + if (data[i] < 0.0f) { + data[i] = 0.0f - data[i]; + } + } + } + + if (((ReduceParameter *)base->param_)->mode_ == Reduce_SumSquare) { + for (int i = 0; i < num; ++i) { + data[i] = data[i] * data[i]; + } + return; + } +} + +int ReduceCheckInputsOutputs(ReduceStruct *reduce) { + NNACL_CHECK_FALSE(reduce->base_.in_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(reduce->base_.out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR); + + for (size_t i = 0; i < reduce->base_.in_size_; i++) { + NNACL_CHECK_NULL_RETURN_ERR(reduce->base_.in_[i]); + } + for (size_t i = 0; i < reduce->base_.out_size_; i++) { + NNACL_CHECK_NULL_RETURN_ERR(reduce->base_.out_[i]); + } + TensorC *input_tensor = reduce->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + if (reduce->base_.in_size_ > ONE_TENSOR) { + TensorC *axes_tensor = reduce->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(axes_tensor); + NNACL_CHECK_FALSE(axes_tensor->data_type_ != kNumberTypeInt && axes_tensor->data_type_ != kNumberTypeInt32 && + axes_tensor->data_type_ != kNumberTypeInt64, + NNACL_REDUCE_AXES_TENSOR_ERROR); + } + return NNACL_OK; +} + +int ReduceCommonPrepare(ReduceStruct *reduce) { + int ret = ReduceCheckInputsOutputs(reduce); + if (ret != NNACL_OK) { + return ret; + } + + if (reduce->base_.in_size_ == ONE_TENSOR) { + reduce->num_axes_ = 0; + return NNACL_OK; + } + + TensorC *axes_tensor = reduce->base_.in_[SECOND_INPUT]; + reduce->num_axes_ = NNACLGetElementNum(axes_tensor); + if (axes_tensor->data_ != NULL && (reduce->num_axes_ <= 0 || reduce->num_axes_ > MAX_SHAPE_SIZE)) { + return NNACL_REDUCE_AXES_TENSOR_ERROR; + } + if (axes_tensor->data_ == NULL) { + reduce->num_axes_ = reduce->base_.in_[FIRST_INPUT]->shape_size_; + for (int i = 0; i < reduce->num_axes_; i++) { + reduce->axes_[i] = i; + } + } else { + if (axes_tensor->data_type_ == kNumberTypeInt32 || axes_tensor->data_type_ == kNumberTypeInt) { + NNACL_CHECK_FALSE(NNACLGetSize(axes_tensor) == 0, NNACL_REDUCE_AXES_TENSOR_ERROR); + (void)memcpy(reduce->axes_, axes_tensor->data_, NNACLGetSize(axes_tensor)); + } else { + int64_t *axes_data = axes_tensor->data_; + for (size_t i = 0; i < reduce->num_axes_; i++) { + reduce->axes_[i] = (int32_t)axes_data[i]; + } + } + } + + return NNACL_OK; +} + +int CheckReduceParameters(ReduceStruct *reduce) { + int input_shape_size = reduce->base_.in_[FIRST_INPUT]->shape_size_; + NNACL_CHECK_FALSE(reduce->num_axes_ > input_shape_size, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + + for (int i = 0; i < reduce->num_axes_; i++) { + NNACL_CHECK_FALSE(reduce->axes_[i] < -input_shape_size, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + NNACL_CHECK_FALSE(reduce->axes_[i] >= input_shape_size, NNACL_REDUCE_INPUT_SHAPE_SIZE_INVALID); + + if (reduce->axes_[i] < 0) { + reduce->axes_[i] += input_shape_size; + } + } + + if (((ReduceParameter *)reduce->base_.param_)->reduce_to_end_) { + // actual num of axes to reduce + reduce->num_axes_ = (int)(input_shape_size)-reduce->axes_[0]; + for (int i = 1; i < reduce->num_axes_; ++i) { + reduce->axes_[i] = reduce->axes_[0] + i; + } + } + + if (reduce->num_axes_ == 0) { + for (int i = 0; i < input_shape_size; i++) { + reduce->axes_[i] = i; + } + reduce->num_axes_ = input_shape_size; + } + return NNACL_OK; +} + +void ReduceCalculateInnerOuterSize(ReduceStruct *reduce) { + TensorC *input_tensor = reduce->base_.in_[FIRST_INPUT]; + int tmp_input_shape[MAX_SHAPE_SIZE]; + memcpy(tmp_input_shape, input_tensor->shape_, MAX_SHAPE_SIZE * sizeof(int)); + reduce->offset_size_ = 0; + + for (int i = 0; i < reduce->num_axes_; ++i) { + int axis = reduce->axes_[i]; + int outer_size = 1; + for (int j = 0; j < axis; j++) { + outer_size *= tmp_input_shape[j]; + } + reduce->outer_sizes_[reduce->offset_size_] = outer_size; + + int inner_size = 1; + for (int k = axis + 1; k < input_tensor->shape_size_; k++) { + inner_size *= tmp_input_shape[k]; + } + reduce->inner_sizes_[reduce->offset_size_] = inner_size; + reduce->axis_sizes_[reduce->offset_size_] = tmp_input_shape[axis]; + + reduce->offset_size_++; + tmp_input_shape[axis] = 1; + } +} + +void ReduceCalculateTmpBufferSize(ReduceStruct *reduce) { + reduce->data_buffers_size_ = 0; + + TensorC *input_tensor = reduce->base_.in_[FIRST_INPUT]; + int tmp_input_shape[MAX_SHAPE_SIZE]; + memcpy(tmp_input_shape, input_tensor->shape_, MAX_SHAPE_SIZE * sizeof(int)); + // calculate size of buffer to malloc for each reducing axis + for (int i = 0; i < reduce->num_axes_ - 1; i++) { + int axis = reduce->axes_[i]; + size_t size = 1; + for (size_t j = 0; j < input_tensor->shape_size_; j++) { + if (axis != (int)(j)) { + size *= (size_t)(tmp_input_shape[j]); + } + } + reduce->data_buffer_sizes_[reduce->data_buffers_size_++] = size; + tmp_input_shape[axis] = 1; + } +} + +void ReduceDecideIfOnlyCopy(ReduceStruct *reduce) { + ReduceModeC can_not_copy[] = {Reduce_SumSquare, Reduce_ASum, Reduce_All, Reduce_L2}; + for (int i = 0; i < sizeof(can_not_copy) / sizeof(ReduceModeC); i++) { + if (can_not_copy[i] == ((ReduceParameter *)reduce->base_.param_)->mode_) { + reduce->only_copy_ = false; + return; + } + } + + int *in_shape = reduce->base_.in_[FIRST_INPUT]->shape_; + + for (int i = 0; i < reduce->num_axes_; i++) { + int axis = reduce->axes_[i]; + if (in_shape[axis] != 1) { + reduce->only_copy_ = false; + return; + } + } + reduce->only_copy_ = true; + return; +} + +int ReducePrepare(struct KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + ReduceStruct *reduce = (ReduceStruct *)self; + + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, ONE_TENSOR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, ONE_TENSOR); + + int ret = ReduceCommonPrepare(reduce); + if (ret != NNACL_OK) { + return ret; + } + + reduce->init_kernel_list_(self); + return NNACL_OK; +} + +int ReduceResize(struct KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + ReduceStruct *reduce = (ReduceStruct *)self; + + int ret = CheckReduceParameters(reduce); + if (ret != NNACL_OK) { + return ret; + } + + ReduceDecideIfOnlyCopy(reduce); + ReduceCalculateTmpBufferSize(reduce); + ReduceCalculateInnerOuterSize(reduce); + + if (reduce->num_axes_ == 1) { + self->thread_nr_ = self->UpdateThread( + TC_TYPE(PrimType_ReduceFusion, ((ReduceParameter *)reduce->base_.param_)->mode_), + reduce->inner_sizes_[Index0] * reduce->axis_sizes_[Index0], + reduce->inner_sizes_[Index0] * reduce->axis_sizes_[Index0], reduce->outer_sizes_[Index0], self->thread_nr_); + } else { + self->thread_nr_ = self->UpdateThread(TC_TYPE(PrimType_ReduceFusion, Reduce_Max + 1), 0, 0, + NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + } + return NNACL_OK; +} + +int ReduceCompute(struct KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + ReduceStruct *reduce = (ReduceStruct *)self; + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ != reduce->data_type_, NNACL_ERR); + + if (reduce->only_copy_) { + return CopyReduceyInputToOutput(reduce); + } + + int ret = MallocReduceTmpBuffer(reduce); + if (ret != NNACL_OK) { + FreeReduceTmpBuffer(reduce); + return ret; + } + + reduce->src_data_ = self->in_[FIRST_INPUT]->data_; + reduce->handle_sum_square_(self); + for (int i = 0; i < reduce->num_axes_; i++) { + if (i != (reduce->num_axes_ - 1)) { + reduce->dst_data_ = reduce->data_buffers_[i]; + } else { + reduce->dst_data_ = self->out_[FIRST_INPUT]->data_; + } + reduce->outer_size_ = reduce->outer_sizes_[i]; + reduce->inner_size_ = reduce->inner_sizes_[i]; + reduce->axis_size_ = reduce->axis_sizes_[i]; + NNACL_CHECK_FALSE(reduce->axis_size_ == 0, NNACL_REDUCE_AXIS_SIZE_ERROR); + + ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ReduceImpl, self, self->thread_nr_); + if (ret != NNACL_OK) { + FreeReduceTmpBuffer(reduce); + return ret; + } + reduce->src_data_ = reduce->dst_data_; + } + + ReduceParameter *param = (ReduceParameter *)reduce->base_.param_; + if (param->reduce_to_end_ && fabsf(param->coeff) > 1e-5) { + ret = reduce->calculate_coeff_(self); + } + + FreeReduceTmpBuffer(reduce); + return ret; +} + +KernelBase *CreateReduce(OpParameter *param, int data_type) { + ReduceStruct *reduce = (ReduceStruct *)malloc(sizeof(ReduceStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reduce); + memset(reduce, 0, sizeof(ReduceStruct)); + reduce->data_type_ = data_type; + reduce->base_.Release = DefaultRelease; + reduce->base_.Prepare = ReducePrepare; + reduce->base_.Resize = ReduceResize; + reduce->base_.Compute = ReduceCompute; + reduce->handle_sum_square_ = HandleReduceASumAndSumSquare; + reduce->calculate_coeff_ = CalculateReduceCoeffOutput; + reduce->init_kernel_list_ = InitialReduceKernelList; + reduce->call_uint_ = CallReduceUnit; + return (KernelBase *)reduce; +} + +REG_KERNEL_CREATOR(PrimType_ReduceFusion, kNumberTypeBool, CreateReduce) +REG_KERNEL_CREATOR(PrimType_ReduceFusion, kNumberTypeInt32, CreateReduce) +REG_KERNEL_CREATOR(PrimType_ReduceFusion, kNumberTypeFloat32, CreateReduce) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reduce.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..2c0a5a5f78e0f3e86162af9d081205b7c0227b47 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reduce.h @@ -0,0 +1,72 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_REDUCE_H_ +#define NNACL_KERNEL_REDUCE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct ReduceKernelList { + int type_; + int (*float_function_)(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + float *dst_data, const int tid, const int thread_num); + int (*int_function_)(const int outer_size, const int inner_size, const int axis_size, const int *src_data, + int *dst_data, const int tid, const int thread_num); + int (*bool_function_)(const int outer_size, const int inner_size, const int axis_size, const bool *src_data, + bool *dst_data, const int tid, const int thread_num); + int (*float_last_axis_func_)(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + float *dst_data, const int tid, const int thread_num); +} ReduceKernelList; + +typedef struct ReduceStruct { + KernelBase base_; + bool only_copy_; + int num_axes_; + TypeIdC data_type_; + int axes_[MAX_SHAPE_SIZE]; + + void *data_buffers_[MAX_SHAPE_SIZE]; + size_t data_buffer_sizes_[MAX_SHAPE_SIZE]; + int data_buffers_size_; + ReduceModeC mode_; + + int outer_sizes_[MAX_SHAPE_SIZE]; + int inner_sizes_[MAX_SHAPE_SIZE]; + int axis_sizes_[MAX_SHAPE_SIZE]; + int offset_size_; + + int outer_size_; + int inner_size_; + int axis_size_; + + void *src_data_; + void *dst_data_; + ReduceKernelList compute_; + + void (*handle_sum_square_)(KernelBase *base); + void (*init_kernel_list_)(KernelBase *base); + int (*calculate_coeff_)(KernelBase *base); + int (*call_uint_)(KernelBase *base, int task_id); +} ReduceStruct; + +KernelBase *CreateReduce(OpParameter *param, int data_type); +int ReducePrepare(KernelBase *self); +int ReduceResize(KernelBase *self); +int ReduceCompute(KernelBase *self); + +#endif // NNACL_KERNEL_RESHAPE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reshape.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reshape.c new file mode 100644 index 0000000000000000000000000000000000000000..975f30891093072fdb90dbcb8e90a7f0b7ade5dd --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reshape.c @@ -0,0 +1,96 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/reshape.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/nnacl_common.h" + +int kMinCostPerThread = 16384; + +int ParallelReshape(void *param, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(param); + ReshapeStruct *reshape = (ReshapeStruct *)param; + + int data_size = (int)DataTypeCSize(reshape->base_.in_[0]->data_type_); + uint8_t *in_start = (uint8_t *)(reshape->base_.in_[0]->data_) + task_id * reshape->block_num_ * data_size; + uint8_t *out_start = (uint8_t *)(reshape->base_.out_[0]->data_) + task_id * reshape->block_num_ * data_size; + int copy_num = reshape->block_num_; + if (task_id == (reshape->base_.thread_nr_ - 1)) { + copy_num = reshape->total_num_ - task_id * reshape->block_num_; + } + (void)memcpy(out_start, in_start, copy_num * data_size); + return NNACL_OK; +} + +int ReshapeResize(struct KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + ReshapeStruct *reshape = (ReshapeStruct *)self; + reshape->total_num_ = NNACLGetElementNum(self->in_[0]); + if (reshape->total_num_ == 0) { + return NNACL_OK; + } + + self->thread_nr_ = MSMIN(self->thread_nr_, UP_DIV(reshape->total_num_, kMinCostPerThread)); + if (self->thread_nr_ < 1) { + self->thread_nr_ = 1; + } + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + reshape->block_num_ = UP_DIV(reshape->total_num_, self->thread_nr_); + NNACL_CHECK_ZERO_RETURN_ERR(reshape->block_num_); + self->thread_nr_ = UP_DIV(reshape->total_num_, reshape->block_num_); + + return NNACL_OK; +} + +int ReshapeCompute(struct KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ParallelReshape, self, self->thread_nr_); +} + +KernelBase *CreateReshape(OpParameter *param, int data_type) { + ReshapeStruct *reshape = (ReshapeStruct *)malloc(sizeof(ReshapeStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reshape); + reshape->base_.Release = DefaultRelease; + reshape->base_.Prepare = DefaultPrepare1In1Out; + reshape->base_.Resize = ReshapeResize; + reshape->base_.Compute = ReshapeCompute; + return (KernelBase *)reshape; +} + +REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Reshape, kNumberTypeBool, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Flatten, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Flatten, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Flatten, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_FlattenGrad, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_FlattenGrad, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeBool, CreateReshape) +REG_KERNEL_CREATOR(PrimType_ExpandDims, kNumberTypeInt8, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Squeeze, kNumberTypeBool, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeFloat16, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeFloat32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeUInt8, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeInt32, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeInt64, CreateReshape) +REG_KERNEL_CREATOR(PrimType_Unsqueeze, kNumberTypeBool, CreateReshape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reshape.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reshape.h new file mode 100644 index 0000000000000000000000000000000000000000..acfee1d53683d6541e6de1bc7d254ed63d375ff7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reshape.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_RESHAPE_H_ +#define NNACL_KERNEL_RESHAPE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct ReshapeStruct { + KernelBase base_; + int block_num_; + int total_num_; +} ReshapeStruct; + +int ParallelReshape(void *param, int task_id, float l, float r); + +KernelBase *CreateReshape(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_RESHAPE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reverse.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reverse.c new file mode 100644 index 0000000000000000000000000000000000000000..34147036ab2d7b561a4aed6eeac2a5a18cacc201 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reverse.c @@ -0,0 +1,166 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/reverse.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/reverse_parameter.h" +#include "nnacl/fp32/reverse_fp32.h" + +int ReverseStride(TensorC *input, int index) { + int stride = 1; + for (int i = index + 1; i < (int)input->shape_size_; i++) { + stride *= input->shape_[i]; + } + return stride; +} + +int ReverseRun(void *cdata, int task_id, float l, float r) { + ReverseStruct *reverse = (ReverseStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(reverse); + + int offset = task_id * reverse->thread_stride_; + int count = NNACL_MIN(reverse->thread_stride_, reverse->data_size_ - offset); + if (count <= 0) { + return NNACL_OK; + } + + float *in_ptr = (float *)reverse->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(in_ptr); + float *out_ptr = (float *)reverse->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_ptr); + return Reverse(in_ptr + offset, out_ptr, reverse->thread_stride_, reverse->tmp_ + offset); +} + +int ReverseUpdateAxisInfo(ReverseStruct *reverse) { + ReverseParameter *reverse_param = (ReverseParameter *)reverse->base_.param_; + int in_shape_len = reverse->base_.in_[FIRST_INPUT]->shape_size_; + for (int i = 0; i < reverse_param->num_axis_; ++i) { + if (reverse_param->axis_[i] < 0) { + reverse_param->axis_[i] += in_shape_len; + } + if (reverse_param->axis_[i] < 0 || reverse_param->axis_[i] >= in_shape_len) { + return NNACL_REVERSE_AXIS_VALUE_INVALID; + } + } + return NNACL_OK; +} + +int ReverseCompute(KernelBase *self) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, ReverseRun, self, self->thread_nr_); +} + +int ReversePrepare(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + ReverseStruct *reverse = (ReverseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(reverse); + if (((ReverseParameter *)self->param_)->num_axis_ < Num1) { + return NNACL_REVERSE_AXIS_INVALID; + } + return NNACL_OK; +} + +int ReverseRelease(KernelBase *self) { + ReverseStruct *reverse = (ReverseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(reverse); + if (reverse->tmp_ != NULL) { + self->env_->Free(self->env_->allocator_, reverse->tmp_); + reverse->tmp_ = NULL; + } + return NNACL_OK; +} + +int ReverseResize(KernelBase *self) { + ReverseStruct *reverse = (ReverseStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(reverse); + + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + // trans negative to positive axis + int ret = ReverseUpdateAxisInfo(reverse); + if (ret != NNACL_OK) { + return ret; + } + + reverse->data_size_ = NNACLGetElementNum(input); + if (NNACLGetElementNum(output) != reverse->data_size_) { + return NNACL_REVERSE_DATA_SIZE_INVALID; + } + + self->thread_nr_ = NNACL_MIN(self->thread_nr_, reverse->data_size_); + NNACL_CHECK_ZERO_RETURN_ERR(self->thread_nr_); + reverse->thread_stride_ = UP_DIV(reverse->data_size_, self->thread_nr_); + + ReverseParameter *reverse_param = (ReverseParameter *)self->param_; + if (reverse_param->num_axis_ > input->shape_size_) { + return NNACL_REVERSE_NUM_AXIS_INVALID; + } + if (input->shape_size_ > REVERSE_SHAPE_MAX_SIZE) { + return NNACL_REVERSE_NUM_AXIS_INVALID; + } + + (void)self->Release(self); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(reverse->data_size_, sizeof(int), NNACL_ERR); + reverse->tmp_ = (int *)self->env_->Alloc(self->env_->allocator_, reverse->data_size_ * sizeof(int)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(reverse->tmp_); + memset(reverse->tmp_, 0, reverse->data_size_ * sizeof(int)); + + for (int i = 0; i < reverse_param->num_axis_; i++) { + int axis = reverse_param->axis_[i]; + int stride = ReverseStride(input, axis); + reverse->strides_[i] = stride; + reverse->in_count_[i] = input->shape_[axis]; + reverse->out_count_[i] = 1; + for (int j = 0; j < axis; j++) { + reverse->out_count_[i] *= input->shape_[j]; + } + } + + int out; + int in; + int C; + int m; + for (int i = 0; i < reverse->data_size_; ++i) { + int tmp = i; + for (int j = 0; j < reverse_param->num_axis_; ++j) { + C = reverse->in_count_[j]; + out = tmp / (C * reverse->strides_[j]); + in = tmp / reverse->strides_[j] - out * C; + m = tmp % reverse->strides_[j]; + tmp = out * C * reverse->strides_[j] + reverse->strides_[j] * (C - 1 - in) + m; + } + reverse->tmp_[i] = tmp; + } + + return NNACL_OK; +} + +KernelBase *CreateReverse(OpParameter *param, int data_type) { + ReverseStruct *reverse = (ReverseStruct *)malloc(sizeof(ReverseStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(reverse); + memset(reverse, 0, sizeof(ReverseStruct)); + reverse->base_.Release = ReverseRelease; + reverse->base_.Prepare = ReversePrepare; + reverse->base_.Resize = ReverseResize; + reverse->base_.Compute = ReverseCompute; + return (KernelBase *)reverse; +} + +REG_KERNEL_CREATOR(PrimType_ReverseV2, kNumberTypeFloat32, CreateReverse) +REG_KERNEL_CREATOR(PrimType_ReverseV2, kNumberTypeInt32, CreateReverse) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reverse.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reverse.h new file mode 100644 index 0000000000000000000000000000000000000000..88d001732e2dcfdab0e874a336262acd47a21b2c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/reverse.h @@ -0,0 +1,36 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_REVERSE_H_ +#define NNACL_KERNEL_REVERSE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct { + KernelBase base_; + int thread_stride_; + int data_size_; + int *tmp_; + int strides_[COMM_SHAPE_SIZE]; + int in_count_[COMM_SHAPE_SIZE]; + int out_count_[COMM_SHAPE_SIZE]; +} ReverseStruct; + +KernelBase *CreateReverse(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_REVERSE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/scale.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/scale.c new file mode 100644 index 0000000000000000000000000000000000000000..afcf7a8eba53988c9b3013729fd296678c039b5d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/scale.c @@ -0,0 +1,333 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/scale.h" +#include "nnacl/common_func.h" +#include "nnacl/scale_parameter.h" +#include "nnacl/fp32/scale_fp32.h" +#include "nnacl/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/utils_fp16.h" +#include "nnacl/fp16/scale_fp16.h" +#endif + +int ScaleRunF16(ScaleStruct *scale, int task_id, ActType act_type) { +#ifdef ENABLE_FP16 + switch (act_type) { + case ActType_Relu6: + DoScaleRelu6Fp16((const float16_t *)scale->input_, (float16_t *)scale->output_, (const float16_t *)scale->scale_, + (const float16_t *)scale->offset_, task_id, scale); + break; + case ActType_Relu: + Fp16DoScaleRelu((const float16_t *)scale->input_, (float16_t *)scale->output_, (const float16_t *)scale->scale_, + (const float16_t *)scale->offset_, task_id, scale); + break; + case ActType_No: + DoScaleFp16((const float16_t *)scale->input_, (float16_t *)scale->output_, (const float16_t *)scale->scale_, + (const float16_t *)scale->offset_, task_id, scale); + break; + default: + return NNACL_ERR; + } + return NNACL_OK; +#endif + return NNACL_DISABLE_FP16; +} + +int ScaleInitInputDataType(ScaleStruct *scale) { + if (scale->data_type_ == kNumberTypeFloat32) { + return NNACL_OK; + } + +#ifdef ENABLE_FP16 + TensorC *scale_tensor = scale->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + if (scale_tensor->data_type_ != kNumberTypeFloat16 && scale->malloc_scale_ == false) { + scale->malloc_scale_ = true; + scale->scale_ = GetOrAllocFp16Data(scale_tensor, scale->base_.env_, true); + } else { + scale->malloc_scale_ = false; + scale->scale_ = NULL; + } + + if (scale->base_.in_size_ == TWO_TENSOR) { + /* already done in prepare */ + return NNACL_OK; + } + + TensorC *offset_tensor = scale->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + if (offset_tensor->data_type_ != kNumberTypeFloat16 && scale->malloc_scale_ == false) { + scale->malloc_offset_ = true; + scale->offset_ = GetOrAllocFp16Data(offset_tensor, scale->base_.env_, true); + } else { + scale->malloc_offset_ = false; + scale->offset_ = NULL; + } + return NNACL_OK; +#endif + return NNACL_DISABLE_FP16; +} + +int ScaleRunF32(ScaleStruct *scale, int task_id, ActType act_type) { + switch (act_type) { + case ActType_Relu6: + DoScaleRelu6((const float *)scale->input_, (float *)scale->output_, (const float *)scale->scale_, + (const float *)scale->offset_, task_id, scale); + break; + case ActType_Relu: + DoScaleRelu((const float *)scale->input_, (float *)scale->output_, (const float *)scale->scale_, + (const float *)scale->offset_, task_id, scale); + break; + case ActType_No: + DoScale((const float *)scale->input_, (float *)scale->output_, (const float *)scale->scale_, + (const float *)scale->offset_, task_id, scale); + break; + default: + return NNACL_SCALE_UNSUPPORT_ACT_TYPE; + } + return NNACL_OK; +} + +int ScaleRun(void *cdata, int task_id, float l, float r) { + ScaleStruct *scale = (ScaleStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(scale); + ActType act_type = ((ScaleParameter *)scale->base_.param_)->activation_type_; + if (scale->data_type_ == kNumberTypeFloat16) { + return ScaleRunF16(scale, task_id, act_type); + } else if (scale->data_type_ == kNumberTypeFloat32) { + return ScaleRunF32(scale, task_id, act_type); + } + return NNACL_UNSUPPORTED_DATA_TYPE; +} + +int ScaleCalculateParameter(ScaleStruct *scale) { + TensorC *input_tensor = scale->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *scale_tensor = scale->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + TensorC *output_tensor = scale->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + + scale->outer_size_ = 1; + scale->axis_size_ = 1; + scale->inner_size_ = 1; + for (int i = 0; i < scale->axis_; i++) { + scale->outer_size_ *= input_tensor->shape_[i]; + } + for (size_t i = 0; i < scale_tensor->shape_size_; i++) { + scale->axis_size_ *= input_tensor->shape_[i + scale->axis_]; + } + for (size_t i = scale->axis_ + scale_tensor->shape_size_; i < input_tensor->shape_size_; i++) { + scale->inner_size_ *= input_tensor->shape_[i]; + } + + scale->base_.thread_nr_ = MSMIN(scale->base_.thread_nr_, scale->outer_size_); + NNACL_CHECK_ZERO_RETURN_ERR(scale->base_.thread_nr_); + + return NNACL_OK; +} + +int ScaleInitScaleOffset(ScaleStruct *scale) { + TensorC *scale_tensor = scale->base_.in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + int data_type_size = DataTypeCSize(scale->data_type_); + + if (scale->base_.in_size_ == TWO_TENSOR) { + scale->malloc_offset_ = true; + int malloc_size = NNACLGetElementNum(scale_tensor) * data_type_size; + NNACL_CHECK_MALLOC_SIZE(malloc_size); + scale->offset_ = scale->base_.env_->Alloc(scale->base_.env_->allocator_, malloc_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(scale->offset_); + memset(scale->offset_, 0, malloc_size); + } + + if (scale->data_type_ == kNumberTypeFloat16) { + /* handle fp16 scale and offset in compute */ + return NNACL_OK; + } + + if (scale_tensor->data_ != NULL) { + scale->malloc_scale_ = true; + int malloc_size = NNACLGetElementNum(scale_tensor) * data_type_size; + NNACL_CHECK_MALLOC_SIZE(malloc_size); + scale->scale_ = scale->base_.env_->Alloc(scale->base_.env_->allocator_, malloc_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(scale->scale_); + (void)memcpy(scale->scale_, scale_tensor->data_, malloc_size); + } else { + scale->malloc_scale_ = false; + scale->scale_ = NULL; + } + + if (scale->base_.in_size_ == TWO_TENSOR) { + return NNACL_OK; + } + NNACL_CHECK_FALSE(scale->base_.in_size_ != THREE_TENSOR, NNACL_SCALE_INPUT_NUM_INVALID); + + TensorC *offset_tensor = scale->base_.in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(offset_tensor); + if (offset_tensor->data_ != NULL) { + scale->malloc_offset_ = true; + int malloc_size = NNACLGetElementNum(offset_tensor) * data_type_size; + NNACL_CHECK_MALLOC_SIZE(malloc_size); + scale->offset_ = scale->base_.env_->Alloc(scale->base_.env_->allocator_, malloc_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(scale->scale_); + (void)memcpy(scale->offset_, offset_tensor->data_, malloc_size); + } else { + scale->malloc_offset_ = false; + scale->offset_ = NULL; + } + + return NNACL_OK; +} + +int ScaleCheckInputsOutputs(KernelBase *self) { + NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_INPUT_TENSOR_ERROR); + + for (size_t i = 0; i < self->in_size_; i++) { + TensorC *input_tensor = self->in_[i]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + if (input_tensor->data_type_ != kNumberTypeFloat32 && input_tensor->data_type_ != kNumberTypeFloat16) { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + } + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + if (output_tensor->data_type_ != kNumberTypeFloat32 && output_tensor->data_type_ != kNumberTypeFloat16) { + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +int ScaleRelease(struct KernelBase *self) { + ScaleStruct *scale = (ScaleStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(scale); + + if (scale->malloc_scale_ && scale->scale_ != NULL) { + self->env_->Free(self->env_->allocator_, scale->scale_); + scale->scale_ = NULL; + scale->malloc_scale_ = false; + } + + if (scale->malloc_offset_ && scale->offset_ != NULL) { + self->env_->Free(self->env_->allocator_, scale->offset_); + scale->offset_ = NULL; + scale->malloc_offset_ = false; + } + return NNACL_OK; +} + +int ScaleResize(struct KernelBase *self) { + ScaleStruct *scale = (ScaleStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(scale); + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + TensorC *scale_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + + int origin_axis = ((ScaleParameter *)self->param_)->axis_; + scale->axis_ = origin_axis < 0 ? origin_axis + input_tensor->shape_size_ : origin_axis; + + for (size_t i = 0; i < scale_tensor->shape_size_; i++) { + if (i + scale->axis_ >= input_tensor->shape_size_) { + return NNACL_SCALE_AXIS_AND_SHAPE_UNMATCH; + } + if (input_tensor->shape_[i + scale->axis_] != scale_tensor->shape_[i]) { + return NNACL_SCALE_SCALE_SHAPE_UNMATCH; + } + } + + int ret = ScaleCalculateParameter(scale); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +int ScaleCompute(struct KernelBase *self) { + ScaleStruct *scale = (ScaleStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(scale); + + TensorC *input_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input_tensor); + scale->input_ = input_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale->input_); + + TensorC *output_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output_tensor); + scale->output_ = output_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale->output_); + + int ret = ScaleInitInputDataType(scale); + if (ret != NNACL_OK) { + return ret; + } + + if (!scale->malloc_scale_) { + TensorC *scale_tensor = self->in_[SECOND_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(scale_tensor); + scale->scale_ = scale_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale->scale_); + } + + if (!scale->malloc_offset_) { + TensorC *offset_tensor = self->in_[THIRD_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(offset_tensor); + scale->offset_ = offset_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(scale->offset_); + } + + return self->env_->ParallelLaunch(self->env_->thread_pool_, ScaleRun, self, self->thread_nr_); +} + +int ScalePrepare(struct KernelBase *self) { + ScaleStruct *scale = (ScaleStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(scale); + + int ret = ScaleCheckInputsOutputs(self); + if (ret != NNACL_OK) { + return ret; + } + + ret = ScaleInitScaleOffset(scale); + if (ret != NNACL_OK) { + return ret; + } + + return NNACL_OK; +} + +KernelBase *CreateScale(OpParameter *param, int data_type) { + ScaleStruct *scale = (ScaleStruct *)malloc(sizeof(ScaleStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(scale); + memset(scale, 0, sizeof(ScaleStruct)); + scale->data_type_ = data_type; + scale->scale_ = NULL; + scale->offset_ = NULL; + scale->malloc_scale_ = false; + scale->malloc_offset_ = false; + scale->base_.Prepare = ScalePrepare; + scale->base_.Resize = ScaleResize; + scale->base_.Compute = ScaleCompute; + scale->base_.Release = ScaleRelease; + return (KernelBase *)scale; +} + +REG_KERNEL_CREATOR(PrimType_ScaleFusion, kNumberTypeFloat16, CreateScale) +REG_KERNEL_CREATOR(PrimType_ScaleFusion, kNumberTypeFloat32, CreateScale) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/scale.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/scale.h new file mode 100644 index 0000000000000000000000000000000000000000..988d44ddae78829c4d138fdfeb94d4caeb96bf53 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/scale.h @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SCALE_H_ +#define NNACL_KERNEL_SCALE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct ScaleStruct { + KernelBase base_; + int axis_; + int data_type_; + int axis_size_; + int outer_size_; + int inner_size_; + bool malloc_scale_; + bool malloc_offset_; + void *scale_; + void *offset_; + void *input_; + void *output_; +} ScaleStruct; + +KernelBase *CreateScale(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SCALE_H_ diff --git a/mindspore-lite/src/extendrt/kernel/base_kernel.cc b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/shape.c similarity index 30% rename from mindspore-lite/src/extendrt/kernel/base_kernel.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/shape.c index de6b2af5870c3e9a22a8e0f35ccc88f6e38ff547..52875501c2f1d7320c90bb21c691f1b4b2079db1 100644 --- a/mindspore-lite/src/extendrt/kernel/base_kernel.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/shape.c @@ -14,40 +14,38 @@ * limitations under the License. */ -#include -#include "src/extendrt/kernel/base_kernel.h" -#include "src/litert/cxx_api/tensor/tensor_impl.h" +#include "nnacl/kernel/shape.h" +#include "nnacl/kernel/default_kernel_base.h" -namespace mindspore::kernel { -const std::vector &BaseKernel::inputs() { - if (inputs_.empty()) { - std::transform(in_tensors_.begin(), in_tensors_.end(), std::back_inserter(inputs_), - [](lite::Tensor *tensor) { return mindspore::MSTensor(std::make_shared(tensor)); }); - } - return inputs_; +int ShapeCompute(struct KernelBase *self) { + ShapeStruct *shape = (ShapeStruct *)self; + memcpy(self->out_[OUTPUT_INDEX]->data_, self->in_[FIRST_INPUT]->shape_, shape->shape_size_); + return NNACL_OK; } -const std::vector &BaseKernel::outputs() { - if (outputs_.empty()) { - std::transform(out_tensors_.begin(), out_tensors_.end(), std::back_inserter(outputs_), - [](lite::Tensor *tensor) { return mindspore::MSTensor(std::make_shared(tensor)); }); - } - return outputs_; +int ShapeResize(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + ShapeStruct *shape = (ShapeStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(shape); + shape->shape_size_ = self->in_[FIRST_INPUT]->shape_size_ * sizeof(int); + return NNACL_OK; } -void BaseKernel::set_in_tensor(lite::Tensor *in_tensor, size_t index) { - if (index >= in_tensors_.size()) { - MS_LOG(ERROR) << "index: " << index << " larger than in_tensors size: " << in_tensors_.size(); - return; - } - this->in_tensors_[index] = in_tensor; +KernelBase *CreateShape(OpParameter *param, int data_type) { + ShapeStruct *shape = (ShapeStruct *)malloc(sizeof(ShapeStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(shape); + shape->base_.Release = DefaultRelease; + shape->base_.Prepare = DefaultPrepare1In1Out; + shape->base_.Resize = ShapeResize; + shape->base_.Compute = ShapeCompute; + return (KernelBase *)shape; } -void BaseKernel::set_out_tensor(lite::Tensor *out_tensor, size_t index) { - if (index >= out_tensors_.size()) { - MS_LOG(ERROR) << "index: " << index << " larger than out_tensors size: " << out_tensors_.size(); - return; - } - this->out_tensors_[index] = out_tensor; -} -} // namespace mindspore::kernel +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt32, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeBool, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeFloat16, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeFloat32, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt8, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeUInt8, CreateShape) +REG_KERNEL_CREATOR(PrimType_Shape, kNumberTypeInt64, CreateShape) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/shape.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/shape.h new file mode 100644 index 0000000000000000000000000000000000000000..6d84ee47b1b27e624583f3d57c4004771bb41220 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/shape.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SHAPE_H_ +#define NNACL_KERNEL_SHAPE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct ShapeStruct { + KernelBase base_; + int shape_size_; +} ShapeStruct; + +KernelBase *CreateShape(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SHAPE_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/stub_kernel.cc b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/size.c similarity index 37% rename from mindspore-lite/src/extendrt/delegate/ascend_native/stub_kernel.cc rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/size.c index 0cd0b3f3dd7f907a235d35034ef562c96d8e8202..4b3217e508e645c1858fe653aa3c8e7bd9754982 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/stub_kernel.cc +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/size.c @@ -14,31 +14,31 @@ * limitations under the License. */ -#include "extendrt/delegate/ascend_native/stub_kernel.h" -#include "extendrt/delegate/ascend_native/ascend_native_kernel_registry.h" -#include "extendrt/delegate/ascend_native/ops/ascend_native_stub.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" +#include "nnacl/kernel/size.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" -namespace mindspore::kernel { -using mindspore::ops::kNameAscendNativeStub; -int AscendNativeStubKernel::Prepare() { - for (size_t i = 0; i < out_tensors_.size(); i++) { - if (out_tensors_[i]->shape().size() == 0) { - if (in_tensors_[i] != nullptr) { - std::vector shape; - for (size_t j = 0; j < in_tensors_[i]->shape().size(); j++) { - shape.push_back(in_tensors_[i]->shape()[j]); - } - out_tensors_[i]->set_shape(shape); - } - } - } - return kSuccess; +int SizeCompute(KernelBase *self) { + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + int *out_data = (int *)out_tensor->data_; + NNACL_CHECK_NULL_RETURN_ERR(out_data); + out_data[Index0] = NNACLGetElementNum(in_tensor); + return NNACL_OK; } -int AscendNativeStubKernel::Execute() { - MS_LOG(INFO) << "AscendNativeStubKernel::Execute - " << this->get_name(); - return kSuccess; +KernelBase *CreateSize(OpParameter *param, int data_type) { + SizeStruct *size = (SizeStruct *)malloc(sizeof(SizeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(size); + size->base_.Release = DefaultRelease; + size->base_.Prepare = DefaultPrepare1In1Out; + size->base_.Resize = DefaultResize; + size->base_.Compute = SizeCompute; + return (KernelBase *)size; } -REGISTER_ASCEND_NATIVE_CREATOR(kNameAscendNativeStub, AscendNativeStubKernel) -} // namespace mindspore::kernel + +REG_KERNEL_CREATOR(PrimType_Size, kNumberTypeInt32, CreateSize) +REG_KERNEL_CREATOR(PrimType_Size, kNumberTypeFloat32, CreateSize) +REG_KERNEL_CREATOR(PrimType_Size, kNumberTypeFloat16, CreateSize) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/size.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/size.h new file mode 100644 index 0000000000000000000000000000000000000000..1766d2db9ba894a5dba4b36361efd5def98fa0f7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/size.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SIZE_H_ +#define NNACL_KERNEL_SIZE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct SizeStruct { + KernelBase base_; +} SizeStruct; + +KernelBase *CreateSize(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SIZE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/slice.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/slice.c new file mode 100644 index 0000000000000000000000000000000000000000..cb9280c407238b3499f2d63ad81037781c65ee6d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/slice.c @@ -0,0 +1,76 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/slice.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/base/slice_base.h" +#include "nnacl/nnacl_common.h" + +int SliceLaunch(void *cdata, int task_id, float l, float r) { + SliceStruct *slice = (SliceStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(slice); + void *in_data = slice->base_.in_[FIRST_INPUT]->data_; + void *out_data = slice->base_.out_[OUTPUT_INDEX]->data_; + DoSlice(in_data, out_data, slice, task_id, slice->base_.thread_nr_, slice->data_type_size_); + return NNACL_OK; +} + +int SliceResize(KernelBase *self) { + SliceStruct *slice = (SliceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(slice); + + InitSliceStruct(slice, self->in_[Index0], self->in_[Index1], self->in_[Index2]); + + if (slice->param_length_ < DIMENSION_8D) { + PadSliceParameterTo8D(slice); + } + return NNACL_OK; +} + +int SliceCompute(KernelBase *self) { + TensorC *in_tensor = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + TensorC *out_tensor = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + + SliceStruct *slice = (SliceStruct *)self; + if (slice->size_[Index5] < self->thread_nr_) { + DoSliceNoParallel(in_tensor->data_, out_tensor->data_, slice, slice->data_type_size_); + return NNACL_OK; + } + + int ret = self->env_->ParallelLaunch(self->env_->thread_pool_, SliceLaunch, self, self->thread_nr_); + if (ret != NNACL_OK) { + return ret; + } + return NNACL_OK; +} + +KernelBase *CreateSlice(OpParameter *param, int data_type) { + SliceStruct *slice = (SliceStruct *)malloc(sizeof(SliceStruct)); + NNACL_CHECK_NULL_RETURN_NULL(slice); + slice->data_type_size_ = DataTypeCSize(data_type); + slice->base_.Release = DefaultRelease; + slice->base_.Prepare = DefaultPrepare3In1Out; + slice->base_.Resize = SliceResize; + slice->base_.Compute = SliceCompute; + return (KernelBase *)slice; +} + +REG_KERNEL_CREATOR(PrimType_SliceFusion, kNumberTypeInt32, CreateSlice) +REG_KERNEL_CREATOR(PrimType_SliceFusion, kNumberTypeFloat32, CreateSlice) +REG_KERNEL_CREATOR(PrimType_SliceFusion, kNumberTypeFloat16, CreateSlice) diff --git a/mindspore-lite/src/extendrt/kernel/default/cnode_infer_manager.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/slice.h similarity index 56% rename from mindspore-lite/src/extendrt/kernel/default/cnode_infer_manager.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/slice.h index c704debf9deb87026a203fc9390334a04753fb48..83f4b47de3cd8f4ad939f0c55bb1077ce83222d7 100644 --- a/mindspore-lite/src/extendrt/kernel/default/cnode_infer_manager.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/slice.h @@ -14,15 +14,23 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_EXTENDRT_KERNEL_DEFAULT_CNODE_INFER_MANAGER_H_ -#define MINDSPORE_LITE_EXTENDRT_KERNEL_DEFAULT_CNODE_INFER_MANAGER_H_ -#include -#include "ir/anf.h" -#include "src/litert/inner_context.h" +#ifndef NNACL_KERNEL_SLICE_H_ +#define NNACL_KERNEL_SLICE_H_ -namespace mindspore { -namespace kernel { -int CNodeInferShape(const CNodePtr &cnode, const std::vector &outputs); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_LITE_EXTENDRT_KERNEL_DEFAULT_CNODE_INFER_MANAGER_H_ +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct SliceStruct { + KernelBase base_; + int data_type_size_; + int32_t begin_[DIMENSION_8D]; + int32_t size_[DIMENSION_8D]; + int32_t shape_[DIMENSION_8D]; + int32_t end_[DIMENSION_8D]; + int32_t param_length_; +} SliceStruct; + +KernelBase *CreateSlice(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SLICE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/softmax.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/softmax.c new file mode 100644 index 0000000000000000000000000000000000000000..68ebb66d6429d099d1acc389380841ba4d8a12d6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/softmax.c @@ -0,0 +1,157 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/softmax.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/fp32/softmax_fp32.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/softmax_fp16.h" +#endif + +int SoftmaxLastAxisRun(void *cdata, int task_id, float l, float r) { + SoftmaxStruct *softmax = (SoftmaxStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(softmax); + + NNACL_CHECK_ZERO_RETURN_ERR(softmax->base_.thread_nr_); + int unit = UP_DIV(softmax->out_plane_size_, softmax->base_.thread_nr_); + + int *in_shape = softmax->base_.in_[FIRST_INPUT]->shape_; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, unit, NNACL_ERR); + int begin = task_id * unit; + int end = MSMIN(begin + unit, softmax->out_plane_size_); + int channel = in_shape[softmax->axis_]; + + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(begin, channel, NNACL_ERR); + int offset = begin * channel; + + void *input_ptr = softmax->base_.in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + void *output_ptr = softmax->base_.out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + +#ifdef ENABLE_FP16 + if (softmax->data_type_ == kNumberTypeFloat16) { + SoftmaxLastAxisFp16((float16_t *)input_ptr + offset, (float16_t *)output_ptr + offset, end - begin, channel); + return NNACL_OK; + } +#endif + return SoftmaxLastAxis((float *)input_ptr + offset, (float *)output_ptr + offset, end - begin, channel); +} + +int SoftmaxRelease(struct KernelBase *self) { + SoftmaxStruct *softmax = (SoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(softmax); + if (softmax->sum_data_ != NULL) { + self->env_->Free(self->env_->allocator_, softmax->sum_data_); + } + softmax->sum_data_ = NULL; + return NNACL_OK; +} + +int InitSoftmaxParam(SoftmaxStruct *softmax) { + TensorC *in_tensor = softmax->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + int *in_shape = in_tensor->shape_; + + softmax->n_dim_ = (int)in_tensor->shape_size_; + int origin_axis = ((SoftmaxParameter *)softmax->base_.param_)->axis_; + softmax->axis_ = origin_axis == -1 ? origin_axis + softmax->n_dim_ : origin_axis; + + NNACL_CHECK_TRUE_RET(softmax->axis_ >= 0, NNACL_SOFTMAX_AXIS_INVALID); + NNACL_CHECK_TRUE_RET(softmax->axis_ < (int)in_tensor->shape_size_, NNACL_SOFTMAX_AXIS_INVALID); + + int out_plane_size = 1; + for (int i = 0; i < softmax->axis_; ++i) { + out_plane_size *= in_shape[i]; + } + int in_plane_size = 1; + for (int i = softmax->axis_ + 1; i < softmax->n_dim_; i++) { + in_plane_size *= in_shape[i]; + } + + ExecEnv *env = softmax->base_.env_; + NNACL_CHECK_NULL_RETURN_ERR(env); + + softmax->in_plane_size_ = in_plane_size; + softmax->out_plane_size_ = out_plane_size; + + (void)softmax->base_.Release(&softmax->base_); + if (softmax->in_plane_size_ > 1) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(out_plane_size, in_plane_size, NNACL_ERR); + int sum_data_size = out_plane_size * in_plane_size; + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(sum_data_size, (int)DataTypeCSize(softmax->data_type_), NNACL_ERR); + softmax->sum_data_ = env->Alloc(env->allocator_, sum_data_size * DataTypeCSize(softmax->data_type_)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(softmax->sum_data_); + } + return NNACL_OK; +} + +int SoftmaxResize(struct KernelBase *self) { + SoftmaxStruct *softmax = (SoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(softmax); + InitSoftmaxParam(softmax); + + TensorC *in_tensor = self->in_[FIRST_INPUT]; + int *in_shape = in_tensor->shape_; + + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_Softmax), in_shape[softmax->axis_], in_shape[softmax->axis_], + NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + return NNACL_OK; +} + +int SoftmaxCompute(struct KernelBase *self) { + SoftmaxStruct *softmax = (SoftmaxStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(softmax); + + if (softmax->in_plane_size_ == 1) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, SoftmaxLastAxisRun, softmax, self->thread_nr_); + } + + void *input_ptr = self->in_[FIRST_INPUT]->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_ptr); + void *output_ptr = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_ptr); + NNACL_CHECK_NULL_RETURN_ERR(softmax->sum_data_); +#ifdef ENABLE_FP16 + if (softmax->data_type_ == kNumberTypeFloat16) { + SoftmaxFp16((float16_t *)input_ptr, (float16_t *)output_ptr, (float16_t *)softmax->sum_data_, softmax->axis_, + softmax->n_dim_, self->in_[FIRST_INPUT]->shape_); + return NNACL_OK; + } +#endif + Softmax((float *)input_ptr, (float *)output_ptr, (float *)softmax->sum_data_, softmax->axis_, softmax->n_dim_, + self->in_[FIRST_INPUT]->shape_); + return NNACL_OK; +} + +KernelBase *CreateSoftmax(OpParameter *param, int data_type) { + SoftmaxStruct *softmax = (SoftmaxStruct *)malloc(sizeof(SoftmaxStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(softmax); + memset(softmax, 0, sizeof(SoftmaxStruct)); + + softmax->sum_data_ = NULL; + softmax->data_type_ = data_type; + softmax->base_.Release = SoftmaxRelease; + softmax->base_.Prepare = DefaultPrepare1In1Out; + softmax->base_.Resize = SoftmaxResize; + softmax->base_.Compute = SoftmaxCompute; + return (KernelBase *)softmax; +} + +REG_KERNEL_CREATOR(PrimType_Softmax, kNumberTypeFloat16, CreateSoftmax) +REG_KERNEL_CREATOR(PrimType_Softmax, kNumberTypeFloat32, CreateSoftmax) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/softmax.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..dd6eec291fc6df5100e4e7343ddc7e3d42fefa7a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/softmax.h @@ -0,0 +1,39 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SOFTMAX_H_ +#define NNACL_KERNEL_SOFTMAX_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct SoftmaxStruct { + KernelBase base_; + int axis_; + int n_dim_; + int in_plane_size_; + int out_plane_size_; + void *sum_data_; + TypeIdC data_type_; + int unit_; +} SoftmaxStruct; + +int InitSoftmaxParam(SoftmaxStruct *softmax); +int SoftmaxRelease(struct KernelBase *self); +KernelBase *CreateSoftmax(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SOFTMAX_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/splice.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/splice.c new file mode 100644 index 0000000000000000000000000000000000000000..2b74d789faf11c54a01b9a410a9167e55cb5d8ae --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/splice.c @@ -0,0 +1,79 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/splice.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/splice_parameter.h" +#include "nnacl/fp32/splice_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/splice_fp16.h" +#endif + +int SpliceCompute(struct KernelBase *self) { + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + NNACL_CHECK_FALSE(input->shape_size_ != output->shape_size_, NNACL_SPLICE_SHAPE_INVALID); + NNACL_CHECK_FALSE(input->shape_size_ != DIMENSION_3D, NNACL_SPLICE_SHAPE_INVALID); + NNACL_CHECK_FALSE(output->shape_size_ != DIMENSION_3D, NNACL_SPLICE_SHAPE_INVALID); + + SpliceParameter *param = (SpliceParameter *)self->param_; + NNACL_CHECK_NULL_RETURN_ERR(param); + + int src_row = input->shape_[Index1]; + int src_col = input->shape_[Index2]; + int dst_row = output->shape_[Index1]; + int dst_col = output->shape_[Index2]; + + NNACL_CHECK_FALSE(src_col * param->context_dim_ != dst_col, NNACL_SPLICE_SHAPE_INVALID); + NNACL_CHECK_FALSE(param->context_dim_ * dst_row != param->forward_indexes_dim_, NNACL_SPLICE_SHAPE_INVALID); + + for (int i = 0; i < param->forward_indexes_dim_; ++i) { + if (param->forward_indexes_[i] >= src_row) { + return NNACL_SPLICE_SHAPE_INVALID; + } + } + + void *input_data = input->data_; + NNACL_CHECK_NULL_RETURN_ERR(input_data); + void *output_data = output->data_; + NNACL_CHECK_NULL_RETURN_ERR(output_data); + +#ifdef ENABLE_FP16 + if (input->data_type_ == kNumberTypeFloat16) { + SpliceFp16((float16_t *)input_data, src_row, src_col, param, (float16_t *)output_data, dst_row, dst_col); + return NNACL_OK; + } +#endif + + SpliceFp32((float *)input_data, src_row, src_col, param, (float *)output_data, dst_row, dst_col); + return NNACL_OK; +} + +KernelBase *CreateSplice(OpParameter *param, int data_type) { + SpliceStruct *splice = (SpliceStruct *)malloc(sizeof(SpliceStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(splice); + splice->base_.Release = DefaultRelease; + splice->base_.Prepare = DefaultPrepare1In1Out; + splice->base_.Resize = DefaultResize; + splice->base_.Compute = SpliceCompute; + return (KernelBase *)splice; +} + +REG_KERNEL_CREATOR(PrimType_Splice, kNumberTypeFloat32, CreateSplice) +REG_KERNEL_CREATOR(PrimType_Splice, kNumberTypeFloat16, CreateSplice) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/splice.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/splice.h new file mode 100644 index 0000000000000000000000000000000000000000..306274acdc68254a077c5d2b8edd13e854265957 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/splice.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_SPLICE_H_ +#define NNACL_KERNEL_SPLICE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct SpliceStruct { + KernelBase base_; +} SpliceStruct; + +KernelBase *CreateSplice(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_SPLICE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/stack.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/stack.c new file mode 100644 index 0000000000000000000000000000000000000000..673463f97ca14c9e8cd697b9f35580c299188e20 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/stack.c @@ -0,0 +1,138 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/stack.h" +#include "nnacl/op_base.h" +#include "nnacl/stack_parameter.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/base/stack_base.h" +#include "nnacl/tensor_c_utils.h" + +static inline int GetCopyNum(const int *in_shape, int axis, int n_dim) { + int copy_num = 1; + if (axis > 0) { + for (int j = n_dim - 1; j > axis - 1; j--) { + copy_num *= in_shape[j]; + } + } else { + for (int i = 0; i < n_dim; ++i) { + copy_num *= in_shape[i]; + } + } + return copy_num; +} + +static inline int GetOuterSize(const int *in_shape, int axis) { + int outer_size = 1; + for (int i = 0; i < axis; ++i) { + outer_size *= in_shape[i]; + } + return outer_size; +} + +int StackRelease(KernelBase *self) { + StackStruct *stack = (StackStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack); + if (stack->buffers_ != NULL) { + self->env_->Free(self->env_->allocator_, stack->buffers_); + stack->buffers_ = NULL; + } + return NNACL_OK; +} + +int StackPrepare(KernelBase *self) { + StackStruct *stack = (StackStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack); + NNACL_CHECK_FALSE(self->in_size_ < ONE_TENSOR, NNACL_ERR); + NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR); + stack->buffers_ = + (void **)self->env_->Alloc(self->env_->allocator_, (self->in_size_ + self->out_size_) * sizeof(void *)); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(stack->buffers_); + return NNACL_OK; +} + +int StackResize(KernelBase *self) { + StackStruct *stack = (StackStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack); + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + + int origin_axis = ((StackParameter *)self->param_)->axis_; + stack->axis_ = origin_axis < 0 ? origin_axis + (int)input->shape_size_ + 1 : origin_axis; + + if (self->in_size_ == 1) { + NNACL_CHECK_FALSE(NNACLGetElementNum(input) <= 0, NNACL_STACK_TENSOR_SHAPE_INVALID); + stack->copy_size_ = (size_t)NNACLGetElementNum(input) * DataTypeCSize(stack->data_type_); + stack->outer_size_ = 1; + } else { + NNACL_CHECK_FALSE((int)input->shape_size_ < stack->axis_, NNACL_STACK_TENSOR_SHAPE_INVALID); + size_t copy_num = (size_t)GetCopyNum(input->shape_, stack->axis_, input->shape_size_); + stack->copy_size_ = copy_num * DataTypeCSize(stack->data_type_); + stack->outer_size_ = GetOuterSize(input->shape_, stack->axis_); + } + + self->thread_nr_ = self->UpdateThread(TC_PTYPE(PrimType_Stack), stack->copy_size_, stack->copy_size_, + NNACLGetElementNum(self->out_[OUTPUT_INDEX]), self->thread_nr_); + self->thread_nr_ = NNACL_MIN(UP_DIV(stack->outer_size_, NNACL_STACK_STEP), self->thread_nr_); + return NNACL_OK; +} + +int StackRun(void *cdata, int task_id, float l, float r) { + StackStruct *stack = (StackStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(stack); + + NNACL_CHECK_TRUE_RET(stack->base_.thread_nr_ != 0, NNACL_ERR); + int step = UP_DIV(stack->outer_size_, stack->base_.thread_nr_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, step, NNACL_ERR); + int start = task_id * step; + int end = NNACL_MIN(start + step, stack->outer_size_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(stack->base_.in_size_ * (size_t)start, stack->copy_size_, NNACL_ERR); + + void *output_data = (void *)(stack->base_.out_[OUTPUT_INDEX]->data_); + NNACL_CHECK_NULL_RETURN_ERR(output_data); + uint8_t *output = (uint8_t *)output_data + stack->base_.in_size_ * (size_t)start * stack->copy_size_; + + Stack(stack->buffers_, (void *)output, stack->base_.in_size_, stack->copy_size_, start, end); + return NNACL_OK; +} + +int StackCompute(KernelBase *self) { + StackStruct *stack = (StackStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(stack); + + for (size_t i = 0; i < self->in_size_; ++i) { + stack->buffers_[i] = self->in_[i]->data_; + NNACL_CHECK_NULL_RETURN_ERR(stack->buffers_[i]); + } + stack->buffers_[self->in_size_] = self->out_[OUTPUT_INDEX]->data_; + NNACL_CHECK_NULL_RETURN_ERR(stack->buffers_[self->in_size_]); + return self->env_->ParallelLaunch(self->env_->thread_pool_, StackRun, self, self->thread_nr_); +} + +KernelBase *CreateStack(OpParameter *param, int data_type) { + StackStruct *stack = (StackStruct *)malloc(sizeof(StackStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(stack); + stack->buffers_ = NULL; + stack->data_type_ = data_type; + stack->base_.Release = StackRelease; + stack->base_.Prepare = StackPrepare; + stack->base_.Resize = StackResize; + stack->base_.Compute = StackCompute; + return (KernelBase *)stack; +} + +REG_KERNEL_CREATOR(PrimType_Stack, kNumberTypeFloat32, CreateStack) +REG_KERNEL_CREATOR(PrimType_Stack, kNumberTypeInt32, CreateStack) diff --git a/mindspore-lite/src/extendrt/kernel/kernel_spec_infos.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/stack.h similarity index 51% rename from mindspore-lite/src/extendrt/kernel/kernel_spec_infos.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/stack.h index 925a0f655e53f06e0763df0aa27a15f853404c66..77e8fa6fa6e84b0052e83535d76cbedce540d0b9 100644 --- a/mindspore-lite/src/extendrt/kernel/kernel_spec_infos.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/stack.h @@ -14,20 +14,28 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_LIB_INFOS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_LIB_INFOS_H_ +#ifndef NNACL_KERNEL_STACK_H_ +#define NNACL_KERNEL_STACK_H_ -namespace mindspore::kernel { -// kernel library names -constexpr char kNNACLLibName[] = "NNACL"; -constexpr char kAclKernelLibName[] = "Acl"; -constexpr char kDefaultKernelLibName[] = "Default"; +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" -// backend -constexpr char kBackendCPU[] = "CPU"; -constexpr char kBackendGPU[] = "GPU"; -constexpr char kBackendNPU[] = "NPU"; -constexpr char kBackendAscend[] = "Ascend"; -} // namespace mindspore::kernel +#define NNACL_STACK_STEP 64 -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_LIB_INFOS_H_ +typedef struct StackStruct { + KernelBase base_; + TypeIdC data_type_; + int axis_; + int outer_size_; + size_t copy_size_; + void **buffers_; +} StackStruct; + +KernelBase *CreateStack(OpParameter *param, int data_type); +int StackRun(void *cdata, int task_id, float l, float r); +int StackRelease(KernelBase *self); +int StackPrepare(KernelBase *self); +int StackResize(KernelBase *self); + +#endif // NNACL_KERNEL_STACK_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/strided_slice.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/strided_slice.c new file mode 100644 index 0000000000000000000000000000000000000000..951eee0ad0e2d7dc1d7d6816255bc71293d4e5a4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/strided_slice.c @@ -0,0 +1,278 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/strided_slice.h" +#include "nnacl/strided_slice_parameter.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/op_base.h" +#include "nnacl/fp32/strided_slice_fp32.h" +#include "nnacl/kernel/reshape.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" + +#define MinStridedSlicePerThread 16384 + +int StridedSliceFaseRun(void *cdata, int task_id, float l, float r) { + StridedSliceStruct *strided_slice = (StridedSliceStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(strided_slice); + + uint8_t *input_data = strided_slice->base_.in_[FIRST_INPUT]->data_; + uint8_t *output_data = strided_slice->base_.out_[OUTPUT_INDEX]->data_; + int *in_shape = strided_slice->base_.in_[FIRST_INPUT]->shape_; + int *out_shape = strided_slice->base_.out_[OUTPUT_INDEX]->shape_; + int begin_index = strided_slice->begins_[strided_slice->split_axis_]; + int caled_num = task_id * strided_slice->cal_num_per_thread_; + int64_t inner_size = (int64_t)strided_slice->inner_size_; + + if (strided_slice->parallel_on_outer_) { + uint8_t *cur_in_ptr = input_data + (caled_num * in_shape[strided_slice->split_axis_] + begin_index) * inner_size; + uint8_t *cur_out_ptr = output_data + caled_num * out_shape[strided_slice->split_axis_] * inner_size; + int cur_outer = (int)strided_slice->outer_ - caled_num; + if (cur_outer <= 0) { + return NNACL_OK; + } + if (cur_outer > strided_slice->cal_num_per_thread_) { + cur_outer = strided_slice->cal_num_per_thread_; + } + FastStride(cur_in_ptr, cur_out_ptr, out_shape[strided_slice->split_axis_], + strided_slice->strides_[strided_slice->split_axis_], cur_outer, strided_slice->inner_size_, + (size_t)in_shape[strided_slice->split_axis_] * strided_slice->inner_size_); + return NNACL_OK; + } + + if (strided_slice->parallel_on_split_axis_) { + uint8_t *cur_in_ptr = + input_data + (caled_num * strided_slice->strides_[strided_slice->split_axis_] + begin_index) * inner_size; + uint8_t *cur_out_ptr = output_data + caled_num * inner_size; + int cal_axis_num = out_shape[strided_slice->split_axis_] - caled_num; + if (cal_axis_num <= 0) { + return NNACL_OK; + } + if (cal_axis_num > strided_slice->cal_num_per_thread_) { + cal_axis_num = strided_slice->cal_num_per_thread_; + } + FastStride(cur_in_ptr, cur_out_ptr, (uint32_t)cal_axis_num, strided_slice->strides_[strided_slice->split_axis_], 1, + strided_slice->inner_size_, 0); + return NNACL_OK; + } + + return NNACL_STRIDED_SLICE_INVALID_PARALLEL_MOD; +} + +int StridedSliceFastRun(StridedSliceStruct *strided_slice) { + // Update length of inner size, because data type of tensor may be changed + // from float32 to float16 during fp16 sub-graph partition process. + size_t data_type_size = DataTypeCSize(strided_slice->base_.in_[FIRST_INPUT]->data_type_); + NNACL_CHECK_FALSE(data_type_size == 0, NNACL_STRIDED_SLICE_UNSUPPORTED_DATA_TYPE); + strided_slice->inner_size_ = strided_slice->inner_ * data_type_size; + + NNACL_CHECK_NULL_RETURN_ERR(strided_slice->base_.in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(strided_slice->base_.in_[OUTPUT_INDEX]->data_); + return strided_slice->base_.env_->ParallelLaunch(strided_slice->base_.env_->thread_pool_, StridedSliceFaseRun, + strided_slice, strided_slice->base_.thread_nr_); +} + +bool StridedSliceMatchInOutShapeEqualPattern(StridedSliceStruct *strided_slice) { + for (int i = 0; i < MAX_SHAPE_SIZE; i++) { + if (strided_slice->strides_[i] < 0) { + return false; + } + } + + TensorC *in_tensor = strided_slice->base_.in_[FIRST_INPUT]; + TensorC *out_tensor = strided_slice->base_.out_[OUTPUT_INDEX]; + + if (in_tensor->data_type_ != out_tensor->data_type_) { + return false; + } + + if (in_tensor->shape_size_ != out_tensor->shape_size_) { + return false; + } + + if (in_tensor->shape_size_ < ONE_TENSOR) { + return false; + } + + for (size_t i = 0; i < in_tensor->shape_size_; ++i) { + if (in_tensor->shape_[i] != out_tensor->shape_[i]) { + return false; + } + if (in_tensor->shape_[i] == -1) { + return false; + } + } + return true; +} + +int StridedSliceSoftCopyInputToOutput(StridedSliceStruct *strided_slice) { + TensorC *in_tensor = strided_slice->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + NNACL_CHECK_NULL_RETURN_ERR(in_tensor->data_); + TensorC *out_tensor = strided_slice->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_NULL_RETURN_ERR(out_tensor->data_); + + int total_num = NNACLGetElementNum(in_tensor); + NNACL_CHECK_FALSE(total_num == 0, NNACL_STRIDED_SLICE_INVALID_DATA_SIZE); + + strided_slice->base_.thread_nr_ = + NNACL_MIN(strided_slice->base_.thread_nr_, UP_DIV(total_num, MinStridedSlicePerThread)); + if (strided_slice->base_.thread_nr_ < 1) { + strided_slice->base_.thread_nr_ = 1; + } + + int block_num = UP_DIV(total_num, strided_slice->base_.thread_nr_); + strided_slice->base_.thread_nr_ = UP_DIV(total_num, block_num); + + if (in_tensor->data_ != out_tensor->data_) { + if (strided_slice->base_.thread_nr_ == 1) { + (void)memcpy(out_tensor->data_, in_tensor->data_, total_num * (int)DataTypeCSize(in_tensor->data_type_)); + return NNACL_OK; + } + ReshapeStruct reshape; + reshape.base_.in_ = strided_slice->base_.in_; + reshape.base_.out_ = strided_slice->base_.out_; + reshape.block_num_ = block_num; + reshape.total_num_ = total_num; + reshape.base_.thread_nr_ = strided_slice->base_.thread_nr_; + return strided_slice->base_.env_->ParallelLaunch(strided_slice->base_.env_->thread_pool_, ParallelReshape, &reshape, + strided_slice->base_.thread_nr_); + } + return NNACL_OK; +} + +bool StridedSliceMatchFastPattern(StridedSliceStruct *strided_slice) { + // This function is seeking if that the number of only one dimension + // is different between input and output. If so, we can do some trick. + // Example 1: + // input shape info: [1, 80, 46, 40] + // output shape info: [1, 80, 20, 40] + // Example 2: + // input shape info: [1, 46, 40] + // output shape info: [1, 20, 40] + TensorC *in_tensor = strided_slice->base_.in_[FIRST_INPUT]; + TensorC *out_tensor = strided_slice->base_.out_[OUTPUT_INDEX]; + if (in_tensor->shape_size_ != out_tensor->shape_size_) { + return false; + } + + int axis_list[MAX_SHAPE_SIZE]; + int axis_list_size = 0; + for (size_t i = 0; i < in_tensor->shape_size_; i++) { + if (in_tensor->shape_[i] != out_tensor->shape_[i]) { + axis_list[axis_list_size++] = (int)i; + } + } + if (axis_list_size == 1) { + strided_slice->split_axis_ = axis_list[Index0]; + return true; + } + return false; +} + +void StridedSliceInitFastRunParam(StridedSliceStruct *strided_slice) { + TensorC *input_tenspr = strided_slice->base_.in_[FIRST_INPUT]; + int *in_shape = input_tenspr->shape_; + int *out_shape = strided_slice->base_.out_[OUTPUT_INDEX]->shape_; + + // reset && cal inner, outer + strided_slice->outer_ = 1; + strided_slice->inner_ = 1; + for (int i = 0; i < strided_slice->split_axis_; ++i) { + strided_slice->outer_ *= (size_t)in_shape[i]; + } + for (size_t i = (size_t)strided_slice->split_axis_ + 1; i < input_tenspr->shape_size_; i++) { + strided_slice->inner_ *= (size_t)in_shape[i]; + } + + if (strided_slice->outer_ == 1) { + strided_slice->parallel_on_split_axis_ = true; + strided_slice->parallel_on_outer_ = false; + } else { + strided_slice->parallel_on_split_axis_ = false; + strided_slice->parallel_on_outer_ = true; + } + + strided_slice->base_.thread_nr_ = strided_slice->base_.UpdateThread( + TC_TYPE(PrimType_StridedSlice, strided_slice->parallel_on_outer_), 1, 1, + NNACLGetElementNum(strided_slice->base_.out_[OUTPUT_INDEX]), strided_slice->base_.thread_nr_); + + strided_slice->cal_num_per_thread_ = + strided_slice->parallel_on_split_axis_ + ? UP_DIV(out_shape[strided_slice->split_axis_], strided_slice->base_.thread_nr_) + : UP_DIV((int)strided_slice->outer_, strided_slice->base_.thread_nr_); +} + +int StridedSliceResize(KernelBase *self) { + StridedSliceStruct *strided_slice = (StridedSliceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(strided_slice); + + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->shape_size_ > MAX_SHAPE_SIZE, NNACL_STRIDED_SLICE_INVALID_SHAPE_SIZE); + + StridedSliceParameter *param = (StridedSliceParameter *)self->param_; + memcpy(strided_slice->begins_, param->begins_, MAX_SHAPE_SIZE * sizeof(int)); + memcpy(strided_slice->ends_, param->ends_, MAX_SHAPE_SIZE * sizeof(int)); + memcpy(strided_slice->in_shape_, param->in_shape_, MAX_SHAPE_SIZE * sizeof(int)); + memcpy(strided_slice->strides_, param->strides_, MAX_SHAPE_SIZE * sizeof(int)); + strided_slice->in_shape_size_ = param->in_shape_length_; + + strided_slice->soft_copy_mode_ = StridedSliceMatchInOutShapeEqualPattern(strided_slice); + strided_slice->fast_run_ = StridedSliceMatchFastPattern(strided_slice); + if (strided_slice->fast_run_) { + StridedSliceInitFastRunParam(strided_slice); + } + + if (strided_slice->soft_copy_mode_ == false && strided_slice->fast_run_ == false) { + return PadStridedSliceParameterTo8D(strided_slice); + } + + return NNACL_OK; +} + +int StridedSliceCompute(KernelBase *self) { + StridedSliceStruct *strided_slice = (StridedSliceStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(strided_slice); + + if (strided_slice->soft_copy_mode_) { + return StridedSliceSoftCopyInputToOutput(strided_slice); + } + if (strided_slice->fast_run_) { + return StridedSliceFastRun(strided_slice); + } + + return DoStridedSliceIn8D(self->in_[FIRST_INPUT]->data_, self->out_[OUTPUT_INDEX]->data_, strided_slice); +} + +KernelBase *CreateStridedSlice(OpParameter *param, int data_type) { + StridedSliceStruct *strided_slice = (StridedSliceStruct *)malloc(sizeof(StridedSliceStruct)); + NNACL_CHECK_NULL_RETURN_NULL(strided_slice); + strided_slice->data_type_ = data_type; + strided_slice->base_.Release = DefaultRelease; + strided_slice->base_.Prepare = DefaultPrepare1In1Out; + strided_slice->base_.Resize = StridedSliceResize; + strided_slice->base_.Compute = StridedSliceCompute; + return (KernelBase *)strided_slice; +} + +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeFloat32, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeFloat16, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt64, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt32, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt8, CreateStridedSlice) +REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeBool, CreateStridedSlice) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/strided_slice.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/strided_slice.h new file mode 100644 index 0000000000000000000000000000000000000000..5e4246ad117f8e49a9862d35246c7d0270f61502 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/strided_slice.h @@ -0,0 +1,47 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_STRIDED_SLICE_H_ +#define NNACL_KERNEL_STRIDED_SLICE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct StridedSliceStruct { + KernelBase base_; + TypeIdC data_type_; + bool fast_run_; + bool soft_copy_mode_; + bool parallel_on_outer_; + bool parallel_on_split_axis_; + + int split_axis_; + int in_shape_size_; + int begins_[MAX_SHAPE_SIZE]; + int ends_[MAX_SHAPE_SIZE]; + int strides_[MAX_SHAPE_SIZE]; + int in_shape_[MAX_SHAPE_SIZE]; + + size_t inner_; + size_t outer_; + size_t inner_size_; + int cal_num_per_thread_; +} StridedSliceStruct; + +KernelBase *CreateStridedSlice(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_STRIDED_SLICE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/tile.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/tile.c new file mode 100644 index 0000000000000000000000000000000000000000..d1b35da224a7ba4ddd2f9da5a99bb4dcab679523 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/tile.c @@ -0,0 +1,182 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/tile.h" +#include "nnacl/tile_parameter.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/nnacl_common.h" +#include "nnacl/op_base.h" +#include "nnacl/base/tile_base.h" +#include "nnacl/kernel/default_kernel_base.h" + +#define kDoubleInputsSize 2 + +int TileDoubleInputScenes(TileStruct *tile) { + TensorC *t = tile->base_.in_[SECOND_INPUT]; + if (t->data_ == NULL) { + tile->resize_done_ = false; + return NNACL_OK; + } + + NNACL_CHECK_FALSE(NNACLGetElementNum(t) > (int)tile->base_.in_[FIRST_INPUT]->shape_size_, + NNACL_TILE_SECOND_INPUT_NUM_INVALID); + NNACL_CHECK_FALSE(t->data_type_ != kNumberTypeInt && t->data_type_ != kNumberTypeInt32, + NNACL_TILE_SECOND_INPUT_DATA_TYPE_INVALID); + + int *input1_addr = (int *)(t->data_); + for (int i = 0; i < NNACLGetElementNum(t); ++i) { + NNACL_CHECK_FALSE(input1_addr[i] <= 0, NNACL_TILE_SECOND_INPUT_VALUE_INVALID); + tile->dims_[i] = i; + tile->multiples_[i] = input1_addr[i]; + } + return NNACL_OK; +} + +int SimpleTileImpl(TileStruct *tile, int task_id) { + NNACL_CHECK_ZERO_RETURN_ERR(tile->base_.thread_nr_); + size_t unit = UP_DIV(tile->fast_outer_size_, (size_t)tile->base_.thread_nr_); + if (unit == 0 && task_id > 0) { + return NNACL_OK; + } + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(unit, (size_t)task_id), NNACL_ERR); + size_t begin = unit * (size_t)(task_id); + size_t end = MSMIN(begin + unit, tile->fast_outer_size_); + TileSimple(tile->input_addr_, tile->output_addr_, begin, end, tile); + return NNACL_OK; +} + +int SimpleTile(void *cdata, int task_id, float l, float r) { + TileStruct *tile = (TileStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(tile); + return SimpleTileImpl(tile, task_id); +} + +int TileFillOneDimTileParam(TileStruct *tile) { + // check if tile exact one dim + int large_one_multiple_count = 0; + int multiple = 0; + int mul_index = 0; + + for (int i = 0; i < tile->in_dim_; ++i) { + if (tile->multiples_[i] > 1) { + large_one_multiple_count++; + multiple = tile->multiples_[i]; + mul_index = i; + } + } + tile->one_dim_tile_ = large_one_multiple_count == 1; + if (tile->one_dim_tile_) { + tile->fast_multiple_ = (size_t)multiple; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(tile->in_shape_[mul_index], tile->in_strides_[mul_index]), NNACL_ERR); + tile->fast_stride_ = (size_t)(tile->in_shape_[mul_index] * tile->in_strides_[mul_index]); + NNACL_CHECK_FALSE(tile->fast_stride_ < 1, NNACL_TILE_INPUT_SHAPE_INVALID); + tile->fast_outer_size_ = (size_t)NNACLGetElementNum(tile->base_.in_[FIRST_INPUT]) / tile->fast_stride_; + } + tile->resize_done_ = true; + return NNACL_OK; +} + +int TileResize(struct KernelBase *self) { + TileStruct *tile = (TileStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(tile); + TileParameter *param = (TileParameter *)(self->param_); + NNACL_CHECK_NULL_RETURN_ERR(tile); + + tile->dims_size_ = param->dims_size_; + for (int i = 0; i < MAX_SHAPE_SIZE; i++) { + tile->dims_[i] = param->dims_[i]; + tile->multiples_[i] = param->multiples_[i]; + } + + if (self->in_size_ == kDoubleInputsSize) { + int ret = TileDoubleInputScenes(tile); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + } + + TensorC *input = self->in_[0]; + TensorC *output = self->out_[0]; + NNACL_CHECK_NULL_RETURN_ERR(input); + NNACL_CHECK_NULL_RETURN_ERR(output); + + tile->in_dim_ = (int)input->shape_size_; + NNACL_CHECK_TRUE_RET(tile->in_dim_ > 0 && tile->in_dim_ <= MAX_SHAPE_SIZE, NNACL_TILE_INPUT_SHAPE_INVALID); + NNACL_CHECK_FALSE((int)output->shape_size_ < tile->in_dim_, NNACL_TILE_INPUT_SHAPE_INVALID); + + for (int i = 0; i < tile->in_dim_; ++i) { + tile->in_shape_[i] = input->shape_[i]; + tile->out_shape_[i] = output->shape_[i]; + } + + ComputeStrides(tile->in_shape_, tile->in_strides_, tile->in_dim_); + ComputeStrides(tile->out_shape_, tile->out_strides_, tile->in_dim_); + + for (size_t i = 0; i < tile->dims_size_; i++) { + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(tile->multiples_[i], tile->in_shape_[i]), NNACL_ERRCODE_MUL_OVERFLOW); + int ele_num = tile->multiples_[i] * tile->in_shape_[i] - 1; + NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(tile->out_strides_[i], ele_num), NNACL_ERRCODE_MUL_OVERFLOW); + } + + int ret = TileFillOneDimTileParam(tile); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + + if (tile->one_dim_tile_) { + self->thread_nr_ = + self->UpdateThread(TC_TYPE(PrimType_TileFusion, 0), 0, 0, tile->fast_outer_size_, self->thread_nr_); + } + return NNACL_OK; +} + +int TileCompute(struct KernelBase *self) { + TileStruct *tile = (TileStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(tile); + tile->input_addr_ = (uint8_t *)(self->in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(tile->input_addr_); + tile->output_addr_ = (uint8_t *)(self->out_[OUTPUT_INDEX]->data_); + NNACL_CHECK_NULL_RETURN_ERR(tile->output_addr_); + + if (!tile->resize_done_) { + int ret = TileResize(self); + NNACL_CHECK_FALSE(ret != NNACL_OK, ret); + NNACL_CHECK_FALSE(tile->resize_done_ == false, NNACL_TILE_RESIZE_IN_RUNTIME_FAILED); + } + + tile->data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + NNACL_CHECK_TRUE_RET(tile->data_size_ > 0, NNACL_UNSUPPORTED_DATA_TYPE); + + if (tile->one_dim_tile_) { + return self->env_->ParallelLaunch(self->env_->thread_pool_, SimpleTile, self, self->thread_nr_); + } + + Tile(tile->input_addr_, tile->output_addr_, tile); + return NNACL_OK; +} + +KernelBase *CreateTile(OpParameter *param, int data_type) { + TileStruct *tile = (TileStruct *)malloc(sizeof(TileStruct)); + NNACL_CHECK_NULL_RETURN_NULL(tile); + tile->resize_done_ = false; + tile->base_.Release = DefaultRelease; + tile->base_.Prepare = DefaultPrepare1In1Out; + tile->base_.Resize = TileResize; + tile->base_.Compute = TileCompute; + return (KernelBase *)tile; +} + +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeInt32, CreateTile) +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeFloat32, CreateTile) +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeFloat16, CreateTile) +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeBool, CreateTile) +REG_KERNEL_CREATOR(PrimType_TileFusion, kNumberTypeUInt8, CreateTile) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/tile.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/tile.h new file mode 100644 index 0000000000000000000000000000000000000000..b3b0418115725c01be8a3f38d3fb122738071fac --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/tile.h @@ -0,0 +1,48 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_TILE_H_ +#define NNACL_KERNEL_TILE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct TileStruct { + KernelBase base_; + bool one_dim_tile_; + bool resize_done_; + int dims_[MAX_SHAPE_SIZE]; + size_t dims_size_; + uint8_t *input_addr_; + uint8_t *output_addr_; + + int multiples_[MAX_SHAPE_SIZE]; + int in_shape_[MAX_SHAPE_SIZE]; + int out_shape_[MAX_SHAPE_SIZE]; + int in_strides_[MAX_SHAPE_SIZE]; + int out_strides_[MAX_SHAPE_SIZE]; + + int in_dim_; + size_t data_size_; + size_t fast_outer_size_; + size_t fast_stride_; + size_t fast_multiple_; +} TileStruct; + +KernelBase *CreateTile(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_TILE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/transpose.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/transpose.c new file mode 100644 index 0000000000000000000000000000000000000000..8e3b069d3a105b9ed828ee1f9edd5fe3fccfc383 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/transpose.c @@ -0,0 +1,358 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/transpose.h" +#include "nnacl/fp32/transpose_fp32.h" +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/pack_fp16.h" +#include "nnacl/fp16/transpose_fp16.h" +#endif + +/* opt perm: { 0, 2, 1 } */ +#define OPT_PERM_0 0 +#define OPT_PERM_1 2 +#define OPT_PERM_2 1 + +int TransposeComputeinMultiThread(TransposeStruct *transpose, int task_id) { + void *in = transpose->base_.in_[FIRST_INPUT]->data_; + void *out = transpose->base_.out_[OUTPUT_INDEX]->data_; + + if (transpose->opt_run_) { + transpose->nhwc2nchw_(in, out, transpose->opt_perm_[FIRST_INPUT], transpose->opt_perm_[SECOND_INPUT], + transpose->opt_perm_[THIRD_INPUT], task_id, transpose->base_.thread_nr_); + } else { + transpose->optimize_(in, out, transpose->out_shape_, transpose->perm_, transpose->strides_, transpose->out_strides_, + transpose->num_axes_, task_id, transpose->base_.thread_nr_); + } + return NNACL_OK; +} + +int TransposeComputeinSingleThread(TransposeStruct *transpose) { + if (transpose->opt_run_ || transpose->num_axes_ > DIMENSION_6D) { + return TransposeComputeinMultiThread(transpose, 0); + } + + void *in = transpose->base_.in_[FIRST_INPUT]->data_; + void *out = transpose->base_.out_[OUTPUT_INDEX]->data_; + return transpose->compute_(in, out, transpose->out_shape_, transpose->perm_, transpose->strides_, + transpose->out_strides_, transpose->data_num_, transpose->num_axes_); +} + +int ResetTransposeStatus(TransposeStruct *transpose) { + transpose->num_axes_ = 0; + if (transpose->base_.in_size_ == C2NUM) { + transpose->num_axes_ = NNACLGetElementNum(transpose->base_.in_[SECOND_INPUT]); + transpose->perm_size_ = transpose->base_.in_[SECOND_INPUT]->shape_[0]; + } + + TensorC *in_tensor = transpose->base_.in_[FIRST_INPUT]; + if (in_tensor->shape_size_ > MAX_TRANSPOSE_DIM_SIZE) { + return NNACL_TRANSPOSE_INSHAPE_OUT_OF_RANGE; + } + + int trans_nd[MAX_TRANSPOSE_DIM_SIZE] = {0, 2, 1}; + int *perm_data; + if ((int)in_tensor->shape_size_ != transpose->num_axes_) { + perm_data = trans_nd; + if (in_tensor->shape_size_ == Num3 && transpose->num_axes_ == Num4) { + transpose->num_axes_ = Num3; + } + if (transpose->num_axes_ == 0) { + for (size_t i = 0; i < in_tensor->shape_size_; ++i) { + trans_nd[i] = (int)in_tensor->shape_size_ - 1 - (int)i; + } + transpose->num_axes_ = (int)in_tensor->shape_size_; + } + } else { + NNACL_CHECK_TRUE_RET(transpose->base_.in_size_ == TWO_TENSOR, NNACL_TRANSPOSE_INPUT_TENSOR_NUM_INVALID); + TensorC *perm_tensor = transpose->base_.in_[SECOND_INPUT]; + if (perm_tensor->data_type_ != kNumberTypeInt32) { + return NNACL_TRANSPOSE_PERM_TENSOR_INVALID; + } + perm_data = (int *)(perm_tensor->data_); + NNACL_CHECK_NULL_RETURN_ERR(perm_data); + int ele_num = NNACLGetElementNum(perm_tensor); + for (int i = 0; i < ele_num; i++) { + for (int j = 0; j < ele_num; j++) { + if (i == perm_data[j]) { + break; + } + if (j == ele_num - 1) { + return NNACL_TRANSPOSE_PERM_TENSOR_VALUE_INVALID; + } + } + } + } + + NNACL_CHECK_TRUE_RET(transpose->num_axes_ <= MAX_TRANSPOSE_DIM_SIZE, NNACL_TRANSPOSE_PERM_DIMS_INVALID); + for (int i = 0; i < transpose->num_axes_; ++i) { + transpose->perm_[i] = perm_data[i]; + } + return NNACL_OK; +} + +void TransposeFreeSegments(int **segments, int segments_size) { + for (int i = 0; i < segments_size; i++) { + if (segments[i] != NULL) { + free(segments[i]); + segments[i] = NULL; + } + } +} + +int TransposeOptimizeShape(TransposeStruct *transpose) { + TensorC *in_tensor = transpose->base_.in_[FIRST_INPUT]; + int *in_shape = in_tensor->shape_; + + // first step, delete dimension where value is 1. + int in_shape_temp[MAX_TRANSPOSE_DIM_SIZE] = {0}; + int in_shape_temp_size = 0; + int perm_diff[MAX_TRANSPOSE_DIM_SIZE] = {0}; + for (size_t i = 0; i < in_tensor->shape_size_; ++i) { + if (in_shape[i] != 1) { + in_shape_temp[in_shape_temp_size++] = in_shape[i]; + continue; + } + for (size_t j = 0; j < in_tensor->shape_size_; ++j) { + if (transpose->perm_[j] < (int)(i)) { + continue; + } + if (transpose->perm_[j] == (int)(i)) { + perm_diff[j] = (int)(i) + 1; + } else { + perm_diff[j] += 1; + } + } + } + + int perm_temp[MAX_TRANSPOSE_DIM_SIZE] = {0}; + int perm_temp_size = 0; + for (size_t i = 0; i < in_tensor->shape_size_; ++i) { + int diff = transpose->perm_[i] - perm_diff[i]; + if (diff < 0) { + continue; + } + perm_temp[perm_temp_size++] = diff; + } + + NNACL_CHECK_TRUE_RET(in_shape_temp_size == perm_temp_size, NNACL_TRANSPOSE_PERM_DELETE_DIMENSION_FAILED); + + // second step, fuse continuous dimension.; + int axis_num = in_shape_temp_size; + int *segments[MAX_TRANSPOSE_DIM_SIZE]; + int segment_sizes[MAX_TRANSPOSE_DIM_SIZE]; + int segments_size = 0; + for (int i = 0; i < axis_num;) { + int segment[MAX_TRANSPOSE_DIM_SIZE]; + int segment_size = 0; + segment[segment_size++] = perm_temp[i]; + ++i; + for (; i < axis_num; ++i) { + if (perm_temp[i] - 1 != perm_temp[i - 1]) { + break; + } + segment[segment_size++] = perm_temp[i]; + } + + segments[segments_size] = malloc(segment_size * sizeof(int)); + if (segments[segments_size] == NULL) { + TransposeFreeSegments(segments, segments_size); + return NNACL_NULL_PTR; + } + memcpy(segments[segments_size], segment, segment_size * sizeof(int)); + segment_sizes[segments_size] = segment_size; + segments_size++; + } + + transpose->in_shape_size_ = segments_size; + transpose->perm_size_ = segments_size; + for (int i = 0; i < segments_size; i++) { + transpose->in_shape_[i] = 1; + transpose->perm_[i] = 0; + } + for (int i = 0; i < segments_size; ++i) { + for (int j = 0; j < segments_size; ++j) { + transpose->perm_[i] += (segments[j][FIRST_INPUT] < segments[i][FIRST_INPUT] ? 1 : 0); + } + for (int k = 0; k < segment_sizes[i]; ++k) { + transpose->in_shape_[transpose->perm_[i]] *= in_shape_temp[segments[i][k]]; + } + } + TransposeFreeSegments(segments, segments_size); + return NNACL_OK; +} + +void SetTransposeOptInfo(TransposeStruct *transpose) { + // now perm is [1, 0] or [0, 2, 1] + if (transpose->perm_size_ == C2NUM) { + transpose->opt_perm_[FIRST_INPUT] = 1; + transpose->opt_perm_[SECOND_INPUT] = transpose->in_shape_[FIRST_INPUT]; + transpose->opt_perm_[THIRD_INPUT] = transpose->in_shape_[transpose->in_shape_size_ - 1]; + } else { + transpose->opt_perm_[FIRST_INPUT] = transpose->in_shape_[FIRST_INPUT]; + transpose->opt_perm_[SECOND_INPUT] = transpose->in_shape_[SECOND_INPUT]; + transpose->opt_perm_[THIRD_INPUT] = transpose->in_shape_[transpose->in_shape_size_ - 1]; + } +} + +bool TransposeOpt(TransposeStruct *transpose) { + if (transpose->perm_size_ == DIMENSION_2D) { + return true; + } + if (transpose->perm_size_ == DIMENSION_3D && transpose->perm_[FIRST_INPUT] == OPT_PERM_0 && + transpose->perm_[SECOND_INPUT] == OPT_PERM_1 && transpose->perm_[THIRD_INPUT] == OPT_PERM_2) { + return true; + } + return false; +} + +int TransposeComputeOfflineInfo(TransposeStruct *transpose) { + transpose->num_axes_ = transpose->in_shape_size_; + NNACL_CHECK_TRUE_RET(transpose->num_axes_ >= DIMENSION_3D, NNACL_TRANSPOSE_INSHAPE_OUT_OF_RANGE); + + for (int i = 0; i < transpose->num_axes_; ++i) { + transpose->out_shape_[i] = transpose->in_shape_[transpose->perm_[i]]; + } + transpose->strides_[transpose->num_axes_ - 1] = 1; + transpose->out_strides_[transpose->num_axes_ - 1] = 1; + transpose->data_num_ = NNACLGetElementNum(transpose->base_.in_[FIRST_INPUT]); + for (int i = transpose->num_axes_ - 2; i >= 0; i--) { + transpose->strides_[i] = transpose->in_shape_[i + 1] * transpose->strides_[i + 1]; + transpose->out_strides_[i] = transpose->out_shape_[i + 1] * transpose->out_strides_[i + 1]; + } + return NNACL_OK; +} + +int TransposeCopyInputToOutput(TransposeStruct *transpose) { + TensorC *in_tensor = transpose->base_.in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(in_tensor); + NNACL_CHECK_NULL_RETURN_ERR(in_tensor->data_); + TensorC *out_tensor = transpose->base_.out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(out_tensor); + NNACL_CHECK_NULL_RETURN_ERR(out_tensor->data_); + + NNACL_CHECK_FALSE(NNACLGetSize(in_tensor) == 0, NNACL_TRANSPOSE_INPUT_TENSOR_VALUD_INVALID); + if (in_tensor->data_ != out_tensor->data_) { + (void)memcpy(out_tensor->data_, in_tensor->data_, NNACLGetSize(in_tensor)); + } + return NNACL_OK; +} + +int TransposeImpl(void *cdata, int task_id, float l, float r) { + NNACL_CHECK_NULL_RETURN_ERR(cdata); + TransposeStruct *transpose = (TransposeStruct *)cdata; + return TransposeComputeinMultiThread(transpose, task_id); +} + +int TransposeCompute(struct KernelBase *self) { + TransposeStruct *transpose = (TransposeStruct *)self; + if (!transpose->is_valid_) { + return TransposeCopyInputToOutput(transpose); + } + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]->data_); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]->data_); + if (self->thread_nr_ == 1) { + return TransposeComputeinSingleThread(transpose); + } + return self->env_->ParallelLaunch(self->env_->thread_pool_, TransposeImpl, self, self->thread_nr_); +} + +int TransposeResize(struct KernelBase *self) { + TransposeStruct *transpose = (TransposeStruct *)self; + int ret = ResetTransposeStatus(transpose); + if (ret != NNACL_OK) { + return ret; + } + transpose->is_valid_ = (int)transpose->base_.in_[FIRST_INPUT]->shape_size_ == transpose->num_axes_ && + (int)transpose->base_.in_[FIRST_INPUT]->shape_size_ == transpose->perm_size_; + if (!transpose->is_valid_) { + return NNACL_OK; + } + + ret = TransposeOptimizeShape(transpose); + if (ret != NNACL_OK) { + return ret; + } + + transpose->is_valid_ = transpose->perm_size_ > DIMENSION_1D; + if (!transpose->is_valid_) { + return NNACL_OK; + } + + transpose->opt_run_ = TransposeOpt(transpose); + if (transpose->opt_run_) { + SetTransposeOptInfo(transpose); + return NNACL_OK; + } + + ret = TransposeComputeOfflineInfo(transpose); + if (ret != NNACL_OK) { + return ret; + } + + self->thread_nr_ = (!transpose->opt_run_ && transpose->num_axes_ <= DIMENSION_6D) ? 1 : self->thread_nr_; + return NNACL_OK; +} + +int TransposePrepare(struct KernelBase *self) { + int ret = DefaultPrepare1In1Out(self); + if (ret != NNACL_OK) { + return ret; + } + TransposeStruct *transpose = (TransposeStruct *)self; + TransposeParameter *param = (TransposeParameter *)transpose->base_.param_; + if (param->perm_size_ > INT32_MAX) { + return NNACL_TRANSPOSE_PERM_DIMS_INVALID; + } + transpose->perm_size_ = (int)param->perm_size_; + for (int i = 0; i < transpose->perm_size_; i++) { + transpose->perm_[i] = param->perm_[i]; + } + return NNACL_OK; +} + +KernelBase *CreateTranspose(OpParameter *param, int data_type) { + TransposeStruct *transpose = (TransposeStruct *)malloc(sizeof(TransposeStruct)); + NNACL_MALLOC_CHECK_NULL_RETURN_NULL(transpose); + transpose->nhwc2nchw_ = PackNHWCToNCHWFp32; + transpose->optimize_ = TransposeDimsFp32; + transpose->compute_ = DoTransposeFp32; + transpose->base_.Release = DefaultRelease; + transpose->base_.Prepare = TransposePrepare; + transpose->base_.Resize = TransposeResize; + transpose->base_.Compute = TransposeCompute; + if (data_type == kNumberTypeFloat16) { +#ifdef ENABLE_FP16 + transpose->nhwc2nchw_ = PackNHWCToNCHWFp16; + transpose->optimize_ = TransposeDimsFp16; + transpose->compute_ = DoTransposeFp16; +#else + free(transpose); + return NULL; +#endif + } + return (KernelBase *)transpose; +} + +REG_KERNEL_CREATOR(PrimType_Transpose, kNumberTypeFloat32, CreateTranspose) +REG_KERNEL_CREATOR(PrimType_Transpose, kNumberTypeFloat16, CreateTranspose) +REG_KERNEL_CREATOR(PrimType_Transpose, kNumberTypeInt32, CreateTranspose) diff --git a/mindspore-lite/src/extendrt/graph_executor/default_executor.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/transpose.h similarity index 33% rename from mindspore-lite/src/extendrt/graph_executor/default_executor.h rename to mindspore-lite/ops/kernel/cpu/nnacl/kernel/transpose.h index 0af7eadad23934468cb50ef11a8dda08b1f81f3e..63f69f47f301f6bebf3cc96d9a68598021a9fcdd 100644 --- a/mindspore-lite/src/extendrt/graph_executor/default_executor.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/transpose.h @@ -14,44 +14,36 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DEFAULT_EXECUTOR_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DEFAULT_EXECUTOR_H_ - -#include -#include -#include - -#include "infer/executor.h" -#include "infer/execution_plan.h" -#include "litert/executor.h" - -namespace mindspore { -/** - * DefaultExecutor: Execute Kernel one by one, good for acl single kernel graph - */ -class DefaultExecutor : public mindspore::infer::abstract::Executor { - public: - DefaultExecutor(); - explicit DefaultExecutor(const std::string &name, std::shared_ptr execution_plan); - virtual ~DefaultExecutor() = default; - - const std::string &Name() override { return name_; } - - Status Prepare() override; - - Status Execute() override; - - int Resize(const std::vector &inputs, - const std::vector> &dims) override; - - private: - bool Init(); - - private: - std::string name_; - std::shared_ptr execution_plan_; - bool inited_ = false; -}; -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DEFAULT_EXECUTOR_H_ +#ifndef NNACL_KERNEL_TRANSPOSE_H_ +#define NNACL_KERNEL_TRANSPOSE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/kernel.h" +#include "nnacl/transpose_parameter.h" + +typedef struct TransposeStruct { + KernelBase base_; + bool is_valid_; + int num_axes_; + int data_num_; + int perm_[MAX_TRANSPOSE_DIM_SIZE]; + int perm_size_; + int in_shape_[MAX_TRANSPOSE_DIM_SIZE]; /* after shape optimize */ + int in_shape_size_; + int out_shape_[MAX_TRANSPOSE_DIM_SIZE]; + int strides_[MAX_TRANSPOSE_DIM_SIZE]; + int out_strides_[MAX_TRANSPOSE_DIM_SIZE]; + + int opt_perm_[PERM_NUM_THREE]; // only valid when opt_run_ is true + bool opt_run_; // only true when perm is [1, 0] or [0, 2, 1] + + int (*compute_)(const void *src, void *dst, const int *out_shape, int *perm, int *strides, int *out_strides, + int data_size, int num_axes); + void (*nhwc2nchw_)(const void *src, void *dst, int b, int hw, int c, int task_id, int thread); + void (*optimize_)(const void *src, void *dst, const int *out_shape, int *perm, int *strides, int *out_strides, + int num_axes, int task_id, int thread); +} TransposeStruct; + +KernelBase *CreateTranspose(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_TRANSPOSE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/tril.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/tril.c new file mode 100644 index 0000000000000000000000000000000000000000..67cd89af74b7fcd10999d189c3fa11e6e815e40b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/tril.c @@ -0,0 +1,89 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/tril.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/common_func.h" +#include "nnacl/fp32/triu_tril_fp32.h" + +int TrilCompute(KernelBase *self) { + TrilStruct *tril = (TrilStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(tril); + + int ret = TriuTrilGetKValue(self, &tril->k_); + if (ret != NNACL_OK) { + return ret; + } + + int64_t mul, height, width; + ret = TriuTrilGetCalculateNum(self, &mul, &height, &width); + if (ret != NNACL_OK) { + return ret; + } + + void *src_data = self->in_[FIRST_INPUT]->data_; + void *dst_data = self->out_[OUTPUT_INDEX]->data_; + int type_size = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + NNACL_CHECK_ZERO_RETURN_ERR(type_size); + + switch (type_size) { + case sizeof(int64_t): { + TrilByte8(src_data, dst_data, tril->k_, height, width, mul); + break; + } + case sizeof(int32_t): { + TrilByte4(src_data, dst_data, tril->k_, height, width, mul); + break; + } + case sizeof(int16_t): { + TrilByte2(src_data, dst_data, tril->k_, height, width, mul); + break; + } + case sizeof(int8_t): { + TrilByte1(src_data, dst_data, tril->k_, height, width, mul); + break; + } + default: + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +KernelBase *CreateTril(OpParameter *param, int data_type) { + TrilStruct *tril = (TrilStruct *)malloc(sizeof(TrilStruct)); + NNACL_CHECK_NULL_RETURN_NULL(tril); + tril->base_.Release = DefaultRelease; + tril->base_.Prepare = DefaultPrepare1In1Out; + tril->base_.Resize = DefaultResize; + tril->base_.Compute = TrilCompute; + return (KernelBase *)tril; +} + +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeDouble, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat64, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat32, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeFloat16, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt64, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt32, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt16, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeInt8, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt64, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt32, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt16, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeUInt8, CreateTril) +REG_KERNEL_CREATOR(PrimType_Tril, kNumberTypeBool, CreateTril) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/tril.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/tril.h new file mode 100644 index 0000000000000000000000000000000000000000..57e36a650067665e09dc5320bdc344baa55d873e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/tril.h @@ -0,0 +1,32 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_TRIL_H_ +#define NNACL_KERNEL_TRIL_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct TrilStruct { + KernelBase base_; + int64_t k_; +} TrilStruct; + +KernelBase *CreateTril(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_TRIL_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/triu.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/triu.c new file mode 100644 index 0000000000000000000000000000000000000000..c88a13f24be253fec1e3715add243c3b2d02ff0c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/triu.c @@ -0,0 +1,89 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/triu.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/common_func.h" +#include "nnacl/fp32/triu_tril_fp32.h" + +int TriuCompute(KernelBase *self) { + TriuStruct *triu = (TriuStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(triu); + + void *src_data = self->in_[FIRST_INPUT]->data_; + void *dst_data = self->out_[OUTPUT_INDEX]->data_; + int type_size = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_); + NNACL_CHECK_ZERO_RETURN_ERR(type_size); + + int ret = TriuTrilGetKValue(self, &triu->k_); + if (ret != NNACL_OK) { + return ret; + } + + int64_t mul, height, width; + ret = TriuTrilGetCalculateNum(self, &mul, &height, &width); + if (ret != NNACL_OK) { + return ret; + } + + switch (type_size) { + case sizeof(int64_t): { + TriuByte8(src_data, dst_data, triu->k_, height, width, mul); + break; + } + case sizeof(int32_t): { + TriuByte4(src_data, dst_data, triu->k_, height, width, mul); + break; + } + case sizeof(int16_t): { + TriuByte2(src_data, dst_data, triu->k_, height, width, mul); + break; + } + case sizeof(int8_t): { + TriuByte1(src_data, dst_data, triu->k_, height, width, mul); + break; + } + default: + return NNACL_UNSUPPORTED_DATA_TYPE; + } + return NNACL_OK; +} + +KernelBase *CreateTriu(OpParameter *param, int data_type) { + TriuStruct *triu = (TriuStruct *)malloc(sizeof(TriuStruct)); + NNACL_CHECK_NULL_RETURN_NULL(triu); + triu->base_.Release = DefaultRelease; + triu->base_.Prepare = DefaultPrepare1In1Out; + triu->base_.Resize = DefaultResize; + triu->base_.Compute = TriuCompute; + return (KernelBase *)triu; +} + +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeDouble, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeFloat, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeFloat64, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeFloat32, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeFloat16, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt64, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt32, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt16, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeInt8, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeUInt64, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeUInt32, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeUInt16, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeUInt8, CreateTriu) +REG_KERNEL_CREATOR(PrimType_Triu, kNumberTypeBool, CreateTriu) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/triu.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/triu.h new file mode 100644 index 0000000000000000000000000000000000000000..9dc13b90d842d5749fe2b99b97a7b25c197740f0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/triu.h @@ -0,0 +1,32 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_TRIU_H_ +#define NNACL_KERNEL_TRIU_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct TriuStruct { + KernelBase base_; + int64_t k_; +} TriuStruct; + +KernelBase *CreateTriu(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_TRIU_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/unique.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/unique.c new file mode 100644 index 0000000000000000000000000000000000000000..20431111f915d69f9a2a78192db05fe6fa1e9955 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/unique.c @@ -0,0 +1,66 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/unique.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/fp32/unique_fp32.h" +#include "nnacl/tensor_c_utils.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/unique_fp16.h" +#endif + +int UniqueCompute(KernelBase *self) { + TensorC *input = self->in_[FIRST_INPUT]; + NNACL_CHECK_NULL_RETURN_ERR(input); + TensorC *output0 = self->out_[Index0]; + NNACL_CHECK_NULL_RETURN_ERR(output0); + TensorC *output1 = self->out_[Index1]; + NNACL_CHECK_NULL_RETURN_ERR(output1); + + int num = NNACLGetElementNum(input); + int output0_len = 0; + +#ifdef ENABLE_FP16 + if (input->data_type_ == kNumberTypeFloat16) { + UniqueFp16((float16_t *)input->data_, num, (float16_t *)output0->data_, &output0_len, (int *)output1->data_); + } +#endif + if (input->data_type_ == kNumberTypeInt32) { + UniqueInt((int *)input->data_, num, (int *)output0->data_, &output0_len, (int *)output1->data_); + } + if (input->data_type_ == kNumberTypeFloat32) { + Unique((float *)input->data_, num, (float *)output0->data_, &output0_len, (int *)output1->data_); + } + + output0->shape_changed_ = (output0->shape_[output0->shape_size_ - 1] != output0_len); + output0->shape_[output0->shape_size_ - 1] = output0_len; + return NNACL_OK; +} + +KernelBase *CreateUnique(OpParameter *param, int data_type) { + UniqueStruct *unique = (UniqueStruct *)malloc(sizeof(UniqueStruct)); + NNACL_CHECK_NULL_RETURN_NULL(unique); + unique->data_type_ = data_type; + unique->base_.Release = DefaultRelease; + unique->base_.Prepare = DefaultPrepare1In2Out; + unique->base_.Resize = DefaultResize; + unique->base_.Compute = UniqueCompute; + return (KernelBase *)unique; +} + +REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeInt32, CreateUnique) +REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeFloat32, CreateUnique) +REG_KERNEL_CREATOR(PrimType_Unique, kNumberTypeFloat16, CreateUnique) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/unique.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/unique.h new file mode 100644 index 0000000000000000000000000000000000000000..a0c01ded63b5535e12608db9296f76931f704bfe --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/unique.h @@ -0,0 +1,32 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_UNIQUE_H_ +#define NNACL_KERNEL_UNIQUE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct UniqueStruct { + KernelBase base_; + int data_type_; +} UniqueStruct; + +KernelBase *CreateUnique(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_UNIQUE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/where.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/where.c new file mode 100644 index 0000000000000000000000000000000000000000..442f35f6c8e698d8fb1daf56adc6aad87470b9ec --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/where.c @@ -0,0 +1,298 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/where.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/common_func.h" +#include "nnacl/tensor_c_utils.h" +#include "nnacl/fp32/where_fp32.h" +#ifdef ENABLE_FP16 +#include "nnacl/fp16/where_fp16.h" +#endif +#include "nnacl/base/broadcast_to.h" + +int WhereExcuteFp16(WhereStruct *where, int task_id) { +#ifdef ENABLE_FP16 + WhereWithTripleInputsFp16((float16_t *)where->x_, (float16_t *)where->y_, (float16_t *)where->output_, &where->args_, + task_id, where->base_.thread_nr_); +#endif + return NNACL_OK; +} + +int WhereExcute(WhereStruct *where, int task_id) { + WhereWithTripleInputs((float *)where->x_, (float *)where->y_, (float *)where->output_, &where->args_, task_id, + where->base_.thread_nr_); + return NNACL_OK; +} + +int WhereRun(void *cdata, int task_id, float l, float r) { + WhereStruct *where = (WhereStruct *)cdata; + NNACL_CHECK_NULL_RETURN_ERR(where); + + NNACL_CHECK_NULL_RETURN_ERR(where->x_); + NNACL_CHECK_NULL_RETURN_ERR(where->y_); + NNACL_CHECK_NULL_RETURN_ERR(where->output_); + NNACL_CHECK_NULL_RETURN_ERR(where->args_.condition_); + + if (where->data_type_ == kNumberTypeFloat16) { + return WhereExcuteFp16(where, task_id); + } + return WhereExcute(where, task_id); +} + +int WhereRunWithSingleInput(WhereStruct *where) { + TensorC *input = where->base_.in_[FIRST_INPUT]; + int32_t *int32_condition = NULL; + float *fp32_condition = NULL; + bool *bool_condition = NULL; + switch (where->data_type_) { + case kNumberTypeInt32: + int32_condition = (int32_t *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(int32_condition); + break; + case kNumberTypeFloat32: + fp32_condition = (float *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(fp32_condition); + break; + case kNumberTypeBool: + bool_condition = (bool *)input->data_; + NNACL_CHECK_NULL_RETURN_ERR(bool_condition); + break; + default: + return NNACL_WHERE_CONDITION_DATA_TYPE_ERROR; + } + WhereArgs *where_args = &where->args_; + where_args->condition_num_ = NNACLGetElementNum(input); + where_args->rank_ = input->shape_size_; + int strides[MAX_SHAPE_SIZE]; + ComputeStrides(input->shape_, strides, where_args->rank_); + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(where_args->condition_num_, where_args->rank_, NNACL_ERR); + int data_num_int = where_args->condition_num_ * where_args->rank_; + NNACL_CHECK_TRUE_RET(data_num_int >= 0, NNACL_ERR); + size_t result_size = (size_t)data_num_int * sizeof(int32_t); + int32_t *result = where->base_.env_->Alloc(where->base_.env_->allocator_, result_size); + NNACL_MALLOC_CHECK_NULL_RETURN_ERR(result); + + int result_index = 0; + int true_num = 0; + for (int index = 0; index < where_args->condition_num_; index++) { + bool condition = false; + switch (where->data_type_) { + case kNumberTypeInt32: + condition = (bool)int32_condition[index]; + break; + case kNumberTypeFloat32: + condition = (bool)fp32_condition[index]; + break; + case kNumberTypeBool: + condition = (bool)bool_condition[index]; + break; + default: + return NNACL_WHERE_CONDITION_DATA_TYPE_ERROR; + } + if (condition) { + true_num++; + int dim = index; + for (int j = 0; j < where_args->rank_; j++) { + NNACL_CHECK_ZERO_RETURN_ERR(strides[j]); + result[result_index++] = dim / strides[j]; + dim %= strides[j]; + } + } + } + + TensorC *output = where->base_.out_[OUTPUT_INDEX]; + if (output->data_ != NULL) { + /* the data should be nullptr */ + where->base_.env_->Free(where->base_.env_->allocator_, output->data_); + } + int output_shape[] = {true_num, where_args->rank_}; + output->shape_changed_ = ShapeEqual(output->shape_, output->shape_size_, output_shape, Num2); + output->shape_size_ = Num2; + memcpy(output->shape_, output_shape, Num2 * sizeof(int)); + + if (true_num > 0) { + output->data_ = result; + } + return NNACL_OK; +} + +int WhereBroadCastForInput(WhereStruct *where, TensorC *condition, TensorC *x, TensorC *y, + void **condition_broadcast_buf, void **x_broadcast_buf, void **y_broadcast_buf, + TensorC *output) { + size_t broad_cast_buf_size = NNACLGetElementNum(output); + if (output->data_type_ == kNumberTypeFloat32) { + broad_cast_buf_size *= sizeof(float); + } else { + return NNACL_WHERE_BROAD_CAST_FAILED; + } + BroadcastShapeInfo condition_info; + condition_info.input_shape_size_ = condition->shape_size_; + condition_info.output_shape_size_ = output->shape_size_; + memcpy(condition_info.input_shape_, condition->shape_, condition->shape_size_ * sizeof(int)); + memcpy(condition_info.output_shape_, output->shape_, output->shape_size_ * sizeof(int)); + + BroadcastShapeInfo x_info; + x_info.input_shape_size_ = x->shape_size_; + x_info.output_shape_size_ = output->shape_size_; + memcpy(x_info.input_shape_, x->shape_, x->shape_size_ * sizeof(int)); + memcpy(x_info.output_shape_, output->shape_, output->shape_size_ * sizeof(int)); + + BroadcastShapeInfo y_info; + y_info.input_shape_size_ = y->shape_size_; + y_info.output_shape_size_ = output->shape_size_; + memcpy(y_info.input_shape_, y->shape_, y->shape_size_ * sizeof(int)); + memcpy(y_info.output_shape_, output->shape_, output->shape_size_ * sizeof(int)); + + *condition_broadcast_buf = where->base_.env_->Alloc(where->base_.env_->allocator_, broad_cast_buf_size); + if (*condition_broadcast_buf == NULL) { + return NNACL_WHERE_BROAD_CAST_FAILED; + } + BroadcastToSize8(condition->data_, &condition_info, *condition_broadcast_buf); + + *x_broadcast_buf = where->base_.env_->Alloc(where->base_.env_->allocator_, broad_cast_buf_size); + if (*x_broadcast_buf == NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, *condition_broadcast_buf); + return NNACL_WHERE_BROAD_CAST_FAILED; + } + BroadcastToSize32(x->data_, &x_info, *x_broadcast_buf); + + *y_broadcast_buf = where->base_.env_->Alloc(where->base_.env_->allocator_, broad_cast_buf_size); + if (*y_broadcast_buf == NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, *condition_broadcast_buf); + where->base_.env_->Free(where->base_.env_->allocator_, *x_broadcast_buf); + return NNACL_WHERE_BROAD_CAST_FAILED; + } + BroadcastToSize32(y->data_, &y_info, *y_broadcast_buf); + return NNACL_OK; +} + +int WhereRunWithTripleInputs(WhereStruct *where) { + TensorC *condition = where->base_.in_[Index0]; + NNACL_CHECK_NULL_RETURN_ERR(condition); + TensorC *x = where->base_.in_[Index1]; + NNACL_CHECK_NULL_RETURN_ERR(x); + TensorC *y = where->base_.in_[Index2]; + NNACL_CHECK_NULL_RETURN_ERR(y); + TensorC *output = where->base_.out_[Index0]; + NNACL_CHECK_NULL_RETURN_ERR(output); + + int condition_nums = NNACLGetElementNum(condition); + int x_num = NNACLGetElementNum(x); + int y_num = NNACLGetElementNum(y); + int out_num = NNACLGetElementNum(output); + int num_max = condition_nums > x_num ? condition_nums : (x_num > y_num ? x_num : y_num); + + where->x_ = x->data_; + where->y_ = y->data_; + where->output_ = output->data_; + + WhereArgs *args = &where->args_; + args->condition_ = (bool *)condition->data_; + args->condition_num_ = condition_nums; + args->x_num_ = x_num; + args->y_num_ = y_num; + args->max_num_ = num_max; + + void *condition_broadcast_buf = NULL; + void *x_broadcast_buf = NULL; + void *y_broadcast_buf = NULL; + + if (out_num < num_max) { + return NNACL_WHERE_INVALID_OUT_NUM; + } + if (((condition_nums != 1) && (condition_nums != num_max)) || ((x_num != 1) && (x_num != num_max)) || + ((y_num != 1) && (y_num != num_max))) { + if (condition_nums != NNACLGetElementNum(y) && condition->shape_size_ != y->shape_size_) { + int ret = WhereBroadCastForInput(where, condition, x, y, &condition_broadcast_buf, &x_broadcast_buf, + &y_broadcast_buf, output); + if (ret != NNACL_OK) { + return NNACL_WHERE_BROAD_CAST_FAILED; + } + int max_num = NNACLGetElementNum(output); + args->condition_ = (bool *)condition_broadcast_buf; + where->x_ = x_broadcast_buf; + where->y_ = y_broadcast_buf; + where->output_ = output->data_; + args->condition_num_ = max_num; + args->x_num_ = max_num; + args->y_num_ = max_num; + args->max_num_ = max_num; + } else { + /* The length of three inputs are not equal to 1 or length of output, which is unacceptable */ + return NNACL_WHERE_CONDITION_NUM_INVALID; + } + } + if (num_max <= 0) { + /* Error, inputs' length are zero */ + return NNACL_WHERE_NUM_MAX_INVALID; + } + int ret = + where->base_.env_->ParallelLaunch(where->base_.env_->thread_pool_, WhereRun, where, where->base_.thread_nr_); + if (condition_broadcast_buf != NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, condition_broadcast_buf); + condition_broadcast_buf = NULL; + } + if (x_broadcast_buf != NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, x_broadcast_buf); + x_broadcast_buf = NULL; + } + if (y_broadcast_buf != NULL) { + where->base_.env_->Free(where->base_.env_->allocator_, y_broadcast_buf); + y_broadcast_buf = NULL; + } + return ret; +} + +int WhereCompute(KernelBase *self) { + WhereStruct *where = (WhereStruct *)self; + NNACL_CHECK_NULL_RETURN_ERR(where); + + int ret = NNACL_ERR; + if (self->in_size_ == Num1) { + ret = WhereRunWithSingleInput(where); + } else if (self->in_size_ == Num3) { + ret = WhereRunWithTripleInputs(where); + } else { + ret = NNACL_WHERE_INPUT_NUM_INVALID; + } + return ret; +} + +int WherePrepare(KernelBase *self) { + NNACL_CHECK_TRUE_RET(self->in_size_ == Num1 || self->in_size_ == Num3, NNACL_WHERE_INPUT_NUM_INVALID); + NNACL_CHECK_TRUE_RET(self->out_size_ == Num1, NNACL_OUTPUT_TENSOR_ERROR); + NNACL_CHECK_NULL_RETURN_ERR(self->in_[FIRST_INPUT]); + NNACL_CHECK_NULL_RETURN_ERR(self->out_[OUTPUT_INDEX]); + return NNACL_OK; +} + +KernelBase *CreateWhere(OpParameter *param, int data_type) { + WhereStruct *where = (WhereStruct *)malloc(sizeof(WhereStruct)); + NNACL_CHECK_NULL_RETURN_NULL(where); + memset(where, 0, sizeof(WhereStruct)); + where->data_type_ = data_type; + where->base_.Prepare = WherePrepare; + where->base_.Compute = WhereCompute; + where->base_.Resize = DefaultResize; + where->base_.Release = DefaultRelease; + return (KernelBase *)where; +} + +REG_KERNEL_CREATOR(PrimType_Where, kNumberTypeBool, CreateWhere) +REG_KERNEL_CREATOR(PrimType_Where, kNumberTypeInt32, CreateWhere) +REG_KERNEL_CREATOR(PrimType_Where, kNumberTypeFloat16, CreateWhere) +REG_KERNEL_CREATOR(PrimType_Where, kNumberTypeFloat32, CreateWhere) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/where.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/where.h new file mode 100644 index 0000000000000000000000000000000000000000..381fed5214b62b0d9328840ec6ba62a348e4df6e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/where.h @@ -0,0 +1,44 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_WHERE_H_ +#define NNACL_KERNEL_WHERE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct WhereArgs { + int condition_num_; + int x_num_; + int y_num_; + int max_num_; + int rank_; + bool *condition_; +} WhereArgs; + +typedef struct WhereStruct { + KernelBase base_; + WhereArgs args_; + int data_type_; + void *x_; + void *y_; + void *output_; +} WhereStruct; + +KernelBase *CreateWhere(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_WHERE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/zeros_like.c b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/zeros_like.c new file mode 100644 index 0000000000000000000000000000000000000000..288475153d833eff4ae29086cf4c8a548e9a8223 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/zeros_like.c @@ -0,0 +1,43 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/kernel/zeros_like.h" +#include "nnacl/kernel/default_kernel_base.h" +#include "nnacl/common_func.h" +#include "nnacl/tensor_c_utils.h" + +int ZerosLikeCompute(KernelBase *self) { + NNACL_CHECK_NULL_RETURN_ERR(self); + TensorC *output = self->out_[OUTPUT_INDEX]; + NNACL_CHECK_NULL_RETURN_ERR(output); + NNACL_CHECK_NULL_RETURN_ERR(output->data_); + (void)memset(output->data_, 0, NNACLGetSize(output)); + return NNACL_OK; +} + +KernelBase *CreateZerosLike(OpParameter *param, int data_type) { + ZerosLikeStruct *zeros_like = (ZerosLikeStruct *)malloc(sizeof(ZerosLikeStruct)); + NNACL_CHECK_NULL_RETURN_NULL(zeros_like); + zeros_like->base_.Release = DefaultRelease; + zeros_like->base_.Prepare = DefaultPrepare1In1Out; + zeros_like->base_.Resize = DefaultResize; + zeros_like->base_.Compute = ZerosLikeCompute; + return (KernelBase *)zeros_like; +} + +REG_KERNEL_CREATOR(PrimType_ZerosLike, kNumberTypeInt32, CreateZerosLike) +REG_KERNEL_CREATOR(PrimType_ZerosLike, kNumberTypeFloat32, CreateZerosLike) +REG_KERNEL_CREATOR(PrimType_ZerosLike, kNumberTypeFloat16, CreateZerosLike) diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/kernel/zeros_like.h b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/zeros_like.h new file mode 100644 index 0000000000000000000000000000000000000000..36ece1ec59336774e7074d1b07d60f66ba2a5110 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/kernel/zeros_like.h @@ -0,0 +1,31 @@ + +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_KERNEL_ZEROS_LIKE_H_ +#define NNACL_KERNEL_ZEROS_LIKE_H_ + +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" +#include "nnacl/kernel.h" + +typedef struct ZerosLikeStruct { + KernelBase base_; +} ZerosLikeStruct; + +KernelBase *CreateZerosLike(OpParameter *param, int data_type); + +#endif // NNACL_KERNEL_ZEROS_LIKE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/l2_norm_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/l2_norm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..f5f1f39cc9cc786791c087cbd73911d5c207e1d3 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/l2_norm_parameter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_L2NORM_PARAMETER_H_ +#define NNACL_L2NORM_PARAMETER_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +typedef struct L2NormParameter { + // Primitive parameter + OpParameter op_parameter_; + float epsilon_; + int axis_[MAX_SHAPE_SIZE]; + // shape correlative + size_t axis_num_; + int data_num_; + int *shape_; + size_t shape_num_; + // other parameter + ActType act_type_; +} L2NormParameter; + +typedef struct { + QuantArg in_; + QuantArg out_; +} L2NormQuantArg; + +#endif // NNACL_L2NORM_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/layer_norm_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/layer_norm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..da8cc09715bfb500bd2a67362cbb0082f03f919f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/layer_norm_parameter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_LAYER_NORM_PARAMETER_H_ +#define NNACL_LAYER_NORM_PARAMETER_H_ + +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +typedef struct LayerNormParameter { + OpParameter op_parameter_; + float epsilon_; + bool elementwise_affine_; + int begin_norm_axis_; + int begin_params_axis_; +} LayerNormParameter; + +typedef struct LayerNormQuantArg { + int32_t in_zp_; + int32_t out_zp_; + double in_scale_; + double out_scale_; +} LayerNormQuantArg; + +#endif // NNACL_LAYER_NORM_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/local_response_norm_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/local_response_norm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b39bace17f05bc5cf4f7a0c12fa51f8253cbc051 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/local_response_norm_parameter.h @@ -0,0 +1,31 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_LRM_PARAMETER_H_ +#define NNACL_LRM_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct LocalResponseNormParameter { + OpParameter op_parameter_; + int depth_radius_; + float bias_; + float alpha_; + float beta_; +} LocalResponseNormParameter; + +#endif // NNACL_LRM_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/lsh_projection_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/lsh_projection_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..a85b45234b9b708cbbc5580b4c381e5ef7947ac7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/lsh_projection_parameter.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_LSH_PROJECTION_PARAMETER_H_ +#define NNACL_LSH_PROJECTION_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct LshProjectionParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int hash_shape_[2]; + // other parameter + int lsh_type_; + int feature_num_; + char **hash_buffs_; + size_t hash_buff_size_; + int64_t thread_stride_; +} LshProjectionParameter; + +#endif // NNACL_LSH_PROJECTION_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/lstm_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/lstm_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..5baf10fa509c78bab8480d434c766fd84217875c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/lstm_parameter.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_LSTM_PARAMETER_H_ +#define NNACL_LSTM_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct LstmParameter { + // Primitive parameter + OpParameter op_parameter_; + // shape correlative + int input_size_; + int hidden_size_; + int project_size_; + int output_size_; + int seq_len_; + int batch_; + // other parameter + int output_step_; + bool bidirectional_; + float zoneout_cell_; + float zoneout_hidden_; + int input_row_align_; + int input_col_align_; + int state_row_align_; + int state_col_align_; + int proj_col_align_; + bool has_bias_; +} LstmParameter; + +#endif // NNACL_LSTM_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/matmul_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/matmul_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..3d66e0a8ceed0ded81b0a6e2f01fd6bdcf80d900 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/matmul_parameter.h @@ -0,0 +1,96 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_MATMUL_H_ +#define NNACL_MATMUL_H_ + +#include "nnacl/op_base.h" + +typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int32_t *dst, int row_4, int col_4, int deep_16, + const int32_t *input_sum, const int32_t *bias); + +typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, + const int32_t *left_shift, const int32_t *right_shift, const int32_t *multiplier, + int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel); + +typedef void (*MATMUL_OPT_DP_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, + const int32_t *left_shift, const int32_t *right_shift, const int32_t *multiplier, + int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel, + const int32_t *filter_zp); + +typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2, OutType_NC4HW4 = 3 } OutType; + +typedef enum MatmulType { + // reserve 0 for base op + kNotImplemented = 0, + kMatmulInt8Cpu, + kMatmulDynamicInt8Cpu, + kMatmulDynamicSdotInt8Cpu, + kMatmulFp32BaseCpu, + kMatmulFp32Arm64Cpu, +} MatmulType; + +typedef struct MatMulParameter { + // Primitive parameter + OpParameter op_parameter_; + bool has_bias_; + bool use_axis_; + bool a_transpose_; /* false : row-major */ + bool b_transpose_; /* true : col-major */ + ActType act_type_; + + // other parameter + int row_; + int col_; + int row_4_; + int row_16_; + int row_align_; + int col_8_; + int col_align_; + int deep_; + int deep_4_; + int deep_16_; + int deep_align_; + int batch; + bool a_const_; + bool b_const_; + int axis_; + MatmulType matmul_type_; +} MatMulParameter; + +typedef struct MatmulQuantParameter { + QuantArg input_; + QuantArg weight_; + QuantArg output_; + int32_t out_act_min_; + int32_t out_act_max_; + float *filter_scale_; + int32_t *filter_zp_; + int32_t *left_shift_; + int32_t *right_shift_; + int32_t *quant_multiplier_; +} MatmulQuantParameter; + +typedef struct MatmulDynamicQuantParameter { + float *input_scale_; + int32_t *input_zp_; + float *filter_scale_; + int32_t *filter_zp_; +} MatmulDynamicQuantParameter; + +#endif // NNACL_MATMUL_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/mul_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/mul_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..ef391eefa5cb20dbb3df412dd2a85a92005ac786 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/mul_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_MUL_PARAMETER_H_ +#define NNACL_MUL_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct MulQuantArg { + QuantArg in_quant_args_[2]; + QuantArg out_quant_arg_; + int output_multiplier_; + int output_activation_min_; + int output_activation_max_; + int shift_left_; + int shift_right_; +} MulQuantArg; + +#endif // NNACL_MUL_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/nllloss_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/nllloss_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..c11f5bce2f82126ef9e88af3fe95c1853f5f1121 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/nllloss_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NLLLOSS_PARAMETER_H_ +#define NNACL_NLLLOSS_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct NLLLossParameter { + OpParameter op_parameter_; + ReductionType reduction_type_; +} NLLLossParameter; + +#endif // NNACL_NLLLOSS_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/nnacl_common.c b/mindspore-lite/ops/kernel/cpu/nnacl/nnacl_common.c new file mode 100644 index 0000000000000000000000000000000000000000..af65fda137f79bc1a7fa0292086580addc75ec9c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/nnacl_common.c @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/nnacl_common.h" + +typedef union float32_bits { + unsigned int u; + float f; +} float32_bits; + +float ShortToFloat32(uint16_t src_value) { + const float32_bits magic = {113 << 23}; + const unsigned int shifted_exp = 0x7c00 << 13; + float32_bits o; + + o.u = (src_value & 0x7fff) << 13; + unsigned int exp = shifted_exp & o.u; + o.u += (127 - 15) << 23; + + if (exp == shifted_exp) { + o.u += (128 - 16) << 23; + } else if (exp == 0) { + o.u += 1 << 23; + o.f -= magic.f; + } + + o.u |= (src_value & 0x8000) << 16; + return o.f; +} + +uint16_t Float32ToShort(float src_value) { + float32_bits src_value_bits; + src_value_bits.f = src_value; + uint16_t res = 0; + // mantissa + res += (src_value_bits.u >> 13); + // exponent + res += (src_value_bits.u >> 13) & 0x3fc00; + res -= (127 - 15) << 13; + + // sign + res |= (src_value_bits.u & 0x80000000) >> 16; + return res; +} diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/nnacl_common.h b/mindspore-lite/ops/kernel/cpu/nnacl/nnacl_common.h new file mode 100644 index 0000000000000000000000000000000000000000..44617603f51f3815b6dccf932324e13f8ae7aab8 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/nnacl_common.h @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NNACL_COMMON_H_ +#define NNACL_NNACL_COMMON_H_ + +#include +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +static inline size_t DataTypeCSize(TypeIdC type) { + switch (type) { + case kNumberTypeFloat64: + return sizeof(double); + case kNumberTypeFloat: + case kNumberTypeFloat32: + return sizeof(float); + case kNumberTypeInt8: + return sizeof(int8_t); + case kNumberTypeUInt8: + return sizeof(uint8_t); + case kNumberTypeFloat16: + case kNumberTypeInt16: + return sizeof(int16_t); + case kNumberTypeInt32: + return sizeof(int32_t); + case kNumberTypeInt64: + return sizeof(int64_t); + case kNumberTypeUInt16: + return sizeof(uint16_t); + case kNumberTypeUInt32: + return sizeof(uint32_t); + case kNumberTypeUInt64: + return sizeof(uint64_t); + case kNumberTypeComplex64: + return sizeof(float) + sizeof(float); + case kNumberTypeComplex128: + return sizeof(double) + sizeof(double); + case kNumberTypeBool: + return sizeof(bool); + case kObjectTypeString: + return sizeof(char); + case kObjectTypeTensorType: + return 0; + case kMetaTypeTypeType: + return sizeof(int); + default: + return 0; + } +} + +static inline void ComputeStrides(const int *shape, int *strides, const int ndim) { + int stride = 1; + for (int i = ndim - 1; i >= 0; i--) { + strides[i] = stride; + stride *= shape[i]; + } +} + +static inline void ComputeAxisDims(const int *shape, int shape_size, int axis, int *out_count, int *axis_count, + int *in_count) { + *out_count = 1; + *in_count = 1; + for (int i = 0; i < shape_size; i++) { + if (i < axis) { + *out_count = (*out_count) * shape[i]; + } + if (i == axis) { + *axis_count = shape[axis]; + } + if (i > axis) { + *in_count = (*in_count) * shape[i]; + } + } +} + +static const unsigned int FP32_BIT_SIZE = 32; +static const unsigned int FP32_EXPONENT_BIAS = 127; +static const unsigned int FP32_SIGNIFICAND = 23; +static const unsigned int FP32_EXPONENT_MAX = 255; +static const unsigned int FP16_BIT_SIZE = 16; +static const unsigned int FP16_EXPONENT_BIAS = 15; +static const unsigned int FP16_SIGNIFICAND = 10; +static const int FP16_EXPONENT_MAX = 30; +static const int FP16_EXPONENT_MIN = -10; +static const int FP16_SHIFT = 13; +float ShortToFloat32(uint16_t src_value); +uint16_t Float32ToShort(float src_value); + +#ifdef __cplusplus +} +#endif +#endif // NNACL_NNACL_COMMON_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/nnacl_utils.c b/mindspore-lite/ops/kernel/cpu/nnacl/nnacl_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..f0f81b26ac5a78363c3bc073b31bb1963b1dd8b9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/nnacl_utils.c @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/nnacl_utils.h" +#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) +#include +#endif + +#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) +uint32_t getHwCap(int hwcap_type) { + uint32_t ret = getauxval(hwcap_type); + return ret; +} +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/nnacl_utils.h b/mindspore-lite/ops/kernel/cpu/nnacl/nnacl_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..dfc198787c9c62476b87388bf68a70f382f0bd64 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/nnacl_utils.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NNACL_UTILS_H_ +#define NNACL_NNACL_UTILS_H_ + +#include +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(__arm__) || defined(__aarch64__) +uint32_t getHwCap(int hwcap_type); +#endif + +#ifdef DEBUG +#include +#define NNACL_ASSERT(f) assert(f) +#else +#define NNACL_ASSERT(f) ((void)0) +#endif + +#ifdef __cplusplus +} +#endif +#endif // NNACL_NNACL_UTILS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/non_max_suppression_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/non_max_suppression_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..eb069530a1a074545c3321ee85c47065e63a28b7 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/non_max_suppression_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_ +#define NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct NMSParameter { + // Primitive parameter + OpParameter op_parameter_; + int center_point_box_; +} NMSParameter; + +#endif // NNACL_NON_MAX_SUPPRESSION_PARAMETER_H_ diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/SceneDelegate.h b/mindspore-lite/ops/kernel/cpu/nnacl/one_hot_parameter.h similarity index 72% rename from mindspore-lite/examples/quick_start_ios/mindspore-lite/SceneDelegate.h rename to mindspore-lite/ops/kernel/cpu/nnacl/one_hot_parameter.h index 23411a6705fcab6b8fc5f00ba9bb6d4946d964b3..94c359c35b6059521a928cab36085faf5d402886 100644 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/SceneDelegate.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/one_hot_parameter.h @@ -14,11 +14,13 @@ * limitations under the License. */ -#import +#ifndef NNACL_ONE_HOT_PARAMETER_H_ +#define NNACL_ONE_HOT_PARAMETER_H_ +#include "nnacl/op_base.h" -@interface SceneDelegate : UIResponder - -@property(strong, nonatomic)UIWindow * window; - -@end +typedef struct OneHotParameter { + OpParameter op_parameter_; + int axis_; +} OneHotParameter; +#endif // NNACL_ONE_HOT_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/op_base.h b/mindspore-lite/ops/kernel/cpu/nnacl/op_base.h new file mode 100644 index 0000000000000000000000000000000000000000..e1a9c40c3f23ff26242c85c1ad3f9317080d72f4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/op_base.h @@ -0,0 +1,802 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_ + +#include +#include +#include +#include +#include +#include +#ifdef ENABLE_ARM +#include +#endif + +#define C0NUM 0 +#define C1NUM 1 +#define C2NUM 2 +#define C3NUM 3 +#define C4NUM 4 +#define C5NUM 5 +#define C6NUM 6 +#define C7NUM 7 +#define C8NUM 8 +#define C9NUM 9 +#define C10NUM 10 +#define C11NUM 11 +#define C12NUM 12 +#define C13NUM 13 +#define C14NUM 14 +#define C15NUM 15 +#define C16NUM 16 +#define C17NUM 17 +#define C18NUM 18 +#define C19NUM 19 +#define C20NUM 20 +#define C21NUM 21 +#define C22NUM 22 +#define C23NUM 23 +#define C24NUM 24 +#define C28NUM 28 +#define C32NUM 32 +#define C36NUM 36 +#define C40NUM 40 +#define C44NUM 44 +#define C48NUM 48 +#define C56NUM 56 +#define C64NUM 64 +#define C128NUM 128 +#define C150NUM 150 +#define C256NUM 256 +#define C512NUM 512 +#define C1500NUM 1500 +#define TILE_NUM 8 +#define MAX_SPLIT_NUM 2048 + +#define FP16_DATA_TYPE_LEN 2 + +#ifndef MS_UNLIKELY +#ifdef _MSC_VER +#define MS_UNLIKELY(x) (x) +#else +#define MS_UNLIKELY(x) __builtin_expect(!!(x), 0) +#endif +#endif + +#ifndef MS_LIKELY +#ifdef _MSC_VER +#define MS_LIKELY(x) (x) +#else +#define MS_LIKELY(x) __builtin_expect(!!(x), 1) +#endif +#endif + +#define NNACL_MIN(x, y) ((x) < (y) ? (x) : (y)) +#define NNACL_MAX(x, y) ((x) > (y) ? (x) : (y)) + +#define MSMIN(x, y) ((x) < (y) ? (x) : (y)) +#define MSMAX(x, y) ((x) > (y) ? (x) : (y)) +#define MSCEIL(x) (int)((x) + (((x) - (int)(x)) > 0 ? 1 : 0)) + +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define UP_ROUND(x, y) (((x) + (y) - (1)) / (y) * (y)) +#define DOWN_DIV(x, y) ((x) / (y)) +#define DOWN_ROUND(x, y) ((x) / (y) * (y)) + +#define MSVALID(left, x, right) (MSMIN((MSMAX(left, x)), right)) +#define SIZE_MUL_OVERFLOW(x, y) (((x) == 0) ? false : (SIZE_MAX / (x)) < (y)) +#define INT_MUL_OVERFLOW(x, y) \ + (((x) == 0) ? false \ + : ((x) > 0 ? (((y) >= 0) ? (INT_MAX / (x)) < (y) : (INT_MAX / (x)) < (-1 * (y))) \ + : (((y) >= 0) ? (INT_MAX / (x)) > (-1 * (y)) : (INT_MAX / (x)) > (y)))) + +#define INT_MUL_OVERFLOW_THRESHOLD(x, y, threshold) \ + (((x) == 0) ? false \ + : ((x) > 0 ? (((y) >= 0) ? ((threshold) / (x)) < (y) : ((threshold) / (x)) < (-1 * (y))) \ + : (((y) >= 0) ? ((threshold) / (x)) > (-1 * (y)) : ((threshold) / (x)) > (y)))) + +#define INT_ADD_OVERFLOW(x, y) (INT_MAX - (x)) < (y) + +#define INT_ADD_OVERFLOW_THRESHOLD(x, y, threshold) ((threshold) - (x)) < (y) + +#define MALLOC_MAX_SIZE (2000 * 1024 * 1024) + +#define COMM_SHAPE_SIZE 4 +#define MAX_SHAPE_SIZE 8 + +#define OUTPUT_INDEX 0 +#define FIRST_INPUT 0 +#define SECOND_INPUT 1 +#define THIRD_INPUT 2 +#define FOURTH_INPUT 3 +#define FIFTH_INPUT 4 +#define SIXTH_INPUT 5 +#define SEVENTH_INPUT 6 +#define EIGHTH_INPUT 7 +#define NINTH_INPUT 8 + +#define ONE_TENSOR 1 +#define TWO_TENSOR 2 +#define THREE_TENSOR 3 +#define FOUR_TENSOR 4 +#define FIVE_TENSOR 5 + +#define Index0 0 +#define Index1 1 +#define Index2 2 +#define Index3 3 +#define Index4 4 +#define Index5 5 +#define Index6 6 +#define Index7 7 +#define Index8 8 +#define Index9 9 + +#define Num0 0 +#define Num1 1 +#define Num2 2 +#define Num3 3 +#define Num4 4 +#define Num5 5 +#define Num6 6 +#define Num7 7 +#define Num8 8 +#define Num9 9 + +#define DIMENSION_0D 0 +#define DIMENSION_1D 1 +#define DIMENSION_2D 2 +#define DIMENSION_3D 3 +#define DIMENSION_4D 4 +#define DIMENSION_5D 5 +#define DIMENSION_6D 6 +#define DIMENSION_7D 7 +#define DIMENSION_8D 8 +#define DIMENSION_9D 9 +#define DIMENSION_10D 10 +#define DIMENSION_11D 11 +#define kInputIndex 0 +#define kWeightIndex 1 +#define kBiasIndex 2 +#define kOutputIndex 0 +#define kNHWC_N 0 +#define kNHWC_H 1 +#define kNHWC_W 2 +#define kNHWC_C 3 +#define kNCHW_N 0 +#define kNCHW_C 1 +#define kNCHW_H 2 +#define kNCHW_W 3 +#define kHWCN_C 2 +#define kHWNC_N 2 +#define kHWCN_N 3 +#define kNDHWC_N 0 +#define kNDHWC_D 1 +#define kNDHWC_H 2 +#define kNDHWC_W 3 +#define kNDHWC_C 4 +#define kInputSize1 2 +#define kInputSize2 3 +#define MAX_AXIS_SIZE 6 +#define MAX_LEN 256 +#define MAX_THREAD_NUM 64 +#define FLT16_MAX 65504 +#define kDefaulLiteMaxSpinCount 300000 +#define kDefaulLiteMinSpinCount 1 +#define kDefaulLiteIosSpinCount 1 +#define DEFAULT_GROUP_NAME_LEN 101 +#define kValueThreshold6 6 + +#define INVALID_SHAPE -1 + +#define CLARGSINDEX0 0 +#define CLARGSINDEX1 1 +#define CLARGSINDEX2 2 +#define CLARGSINDEX3 3 +#define CLARGSINDEX4 4 +#define CLARGSINDEX5 5 +#define CLARGSINDEX6 6 +#define CLARGSINDEX7 7 +#define CLARGSINDEX8 8 +#define CLARGSINDEX9 9 + +#define CLIDX_X 0 +#define CLIDX_Y 1 +#define CLIDX_Z 2 +#define CLIDX_W 3 + +#define RELU6_MIN_VAL 0 +#define RELU6_MAX_VAL 6 + +/* index for primitive_type & activation_type */ +#define TC_PTYPE(primitive_type) (primitive_type << 16) +#define TC_ATYPE(activation_type) (activation_type) +#define TC_TYPE(primitive_type, activation_type) (TC_PTYPE(primitive_type) + TC_ATYPE(activation_type)) + +#define NNACL_MALLOC_CHECK_NULL_RETURN_ERR(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + } while (0) + +#define NNACL_MALLOC_CHECK_NULL_RETURN_NULL(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return NULL; \ + } \ + } while (0) + +#if ENABLE_HIGH_PERFORMANCE +#define NNACL_CHECK_TRUE_RET(value, errcode) +#define NNACL_CHECK_TRUE_RET_VOID(value) +#define NNACL_CHECK_FALSE(value, errcode) +#define NNACL_CHECK_INT_MUL_NOT_OVERFLOW(value1, value2, errcode) +#define NNACL_CHECK_INT_ADD_NOT_OVERFLOW(value1, value2, errcode) + +#define NNACL_CHECK_ZERO_RETURN_ERR(val) +#define NNACL_CHECK_ZERO_RETURN(val) +#define NNACL_CHECK_NULL_RETURN_ERR(ptr) +#define NNACL_CHECK_NULL_RETURN_VOID(ptr) +#define NNACL_CHECK_NULL_RETURN_NULL(ptr) +#define NNACL_CHECK_MALLOC_SIZE(val) +#else +#define NNACL_CHECK_TRUE_RET(value, errcode) \ + do { \ + if (!(value)) { \ + return errcode; \ + } \ + } while (0) + +#define NNACL_CHECK_TRUE_RET_VOID(value) \ + do { \ + if (!(value)) { \ + return; \ + } \ + } while (0) + +// Check whether value is false, if not return 'errcode' +#define NNACL_CHECK_FALSE(value, errcode) \ + do { \ + if ((value)) { \ + return errcode; \ + } \ + } while (0) + +#define NNACL_CHECK_INT_MUL_NOT_OVERFLOW(value1, value2, errcode) \ + NNACL_CHECK_TRUE_RET(!(INT_MUL_OVERFLOW(value1, value2)), errcode) +#define NNACL_CHECK_INT_ADD_NOT_OVERFLOW(value1, value2, errcode) \ + NNACL_CHECK_TRUE_RET(!(INT_ADD_OVERFLOW(value1, value2)), errcode) +#define NNACL_CHECK_MALLOC_SIZE(malloc_size) \ + NNACL_CHECK_FALSE((malloc_size) > MALLOC_MAX_SIZE, NNACL_MALLOC_SIZE_INVALID) + +#define NNACL_CHECK_ZERO_RETURN_ERR(val) \ + do { \ + if ((val) == 0) { \ + return NNACL_ERR; \ + } \ + } while (0) + +#define NNACL_CHECK_ZERO_RETURN(val) \ + do { \ + if ((val) == 0) { \ + return; \ + } \ + } while (0) + +#define NNACL_CHECK_NULL_RETURN_ERR(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return NNACL_NULL_PTR; \ + } \ + } while (0) + +#define NNACL_CHECK_NULL_RETURN_VOID(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return; \ + } \ + } while (0) + +#define NNACL_CHECK_NULL_RETURN_NULL(ptr) \ + do { \ + if ((ptr) == NULL) { \ + return NULL; \ + } \ + } while (0) +#endif + +enum PrimType { + PrimType_NONE = 0, + PrimType_Abs = 1, + PrimType_Activation = 2, + PrimType_ActivationGrad = 3, + PrimType_Adam = 4, + PrimType_AddFusion = 5, + PrimType_AdderFusion = 6, + PrimType_AddGrad = 7, + PrimType_AddN = 8, + PrimType_All = 9, + PrimType_ApplyMomentum = 10, + PrimType_ArgMaxFusion = 11, + PrimType_ArgMinFusion = 12, + PrimType_Assert = 13, + PrimType_Assign = 14, + PrimType_AssignAdd = 15, + PrimType_AudioSpectrogram = 16, + PrimType_AvgPoolFusion = 17, + PrimType_AvgPoolGrad = 18, + PrimType_BatchNorm = 19, + PrimType_BatchNormGrad = 20, + PrimType_BatchToSpace = 21, + PrimType_BatchToSpaceND = 22, + PrimType_BiasAdd = 23, + PrimType_BinaryCrossEntropy = 24, + PrimType_BinaryCrossEntropyGrad = 25, + PrimType_BiasAddGrad = 26, + PrimType_BroadcastTo = 27, + PrimType_Cast = 28, + PrimType_Ceil = 29, + PrimType_Clip = 30, + PrimType_Concat = 31, + PrimType_Attention = 32, + PrimType_Conv2DBackpropFilterFusion = 33, + PrimType_Conv2DBackpropInputFusion = 34, + PrimType_Conv2DFusion = 35, + PrimType_Conv2dTransposeFusion = 36, + PrimType_Cos = 37, + PrimType_ConstantOfShape = 38, + PrimType_Crop = 39, + PrimType_CustomExtractFeatures = 40, + PrimType_CustomNormalize = 41, + PrimType_CustomPredict = 42, + PrimType_DeConv2DGradFilter = 43, + PrimType_Depend = 44, + PrimType_DepthToSpace = 45, + PrimType_DetectionPostProcess = 46, + PrimType_DivFusion = 47, + PrimType_DivGrad = 48, + PrimType_Dropout = 49, + PrimType_DropoutGrad = 50, + PrimType_Elu = 51, + PrimType_Eltwise = 52, + PrimType_Equal = 53, + PrimType_EmbeddingLookupFusion = 54, + PrimType_ExpFusion = 55, + PrimType_ExpandDims = 56, + PrimType_FakeQuantWithMinMaxVars = 57, + PrimType_FakeQuantWithMinMaxVarsPerChannel = 58, + PrimType_FftReal = 59, + PrimType_FftImag = 60, + PrimType_Flatten = 61, + PrimType_FlattenGrad = 62, + PrimType_Floor = 63, + PrimType_FloorDiv = 64, + PrimType_FloorMod = 65, + PrimType_Fill = 66, + PrimType_FullConnection = 67, + PrimType_FusedBatchNorm = 68, + PrimType_Gather = 69, + PrimType_GatherNd = 70, + PrimType_Greater = 71, + PrimType_GreaterEqual = 72, + PrimType_HashtableLookup = 73, + PrimType_InstanceNorm = 74, + PrimType_LayerNormFusion = 75, + PrimType_LeakyRelu = 76, + PrimType_Less = 77, + PrimType_LessEqual = 78, + PrimType_Log = 79, + PrimType_LogGrad = 80, + PrimType_LogicalAnd = 81, + PrimType_LogicalNot = 82, + PrimType_LogicalOr = 83, + PrimType_LpNormalization = 84, + PrimType_LRN = 85, + PrimType_LshProjection = 86, + PrimType_LSTM = 87, + PrimType_L2NormalizeFusion = 88, + PrimType_MatMulFusion = 89, + PrimType_Maximum = 90, + PrimType_MaximumGrad = 91, + PrimType_MaxPoolFusion = 92, + PrimType_MaxPoolGrad = 93, + PrimType_SwitchLayer = 94, + PrimType_Mfcc = 95, + PrimType_Minimum = 96, + PrimType_MinimumGrad = 97, + PrimType_Mod = 98, + PrimType_MulFusion = 99, + PrimType_MulGrad = 100, + PrimType_Neg = 101, + PrimType_NegGrad = 102, + PrimType_NotEqual = 103, + PrimType_NonMaxSuppression = 104, + PrimType_OneHot = 105, + PrimType_OnesLike = 106, + PrimType_PadFusion = 107, + PrimType_PartialFusion = 108, + PrimType_PowerGrad = 109, + PrimType_PowFusion = 110, + PrimType_PriorBox = 111, + PrimType_PReLUFusion = 112, + PrimType_QuantDTypeCast = 113, + PrimType_Rank = 114, + PrimType_Range = 115, + PrimType_Reciprocal = 116, + PrimType_RealDiv = 117, + PrimType_ReduceFusion = 118, + PrimType_Reshape = 119, + PrimType_Resize = 120, + PrimType_ReverseSequence = 121, + PrimType_ReverseV2 = 122, + PrimType_Rfft = 123, + PrimType_ROIPooling = 124, + PrimType_Round = 125, + PrimType_Rsqrt = 126, + PrimType_ScaleFusion = 127, + PrimType_ScatterNd = 128, + PrimType_SGD = 129, + PrimType_Shape = 130, + PrimType_SigmoidCrossEntropyWithLogits = 131, + PrimType_SigmoidCrossEntropyWithLogitsGrad = 132, + PrimType_Sin = 133, + PrimType_SkipGram = 134, + PrimType_SliceFusion = 135, + PrimType_SmoothL1Loss = 136, + PrimType_SmoothL1LossGrad = 137, + PrimType_Softmax = 138, + PrimType_SoftmaxCrossEntropyWithLogits = 139, + PrimType_SpaceToBatch = 140, + PrimType_SpaceToBatchND = 141, + PrimType_SpaceToDepth = 142, + PrimType_SparseSoftmaxCrossEntropyWithLogits = 143, + PrimType_SparseToDense = 144, + PrimType_Split = 145, + PrimType_Sqrt = 146, + PrimType_Squeeze = 147, + PrimType_Square = 148, + PrimType_SquaredDifference = 149, + PrimType_Stack = 150, + PrimType_StridedSlice = 151, + PrimType_SubFusion = 152, + PrimType_SubGrad = 153, + PrimType_Switch = 154, + PrimType_TensorListFromTensor = 155, + PrimType_TensorListGetItem = 156, + PrimType_TensorListReserve = 157, + PrimType_TensorListSetItem = 158, + PrimType_TensorListStack = 159, + PrimType_TileFusion = 160, + PrimType_TopKFusion = 161, + PrimType_Transpose = 162, + PrimType_Unique = 163, + PrimType_UnsortedSegmentSum = 164, + PrimType_Unsqueeze = 165, + PrimType_Unstack = 166, + PrimType_LSTMGrad = 167, + PrimType_Where = 168, + PrimType_ZerosLike = 169, + PrimType_Select = 170, + PrimType_ScatterNdUpdate = 171, + PrimType_GRU = 172, + PrimType_NonZero = 173, + PrimType_InvertPermutation = 174, + PrimType_Size = 175, + PrimType_RandomStandardNormal = 176, + PrimType_CropAndResize = 177, + PrimType_Erf = 178, + PrimType_StridedSliceGrad = 179, + PrimType_IsFinite = 180, + PrimType_LinSpace = 181, + PrimType_UniformReal = 182, + PrimType_AbsGrad = 183, + PrimType_RsqrtGrad = 184, + PrimType_SqrtGrad = 185, + PrimType_LayerNormGrad = 186, + PrimType_ResizeGrad = 187, + PrimType_Splice = 188, + PrimType_LogSoftmax = 189, + PrimType_Call = 190, + PrimType_Custom = 191, + PrimType_CumSum = 192, + PrimType_SplitWithOverlap = 193, + PrimType_GenOP = 194, + PrimType_RaggedRange = 195, + PrimType_GLU = 196, + PrimType_TensorArray = 197, + PrimType_TensorArrayRead = 198, + PrimType_TensorArrayWrite = 199, + PrimType_Affine = 200, + PrimType_AllGather = 201, + PrimType_ReduceScatter = 202, + PrimType_DynamicQuant = 203, + PrimType_LSTMGradData = 204, + PrimType_LSTMGradWeight = 205, + PrimType_RandomNormal = 206, + PrimType_NLLLoss = 207, + PrimType_NLLLossGrad = 208, + PrimType_FormatTranspose = 209, + PrimType_GatherD = 210, + PrimType_GroupNormFusion = 211, + PrimType_Log1p = 212, + PrimType_TensorScatterAdd = 213, + PrimType_SparseFillEmptyRows = 214, + PrimType_SparseReshape = 215, + PrimType_SparseSegmentSum = 216, + PrimType_ScatterElements = 217, + PrimType_Triu = 218, + PrimType_Tril = 219, + PrimType_AdamWeightDecay = 220, + PrimType_FillV2 = 221, + PrimType_MIN = PrimType_NONE, + PrimType_MAX = PrimType_FillV2 + 1, + + // inner operators. + PrimType_Inner_ToFormat = 10000, + PrimType_Inner_GltextureToOpencl = 10001, + PrimType_Inner_Identity = 10002, + PrimType_Inner_ShapeFusion = 10003, + PrimType_Inner_GraphKernel = 10004, + PrimType_Inner_SplitReduceConcatFusion = 10005, + PrimType_Inner_EncoderLayer = 10006, + PrimType_Inner_FseDecode = 10007, + PrimType_Inner_DecoderLayer = 10008, + PrimType_Inner_UsePastEmbedding = 10009, + PrimType_Inner_CustomGru = 10010, + PrimType_Inner_CastGatherReduceFusion = 10011, + PrimType_Inner_ReduceConcatFusion = 10012, + PrimType_Inner_AclCustomOp = 10013, + PrimType_Inner_CustomMaskedFill = 10014, + PrimType_Inner_CustomTensorScatterMax = 10015, + PrimType_Inner_CustomIsInf = 10016, + PrimType_Inner_Conv3D = 10017, + PrimType_Inner_GridSampler = 10018, + PrimType_InnerOpMax, + PrimType_InnerOpMin = PrimType_Inner_ToFormat +}; + +typedef enum FormatC { + DEFAULT_FORMAT = -1, + Format_NCHW = 0, + Format_NHWC = 1, + Format_NHWC4 = 2, + Format_HWKC = 3, + Format_HWCK = 4, + Format_KCHW = 5, + Format_CKHW = 6, + Format_KHWC = 7, + Format_CHWK = 8, + Format_HW = 9, + Format_HW4 = 10, + Format_NC = 11, + Format_NC4 = 12, + Format_NC4HW4 = 13, + Format_NONE = 14, // The origin Format_NUM_OF_FORMAT can't be used. + Format_NCDHW = 15, + Format_NWC = 16, + Format_NCW = 17, + Format_NDHWC = 18, + Format_NC8HW8 = 19, + Format_NC16HW16 = 20, + Format_MAX, + Format_MIN = Format_NCHW +} FormatC; + +typedef enum TypeIdC { + kTypeUnknown = 0, + kMetaTypeBegin = kTypeUnknown, + kMetaTypeType, // Type + kMetaTypeAny, + kMetaTypeObject, + kMetaTypeTypeType, // TypeType + kMetaTypeProblem, + kMetaTypeExternal, + kMetaTypeNone, + kMetaTypeNull, + kMetaTypeEllipsis, + kMetaTypeEnd, + // + // Object types + // + kObjectTypeBegin = kMetaTypeEnd, + kObjectTypeNumber, + kObjectTypeString, + kObjectTypeList, + kObjectTypeTuple, + kObjectTypeSlice, + kObjectTypeKeyword, + kObjectTypeTensorType, + kObjectTypeRowTensorType, + kObjectTypeCOOTensorType, + kObjectTypeUndeterminedType, + kObjectTypeClass, + kObjectTypeDictionary, + kObjectTypeFunction, + kObjectTypeJTagged, + kObjectTypeSymbolicKeyType, + kObjectTypeEnvType, + kObjectTypeRefKey, + kObjectTypeRef, + kObjectTypeEnd, + // + // Number Types + // + kNumberTypeBegin = kObjectTypeEnd, + kNumberTypeBool, + kNumberTypeInt, + kNumberTypeInt8, + kNumberTypeInt16, + kNumberTypeInt32, + kNumberTypeInt64, + kNumberTypeUInt, + kNumberTypeUInt8, + kNumberTypeUInt16, + kNumberTypeUInt32, + kNumberTypeUInt64, + kNumberTypeFloat, + kNumberTypeFloat16, + kNumberTypeFloat32, + kNumberTypeFloat64, + kNumberTypeDouble, + kNumberTypeComplex, + kNumberTypeComplex64, + kNumberTypeComplex128, + kNumberTypeInt4, + kNumberTypeGLUInt, + kNumberTypeEnd, +} TypeIdC; + +typedef enum DataOrder { + RowMajor, + ColMajor, +} DataOrder; + +typedef struct OpParameter { + char name_[100]; + int type_; + int thread_num_; + int quant_type_; + bool is_train_session_; + bool is_zero_shape_; + void (*destroy_func_)(struct OpParameter *param); +} OpParameter; + +typedef struct QuantArg { + float scale_; + int32_t zp_; +} QuantArg; + +typedef struct QuantMulArg { + int32_t multiplier_; + int left_shift_; + int right_shift_; +} QuantMulArg; + +typedef enum ReductionType { Reduction_Sum, Reduction_Mean, Reduction_None } ReductionType; +typedef enum ActType { + ActType_No = 0, + ActType_Relu = 1, + ActType_Sigmoid = 2, + ActType_Relu6 = 3, + ActType_Elu = 4, + ActType_LeakyRelu = 5, + ActType_Abs = 6, + ActType_Relu1 = 7, + ActType_Softsign = 8, + ActType_Softplus = 9, + ActType_Tanh = 10, + ActType_Selu = 11, + ActType_HSwish = 12, + ActType_HSigmoid = 13, + ActType_ThresholdRelu = 14, + ActType_Linear = 15, + ActType_HardTanh = 16, + ActType_Sign = 17, + ActType_Swish = 18, + ActType_Gelu = 19, + ActType_FastGelu = 20, + ActType_Unknown = 21 +} ActType; +typedef enum PadType { Pad_pad, Pad_same, Pad_valid } PadType; +typedef enum EltwiseType { Eltwise_PROD, Eltwise_SUM, Eltwise_MAXIMUM, Eltwise_UNKNOWN } EltwiseType; +typedef enum RoundingMode { Rounding_No, Rounding_Away_from_zero, Rounding_Up } RoundingMode; + +typedef enum PaddingModeC { + PaddingMode_Constant, + PaddingMode_Reflect, + PaddingMode_Symmetric, + PaddingMode_Mode_Reserved, +} PaddingModeC; + +typedef enum ElementwiseModeC { + Elementwise_Not = 0, + Elementwise_Per_Channel = 1, + Elementwise_Per_Num = 2 +} ElementwiseModeC; + +typedef enum QuantTypeC { + Quant_None = 0, + Quant_AwareTraining = 1, + Quant_WeightQuant = 2, + Quant_PostTraining = 3, + Quant_QuantWeight = 4, + Quant_QuantAll = 5, + Quant_QuantDynamic = 6, + Quant_Min = Quant_None, + Quant_Max = Quant_QuantDynamic +} QuantTypeC; + +typedef enum TensorCategoryC { + VarTensor, // common tensor + ConstTensor, // const tensor + ConstScalar, // const scalar + GraphInput, + GraphOutput +} TensorCategoryC; + +typedef enum ReduceModeC { + Reduce_Mean = 0, + Reduce_Max = 1, + Reduce_Min = 2, + Reduce_Prod = 3, + Reduce_Sum = 4, + Reduce_SumSquare = 5, + Reduce_ASum = 6, + Reduce_All = 7, + Reduce_L2 = 8, + Reduce_MIN = Reduce_Mean, + Reduce_MAX = Reduce_L2 +} ReduceModeC; + +typedef enum CalFixedMultiplierMode { + Method_No, + Method_SinglePrecision, + Method_DoublePrecision +} CalFixedMultiplierMode; + +#define VA_ARG_TUPLE_LEN 2 +static inline void offset_to_index_init(int offset, int cnt, ...) { + va_list valist; + va_start(valist, cnt); + int start = offset; + for (int i = 0; i < cnt; i += VA_ARG_TUPLE_LEN) { + int *x = va_arg(valist, int *); + int X = va_arg(valist, int); + + *x = start % X; + start = start / X; + } + va_end(valist); +} + +static inline void offset_to_index_step(int cnt, ...) { + va_list valist; + int flag = 1; + va_start(valist, cnt); + for (int i = 0; i < cnt; i += VA_ARG_TUPLE_LEN) { + int *x = va_arg(valist, int *); + int X = va_arg(valist, int); + if (flag) { + *x = (++*x != X) ? (flag = 0, *x) : (flag = 1, 0); + } + } + va_end(valist); +} + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/op_simd_header_file.h.in b/mindspore-lite/ops/kernel/cpu/nnacl/op_simd_header_file.h.in new file mode 100644 index 0000000000000000000000000000000000000000..4013d9f76ed8b00f45c78f7f986c10ef5e3daae6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/op_simd_header_file.h.in @@ -0,0 +1,36 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_@OP_NAME_UPPER@_SIMD_H_ +#define NNACL_@OP_NAME_UPPER@_SIMD_H_ + +#include "nnacl/intrinsics/ms_simd_instructions.h" +#ifdef ENABLE_AVX512 +#include "nnacl/avx512/@OP_NAME_LOWER@_avx512.h" +#endif + +#ifdef ENABLE_AVX +#include "nnacl/avx/@OP_NAME_LOWER@_avx.h" +#endif + +#ifdef ENABLE_SSE +#include "nnacl/sse/@OP_NAME_LOWER@_sse.h" +#endif + +#ifdef ENABLE_ARM +#include "nnacl/neon/@OP_NAME_LOWER@_neon.h" +#endif + +#endif diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/optimize/CMakeLists.txt b/mindspore-lite/ops/kernel/cpu/nnacl/optimize/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9b190de7416bca9aaa43ce53092d26386122da82 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/optimize/CMakeLists.txt @@ -0,0 +1,62 @@ +project(optimize) + +set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) +include_directories(NNACL_DIR) + +########################### optimized files ########################### +file(GLOB FP16_C_SRC ${NNACL_DIR}/fp16/*.c ${NNACL_DIR}/kernel/f16/*.c) +if(PLATFORM_ARM32) + file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/arm82_aarch32_fp16/*.S) +else() + file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S) + file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S) + set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C) +endif() + +set_property(SOURCE ${FP16_C_SRC} PROPERTY LANGUAGE C) +set_property(SOURCE ${FP16_NEON_SRC} PROPERTY LANGUAGE C) + +if(APPLE) + set_source_files_properties(${SDOT_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") + set_source_files_properties(${FP16_NEON_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") +endif() +########################### share library build ######################## +list(APPEND FP16_FILES ${FP16_C_SRC}) +list(APPEND FP16_FILES ${FP16_NEON_SRC}) + +if(SUPPORT_TRAIN) + file(GLOB FP16_TRAIN_SRC ${NNACL_DIR}/fp16_grad/*.c) + list(APPEND FP16_FILES ${FP16_TRAIN_SRC}) +endif() +if(NOT MSVC) +string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") +endif() +if(MACHINE_LINUX_ARM64) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+fp16") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+fp16") +elseif(NOT PLATFORM_ARM32 AND NOT TARGET_HIMIX AND (NOT (TARGET_AOS_ARM AND TOOLCHAIN_NAME STREQUAL "gcc"))) + list(APPEND SDOT_FILES ${SDOT_SRC}) + add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES}) + add_dependencies(nnacl_optimize_mid fbs_src) + if(NOT TARGET_MIX210) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") + endif() +endif() + +if(MSLITE_ENABLE_FP16) + add_library(nnacl_fp16_mid OBJECT ${FP16_FILES}) + add_dependencies(nnacl_fp16_mid fbs_src) + if(PLATFORM_ARM32) + target_compile_options(nnacl_fp16_mid PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp) + endif() + if(TARGET_AOS_ARM) + if(TOOLCHAIN_NAME STREQUAL "gcc") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+simd+fp16 -mtune=cortex-a72") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+simd+fp16 -mtune=cortex-a72") + else() + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+simd+dotprod+fp16 -mtune=cortex-a72") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+simd+dotprod+fp16 -mtune=cortex-a72") + endif() + endif() +endif() \ No newline at end of file diff --git a/mindspore-lite/examples/quick_start_ios/mindspore-lite/ViewController.h b/mindspore-lite/ops/kernel/cpu/nnacl/pack.h similarity index 75% rename from mindspore-lite/examples/quick_start_ios/mindspore-lite/ViewController.h rename to mindspore-lite/ops/kernel/cpu/nnacl/pack.h index 5607769063439e793c3206213b0288b32a0c3298..633b3db1bd4c4a92b535ab67a04753ff685929c7 100644 --- a/mindspore-lite/examples/quick_start_ios/mindspore-lite/ViewController.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/pack.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,10 +14,10 @@ * limitations under the License. */ -#import +#ifndef NNACL_PACK_H_ +#define NNACL_PACK_H_ -@interface ViewController : UIViewController - - -@end +#include "nnacl/fp32/pack_fp32.h" +#include "nnacl/int8/pack_int8.h" +#endif // NNACL_PACK_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/pad_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/pad_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..657d191e1ec41d0a6367344e09c167a3829d95a6 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/pad_parameter.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_PAD_PARAMETER_H_ +#define NNACL_PAD_PARAMETER_H_ + +#include "nnacl/op_base.h" + +#define MAX_PAD_SIZE 12 +#define DEFAULT_PAD_NDIMS 6 + +typedef struct PadQuantArg { + QuantArg *in_quant_args_; + QuantArg *out_quanr_args_; + int8_t *constant_value_; +} PadQuantArg; + +typedef struct PadParameter { + OpParameter op_parameter_; + int paddings_[MAX_PAD_SIZE]; + int pad_mode_; + float constant_value_; + int padding_length; +} PadParameter; + +#endif // NNACL_PAD_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/partial_fusion_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/partial_fusion_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..de1d92b5422a2e2070a67acfe1727f7cb9506096 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/partial_fusion_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_PARTIAL_FUSION_H_ +#define NNACL_PARTIAL_FUSION_H_ + +#include "nnacl/op_base.h" +#include "nnacl/common_func.h" +#include "nnacl/nnacl_utils.h" + +typedef struct PartialParameter { + OpParameter op_parameter_; + int sub_graph_index_; +} PartialParameter; + +#endif // NNACL_ARTITHMETIC_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/logical.cuh b/mindspore-lite/ops/kernel/cpu/nnacl/pooling_parameter.h similarity index 37% rename from mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/logical.cuh rename to mindspore-lite/ops/kernel/cpu/nnacl/pooling_parameter.h index 4596f6bc457b6828be0b21dae42badaf31995989..3b87a151a300bbd55e680cc784638626fccb083e 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/logical.cuh +++ b/mindspore-lite/ops/kernel/cpu/nnacl/pooling_parameter.h @@ -1,5 +1,5 @@ /** - * Copyright 2022 Huawei Technologies Co., Ltd + * Copyright 2020-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,23 +13,43 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef NNACL_POOLING_PARAMETER_H_ +#define NNACL_POOLING_PARAMETER_H_ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_LOGICAL_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_LOGICAL_H_ +#include "nnacl/op_base.h" -template -void LogicalAnd(const T *input1, const T *input2, T *output, int element_cnt, cudaStream_t stream); +typedef enum PoolMode { PoolMode_No, PoolMode_MaxPool, PoolMode_AvgPool } PoolMode; -template -void LogicalOr(const T *input1, const T *input2, T *output, int element_cnt, cudaStream_t stream); +typedef enum RoundType { RoundType_No, RoundType_Ceil, RoundType_Floor } RoundType; -template -void LogicalNot(const T *input1, T *output, int element_cnt, cudaStream_t stream); +typedef struct PoolingParameter { + OpParameter op_parameter_; + PoolMode pool_mode_; + RoundType round_type_; + PadType pad_mode_; + ActType act_type_; + int avg_mode_; + bool global_; + int window_w_; + int window_h_; + int stride_w_; + int stride_h_; + int pad_u_; + int pad_d_; + int pad_l_; + int pad_r_; +} PoolingParameter; -template -void GreaterOrEqual(const T *input1, const T *input2, T *output, int element_cnt, cudaStream_t stream); +typedef struct Pooling3DParameter { + PoolingParameter pooling_parameter_; + int window_d_; + int stride_d_; + int input_d_; + int output_d_; + int pad_f_; // front + int pad_b_; // back + bool count_include_pad_; + int divisor_override_; +} Pooling3DParameter; -template -void LessOrEqual(const T *input1, const T *input2, T *output, int element_cnt, cudaStream_t stream); - -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_LOGICAL_H_ +#endif // NNACL_POOLING_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/pow_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/pow_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..52dfd4875d64640a05270426f49252b58e7ba372 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/pow_parameter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_POW_PARAMETER_H_ +#define NNACL_POW_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct PowQuantArg { + QuantArg in_args_; + QuantArg exp_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +} PowQuantArg; + +typedef struct PowParameter { + OpParameter op_parameter_; + float power_; + float scale_; + float shift_; +} PowParameter; + +#endif // NNACL_POW_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/predict_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/predict_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..384b37dea826477d0fcd1cd86df8ae19e0f4da86 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/predict_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_PREDICT_PARAMETER_H_ +#define NNACL_PREDICT_PARAMETER_H_ + +#include "nnacl/op_base.h" +typedef struct { + // Primitive parameter + OpParameter op_parameter_; + // other parameter + int output_num; + float weight_threshold; +} PredictParameter; + +typedef struct { + int label; + float weight; +} LabelInfo; +#endif // NNACL_PREDICT_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/prelu_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/prelu_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..6ff47548db55b1531fa16d92c2d3523a98f04c70 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/prelu_parameter.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_PRELU_PARAMETER_H_ +#define NNACL_PRELU_PARAMETER_H_ + +#include "nnacl/op_base.h" +typedef struct PReluParameter { + OpParameter op_parameter_; + bool channel_shared_; +} PReluParameter; + +#endif // NNACL_PRELU_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/prior_box_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/prior_box_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b372dba18680c4a74a7bfad0fd756c82b9b0314a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/prior_box_parameter.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_PRIOR_BOX_PARAMETER_H_ +#define NNACL_PRIOR_BOX_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct PriorBoxParameter { + OpParameter op_parameter_; + int32_t min_sizes_size; + int32_t min_sizes[MAX_SHAPE_SIZE]; + int32_t max_sizes_size; + int32_t max_sizes[MAX_SHAPE_SIZE]; + int32_t aspect_ratios_size; + float aspect_ratios[MAX_SHAPE_SIZE]; + float variances[COMM_SHAPE_SIZE]; + int32_t image_size_w; + int32_t image_size_h; + float step_w; + float step_h; + bool clip; + bool flip; + float offset; +} PriorBoxParameter; + +#endif // NNACL_PRIOR_BOX_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/random_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/random_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..7030c17724e217365cb7bbf69b0473fab2bdd6ec --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/random_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_RNADOM_PARAMETER_H_ +#define NNACL_RNADOM_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct RandomParam { + OpParameter op_parameter_; + int seed_; + int seed2_; +} RandomParam; + +typedef struct RandomNormalParam { + OpParameter op_parameter_; + float seed_; + float mean_; + float scale_; +} RandomNormalParam; + +#endif // NNACL_RNADOM_STANDARD_NORMAL_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/range_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/range_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..65243a1565f6a6f2e609c6a71dd2baa079c9bae9 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/range_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_RANGE_PARAMETER_H_ +#define NNACL_RANGE_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct RangeParameter { + OpParameter op_parameter_; + int dtype_; + int start_; + int limit_; + int delta_; +} RangeParameter; + +#endif // NNACL_RANGE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/reduce_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/reduce_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..898fd1aa31e6520a32974a7ba45deab810b3504b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/reduce_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_REDUCE_PARAMETER_H_ +#define NNACL_REDUCE_PARAMETER_H_ +#include "nnacl/op_base.h" + +typedef struct ReduceParameter { + OpParameter op_parameter_; + bool keep_dims_; + int mode_; + bool reduce_to_end_; + float coeff; +} ReduceParameter; + +#endif // NNACL_REDUCE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/reduce_scatter_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/reduce_scatter_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..15ad6568ec1a6567fd3bb379b86aaa8fe74b2b5d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/reduce_scatter_parameter.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_REDUCE_SCATTER_PARAMETER_H_ +#define NNACL_REDUCE_SCATTER_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct ReduceScatterParameter { + // primitive parameter + OpParameter op_parameter_; + char group_[DEFAULT_GROUP_NAME_LEN]; + int mode_; + + // other parameter + int rank_size_; +} ReduceScatterParameter; +#endif // NNACL_REDUCE_SCATTER_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/reshape_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/reshape_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..d0d5ec0030f2f138d6d25608b3a79d9deb113f69 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/reshape_parameter.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_RESHAHPE_PARAMETER_H_ +#define NNACL_RESHAHPE_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct ReshapeQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; +} ReshapeQuantArg; + +typedef struct ReshapeParameter { + // primitive parameter + OpParameter op_parameter_; + int shape_dim_; + int shape_[MAX_SHAPE_SIZE]; + + // other parameter + ReshapeQuantArg quant_para_; + int thread_count_; +} ReshapeParameter; + +#endif // NNACL_RESHAHPE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/resize_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/resize_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..356e487ad473095fc1696a252f693c0230609718 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/resize_parameter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_RESIZE_PARAMETER_H_ +#define NNACL_RESIZE_PARAMETER_H_ + +#include "nnacl/op_base.h" +typedef struct ResizeParameter { + // primitive parameter + OpParameter op_parameter_; + int method_; + int64_t new_height_; + int64_t new_width_; + int coordinate_transform_mode_; + float cubic_coeff_; + bool preserve_aspect_ratio_; +} ResizeParameter; + +typedef struct CropAndResizeParameter { + // primitive parameter + OpParameter op_parameter_; + int method_; + float extrapolation_value_; +} CropAndResizeParameter; +#endif // NNACL_RESIZE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/reverse_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/reverse_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..b148ab5dec9b8f885a02c4ca14060f71b5b6a534 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/reverse_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_REVERSE_PARAMETER_H_ +#define NNACL_REVERSE_PARAMETER_H_ + +#include "nnacl/op_base.h" + +#define REVERSE_SHAPE_MAX_SIZE 4 + +typedef struct ReverseParameter { + OpParameter op_parameter_; + int axis_[REVERSE_SHAPE_MAX_SIZE]; + int num_axis_; +} ReverseParameter; + +#endif // NNACL_REVERSE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/reverse_sequence_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/reverse_sequence_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..0b767078a8b87bd990d9d9b10160bb6c5a3d4c4b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/reverse_sequence_parameter.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_REVERSE_SEQUENCE_PARAMETER_H_ +#define NNACL_REVERSE_SEQUENCE_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct ReverseSequenceParameter { + // primitive parameter + OpParameter op_parameter_; + int seq_axis_; + int batch_axis_; + + // shape correlative + int input_shape0_[5]; + int output_shape_[5]; + int input_stride_[5]; + int output_stride_[5]; + + // other parameter + int ndim_; + int outer_count_; + int outer_stride_; + int inner_count_; + int inner_stride_; + int copy_byte_size_; + int total_data_size_; + bool is_seq_length_int32_; +} ReverseSequenceParameter; + +#endif // NNACL_REVERSE_SEQUENCE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/scale_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/scale_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..c960c2821acb2be3399bededf078196a45aba2a2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/scale_parameter.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SCALE_H_ +#define NNACL_SCALE_H_ + +#include "nnacl/op_base.h" + +typedef struct ScaleParameter { + OpParameter op_parameter_; + int axis_; + int activation_type_; +} ScaleParameter; + +typedef struct ScaleQuantParameter { + QuantMulArg scale_mul_arg_; + QuantMulArg offset_mul_arg_; + int input_zp_; + int scale_zp_; + int offset_zp_; + int output_zp_; + int output_activation_min_; + int output_activation_max_; +} ScaleQuantParameter; + +#endif // NNACL_SCALE_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/scatter_elements_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/scatter_elements_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..437aba689992ab53d88f188c911ee5e7d0c8468d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/scatter_elements_parameter.h @@ -0,0 +1,25 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_SCATTER_ELEMENTS_PARAMETER_H_ +#define NNACL_SCATTER_ELEMENTS_PARAMETER_H_ + +#include "nnacl/op_base.h" +typedef struct ScatterElementsParameter { + OpParameter op_parameter_; + int axis_; +} ScatterElementsParameter; + +#endif // NNACL_SCATTER_ELEMENTS_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/scatter_nd_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/scatter_nd_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..7ab4a1558a66862943705c3c2c69a7fbc5c8722d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/scatter_nd_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SCATTER_ND_PARAMETER_H_ +#define NNACL_SCATTER_ND_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct ScatterNDParameter { + OpParameter op_parameter; + int num_unit; + int unit_size; + int data_type_len; +} ScatterNDParameter; + +#endif // NNACL_SCATTER_ND_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/sequence_unstack_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/sequence_unstack_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..3aa3073d7fa8a58daa2802ee64df348b53084e5c --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/sequence_unstack_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_NNACL_SEQUENCE_UNSTACK_PARAMETER_H_ +#define MINDSPORE_NNACL_SEQUENCE_UNSTACK_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct SequenceUnstackParameter { + // primitive parameter + OpParameter op_parameter_; + int num_; + int axis_; + + // other parameter + int pre_dims_; + int axis_dim_; + int after_dims_; +} SequenceUnstackParameter; + +#endif // MINDSPORE_NNACL_SEQUENCE_UNSTACK_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/sigmoid_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/sigmoid_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..0f12da9f65e9a31b038f871b7f06392917feddf0 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/sigmoid_parameter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SIGMOID_PARAMETER_H_ +#define NNACL_SIGMOID_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct SigmoidParameter { + // primitive parameter + OpParameter op_parameter_; + + // shape correlative + const int *in_shape_; + const int *out_shape_; + + // other parameter + SigmoidQuantArg quant_arg; + double alpha_; + int thread_count_; + int64_t offset_[MAX_SHAPE_SIZE]; + int64_t in_offset_[MAX_SHAPE_SIZE]; + int64_t axis_; + int input_dim_; + int element_num; +} SigmoidParameter; + +#endif // NNACL_SIGMOID_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/skip_gram_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/skip_gram_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..d0fa9e34e7fd2f85e31525fcd2fe6932faa7d96b --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/skip_gram_parameter.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SKIP_GRAM_PARAMETER_H_ +#define NNACL_SKIP_GRAM_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct SkipGramParameter { + // primitive parameter + OpParameter op_parameter_; + bool include_all_ngrams; + int max_skip_size; + int ngram_size; +} SkipGramParameter; + +#endif // NNACL_SKIP_GRAM_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/slice_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/slice_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..8baff64634327a9d4466ee67ec1e029f2e0066f2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/slice_parameter.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SLICE_PARAMETER_H_ +#define NNACL_SLICE_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct SliceQuantArg { + QuantArg in_args_; + QuantArg out_args_; + int output_activation_min_; + int output_activation_max_; + QuantMulArg multiplier_; +} SliceQuantArg; + +typedef struct SliceParameter { + OpParameter op_parameter_; + int32_t axis_[DIMENSION_8D]; +} SliceParameter; + +#endif // NNACL_SLICE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/softmax_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/softmax_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..75db75288b972d9a8606e6cdf46d758ce1f6715a --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/softmax_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SOFTMAX_PARAMETER_H_ +#define NNACL_SOFTMAX_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct SoftmaxParameter { + OpParameter op_parameter_; + int32_t axis_; +} SoftmaxParameter; + +#endif // NNACL_SOFTMAX_PARAMETER_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cast.cuh b/mindspore-lite/ops/kernel/cpu/nnacl/space_to_depth_parameter.h similarity index 56% rename from mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cast.cuh rename to mindspore-lite/ops/kernel/cpu/nnacl/space_to_depth_parameter.h index 59d7ab8279315bad6c5c924c1552c756a4adcd71..43f800aa98d379acd98064ee9a18a83550b79ab9 100644 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cast.cuh +++ b/mindspore-lite/ops/kernel/cpu/nnacl/space_to_depth_parameter.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,11 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifndef LITE_SRC_BACKEND_ARM_NNACL_SPACE_TO_DEPTH_PARAMETER_H_ +#define LITE_SRC_BACKEND_ARM_NNACL_SPACE_TO_DEPTH_PARAMETER_H_ +#include "nnacl/op_base.h" -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_CAST_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_CAST_H_ +typedef struct SpaceToDepthParameter { + // primitive parameter + OpParameter op_parameter_; + int32_t block_size_; + int32_t date_type_len; +} SpaceToDepthParameter; -template -void Cast(const int input_size, const S *input_addr, T *output_addr, cudaStream_t stream); - -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_CAST_H_ +#endif // LITE_SRC_BACKEND_ARM_NNACL_SPACE_TO_DEPTH_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/sparse_to_dense_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/sparse_to_dense_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..e92e36ab2cb27171eec7aa66264d892618ff061e --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/sparse_to_dense_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SPARSE_TO_DENSE_PARAMETER_H_ +#define NNACL_SPARSE_TO_DENSE_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct SparseToDenseParameter { + // primitive parameter + OpParameter op_parameter_; + bool validate_indices_; + bool is_scalar; + int index_num; + int output_num; + int output_stride[DIMENSION_4D]; +} SparseToDenseParameter; + +#endif // NNACL_SPARSE_TO_DENSE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/splice_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/splice_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..190bba4b5cfa3c068fa7b71652548201272a0a56 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/splice_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SPLICE_PARAMETER_H_ +#define NNACL_SPLICE_PARAMETER_H_ +#include "nnacl/op_base.h" +typedef struct SpliceParameter { + OpParameter op_parameter_; + int context_dim_; + int forward_indexes_dim_; + int *context_; + int *forward_indexes_; + int output_dim_; +} SpliceParameter; +#endif // NNACL_SPLICE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/split_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/split_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..01b8485220e74d898e084b4d21d3676d435e6aaa --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/split_parameter.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SPLIT_PARAMETER_H_ +#define NNACL_SPLIT_PARAMETER_H_ + +#include "nnacl/op_base.h" + +#define SPLIT_STRIDES_SIZE 32 +#define SPLIT_MAX_SLICE_NUM 10 + +typedef struct SplitQuantArg { + QuantArg in_args_; + QuantArg out_args_[20]; + int output_activation_min_; + int output_activation_max_; +} SplitQuantArg; + +typedef struct SplitParameter { + // primitive parameter + OpParameter op_parameter_; + int num_split_; + int *split_sizes_; + int split_dim_; + + // shape correlative + int strides_[SPLIT_STRIDES_SIZE]; + + // other parameter + SplitQuantArg quant_arg_; + int n_dims_; + int split_count_; +} SplitParameter; + +typedef struct SplitWithOverlapParameter { + OpParameter op_parameter_; + int num_split_; + int split_dim_; + int ratio_[SPLIT_MAX_SLICE_NUM]; + int extend_top_[SPLIT_MAX_SLICE_NUM]; + int extend_bottom_[SPLIT_MAX_SLICE_NUM]; + + // other parameter + int element_bytes_; + int split_dim_size_; + int outer_total_dim_; + int inner_stride_; +} SplitWithOverlapParameter; + +#endif // NNACL_SPLIT_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/squeeze_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/squeeze_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..e60e46c74b4085f6c300aa2a33e9635e6bef64d5 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/squeeze_parameter.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_SQUEEZE_PARAMETER_H_ +#define NNACL_SQUEEZE_PARAMETER_H_ +#include "nnacl/op_base.h" +#include "nnacl/int8/quantize.h" + +#define SQUEEZE_OFFSET_MAX_SIZE 4 + +typedef struct SqueezeQuantArg { + QuantArg *in_quant_args_; + QuantArg *out_quant_args_; +} SqueezeQuantArg; + +typedef struct SqueezeParameter { + // primitive parameter + OpParameter op_parameter_; + int axis_[8]; + size_t axis_size_; + + // shape correlative + const int *in_shape_; + const int *out_shape_; + int offset_size_; + int64_t offset_[SQUEEZE_OFFSET_MAX_SIZE]; + int64_t in_offset_[SQUEEZE_OFFSET_MAX_SIZE]; + int input_dim_; + // other parameter + SqueezeQuantArg quant_arg; +} SqueezeParameter; + +#endif // NNACL_SQUEEZE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/stack_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/stack_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..9ed327e55710d7af97540ad57544203b0928ce01 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/stack_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_STACK_PARAMETER_H_ +#define NNACL_STACK_PARAMETER_H_ + +#include "nnacl/op_base.h" +typedef struct StackParameter { + // primitive parameter + OpParameter op_parameter_; + int32_t axis_; +} StackParameter; + +#endif // NNACL_STACK_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/strided_slice_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/strided_slice_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..27b3a5d9cfa828a62341bc312668aa146d1d2c24 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/strided_slice_parameter.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_STRIDED_SLICE_PARAMETER_H_ +#define NNACL_STRIDED_SLICE_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct StridedSliceParameter { + // primitive parameter + OpParameter op_parameter_; + int begins_[MAX_SHAPE_SIZE]; + int ends_[MAX_SHAPE_SIZE]; + int strides_[MAX_SHAPE_SIZE]; + int isScale; + + // shape correlative + int in_shape_length_; + int in_shape_[MAX_SHAPE_SIZE]; + + // other parameter + int num_axes_; + TypeIdC data_type; + int begins_mask_; + int ends_mask_; + int ellipsisMask_; + int newAxisMask_; + int shrinkAxisMask_; +} StridedSliceParameter; + +#endif // NNACL_STRIDED_SLICE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/tensor_array_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/tensor_array_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..a95384066698f32eb65c6e3b277033fa74b3a854 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/tensor_array_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_TENSOR_ARRAY_PARAMETER_H_ +#define NNACL_TENSOR_ARRAY_PARAMETER_H_ +#include "nnacl/op_base.h" + +typedef struct TensorArrayParameter { + OpParameter op_parameter_; + bool dynamic_size_; + bool identical_element_shapes_; + int element_shape_[MAX_SHAPE_SIZE]; + int element_shape_size_; + int data_type_; +} TensorArrayParameter; + +#endif // NNACL_TENSOR_ARRAY_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/tensor_c.h b/mindspore-lite/ops/kernel/cpu/nnacl/tensor_c.h new file mode 100644 index 0000000000000000000000000000000000000000..8f40c52d4506f999b429ef8f3b2fd42227e963c2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/tensor_c.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_TENSOR_C_H_ +#define NNACL_TENSOR_C_H_ +#include "nnacl/op_base.h" + +typedef struct TensorC { + bool shape_changed_; + int data_type_; + int format_; + int category_; + void *data_; + size_t shape_size_; + int shape_[MAX_SHAPE_SIZE]; + char *name_; // only used in micro now. +} TensorC; + +#endif // NNACL_TENSOR_C_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/tensor_c_utils.c b/mindspore-lite/ops/kernel/cpu/nnacl/tensor_c_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..01e9dfbf06c2ed9fd784c3e54b013632d904cd33 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/tensor_c_utils.c @@ -0,0 +1,439 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use tensor file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/tensor_c_utils.h" +#include "nnacl/nnacl_common.h" + +int CheckAugmentNull(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter) { + NNACL_CHECK_NULL_RETURN_ERR(inputs); + NNACL_CHECK_NULL_RETURN_ERR(outputs); + for (size_t i = 0; i < inputs_size; i++) { + if (inputs[i] == NULL) { + return NNACL_NULL_PTR; + } + } + for (size_t i = 0; i < outputs_size; i++) { + if (outputs[i] == NULL) { + return NNACL_NULL_PTR; + } + } + if (parameter == NULL) { + return NNACL_NULL_PTR; + } + return NNACL_OK; +} + +int CheckAugmentNullSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if (inputs_size != inputs_size_obj || outputs_size != outputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +int CheckAugmentNullSizeInputTwo(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, + size_t outputs_size, const OpParameter *parameter, size_t inputs_size_obj_0, + size_t inputs_size_obj_1, size_t outputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if ((inputs_size != inputs_size_obj_0 && inputs_size != inputs_size_obj_1) || outputs_size != outputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +int CheckAugmentNullInputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if (inputs_size != inputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +int CheckAugmentNullOutputSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t outputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if (outputs_size != outputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +int CheckAugmentWithMinSize(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + const OpParameter *parameter, size_t inputs_size_obj, size_t outputs_size_obj) { + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret == NNACL_NULL_PTR) { + return NNACL_NULL_PTR; + } + if (inputs_size < inputs_size_obj || outputs_size < outputs_size_obj) { + return NNACL_INPUT_TENSOR_ERROR; + } + return NNACL_OK; +} + +void SetShapeTensor(TensorC *dst, const TensorC *src) { + for (size_t i = 0; i < src->shape_size_; i++) { + dst->shape_[i] = src->shape_[i]; + } + dst->shape_size_ = src->shape_size_; +} + +void SetShapeArray(TensorC *dst, const int *src, size_t src_size) { + for (size_t i = 0; i < src_size && i < MAX_SHAPE_SIZE; i++) { + dst->shape_[i] = src[i]; + } + dst->shape_size_ = src_size; +} + +void SetDataTypeFormat(TensorC *dst, const TensorC *src) { + dst->format_ = src->format_; + dst->data_type_ = src->data_type_; +} + +int NNACLGetBatch(const TensorC *tensor) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return -1; + } + switch (tensor->format_) { + case Format_NHWC: + case Format_NHWC4: + case Format_NCHW: + case Format_NC4HW4: + case Format_NC8HW8: + case Format_KCHW: + case Format_KHWC: + case Format_NC: + case Format_NC4: + return tensor->shape_[kNHWC_N]; + case Format_HWCK: + case Format_CHWK: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kHWCN_N]; + case Format_HWKC: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kHWNC_N]; + case Format_CKHW: + return tensor->shape_[1]; + default: + return -1; + } +} +int NNACLGetHeight(const TensorC *tensor) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return -1; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_CKHW: + case Format_NC4HW4: + case Format_NC8HW8: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kNCHW_H]; + case Format_NHWC: + case Format_NHWC4: + case Format_KHWC: + case Format_CHWK: + return tensor->shape_[kNHWC_H]; + case Format_HWCK: + case Format_HWKC: + case Format_HW: + case Format_HW4: + return tensor->shape_[0]; + default: + return -1; + } +} +int NNACLGetWidth(const TensorC *tensor) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return -1; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_CKHW: + case Format_NC4HW4: + case Format_NC8HW8: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kNCHW_W]; + case Format_KHWC: + case Format_NHWC: + case Format_NHWC4: + case Format_CHWK: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kNHWC_W]; + case Format_HWCK: + case Format_HWKC: + case Format_HW: + case Format_HW4: + return tensor->shape_[1]; + default: + return -1; + } +} +int NNACLGetChannel(const TensorC *tensor) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return -1; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_NC: + case Format_NC4: + case Format_NC4HW4: + case Format_NC8HW8: + return tensor->shape_[kNCHW_C]; + case Format_HWCK: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kHWCN_C]; + case Format_HWKC: + case Format_NHWC: + case Format_NHWC4: + case Format_KHWC: + if (tensor->shape_size_ != DIMENSION_4D) { + return -1; + } + return tensor->shape_[kNHWC_C]; + case Format_CKHW: + case Format_CHWK: + return tensor->shape_[0]; + default: + return -1; + } +} + +void NNACLSetBatch(TensorC *tensor, int batch) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return; + } + switch (tensor->format_) { + case Format_NHWC: + case Format_NHWC4: + case Format_NCHW: + case Format_NC4HW4: + case Format_NC8HW8: + case Format_KCHW: + case Format_KHWC: + case Format_NC: + case Format_NC4: + tensor->shape_[kNHWC_N] = batch; + return; + case Format_HWCK: + case Format_CHWK: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kHWCN_N] = batch; + return; + case Format_HWKC: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kHWNC_N] = batch; + return; + case Format_CKHW: + tensor->shape_[1] = batch; + return; + default: + return; + } +} + +void NNACLSetHeight(TensorC *tensor, int height) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_CKHW: + case Format_NC4HW4: + case Format_NC8HW8: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kNCHW_H] = height; + return; + case Format_NHWC: + case Format_NHWC4: + case Format_KHWC: + case Format_CHWK: + tensor->shape_[kNHWC_H] = height; + return; + case Format_HWCK: + case Format_HWKC: + case Format_HW: + case Format_HW4: + tensor->shape_[0] = height; + return; + default: + return; + } +} + +void NNACLSetWidth(TensorC *tensor, int width) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_CKHW: + case Format_NC4HW4: + case Format_NC8HW8: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kNCHW_W] = width; + return; + case Format_KHWC: + case Format_NHWC: + case Format_NHWC4: + case Format_CHWK: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kNHWC_W] = width; + return; + case Format_HWCK: + case Format_HWKC: + case Format_HW: + case Format_HW4: + tensor->shape_[1] = width; + return; + default: + return; + } +} + +void NNACLSetChannel(TensorC *tensor, int channel) { + if (tensor->shape_size_ != DIMENSION_4D && tensor->shape_size_ != DIMENSION_2D) { + return; + } + switch (tensor->format_) { + case Format_NCHW: + case Format_KCHW: + case Format_NC: + case Format_NC4: + case Format_NC4HW4: + case Format_NC8HW8: + tensor->shape_[kNCHW_C] = channel; + return; + case Format_HWCK: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kHWCN_C] = channel; + return; + case Format_HWKC: + case Format_NHWC: + case Format_NHWC4: + case Format_KHWC: + if (tensor->shape_size_ != DIMENSION_4D) { + return; + } + tensor->shape_[kNHWC_C] = channel; + return; + case Format_CKHW: + case Format_CHWK: + tensor->shape_[0] = channel; + return; + default: + return; + } +} + +int NNACLGetSize(const TensorC *tensor) { + int element_num = NNACLGetElementNum(tensor); + int data_type_size = (int)DataTypeCSize(tensor->data_type_); + return element_num * data_type_size; +} + +int NNACLGetElementNum(const TensorC *tensor) { + if (tensor == NULL) { + return -1; + } + if (tensor->shape_size_ == 0) { + return 1; // scalar mode + } + int res = 1; + for (size_t i = 0; i < tensor->shape_size_; i++) { + NNACL_CHECK_INT_MUL_NOT_OVERFLOW(res, tensor->shape_[i], NNACL_ERRCODE_MUL_OVERFLOW); + res = res * tensor->shape_[i]; + } + + int c = NNACLGetChannel(tensor); + if (c == 0) { + return res; + } + if (tensor->format_ == Format_NC4HW4) { + res = res / c * UP_ROUND(c, C4NUM); + } + if (tensor->format_ == Format_NC8HW8) { + res = res / c * UP_ROUND(c, C8NUM); + } + return res; +} + +int NNACLGetDimensionSize(const TensorC *tensor, const size_t index) { + int dim_size = -1; + if (index < tensor->shape_size_) { + dim_size = tensor->shape_[index]; + } + return dim_size; +} + +bool NNACLIsShapeSame(const TensorC *tensor1, const TensorC *tensor2) { + if (tensor1->shape_size_ != tensor2->shape_size_) { + return false; + } + for (size_t i = 0; i < tensor1->shape_size_; i++) { + if (tensor1->shape_[i] != tensor2->shape_[i]) { + return false; + } + } + return true; +} + +bool NNACLIsConst(const TensorC *tensor) { + return (tensor->category_ == ConstTensor || tensor->category_ == ConstScalar) && tensor->data_ != NULL; +} diff --git a/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.h b/mindspore-lite/ops/kernel/cpu/nnacl/tensor_c_utils.h similarity index 40% rename from mindspore-lite/src/extendrt/kernel/cuda/batchtospace.h rename to mindspore-lite/ops/kernel/cpu/nnacl/tensor_c_utils.h index 8c2fa4e622f4e8a26274a62cd46ebc2d2f91435e..2087f68764b2cbc77c237aa75a41902928cf0e5e 100644 --- a/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/tensor_c_utils.h @@ -14,26 +14,34 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CUDA_BATCHTOSPACE_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CUDA_BATCHTOSPACE_H_ +#ifndef NNACL_TENSORC_UTILS_H_ +#define NNACL_TENSORC_UTILS_H_ -#include -#include -#include "src/extendrt/kernel/cuda/cuda_kernel.h" -#include "cuda_impl/cuda_class/batchtospace_helper.h" +#include +#include "nnacl/errorcode.h" +#include "nnacl/op_base.h" +#include "nnacl/tensor_c.h" -namespace mindspore::kernel { -class BatchtoSpaceCudaKernel : public CudaKernel { - public: - BatchtoSpaceCudaKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::InnerContext *ctx) - : CudaKernel(parameter, inputs, outputs, ctx) {} - ~BatchtoSpaceCudaKernel() override = default; - int Prepare() override; - int Run() override; +#ifdef __cplusplus +extern "C" { +#endif + +int NNACLGetBatch(const TensorC *tensor); +int NNACLGetHeight(const TensorC *tensor); +int NNACLGetWidth(const TensorC *tensor); +int NNACLGetChannel(const TensorC *tensor); +void NNACLSetBatch(TensorC *tensor, int batch); +void NNACLSetHeight(TensorC *tensor, int height); +void NNACLSetWidth(TensorC *tensor, int width); +void NNACLSetChannel(TensorC *tensor, int channel); +int NNACLGetElementNum(const TensorC *tensor); +int NNACLGetSize(const TensorC *tensor); +int NNACLGetDimensionSize(const TensorC *tensor, const size_t index); +bool NNACLIsShapeSame(const TensorC *tensor1, const TensorC *tensor2); +bool NNACLIsConst(const TensorC *tensor); - private: - std::shared_ptr> batch_to_space_helper_{nullptr}; -}; -} // namespace mindspore::kernel +#ifdef __cplusplus +} #endif + +#endif // NNACL_TENSORC_UTILS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/tensorlist_c.h b/mindspore-lite/ops/kernel/cpu/nnacl/tensorlist_c.h new file mode 100644 index 0000000000000000000000000000000000000000..1b3c83d8ad5309021f4bc9856a24dfa74a65fb0d --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/tensorlist_c.h @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TENSORLIST_C_H_ +#define NNACL_TENSORLIST_C_H_ + +#include "nnacl/tensor_c.h" + +typedef struct vvector { + int **shape_; // value of shapes + int *shape_size_; // size of shape + size_t size_; // number of shapes +} vvector; + +typedef struct TensorListC { + bool shape_changed_; + int data_type_; + int format_; + int shape_value_; + int tensors_data_type_; // element_data_type_, keep same as c++ + int max_elements_num_; + TensorC **tensors_; + size_t element_num_; + size_t element_shape_size_; + int element_shape_[MAX_SHAPE_SIZE]; +} TensorListC; + +#endif // NNACL_TENSORLIST_C_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/tensorlist_c_utils.c b/mindspore-lite/ops/kernel/cpu/nnacl/tensorlist_c_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..f3e865d1f50c4881e8fdfb002614fa01bde76c74 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/tensorlist_c_utils.c @@ -0,0 +1,82 @@ +/** + * Copyright 2022-2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use tensor file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/tensorlist_c_utils.h" + +int MallocTensorListData(TensorListC *tensor_list, TypeIdC dtype, const vvector *tensor_shape) { + // This function will create a new tensors_ + // Your must to set shape(param2: tensor_shape) and data_type_(tensors_data_type_ = param1: dtype) of each tensor in + // tensors_. After that, you need to call function:MallocData to malloc data buf of each tensor in tensors_. + + if (tensor_list->element_num_ == 0) { + return NNACL_OK; + } + if (((size_t)(tensor_list->element_num_)) != tensor_shape->size_) { + return NNACL_ERR; + } + tensor_list->tensors_data_type_ = dtype; + void *addr = malloc(tensor_list->element_num_ * sizeof(void *) + + tensor_list->element_num_ * sizeof(TensorC)); // free in infer_manager + if (addr == NULL) { + free(tensor_list->tensors_); + return NNACL_NULL_PTR; + } + memset(addr, 0, tensor_list->element_num_ * sizeof(void *) + tensor_list->element_num_ * sizeof(TensorC)); + tensor_list->tensors_ = (TensorC **)addr; + TensorC *tensors = (TensorC *)(tensor_list->tensors_ + tensor_list->element_num_); + for (size_t i = 0; i < tensor_list->element_num_; ++i) { + TensorC *tensor = tensors + i; + tensor_list->tensors_[i] = tensor; + tensor->format_ = Format_NHWC; + tensor->data_type_ = dtype; + ShapeSet(tensor->shape_, &(tensor->shape_size_), tensor_shape->shape_[i], (size_t)tensor_shape->shape_size_[i]); + } + return NNACL_OK; +} + +int TensorListMergeShape(int *element_shape, size_t *element_shape_size, const int *tmp, size_t tmp_size) { + if (*element_shape_size >= 255 || element_shape[0] == -1) { + ShapeSet(element_shape, element_shape_size, tmp, tmp_size); + return NNACL_OK; + } + if (*element_shape_size != tmp_size) { + return NNACL_ERR; + } + for (size_t j = 0; j < tmp_size; ++j) { + if (element_shape[j] >= 0 && tmp[j] >= 0 && element_shape[j] != tmp[j]) { + return NNACL_ERR; + } + element_shape[j] = element_shape[j] >= 0 ? element_shape[j] : tmp[j]; + } + return NNACL_OK; +} + +bool TensorListIsFullyDefined(const int *shape, size_t shape_size) { + for (size_t i = 0; i < shape_size; ++i) { + if (shape[i] < 0) { + return false; + } + } + return true; +} + +bool InferFlagTensorList(TensorC *tensorc) { + TensorListC *input_tensor_list = (TensorListC *)tensorc; + if (input_tensor_list->shape_value_ == -1) { + return false; + } + return true; +} diff --git a/mindspore-lite/tools/graph_kernel/converter/kernel_builder.h b/mindspore-lite/ops/kernel/cpu/nnacl/tensorlist_c_utils.h similarity index 49% rename from mindspore-lite/tools/graph_kernel/converter/kernel_builder.h rename to mindspore-lite/ops/kernel/cpu/nnacl/tensorlist_c_utils.h index 7845c8151fcb72cb9ba56934f81ce79b77aac708..54a2acdb0fc458570bc4ee53cacff361b782a378 100644 --- a/mindspore-lite/tools/graph_kernel/converter/kernel_builder.h +++ b/mindspore-lite/ops/kernel/cpu/nnacl/tensorlist_c_utils.h @@ -14,18 +14,25 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_KERNEL_BUILDER_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_KERNEL_BUILDER_H_ -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "include/backend/optimizer/pass.h" +#ifndef NNACL_TENSORLIST_C_UTILS_H_ +#define NNACL_TENSORLIST_C_UTILS_H_ -namespace mindspore::graphkernel { -class KernelBuilder : public opt::Pass { - public: - KernelBuilder() : Pass("akg_kernel_builder") {} - ~KernelBuilder() override = default; - bool Run(const FuncGraphPtr &func_graph) override; -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_KERNEL_BUILDER_H_ +#include +#include "nnacl/op_base.h" +#include "nnacl/tensorlist_c.h" +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int MallocTensorListData(TensorListC *tensor_list, TypeIdC dtype, const vvector *tensor_shape); +int TensorListMergeShape(int *element_shape, size_t *element_shape_size, const int *tmp, size_t tmp_size); +bool TensorListIsFullyDefined(const int *shape, size_t shape_size); +bool InferFlagTensorList(TensorC *tensor_list); + +#ifdef __cplusplus +} +#endif + +#endif // NNACL_TENSORLIST_C_UTILS_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/tensorlist_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/tensorlist_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..d9273793b6c234016cabfa7025d92e9526b00004 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/tensorlist_parameter.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TENSORLIST_PARAMETER_H_ +#define NNACL_TENSORLIST_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct TensorListParameter { + // primitive parameter + OpParameter op_parameter_; + int shape_type_; + int element_dtype_; + + // other parameter + int num_element_; +} TensorListParameter; + +#endif // NNACL_ARG_TENSORLIST_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/tile_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/tile_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..a82118e500c4552626e320167558e1ff2eb3ea52 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/tile_parameter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_TILE_PARAMETER_H_ +#define NNACL_TILE_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct TileParameter { + OpParameter op_parameter_; + size_t dims_size_; + int dims_[MAX_SHAPE_SIZE]; + int multiples_[MAX_SHAPE_SIZE]; +} TileParameter; + +#endif // NNACL_TILE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/transpose_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/transpose_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..14420437c7748f9225f9dba793434fa877c5068f --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/transpose_parameter.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_TRANSPOSE_PARAMETER_H_ +#define NNACL_TRANSPOSE_PARAMETER_H_ + +#include "nnacl/op_base.h" + +// MAX_TRANSPOSE_SERIAL_SIZE = 64 * 3 * 512 * 512 +#define MAX_TRANSPOSE_SERIAL_SIZE 50331648 +#define MAX_TRANSPOSE_DIM_SIZE 20 +#define PERM_NUM_THREE 3 +#define PERM_NUM_FOUR 4 + +typedef struct TransposeParameter { + // primitive parameter + OpParameter op_parameter_; + int perm_[MAX_TRANSPOSE_DIM_SIZE]; + size_t perm_size_; + bool conjugate_; + + // shape correlative + int strides_[MAX_TRANSPOSE_DIM_SIZE]; + int out_strides_[MAX_TRANSPOSE_DIM_SIZE]; + + // other parameter + int num_axes_; + int data_num_; +} TransposeParameter; + +#endif // NNACL_TRANSPOSE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/triu_tril_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/triu_tril_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..8267395b6323e6b0545e54ba513516d76df5f677 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/triu_tril_parameter.h @@ -0,0 +1,31 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_TRIU_TRIL_PARAMETER_H_ +#define NNACL_TRIU_TRIL_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct TriuParameter { + // Primitive parameter + OpParameter op_parameter_; +} TriuParameter; + +typedef struct TrilParameter { + // Primitive parameter + OpParameter op_parameter_; +} TrilParameter; + +#endif // NNACL_TRIU_TRIL_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/unsqueeze_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/unsqueeze_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..66a8fca69543115b404575c503e28e8f410ff7af --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/unsqueeze_parameter.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_UNSQUEEZE_PARAMETER_H_ +#define NNACL_UNSQUEEZE_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct UnSqueezeQuantArg { + int *output_shape_; + float alpha; + int axis_; + size_t input_num_; + QuantArg in_quant_args_; + QuantArg out_quant_args_; +} UnSqueezeQuantArg; + +typedef struct UnSqueezeParameter { + // primitive parameter + OpParameter op_parameter_; + int dims_[COMM_SHAPE_SIZE]; + int num_dim_; + + // shape correlative + const int *in_shape_; + const int *out_shape_; + int64_t offset_[COMM_SHAPE_SIZE]; + int64_t axis_; + + // other parameter + UnSqueezeQuantArg quant_arg; + int thread_count_; +} UnSqueezeParameter; + +#endif // NNACL_UNSQUEEZE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/unstack_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/unstack_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..669f938fa45340ea52e864097277477fb86918d2 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/unstack_parameter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_UNSTACK_PARAMETER_H_ +#define NNACL_UNSTACK_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct UnstackParameter { + // primitive parameter + OpParameter op_parameter_; + int num_; + int axis_; + + // other parameter + int pre_dims_; + int axis_dim_; + int after_dims_; +} UnstackParameter; + +#endif // NNACL_UNSTACK_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/upsample_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/upsample_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..0b74da673fe807b64a9f6ec58861fb81501c97c4 --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/upsample_parameter.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef NNACL_UPSAMPLE_PARAMETER_H_ +#define NNACL_UPSAMPLE_PARAMETER_H_ + +#include "nnacl/op_base.h" +typedef struct { + // primitive parameter + OpParameter op_parameter_; + + // other parameter + int method_; // 0 for bilinear; 1 for nearest +} UpsampleParameter; + +#endif // NNACL_UPSAMPLE_PARAMETER_H_ diff --git a/mindspore-lite/ops/kernel/cpu/nnacl/where_parameter.h b/mindspore-lite/ops/kernel/cpu/nnacl/where_parameter.h new file mode 100644 index 0000000000000000000000000000000000000000..149a7d2a09f561008cb191e4b3184a9d46636bda --- /dev/null +++ b/mindspore-lite/ops/kernel/cpu/nnacl/where_parameter.h @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef NNACL_WHERE_PARAMETER_H_ +#define NNACL_WHERE_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct WhereParameter { + OpParameter op_parameter_; +} WhereParameter; + +#endif // NNACL_WHERE_PARAMETER_H_ diff --git a/mindspore-lite/providers/siteai/CMakeLists.txt b/mindspore-lite/providers/siteai/CMakeLists.txt index dccb9942a6eaa4f05c78935da7b2f6d1fb55a301..9730ebba765161cb804ed0e4b3650e3c5868ff23 100644 --- a/mindspore-lite/providers/siteai/CMakeLists.txt +++ b/mindspore-lite/providers/siteai/CMakeLists.txt @@ -11,5 +11,7 @@ set(MSLITE_DEPS_CMSIS off CACHE INTERNAL "setting MSLITE_DEPS_CMSIS value") ##enable prune simplest cloud inferenc set(MSLITE_SIMPLEST_CLOUD_INFERENCE on CACHE INTERNAL "setting MSLITE_SIMPLEST_CLOUD_INFERENCE value") -set(MSLITE_MINDDATA_IMPLEMENT "off" CACHE INTERNAL "setting MSLITE_MINDDATA_IMPLEMENT value") +if(NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) + set(MSLITE_MINDDATA_IMPLEMENT "off" CACHE INTERNAL "setting MSLITE_MINDDATA_IMPLEMENT value") +endif() set(MSLITE_ENABLE_OPENCV on CACHE INTERNAL "setting MSLITE_ENABLE_OPENCV value") \ No newline at end of file diff --git a/mindspore-lite/python/CMakeLists.txt b/mindspore-lite/python/CMakeLists.txt index 14747406db2807d228fc9ec2e9086994b18abb73..50e3cd1208153c05041e670a1fccf798b2d970f9 100644 --- a/mindspore-lite/python/CMakeLists.txt +++ b/mindspore-lite/python/CMakeLists.txt @@ -1,7 +1,6 @@ cmake_minimum_required(VERSION 3.12) project(MindSpore_Lite_Python_API) -# set(CMAKE_VERBOSE_MAKEFILE on) set(PYBIND11_CPP_STANDARD -std=c++17) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-function -Wl,-rpath,$ORIGIN/") @@ -36,7 +35,7 @@ if(Python3_FOUND) pybind11_add_module(_c_lite_wrapper NO_EXTRAS ${PY_SRC_LIST}) if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) - target_link_libraries(_c_lite_wrapper PRIVATE -Wl,--whole-archive mindspore-extendrt + target_link_libraries(_c_lite_wrapper PRIVATE -Wl,--whole-archive mindspore-extendrt lite_src_common_mid -Wl,--no-whole-archive -Wl,-z,relro,-z,now,-z,noexecstack -fstack-protector-all -s) endif() if(MSLITE_ENABLE_CONVERTER OR MSLITE_ENABLE_RUNTIME_GLOG) diff --git a/mindspore-lite/python/api/__init__.py b/mindspore-lite/python/api/__init__.py index e53d9a766a6c1373c44c43f277b5020a6d705577..a45138b1ec67cfe7d6789744abf0e2a342ea846d 100644 --- a/mindspore-lite/python/api/__init__.py +++ b/mindspore-lite/python/api/__init__.py @@ -27,9 +27,6 @@ from mindspore_lite.converter import FmkType, Converter from mindspore_lite.model import ModelType, Model, ModelParallelRunner, ModelGroup, ModelGroupFlag from mindspore_lite.tensor import DataType, Format, Tensor from mindspore_lite.lite_split import split_network, split_ir -from mindspore_lite.llm_engine import LLMReq, LLMEngineStatus, LLMRole, LLMEngine, LLMClusterInfo, LLMStatusCode -from mindspore_lite.llm_engine import LLMException, LLMKVCacheNotExist, LLMWaitProcessTimeOut, LLMRepeatRequest -from mindspore_lite.llm_engine import LLMRequestAlreadyCompleted, LLMEngineFinalized, LLMParamInvalid, LLMNotYetLink from mindspore_lite import lite_infer @@ -84,4 +81,3 @@ __all__.extend(converter.__all__) __all__.extend(model.__all__) __all__.extend(tensor.__all__) __all__.extend(lite_split.__all__) -__all__.extend(llm_engine.__all__) diff --git a/mindspore-lite/python/api/context.py b/mindspore-lite/python/api/context.py index 144c389736c199435e5b41536ff2fce07a1c2267..ffc33adbed68c03ef3397bc3f4e05705bae5ec25 100644 --- a/mindspore-lite/python/api/context.py +++ b/mindspore-lite/python/api/context.py @@ -95,7 +95,6 @@ class Context: >>> print(context) target: ['cpu']. >>> # testcase 2 about context's attribute parallel based on server inference package - >>> # (export MSLITE_ENABLE_SERVER_INFERENCE=on before compile lite or use cloud inference package) >>> import mindspore_lite as mslite >>> context = mslite.Context() >>> context.target = ["cpu"] @@ -774,7 +773,7 @@ class _Parallel: self._runner_config = _c_lite_wrapper.RunnerConfigBind() else: raise RuntimeError(f"parallel init failed, If you want to set parallel, you need to build" - f"MindSpore Lite serving package by export MSLITE_ENABLE_SERVER_INFERENCE=on.") + f"MindSpore Lite serving package by export MSLITE_ENABLE_CLOUD_INFERENCE=on.") if context is not None: self._runner_config.set_context(context._inner_context) diff --git a/mindspore-lite/python/api/llm_engine.py b/mindspore-lite/python/api/llm_engine.py deleted file mode 100644 index f2a0e695f98f36ea0ccaa1284f7d97807948541e..0000000000000000000000000000000000000000 --- a/mindspore-lite/python/api/llm_engine.py +++ /dev/null @@ -1,1027 +0,0 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -LLMEngin interface -""" -__all__ = ['LLMReq', 'LLMEngineStatus', 'LLMRole', 'LLMEngine', 'LLMStatusCode', 'LLMException'] - -import os -import sys -import threading -from enum import Enum -from typing import Union, List, Tuple, Dict -from mindspore_lite._checkparam import check_isinstance, check_uint32_number_range, check_uint64_number_range -from mindspore_lite.tensor import Tensor -from mindspore_lite.lib._c_lite_wrapper import LLMEngine_, LLMReq_, LLMRole_, StatusCode, LLMClusterInfo_ -from mindspore_lite.model import set_env - - -class LLMReq: - """ - LLMEngine request, used to represent a multi round inference task. - """ - - def __init__(self, prompt_cluster_id: int, req_id: int, prompt_length: int): - check_uint64_number_range("prompt_cluster_id", prompt_cluster_id) - check_uint64_number_range("req_id", req_id) - check_uint64_number_range("prompt_length", prompt_length) - self.llm_request_ = LLMReq_() - self.llm_request_.prompt_cluster_id = prompt_cluster_id - self.llm_request_.req_id = req_id - self.llm_request_.prompt_length = prompt_length - - _llm_req_id = 0 - _llm_req_id_lock = threading.Lock() - - @staticmethod - def next_req_id(): - with LLMReq._llm_req_id_lock: - new_req_id = LLMReq._llm_req_id - LLMReq._llm_req_id += 1 - return new_req_id - - @property - def req_id(self): - """Get request id of this inference task""" - return self.llm_request_.req_id - - @req_id.setter - def req_id(self, req_id: int): - """Set request id of this inference task""" - check_uint64_number_range("req_id", req_id) - self.llm_request_.req_id = req_id - - @property - def prompt_length(self): - """Set prompt length of this inference task""" - return self.llm_request_.prompt_length - - @prompt_length.setter - def prompt_length(self, prompt_length: int): - """Get prompt length of this inference task""" - check_uint64_number_range("prompt_length", prompt_length) - self.llm_request_.prompt_length = prompt_length - - @property - def prompt_cluster_id(self): - """Get prompt cluster id of this inference task in LLMEngine""" - return self.llm_request_.prompt_cluster_id - - @prompt_cluster_id.setter - def prompt_cluster_id(self, prompt_cluster_id: int): - """Set prompt cluster id of this inference task in LLMEngine""" - check_uint64_number_range("prompt_cluster_id", prompt_cluster_id) - self.llm_request_.prompt_cluster_id = prompt_cluster_id - - @property - def decoder_cluster_id(self): - """Get decoder cluster id of this inference task in LLMEngine""" - return self.llm_request_.decoder_cluster_id - - @decoder_cluster_id.setter - def decoder_cluster_id(self, decoder_cluster_id: int): - """Set decoder cluster id of this inference task in LLMEngine""" - check_uint64_number_range("decoder_cluster_id", decoder_cluster_id) - self.llm_request_.decoder_cluster_id = decoder_cluster_id - - @property - def prefix_id(self): - """Get decoder prefix id of this inference task in LLMEngine""" - return self.llm_request_.prefix_id - - @prefix_id.setter - def prefix_id(self, prefix_id: int): - """Set decoder prefix id of this inference task in LLMEngine""" - check_uint64_number_range("prefix_id", prefix_id) - self.llm_request_.prefix_id = prefix_id - - @property - def sequence_length(self): - """Get decoder sequence length of this inference task in LLMEngine""" - return self.llm_request_.sequence_length - - @sequence_length.setter - def sequence_length(self, sequence_length: int): - """Set decoder sequence length of this inference task in LLMEngine""" - check_uint64_number_range("sequence_length", sequence_length) - self.llm_request_.sequence_length = sequence_length - - -class LLMEngineStatus: - """ - LLMEngine Status, which can be got from LLEngine.fetch_status. - """ - - def __init__(self, status): - self.status_ = status - - @property - def empty_max_prompt_kv(self): - """Get empty count of prompt KV cache of this LLMEngine object""" - return self.status_.empty_max_prompt_kv - - @property - def num_free_blocks(self): - """Get number of free blocks PagedAttention""" - return self.status_.num_free_blocks - - @property - def num_total_blocks(self): - """Get number of total blocks PagedAttention""" - return self.status_.num_total_blocks - - @property - def block_size(self): - """Get block size of PagedAttention""" - return self.status_.block_size - - -class LLMStatusCode(Enum): - """ - LLM Error Code - """ - LLM_SUCCESS = StatusCode.kSuccess - LLM_WAIT_PROC_TIMEOUT = StatusCode.kLiteLLMWaitProcessTimeOut - LLM_KV_CACHE_NOT_EXIST = StatusCode.kLiteLLMKVCacheNotExist - LLM_REPEAT_REQUEST = StatusCode.kLiteLLMRepeatRequest - LLM_REQUEST_ALREADY_COMPLETED = StatusCode.kLiteLLMRequestAlreadyCompleted - LLM_PARAM_INVALID = StatusCode.kLiteParamInvalid - LLM_ENGINE_FINALIZED = StatusCode.kLiteLLMEngineFinalized - LLM_NOT_YET_LINK = StatusCode.kLiteLLMNotYetLink - LLM_ALREADY_LINK = StatusCode.kLiteLLMAlreadyLink - LLM_LINK_FAILED = StatusCode.kLiteLLMLinkFailed - LLM_UNLINK_FAILED = StatusCode.kLiteLLMUnlinkFailed - LLM_NOTIFY_PROMPT_UNLINK_FAILED = StatusCode.kLiteLLMNofiryPromptUnlinkFailed - LLM_CLUSTER_NUM_EXCEED_LIMIT = StatusCode.kLiteLLMClusterNumExceedLimit - LLM_PROCESSING_LINK = StatusCode.kLiteLLMProcessingLink - LLM_DEVICE_OUT_OF_MEMORY = StatusCode.kLiteLLMOutOfMemory - LLM_PREFIX_ALREADY_EXIST = StatusCode.kLiteLLMPrefixAlreadyExist - LLM_PREFIX_NOT_EXIST = StatusCode.kLiteLLMPrefixNotExist - LLM_SEQ_LEN_OVER_LIMIT = StatusCode.kLiteLLMSeqLenOverLimit - LLM_NO_FREE_BLOCK = StatusCode.kLiteLLMNoFreeBlock - LLM_BLOCKS_OUT_OF_MEMORY = StatusCode.kLiteLLMBlockOutOfMemory - - -class LLMRole(Enum): - """ - Role of LLMEngine. When LLMEngine accelerates inference performance through KVCache, the generation process includes - one full inference and n incremental inference, involving both full and incremental models. When the full and - incremental models are deployed on different nodes, the role of the node where the full models are located is - ``Prompt``, and the role of the node where the incremental models are located is ``Decoder``. - """ - Prompt = 0 - Decoder = 1 - - -class LLMException(RuntimeError): - """ - Base Error class for LLM - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_SUCCESS - - @property - def statusCode(self): - """ - LLMException status code property - """ - return self._status_code - - def StatusCode(self): - """ - get LLMException status code - """ - return self._status_code - - - -class LLMKVCacheNotExist(LLMException): - """ - Key & Value cache does not exist in Prompt cluster specified by parameter LLMReq.prompt_cluster_id, and the - LLM request may have been released in Prompt cluster by calling method LLMEngine.complete_request. - Raised in LLMEngine.predict. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_KV_CACHE_NOT_EXIST - - - -class LLMWaitProcessTimeOut(LLMException): - """ - Request waiting for processing timed out. Raised in LLMEngine.predict. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_WAIT_PROC_TIMEOUT - - -class LLMRepeatRequest(LLMException): - """ - Request repeated . Raised in LLMEngine.predict. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_REPEAT_REQUEST - - -class LLMRequestAlreadyCompleted(LLMException): - """ - Request has already completed. Raised in LLMEngine.predict. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_REQUEST_ALREADY_COMPLETED - - -class LLMEngineFinalized(LLMException): - """ - LLMEngine has finalized. Raised in LLMEngine.predict. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_ENGINE_FINALIZED - - -class LLMParamInvalid(LLMException): - """ - Parameters invalid. Raised in LLMEngine.predict. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_PARAM_INVALID - - - -class LLMNotYetLink(LLMException): - """ - Decoder cluster has no link with prompt. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_NOT_YET_LINK - - - -class LLMOutOfMemory(LLMException): - """ - Device out of memory. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_DEVICE_OUT_OF_MEMORY - - - -class LLMPrefixAlreadyExist(LLMException): - """ - Prefix has already existed. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_PREFIX_ALREADY_EXIST - - - -class LLMPrefixNotExist(LLMException): - """ - Prefix does not exist. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_PREFIX_NOT_EXIST - - - -class LLMSeqLenOverLimit(LLMException): - """ - Sequence length exceed limit. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_SEQ_LEN_OVER_LIMIT - - - -class LLMNoFreeBlocks(LLMException): - """ - No free block. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_NO_FREE_BLOCK - - - -class LLMBlockOutOfMemory(LLMException): - """ - Block is out of memory. - """ - def __init__(self, *args: object): - super().__init__(*args) - self._status_code = LLMStatusCode.LLM_BLOCKS_OUT_OF_MEMORY - - - -class LLMClusterInfo: - """ - The `LLMClusterInfo` class defines a MindSpore Lite's LLMEngine cluster, used to link and unlink clusters. - - Args: - remote_role (LLMRole): Role of remote LLMEngine object. - remote_cluster_id (int): Cluster id of remote LLMEngine object. - - Raises: - TypeError: `remote_role` is not a LLMRole. - TypeError: `remote_cluster_id` is not an int. - - Examples: - >>> import mindspore_lite as mslite - >>> remote_cluster_id = 1 - >>> cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, remote_cluster_id) - >>> cluster.append_local_ip_info(("192.168.1.1", 2222)) - >>> cluster.append_remote_ip_info(("192.168.2.1", 2222)) - >>> local_cluster_id = 0 - >>> llm_engine = mslite.LLMEngine(mslite.LLMRole.Decoder, local_cluster_id) - >>> # ... llm_engine.init - >>> llm_engine.link_clusters([cluster]) - """ - def __init__(self, remote_role: LLMRole, remote_cluster_id: int): - check_uint64_number_range("remote_cluster_id", remote_cluster_id) - check_isinstance("remote_role", remote_role, LLMRole) - self.llm_cluster_ = LLMClusterInfo_() - self.llm_cluster_.remote_cluster_id = remote_cluster_id - remote_role_type_int = 0 if remote_role == LLMRole.Prompt else 1 # 0: Prompt, 1: Decoder - self.llm_cluster_.remote_role_type = remote_role_type_int - - @property - def remote_role(self): - """Get remote role of this LLMClusterInfo object""" - remote_role_type_int = self.llm_cluster_.remote_role - return LLMRole.Prompt if remote_role_type_int == 0 else LLMRole.Decoder # 0: Prompt, 1: Decoder - - @remote_role.setter - def remote_role(self, remote_role): - """Set remote role of this LLMClusterInfo object""" - check_isinstance("remote_role", remote_role, LLMRole) - remote_role_type_int = 0 if remote_role == LLMRole.Prompt else 1 # 0: Prompt, 1: Decoder - self.llm_cluster_.remote_role_type = remote_role_type_int - - @property - def remote_cluster_id(self): - """Get remote cluster id of this LLMClusterInfo object""" - return self.llm_cluster_.remote_cluster_id - - @remote_cluster_id.setter - def remote_cluster_id(self, remote_cluster_id): - """Set remote cluster id of this LLMClusterInfo object""" - check_uint64_number_range("remote_cluster_id", remote_cluster_id) - self.llm_cluster_.remote_cluster_id = remote_cluster_id - - def append_local_ip_info(self, address): - """ - Append local ip info. - - Args: - address: ip address, in format ('xxx.xxx.xxx.xxx', xxx) or (xxx, xxx). - - Raises: - TypeError: `address` format or type is invalid. - ValueError: `address` value is invalid. - - Examples: - >>> import mindspore_lite as mslite - >>> cluster_id = 1 - >>> cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 0) - >>> cluster.append_local_ip_info(("192.168.1.1", 2222)) - """ - ip, port = LLMClusterInfo._trans_address(address) - self.llm_cluster_.append_local_ip_info(ip, port) - - @property - def local_ip_infos(self): - """Get all local ip infos of this LLMClusterInfo object""" - return tuple((ip, port) for ip, port in self.llm_cluster_.get_local_ip_infos()) - - def append_remote_ip_info(self, address): - """ - Append remote ip info. - - Args: - address: ip address, in format ('xxx.xxx.xxx.xxx', xxx) or (xxx, xxx). - - Raises: - TypeError: `address` format or type is invalid. - ValueError: `address` value is invalid. - - Examples: - >>> import mindspore_lite as mslite - >>> cluster_id = 1 - >>> cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 0) - >>> cluster.append_remote_ip_info(("192.168.1.1", 2222)) - """ - ip, port = LLMClusterInfo._trans_address(address) - self.llm_cluster_.append_remote_ip_info(ip, port) - - @property - def remote_ip_infos(self): - """Get all remote ip infos of this LLMClusterInfo object""" - return tuple((ip, port) for ip, port in self.llm_cluster_.get_remote_ip_infos()) - - @staticmethod - def _trans_address(address): - """Transfer address from str format 'xxx.xxx.xxx.xxx' to int""" - if not isinstance(address, tuple): - raise TypeError(f"address must be in format of ('xxx.xxx.xxx.xxx', xxx) or (xxx, xxx), but got {address}") - if len(address) != 2: - raise TypeError(f"address must be in format of ('xxx.xxx.xxx.xxx', xxx) or (xxx, xxx), but got {address}") - ip, port = address - if not isinstance(ip, (str, int)) or not isinstance(port, int): - raise TypeError(f"address must be in format of ('xxx.xxx.xxx.xxx', xxx) or (xxx, xxx), but got {address}") - if isinstance(ip, int) and (ip < 0 or ip > pow(2, 32) - 1): - raise ValueError(f"address ip should in range [0,{pow(2, 32) - 1}], but got {ip}") - if port < 0 or port > 65535: - raise ValueError(f"address port should in range [0,65535], but got {port}") - if isinstance(ip, str): - try: - if "." not in ip: # format ("[0-9]+", xxx) - ip = int(ip) - return ip, port - except ValueError: - raise ValueError( - f"address must be in format of ('xxx.xxx.xxx.xxx', xxx) or (xxx, xxx), but got {address}") - try: - import socket - ip = socket.inet_aton(ip) - ip = int.from_bytes(ip, byteorder=sys.byteorder) - except OSError: - raise ValueError( - f"address must be in format of ('xxx.xxx.xxx.xxx', xxx) or (xxx, xxx), but got {address}") - return ip, port - - -def _handle_llm_status(status, func_name, other_info): - """Handle LLM error code""" - status_code = status.StatusCode() - if status_code != StatusCode.kSuccess: - if not isinstance(other_info, str): - other_info = other_info() - error_code_map = { - StatusCode.kLiteLLMWaitProcessTimeOut: - LLMWaitProcessTimeOut(f"{func_name} failed: Waiting for processing timeout, {other_info}"), - StatusCode.kLiteLLMKVCacheNotExist: - LLMKVCacheNotExist(f"{func_name} failed: KV Cache not exist, {other_info}."), - StatusCode.kLiteLLMRepeatRequest: LLMRepeatRequest(f"{func_name} failed: Repeat request, {other_info}."), - StatusCode.kLiteLLMRequestAlreadyCompleted: - LLMRequestAlreadyCompleted(f"{func_name} failed: Request has already completed, {other_info}."), - StatusCode.kLiteLLMEngineFinalized: - LLMEngineFinalized(f"{func_name} failed: LLMEngine has finalized, {other_info}."), - StatusCode.kLiteParamInvalid: LLMParamInvalid(f"{func_name} failed: Parameters invalid, {other_info}."), - StatusCode.kLiteLLMNotYetLink: - LLMNotYetLink(f"{func_name} failed: Decoder cluster is no link with prompt, {other_info}."), - StatusCode.kLiteLLMOutOfMemory: LLMOutOfMemory(f"{func_name} failed: Device out of memory, {other_info}."), - StatusCode.kLiteLLMPrefixAlreadyExist: - LLMPrefixAlreadyExist(f"{func_name} failed: Prefix has already existed, {other_info}."), - StatusCode.kLiteLLMPrefixNotExist: - LLMPrefixNotExist(f"{func_name} failed: Prefix does not exist, {other_info}."), - StatusCode.kLiteLLMSeqLenOverLimit: - LLMSeqLenOverLimit(f"{func_name} failed: Sequence length exceed limit, {other_info}."), - StatusCode.kLiteLLMNoFreeBlock: LLMNoFreeBlocks(f"{func_name} failed: No free block, {other_info}."), - StatusCode.kLiteLLMBlockOutOfMemory: - LLMBlockOutOfMemory(f"{func_name} failed: NBlock is out of memory, {other_info}."), - } - if status_code in error_code_map: - raise error_code_map[status_code] - raise RuntimeError(f"{func_name} failed, {other_info}.") - - -def _llm_req_str(llm_req): - return "{" + f"llm_req: {llm_req.req_id}, prompt_cluster_id: {llm_req.prompt_cluster_id}, " \ - f"decoder_cluster_id: {llm_req.decoder_cluster_id}, prefix_id: {llm_req.prefix_id}, " \ - f"prompt_length: {llm_req.prompt_length}" + "}" - - -class LLMModel: - """ - The `LLMModel` class defines one model of MindSpore Lite's LLMEngine, used to schedule and execute inference - request. LLMModel object should be created from LLMEngine.add_model. - - Examples: - >>> import mindspore_lite as mslite - >>> cluster_id = 1 - >>> llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, cluster_id) - >>> model_paths = [os.path.join(model_dir, f"device_${rank}") for rank in range(4)] - >>> options = {} - >>> llm_model = llm_engine.add_mode(model_paths, options) # return LLMModel object - >>> llm_engine.init() - >>> llm_req = mslite.LLMReq(llm_engine.cluster_id, mslite.LLMReq.next_req_id(), prompt_length=1024) - >>> inputs = [mslite.Tensor(np_input) for np_input in np_inputs] - >>> outputs = llm_model.predit(llm_req, inputs) - >>> for output in outputs: - >>> print(f"output is {output.get_data_to_numpy()}") - >>> llm_engine.complete(llm_req) - """ - def __init__(self, model_obj, batch_mode): - self.model_ = model_obj # inited by LLMEngine - self.batch_mode_ = batch_mode - self.inited_ = False - - def predict(self, llm_req: Union[LLMReq, List[LLMReq], Tuple[LLMReq]], inputs: Union[Tuple[Tensor], List[Tensor]]): - """ - Schedule and execute inference request. - - Args: - llm_req (Union[LLMReq, list[LLMReq], Tuple[LLMReq]]): Request of LLMEngine. - inputs (Union[Tuple[Tensor], List[Tensor]]): A list that includes all input Tensors in order. - - Returns: - list[Tensor], the output Tensor list of the model. - - Raises: - TypeError: `llm_req` is not a LLMReq. - TypeError: `inputs` is not a list. - TypeError: `inputs` is a list, but the elements are not Tensor. - RuntimeError: schedule and execute inference request failed. - RuntimeError: this LLMEngine object has not been inited. - LLMKVCacheNotExist: Key & Value cache does not exist in Prompt cluster specified by - `llm_req.prompt_cluster_id`, and the LLM request may have been released in Prompt cluster - by calling method LLMEngine.complete_request. - LLMWaitProcessTimeOut: Request waiting for processing timed out. - LLMRepeatRequest: Repeat request. - LLMRequestAlreadyCompleted: Request has already completed. - LLMEngineFinalized: LLMEngine has finalized. - LLMParamInvalid: Parameters invalid. - """ - if not self.inited_: - raise RuntimeError(f"LLMEngine is not inited or init failed") - if not isinstance(inputs, (tuple, list)): - raise TypeError(f"inputs must be list/tuple of Tensor, but got {type(inputs)}.") - if not isinstance(llm_req, (list, tuple, LLMReq)): - raise TypeError(f"llm_req must be instance of LLMReq or list/tuple of LLMReq, but got {type(llm_req)}.") - if self.batch_mode_ == "manual": - if not isinstance(llm_req, (list, tuple)): - raise TypeError(f"llm_req must be list/tuple of LLMReq when batch_mode is \"manual\"," - f" but got {type(llm_req)}.") - for i, item in enumerate(llm_req): - if not isinstance(item, LLMReq): - raise TypeError(f"llm_req element must be LLMReq when batch_mode is \"manual\"," - f" but got {type(item)} at index {i}.") - else: - if not isinstance(llm_req, LLMReq): - raise TypeError(f"llm_req must be LLMReq when batch_mode is \"auto\", but got {type(llm_req)}.") - - _inputs = [] - for i, element in enumerate(inputs): - if not isinstance(element, Tensor): - raise TypeError(f"inputs element must be Tensor, but got {type(element)} at index {i}.") - # pylint: disable=protected-access - _inputs.append(element._tensor) - # pylint: disable=protected-access - if self.batch_mode_ == "manual": - llm_req_list = [item.llm_request_ for item in llm_req] - outputs, status = self.model_.predict_batch(llm_req_list, _inputs) - else: - outputs, status = self.model_.predict(llm_req.llm_request_, _inputs) - - def get_info(): - if isinstance(llm_req, LLMReq): - req_infos = _llm_req_str(llm_req) - else: - req_infos = [_llm_req_str(llm) for llm in llm_req] - - input_infos = [(item.shape, item.dtype) for item in inputs] - info = f"llm_req {req_infos}, inputs {input_infos}" - return info - - _handle_llm_status(status, "predict", get_info) - if not outputs: - raise RuntimeError(f"predict failed, {get_info()}.") - predict_outputs = [Tensor(output) for output in outputs] - return predict_outputs - - def pull_kv(self, llm_req: LLMReq): - """ - For Decoder LLMEngine, fetch KVCache from Prompt LLMEngine specified by llm_req.prompt_cluster and - llm_req.req_id. - - Args: - llm_req (LLMReq): Request of LLMEngine. - - Raises: - TypeError: `llm_req` is not a LLMReq. - RuntimeError: this LLMEngine object has not been inited. - RuntimeError: Failed to pull KVCache. - LLMKVCacheNotExist: Key & Value cache does not exist in Prompt cluster specified by - `llm_req.prompt_cluster_id`, and the LLM request may have been released in Prompt cluster - by calling method LLMEngine.complete_request. - LLMParamInvalid: Parameters invalid. - """ - if not self.inited_: - raise RuntimeError(f"LLMEngine is not inited or init failed") - if self.batch_mode_ != "manual": - raise RuntimeError(f"LLMEngine.pull_kv is only support when batch_mode is \"manual\"") - check_isinstance("llm_req", llm_req, LLMReq) - # pylint: disable=protected-access - status = self.model_.pull_kv(llm_req.llm_request_) - _handle_llm_status(status, "pull_kv", "llm_req " + _llm_req_str(llm_req)) - - def merge_kv(self, llm_req: LLMReq, batch_index: int, batch_id: int = 0): - """ - For Decoder LLMEngine, merge KVCache of LLMReq specified by `llm_req.req_id` into `batch_index` slot. - Args: - llm_req (LLMReq): Request of LLMEngine. - batch_index (int): Request batch index. - batch_id (int): Request pipline index for ping pong pipeline. - - Raises: - TypeError: `llm_req` is not a LLMReq. - RuntimeError: this LLMEngine object has not been inited. - RuntimeError: Failed to merge KVCache. - LLMParamInvalid: Parameters invalid. - """ - if not self.inited_: - raise RuntimeError(f"LLMEngine is not inited or init failed") - if self.batch_mode_ != "manual": - raise RuntimeError(f"LLMEngine.merge_kv is only support when batch_mode is \"manual\"") - check_isinstance("llm_req", llm_req, LLMReq) - check_uint32_number_range("batch_index", batch_index) - check_uint32_number_range("batch_id", batch_id) - # pylint: disable=protected-access - status = self.model_.merge_kv(llm_req.llm_request_, batch_index, batch_id) - _handle_llm_status(status, "merge_kv", "llm_req " + _llm_req_str(llm_req)) - - def preload_prompt_prefix(self, llm_req: LLMReq, inputs: Union[Tuple[Tensor], List[Tensor]]): - """ - Preload prompt inference common prefix. - - Args: - llm_req (LLMReq): Request of LLMEngine. - inputs (Union[Tuple[Tensor], List[Tensor]]): A list that includes all input Tensors in order. - - Raises: - TypeError: `llm_req` is not a LLMReq. - TypeError: `inputs` is not a list. - TypeError: `inputs` is a list, but the elements are not Tensor. - RuntimeError: preload prompt prefix inference request failed. - RuntimeError: this LLMEngine object has not been inited. - LLMParamInvalid: Parameters invalid. - """ - if not self.inited_: - raise RuntimeError(f"LLMEngine is not inited or init failed") - if not isinstance(inputs, (tuple, list)): - raise TypeError(f"inputs must be list/tuple of Tensor, but got {type(inputs)}.") - check_isinstance("llm_req", llm_req, LLMReq) - _inputs = [] - for i, element in enumerate(inputs): - if not isinstance(element, Tensor): - raise TypeError(f"inputs element must be Tensor, but got {type(element)} at index {i}.") - # pylint: disable=protected-access - _inputs.append(element._tensor) - # pylint: disable=protected-access - ret = self.model_.preload_prompt_prefix(llm_req.llm_request_, _inputs) - _handle_llm_status(ret, "preload_prompt_prefix", "llm_req " + _llm_req_str(llm_req)) - - def release_prompt_prefix(self, llm_req: LLMReq): - """ - Release the memory space used by prompt inference common prefix. - - Args: - llm_req (LLMReq): Request of LLMEngine. - - Raises: - TypeError: `llm_req` is not a LLMReq. - RuntimeError: this LLMEngine object has not been inited. - LLMParamInvalid: Parameters invalid. - """ - if not self.inited_: - raise RuntimeError(f"LLMEngine is not inited or init failed") - check_isinstance("llm_req", llm_req, LLMReq) - # pylint: disable=protected-access - ret = self.model_.release_prompt_prefix(llm_req.llm_request_) - _handle_llm_status(ret, "release_prompt_prefix", "llm_req " + _llm_req_str(llm_req)) - - def get_inputs(self) -> List[Tensor]: - """ - Get inputs of this LLMModel. - - Returns: - Tuple[Tensor], the input Tensor list of the model. - - Examples: - >>> import mindspore_lite as mslite - >>> cluster_id = 1 - >>> llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, cluster_id) - >>> model_paths = [os.path.join(model_dir, f"device_${rank}") for rank in range(4)] - >>> options = {} - >>> llm_model = llm_engine.add_mode(model_paths, options) # return LLMModel object - >>> inputs = llm_model.get_inputs() - >>> for i in range(len(inputs)): - ... print(f"Input name {inputs[i].name}, dtype {inputs[i].dtype}, shape: {inputs[i].shape}") - """ - if not self.model_: - raise RuntimeError(f"LLMModel is invalid, please return LLMModel from LLMEngine.add_model.") - inputs = [] - for _tensor in self.model_.get_inputs(): - inputs.append(Tensor(_tensor)) - return inputs - - -class LLMEngine: - """ - The `LLMEngine` class defines a MindSpore Lite's LLMEngine, used to load and manage Large Language Mode, - and schedule and execute inference request. - - Args: - role (LLMRole): Role of this LLMEngine object. - cluster_id (int): Cluster id of this LLMEngine object. - batch_mode (str): Controls whether the request batching is "auto" formed by the framework or "manual"ly - by the user. Option is "auto" or "manual", default "auto". - - Raises: - TypeError: `role` is not a LLMRole. - TypeError: `cluster_id` is not an int. - - Examples: - >>> import mindspore_lite as mslite - >>> cluster_id = 1 - >>> llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, cluster_id) - >>> model_paths = [os.path.join(model_dir, f"device_${rank}") for rank in range(4)] - >>> options = {} - >>> llm_model = llm_engine.add_mode(model_paths, options) # return LLMModel object - >>> llm_engine.init() - >>> llm_req = mslite.LLMReq(llm_engine.cluster_id, mslite.LLMReq.next_req_id(), prompt_length=1024) - >>> inputs = [mslite.Tensor(np_input) for np_input in np_inputs] - >>> outputs = llm_model.predit(llm_req, inputs) - >>> for output in outputs: - >>> print(f"output is {output.get_data_to_numpy()}") - >>> llm_engine.complete(llm_req) - """ - - def __init__(self, role: LLMRole, cluster_id: int, batch_mode="auto"): - check_isinstance("role", role, LLMRole) - check_uint64_number_range("cluster_id", cluster_id) - check_isinstance("batch_mode", batch_mode, str) - if batch_mode != "auto" and batch_mode != "manual": - raise ValueError(f"batch_mode should be str \"auto\" or \"manual\", but got {batch_mode}") - self.role_ = role - self.cluster_id_ = cluster_id - self.batch_mode_ = batch_mode - self.models_ = [] - self.inited_ = False - role_inner = LLMRole_.Prompt if self.role == LLMRole.Prompt else LLMRole_.Decoder - self.engine_ = LLMEngine_(role_inner, self.cluster_id, self.batch_mode) - - @property - def cluster_id(self): - """Get cluster id set to this LLMEngine object""" - return self.cluster_id_ - - @property - def role(self): - """Get LLM role set to this LLMEngine object""" - return self.role_ - - @property - def batch_mode(self): - """Get batch mode of this LLMEngine object""" - return self.batch_mode_ - - def add_model(self, model_paths: Union[Tuple[str], List[str]], options: Dict[str, str], - postprocess_model_path=None) -> LLMModel: - """ - Add model to LLMEngine. - - Args: - model_paths (Union[Tuple[str], List[str]]): List or tuple of model path. - options (Dict[str, str]): Other init options of this LLMEngine object. - postprocess_model_path (Union[str, None]): Postprocess model path, default None. - - Raises: - TypeError: `model_paths` is not a list and tuple. - TypeError: `model_paths` is a list or tuple, but the elements are not str. - TypeError: `options` is not a dict. - RuntimeError: add model failed. - """ - if self.inited_: - raise RuntimeError(f"Cannot add model for LLMEngine: LLMEngine has been inited") - if not isinstance(model_paths, (list, tuple)): - raise TypeError(f"model_paths must be tuple/list of str, but got {type(model_paths)}.") - for i, model_path in enumerate(model_paths): - if not isinstance(model_path, str): - raise TypeError(f"model_paths element must be str, but got {type(model_path)} at index {i}.") - if not os.path.exists(model_path): - raise RuntimeError(f"model_paths {model_path} at index {i} does not exist!") - check_isinstance("options", options, dict) - for key, value in options.items(): - if not isinstance(key, str): - raise TypeError(f"options key must be str, but got {type(key)}.") - if not isinstance(value, str): - raise TypeError(f"options value must be str, but got {type(value)}.") - if postprocess_model_path is not None: - if not isinstance(postprocess_model_path, str): - raise TypeError( - f"postprocess_model_path must be None or str, but got {type(postprocess_model_path)}.") - if not os.path.exists(postprocess_model_path): - raise RuntimeError(f"postprocess_model_path {postprocess_model_path} does not" - f" exist!") - else: - postprocess_model_path = "" - - ret, llm_model_inner = self.engine_.add_model(model_paths, options, postprocess_model_path) - status_code = ret.StatusCode() - if status_code == StatusCode.kLiteParamInvalid: - raise LLMParamInvalid("Parameters invalid") - if not ret.IsOk(): - role_str = 'Prompt' if self.role == LLMRole.Prompt else 'Decoder' - raise RuntimeError( - f"Failed to add_model, model paths {model_paths}, options {options}, postprocess path" - f" {postprocess_model_path}, role {role_str}, cluster id {self.cluster_id}") - llm_model = LLMModel(llm_model_inner, self.batch_mode_) - self.models_.append(llm_model) - return llm_model - - @set_env - def init(self, options: Dict[str, str]): - """ - Init LLMEngine. - - Args: - options (Dict[str, str]): init options of this LLMEngine object. - - Raises: - TypeError: `options` is not a dict. - RuntimeError: init LLMEngine failed. - """ - if self.inited_: - raise RuntimeError(f"LLMEngine has been inited") - if not self.models_: - raise RuntimeError(f"At least one group of models need to be added through LLMEngine.add_model before call" - f" LLMEngine.init.") - check_isinstance("options", options, dict) - for key, value in options.items(): - if not isinstance(key, str): - raise TypeError(f"options key must be str, but got {type(key)}.") - if not isinstance(value, str): - raise TypeError(f"options value must be str, but got {type(value)}.") - ret = self.engine_.init(options) - status_code = ret.StatusCode() - if status_code == StatusCode.kLiteParamInvalid: - raise LLMParamInvalid("Parameters invalid") - if not ret.IsOk(): - role_str = 'Prompt' if self.role == LLMRole.Prompt else 'Decoder' - raise RuntimeError(f"Failed to init LLMEngine, role {role_str}, cluster id {self.cluster_id}," - f" options {options}") - self.inited_ = True - for model in self.models_: - model.inited_ = True - - def complete_request(self, llm_req: LLMReq): - """ - Complete inference request. - - Args: - llm_req (LLMReq): Request of LLMEngine. - - Raises: - TypeError: `llm_req` is not a LLMReq. - RuntimeError: this LLMEngine object has not been inited. - """ - if not self.inited_: - raise RuntimeError(f"LLMEngine is not inited or init failed") - check_isinstance("llm_req", llm_req, LLMReq) - ret = self.engine_.complete_request(llm_req.llm_request_) - _handle_llm_status(ret, "complete_request", "llm_req " + _llm_req_str(llm_req)) - - def finalize(self): - """ - Finalize LLMEngine. - """ - if not self.inited_: - print(f"LLMEngine is not inited or init failed", flush=True) - return - self.engine_.finalize() - - def fetch_status(self): - """ - Get LLMEngine status. - - Returns: - LLMEngineStatus, LLMEngine status. - - Raises: - RuntimeError: this LLMEngine object has not been inited. - """ - if not self.inited_: - raise RuntimeError(f"LLMEngine is not inited or init failed") - status = self.engine_.fetch_status() - return LLMEngineStatus(status) - - def link_clusters(self, clusters: Union[List[LLMClusterInfo], Tuple[LLMClusterInfo]], timeout=-1): - """ - Link clusters. - - Args: - clusters (Union[List[LLMClusterInfo], Tuple[LLMClusterInfo]]): clusters. - timeout (int): timeout in seconds. - - Raises: - TypeError: `clusters` is not list/tuple of LLMClusterInfo. - RuntimeError: LLMEngine is not inited or init failed. - - Returns: - Status, tuple[Status], Whether all clusters link normally, and the link status of each cluster. - - Examples: - >>> import mindspore_lite as mslite - >>> cluster_id = 1 - >>> llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, cluster_id) - >>> model_paths = [os.path.join(model_dir, f"device_${rank}") for rank in range(4)] - >>> options = {} - >>> llm_engine.init(model_paths, options) - >>> cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 0) - >>> cluster.append_local_ip_info(("192.168.1.1", 2222)) - >>> cluster.append_remote_ip_info(("192.168.2.1", 2222)) - >>> cluster2 = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 1) - >>> cluster2.append_local_ip_info(("192.168.3.1", 2222)) - >>> cluster2.append_remote_ip_info(("192.168.4.2", 2222)) - >>> ret, rets = llm_engine.link_clusters((cluster, cluster2)) - >>> if not ret.IsOk(): - >>> for ret_item in rets: - >>> if not ret_item.IsOk(): - >>> # do something - """ - if not self.inited_: - raise RuntimeError(f"LLMEngine is not inited or init failed") - if not isinstance(clusters, (tuple, list)): - raise TypeError(f"clusters must be list/tuple of LLMClusterInfo, but got {type(clusters)}.") - check_isinstance("timeout", timeout, int) - for i, element in enumerate(clusters): - if not isinstance(element, LLMClusterInfo): - raise TypeError(f"clusters element must be LLMClusterInfo, but got {type(element)} at index {i}.") - clusters_inners = [item.llm_cluster_ for item in clusters] - ret, rets = self.engine_.link_clusters(clusters_inners, timeout) - if not rets: - _handle_llm_status(ret, "link_clusters", "") - return ret, rets - - def unlink_clusters(self, clusters: Union[List[LLMClusterInfo], Tuple[LLMClusterInfo]], timeout=-1): - """ - Unlink clusters. - - Args: - clusters (Union[List[LLMClusterInfo], Tuple[LLMClusterInfo]]): clusters. - timeout (int): LLMEngine is not inited or init failed. - - Raises: - TypeError: `clusters` is not list/tuple of LLMClusterInfo. - RuntimeError: Some error occurred. - - Returns: - Status, tuple[Status], Whether all clusters unlink normally, and the unlink status of each cluster. - - Examples: - >>> import mindspore_lite as mslite - >>> cluster_id = 1 - >>> llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, cluster_id) - >>> model_paths = [os.path.join(model_dir, f"device_${rank}") for rank in range(4)] - >>> options = {} - >>> llm_engine.init(model_paths, options) - >>> cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 0) - >>> cluster.append_local_ip_info(("192.168.1.1", 2222)) - >>> cluster.append_remote_ip_info(("192.168.2.1", 2222)) - >>> cluster2 = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 1) - >>> cluster2.append_local_ip_info(("192.168.3.1", 2222)) - >>> cluster2.append_remote_ip_info(("192.168.4.2", 2222)) - >>> ret, rets = llm_engine.unlink_clusters((cluster, cluster2)) - >>> if not ret.IsOk(): - >>> for ret_item in rets: - >>> if not ret_item.IsOk(): - >>> # do something - """ - if not self.inited_: - raise RuntimeError(f"LLMEngine is not inited or init failed") - if not isinstance(clusters, (tuple, list)): - raise TypeError(f"clusters must be list/tuple of LLMClusterInfo, but got {type(clusters)}.") - check_isinstance("timeout", timeout, int) - for i, element in enumerate(clusters): - if not isinstance(element, LLMClusterInfo): - raise TypeError(f"clusters element must be LLMClusterInfo, but got {type(element)} at index {i}.") - clusters_inners = [item.llm_cluster_ for item in clusters] - ret, rets = self.engine_.unlink_clusters(clusters_inners, timeout) - if not rets: - _handle_llm_status(ret, "unlink_clusters", "") - raise RuntimeError(f"Failed to call unlink_clusters") - return ret, rets diff --git a/mindspore-lite/python/api/model.py b/mindspore-lite/python/api/model.py index 733ce389a1cdab35f36d7ddeda6a7499bcbe071c..d32baf0a805be765a516212338bd30aacf5b7398 100644 --- a/mindspore-lite/python/api/model.py +++ b/mindspore-lite/python/api/model.py @@ -504,7 +504,7 @@ class ModelParallelRunner: Examples: >>> # Use case: serving inference. - >>> # precondition 1: Building MindSpore Lite serving package by export MSLITE_ENABLE_SERVER_INFERENCE=on. + >>> # precondition 1: Building MindSpore Lite serving package by export MSLITE_ENABLE_CLOUD_INFERENCE=on. >>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1. >>> import mindspore_lite as mslite >>> model_parallel_runner = mslite.ModelParallelRunner() @@ -517,7 +517,7 @@ class ModelParallelRunner: self._model = _c_lite_wrapper.ModelParallelRunnerBind() else: raise RuntimeError(f"ModelParallelRunner init failed, If you want to use it, you need to build" - f"MindSpore Lite serving package by export MSLITE_ENABLE_SERVER_INFERENCE=on.") + f"MindSpore Lite serving package by export MSLITE_ENABLE_CLOUD_INFERENCE=on.") self.model_path_ = "" def __str__(self): @@ -541,7 +541,7 @@ class ModelParallelRunner: Examples: >>> # Use case: serving inference. - >>> # precondition 1: Building MindSpore Lite serving package by export MSLITE_ENABLE_SERVER_INFERENCE=on. + >>> # precondition 1: Building MindSpore Lite serving package by export MSLITE_ENABLE_CLOUD_INFERENCE=on. >>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1. >>> import mindspore_lite as mslite >>> context = mslite.Context() @@ -576,7 +576,7 @@ class ModelParallelRunner: Examples: >>> # Use case: serving inference. - >>> # precondition 1: Building MindSpore Lite serving package by export MSLITE_ENABLE_SERVER_INFERENCE=on. + >>> # precondition 1: Building MindSpore Lite serving package by export MSLITE_ENABLE_CLOUD_INFERENCE=on. >>> # precondition 2: install wheel package of MindSpore Lite built by precondition 1. >>> import mindspore_lite as mslite >>> context = mslite.Context() @@ -611,7 +611,7 @@ class ModelParallelRunner: Examples: >>> # Use case: serving inference. >>> # Precondition 1: Download MindSpore Lite serving package or building MindSpore Lite serving package by - >>> # export MSLITE_ENABLE_SERVER_INFERENCE=on. + >>> # export MSLITE_ENABLE_CLOUD_INFERENCE=on. >>> # Precondition 2: Install wheel package of MindSpore Lite built by precondition 1. >>> # The result can be find in the tutorial of runtime_parallel_python. >>> import time diff --git a/mindspore-lite/python/setup.py b/mindspore-lite/python/setup.py index 7f682fb61352c70c71d7c15bee3d798bea0cdc6f..eef877ab500574a73bfe083ac3252000a726da76 100644 --- a/mindspore-lite/python/setup.py +++ b/mindspore-lite/python/setup.py @@ -30,12 +30,12 @@ def _read_file(filename): return f.read() -def is_enable_akg(): - """check if enable akg""" - enable_akg = os.getenv('ENABLE_AKG') - if enable_akg is not None and re.match('[Oo][Nn]', enable_akg) is not None: - return True - return False +# def is_enable_akg(): +# """check if enable akg""" +# enable_akg = os.getenv('ENABLE_AKG') +# if enable_akg is not None and re.match('[Oo][Nn]', enable_akg) is not None: +# return True +# return False def _get_package_data(): @@ -55,10 +55,10 @@ def _get_package_data(): pkg_data.extend(custom_ops_data) if os.getenv('MSLITE_ENABLE_CLOUD_INFERENCE') == "on": pkg_data.append('lite_infer.py') - if is_enable_akg(): - akg_data = ['akg/*.so*', 'akg/*.cuh', 'akg/config/*', 'akg/composite/*', 'akg/include/*', 'akg/include/*/*', - 'akg/include/*/*/*', 'akg/include/*/*/*/*'] - pkg_data.extend(akg_data) + # if is_enable_akg(): + # akg_data = ['akg/*.so*', 'akg/*.cuh', 'akg/config/*', 'akg/composite/*', 'akg/include/*', 'akg/include/*/*', + # 'akg/include/*/*/*', 'akg/include/*/*/*/*'] + # pkg_data.extend(akg_data) return pkg_data diff --git a/mindspore-lite/python/src/llm_engine_pybind.cc b/mindspore-lite/python/src/llm_engine_pybind.cc deleted file mode 100644 index bef25284f6c67edf90cca23ddd3e0e50e3797e71..0000000000000000000000000000000000000000 --- a/mindspore-lite/python/src/llm_engine_pybind.cc +++ /dev/null @@ -1,217 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "include/api/types.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "extendrt/cxx_api/llm_engine/llm_engine.h" -#include "src/common/log_adapter.h" -#include "python/src/common_pybind.h" - -namespace mindspore::lite { -namespace py = pybind11; - -std::pair, Status> PyLLMModelPredict(LLMModel *llm_model, const LLMReq &req, - const std::vector &inputs_ptr) { - if (llm_model == nullptr) { - MS_LOG(ERROR) << "Model object cannot be nullptr"; - return {}; - } - std::vector inputs = MSTensorPtrToMSTensor(inputs_ptr); - std::vector outputs; - auto status = llm_model->Predict(req, inputs, &outputs); - if (!status.IsOk()) { - return {{}, status}; - } - return {MSTensorToMSTensorPtr(outputs), status}; -} - -std::pair, Status> PyLLMModelPredictBatch(LLMModel *llm_model, const std::vector &req, - const std::vector &inputs_ptr) { - if (llm_model == nullptr) { - MS_LOG(ERROR) << "Model object cannot be nullptr"; - return {}; - } - std::vector inputs = MSTensorPtrToMSTensor(inputs_ptr); - std::vector outputs; - auto status = llm_model->Predict(req, inputs, &outputs); - if (!status.IsOk()) { - return {{}, status}; - } - return {MSTensorToMSTensorPtr(outputs), status}; -} - -Status PyLLMModelPreloadPromptPrefix(LLMModel *llm_model, const LLMReq &req, - const std::vector &inputs_ptr) { - if (llm_model == nullptr) { - MS_LOG(ERROR) << "Model object cannot be nullptr"; - return kLiteError; - } - std::vector inputs = MSTensorPtrToMSTensor(inputs_ptr); - return llm_model->PreloadPromptPrefix(req, inputs); -} - -std::pair> PyLLMEngineLinkClusters(LLMEngine *llm_engine, - const std::vector &clusters, - int32_t timeout) { - if (llm_engine == nullptr) { - MS_LOG(ERROR) << "LLMEngine object cannot be nullptr"; - return {kLiteError, {}}; - } - std::vector rets; - auto ret = llm_engine->LinkClusters(clusters, &rets, timeout); - return {ret, rets}; -} - -std::pair> PyLLMEngineUnlinkClusters(LLMEngine *llm_engine, - const std::vector &clusters, - int32_t timeout) { - if (llm_engine == nullptr) { - MS_LOG(ERROR) << "LLMEngine object cannot be nullptr"; - return {kLiteError, {}}; - } - std::vector rets; - auto ret = llm_engine->UnlinkClusters(clusters, &rets, timeout); - return {ret, rets}; -} - -void PyLLMClusterAppendLocalIpInfo(LLMClusterInfo *cluster_info, uint32_t ip, uint16_t port) { - if (cluster_info == nullptr) { - MS_LOG(ERROR) << "LLMClusterInfo object cannot be nullptr"; - return; - } - LLMIpInfo ip_info; - ip_info.ip = ip; - ip_info.port = port; - cluster_info->local_ip_infos.push_back(ip_info); -} - -std::vector> PyLLMClusterGetLocalIpInfo(LLMClusterInfo *cluster_info) { - if (cluster_info == nullptr) { - MS_LOG(ERROR) << "LLMClusterInfo object cannot be nullptr"; - return {}; - } - std::vector> ip_infos; - auto &local_ip_infos = cluster_info->local_ip_infos; - (void)std::transform(local_ip_infos.begin(), local_ip_infos.end(), std::back_inserter(ip_infos), - [](auto &item) { return std::make_pair(item.ip, item.port); }); - return ip_infos; -} - -void PyLLMClusterAppendRemoteIpInfo(LLMClusterInfo *cluster_info, uint32_t ip, uint16_t port) { - if (cluster_info == nullptr) { - MS_LOG(ERROR) << "LLMClusterInfo object cannot be nullptr"; - return; - } - LLMIpInfo ip_info; - ip_info.ip = ip; - ip_info.port = port; - cluster_info->remote_ip_infos.push_back(ip_info); -} - -std::vector> PyLLMClusterGetRemoteIpInfo(LLMClusterInfo *cluster_info) { - if (cluster_info == nullptr) { - MS_LOG(ERROR) << "LLMClusterInfo object cannot be nullptr"; - return {}; - } - std::vector> ip_infos; - auto &remote_ip_infos = cluster_info->remote_ip_infos; - (void)std::transform(remote_ip_infos.begin(), remote_ip_infos.end(), std::back_inserter(ip_infos), - [](auto &item) { return std::make_pair(item.ip, item.port); }); - return ip_infos; -} - -std::pair> PyLLMEngineAddModel(LLMEngine *llm_engine, - const std::vector &model_paths, - const std::map &options, - const std::string &postprocess_model_path) { - if (llm_engine == nullptr) { - MS_LOG(ERROR) << "LLMClusterInfo object cannot be nullptr"; - return {kLiteError, nullptr}; - } - auto llm_model = std::make_shared(); - if (llm_model == nullptr) { - MS_LOG(ERROR) << "Failed to create LLMModel object"; - return {kLiteError, nullptr}; - } - auto status = llm_engine->AddModel(llm_model.get(), model_paths, options, postprocess_model_path); - return {status, llm_model}; -} - -std::vector PyLLMModelGetInputs(LLMModel *model) { - if (model == nullptr) { - MS_LOG(ERROR) << "model object cannot be nullptr"; - return {}; - } - return MSTensorToMSTensorPtr(model->GetInputs()); -} - -void LLMEnginePyBind(const py::module &m) { - (void)py::enum_(m, "LLMRole_", py::arithmetic()) - .value("Prompt", LLMRole::kLLMRolePrompt) - .value("Decoder", LLMRole::kLLMRoleDecoder); - - py::class_(m, "LLMReq_") - .def(py::init<>()) - .def_readwrite("req_id", &LLMReq::req_id) - .def_readwrite("prompt_length", &LLMReq::prompt_length) - .def_readwrite("prompt_cluster_id", &LLMReq::prompt_cluster_id) - .def_readwrite("decoder_cluster_id", &LLMReq::decoder_cluster_id) - .def_readwrite("prefix_id", &LLMReq::prefix_id) - .def_readwrite("sequence_length", &LLMReq::sequence_length); - - py::class_(m, "LLMClusterInfo_") - .def(py::init<>()) - .def_readwrite("remote_cluster_id", &LLMClusterInfo::remote_cluster_id) - .def_readwrite("remote_role_type", &LLMClusterInfo::remote_role_type) - .def("append_local_ip_info", &PyLLMClusterAppendLocalIpInfo) - .def("append_remote_ip_info", &PyLLMClusterAppendRemoteIpInfo) - .def("get_local_ip_infos", &PyLLMClusterGetLocalIpInfo) - .def("get_remote_ip_infos", &PyLLMClusterGetRemoteIpInfo); - - py::class_(m, "LLMTensorInfo_") - .def(py::init<>()) - .def_readwrite("name", &LLMTensorInfo::name) - .def_readwrite("shape", &LLMTensorInfo::shape) - .def_readwrite("dtype", &LLMTensorInfo::dtype); - - py::class_(m, "LLMEngineStatus_") - .def(py::init<>()) - .def_readwrite("empty_max_prompt_kv", &LLMEngineStatus::empty_max_prompt_kv) - .def_readwrite("num_free_blocks", &LLMEngineStatus::num_free_blocks) - .def_readwrite("num_total_blocks", &LLMEngineStatus::num_total_blocks) - .def_readwrite("block_size", &LLMEngineStatus::block_size); - - (void)py::class_>(m, "LLMModel_") - .def(py::init<>()) - .def("predict", &PyLLMModelPredict, py::call_guard()) - .def("predict_batch", &PyLLMModelPredictBatch, py::call_guard()) - .def("preload_prompt_prefix", &PyLLMModelPreloadPromptPrefix, py::call_guard()) - .def("release_prompt_prefix", &LLMModel::ReleasePromptPrefix, py::call_guard()) - .def("pull_kv", &LLMModel::PullKV, py::call_guard()) - .def("merge_kv", &LLMModel::MergeKV, py::call_guard()) - .def("get_inputs", &PyLLMModelGetInputs); - - (void)py::class_>(m, "LLMEngine_") - .def(py::init()) - .def("add_model", &PyLLMEngineAddModel, py::call_guard()) - .def("init", &LLMEngine::Init, py::call_guard()) - .def("finalize", &LLMEngine::Finalize, py::call_guard()) - .def("fetch_status", &LLMEngine::FetchStatus, py::call_guard()) - .def("link_clusters", &PyLLMEngineLinkClusters, py::call_guard()) - .def("unlink_clusters", &PyLLMEngineUnlinkClusters, py::call_guard()) - .def("complete_request", &LLMEngine::CompleteRequest, py::call_guard()); -} -} // namespace mindspore::lite diff --git a/mindspore-lite/python/src/model_pybind.cc b/mindspore-lite/python/src/model_pybind.cc index 036a56c189bf438353d45e0b15e39ce5164b308d..f7e69a8c11c318789598b8de9ca21caa3ee6bc28 100644 --- a/mindspore-lite/python/src/model_pybind.cc +++ b/mindspore-lite/python/src/model_pybind.cc @@ -227,7 +227,7 @@ void ModelPyBind(const py::module &m) { [](Model &model, const std::string &tensor_name) { return model.GetOutputByTensorName(tensor_name); }); } -#ifdef PARALLEL_INFERENCE +#ifdef ENABLE_CLOUD_INFERENCE std::vector PyModelParallelRunnerPredict(ModelParallelRunner *runner, const std::vector &inputs_ptr, const std::vector &outputs_ptr, @@ -266,7 +266,7 @@ std::vector PyModelParallelRunnerGetOutputs(ModelParallelRunner *ru #endif void ModelParallelRunnerPyBind(const py::module &m) { -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE (void)py::class_>(m, "RunnerConfigBind") .def(py::init<>()) .def("set_config_info", py::overload_cast &>( diff --git a/mindspore-lite/python/src/pybind_module.cc b/mindspore-lite/python/src/pybind_module.cc index 808635a7c62b2f482cf8f4c68325ea0d7ad407f4..9c8bcacafe01ad12414662993049ac4ba909501d 100644 --- a/mindspore-lite/python/src/pybind_module.cc +++ b/mindspore-lite/python/src/pybind_module.cc @@ -14,7 +14,6 @@ * limitations under the License. */ #include "include/api/types.h" -#include "src/extendrt/session/single_op_session.h" #include "pybind11/pybind11.h" #include "pybind11/numpy.h" #include "pybind11/stl.h" @@ -31,7 +30,7 @@ void LiteInferPyBind(const py::module &m); void ModelParallelRunnerPyBind(const py::module &m); void ModelGroupPyBind(const py::module &m); void TensorPyBind(const py::module &m); -void LLMEnginePyBind(const py::module &m); +// void LLMEnginePyBind(const py::module &m); std::shared_ptr create_tensor(DataType data_type, const std::vector &shape, const std::string &device_type, int device_id); std::shared_ptr create_tensor_by_tensor(const MSTensor &tensor, const std::string &device_type, @@ -51,13 +50,15 @@ PYBIND11_MODULE(_c_lite_wrapper, m) { ModelParallelRunnerPyBind(m); ModelGroupPyBind(m); TensorPyBind(m); - LLMEnginePyBind(m); + // LLMEnginePyBind(m); m.def("create_tensor", &create_tensor); m.def("create_tensor_by_tensor", &create_tensor_by_tensor); m.def("create_tensor_by_numpy", &create_tensor_by_numpy); // call aclFinalize manually before exit. (void)py::module::import("atexit").attr("register")( - py::cpp_function{[&]() -> void { SingleOpInferSession::AscendFinalize(); }}); + py::cpp_function{[&]() -> void { + return; + }}); } } // namespace mindspore::lite diff --git a/mindspore-lite/python/src/tensor_numpy_impl.h b/mindspore-lite/python/src/tensor_numpy_impl.h index ea1ba70b66c6136aecc3667a7a5a075a299f5917..f1a6b7f26e82bd82024434cf6a0b8eb30f400b35 100644 --- a/mindspore-lite/python/src/tensor_numpy_impl.h +++ b/mindspore-lite/python/src/tensor_numpy_impl.h @@ -29,7 +29,7 @@ #include "pybind11/pybind11.h" #include "pybind11/numpy.h" #ifdef ENABLE_CLOUD_INFERENCE -#include "extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h" +#include "src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h" #endif namespace py = pybind11; @@ -49,7 +49,7 @@ class TensorNumpyImpl : public MutableTensorImpl { } if (device_data_ != nullptr) { MS_LOG(INFO) << "free device data in tensor numpy impl."; - kernel::AscendAllocatorPlugin::GetInstance().Free(device_data_, device_id_); + AscendAllocatorPlugin::GetInstance().Free(device_data_, device_id_); } } const std::vector &Shape() const override { return ms_shape_; } @@ -84,7 +84,7 @@ class TensorNumpyImpl : public MutableTensorImpl { void SetDeviceData(void *data) override { #ifdef ENABLE_CLOUD_INFERENCE if (device_data_ != nullptr) { - kernel::AscendAllocatorPlugin::GetInstance().Free(device_data_, device_id_); + AscendAllocatorPlugin::GetInstance().Free(device_data_, device_id_); } device_data_ = data; return; diff --git a/mindspore-lite/python/src/tensor_pybind.cc b/mindspore-lite/python/src/tensor_pybind.cc index f86672dfde9ce6075c0ce82d90318c7f7f399c40..8bf2452d99f69a7091b4cda091040199e86857cc 100644 --- a/mindspore-lite/python/src/tensor_pybind.cc +++ b/mindspore-lite/python/src/tensor_pybind.cc @@ -28,7 +28,7 @@ #include "numpy/arrayobject.h" #include "pybind11/stl.h" #ifdef ENABLE_CLOUD_INFERENCE -#include "extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h" +#include "src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h" #endif namespace mindspore::lite { namespace { @@ -182,14 +182,14 @@ MSTensorPtr create_tensor_by_numpy(const py::array &input, const std::string &de } #ifdef ENABLE_CLOUD_INFERENCE if (device_type == "ascend") { - kernel::AscendAllocatorPlugin::GetInstance().Register(); - device_id = device_id == -1 ? kernel::AscendAllocatorPlugin::GetInstance().GetCurrentDeviceId() : device_id; - auto device_data = kernel::AscendAllocatorPlugin::GetInstance().Malloc(data_size, device_id); + AscendAllocatorPlugin::GetInstance().Register(); + device_id = device_id == -1 ? AscendAllocatorPlugin::GetInstance().GetCurrentDeviceId() : device_id; + auto device_data = AscendAllocatorPlugin::GetInstance().Malloc(data_size, device_id); if (device_data == nullptr) { MS_LOG(ERROR) << "Malloc device data for numpy tensor failed."; return nullptr; } - auto status = kernel::AscendAllocatorPlugin::GetInstance().CopyHostDataToDevice(numpy_tensor->MutableData(), + auto status = AscendAllocatorPlugin::GetInstance().CopyHostDataToDevice(numpy_tensor->MutableData(), device_data, data_size); if (status != kSuccess) { MS_LOG(ERROR) << "tensor has device data, then copy host data to device failed."; @@ -260,7 +260,7 @@ bool SetTensorNumpyData(const MSTensorPtr &tensor_ptr, const py::array &input) { #ifdef ENABLE_CLOUD_INFERENCE if (tensor.GetDeviceData() != nullptr) { MS_LOG(INFO) << "device tensor data ptr is not nullptr, need copy host data to device data."; - auto status = kernel::AscendAllocatorPlugin::GetInstance().CopyHostDataToDevice( + auto status = AscendAllocatorPlugin::GetInstance().CopyHostDataToDevice( py_buffer_info.ptr, tensor.GetDeviceData(), tensor.DataSize()); if (status != kSuccess) { MS_LOG(ERROR) << "tensor has device data, then copy host data to device failed."; @@ -292,7 +292,7 @@ py::buffer_info GetPyBufferInfo(const MSTensorPtr &tensor) { if (device_data != nullptr) { MS_LOG(INFO) << "need copy host data to device."; // device data is not nullptr, data in device, need copy device data to host. - auto status = kernel::AscendAllocatorPlugin::GetInstance().CopyDeviceDataToHost( + auto status = AscendAllocatorPlugin::GetInstance().CopyDeviceDataToHost( device_data, tensor->MutableData(), tensor->DataSize(), tensor->GetDeviceId()); if (status != kSuccess) { MS_LOG(ERROR) << "tensor has device data, then copy device data to host failed."; diff --git a/mindspore-lite/src/CMakeLists.txt b/mindspore-lite/src/CMakeLists.txt index ca9cbb3b65b519059e4d3b138bee4d00dafd9cbb..cd6e5cb9184736b0204579bbb92f6208dac238f7 100644 --- a/mindspore-lite/src/CMakeLists.txt +++ b/mindspore-lite/src/CMakeLists.txt @@ -97,7 +97,8 @@ set(API_SRC ${CXX_API_SRCS} ${C_API_SRCS}) if(NOT MSLITE_ENABLE_RUNTIME_CONVERT) set(API_SRC ${API_SRC} ${CORE_DIR}/utils/status.cc) endif() -if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") +if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full" AND NOT + (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) file(GLOB CXX_API_TRAIN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/litert/cxx_api/train/model.cc ${CMAKE_CURRENT_SOURCE_DIR}/litert/cxx_api/train/model_impl.cc @@ -151,13 +152,6 @@ set(LITE_SRC file(GLOB FORMAT_PASS_SRC "${CMAKE_CURRENT_SOURCE_DIR}/litert/pass/format_pass/*.cc") set(LITE_SRC ${LITE_SRC} ${FORMAT_PASS_SRC}) -if(NOT ANDROID_NDK_TOOLCHAIN_INCLUDED) - set(LITE_SRC - ${LITE_SRC} - ${CMAKE_CURRENT_SOURCE_DIR}/litert/kernel/ascend/plugin/ascend_kernel_plugin.cc - ) -endif() - set(MODEL_LOADER_FRAMEWORK_SRC ${MODEL_LOADER_FRAMEWORK_SRC} ${CMAKE_CURRENT_SOURCE_DIR}/extendrt/mindir_loader/model_loader.cc @@ -212,7 +206,7 @@ if(MSLITE_ENABLE_RUNTIME_GLOG) string(REPLACE "-fno-rtti" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) if(NOT MSLITE_ENABLE_RUNTIME_CONVERT AND NOT MSLITE_ENABLE_KERNEL_EXECUTOR) set(LITE_SRC ${LITE_SRC} - ${CORE_DIR}/utils/log_adapter.cc) + ${CORE_DIR}/utils/log_adapter.cc) endif() endif() @@ -274,10 +268,11 @@ if(MSLITE_GPU_BACKEND STREQUAL cuda) ${CUDA_RUNTIME_SRC} ) endif() -if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") - set(TRAIN_SRC_WITH_MD - ${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc - ) +if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full" AND NOT + (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) + set(TRAIN_SRC_WITH_MD + ${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc + ) endif() set(TRAIN_SRC @@ -351,18 +346,9 @@ elseif(TARGET_AOS_ARM) ) else() set(LITE_SRC ${LITE_SRC} - ${CORE_DIR}/mindrt/src/thread/core_affinity.cc - ${CORE_DIR}/mindrt/src/thread/threadpool.cc - ) -endif() - -if(MSLITE_ENABLE_GRAPH_KERNEL) - file(GLOB_RECURSE GRAPH_KERNEL_SRC - ${TOOLS_DIR}/graph_kernel/common/*.cc - ${TOOLS_DIR}/graph_kernel/runtime/*.cc - ${OPS_DIR}/kernel/cpu/akg/akg_kernle_loader.cc - ) - set(LITE_SRC ${LITE_SRC} ${GRAPH_KERNEL_SRC}) + ${CORE_DIR}/mindrt/src/thread/core_affinity.cc + ${CORE_DIR}/mindrt/src/thread/threadpool.cc + ) endif() if(NOT MSLITE_ENABLE_COREML) @@ -373,9 +359,6 @@ endif() add_subdirectory(litert/kernel/cpu) add_subdirectory(common) -if(MSLITE_EXPORT_COMPUTE_IR) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/common/draw) -endif() add_library(lite_src_mid OBJECT ${LITE_SRC}) add_dependencies(lite_src_mid lite_src_common_mid fbs_src fbs_inner_src) @@ -443,16 +426,8 @@ if(MSVC) set_target_properties(mindspore-lite_static PROPERTIES PREFIX lib) endif() -# if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE) -# target_link_libraries(mindspore-lite cpu_kernel_mid nnacl_mid cpu_ops_mid mindspore_core mindspore::protobuf) -# target_link_libraries(mindspore-lite_static cpu_kernel_mid nnacl_mid cpu_ops_mid mindspore_core mindspore::protobuf) -# else() target_link_libraries(mindspore-lite cpu_kernel_mid nnacl_mid cpu_ops_mid) target_link_libraries(mindspore-lite_static cpu_kernel_mid nnacl_mid cpu_ops_mid) -# endif() - -# target_link_libraries(mindspore-lite mindspore-extendrt) -# target_link_libraries(mindspore-lite_static mindspore-extendrt) if(SUPPORT_TRAIN) target_link_libraries(mindspore-lite train_cpu_kernel_mid) @@ -489,7 +464,8 @@ if(TARGET_OHOS) target_link_libraries(mindspore-lite_static hilog_ndk.z.so) endif() -if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite") +if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" AND NOT + (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) target_link_libraries(mindspore-lite minddata_eager_mid minddata-lite) target_link_libraries(mindspore-lite_static minddata_eager_mid) endif() @@ -502,20 +478,22 @@ if(SUPPORT_TRAIN) target_link_libraries(mindspore-lite-train lite_src_train_common_mid securec) set_target_properties(mindspore-lite-train PROPERTIES OUTPUT_NAME "mindspore-lite-train") set_target_properties(mindspore-lite-train PROPERTIES CLEAN_DIRECT_OUTPUT 1) - if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") - target_link_libraries(mindspore-lite-train minddata-lite mindspore-lite) + if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full" AND NOT + (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) + target_link_libraries(mindspore-lite-train minddata-lite mindspore-lite) else() - target_link_libraries(mindspore-lite-train mindspore-lite) + target_link_libraries(mindspore-lite-train mindspore-lite) endif() add_library(mindspore-lite-train_static STATIC $) target_link_libraries(mindspore-lite-train_static lite_src_train_common_mid) set_target_properties(mindspore-lite-train_static PROPERTIES OUTPUT_NAME "mindspore-lite-train") set_target_properties(mindspore-lite-train_static PROPERTIES CLEAN_DIRECT_OUTPUT 1) - if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") - target_link_libraries(mindspore-lite-train_static minddata-lite mindspore-lite) + if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full" AND NOT + (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) + target_link_libraries(mindspore-lite-train_static minddata-lite mindspore-lite) else() - target_link_libraries(mindspore-lite-train_static mindspore-lite) + target_link_libraries(mindspore-lite-train_static mindspore-lite) endif() endif() @@ -543,12 +521,12 @@ endif() if(MSLITE_ENABLE_RUNTIME_CONVERT) target_link_libraries(mindspore-lite quantizer_mid fusion_mid proto_mid graph_pass_mid preprocess_mid - cpu_kernel_mid ccsrc_src_mid converter_src_mid lite_exporter_mid + cpu_kernel_mid converter_src_mid lite_exporter_mid config_parser_mid mslite_converter_plugin mindspore_core mindspore_ops coder_mid mindir_serializer_mid mindspore::protobuf ${SECUREC_LIBRARY}) target_link_libraries(mindspore-lite_static quantizer_mid fusion_mid proto_mid graph_pass_mid preprocess_mid - cpu_kernel_mid ccsrc_src_mid converter_src_mid lite_exporter_mid + cpu_kernel_mid converter_src_mid lite_exporter_mid config_parser_mid mslite_converter_plugin mindspore_core mindspore_ops coder_mid mindir_serializer_mid mindspore::protobuf ${SECUREC_LIBRARY}) target_link_libraries(mindspore-lite diff --git a/mindspore-lite/src/common/CMakeLists.txt b/mindspore-lite/src/common/CMakeLists.txt index c013142e36193f33d4c4aa413d97fb7b77ef64ad..3548a83802a497ad28ae00372570f55c39d43532 100644 --- a/mindspore-lite/src/common/CMakeLists.txt +++ b/mindspore-lite/src/common/CMakeLists.txt @@ -14,20 +14,13 @@ set(LITE_SRC_COMMON_MID_SRC ${CMAKE_CURRENT_SOURCE_DIR}/config_infos.cc ${CMAKE_CURRENT_SOURCE_DIR}/helper/external_tensor/file_helper.cc ${CMAKE_CURRENT_SOURCE_DIR}/helper/external_tensor/memory_helper.cc + ${CMAKE_CURRENT_SOURCE_DIR}/crypto.cc ) if(NOT MSLITE_ENABLE_RUNTIME_CONVERT OR MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) set(LITE_SRC_COMMON_MID_SRC ${LITE_SRC_COMMON_MID_SRC} - ${CMAKE_CURRENT_SOURCE_DIR}/config_file.cc - ) -endif() - -if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) - set(LITE_SRC_COMMON_MID_SRC - ${LITE_SRC_COMMON_MID_SRC} - ${CMAKE_CURRENT_SOURCE_DIR}/crypto.cc - ) + ${CMAKE_CURRENT_SOURCE_DIR}/config_file.cc) endif() if(NOT ANDROID_NDK_TOOLCHAIN_INCLUDED) @@ -60,12 +53,12 @@ if(MSLITE_ENABLE_STRING_KERNEL) ) endif() -if(MSLITE_ENABLE_GRAPH_KERNEL) - set(LITE_SRC_COMMON_MID_SRC - ${LITE_SRC_COMMON_MID_SRC} - ${CMAKE_CURRENT_SOURCE_DIR}/dynamic_library_loader.cc - ) -endif() +#if(MSLITE_ENABLE_GRAPH_KERNEL) +# set(LITE_SRC_COMMON_MID_SRC +# ${LITE_SRC_COMMON_MID_SRC} +# ${CMAKE_CURRENT_SOURCE_DIR}/dynamic_library_loader.cc +# ) +#endif() if(MSVC) set(LITE_SRC_COMMON_MID_SRC diff --git a/mindspore-lite/src/common/crypto.cc b/mindspore-lite/src/common/crypto.cc index addc354e8599a9f4381e9e6ed4f4385ace820438..8d605a6b019265c03d5b7706c46f579e4f01b41b 100644 --- a/mindspore-lite/src/common/crypto.cc +++ b/mindspore-lite/src/common/crypto.cc @@ -29,7 +29,6 @@ #endif #include "src/common/log_adapter.h" #include "src/common/log_util.h" -#include "securec/include/securec.h" #ifndef SECUREC_MEM_MAX_LEN #define SECUREC_MEM_MAX_LEN 0x7fffffffUL @@ -73,7 +72,7 @@ void IntToByte(std::vector *byteArray, int32_t n) { auto ptr = reinterpret_cast(&n); (*byteArray).assign(ptr, ptr + sizeof(int32_t)); } - +} // namespace bool ParseEncryptData(const Byte *encrypt_data, size_t encrypt_len, std::vector *iv, std::vector *cipher_data) { // encrypt_data is organized in order to iv_len, iv, cipher_len, cipher_data @@ -112,7 +111,6 @@ bool ParseMode(const std::string &mode, std::string *alg_mode, std::string *work *work_mode = results[2]; return true; } -} // namespace int InitCipherCtxAES(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *(*funcPtr)(), const std::string &work_mode, const Byte *key, const Byte *iv, int iv_len, bool is_encrypt) { @@ -274,7 +272,6 @@ EVP_CIPHER_CTX *GetEvpCipherCtx(const std::string &alg_mode, const std::string & return ctx; } -namespace { bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vector &plain_data, const Byte *key, int32_t key_len, const std::string &enc_mode, unsigned char *tag) { size_t encrypt_data_buf_len = *encrypt_data_len; @@ -366,7 +363,6 @@ bool BlockEncrypt(Byte *encrypt_data, size_t *encrypt_data_len, const std::vecto *encrypt_data_len += sizeof(int32_t); return true; } -} // namespace bool BlockDecrypt(Byte *plain_data, int32_t *plain_len, const Byte *encrypt_data, size_t encrypt_len, const Byte *key, int32_t key_len, const std::string &dec_mode, unsigned char *tag) { diff --git a/mindspore-lite/src/common/draw/CMakeLists.txt b/mindspore-lite/src/common/draw/CMakeLists.txt deleted file mode 100644 index c645cc64332c9df8f1ca7ec44a79a0d08489ac7c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/common/draw/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -add_library(mindspore_lite_drawer - ${CMAKE_CURRENT_SOURCE_DIR}/drawer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graphviz_graph.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graphviz_graph_builder.cc - ) diff --git a/mindspore-lite/src/common/draw/adapter_graph.h b/mindspore-lite/src/common/draw/adapter_graph.h deleted file mode 100644 index 4ed58ea5f2559900e8724f8ba0b1906a78009e95..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/common/draw/adapter_graph.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPH_H_ -#define MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPH_H_ - -#include -#include -#include -#include -#include "src/common/log_adapter.h" -#include "src/common/draw/graphviz_graph.h" -#include "src/tensor.h" -#include "include/errorcode.h" - -namespace mindspore::lite { -class AdapterNode { - public: - virtual ~AdapterNode() = default; - virtual std::string GetName() const = 0; - virtual std::vector GetInputs() const = 0; - virtual Tensor *GetInput(const size_t &index) const = 0; - virtual size_t InputSize() const = 0; - virtual std::vector GetOutputs() const = 0; - virtual Tensor *GetOutput(const size_t &index) const = 0; - virtual size_t OutputSize() const = 0; - virtual bool IsHighlight() const { return false; } -}; - -class AdapterGraph { - public: - virtual ~AdapterGraph() = default; - virtual std::string GetName() const = 0; - virtual std::vector GetNodes() const = 0; - virtual std::vector GetInputs() const = 0; - virtual size_t InputSize() const = 0; - virtual std::vector GetOutputs() const = 0; - virtual size_t OutputSize() const = 0; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPH_H_ diff --git a/mindspore-lite/src/common/draw/adapter_graphs/compile_result_adapter_graph.h b/mindspore-lite/src/common/draw/adapter_graphs/compile_result_adapter_graph.h deleted file mode 100644 index 72b32cb99ed3a1557c1b1e2e67e3366427cbb1ac..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/common/draw/adapter_graphs/compile_result_adapter_graph.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef ENABLE_DRAW -#ifndef MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_COMPILE_RESULT_ADAPTER_GRAPH_H_ -#define MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_COMPILE_RESULT_ADAPTER_GRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include "src/common/log_adapter.h" -#include "src/common/draw/adapter_graph.h" -#include "src/common/draw/graphviz_graph_builder.h" -#include "include/errorcode.h" -#include "src/extendrt/graph_compiler/compile_result.h" - -namespace mindspore::lite { -class CompileNodeAdapterNode : public AdapterNode { - public: - explicit CompileNodeAdapterNode(CompileNodePtr node) : node_(std::move(node)) {} - - std::string GetName() const override { return node_->GetName(); } - std::vector GetInputs() const override { return node_->GetInputs(); } - Tensor *GetInput(const size_t &index) const override { - if (index >= InputSize()) { - return nullptr; - } - return node_->GetInput(index); - } - size_t InputSize() const override { return node_->InputSize(); } - std::vector GetOutputs() const override { return node_->GetOutputs(); } - Tensor *GetOutput(const size_t &index) const override { - if (index >= OutputSize()) { - return nullptr; - } - return node_->GetOutput(index); - } - size_t OutputSize() const override { return node_->OutputSize(); } - - private: - const CompileNodePtr node_; -}; - -class CompileResultAdapterGraph : public AdapterGraph { - public: - static std::shared_ptr Create(const CompileResult *graph) { - auto adapter_graph = std::make_shared(graph); - for (const auto &node : graph->GetNodes()) { - adapter_graph->nodes_.emplace_back(new CompileNodeAdapterNode(node)); - } - return adapter_graph; - } - - explicit CompileResultAdapterGraph(const CompileResult *graph) : graph_(graph) {} - ~CompileResultAdapterGraph() override { - for (auto node : nodes_) { - delete node; - } - nodes_.clear(); - } - std::string GetName() const override { return "CompileResult"; } - std::vector GetNodes() const override { return nodes_; } - std::vector GetInputs() const override { return graph_->GetInputs(); } - size_t InputSize() const override { return graph_->InputSize(); } - std::vector GetOutputs() const override { return graph_->GetOutputs(); } - size_t OutputSize() const override { return graph_->OutputSize(); } - - private: - const CompileResult *graph_; - std::vector nodes_; -}; - -std::shared_ptr CreateGVGraph(const CompileResult *graph) { - auto adapter_graph = CompileResultAdapterGraph::Create(graph); - if (adapter_graph == nullptr) { - MS_LOG(ERROR) << "Create CompileResultAdapterGraph failed."; - return nullptr; - } - GVGraphBuilder builder; - auto gv_graph = builder.Build(adapter_graph); - if (gv_graph == nullptr) { - MS_LOG(ERROR) << "Build gv_graph failed."; - return nullptr; - } - return gv_graph; -} -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_COMPILE_RESULT_ADAPTER_GRAPH_H_ -#endif diff --git a/mindspore-lite/src/common/draw/adapter_graphs/sub_graph_kernel_adapter_graph.h b/mindspore-lite/src/common/draw/adapter_graphs/sub_graph_kernel_adapter_graph.h deleted file mode 100644 index c92c9a2d68fab02fcdbc626f092743d8d8f8768a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/common/draw/adapter_graphs/sub_graph_kernel_adapter_graph.h +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifdef ENABLE_DRAW -#ifndef MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_SUB_GRAPH_KERNEL_ADAPTER_GRAPH_H_ -#define MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_SUB_GRAPH_KERNEL_ADAPTER_GRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include "src/common/log_adapter.h" -#include "src/common/draw/adapter_graph.h" -#include "src/common/draw/graphviz_graph_builder.h" -#include "include/errorcode.h" -#include "src/litert/kernel_exec_util.h" -#include "src/executor/kernel_exec.h" -#include "src/executor/sub_graph_kernel.h" -#include "src/common/draw/adapter_graphs/drawer_mark_filter.h" - -namespace mindspore::lite { -class KernelExecAdapterNode : public AdapterNode { - public: - explicit KernelExecAdapterNode(const kernel::KernelExec *kernel, MarkFilter mark_filter = nullptr) - : kernel_(kernel), filter_(std::move(mark_filter)) {} - - std::string GetName() const override { return kernel_->name(); } - std::vector GetInputs() const override { return kernel_->in_tensors(); } - Tensor *GetInput(const size_t &index) const override { - if (index >= InputSize()) { - return nullptr; - } - return kernel_->in_tensors()[index]; - } - size_t InputSize() const override { return kernel_->in_tensors().size(); } - std::vector GetOutputs() const override { return kernel_->out_tensors(); } - Tensor *GetOutput(const size_t &index) const override { - if (index >= OutputSize()) { - return nullptr; - } - return kernel_->out_tensors()[index]; - } - size_t OutputSize() const override { return kernel_->out_tensors().size(); } - - bool IsHighlight() const override { - if (filter_ == nullptr) { - return false; - } - return filter_(*kernel_); - } - - private: - const kernel::KernelExec *kernel_; - const MarkFilter filter_; -}; - -class SubGraphKernelAdapterGraph : public AdapterGraph { - public: - static std::shared_ptr Create(const kernel::SubGraphKernel *graph, - const MarkFilter &mark_filter = nullptr) { - auto adapter_graph = std::make_shared(graph); - auto nodes = graph->immutable_nodes(); - auto ret = kernel::KernelExecUtil::TopologicalSortNodes(&nodes, graph->in_nodes()); - if (ret != RET_OK) { - MS_LOG(ERROR) << "TopologicalSortNodes failed!"; - return nullptr; - } - for (auto node : nodes) { - adapter_graph->nodes_.emplace_back(new KernelExecAdapterNode(node, mark_filter)); - } - return adapter_graph; - } - - explicit SubGraphKernelAdapterGraph(const kernel::SubGraphKernel *graph) : graph_(graph) {} - ~SubGraphKernelAdapterGraph() override { - for (auto node : nodes_) { - delete node; - } - nodes_.clear(); - } - std::string GetName() const override { return graph_->name(); } - std::vector GetNodes() const override { return nodes_; } - std::vector GetInputs() const override { return graph_->in_tensors(); } - size_t InputSize() const override { return graph_->in_tensors().size(); } - std::vector GetOutputs() const override { return graph_->out_tensors(); } - size_t OutputSize() const override { return graph_->out_tensors().size(); } - - private: - const kernel::SubGraphKernel *graph_; - std::vector nodes_; -}; - -std::shared_ptr CreateGVGraph(const kernel::SubGraphKernel *graph, const MarkFilter &mark_filter = nullptr) { - auto adapter_graph = SubGraphKernelAdapterGraph::Create(graph, mark_filter); - if (adapter_graph == nullptr) { - MS_LOG(ERROR) << "Create SubGraphKernelAdapterGraph failed."; - return nullptr; - } - GVGraphBuilder builder; - auto gv_graph = builder.Build(adapter_graph); - if (gv_graph == nullptr) { - MS_LOG(ERROR) << "Build gv_graph failed."; - return nullptr; - } - return gv_graph; -} -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_SRC_COMMON_DRAW_ADAPTER_GRAPHS_SUB_GRAPH_KERNEL_ADAPTER_GRAPH_H_ -#endif diff --git a/mindspore-lite/src/common/draw/drawer.cc b/mindspore-lite/src/common/draw/drawer.cc deleted file mode 100644 index bf5313b257a7cdd6af09036d989a316f2ad95689..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/common/draw/drawer.cc +++ /dev/null @@ -1,109 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/common/draw/drawer.h" -#ifdef ENABLE_DRAW -#include -#include -#include -#include "src/common/file_utils.h" -#include "src/common/draw/adapter_graphs/sub_graph_kernel_adapter_graph.h" -#include "src/common/draw/adapter_graphs/compile_result_adapter_graph.h" -#endif - -namespace mindspore::lite { -constexpr char kDefaultDrawDIR[] = "./graphs"; -#ifdef ENABLE_DRAW -inline void Drawer::Reset() { count_ = 0; } - -void Drawer::Init() { - auto ret = CreateDir(kDefaultDrawDIR); - if (ret != RET_OK) { - MS_LOG(WARNING) << "Create draw directory failed, disable draw."; - enabled_ = false; - } - if (enabled_) { - base_dir_ = RealPath(kDefaultDrawDIR); - if (base_dir_.empty()) { - MS_LOG(WARNING) << kDefaultDrawDIR << " is invalid: " << base_dir_ << ", disable draw."; - enabled_ = false; - } - } - Reset(); -} - -std::string Drawer::GetNextFileName(const std::string &name) { - std::ostringstream oss; - oss << std::setw(3) << std::setfill('0') << count_++ << '-' << name << ".dot"; - return oss.str(); -} - -inline bool Drawer::SaveDotFile(const std::string &dot_name, const std::string &dot_content) { - auto fname = GetNextFileName(dot_name); - auto write_path = lite::WriteStrToFile(this->base_dir_, fname, dot_content); - if (write_path.empty()) { - MS_LOG(ERROR) << "Save dot-file failed, path: " << this->base_dir_ << ", fname: " << fname; - return false; - } else { - MS_LOG(INFO) << "Save dot-file successfully, path: " << write_path; - return true; - } -} - -void Drawer::Draw(const kernel::SubGraphKernel *graph, const std::string &name) { - if (!enabled_) { - return; - } - auto gv_graph = lite::CreateGVGraph(graph); - if (gv_graph == nullptr) { - MS_LOG(ERROR) << "Create gv_graph failed."; - return; - } - (void)SaveDotFile(name, gv_graph->Code()); -} -#ifdef ENABLE_CLOUD_INFERENCE -void Drawer::Draw(const CompileResult *graph, const std::string &name) { - if (!enabled_) { - return; - } - auto gv_graph = lite::CreateGVGraph(graph); - if (gv_graph == nullptr) { - MS_LOG(ERROR) << "Create gv_graph failed."; - return; - } - (void)SaveDotFile(name, gv_graph->Code()); -} -#endif -#else -#define WARNLOG \ - MS_LOG(WARNING) << "Drawer is not enabled, please set env 'export MSLITE_EXPORT_COMPUTE_IR=on; export " \ - << kDrawDIREnvKey << "=/path/to/draw_dir' to enable drawer." - -inline void Drawer::Reset() { WARNLOG; } - -void Drawer::Init() { WARNLOG; } - -inline bool Drawer::SaveDotFile(const std::string &dot_name, const std::string &dot_content) { - WARNLOG; - return false; -} - -void Drawer::Draw(const kernel::SubGraphKernel *graph, const std::string &name) { WARNLOG; } -#ifdef ENABLE_CLOUD_INFERENCE -void Drawer::Draw(const CompileResult *graph, const std::string &name) { WARNLOG; } -#endif -#endif -} // namespace mindspore::lite diff --git a/mindspore-lite/src/common/draw/graphviz_graph.cc b/mindspore-lite/src/common/draw/graphviz_graph.cc deleted file mode 100644 index 7ad0bcfa49fa0e1c61d9422b4a6e0b2cd2bdce49..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/common/draw/graphviz_graph.cc +++ /dev/null @@ -1,266 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/common/draw/graphviz_graph.h" -#include -#include -#include -#include - -namespace mindspore::lite { -std::string Edge::From() const { return from_->name(); } -std::string Edge::name() const { return this->name_; } - -void Edge::AppendOutput(const GVNode *to, size_t port) { - tos_.emplace_back(to); - to_ports_.emplace_back(port); -} - -std::string Edge::Code() const { - std::ostringstream oss; - if (from_->type() == kNodeTypeCNode) { - oss << from_->prefix() << from_->name() << ":O" << from_port_ << " -> "; - } else { - oss << from_->prefix() << from_->name() << " -> "; - } - auto from_str = oss.str(); - oss.str(""); - for (size_t i = 0; i < tos_.size(); i++) { - auto to = tos_[i]; - if (to->type() == kNodeTypeCNode) { - oss << from_str << to->prefix() << to->name() << ":I" << to_ports_[i] << " [label=\"" << info_ << "\"];"; - } else { - oss << from_str << to->prefix() << to->name() << " [label=\"" << info_ << "\"];"; - } - } - return oss.str(); -} - -GVNode *GVNode::CreateCNode(const std::string &id, const std::string &label, size_t input_size, - const std::vector &output_names, const std::vector &output_infos, - bool highlight) { - auto node = new (std::nothrow) GVNode(id, label, kNodeTypeCNode, input_size, output_names.size(), highlight); - if (node == nullptr) { - MS_LOG(ERROR) << "new GVNode failed!"; - return nullptr; - } - node->prefix_ = "Node_"; - node->shape_ = "plaintext"; - node->color_ = "cornsilk"; - node->Init(output_names, output_infos); - return node; -} - -GVNode *GVNode::CreateInput(const std::string &id, const std::vector &output_names, - const std::vector &output_infos, bool highlight) { - auto node = new (std::nothrow) GVNode(id, id, kNodeTypeInput, 0, output_names.size(), highlight); - if (node == nullptr) { - MS_LOG(ERROR) << "new GVNode failed!"; - return nullptr; - } - node->prefix_ = "Input_"; - node->shape_ = "egg"; - node->Init(output_names, output_infos); - return node; -} - -GVNode *GVNode::CreateOutput(const std::string &id, size_t input_size, bool highlight) { - auto node = new (std::nothrow) GVNode(id, id, kNodeTypeOutput, input_size, 0, highlight); - if (node == nullptr) { - MS_LOG(ERROR) << "new GVNode failed!"; - return nullptr; - } - node->prefix_ = "Output_"; - node->shape_ = "egg"; - node->Init({}, {}); - return node; -} - -GVNode *GVNode::CreateWeight(const std::string &id, const std::string &label, - const std::vector &output_names, const std::vector &output_infos, - bool highlight) { - auto node = new (std::nothrow) GVNode(id, label, kNodeTypeWeight, 0, output_names.size(), highlight); - if (node == nullptr) { - MS_LOG(ERROR) << "new GVNode failed!"; - return nullptr; - } - node->prefix_ = "Weight_"; - node->shape_ = "octagon"; - node->color_ = "paleturquoise"; - node->Init(output_names, output_infos); - return node; -} - -GVNode::~GVNode() { - for (auto output : outputs_) { - delete output; - } - outputs_.clear(); -} - -void GVNode::Init(const std::vector &output_names, const std::vector &output_infos) { - inputs_.reserve(input_size_); - outputs_.reserve(output_size_); - if (output_names.size() != output_size_) { - MS_LOG(ERROR) << "GVNode init failed! output_names size " << output_names.size() - << ", output_size_ = " << output_size_; - return; - } - for (size_t i = 0; i < output_size_; i++) { - auto edge = new (std::nothrow) Edge(output_names[i], this, i, output_infos[i]); - if (edge == nullptr) { - MS_LOG(ERROR) << "GVNode init failed! New Edge failed, please check whether memory is enough!"; - return; - } - - this->outputs_.emplace_back(edge); - } -} - -size_t GVNode::FindCols() const { - auto max = std::max(input_size_, output_size_); - auto min = std::min(input_size_, output_size_); - if (min == 0 || max == 0) { - return 1; - } - size_t ret = max; - while (ret <= input_size_ * output_size_) { - if (ret % min == 0) { - break; - } - ret++; - } - while (ret <= input_size_ * output_size_) { - if (ret % max == 0) { - break; - } - ret += min; - } - return ret; -} - -std::string GVNode::Code() const { - std::ostringstream oss; - if (type_ == kNodeTypeCNode) { - auto bgcolor = highlight_ ? "red" : color_; - oss << "\t" - << "\t" - << "\t" - << "\t"; - auto indent = oss.str(); - oss.str(""); - auto cols = FindCols(); - oss << "<" << std::endl; - oss << indent << ""; - auto input_cols = input_size_ == 0 ? 0 : cols / input_size_; - for (size_t i = 0; i < input_size_; i++) { - oss << ""; - } - oss << "" << std::endl; - oss << indent << "" << std::endl; - oss << indent << ""; - auto output_cols = output_size_ == 0 ? 0 : cols / output_size_; - for (size_t i = 0; i < output_size_; i++) { - oss << ""; - } - oss << "" << std::endl; - oss << indent << "
I" << i << "
" << label_ - << "
O" << i << "
>"; - } else { - oss << "\"" << label_ << "\""; - } - auto label = oss.str(); - oss.str(""); - oss << prefix_ << id_ << " [shape=" << shape_; - oss << ", label=" << label; - if (type_ != kNodeTypeCNode) { - oss << ", style=filled, fillcolor=" << color_; - } - oss << "];"; - return oss.str(); -} - -GVGraph::~GVGraph() { - for (auto *node : nodes_) { - delete node; - } - nodes_.clear(); -} - -void GVGraph::AppendNode(GVNode *node) { - if (node == nullptr) { - return; - } - nodes_.emplace_back(node); - node_map_[node->name()] = node; -} - -int GVGraph::Link(const std::string &from_name, size_t from_port, const std::string &to_name, size_t to_port) { - auto from = node_map_.find(from_name); - if (from == node_map_.end()) { - MS_LOG(ERROR) << "Node " << from_name << " is not belong to this graph."; - return RET_ERROR; - } - - if (from->second == nullptr) { - MS_LOG(ERROR) << "from node is null!"; - return RET_ERROR; - } - if (from_port >= from->second->output_size()) { - MS_LOG(ERROR) << "`from_port`(" << from_port << ") out of range of node(" << from_name - << ")'s output ports number: " << from->second->output_size(); - return RET_ERROR; - } - auto to = node_map_.find(to_name); - if (to == node_map_.end()) { - MS_LOG(ERROR) << "Node " << to_name << " is not belong to this graph."; - return RET_ERROR; - } - if (to->second == nullptr) { - MS_LOG(ERROR) << "to node is null!"; - return RET_ERROR; - } - if (to_port >= to->second->input_size()) { - MS_LOG(ERROR) << "`to_port`(" << to_port << ") out of range of node(" << to_name - << ")'s input ports number: " << to->second->input_size(); - return RET_ERROR; - } - if (to_port < to->second->size()) { - MS_LOG(ERROR) << "node(" << to_name << ")'s " << to_port << "th input port already link to " - << to->second->inputs()[to_port]->From(); - return RET_ERROR; - } - auto edge = from->second->outputs()[from_port]; - edge->AppendOutput(to->second, to_port); - to->second->AppendInput(edge); - return RET_OK; -} - -std::string GVGraph::Code() const { - std::ostringstream oss; - oss << "digraph " << name_ << " {" << std::endl; - for (auto node : nodes_) { - oss << node->Code() << std::endl; - } - for (auto node : nodes_) { - for (auto output : node->outputs()) { - oss << output->Code() << std::endl; - } - } - oss << "}"; - return oss.str(); -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/common/draw/graphviz_graph.h b/mindspore-lite/src/common/draw/graphviz_graph.h deleted file mode 100644 index bdcbe7a82ca59f228ed218a5657e427d10fccafe..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/common/draw/graphviz_graph.h +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_COMMON_DRAW_GRAPHVIZ_GRAPH_H_ -#define MINDSPORE_LITE_SRC_COMMON_DRAW_GRAPHVIZ_GRAPH_H_ - -#include -#include -#include -#include -#include "src/common/log_adapter.h" -#include "src/tensor.h" -#include "include/errorcode.h" - -namespace mindspore::lite { -constexpr int kNodeTypeCNode = 0; -constexpr int kNodeTypeInput = 1; -constexpr int kNodeTypeOutput = 2; -constexpr int kNodeTypeWeight = 3; -class GVNode; - -class Edge { - public: - Edge(std::string name, const GVNode *from, const size_t &from_port, std::string info) - : name_(std::move(name)), from_(from), from_port_(from_port), info_(std::move(info)) {} - - std::string From() const; - std::string name() const; - void AppendOutput(const GVNode *to, size_t port); - std::string Code() const; - - private: - std::string name_; - const GVNode *from_{nullptr}; - const size_t from_port_{}; - std::vector tos_{}; - std::vector to_ports_{}; - std::string info_; -}; - -class GVNode { - public: - static GVNode *CreateCNode(const std::string &id, const std::string &label, size_t input_size, - const std::vector &output_names, const std::vector &output_infos, - bool highlight = false); - static GVNode *CreateInput(const std::string &id, const std::vector &output_names, - const std::vector &output_infos, bool highlight = false); - static GVNode *CreateOutput(const std::string &id, size_t input_size, bool highlight = false); - static GVNode *CreateWeight(const std::string &id, const std::string &label, - const std::vector &output_names, - const std::vector &output_infos, bool highlight = false); - virtual ~GVNode(); - - int type() const { return this->type_; } - std::string prefix() const { return this->prefix_; } - std::string name() const { return this->id_; } - size_t input_size() const { return input_size_; } - size_t output_size() const { return output_size_; } - const std::vector &inputs() const { return inputs_; } - const std::vector &outputs() const { return outputs_; } - void AppendInput(Edge *edge) { this->inputs_.emplace_back(edge); } - std::string Code() const; - - protected: - GVNode(std::string id, std::string label, int type, size_t input_size, size_t output_size, bool highlight = false) - : id_(std::move(id)), - label_(std::move(label)), - type_(type), - input_size_(input_size), - output_size_(output_size), - highlight_(highlight) {} - void Init(const std::vector &output_names, const std::vector &output_infos); - size_t FindCols() const; - - private: - std::string id_; - std::string label_; - int type_; - std::string prefix_; - std::string color_ = "white"; - size_t input_size_{0}; - size_t output_size_{0}; - std::string shape_; - bool highlight_{false}; - std::vector inputs_{}; - std::vector outputs_{}; -}; - -class GVGraph { - public: - explicit GVGraph(std::string name) : name_{std::move(name)} {}; - virtual ~GVGraph(); - - void AppendNode(GVNode *node); - int Link(const std::string &from_name, size_t from_port, const std::string &to_name, size_t to_port); - std::string Code() const; - - private: - std::string name_; - std::vector nodes_; - std::unordered_map node_map_; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_SRC_COMMON_DRAW_GRAPHVIZ_GRAPH_H_ diff --git a/mindspore-lite/src/common/draw/graphviz_graph_builder.cc b/mindspore-lite/src/common/draw/graphviz_graph_builder.cc deleted file mode 100644 index 8a004eb05d6a315088ccac8636b68c508de03b3d..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/common/draw/graphviz_graph_builder.cc +++ /dev/null @@ -1,261 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/common/draw/graphviz_graph_builder.h" -#include -#include -#include "src/common/draw/adapter_graph.h" -#include "ir/dtype.h" - -namespace mindspore::lite { -namespace { -inline void StrReplace(std::string *str) { - replace(str->begin(), str->end(), '/', '_'); - replace(str->begin(), str->end(), '-', '_'); -} - -inline void ShortName(std::string *str) { - auto pos = str->rfind('/'); - if (pos == std::string::npos) { - return; - } - *str = str->substr(pos + 1); -} - -inline std::string GetNodeId(const AdapterNode &node) { - auto name = node.GetName(); - StrReplace(&name); - return name; -} - -inline std::string GetNodeLabel(const AdapterNode &node) { - auto name = node.GetName(); - ShortName(&name); - StrReplace(&name); - return name; -} - -inline std::string GetTensorId(const lite::Tensor &tensor) { - auto name = tensor.tensor_name(); - StrReplace(&name); - return name; -} - -inline std::string GetTensorInfo(const lite::Tensor &tensor) { - auto tensor_info = FormatEnumToString(tensor.format()); - tensor_info += ", "; - tensor_info += TypeIdToString(tensor.data_type()); - tensor_info += ", "; - tensor_info += lite::ShapeVectorToStr(tensor.shape()); - return tensor_info; -} -} // namespace - -std::shared_ptr GVGraphBuilder::Build(const std::shared_ptr &graph) { - gv_graph_ = std::make_shared(graph->GetName()); - // graph inputs - for (auto in_tensor : graph->GetInputs()) { - this->AppendGraphInputNode(*in_tensor); - } - // nodes - for (const auto *node : graph->GetNodes()) { - auto node_id = GetNodeId(*node); - auto node_label = GetNodeLabel(*node); - for (size_t i = 0; i < node->InputSize(); i++) { - auto in_tensor = node->GetInput(i); - if (GetBelongingGVNode(in_tensor).first == nullptr) { - if (!in_tensor->IsConst()) { - MS_LOG(WARNING) << "The " << i << "th input of " << node->GetName() - << " is neither a const tensor nor an output of other node. Treat it as a weight node."; - } - auto tensor_id = node_id + "_in_" + std::to_string(i); - auto tensor_label = node_label + "_in_" + std::to_string(i); - AppendWeightNode(*in_tensor, tensor_id, tensor_label); - } - } - auto ret = this->AppendComputeNode(*node); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Create and append gv_node for " << node->GetName() << " failed."; - return nullptr; - } - } - // graph outputs - auto ret = this->AppendGraphOutputNode(graph->GetOutputs()); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Create and append graph return node failed"; - return nullptr; - } - return this->gv_graph_; -} - -void GVGraphBuilder::AppendGraphInputNode(const lite::Tensor &tensor) { - auto tensor_id = GetTensorId(tensor); - auto gv_node = lite::GVNode::CreateInput(tensor_id, {tensor_id}, {GetTensorInfo(tensor)}); - MS_ASSERT(gv_node != nullptr); - gv_graph_->AppendNode(gv_node); - gv_node_out_tensor_map_[&tensor] = std::make_pair(gv_node, 0); -} - -namespace { -template -std::string BufferToString(const T *buffer, size_t size) { - MS_ASSERT(buffer != nullptr); - constexpr size_t print_pre_number = 3; - constexpr size_t print_post_number = 3; - constexpr size_t print_period_number = 2; - if (size <= print_pre_number + print_post_number + print_period_number) { - std::ostringstream oss; - for (size_t i = 0; i < size; i++) { - if (i == 0) { - oss << buffer[i]; - } else { - oss << ", " << buffer[i]; - } - } - return oss.str(); - } - - size_t index = 0; - std::ostringstream oss; - for (size_t i = 0; i < print_pre_number; i++, index++) { - if (index == 0) { - oss << buffer[index]; - } else { - oss << ", " << buffer[index]; - } - } - oss << "..."; - for (size_t i = 0; i < print_post_number; i++, index++) { - oss << ", " << buffer[index]; - } - return oss.str(); -} - -std::string TensorDataString(const lite::Tensor &tensor) { - if (tensor.shape().size() != 1 || tensor.shape()[0] <= 0 || tensor.data() == nullptr) { - return ""; - } - auto data_size = static_cast(tensor.shape()[0]); - - std::ostringstream oss; - oss << "\n["; - if (tensor.data_type() == kNumberTypeInt || tensor.data_type() == kNumberTypeInt32) { - auto data = reinterpret_cast(tensor.data()); - oss << BufferToString(data, data_size); - } else if (tensor.data_type() == kNumberTypeInt64) { - auto data = reinterpret_cast(tensor.data()); - oss << BufferToString(data, data_size); - } else { - return ""; - } - oss << "]"; - return oss.str(); -} -} // namespace - -void GVGraphBuilder::AppendWeightNode(const lite::Tensor &tensor, const std::string &id, const std::string &label) { - auto gv_node = lite::GVNode::CreateWeight(id, label + TensorDataString(tensor), {id}, {GetTensorInfo(tensor)}); - MS_ASSERT(gv_node != nullptr); - gv_graph_->AppendNode(gv_node); - AppendOutTensorMap(&tensor, gv_node, 0); -} - -int GVGraphBuilder::AppendComputeNode(const AdapterNode &node) { - auto gv_node = CreateComputeNode(node); - if (gv_node == nullptr) { - MS_LOG(ERROR) << "Create gv_node for " << node.GetName() << " failed."; - return RET_ERROR; - } - gv_graph_->AppendNode(gv_node); - for (size_t i = 0; i < node.OutputSize(); i++) { - AppendOutTensorMap(node.GetOutput(i), gv_node, i); - } - auto ret = LinkNodes(node, *gv_node); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Link inputs for " << node.GetName() << " failed."; - return RET_ERROR; - } - return RET_OK; -} - -int GVGraphBuilder::AppendGraphOutputNode(const std::vector &out_tensors) { - auto out_tensor_size = out_tensors.size(); - auto gv_node = lite::GVNode::CreateOutput("return", out_tensor_size); - if (gv_node == nullptr) { - MS_LOG(ERROR) << "create output node failed!"; - return RET_ERROR; - } - gv_graph_->AppendNode(gv_node); - for (size_t i = 0; i < out_tensors.size(); i++) { - auto out_tensor = out_tensors[i]; - auto pair = this->GetBelongingGVNode(out_tensor); - if (pair.first == nullptr) { - MS_LOG(ERROR) << "Can not find graph output tensor source: " << out_tensor->tensor_name(); - return RET_ERROR; - } - auto link_ret = gv_graph_->Link(pair.first->name(), pair.second, gv_node->name(), i); - if (link_ret != RET_OK) { - MS_LOG(ERROR) << "Link " << i << "th input tensor of return failed."; - return RET_ERROR; - } - } - return RET_OK; -} - -GVNode *GVGraphBuilder::CreateComputeNode(const AdapterNode &node) { - auto node_id = GetNodeId(node); - auto node_label = GetNodeLabel(node); - std::vector output_names; - std::vector output_infos; - for (auto out_tensor : node.GetOutputs()) { - output_names.emplace_back(GetTensorId(*out_tensor)); - output_infos.emplace_back(GetTensorInfo(*out_tensor)); - } - auto *gv_node = - lite::GVNode::CreateCNode(node_id, node_label, node.InputSize(), output_names, output_infos, node.IsHighlight()); - MS_ASSERT(gv_node != nullptr); - return gv_node; -} - -void GVGraphBuilder::AppendOutTensorMap(const lite::Tensor *tensor, lite::GVNode *node, size_t out_index) { - gv_node_out_tensor_map_[tensor] = std::make_pair(node, out_index); -} - -std::pair GVGraphBuilder::GetBelongingGVNode(const lite::Tensor *tensor) const { - auto iter = gv_node_out_tensor_map_.find(tensor); - if (iter == gv_node_out_tensor_map_.end()) { - return {}; - } else { - return iter->second; - } -} -int GVGraphBuilder::LinkNodes(const AdapterNode &node, const GVNode &gv_node) { - for (size_t i = 0; i < node.InputSize(); i++) { - auto in_tensor = node.GetInput(i); - auto pair = this->GetBelongingGVNode(in_tensor); - if (pair.first == nullptr) { - MS_LOG(ERROR) << "Can not find input tensor source: " << in_tensor->tensor_name(); - return RET_ERROR; - } - auto link_ret = gv_graph_->Link(pair.first->name(), pair.second, gv_node.name(), i); - if (link_ret != RET_OK) { - MS_LOG(ERROR) << "Link " << i << "th input tensor of " << node.GetName() << " failed."; - return RET_ERROR; - } - } - return RET_OK; -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/executor/CMakeLists.txt b/mindspore-lite/src/executor/CMakeLists.txt index 4f210e1f08a605767b67428dda14f5de9740dc4f..47124efc280bde465f07eeff5ead7cc7cbe0d2a4 100644 --- a/mindspore-lite/src/executor/CMakeLists.txt +++ b/mindspore-lite/src/executor/CMakeLists.txt @@ -20,7 +20,6 @@ set(LITE_SRC ${LITE_DIR}/src/litert/infer_manager.cc ${LITE_DIR}/src/litert/runtime_shape_fusion_pass.cc ${LITE_DIR}/src/litert/runtime_pass.cc - # ${LITE_DIR}/src/litert/pass/runtime_ncx_pass.cc ${LITE_DIR}/src/litert/schema_tensor_wrapper.cc ${LITE_DIR}/src/tensor.cc ${LITE_DIR}/src/tensorlist.cc @@ -97,14 +96,6 @@ if(MSLITE_ENABLE_CONTROLFLOW) set(LITE_SRC ${LITE_SRC} ${CONTROL_FLOW_KERNEL_SRC}) endif() -if(MSLITE_ENABLE_RUNTIME_GLOG) - if(NOT MSLITE_ENABLE_RUNTIME_CONVERT AND NOT MSLITE_ENABLE_KERNEL_EXECUTOR - AND NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) - set(LITE_SRC ${LITE_SRC} - ${CORE_DIR}/utils/log_adapter.cc) - endif() -endif() - if(MSLITE_ENABLE_RUNTIME_CONVERT) file(GLOB RUNTIME_CONVERT_SRC ${LITE_DIR}/src/common/ops/ops_def.cc @@ -196,21 +187,16 @@ if(MSLITE_ENABLE_MINDRT) ) set(LITE_SRC ${LITE_SRC} ${CONTROL_FLOW_ACTOR_SRC}) endif() -else() - set(LITE_SRC ${LITE_SRC} - ${CORE_DIR}/mindrt/src/thread/core_affinity.cc - ${CORE_DIR}/mindrt/src/thread/threadpool.cc - ) endif() -if(MSLITE_ENABLE_GRAPH_KERNEL) - file(GLOB_RECURSE GRAPH_KERNEL_SRC - ${TOOLS_DIR}/graph_kernel/common/*.cc - ${TOOLS_DIR}/graph_kernel/runtime/*.cc - ${OPS_DIR}/kernel/cpu/akg/akg_kernel_loader.cc - ) - set(LITE_SRC ${LITE_SRC} ${GRAPH_KERNEL_SRC}) -endif() +#if(MSLITE_ENABLE_GRAPH_KERNEL) +# file(GLOB_RECURSE GRAPH_KERNEL_SRC +# ${TOOLS_DIR}/graph_kernel/common/*.cc +# ${TOOLS_DIR}/graph_kernel/runtime/*.cc +# ${OPS_DIR}/kernel/cpu/akg/akg_kernel_loader.cc +# ) +# set(LITE_SRC ${LITE_SRC} ${GRAPH_KERNEL_SRC}) +#endif() if(NOT MSLITE_ENABLE_COREML) set(LITE_SRC ${LITE_SRC} @@ -223,16 +209,15 @@ set(MSLITE_GE_LITERT_SRC ${LITE_DIR}/src/extendrt/delegate/graph_executor/litert/litert_plugin_impl.cc ) -#set(LITE_SRC ${LITE_SRC} ${MSLITE_GE_LITERT_SRC}) - if(NOT MSLITE_SIMPLEST_CLOUD_INFERENCE) add_library(unified_runtime_lite_src_mid OBJECT ${LITE_SRC}) add_dependencies(unified_runtime_lite_src_mid fbs_src fbs_inner_src) add_dependencies(unified_runtime_lite_src_mid lite_src_common_mid) + target_link_libraries(unified_runtime_lite_src_mid mindspore_core) add_library(lite-unified-executor SHARED $) - target_link_libraries(lite-unified-executor lite_src_common_mid) + target_link_libraries(lite-unified-executor lite_src_common_mid mindspore_core) add_dependencies(lite-unified-executor mindspore_converter) target_link_libraries(lite-unified-executor mindspore_converter) @@ -256,7 +241,8 @@ if(NOT MSLITE_SIMPLEST_CLOUD_INFERENCE) AND NOT TARGET_MIX210 AND NOT TARGET_OHOS_LITE AND NOT MACHINE_LINUX_ARM64) target_link_libraries(lite-unified-executor log) endif() - if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite") + if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" AND NOT + (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) target_link_libraries(lite-unified-executor minddata_eager_mid minddata-lite) endif() @@ -273,7 +259,7 @@ if(NOT MSLITE_SIMPLEST_CLOUD_INFERENCE) if(MSLITE_ENABLE_RUNTIME_CONVERT) target_link_libraries(lite-unified-executor quantizer_mid fusion_mid proto_mid graph_pass_mid preprocess_mid - cpu_kernel_mid ccsrc_src_mid converter_src_mid lite_exporter_mid + cpu_kernel_mid converter_src_mid lite_exporter_mid config_parser_mid mslite_converter_plugin mindspore_core mindspore_ops coder_mid mindir_serializer_mid mindspore::protobuf ${SECUREC_LIBRARY}) target_link_libraries(lite-unified-executor diff --git a/mindspore-lite/src/extendrt/CMakeLists.txt b/mindspore-lite/src/extendrt/CMakeLists.txt index 72c54233b1d2d810f99850b8b64bc4daabc43aec..ad44830d260ef84f76acab11f0c97e7773041667 100644 --- a/mindspore-lite/src/extendrt/CMakeLists.txt +++ b/mindspore-lite/src/extendrt/CMakeLists.txt @@ -15,75 +15,26 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) add_compile_definitions(ENABLE_CLOUD_INFERENCE) remove_definitions(-DBUILD_LITE_INFERENCE) - set(MSLITE_KERNEL_PLUGIN - ${MSLITE_KERNEL_PLUGIN} - ${CMAKE_CURRENT_SOURCE_DIR}/kernel/ascend/plugin/ascend_kernel_plugin.cc) - - set(MSLITE_EXTEND_NNACL_KERNEL_LIB_SRC - ${CMAKE_CURRENT_SOURCE_DIR}/kernel/nnacl/nnacl_lib.cc - ) - - set(MSLITE_GRAPH_KERNEL_FLAGS_SRC - ${CCSRC_DIR}/backend/common/graph_kernel/graph_kernel_flags.cc - ) - - set(MSLITE_EXTEND_DEFAULT_KERNEL_LIB_SRC - ${CMAKE_CURRENT_SOURCE_DIR}/kernel/default/cnode_infer_manager.cc - ${CMAKE_CURRENT_SOURCE_DIR}/kernel/default/kernel_mod_kernel.cc - ${CMAKE_CURRENT_SOURCE_DIR}/kernel/default/default_kernel_lib.cc) - - file(GLOB DELEGATE_OPS ${CMAKE_CURRENT_SOURCE_DIR}/delegate/ops/*.cc) - set(MSLITE_EXTEND_RUNTIME_SRC - ${MSLITE_KERNEL_PLUGIN} - ${MSLITE_EXTEND_NNACL_KERNEL_LIB_SRC} - ${MSLITE_GRAPH_KERNEL_FLAGS_SRC} - ${MSLITE_EXTEND_DEFAULT_KERNEL_LIB_SRC} - ${CMAKE_CURRENT_SOURCE_DIR}/kernel/kernel_selector/kernel_selector.cc ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/numa_adapter.cc ${CMAKE_CURRENT_SOURCE_DIR}/model_manager.cc - ${CMAKE_CURRENT_SOURCE_DIR}/kernel/cpu/less_test_kernel_mod.cc - ${CMAKE_CURRENT_SOURCE_DIR}/kernel/cpu/transpose_kernel_mod.cc - ${CMAKE_CURRENT_SOURCE_DIR}/kernel/base_kernel.cc ${CMAKE_CURRENT_SOURCE_DIR}/infer_session.cc - ${CMAKE_CURRENT_SOURCE_DIR}/session/single_op_session.cc - ${CMAKE_CURRENT_SOURCE_DIR}/session/memory_offload_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/session/delegate_session.cc - ${CMAKE_CURRENT_SOURCE_DIR}/session/default_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/session/factory.cc - ${CMAKE_CURRENT_SOURCE_DIR}/memory_offload/infer_strategy_builder.cc - ${CMAKE_CURRENT_SOURCE_DIR}/infer_device_address.cc - ${CMAKE_CURRENT_SOURCE_DIR}/utils/kernel_build_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/tensor_utils.cc - ${CMAKE_CURRENT_SOURCE_DIR}/utils/runtime_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/serialization.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/func_graph_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/delegate/comm_group_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/delegate/factory.cc - ${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/tensorrt_executor_plugin.cc ${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/litert_executor_plugin.cc ${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/ascend_ge_executor_plugin.cc - ${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/ascend_native_executor_plugin.cc - ${CMAKE_CURRENT_SOURCE_DIR}/delegate/tensorrt/distribution/distribution_base.cc + ${CMAKE_CURRENT_SOURCE_DIR}/delegate/plugin/ascend_acl_executor_plugin.cc ${CMAKE_CURRENT_SOURCE_DIR}/delegate_graph_executor.cc ${CMAKE_CURRENT_SOURCE_DIR}/delegate/graph_executor/litert/func_graph_reuse_manager.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph_compiler/factory.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph_compiler/default_graph_compiler.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph_runtime/factory.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph_runtime/default_graph_runtime.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph_executor/factory.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph_executor/mindrt_graph_executor.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph_executor/default_executor.cc ${CMAKE_CURRENT_SOURCE_DIR}/execution_flow.cc ${CMAKE_CURRENT_SOURCE_DIR}/execution_plan.cc ${CMAKE_CURRENT_SOURCE_DIR}/../infer/primitive_type.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph_compiler/compile_result.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph_compiler/single_graph_scheduler.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph_compiler/compile_result_builder.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph_compiler/anfnode_tensor_adapter.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph_compiler/infershape_helper.cc - ${DELEGATE_OPS} ) if(MSLITE_ENABLE_BFC_MEMORY) @@ -92,150 +43,43 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) ${CMAKE_CURRENT_SOURCE_DIR}/dynamic_mem_manager.cc ) endif() - if(MSLITE_ENABLE_ACL) - set(MSLITE_EXTEND_RUNTIME_SRC ${MSLITE_EXTEND_RUNTIME_SRC} - ${CMAKE_CURRENT_SOURCE_DIR}/../train/opt_allocator.cc - ${CMAKE_CURRENT_SOURCE_DIR}/kernel/base_kernel.cc - ) - if(MSLITE_ASCEND_TARGET) - set(ASCEND_NATIVE_PLUGIN - ${CMAKE_CURRENT_SOURCE_DIR}/session/ascend_native_session.cc - ${CMAKE_CURRENT_SOURCE_DIR}/kernel/ascend_native/ascend_native_composite_kernel.cc - ${CMAKE_CURRENT_SOURCE_DIR}/kernel/ascend_native/ascend_native_copy_kernel.cc - ${CMAKE_CURRENT_SOURCE_DIR}/delegate/ascend_native/ascend_native_registration_factory.cc - ${CMAKE_CURRENT_SOURCE_DIR}/delegate/ascend_native/ascend_native_add_kernel.cc - ${CMAKE_CURRENT_SOURCE_DIR}/delegate/ascend_native/ascend_native_matmul_kernel.cc - ${CMAKE_CURRENT_SOURCE_DIR}/delegate/ascend_native/ascend_native_layernorm_kernel.cc - ${CMAKE_CURRENT_SOURCE_DIR}/delegate/ascend_native/ascend_native_gather_kernel.cc - ${CMAKE_CURRENT_SOURCE_DIR}/delegate/ascend_native/ascend_native_encoder_kernel.cc - ${CCSRC_DIR}/plugin/res_manager/ascend/hccl_adapter/hccl_adapter.cc - ${CCSRC_DIR}/utils/config_manager.cc - ${CMAKE_CURRENT_SOURCE_DIR}/mock/ge_mock.cc - ${CMAKE_CURRENT_SOURCE_DIR}/mock/transform_mock.cc - ) - add_library(ascend_native_plugin SHARED ${ASCEND_NATIVE_PLUGIN}) - find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - target_link_libraries(ascend_native_plugin ${ge_graph}) - target_include_directories(ascend_native_plugin PRIVATE ${CMAKE_BINARY_DIR}/proto/ge) - add_subdirectory(delegate/ascend_native) - if(TARGET ascend_native_kernels_impl) - set(ASCEND_NATIVE_KERNELS_IMPL ascend_native_kernels_impl) - endif() - target_link_libraries(ascend_native_plugin ${ASCEND_NATIVE_KERNELS_IMPL} ascend_native_mid) - set(ASCEND_TOOLKIT_PLUGIN_PATH ${ASCEND_TOOLKIT_RUNTIME_PATH}/plugin/opskernel) - include_directories(${CCSRC_DIR}/plugin/res_manager/ascend/hccl_adapter/) - add_subdirectory(${CCSRC_DIR}/plugin/res_manager/ascend/hccl_adapter/plugin build) - endif() - endif() - include(${LITE_DIR}/cmake/ccsrc_extendrt.cmake) include(${TOP_DIR}/cmake/external_libs/pocketfft.cmake) - set_property(SOURCE ${MSLITE_EXTEND_RUNTIME_SRC} PROPERTY COMPILE_DEFINITIONS LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) - add_library(mindspore-extendrt SHARED ${MSLITE_EXTEND_RUNTIME_SRC} ${MSLITE_EXTEND_CPU_RUNTIME_SRC}) add_dependencies(mindspore-extendrt lite_src_common_mid) target_link_libraries(mindspore-extendrt lite_src_common_mid) add_dependencies(mindspore-extendrt fbs_src fbs_inner_src) - add_dependencies(mindspore-extendrt mindspore-infer-anfalgo) - add_dependencies(mindspore-extendrt mindspore-kernel-graph) add_subdirectory(cxx_api) add_subdirectory(delegate/graph_executor/litert) add_library(mindspore-extendrt_static STATIC ${MSLITE_EXTEND_RUNTIME_SRC}) add_dependencies(mindspore-extendrt_static lite_src_common_mid) target_link_libraries(mindspore-extendrt_static lite_src_common_mid) add_dependencies(mindspore-extendrt_static fbs_src fbs_inner_src) - add_dependencies(mindspore-extendrt_static mindspore-infer-anfalgo) - add_dependencies(mindspore-extendrt_static mindspore-kernel-graph) - - add_subdirectory(${CCSRC_DIR}/backend/common/pass common_pass) - add_subdirectory(${CCSRC_DIR}/backend/operator backend_operator) - add_subdirectory(${CCSRC_DIR}/backend/common/optimizer mindspore_ccsrc_backend_cmmon_optimizer) target_link_libraries(mindspore-extendrt mindspore_infer_shared_lib_obj) - target_link_libraries(mindspore-extendrt mindspore-infer-anfalgo - mindspore-kernel-graph _mindspore_backend_common_optimizer_obj - _mindspore_backend_common_pass_obj) target_link_libraries(mindspore-extendrt mindspore_core mindspore_ops mindspore::protobuf) target_link_libraries(mindspore-extendrt_static mindspore_infer_shared_lib_obj) - target_link_libraries(mindspore-extendrt_static mindspore-infer-anfalgo - mindspore-kernel-graph _mindspore_backend_common_optimizer_obj - _mindspore_backend_common_pass_obj _mindspore_backend_operator_obj) target_link_libraries(mindspore-extendrt_static mindspore_core mindspore_ops mindspore::protobuf) add_dependencies(mindspore-extendrt_static msplugin-ge-litert) target_link_libraries(mindspore-extendrt_static msplugin-ge-litert) add_subdirectory(${LITE_DIR}/src/executor unified_executor) - if(NOT PLATFORM_ARM) - add_dependencies(mindspore-extendrt _mindspore_cpu_kernel_mod_depend_obj - mindspore-lite-proto) - target_link_libraries(mindspore-extendrt _mindspore_cpu_kernel_mod_depend_obj - mindspore-lite-proto) - add_dependencies(mindspore-extendrt_static _mindspore_cpu_kernel_mod_depend_obj - mindspore-lite-proto) - target_link_libraries(mindspore-extendrt_static _mindspore_cpu_kernel_mod_depend_obj - mindspore-lite-proto) - if(MSLITE_DEPS_MKLDNN) - add_dependencies(mindspore-extendrt mindspore::dnnl) - target_link_libraries(mindspore-extendrt mindspore::dnnl) - add_dependencies(mindspore-extendrt_static mindspore::dnnl) - target_link_libraries(mindspore-extendrt_static mindspore::dnnl) - endif() - - if(MSLITE_DEPS_MKLDNN) - set(CPU_KERNEL_OBJECT_COUNT 0) - add_subdirectory(${OPS_DIR}/kernel/cpu lite_kernel_mod) - foreach(number RANGE 1 ${CPU_KERNEL_OBJECT_COUNT}) - target_link_libraries(mindspore-extendrt _mindspore_ops_cpu_kernel_obj) - target_link_libraries(mindspore-extendrt_static _mindspore_ops_cpu_kernel_obj) - endforeach() - endif() - - endif() - if(NOT WIN32) target_link_libraries(mindspore-extendrt dl) target_link_libraries(mindspore-extendrt_static dl) endif() if(MSLITE_ENABLE_ACL) - add_subdirectory(kernel/ascend) add_subdirectory(delegate/ascend_ge) + add_subdirectory(delegate/ascend_acl) endif() - - if(SUPPORT_CUDA) - set(CUDA_PATH $ENV{CUDA_HOME}) - set(ENABLE_GPU on) - add_definitions(-DENABLE_GPU) - set(CUDA_VERSION 11.1) - include_directories(${CUDA_PATH}) - include_directories(${CUDA_PATH}/include) - find_package(CUDA) - add_subdirectory(kernel/cuda) - list(APPEND CUDA_NVCC_FLAGS -arch=sm_53 --expt-relaxed-constexpr) - target_link_libraries(mindspore-extendrt cuda_lite_kernel_mid cuda_ops) - target_link_libraries(mindspore-extendrt_static cuda_lite_kernel_mid cuda_ops) - if(CMAKE_BUILD_TYPE STREQUAL "Debug") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -g -G -lineinfo) - endif() - endif() - - if(SUPPORT_TENSORRT) - add_definitions(-DSUPPORT_TENSORRT) - add_subdirectory(delegate/tensorrt) - endif() - if(MSLITE_ENABLE_CONVERTER) add_subdirectory(convert) target_link_libraries(mindspore-extendrt -Wl,--no-as-needed mindspore_converter) endif() - - if(MSLITE_EXPORT_COMPUTE_IR) - target_link_libraries(mindspore-extendrt mindspore_lite_drawer) - target_link_libraries(mindspore-extendrt_static mindspore_lite_drawer) - endif() else() set(MSLITE_EXTEND_RUNTIME_SRC ${MODEL_LOADER_FRAMEWORK_SRC}) add_library(mindspore-extendrt OBJECT ${MSLITE_EXTEND_RUNTIME_SRC}) diff --git a/mindspore-lite/src/extendrt/convert/CMakeLists.txt b/mindspore-lite/src/extendrt/convert/CMakeLists.txt index d4b4c8dfa04a1a37c2055c56b4aad978c4aa9d7b..1a22f5a216602b2dce0bf84e22f4d99d9807935c 100644 --- a/mindspore-lite/src/extendrt/convert/CMakeLists.txt +++ b/mindspore-lite/src/extendrt/convert/CMakeLists.txt @@ -1,5 +1,5 @@ include_directories(${TOP_DIR}) -include_directories(${TOP_DIR}/mindspore-lite) +include_directories(${TOP_DIR}/mindspor-lite) file(GLOB RUNTIME_CONVERT_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) diff --git a/mindspore-lite/src/extendrt/cxx_api/CMakeLists.txt b/mindspore-lite/src/extendrt/cxx_api/CMakeLists.txt index 8bf9aa78c1b143edcd42d349ba4df66b5b634677..ccd89ad35f5eeef0cc64f2cf6aa846f66827f59d 100644 --- a/mindspore-lite/src/extendrt/cxx_api/CMakeLists.txt +++ b/mindspore-lite/src/extendrt/cxx_api/CMakeLists.txt @@ -12,9 +12,6 @@ if(MODE_ASCEND_ACL) string(STRIP "${PY3_LIBG}" PY3_LIBG) message("Python3 general library = " ${PY3_LIBG}) endif() -# build mslite_shared_lib -include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc) -include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset) if(ENABLE_D OR ENABLE_ACL) # build 910 and 310 code into one distro, files needed for 310 mode @@ -56,30 +53,9 @@ set(MSLIB_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR}/model_pool/model_parallel_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/model_pool/model_parallel_runner_impl.cc ${CMAKE_CURRENT_SOURCE_DIR}/model_pool/resource_manager.cc - ${CMAKE_CURRENT_SOURCE_DIR}/llm_engine/llm_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/llm_engine/llm_engine_impl.cc ${API_MS_INFER_SRC} ${API_ACL_SRC} ${API_OPS_SRC} ) - -if(ENABLE_D) - list(APPEND MSLIB_INFER_SRC - "${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.cc" - "${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc" - "${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/group_manager.cc" - "${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/device_manager.cc" - "${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/device_matrix.cc" - "${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/tensor_layout/array.cc" - "${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/tensor_layout/map.cc" - "${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/tensor_layout/arrangement.cc" - "${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/tensor_layout/shape_util.cc" - "${CMAKE_SOURCE_DIR}/../ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc") -endif() - add_library(mindspore_infer_shared_lib_obj OBJECT ${MSLIB_INFER_SRC}) add_dependencies(mindspore_infer_shared_lib_obj fbs_inner_src) - -if(MSLITE_ENABLE_ACL) -add_subdirectory(llm_engine) -endif() diff --git a/mindspore-lite/src/extendrt/cxx_api/context.cc b/mindspore-lite/src/extendrt/cxx_api/context.cc index 758ae520ed2b3120e9928a3669ae4a5b28526949..af4d2a05eda70f481065b8316971aae977a03d8b 100644 --- a/mindspore-lite/src/extendrt/cxx_api/context.cc +++ b/mindspore-lite/src/extendrt/cxx_api/context.cc @@ -21,7 +21,6 @@ #include "include/lite_types.h" #include "src/litert/inner_allocator.h" #include "src/common/log_adapter.h" -#include "src/extendrt/delegate/tensorrt/distribution/distribution_base.h" #include "src/extendrt/delegate_graph_executor.h" namespace mindspore { @@ -286,108 +285,6 @@ bool CPUDeviceInfo::GetEnableFP16() const { return GetValue(data_, kModelOptionCpuEnableFP16); } -void GPUDeviceInfo::SetEnableFP16(bool is_fp16) { - if (data_ == nullptr) { - MS_LOG(ERROR) << "Invalid context."; - return; - } - if (is_fp16) { - data_->params[kModelOptionGPUPrecisionMode] = std::string("preferred_fp16"); - } else { - data_->params[kModelOptionGPUPrecisionMode] = std::string("enforce_fp32"); - } - data_->params[kModelOptionGPUEnableFP16] = is_fp16; -} - -bool GPUDeviceInfo::GetEnableFP16() const { - if (data_ == nullptr) { - MS_LOG(ERROR) << "Invalid context."; - return false; - } - return GetValue(data_, kModelOptionGPUEnableFP16); -} - -void GPUDeviceInfo::SetPrecisionMode(const std::vector &precision_mode) { - if (data_ == nullptr) { - MS_LOG(ERROR) << "Invalid context."; - return; - } - if (precision_mode == StringToChar("enforce_fp32")) { - data_->params[kModelOptionGPUEnableFP16] = false; - } else if (precision_mode == StringToChar("preferred_fp16")) { - data_->params[kModelOptionGPUEnableFP16] = true; - } else { - MS_LOG(ERROR) << "GPU only support mode enforce_fp32 and preferred_fp16. Now the precision mode is fp32 as default"; - return; - } - data_->params[kModelOptionGPUPrecisionMode] = CharToString(precision_mode); -} - -std::vector GPUDeviceInfo::GetPrecisionModeChar() const { - if (data_ == nullptr) { - MS_LOG(ERROR) << "Invalid context."; - return std::vector(); - } - const std::string &ref = GetValue(data_, kModelOptionGPUPrecisionMode); - return StringToChar(ref); -} - -void GPUDeviceInfo::SetEnableGLTexture(bool is_enable_gl_texture) { - MS_LOG(ERROR) << "Unsupported Feature."; - return; -} - -bool GPUDeviceInfo::GetEnableGLTexture() const { - MS_LOG(ERROR) << "Unsupported Feature."; - return false; -} - -void GPUDeviceInfo::SetGLContext(void *gl_context) { - MS_LOG(ERROR) << "Unsupported Feature."; - return; -} - -void *GPUDeviceInfo::GetGLContext() const { - MS_LOG(ERROR) << "Unsupported Feature."; - return nullptr; -} - -void GPUDeviceInfo::SetGLDisplay(void *gl_display) { - MS_LOG(ERROR) << "Unsupported Feature."; - return; -} - -void *GPUDeviceInfo::GetGLDisplay() const { - MS_LOG(ERROR) << "Unsupported Feature."; - return nullptr; -} - -void GPUDeviceInfo::SetDeviceID(uint32_t device_id) { - if (data_ == nullptr) { - MS_LOG(ERROR) << "Invalid context."; - return; - } - data_->params[kModelOptionGPUDeviceID] = device_id; -} - -uint32_t GPUDeviceInfo::GetDeviceID() const { - if (data_ == nullptr) { - MS_LOG(ERROR) << "Invalid context."; - return 0; - } - return GetValue(data_, kModelOptionGPUDeviceID); -} - -int GPUDeviceInfo::GetRankID() const { - data_->params[kModelOptionGPURankID] = lite::GetRankID(); - return GetValue(data_, kModelOptionGPURankID); -} - -int GPUDeviceInfo::GetGroupSize() const { - data_->params[kModelOptionGPUGroupSize] = lite::GetGPUGroupSize(); - return GetValue(data_, kModelOptionGPUGroupSize); -} - void AscendDeviceInfo::SetDeviceID(uint32_t device_id) { if (data_ == nullptr) { MS_LOG(ERROR) << "Invalid context."; diff --git a/mindspore-lite/src/extendrt/cxx_api/graph/graph_data.h b/mindspore-lite/src/extendrt/cxx_api/graph/graph_data.h index a852de107b2a1bbab270936281e9f4c75e9220a0..1b0ad53a6469721938bf1a372d5b328798a0b3b1 100644 --- a/mindspore-lite/src/extendrt/cxx_api/graph/graph_data.h +++ b/mindspore-lite/src/extendrt/cxx_api/graph/graph_data.h @@ -21,11 +21,12 @@ #include #include "include/api/graph.h" #include "include/api/types.h" -#include "include/dataset/execute.h" #include "ir/func_graph.h" namespace mindspore { - +namespace dataset { +class Execute; +} class Graph::GraphData { public: GraphData(); diff --git a/mindspore-lite/src/extendrt/cxx_api/llm_engine/CMakeLists.txt b/mindspore-lite/src/extendrt/cxx_api/llm_engine/CMakeLists.txt deleted file mode 100644 index 273ba207fb22df2e250f9f6dbdea71591d9c9bd2..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/cxx_api/llm_engine/CMakeLists.txt +++ /dev/null @@ -1,21 +0,0 @@ -include_directories(${CCSRC_DIR}) - -file(STRINGS "${TOP_DIR}/version.txt" MSVERSION) -add_definitions(-DMSVERSION=\"${MSVERSION}\") -add_compile_definitions(ENABLE_SECURITY) - -file(GLOB LLM_ENGINE_PLUGIN_SRC - ${CMAKE_CURRENT_SOURCE_DIR}/llm_engine_plugin.cc - ${CCSRC_DIR}/utils/config_manager.cc - ) - -add_library(llm_engine_plugin SHARED ${LLM_ENGINE_PLUGIN_SRC}) - -find_library(llm_engine libllm_engine.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) -find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) -find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - -target_link_libraries(llm_engine_plugin ${llm_engine} ${ge_graph} ${acl} - mindspore_converter mindspore_core mindspore_ops mindspore_graph_ir) - -target_link_libraries(llm_engine_plugin mindspore-extendrt) diff --git a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine.cc b/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine.cc deleted file mode 100644 index ea627b676417b5ffdd1cf2a1fa5fe613ac21a827..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine.cc +++ /dev/null @@ -1,157 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "extendrt/cxx_api/llm_engine/llm_engine.h" -#include "src/extendrt/cxx_api/dlutils.h" -#include "src/extendrt/cxx_api/file_utils.h" -#include "src/common/common.h" -#include "src/extendrt/cxx_api/llm_engine/llm_engine_impl.h" - -namespace mindspore { -LLMEngine::LLMEngine(LLMRole role, uint64_t cluster_id, const VecChar &batch_mode) { - impl_ = std::make_shared(role, cluster_id, CharToString(batch_mode)); -} - -Status LLMEngine::AddModelInner(mindspore::LLMModel *llm_model, const std::vector &model_paths_c, - const std::map &options_c, const VecChar &postprocess_model_path_c) { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return kLiteError; - } - if (llm_model == nullptr || llm_model->impl_ == nullptr) { - MS_LOG(ERROR) << "Failed to add model, input argument llm_model is nullptr"; - return kLiteError; - } - auto model_paths = VectorCharToString(model_paths_c); - auto options = MapVectorCharToString(options_c); - auto postprocess_model_path = CharToString(postprocess_model_path_c); - uint64_t model_id = 0; - auto status = impl_->AddModel(model_paths, options, postprocess_model_path, &model_id); - if (status != kSuccess) { - return status; - } - llm_model->impl_->SetModelId(model_id); - llm_model->impl_->SetLLMEngine(impl_); - return kSuccess; -} - -Status LLMEngine::InitInner(const std::map &options) { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return kLiteError; - } - return impl_->Init(MapVectorCharToString(options)); -} - -void LLMEngine::Finalize() { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return; - } - impl_->Finalize(); -} - -Status LLMEngine::LinkClusters(const std::vector &clusters, std::vector *rets, - int32_t timeout) { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return kLiteError; - } - return impl_->LinkClusters(clusters, rets, timeout); -} - -Status LLMEngine::UnlinkClusters(const std::vector &clusters, std::vector *rets, - int32_t timeout) { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return kLiteError; - } - return impl_->UnlinkClusters(clusters, rets, timeout); -} - -LLMEngineStatus LLMEngine::FetchStatus() { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return LLMEngineStatus(); - } - return impl_->FetchStatus(); -} - -Status LLMEngine::CompleteRequest(const LLMReq &req) { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return kLiteError; - } - return impl_->CompleteRequest(req); -} - -LLMModel::LLMModel() { impl_ = std::make_shared(); } - -Status LLMModel::Predict(const LLMReq &req, const std::vector &inputs, std::vector *outputs) { - if (impl_ == nullptr || impl_->GetLLMEngine() == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return kLiteError; - } - return impl_->GetLLMEngine()->Predict(req, inputs, outputs, impl_->GetModelId()); -} - -Status LLMModel::Predict(const std::vector &req, const std::vector &inputs, - std::vector *outputs) { - if (impl_ == nullptr || impl_->GetLLMEngine() == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return kLiteError; - } - return impl_->GetLLMEngine()->Predict(req, inputs, outputs, impl_->GetModelId()); -} -Status LLMModel::PreloadPromptPrefix(const LLMReq &req, const std::vector &inputs) { - if (impl_ == nullptr || impl_->GetLLMEngine() == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return kLiteError; - } - return impl_->GetLLMEngine()->PreloadPromptPrefix(req, inputs, impl_->GetModelId()); -} - -Status LLMModel::ReleasePromptPrefix(const LLMReq &req) { - if (impl_ == nullptr || impl_->GetLLMEngine() == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return kLiteError; - } - return impl_->GetLLMEngine()->ReleasePromptPrefix(req, impl_->GetModelId()); -} - -Status LLMModel::PullKV(const LLMReq &req) { - if (impl_ == nullptr || impl_->GetLLMEngine() == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return kLiteError; - } - return impl_->GetLLMEngine()->PullKV(req, impl_->GetModelId()); -} - -Status LLMModel::MergeKV(const LLMReq &req, uint32_t batch_index, uint32_t batch_id) { - if (impl_ == nullptr || impl_->GetLLMEngine() == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return kLiteError; - } - return impl_->GetLLMEngine()->MergeKV(req, batch_index, batch_id, impl_->GetModelId()); -} - -std::vector LLMModel::GetInputs() { - if (impl_ == nullptr || impl_->GetLLMEngine() == nullptr) { - MS_LOG(ERROR) << "LLMEngine impl is nullptr"; - return {}; - } - return impl_->GetLLMEngine()->GetInputs(impl_->GetModelId()); -} -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine.h b/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine.h deleted file mode 100644 index 6296678de4c7a29db138931c7dd0eb846cfaedaf..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine.h +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_INCLUDE_API_LLM_ENGINE_H_ -#define MINDSPORE_INCLUDE_API_LLM_ENGINE_H_ -#include -#include -#include -#include -#include -#include "include/api/types.h" -#include "include/api/status.h" - -namespace mindspore { -struct LLMReq { - uint64_t req_id = UINT64_MAX; - uint64_t prompt_length = 0; - uint64_t prompt_cluster_id = 0; - uint64_t decoder_cluster_id = 0; - uint64_t prefix_id = UINT64_MAX; - uint64_t sequence_length = 0; -}; - -struct LLMIpInfo { - uint32_t ip; - uint16_t port; -}; - -struct LLMClusterInfo { - uint64_t remote_cluster_id; - int32_t remote_role_type; - std::vector local_ip_infos; - std::vector remote_ip_infos; -}; - -struct LLMEngineStatus { - uint64_t empty_max_prompt_kv = 0; - int32_t num_free_blocks = 0; - int32_t num_total_blocks = 0; - int32_t block_size = 0; -}; - -enum LLMRole { - kLLMRolePrompt = 0, - kLLMRoleDecoder = 1, -}; - -struct LLMTensorInfo { - std::string name; - std::vector shape; - DataType dtype; -}; - -class LLMModel; -class LLMEngineImpl; -class MS_API LLMEngine { - public: - LLMEngine(LLMRole role, uint64_t cluster_id, const std::string &batch_mode = "auto") - : LLMEngine(role, cluster_id, StringToChar(batch_mode)) {} - ~LLMEngine() = default; - Status AddModel(LLMModel *llm_model, const std::vector &model_paths, - const std::map &options, const std::string &postprocess_model_path = "") { - return AddModelInner(llm_model, VectorStringToChar(model_paths), MapStringToVectorChar(options), - StringToChar(postprocess_model_path)); - } - Status Init(const std::map &options) { return InitInner(MapStringToVectorChar(options)); } - - void Finalize(); - LLMEngineStatus FetchStatus(); - Status LinkClusters(const std::vector &clusters, std::vector *rets, int32_t timeout = -1); - Status UnlinkClusters(const std::vector &clusters, std::vector *rets, int32_t timeout = -1); - Status CompleteRequest(const LLMReq &req); - - private: - std::shared_ptr impl_ = nullptr; - friend class LLMModel; - - LLMEngine(LLMRole role, uint64_t cluster_id, const std::vector &batch_mode); - Status AddModelInner(LLMModel *llm_model, const std::vector &model_paths, - const std::map &options, const VecChar &postprocess_model_path); - Status InitInner(const std::map &options); -}; - -class LLMModelImpl; -class MS_API LLMModel { - public: - LLMModel(); - ~LLMModel() = default; - Status Predict(const LLMReq &req, const std::vector &inputs, std::vector *outputs); - Status Predict(const std::vector &req, const std::vector &inputs, std::vector *outputs); - LLMEngineStatus FetchStatus(); - Status PreloadPromptPrefix(const LLMReq &req, const std::vector &inputs); - Status ReleasePromptPrefix(const LLMReq &req); - Status PullKV(const LLMReq &req); - Status MergeKV(const LLMReq &req, uint32_t batch_index, uint32_t batch_id); - std::vector GetInputs(); - - private: - std::shared_ptr impl_ = nullptr; - friend class LLMEngine; -}; -} // namespace mindspore -#endif // MINDSPORE_INCLUDE_API_LLM_ENGINE_H_ diff --git a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_impl.cc b/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_impl.cc deleted file mode 100644 index 47b5a928c3acfb2d04eab7b3537245d53b5330eb..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_impl.cc +++ /dev/null @@ -1,367 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "extendrt/cxx_api/llm_engine/llm_engine_impl.h" -#include -#include "src/extendrt/cxx_api/dlutils.h" -#include "src/extendrt/cxx_api/file_utils.h" -#include "load_mindir/load_model.h" -#include "extendrt/cxx_api/llm_engine/llm_engine_plugin.h" -#include "src/common/common.h" -#include "tools/common/custom_ascend_utils.h" -#include "src/extendrt/utils/func_graph_utils.h" -#include "src/extendrt/utils/tensor_default_impl.h" - -namespace mindspore { -namespace { -constexpr auto kLLMEnginePluginSoName = "libllm_engine_plugin.so"; -constexpr auto kLLMEngineCreatePluginFuncName = "CreateLLMEnginePlugin"; -} // namespace - -bool LLEnginePluginLoader::Register() { - if (create_plugin_func_ != nullptr) { - return kSuccess; - } - std::string plugin_path; - auto ret = DLSoPath({"libmindspore-lite.so", "_c_lite"}, kLLMEnginePluginSoName, &plugin_path); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Get real path of " << kLLMEnginePluginSoName << " failed."; - return false; - } - MS_LOG(INFO) << "Find LLMEngine plugin so success, path = " << plugin_path; - void *function = nullptr; - ret = DLSoOpen(plugin_path, kLLMEngineCreatePluginFuncName, &handle_, &function); - if (ret != kSuccess) { - MS_LOG(ERROR) << "DLSoOpen failed, so path: " << plugin_path << ", err: " << ret.ToString(); - return false; - } - create_plugin_func_ = reinterpret_cast(function); - if (create_plugin_func_ == nullptr) { - MS_LOG(ERROR) << "Cast " << kLLMEngineCreatePluginFuncName << " failed."; - return false; - } - MS_LOG(INFO) << "Register LLMEngine plugin success."; - return true; -} - -std::shared_ptr LLEnginePluginLoader::CreatePlugin(LLMRole role, uint64_t cluster_id, - const std::string &batch_mode) { - if (!Register()) { - MS_LOG(ERROR) << "Failed to register " << kLLMEnginePluginSoName; - return nullptr; - } - if (create_plugin_func_ == nullptr) { - MS_LOG(ERROR) << "Create plugin func is nullptr"; - return nullptr; - } - return std::shared_ptr(create_plugin_func_(role, cluster_id, batch_mode)); -} - -LLMEngineImpl::LLMEngineImpl(LLMRole role, uint64_t cluster_id, const std::string &batch_mode) - : role_(role), cluster_id_(cluster_id), batch_mode_(batch_mode) {} - -Status LLMEngineImpl::InitPlugin() { - if (plugin_ != nullptr) { - return kSuccess; - } - plugin_ = LLEnginePluginLoader::Instance().CreatePlugin(role_, cluster_id_, batch_mode_); - if (plugin_ == nullptr) { - MS_LOG(ERROR) << "Failed to create LLMEngine plugin"; - return kLiteError; - } - return kSuccess; -} - -Status LLMEngineImpl::Init(const std::map &options) { - if (inited_) { - MS_LOG(ERROR) << "LLMEngine has been inited or inited failed"; - return kLiteError; - } - auto status = InitPlugin(); - if (status != kSuccess) { - MS_LOG(ERROR) << "Failed to init LLMEngine plugin"; - return status; - } - status = plugin_->Init(options); - inited_ = true; - if (status != kSuccess) { - MS_LOG(ERROR) << "Failed to init LLMEngine"; - return status; - } - return kSuccess; -} - -void LLMEngineImpl::Finalize() { - if (plugin_ == nullptr) { - MS_LOG(INFO) << "LLMEngine plugin has not been created"; - return; - } - plugin_->Finalize(); -} - -Status LLMEngineImpl::AddModel(const std::vector &model_paths, - const std::map &options, - const std::string &postprocess_model_path, uint64_t *model_id) { - if (inited_) { - MS_LOG(ERROR) << "LLMEngine has been inited or inited failed"; - return kLiteError; - } - auto status = InitPlugin(); - if (status != kSuccess) { - MS_LOG(ERROR) << "Failed to init LLMEngine plugin"; - return status; - } - std::vector infos; - std::set names; - for (auto &model_path : model_paths) { - LLMEngineModelInfo model_info; - if (LoadAndGetModelInfo(model_path, &model_info) != kSuccess) { - MS_LOG(ERROR) << "Failed to ge graph info, mindir " << model_path; - return kLiteError; - } - infos.push_back(model_info); - names.emplace(model_info.name); - } - if (names.size() != infos.size()) { - for (size_t i = 0; i < infos.size(); i++) { - infos[i].name += "_U" + std::to_string(i); // make unique name - } - } - LLMEngineModelInfo postprocess_model_info; - if (!postprocess_model_path.empty()) { - if (LoadAndGetModelInfo(postprocess_model_path, &postprocess_model_info) != kSuccess) { - MS_LOG(ERROR) << "Failed to ge graph info, mindir " << postprocess_model_path; - return kLiteError; - } - } - status = plugin_->AddModel(infos, options, postprocess_model_info, model_id); - if (status != kSuccess) { - MS_LOG(ERROR) << "Failed to Add Model"; - return status; - } - if (!infos.empty()) { - std::vector input_infos; - auto &info = infos[0]; - for (size_t i = 0; i < info.input_names.size(); i++) { - LLMTensorInfo tensor_info; - tensor_info.name = info.input_names[i]; - tensor_info.shape = info.input_shapes[i]; - tensor_info.dtype = static_cast(info.input_dtypes[i]); - input_infos.push_back(tensor_info); - } - model_infos_[*model_id] = input_infos; - } - return status; -} - -Status LLMEngineImpl::Predict(const LLMReq &req, const std::vector &inputs, std::vector *outputs, - uint64_t model_id) { - if (plugin_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine plugin has not been created"; - return kLiteError; - } - return plugin_->Predict(req, inputs, outputs, model_id); -} - -Status LLMEngineImpl::Predict(const std::vector &req, const std::vector &inputs, - std::vector *outputs, uint64_t model_id) { - if (plugin_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine plugin has not been created"; - return kLiteError; - } - return plugin_->Predict(req, inputs, outputs, model_id); -} - -Status LLMEngineImpl::CompleteRequest(const LLMReq &req) { - if (plugin_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine plugin has not been created"; - return kLiteError; - } - return plugin_->CompleteRequest(req); -} - -LLMEngineStatus LLMEngineImpl::FetchStatus() { - if (plugin_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine plugin has not been created"; - return LLMEngineStatus(); - } - return plugin_->FetchStatus(); -} - -Status LLMEngineImpl::PreloadPromptPrefix(const LLMReq &req, const std::vector &inputs, uint64_t model_id) { - if (plugin_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine plugin has not been created"; - return kLiteError; - } - return plugin_->PreloadPromptPrefix(req, inputs, model_id); -} - -Status LLMEngineImpl::ReleasePromptPrefix(const LLMReq &req, uint64_t model_id) { - if (plugin_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine plugin has not been created"; - return kLiteError; - } - return plugin_->ReleasePromptPrefix(req, model_id); -} - -Status LLMEngineImpl::PullKV(const LLMReq &req, uint64_t model_id) { - if (plugin_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine plugin has not been created"; - return kLiteError; - } - return plugin_->PullKV(req, model_id); -} - -Status LLMEngineImpl::MergeKV(const LLMReq &req, uint32_t batch_index, uint32_t batch_id, uint64_t model_id) { - if (plugin_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine plugin has not been created"; - return kLiteError; - } - return plugin_->MergeKV(req, batch_index, batch_id, model_id); -} - -std::vector LLMEngineImpl::GetInputs(uint64_t model_id) { - if (plugin_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine plugin has not been created"; - return {}; - } - auto it = model_infos_.find(model_id); - if (it == model_infos_.end()) { - MS_LOG(ERROR) << "Cannot find model info for model " << model_id; - return {}; - } - auto input_infos = it->second; - std::vector tensors; - for (auto &item : input_infos) { - auto tensor_impl = std::make_shared(item.name, item.dtype, item.shape); - tensors.push_back(MSTensor(tensor_impl)); - } - return tensors; -} - -Status LLMEngineImpl::LinkClusters(const std::vector &clusters, std::vector *rets, - int32_t timeout) { - if (plugin_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine plugin has not been created"; - return kLiteError; - } - return plugin_->LinkClusters(clusters, rets, timeout); -} - -Status LLMEngineImpl::UnlinkClusters(const std::vector &clusters, std::vector *rets, - int32_t timeout) { - if (plugin_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine plugin has not been created"; - return kLiteError; - } - return plugin_->UnlinkClusters(clusters, rets, timeout); -} - -Status LLMEngineImpl::GetModelInfo(const FuncGraphPtr &func_graph, LLMEngineModelInfo *model_info) { - if (func_graph == nullptr || model_info == nullptr) { - return kLiteNullptr; - } - if (!CustomAscendUtils::IsCustomFuncGraph(func_graph)) { - MS_LOG(ERROR) << "LLMEngine model should be converted to offline compiled model by mindspore_lite.Converter"; - return kLiteError; - } - std::map attr_map; - std::vector> ref_datas; - DynKVCacheSaveInfo kv_info; - auto ret = CustomAscendUtils::ParseCustomFuncGraph(func_graph, &model_info->om_data, &model_info->name, &attr_map, - &ref_datas, &kv_info); - if (!ret) { - MS_LOG(ERROR) << "Failed to parse custom func graph"; - return kLiteError; - } - for (auto &item : func_graph->get_inputs()) { - auto shape = FuncGraphUtils::GetTensorShape({item, 0}); - auto type_id = FuncGraphUtils::GetTensorDataType({item, 0}); - model_info->input_shapes.push_back(shape); - model_info->input_dtypes.push_back(static_cast(type_id)); - model_info->input_names.push_back(item->fullname_with_scope()); - } - for (auto &item : ref_datas) { - auto &tensor = item.second; - auto ref_shape = - SetKVCacheShape(kv_info.batch_size_dyn, kv_info.seq_length_dyn, kv_info.kv_cache_layout, tensor->shape_c()); - model_info->ref_input_shapes.push_back(ref_shape); - model_info->ref_input_dtypes.push_back(static_cast(tensor->data_type())); - } - auto attr_it = attr_map.find(lite::kNameAttrWeightDir); - if (attr_it == attr_map.end()) { - MS_LOG(ERROR) << "Failed to attr " << lite::kNameAttrWeightDir; - return kLiteError; - } - auto &attr_val = attr_it->second; - if (!attr_val->isa()) { - MS_LOG(ERROR) << "Failed to attr " << lite::kNameAttrWeightDir << ", attr type is " << attr_val->type_name(); - return kLiteError; - } - model_info->weight_dir = GetValue(attr_it->second); - MS_LOG(INFO) << "Get graph attr " << lite::kNameAttrWeightDir << " " << model_info->weight_dir; - std::vector outputs; - if (!FuncGraphUtils::GetFuncGraphOutputs(func_graph, &outputs)) { - MS_LOG(ERROR) << "Failed to get func graph outputs"; - return kLiteError; - } - model_info->output_count = outputs.size(); - return kSuccess; -} - -FuncGraphPtr LLMEngineImpl::LoadMindIR(const std::string &model_path) { - if (model_path.empty()) { - MS_LOG(ERROR) << "Model path cannot be empty"; - return nullptr; - } - auto buffer = ReadFile(model_path); - if (buffer.Data() == nullptr || buffer.DataSize() == 0) { - MS_LOG(ERROR) << "Failed to read buffer from model file: " << model_path; - return nullptr; - } - std::string weight_path = "./"; - if (model_path.find("/") != std::string::npos) { - weight_path = model_path.substr(0, model_path.rfind("/")); - } - MindIRLoader mindir_loader(true, nullptr, 0, kDecModeAesGcm, false); - auto func_graph = mindir_loader.LoadMindIR(buffer.Data(), buffer.DataSize(), weight_path); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Failed to load MindIR model, please check the validity of the model: " << model_path; - return nullptr; - } - return func_graph; -} - -Status LLMEngineImpl::LoadAndGetModelInfo(const std::string &model_path, LLMEngineModelInfo *model_info_ptr) { - auto func_graph = LoadMindIR(model_path); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Failed to load mindir " << model_path; - return kLiteError; - } - LLMEngineModelInfo &model_info = *model_info_ptr; - if (GetModelInfo(func_graph, &model_info) != kSuccess) { - MS_LOG(ERROR) << "Failed to ge graph info, mindir " << model_path; - return kLiteError; - } - // relative weight path - if (!model_info.weight_dir.empty() && model_info.weight_dir[0] != '/') { - if (model_path.find("/") != std::string::npos) { - model_info.weight_dir = model_path.substr(0, model_path.rfind("/") + 1) + model_info.weight_dir; - MS_LOG(INFO) << "Update " << model_path << " weight dir to " << model_info.weight_dir; - } - } - return kSuccess; -} -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_impl.h b/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_impl.h deleted file mode 100644 index fb7fffd18854d28a61adfad7b13e6f58d7b06283..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_impl.h +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_LLM_ENGINE_IMPL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_LLM_ENGINE_IMPL_H_ -#include -#include -#include -#include -#include -#include "include/api/types.h" -#include "include/api/status.h" -#include "extendrt/cxx_api/llm_engine/llm_engine.h" -#include "base/base.h" -#include "extendrt/cxx_api/llm_engine/llm_engine_plugin.h" - -namespace mindspore { -class MS_API LLMEngineImpl { - public: - LLMEngineImpl(LLMRole role, uint64_t cluster_id, const std::string &batch_mode = "auto"); - ~LLMEngineImpl() = default; - Status AddModel(const std::vector &model_paths, const std::map &options, - const std::string &postprocess_model_path, uint64_t *model_id); - Status Init(const std::map &options); - - void Finalize(); - Status Predict(const LLMReq &req, const std::vector &inputs, std::vector *outputs, - uint64_t model_id); - Status Predict(const std::vector &req, const std::vector &inputs, std::vector *outputs, - uint64_t model_id); - Status CompleteRequest(const LLMReq &req); - Status PreloadPromptPrefix(const LLMReq &req, const std::vector &inputs, uint64_t model_id); - Status ReleasePromptPrefix(const LLMReq &req, uint64_t model_id); - Status PullKV(const LLMReq &req, uint64_t model_id); - Status MergeKV(const LLMReq &req, uint32_t batch_index, uint32_t batch_id, uint64_t model_id); - std::vector GetInputs(uint64_t model_id); - - LLMEngineStatus FetchStatus(); - Status LinkClusters(const std::vector &clusters, std::vector *rets, int32_t timeout = -1); - Status UnlinkClusters(const std::vector &clusters, std::vector *rets, int32_t timeout = -1); - - private: - LLMRole role_; - uint64_t cluster_id_; - std::string batch_mode_; - std::map> model_infos_; - bool inited_ = false; - std::shared_ptr plugin_ = nullptr; - - Status GetModelInfo(const FuncGraphPtr &func_graph, LLMEngineModelInfo *model_info); - Status LoadAndGetModelInfo(const std::string &model_path, LLMEngineModelInfo *model_info_ptr); - FuncGraphPtr LoadMindIR(const std::string &model_path); - Status InitPlugin(); -}; - -typedef LLMEnginePluginBase *(*CreateLLMEnginePluginFunc)(LLMRole, uint64_t, const std::string &); - -class LLEnginePluginLoader { - public: - static LLEnginePluginLoader &Instance() { - static LLEnginePluginLoader instance; - return instance; - } - std::shared_ptr CreatePlugin(LLMRole role, uint64_t cluster_id, const std::string &batch_mode); - - private: - void *handle_ = nullptr; - CreateLLMEnginePluginFunc create_plugin_func_ = nullptr; - bool Register(); -}; - -class LLMModelImpl { - public: - LLMModelImpl() = default; - ~LLMModelImpl() = default; - - void SetModelId(uint64_t model_id) { model_id_ = model_id; } - uint64_t GetModelId() const { return model_id_; } - - void SetLLMEngine(const std::shared_ptr &llm_engine) { engine_impl_ = llm_engine; } - std::shared_ptr GetLLMEngine() const { return engine_impl_; } - - private: - uint64_t model_id_ = 0; - std::shared_ptr engine_impl_ = nullptr; -}; -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_LLM_ENGINE_IMPL_H_ diff --git a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_plugin.cc b/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_plugin.cc deleted file mode 100644 index 9876fcbaedafb0ebbea25488393b8a207e037c9c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_plugin.cc +++ /dev/null @@ -1,974 +0,0 @@ -/** - * Copyright 2023-2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "extendrt/cxx_api/llm_engine/llm_engine_plugin.h" -#include -#include "src/extendrt/cxx_api/dlutils.h" -#include "load_mindir/load_model.h" -#include "mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/transform_util.h" -#include "src/extendrt/utils/tensor_utils.h" -#include "src/common/common.h" -#include "src/common/utils.h" -#include "ge/llm_engine.h" -#include "external/ge_common/ge_api_error_codes.h" -#include "ge/llm_error_codes.h" - -namespace mindspore { -struct LLMModelInfo { - std::vector model_infos; - std::map options; - LLMEngineModelInfo postprocess_model; -}; - -class LLMEnginePlugin : public LLMEnginePluginBase { - public: - LLMEnginePlugin(LLMRole role, uint64_t cluster_id, const std::string &batch_mode) - : LLMEnginePluginBase(role, cluster_id, batch_mode), llm_engine_(std::make_shared(cluster_id_)) {} - ~LLMEnginePlugin(); - Status AddModel(const std::vector &model_infos, const std::map &options, - const LLMEngineModelInfo &postprocess_model, uint64_t *model_id) override; - Status Init(const std::map &options) override; - void Finalize() override; - LLMEngineStatus FetchStatus() override; - - Status Predict(const LLMReq &req, const std::vector &inputs, std::vector *outputs, - uint64_t model_id) override; - - Status Predict(const std::vector &req, const std::vector &inputs, std::vector *outputs, - uint64_t model_id) override; - Status CompleteRequest(const LLMReq &req) override; - Status PreloadPromptPrefix(const LLMReq &req, const std::vector &inputs, uint64_t model_id) override; - Status ReleasePromptPrefix(const LLMReq &req, uint64_t model_id) override; - - Status PullKV(const LLMReq &req, uint64_t model_id) override; - Status MergeKV(const LLMReq &req, uint32_t batch_index, uint32_t batch_id, uint64_t model_id) override; - - Status LinkClusters(const std::vector &, std::vector *rets, int32_t timeout) override; - Status UnlinkClusters(const std::vector &, std::vector *rets, int32_t timeout) override; - - private: - std::shared_ptr<::llm::LLMEngine> llm_engine_ = nullptr; - bool finalized_ = false; - bool inited_ = false; - std::map model_infos_; - int32_t num_total_blocks_ = 0; // For PagedAttention - int32_t block_size_ = 0; // For PagedAttention - int32_t max_seq_len_ = 0; // For PagedAttention - - MSTensor ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor_ptr); - Status Run(const llm::LLMReq &req, const std::vector<::ge::Tensor> &ge_inputs, std::vector<::ge::Tensor> *ge_outputs, - uint64_t model_id); - Status Run(const std::vector &req, const std::vector<::ge::Tensor> &ge_inputs, - std::vector<::ge::Tensor> *ge_outputs, uint64_t model_id); - Status CheckModelInfos(const std::vector &model_infos); - void InitInputOptions(const LLMEngineModelInfo &model_info, bool postprocess, - std::map *options); - static void TransLLMReq(const LLMReq &req, llm::LLMReq *llm_req); - static void TransLLMClusterInfos(const std::vector &clusters, - std::vector *llm_clusters); - Status MSTensorToGeTensor(const std::vector &inputs, std::vector<::ge::Tensor> *ge_inputs); - Status OnGeStatus(ge::Status ge_status, const std::string &func_s, const std::string &phase); - void SetPagedAttentionOptions(std::map *options); -}; - -LLMEnginePluginBase *CreateLLMEnginePlugin(LLMRole role, uint64_t cluster_id, const std::string &batch_mode) { - return new LLMEnginePlugin(role, cluster_id, batch_mode); -} - -LLMEnginePlugin::~LLMEnginePlugin() { LLMEnginePlugin::Finalize(); } - -Status LLMEnginePlugin::CheckModelInfos(const std::vector &model_infos) { - for (size_t i = 1; i < model_infos.size(); i++) { - if (model_infos[i].input_shapes != model_infos[0].input_shapes) { - MS_LOG(ERROR) << "Model " << i << " input shapes " << model_infos[i].input_shapes << " != that " - << model_infos[0].input_shapes << " of model 0"; - return kLiteError; - } - if (model_infos[i].input_dtypes != model_infos[0].input_dtypes) { - MS_LOG(ERROR) << "Model " << i << " input dtypes " << model_infos[i].input_dtypes << " != that " - << model_infos[0].input_dtypes << " of model 0"; - return kLiteError; - } - if (model_infos[i].ref_input_shapes != model_infos[0].ref_input_shapes) { - MS_LOG(ERROR) << "Model " << i << " ref data input shapes " << model_infos[i].ref_input_shapes << " != that " - << model_infos[0].ref_input_shapes << " of model 0"; - return kLiteError; - } - if (model_infos[i].ref_input_dtypes != model_infos[0].ref_input_dtypes) { - MS_LOG(ERROR) << "Model " << i << " ref data input dtypes " << model_infos[i].ref_input_dtypes << " != that " - << model_infos[0].ref_input_dtypes << " of model 0"; - return kLiteError; - } - } - return kSuccess; -} - -void LLMEnginePlugin::InitInputOptions(const LLMEngineModelInfo &model_info, bool postprocess, - std::map *options_ptr) { - auto shape_as_string = [](const ShapeVector &shape) { - std::string str; - for (size_t i = 0; i < shape.size(); i++) { - str += std::to_string(shape[i]); - if (i + 1 < shape.size()) { - str += ","; - } - } - return str; - }; - auto dtype_as_string = [](TypeId type_id) { - auto ge_type = device::ascend::TransformUtil::ConvertDataType(type_id); - return std::to_string(static_cast(ge_type)); - }; - std::string input_shapes; - std::string input_dtypes; - std::string ref_input_shapes; - std::string ref_input_dtypes; - for (auto &item : model_info.input_shapes) { - input_shapes += shape_as_string(item) + ";"; - } - for (auto &item : model_info.input_dtypes) { - input_dtypes += dtype_as_string(item) + ";"; - } - for (auto &item : model_info.ref_input_shapes) { - ref_input_shapes += shape_as_string(item) + ";"; - } - for (auto &item : model_info.ref_input_dtypes) { - ref_input_dtypes += dtype_as_string(item) + ";"; - } - auto &options = *options_ptr; - auto erase_comma = [](const std::string &str) { return str.empty() ? str : str.substr(0, str.size() - 1); }; - if (!postprocess) { - options["llm.InputShapes"] = erase_comma(input_shapes); - options["llm.InputDtypes"] = erase_comma(input_dtypes); - options["llm.RefInputShapes"] = erase_comma(ref_input_shapes); - options["llm.RefInputDtypes"] = erase_comma(ref_input_dtypes); - options["llm.OutputNums"] = std::to_string(model_info.output_count); - } else { - options["llm.PostProcessInputShapes"] = erase_comma(input_shapes); - options["llm.PostProcessInputDtypes"] = erase_comma(input_dtypes); - options["llm.PostProcessOutputNums"] = std::to_string(model_info.output_count); - options["llm.PostProcessOmCachePath"] = model_info.weight_dir; - } -} - -using ErrorCodeMap = std::unordered_map>; - -static ErrorCodeMap error_map = { - {ge::GRAPH_SUCCESS, [](const std::string &func_s, const std::string &phase) { - MS_LOG(INFO) << "End call llm::LLMEngine::" << func_s; - return kSuccess;} - }, - {ge::LLM_WAIT_PROC_TIMEOUT, [](const std::string &func_s, const std::string &phase) { - MS_LOG(WARNING) << "Failed to call llm::LLMEngine::" << func_s << " " - << ", " << phase << " status: LLM_WAIT_PROC_TIMEOUT"; - return kLiteLLMWaitProcessTimeOut;} - }, - {ge::LLM_KV_CACHE_NOT_EXIST, [](const std::string &func_s, const std::string &phase) { - MS_LOG(WARNING) << "Failed to call llm::LLMEngine::" << func_s - << " " << phase << " status: LLM_KV_CACHE_NOT_EXIST"; - return kLiteLLMKVCacheNotExist;} - }, - {ge::LLM_REPEAT_REQUEST, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " << phase << " status: LLM_REPEAT_REQUEST"; - return kLiteLLMRepeatRequest;} - }, - {ge::LLM_REQUEST_ALREADY_COMPLETED, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " << phase - << " receive LLM_REQUEST_ALREADY_COMPLETED"; - return kLiteLLMRequestAlreadyCompleted;} - }, - {ge::LLM_ENGINE_FINALIZED, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " << phase << " status: LLM_ENGINE_FINALIZED"; - return kLiteLLMRequestAlreadyCompleted;} - }, - {ge::LLM_PARAM_INVALID, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " << phase << " status: LLM_PARAM_INVALID"; - return kLiteParamInvalid;} - }, - {ge::LLM_NOT_YET_LINK, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " << phase << " status: LLM_NOT_YET_LINK"; - return kLiteLLMNotYetLink;} - }, - {ge::LLM_ALREADY_LINK, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " << phase << " status: LLM_ALREADY_LINK"; - return kLiteLLMAlreadyLink;} - }, - {ge::LLM_LINK_FAILED, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " << phase << " status: LLM_LINK_FAILED"; - return kLiteLLMLinkFailed;} - }, - {ge::LLM_UNLINK_FAILED, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " << phase << " status: LLM_UNLINK_FAILED"; - return kLiteLLMUnlinkFailed;} - }, - {ge::LLM_NOTIFY_PROMPT_UNLINK_FAILED, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " - << phase << " status: LLM_NOTIFY_PROMPT_UNLINK_FAILED"; - return kLiteLLMNofiryPromptUnlinkFailed;} - }, - {ge::LLM_CLUSTER_NUM_EXCEED_LIMIT, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " - << phase << " status: LLM_CLUSTER_NUM_EXCEED_LIMIT"; - return kLiteLLMClusterNumExceedLimit;} - }, - {ge::LLM_PROCESSING_LINK, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " << phase << " status: LLM_PROCESSING_LINK"; - return kLiteLLMProcessingLink;} - }, - {ge::LLM_DEVICE_OUT_OF_MEMORY, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " - << phase << " status: LLM_DEVICE_OUT_OF_MEMORY"; - return kLiteLLMOutOfMemory;} - }, - {ge::LLM_PREFIX_ALREADY_EXIST, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " - << phase << " status: LLM_PREFIX_ALREADY_EXIST"; - return kLiteLLMPrefixAlreadyExist;} - }, - {ge::LLM_PREFIX_NOT_EXIST, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " - << phase << " status: LLM_PREFIX_NOT_EXIST"; - return kLiteLLMPrefixNotExist;} - }, - {ge::LLM_SEQ_LEN_OVER_LIMIT, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " - << phase << " status: LLM_SEQ_LEN_OVER_LIMIT"; - return kLiteLLMSeqLenOverLimit;} - }, - {ge::LLM_NO_FREE_BLOCK, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " - << phase << " status: LLM_NO_FREE_BLOCK"; - return kLiteLLMNoFreeBlock;} - }, - {ge::LLM_BLOCKS_OUT_OF_MEMORY, [](const std::string &func_s, const std::string &phase) { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " - << phase << " status: LLM_BLOCKS_OUT_OF_MEMORY"; - return kLiteLLMBlockOutOfMemory;} - } -}; - -Status LLMEnginePlugin::OnGeStatus(ge::Status ge_status, const std::string &func_s, const std::string &phase) { - Status lite_status; - if (error_map.count(ge_status) == 1) { - lite_status = error_map[ge_status](func_s, phase); - } else { - MS_LOG(ERROR) << "Failed to call llm::LLMEngine::" << func_s << " " << phase << " status: " << ge_status; - lite_status = kLiteError; - } - return lite_status; -} - -void LLMEnginePlugin::SetPagedAttentionOptions(std::map *options_ptr) { - auto &options = *options_ptr; - auto it = options.find("llm.EnablePagedAttention"); - if (it == options.end()) { - return; - } - auto option = lite::StringTolower(it->second); - if (option != "true" && option != "1") { - return; - } - if (model_infos_.size() != 1) { - MS_LOG(WARNING) << "PagedAttention is not supported when model count " << model_infos_.size() << " is not 1"; - return; - } - auto &model_infos = model_infos_.begin()->second.model_infos; - if (model_infos.empty()) { - MS_LOG(WARNING) << "Model count is model paths is 0"; - return; - } - if (model_infos[0].ref_input_shapes.empty()) { - MS_LOG(WARNING) << "Not found RefData input in model " << model_infos[0].name; - return; - } - auto kv_shape = model_infos[0].ref_input_shapes[0]; - if (kv_shape.size() != kShape3dDims) { - MS_LOG(WARNING) << "KVCache shape " << kv_shape << " is not 3D likes [num_blocks, block_size, hidden_size]"; - return; - } - num_total_blocks_ = kv_shape[kDim0]; - it = options.find("llm.PagedAttentionBlocksNum"); - if (it != options.end()) { - int64_t option_num_total_blocks = 0; - lite::ConvertStrToInt(it->second, &option_num_total_blocks); - MS_LOG(INFO) << "Found llm.PagedAttentionBlocksNum in model config, value: " << it->second; - if (num_total_blocks_ != option_num_total_blocks) { - MS_LOG(WARNING) << "llm.PagedAttentionBlocksNum got from option " << it->second << " != KVCache dim 0 value " - << num_total_blocks_ << ", KVCache shape " << kv_shape; - num_total_blocks_ = option_num_total_blocks; - } - } else { - options["llm.PagedAttentionBlocksNum"] = std::to_string(num_total_blocks_); - MS_LOG(INFO) << "Set model option llm.PagedAttentionBlocksNum to value: " << num_total_blocks_ << ", KVCache shape " - << kv_shape; - } - block_size_ = kv_shape[kDim1]; - it = options.find("llm.PagedAttentionBlockSize"); - if (it != options.end()) { - int64_t option_block_size = 0; - lite::ConvertStrToInt(it->second, &option_block_size); - MS_LOG(INFO) << "Found llm.PagedAttentionBlockSize in model config, value: " << it->second; - if (block_size_ != option_block_size) { - MS_LOG(WARNING) << "llm.PagedAttentionBlockSize got from option " << it->second << " != KVCache dim 1 value " - << block_size_ << ", KVCache shape " << kv_shape; - block_size_ = option_block_size; - } - } else { - options["llm.PagedAttentionBlockSize"] = std::to_string(block_size_); - MS_LOG(INFO) << "Set model option llm.PagedAttentionBlockSize to value: " << block_size_ << ", KVCache shape " - << kv_shape; - } - if (model_infos[0].input_shapes.empty()) { - MS_LOG(WARNING) << "Not found Data input in model " << model_infos[0].name; - return; - } - auto block_table_shape = *model_infos[0].input_shapes.rbegin(); - if (block_table_shape.size() != kShape2dDims) { - MS_LOG(WARNING) << "Block table(last input) shape " << block_table_shape - << " is not 2D likes [batch_size, max_block_num_per_seq]"; - return; - } - max_seq_len_ = block_table_shape[kDim1] * block_size_; - it = options.find("llm.PagedAttentionMaxSeqLen"); - if (it != options.end()) { - int64_t option_max_seq_len = 0; - lite::ConvertStrToInt(it->second, &option_max_seq_len); - MS_LOG(INFO) << "Found llm.PagedAttentionMaxSeqLen in model config, value: " << it->second; - if (max_seq_len_ != option_max_seq_len) { - MS_LOG(WARNING) << "llm.PagedAttentionMaxSeqLen got from option " << it->second << " != " << max_seq_len_ - << " calculated from block_table(last input) shape " << block_table_shape << " and block_size " - << block_size_; - max_seq_len_ = option_max_seq_len; - } - } else { - options["llm.PagedAttentionMaxSeqLen"] = std::to_string(max_seq_len_); - MS_LOG(INFO) << "Set model option llm.PagedAttentionMaxSeqLen to value: " << max_seq_len_ - << ", block_table(last input) shape " << block_table_shape << ", block_size " << block_size_; - } -} - -Status LLMEnginePlugin::AddModel(const std::vector &model_infos, - const std::map &options_i, - const LLMEngineModelInfo &postprocess_model, uint64_t *model_id) { - if (model_infos.empty()) { - MS_LOG(ERROR) << "Model infos cannot be empty"; - return kLiteError; - } - if (model_id == nullptr) { - MS_LOG(ERROR) << "Input argument model_id is nullptr"; - return kLiteError; - } - if (finalized_) { - MS_LOG(ERROR) << "LLMEngine has been finalized"; - return kLiteLLMEngineFinalized; - } - if (inited_) { - MS_LOG(ERROR) << "LLMEngine has been inited"; - return kLiteError; - } - if (llm_engine_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine object is nullptr"; - return kLiteError; - } - MS_LOG(INFO) << "LLMEngine AddLLMModel begin"; - auto options = options_i; - if (CheckModelInfos(model_infos) != kSuccess) { - return kLiteError; - } - InitInputOptions(model_infos[0], false, &options); - auto option_it = options.find("llm.OmCachePath"); - if (option_it == options.end()) { - std::string cache_path; - for (size_t i = 0; i < model_infos.size(); i++) { - cache_path += model_infos[i].weight_dir; - if (i + 1 < model_infos.size()) { - cache_path += ";"; - } - } - MS_LOG(INFO) << "Add option llm.OmCachePath to " << cache_path; - options["llm.OmCachePath"] = cache_path; - } - std::vector model_buffers; - for (auto &item : model_infos) { - ge::ModelBufferData buff; - buff.data = std::shared_ptr(reinterpret_cast(item.om_data->data_c()), [](uint8_t *) {}); - buff.length = item.om_data->Size(); - model_buffers.push_back(buff); - MS_LOG(INFO) << "Inference model " << item.name << ", model buffer size " << item.om_data->Size(); - } - std::map> model_buffers_map; - model_buffers_map["inference"] = model_buffers; - if (postprocess_model.om_data != nullptr) { - InitInputOptions(postprocess_model, true, &options); - - ge::ModelBufferData postprocess_buff; - postprocess_buff.data = - std::shared_ptr(reinterpret_cast(postprocess_model.om_data->data_c()), [](uint8_t *) {}); - postprocess_buff.length = postprocess_model.om_data->Size(); - MS_LOG(INFO) << "Postprocess model " << postprocess_model.name << ", model buffer size " - << postprocess_model.om_data->Size(); - model_buffers_map["postprocess"] = {postprocess_buff}; - } - std::map model_options; - for (auto &option : options) { - model_options[ge::AscendString(option.first.c_str())] = ge::AscendString(option.second.c_str()); - MS_LOG(INFO) << "AddLLMModel option " << option.first << " = " << option.second; - } - MS_LOG(INFO) << "Start to call llm::LLMEngine::LLMEngineInitializeV2"; - auto ge_status = llm_engine_->AddLLMModel(model_buffers_map, model_options, *model_id); - if (ge_status != ge::GRAPH_SUCCESS) { - return OnGeStatus(ge_status, "AddLLMModel", "return"); - } - LLMModelInfo info; - info.model_infos = model_infos; - info.postprocess_model = postprocess_model; - info.options = options; - model_infos_[*model_id] = info; - MS_LOG(INFO) << "LLMEngine AddLLMModel end"; - return kSuccess; -} - -Status LLMEnginePlugin::Init(const std::map &options_i) { - if (finalized_) { - MS_LOG(ERROR) << "LLMEngine has been finalized"; - return kLiteLLMEngineFinalized; - } - if (inited_) { - MS_LOG(ERROR) << "LLMEngine has been inited"; - return kLiteError; - } - if (llm_engine_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine object is nullptr"; - return kLiteError; - } - MS_LOG(INFO) << "LLMEngine Init begin"; - auto options = options_i; - options["llm.Role"] = role_ == LLMRole::kLLMRolePrompt ? "Prompt" : "Decoder"; - options["llm.batch_mode"] = batch_mode_; - SetPagedAttentionOptions(&options); - std::map init_options; - for (auto &option : options) { - init_options[ge::AscendString(option.first.c_str())] = ge::AscendString(option.second.c_str()); - MS_LOG(INFO) << "LLMEngineInitializeV2 option " << option.first << " = " << option.second; - } - MS_LOG(INFO) << "Start to call llm::LLMEngine::LLMEngineInitialize"; - auto ge_status = llm_engine_->LLMEngineInitializeV2({}, init_options); - if (ge_status != ge::GRAPH_SUCCESS) { - return OnGeStatus(ge_status, "LLMEngineInitializeV2", "return"); - } - model_infos_.clear(); - inited_ = true; - MS_LOG(INFO) << "LLMEngine Init end"; - return kSuccess; -} - -void LLMEnginePlugin::Finalize() { - if (llm_engine_ != nullptr) { - MS_LOG(INFO) << "Start to call LLMEngineFinalize"; - auto ge_status = llm_engine_->LLMEngineFinalize(); - llm_engine_ = nullptr; - finalized_ = true; - if (ge_status != ge::GRAPH_SUCCESS) { - MS_LOG(ERROR) << "Failed to call LLMEngineFinalize, status: " << ge_status; - return; - } - MS_LOG(INFO) << "End to call LLMEngineFinalize"; - } -} - -Status LLMEnginePlugin::Run(const llm::LLMReq &llm_req, const std::vector<::ge::Tensor> &ge_inputs, - std::vector<::ge::Tensor> *outputs, uint64_t model_id) { - auto time_start = std::chrono::system_clock::now(); - - if (role_ == kLLMRolePrompt) { - MS_LOG(INFO) << "Start to call llm::LLMEngine::RunPrompt"; - auto ge_status = llm_engine_->RunPrompt(llm_req, ge_inputs, *outputs, model_id); - if (ge_status != ge::GRAPH_SUCCESS) { - return OnGeStatus(ge_status, "RunPrompt", "return"); - } - } else { - if (model_id != 0) { - MS_LOG(ERROR) << "Decoder only support manual mode when there are more than one LLM Model, current mode " - << batch_mode_; - return kLiteError; - } - MS_LOG(INFO) << "Start to call llm::LLMEngine::RunDecoder"; - auto ge_status = llm_engine_->RunDecoder(llm_req, ge_inputs, *outputs); - if (ge_status != ge::GRAPH_SUCCESS) { - return OnGeStatus(ge_status, "RunDecoder", "return"); - } - } - auto time_cost = - std::chrono::duration_cast(std::chrono::system_clock::now() - time_start).count(); - MS_LOG(INFO) << "Call LLMEngine RunPrompt or RunDecoder Success in " << time_cost << " us, role " - << (role_ == LLMRole::kLLMRolePrompt ? "Prompt" : "Decoder") << ", outputs num is: " << outputs->size(); - return kSuccess; -} - -Status LLMEnginePlugin::Run(const std::vector &llm_req, const std::vector<::ge::Tensor> &ge_inputs, - std::vector<::ge::Tensor> *outputs, uint64_t model_id) { - auto time_start = std::chrono::system_clock::now(); - - if (role_ == kLLMRolePrompt) { - MS_LOG(INFO) << "Start to call llm::LLMEngine::RunPrompt"; - auto ge_status = llm_engine_->RunPrompt(llm_req, ge_inputs, *outputs, model_id); - if (ge_status != ge::GRAPH_SUCCESS) { - return OnGeStatus(ge_status, "RunPrompt", "return"); - } - } else { - MS_LOG(INFO) << "Start to call llm::LLMEngine::RunDecoder"; - auto ge_status = llm_engine_->RunDecoder(llm_req, ge_inputs, *outputs, model_id); - if (ge_status != ge::GRAPH_SUCCESS) { - return OnGeStatus(ge_status, "RunDecoder", "return"); - } - } - auto time_cost = - std::chrono::duration_cast(std::chrono::system_clock::now() - time_start).count(); - MS_LOG(INFO) << "Call LLMEngine RunPrompt or RunDecoder Success in " << time_cost << " us, role " - << (role_ == LLMRole::kLLMRolePrompt ? "Prompt" : "Decoder") << ", outputs num is: " << outputs->size(); - return kSuccess; -} - -void LLMEnginePlugin::TransLLMReq(const LLMReq &req, llm::LLMReq *llm_req_ptr) { - if (llm_req_ptr == nullptr) { - MS_LOG(ERROR) << "Input argument llm_req_ptr is nullptr"; - return; - } - llm::LLMReq &llm_req = *llm_req_ptr; - llm_req.SetReqId(req.req_id); - llm_req.SetPromptLength(req.prompt_length); - llm_req.SetPromptClusterId(req.prompt_cluster_id); - llm_req.SetDecoderClusterId(req.decoder_cluster_id); - llm_req.SetPrefixId(req.prefix_id); - llm_req.SetSequenceLen(req.sequence_length); -} - -void LLMEnginePlugin::TransLLMClusterInfos(const std::vector &clusters, - std::vector *llm_clusters_ptr) { - if (llm_clusters_ptr == nullptr) { - MS_LOG(ERROR) << "Input argument llm_clusters_ptr is nullptr"; - return; - } - auto &llm_clusters = *llm_clusters_ptr; - for (auto &cluster : clusters) { - llm::ClusterInfo llm_cluster; - llm_cluster.remote_cluster_id = cluster.remote_cluster_id; - llm_cluster.remote_role_type = cluster.remote_role_type; - for (auto &item : cluster.local_ip_infos) { - llm::IpInfo llm_ip_info; - llm_ip_info.ip = item.ip; - llm_ip_info.port = item.port; - llm_cluster.local_ip_infos.push_back(llm_ip_info); - } - for (auto &item : cluster.remote_ip_infos) { - llm::IpInfo llm_ip_info; - llm_ip_info.ip = item.ip; - llm_ip_info.port = item.port; - llm_cluster.remote_ip_infos.push_back(llm_ip_info); - } - llm_clusters.push_back(llm_cluster); - } -} - -Status LLMEnginePlugin::MSTensorToGeTensor(const std::vector &inputs, std::vector<::ge::Tensor> *ge_inputs) { - for (size_t i = 0; i < inputs.size(); i++) { - auto &input = inputs[i]; - MS_LOG(INFO) << "Input " << i << " shape " << input.Shape() << ", datatype " << input.DataType(); - // create ge tensor - auto desc = device::ascend::TransformUtil::GetGeTensorDesc(input.Shape(), static_cast(input.DataType()), - kOpFormat_NCHW); - if (desc == nullptr) { - MS_LOG(ERROR) << "Failed to get Tensor Desc"; - return kLiteError; - } - ge::Tensor tensor(*desc); - auto data = reinterpret_cast(const_cast(input.Data().get())); - auto ret = tensor.SetData(data, input.DataSize(), [](uint8_t *) -> void {}); - if (ret != ge::GRAPH_SUCCESS) { - MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(uint8_t*, size, DeleteFunc), data size " << input.DataSize(); - return kLiteError; - } - ge_inputs->emplace_back(tensor); - } - return kSuccess; -} - -Status LLMEnginePlugin::Predict(const LLMReq &req, const std::vector &inputs, std::vector *outputs, - uint64_t model_id) { - if (outputs == nullptr) { - MS_LOG(ERROR) << "Input argument outputs is nullptr"; - return kLiteError; - } - if (finalized_) { - MS_LOG(ERROR) << "LLMEngine has been finalized"; - return kLiteLLMEngineFinalized; - } - if (!inited_) { - MS_LOG(ERROR) << "LLMEngine has not been inited or inited failed"; - return kLiteError; - } - if (llm_engine_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine object is nullptr"; - return kLiteError; - } - std::vector<::ge::Tensor> ge_inputs; - auto ret = MSTensorToGeTensor(inputs, &ge_inputs); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Failed to transform MSTensor to Ge Tensor"; - return ret; - } - llm::LLMReq llm_req; - TransLLMReq(req, &llm_req); - MS_LOG(INFO) << "Start to call predict, req_id " << llm_req.GetReqId() << ", prompt_length " - << llm_req.GetPromptLength() << ", prompt_cluster_id: " << llm_req.GetPromptClusterId() - << ", decoder_cluster_id: " << llm_req.GetDecoderClusterId() << ", prefix id " << llm_req.GetPrefixId(); - std::vector<::ge::Tensor> ge_outputs; - ret = Run(llm_req, ge_inputs, &ge_outputs, model_id); - if (ret != kSuccess) { - return ret; - } - for (size_t i = 0; i < ge_outputs.size(); i++) { - auto &ge_tensor = ge_outputs[i]; - auto ms_tensor = ConvertGeTensorNoCopy(&ge_tensor); - if (ms_tensor == nullptr) { - MS_LOG(ERROR) << "Failed to converter output " << i << " GE Tensor to ME Tensor"; - return kLiteError; - } - MS_LOG(INFO) << "Output " << i << " shape " << ms_tensor.Shape() << ", datatype " << ms_tensor.DataType(); - outputs->push_back(ms_tensor); - } - return kSuccess; -} - -Status LLMEnginePlugin::Predict(const std::vector &req, const std::vector &inputs, - std::vector *outputs, uint64_t model_id) { - if (outputs == nullptr) { - MS_LOG(ERROR) << "Input argument outputs is nullptr"; - return kLiteError; - } - if (finalized_) { - MS_LOG(ERROR) << "LLMEngine has been finalized"; - return kLiteLLMEngineFinalized; - } - if (!inited_) { - MS_LOG(ERROR) << "LLMEngine has not been inited or inited failed"; - return kLiteError; - } - if (llm_engine_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine object is nullptr"; - return kLiteError; - } - std::vector<::ge::Tensor> ge_inputs; - auto ret = MSTensorToGeTensor(inputs, &ge_inputs); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Failed to transform MSTensor to Ge Tensor"; - return ret; - } - MS_LOG(INFO) << "Start to call predict, requests: "; - std::vector llm_reqs; - (void)std::transform(req.begin(), req.end(), std::back_inserter(llm_reqs), [](const LLMReq &item) { - llm::LLMReq llm_req; - TransLLMReq(item, &llm_req); - MS_LOG(INFO) << "req_id " << llm_req.GetReqId() << ", prompt_length " << llm_req.GetPromptLength() - << ", prompt_cluster_id: " << llm_req.GetPromptClusterId() - << ", decoder_cluster_id: " << llm_req.GetDecoderClusterId() << ", prefix id " - << llm_req.GetPrefixId(); - return llm_req; - }); - std::vector<::ge::Tensor> ge_outputs; - ret = Run(llm_reqs, ge_inputs, &ge_outputs, model_id); - if (ret != kSuccess) { - return ret; - } - for (size_t i = 0; i < ge_outputs.size(); i++) { - auto &ge_tensor = ge_outputs[i]; - auto ms_tensor = ConvertGeTensorNoCopy(&ge_tensor); - if (ms_tensor == nullptr) { - MS_LOG(ERROR) << "Failed to converter output " << i << " GE Tensor to ME Tensor"; - return kLiteError; - } - MS_LOG(INFO) << "Output " << i << " shape " << ms_tensor.Shape() << ", datatype " << ms_tensor.DataType(); - outputs->push_back(ms_tensor); - } - return kSuccess; -} - -Status LLMEnginePlugin::PullKV(const LLMReq &req, uint64_t model_id) { - if (!inited_) { - MS_LOG(ERROR) << "LLMEngine has not been inited or inited failed"; - return kLiteError; - } - if (llm_engine_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine object is nullptr"; - return kLiteError; - } - llm::LLMReq llm_req; - TransLLMReq(req, &llm_req); - MS_LOG(INFO) << "Start to call PullKv, req_id " << llm_req.GetReqId() << ", prompt_length " - << llm_req.GetPromptLength() << ", prompt_cluster_id: " << llm_req.GetPromptClusterId() - << ", decoder_cluster_id: " << llm_req.GetDecoderClusterId() << ", prefix_id " << llm_req.GetPrefixId() - << ", model_id " << model_id; - auto ge_ret = llm_engine_->PullKv(llm_req, model_id); - return OnGeStatus(ge_ret, "PullKv", "return"); -} - -Status LLMEnginePlugin::MergeKV(const LLMReq &req, uint32_t batch_index, uint32_t batch_id, uint64_t model_id) { - if (!inited_) { - MS_LOG(ERROR) << "LLMEngine has not been inited or inited failed"; - return kLiteError; - } - if (llm_engine_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine object is nullptr"; - return kLiteError; - } - MS_LOG(INFO) << "Start to call MergeKV, req_id " << req.req_id << ", batch_index " << batch_index << ", batch_id " - << batch_id << ", model_id " << model_id; - auto ge_ret = llm_engine_->MergeKv(req.req_id, batch_index, batch_id, model_id); - return OnGeStatus(ge_ret, "MergeKV", "return"); -} - -Status LLMEnginePlugin::CompleteRequest(const LLMReq &req) { - if (finalized_) { - MS_LOG(ERROR) << "LLMEngine has been finalized"; - return kLiteLLMEngineFinalized; - } - if (!inited_) { - MS_LOG(ERROR) << "LLMEngine has not been inited or inited failed"; - return kLiteError; - } - if (llm_engine_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine object is nullptr"; - return kLiteError; - } - MS_LOG(INFO) << "Start to call llm::LLMEngine::LLMReqComplete, req_id " << req.req_id << ", prompt_length " - << req.prompt_length << ", prompt_cluster_id: " << req.prompt_cluster_id; - llm::LLMReq llm_req; - TransLLMReq(req, &llm_req); - auto ge_ret = llm_engine_->LLMReqComplete(llm_req); - return OnGeStatus(ge_ret, "LLMReqComplete", "return"); -} - -LLMEngineStatus LLMEnginePlugin::FetchStatus() { - if (finalized_) { - MS_LOG(ERROR) << "LLMEngine has been finalized"; - return LLMEngineStatus(); - } - if (!inited_) { - MS_LOG(ERROR) << "LLMEngine has not been inited or inited failed"; - return LLMEngineStatus(); - } - if (llm_engine_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine object is nullptr"; - return LLMEngineStatus(); - } - LLMEngineStatus status; - auto llm_status = llm_engine_->FetchLLMEngineStatus(); - status.empty_max_prompt_kv = llm_status.empty_max_prompt_kv; - status.num_free_blocks = llm_status.num_free_blocks; - status.num_total_blocks = num_total_blocks_; - status.block_size = block_size_; - return status; -} - -Status LLMEnginePlugin::PreloadPromptPrefix(const LLMReq &req, const std::vector &inputs, uint64_t model_id) { - if (finalized_) { - MS_LOG(ERROR) << "LLMEngine has been finalized"; - return kLiteLLMEngineFinalized; - } - if (!inited_) { - MS_LOG(ERROR) << "LLMEngine has not been inited or inited failed"; - return kLiteError; - } - if (llm_engine_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine object is nullptr"; - return kLiteError; - } - MS_LOG(INFO) << "Start to call llm::LLMEngine::PreloadPromptPrefix, req_id " << req.req_id << ", prompt_length " - << req.prompt_length << ", prompt_cluster_id: " << req.prompt_cluster_id << ", prefix_id " - << req.prefix_id << ", model_id " << model_id; - std::vector<::ge::Tensor> ge_inputs; - auto ret = MSTensorToGeTensor(inputs, &ge_inputs); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Failed to transform MSTensor to Ge Tensor"; - return ret; - } - llm::LLMReq llm_req; - TransLLMReq(req, &llm_req); - auto ge_ret = llm_engine_->PreloadPromptPrefix(llm_req, ge_inputs, model_id); - return OnGeStatus(ge_ret, "PreloadPromptPrefix", "return"); -} - -Status LLMEnginePlugin::ReleasePromptPrefix(const LLMReq &req, uint64_t model_id) { - if (finalized_) { - MS_LOG(ERROR) << "LLMEngine has been finalized"; - return kLiteLLMEngineFinalized; - } - if (!inited_) { - MS_LOG(ERROR) << "LLMEngine has not been inited or inited failed"; - return kLiteError; - } - if (llm_engine_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine object is nullptr"; - return kLiteError; - } - MS_LOG(INFO) << "Start to call llm::LLMEngine::ReleasePromptPrefix, req_id " << req.req_id << ", prompt_length " - << req.prompt_length << ", prompt_cluster_id: " << req.prompt_cluster_id << ", prefix_id " - << req.prefix_id << ", model_id " << model_id; - llm::LLMReq llm_req; - TransLLMReq(req, &llm_req); - auto ge_ret = llm_engine_->ReleasePromptPrefix(llm_req, model_id); - return OnGeStatus(ge_ret, "ReleasePromptPrefix", "return"); -} - -Status LLMEnginePlugin::LinkClusters(const std::vector &clusters, std::vector *rets, - int32_t timeout) { - if (finalized_) { - MS_LOG(ERROR) << "LLMEngine has been finalized"; - return kLiteLLMEngineFinalized; - } - if (!inited_) { - MS_LOG(ERROR) << "LLMEngine has not been inited or inited failed"; - return kLiteError; - } - if (llm_engine_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine object is nullptr"; - return kLiteError; - } - if (rets == nullptr) { - MS_LOG(ERROR) << "Input argument rets is nullptr"; - return kLiteError; - } - MS_LOG(INFO) << "Start to call llm::LLMEngine::LinkClusters, cluster size " << clusters.size(); - std::function ip_info_as_str = [](const LLMIpInfo &info) { - return std::to_string(info.ip) + ":" + std::to_string(info.port); - }; - std::vector llm_clusters; - for (size_t i = 0; i < clusters.size(); i++) { - auto &cluster = clusters[i]; - MS_LOG(INFO) << "Cluster " << i << ", remote_cluster_id " << cluster.remote_cluster_id << ", remote_role_type " - << cluster.remote_role_type; - MS_LOG(INFO) << "local ip infos: " << lite::VectorToStr(cluster.local_ip_infos, ip_info_as_str); - MS_LOG(INFO) << "remote ip infos: " << lite::VectorToStr(cluster.remote_ip_infos, ip_info_as_str); - } - TransLLMClusterInfos(clusters, &llm_clusters); - std::vector ge_rets; - auto ret = llm_engine_->LinkClusters(llm_clusters, ge_rets, timeout); - if (!ge_rets.empty() && llm_clusters.size() != ge_rets.size()) { - MS_LOG(ERROR) << "Cluster info size " << llm_clusters.size() << "!=" - << " LinkClusters rets size " << ge_rets.size(); - return kLiteError; - } - for (size_t i = 0; i < ge_rets.size(); i++) { - auto ge_ret = ge_rets[i]; - if (ge_ret != ge::GRAPH_SUCCESS) { - rets->push_back(OnGeStatus(ge_ret, "LinkClusters", "return")); - auto &cluster = clusters[i]; - MS_LOG(ERROR) << "Cluster " << i << " error occur, ge error code " << ge_ret << ", remote_cluster_id " - << cluster.remote_cluster_id << ", remote_role_type " << cluster.remote_role_type - << ", local ip infos: " << lite::VectorToStr(cluster.local_ip_infos, ip_info_as_str) - << "remote ip infos: " << lite::VectorToStr(cluster.remote_ip_infos, ip_info_as_str); - } else { - rets->push_back(kSuccess); - } - } - return OnGeStatus(ret, "LinkClusters", "return"); -} - -Status LLMEnginePlugin::UnlinkClusters(const std::vector &clusters, std::vector *rets, - int32_t timeout) { - if (finalized_) { - MS_LOG(ERROR) << "LLMEngine has been finalized"; - return kLiteLLMEngineFinalized; - } - if (!inited_) { - MS_LOG(ERROR) << "LLMEngine has not been inited or inited failed"; - return kLiteError; - } - if (llm_engine_ == nullptr) { - MS_LOG(ERROR) << "LLMEngine object is nullptr"; - return kLiteError; - } - if (rets == nullptr) { - MS_LOG(ERROR) << "Input argument rets is nullptr"; - return kLiteError; - } - MS_LOG(INFO) << "Start to call llm::LLMEngine::UnlinkClusters, cluster size " << clusters.size(); - std::function ip_info_as_str = [](const LLMIpInfo &info) { - return std::to_string(info.ip) + ":" + std::to_string(info.port); - }; - std::vector llm_clusters; - for (size_t i = 0; i < clusters.size(); i++) { - auto &cluster = clusters[i]; - MS_LOG(INFO) << "Cluster " << i << ", remote_cluster_id " << cluster.remote_cluster_id << ", remote_role_type " - << cluster.remote_role_type; - MS_LOG(INFO) << "local ip infos: " << lite::VectorToStr(cluster.local_ip_infos, ip_info_as_str); - MS_LOG(INFO) << "remote ip infos: " << lite::VectorToStr(cluster.remote_ip_infos, ip_info_as_str); - } - TransLLMClusterInfos(clusters, &llm_clusters); - std::vector ge_rets; - auto ret = llm_engine_->UnlinkClusters(llm_clusters, ge_rets, timeout); - if (!ge_rets.empty() && llm_clusters.size() != ge_rets.size()) { - MS_LOG(ERROR) << "Cluster info size " << llm_clusters.size() << "!=" - << " UnlinkClusters rets size " << ge_rets.size(); - return kLiteError; - } - for (size_t i = 0; i < ge_rets.size(); i++) { - auto ge_ret = ge_rets[i]; - if (ge_ret != ge::GRAPH_SUCCESS) { - rets->push_back(OnGeStatus(ge_ret, "UnlinkClusters", "return")); - auto &cluster = clusters[i]; - MS_LOG(ERROR) << "Cluster " << i << " error occur, ge error code " << ge_ret << ", remote_cluster_id " - << cluster.remote_cluster_id << ", remote_role_type " << cluster.remote_role_type - << ", local ip infos: " << lite::VectorToStr(cluster.local_ip_infos, ip_info_as_str) - << "remote ip infos: " << lite::VectorToStr(cluster.remote_ip_infos, ip_info_as_str); - } else { - rets->push_back(kSuccess); - } - } - return OnGeStatus(ret, "UnlinkClusters", "return"); -} - -MSTensor LLMEnginePlugin::ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor_ptr) { - auto &ge_tensor = *ge_tensor_ptr; - auto ge_tensor_desc = ge_tensor.GetTensorDesc(); - auto me_shape = device::ascend::TransformUtil::ConvertGeShape(ge_tensor_desc.GetShape()); - if (ge_tensor_desc.GetPlacement() != ::ge::kPlacementHost) { - MS_LOG(ERROR) << "It is not supported that graph output data's placement is device now."; - return MSTensor(nullptr); - } - auto &&ge_data_uni = ge_tensor.ResetData(); - auto deleter = ge_data_uni.get_deleter(); - auto ge_data = ge_data_uni.release(); - if (ge_data == nullptr) { - MS_LOG(ERROR) << "Ge data cannot be nullptr"; - return MSTensor(nullptr); - } - constexpr int64_t kTensorAlignBytes = 64; - if (reinterpret_cast(ge_data) % kTensorAlignBytes != 0) { - MS_LOG(ERROR) << "Skip zero-copy ge tensor " << reinterpret_cast(ge_data) - << ", bytes not aligned with expected."; - return MSTensor(nullptr); - } - int64_t elem_num = 1; - for (size_t i = 0; i < me_shape.size(); ++i) { - elem_num *= me_shape[i]; - } - auto tensor_data = std::make_shared(ge_data, elem_num, ge_tensor.GetSize(), me_shape.size(), deleter); - auto type_id = device::ascend::TransformUtil::ConvertGeDataType(ge_tensor_desc.GetDataType()); - auto tensor = std::make_shared(type_id, me_shape, tensor_data); - auto tensor_impl = std::make_shared(tensor); - return MSTensor(tensor_impl); -} -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_plugin.h b/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_plugin.h deleted file mode 100644 index 87fcb6249316029ba73e9dcf9b9dc553b951c8b5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/cxx_api/llm_engine/llm_engine_plugin.h +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_LLM_ENGINE_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_LLM_ENGINE_PLUGIN_H_ -#include -#include -#include -#include -#include -#include -#include "include/api/context.h" -#include "include/api/model.h" -#include "include/api/graph.h" -#include "include/api/serialization.h" -#include "extendrt/cxx_api/graph/graph_data.h" -#include "include/common/utils/utils.h" -#include "ir/func_graph.h" -#include "extendrt/infer_session.h" -#include "src/common/config_infos.h" -#include "mindapi/ir/common.h" -#include "extendrt/cxx_api/llm_engine/llm_engine.h" - -namespace mindspore { -struct LLMEngineModelInfo { - std::string name; - tensor::TensorPtr om_data = nullptr; - std::vector input_names; - std::vector input_shapes; - std::vector input_dtypes; - std::vector ref_input_shapes; - std::vector ref_input_dtypes; - size_t output_count = 0; - std::string weight_dir; -}; - -class LLMEnginePluginBase { - public: - LLMEnginePluginBase(LLMRole role, uint64_t cluster_id, const std::string &batch_mode) - : role_(role), cluster_id_(cluster_id), batch_mode_(batch_mode) {} - virtual ~LLMEnginePluginBase() = default; - - virtual Status AddModel(const std::vector &model_infos, - const std::map &options, - const LLMEngineModelInfo &postprocess_model, uint64_t *model_id) = 0; - virtual Status Init(const std::map &options) = 0; - virtual void Finalize() = 0; - virtual Status Predict(const LLMReq &req, const std::vector &inputs, std::vector *outputs, - uint64_t model_id) = 0; - virtual Status Predict(const std::vector &req, const std::vector &inputs, - std::vector *outputs, uint64_t model_id) = 0; - virtual Status CompleteRequest(const LLMReq &req) = 0; - virtual Status PreloadPromptPrefix(const LLMReq &req, const std::vector &inputs, uint64_t model_id) = 0; - virtual Status ReleasePromptPrefix(const LLMReq &req, uint64_t model_id) = 0; - virtual Status PullKV(const LLMReq &req, uint64_t model_id) = 0; - virtual Status MergeKV(const LLMReq &req, uint32_t batch_index, uint32_t batch_id, uint64_t model_id) = 0; - - virtual LLMEngineStatus FetchStatus() = 0; - virtual Status LinkClusters(const std::vector &, std::vector *rets, int32_t timeout) = 0; - virtual Status UnlinkClusters(const std::vector &, std::vector *rets, int32_t timeout) = 0; - - protected: - LLMRole role_ = kLLMRolePrompt; - uint64_t cluster_id_ = 0; - std::string batch_mode_; -}; - -extern "C" MS_API LLMEnginePluginBase *CreateLLMEnginePlugin(LLMRole, uint64_t, const std::string &); -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_LLM_ENGINE_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc b/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc index 749b33eaa50001a8191eff13cf06deeda6bd6885..4d1eb6e8ad99cf4060fbd48d222bbf87e04a8237 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model/model_impl.cc @@ -39,8 +39,7 @@ #include "mindapi/base/base.h" #include "src/extendrt/delegate/graph_executor/litert/func_graph_reuse_manager.h" #include "load_mindir/load_model.h" -#include "src/extendrt/delegate/plugin/tensorrt_executor_plugin.h" -#include "src/extendrt/kernel/ascend/plugin/ascend_kernel_plugin.h" +#include "src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h" #include "utils/ms_utils_secure.h" #include "infer/custom.h" #include "infer/return.h" @@ -51,65 +50,11 @@ namespace mindspore { namespace { const char *const kExecutionPlan = "execution_plan"; -const char *const kDataFlowGraphType = "data_flow"; -const char *const kDataFlowGraphName = "data_flow_graph"; constexpr size_t kMaxSectionNum = 100; constexpr size_t kMaxConfigNumPerSection = 1000; std::shared_mutex g_model_converter_lock; std::mutex g_load_mindir_lock; -FuncGraphPtr CreateFuncGraphFromDataFlow(const void *model_data, size_t data_size) { - auto func_graph = std::make_shared(); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "The func_graph is nullptr."; - return nullptr; - } - func_graph->set_attr(kAttrFuncType, MakeValue(kDataFlowGraphType)); - - // Create custom node with the dataFlow graph. - auto param = func_graph->add_parameter(); - MS_CHECK_TRUE_RET(param != nullptr, nullptr); - param->set_name(kDataFlowGraphName); - auto type_ptr = TypeIdToType(kNumberTypeUInt8); - MS_CHECK_TRUE_RET(type_ptr != nullptr, nullptr); - ShapeVector shape = {static_cast(data_size)}; - auto param_tensor = std::make_shared(kNumberTypeUInt8, shape); - MS_CHECK_TRUE_RET(param_tensor != nullptr, nullptr); - if (param_tensor->Size() != data_size) { - MS_LOG(ERROR) << "The data size of param value is not equal to the data size: " << data_size; - return nullptr; - } - auto tensor_data = param_tensor->data_c(); - MS_CHECK_TRUE_RET(tensor_data != nullptr, nullptr); - if (common::huge_memcpy(reinterpret_cast(tensor_data), param_tensor->Size(), - reinterpret_cast(const_cast(model_data)), data_size) != EOK) { - MS_LOG(ERROR) << "Memcpy dataflow graph data failed."; - return nullptr; - } - param->set_default_param(param_tensor); - auto abstract_tensor = std::make_shared(type_ptr, shape); - MS_CHECK_TRUE_RET(abstract_tensor != nullptr, nullptr); - param->set_abstract(abstract_tensor); - - auto custom_prim = std::make_shared(); - MS_CHECK_TRUE_RET(custom_prim != nullptr, nullptr); - custom_prim->set_type(kDataFlowGraphType); - auto custom_prim_c = custom_prim->GetPrim(); - MS_CHECK_TRUE_RET(custom_prim_c != nullptr, nullptr); - CNodePtr custom_cnode = func_graph->NewCNode(custom_prim_c, {param}); - MS_CHECK_TRUE_RET(custom_cnode != nullptr, nullptr); - custom_cnode->set_fullname_with_scope("Custom_" + std::string(kDataFlowGraphName)); - auto return_prim = std::make_shared(); - MS_CHECK_TRUE_RET(return_prim != nullptr, nullptr); - auto return_prim_c = return_prim->GetPrim(); - MS_CHECK_TRUE_RET(return_prim_c != nullptr, nullptr); - auto return_cnode = func_graph->NewCNode(return_prim_c, {custom_cnode}); - MS_CHECK_TRUE_RET(return_cnode != nullptr, nullptr); - return_cnode->set_fullname_with_scope("Return"); - func_graph->set_return(return_cnode); - return func_graph; -} - std::unordered_map kStr2FormatMap{{"DEFAULT_FORMAT", mindspore::Format::DEFAULT_FORMAT}, {"NCHW", mindspore::Format::NCHW}, {"NHWC", mindspore::Format::NHWC}, @@ -489,25 +434,16 @@ Status ModelImpl::BuildByBufferImpl(const void *model_buff, size_t model_size, M return session_->CompileGraph(func_graph, nullptr, 0, &graph_id_); } - if (model_type != ModelType::kDataFlow) { - func_graph = LoadGraphByBufferImpl(model_buff, model_size, model_type, model_context, model_path); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Failed to load MindIR model, please check the validity of the model: " << model_path; - return kLiteError; - } - // convert and optimize func graph to infer - ret = ConvertGraphOnline(func_graph, model_context); - if (ret != kSuccess) { - MS_LOG(ERROR) << "convert graph failed!ret = " << ret; - return ret; - } - } else { - // new a func graph contains a custom node, which is the data-flow graph. - func_graph = CreateFuncGraphFromDataFlow(model_buff, model_size); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Create func graph failed from data flow graph!"; - return kLiteError; - } + func_graph = LoadGraphByBufferImpl(model_buff, model_size, model_type, model_context, model_path); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Failed to load MindIR model, please check the validity of the model: " << model_path; + return kLiteError; + } + // convert and optimize func graph to infer + ret = ConvertGraphOnline(func_graph, model_context); + if (ret != kSuccess) { + MS_LOG(ERROR) << "convert graph failed!ret = " << ret; + return ret; } ret = session_->CompileGraph(func_graph, nullptr, 0, &graph_id_); if (ret != kSuccess) { @@ -577,25 +513,16 @@ Status ModelImpl::BuildByBufferImpl(const void *model_data, size_t model_size, M return session_->CompileGraph(func_graph, nullptr, 0, &graph_id_); } - if (model_type != ModelType::kDataFlow) { - func_graph = LoadGraphByBufferImpl(model_data, model_size, model_type, model_context, model_path, cryptoInfo); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Failed to load MindIR model, please check the validity of the model: " << model_path; - return kLiteError; - } - // convert and optimize func graph to infer - ret = ConvertGraphOnline(func_graph, model_context); - if (ret != kSuccess) { - MS_LOG(ERROR) << "convert graph failed!"; - return ret; - } - } else { - // new a func graph contains a custom node, which is the data-flow graph. - func_graph = CreateFuncGraphFromDataFlow(model_data, model_size); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Create func graph failed from data flow graph!"; - return kLiteError; - } + func_graph = LoadGraphByBufferImpl(model_data, model_size, model_type, model_context, model_path, cryptoInfo); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Failed to load MindIR model, please check the validity of the model: " << model_path; + return kLiteError; + } + // convert and optimize func graph to infer + ret = ConvertGraphOnline(func_graph, model_context); + if (ret != kSuccess) { + MS_LOG(ERROR) << "convert graph failed!"; + return ret; } ret = session_->CompileGraph(func_graph, nullptr, 0, &graph_id_); if (ret != kSuccess) { @@ -755,8 +682,8 @@ Status ModelImpl::Resize(const std::vector &inputs, const std::vector< MS_LOG(ERROR) << "The size of inputs is incorrect."; return kLiteInputParamInvalid; } - std::vector resize_inputs = TensorUtils::MSTensorToTensor(inputs); - return session_->Resize(graph_id_, resize_inputs, dims); + // std::vector resize_inputs = TensorUtils::MSTensorToTensor(inputs); + return session_->Resize(graph_id_, inputs, dims); } std::vector ModelImpl::GetInputs() { @@ -823,13 +750,14 @@ MSTensor ModelImpl::GetOutputByTensorName(const std::string &name) { } Status ModelImpl::UpdateWeights(const std::vector> &weights) { - MS_CHECK_TRUE_MSG(session_ != nullptr, kLiteError, "Session is null, please build model first!"); - size_t weights_size = weights.size(); - std::vector> new_weights(weights_size); - for (size_t i = 0; i < weights_size; ++i) { - new_weights[i] = TensorUtils::MSTensorToTensorPtr(weights[i]); - } - return session_->UpdateWeights(new_weights); + // MS_CHECK_TRUE_MSG(session_ != nullptr, kLiteError, "Session is null, please build model first!"); + // size_t weights_size = weights.size(); + // std::vector> new_weights(weights_size); + // for (size_t i = 0; i < weights_size; ++i) { + // new_weights[i] = TensorUtils::MSTensorToTensorPtr(weights[i]); + // } + // return session_->UpdateWeights(weights); + return kSuccess; } Status ModelImpl::Predict(const std::vector &inputs, std::vector *outputs, @@ -840,41 +768,15 @@ Status ModelImpl::Predict(const std::vector &inputs, std::vector graph_inputs = TensorUtils::MSTensorToTensor(inputs); - std::vector graph_outputs; - std::vector org_graph_outputs; - if (!outputs->empty()) { - graph_outputs = TensorUtils::MSTensorToTensor(*outputs); - org_graph_outputs = graph_outputs; - } - auto ret = session_->RunGraph(graph_id_, graph_inputs, &graph_outputs, before, after); + auto ret = session_->RunGraph(graph_id_, inputs, outputs, before, after); if (ret != kSuccess) { MS_LOG(ERROR) << "ModelImpl::Predict RunGraph failed!ret = " << ret; return ret; } - bool output_remain = false; - if (!org_graph_outputs.empty() && org_graph_outputs.size() == graph_outputs.size()) { - output_remain = true; - for (size_t i = 0; i < org_graph_outputs.size(); i++) { - if (org_graph_outputs[i].data_ptr() != graph_outputs[i].data_ptr() || - org_graph_outputs[i].device_address() != graph_outputs[i].device_address()) { - output_remain = false; - break; - } - } - } - if (!output_remain) { - auto session_outputs = session_->GetOutputNames(graph_id_); - if (session_outputs.empty() || session_outputs.size() != graph_outputs.size()) { - MS_LOG(ERROR) << "output name is wrong."; - return kLiteError; - } - *outputs = TensorUtils::TensorToMSTensor(graph_outputs, session_outputs); - } auto session_outputs = session_->GetOutputs(graph_id_); - if (graph_outputs.size() != session_outputs.size()) { + if (outputs->size() != session_outputs.size()) { MS_LOG(ERROR) << "Outputs count get from session " << session_outputs.size() << " != outputs count of RunGraph " - << graph_outputs.size(); + << outputs->size(); return kCoreFailed; } for (size_t i = 0; i < session_outputs.size(); i++) { @@ -1076,13 +978,6 @@ bool ModelImpl::CheckModelSupport(DeviceType device_type, ModelType model_type) if (model_type != kMindIR) { return false; } - - if (device_type == kGPU) { - return lite::TensorRTExecutorPlugin::GetInstance().TryRegister().IsOk(); - } - if (device_type == kAscend) { - return kernel::AscendKernelPlugin::TryRegister().IsOk(); - } return false; } diff --git a/mindspore-lite/src/extendrt/cxx_api/model/model_impl.h b/mindspore-lite/src/extendrt/cxx_api/model/model_impl.h index 2cc6076dfd2fce1a608449b7255623db93d68236..4c9802b8051083ca9dc025f1ea5cfd672c835e41 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model/model_impl.h +++ b/mindspore-lite/src/extendrt/cxx_api/model/model_impl.h @@ -63,9 +63,9 @@ class ModelImpl { /// /// \param[in] model_data Define the buffer read from a model file. /// \param[in] data_size Define bytes number of model buffer. - /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kMindIR_Lite, - /// ModelType::kDataFlow. Only ModelType::kMindIR is valid for Lite. \param[in] model_context Define the context used - /// to store options during execution. + /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kMindIR_Lite. + /// Only ModelType::kMindIR is valid for Lite. \param[in] model_context Define the context used to store options + /// during execution. /// /// \return Status. Status Build(const void *model_data, size_t data_size, ModelType model_type, @@ -75,8 +75,8 @@ class ModelImpl { /// /// \param[in] model_data Define the buffer of the loaded mindir_graph /// \param[in] data_size Define bytes number of model buffer. - /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kMindIR_Lite, - /// ModelType::kDataFlow. Only ModelType::kMindIR is valid for Lite. + /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kMindIR_Lite. + /// Only ModelType::kMindIR is valid for Lite. /// \param[in] model_context Define the context used to store options during execution. /// \param[in] model_path Define the path of weight file /// \param[in] cryptoInfo Define the decryption information @@ -89,9 +89,9 @@ class ModelImpl { /// \brief Build a model from model file path so that it can run on a device. /// /// \param[in] model_path Define the path of a model file. - /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kMindIR_Lite, - /// ModelType::kDataFlow. Only ModelType::kMindIR is valid for Lite. \param[in] model_context Define the context used - /// to store options during execution. + /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kMindIR_Lite. + /// Only ModelType::kMindIR is valid for Lite. \param[in] model_context Define the context used to store options + /// during execution. /// /// \return Status. Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr &model_context); @@ -225,10 +225,9 @@ class ModelImpl { /// /// \param[in] model_data Define the buffer read from a model file. /// \param[in] data_size Define bytes number of model buffer. - /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kMindIR_Lite, - /// ModelType::kDataFlow. Only ModelType::kMindIR is valid for Lite. \param[in] model_context Define the context used - /// to store options during execution. \param[in] model_path Define the model_path, this param is used for net and - /// weight divided case. + /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kMindIR_Lite. + /// Only ModelType::kMindIR is valid for Lite. \param[in] model_context Define the context used to store options + /// during execution. \param[in] model_path Define the model_path, this param is used for net and weight divided case. /// /// \return value of config as string type. Status BuildByBufferImpl(const void *model_data, size_t data_size, ModelType model_type, @@ -237,10 +236,9 @@ class ModelImpl { /// /// \param[in] model_data Define the buffer read from a model file. /// \param[in] data_size Define bytes number of model buffer. - /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kMindIR_Lite, - /// ModelType::kDataFlow. Only ModelType::kMindIR is valid for Lite. \param[in] model_context Define the context used - /// to store options during execution. \param[in] model_path Define the model_path, this param is used for net and - /// weight divided case. + /// \param[in] model_type Define The type of model file. Options: ModelType::kMindIR, ModelType::kMindIR_Lite. + /// Only ModelType::kMindIR is valid for Lite. \param[in] model_context Define the context used to store options + /// during execution. \param[in] model_path Define the model_path, this param is used for net and weight divided case. /// \param[in] cryptoInfo Define the decryption information /// /// \return value of config as string type. diff --git a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_pool.cc b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_pool.cc index 985de8f1068545fd639d8557c0d27d4b00522126..048d15d05f35ad35b20a8bc58af2732c0fbe7911 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model_pool/model_pool.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model_pool/model_pool.cc @@ -27,7 +27,7 @@ #include "src/litert/pack_weight_manager.h" #include "src/extendrt/numa_adapter.h" #include "src/common/common.h" -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +#if defined(ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) #include "thread/parallel_thread_pool_manager.h" #endif #include "src/common/config_file.h" @@ -53,8 +53,8 @@ int ModelPool::GetDefaultThreadNum(int worker_num) { int default_thread_num = -1; if (can_use_core_num_ <= kNumPhysicalCoreThreshold) { default_thread_num = can_use_core_num_ >= kDefaultWorkerNumPerPhysicalCpu - ? can_use_core_num_ / kDefaultWorkerNumPerPhysicalCpu - : can_use_core_num_; + ? can_use_core_num_ / kDefaultWorkerNumPerPhysicalCpu + : can_use_core_num_; } else { default_thread_num = kDefaultThreadsNum; } @@ -807,13 +807,9 @@ Status ModelPool::CanUseAllPhysicalResources() { size_t percentage; auto can_use_cores = ResourceManager::GetInstance()->ParseCpuCoreList(&percentage); if (can_use_cores.empty()) { - MS_LOG(WARNING) << "parse cpu files failed, | can use core list: " << can_use_cores - << " | percentage: " << percentage; - can_use_all_physical_core_ = false; - bind_core_available_ = false; - all_core_num_ = lite::GetCoreNum(); - can_use_core_num_ = all_core_num_; - return kSuccess; + MS_LOG(ERROR) << "parse cpu files failed, | can use core list: " << can_use_cores + << " | percentage: " << percentage; + return kLiteError; } if (percentage == can_use_cores.size()) { MS_LOG(INFO) << "percentage: " << percentage << "can_use_cores size: " << can_use_cores.size(); diff --git a/mindspore-lite/src/extendrt/cxx_api/model_pool/resource_manager.cc b/mindspore-lite/src/extendrt/cxx_api/model_pool/resource_manager.cc index a14f96b14fc94a9ccf433a767fa59f847ff88c13..00dd71a38da89d4382e80cc1ab139c31b681fb67 100644 --- a/mindspore-lite/src/extendrt/cxx_api/model_pool/resource_manager.cc +++ b/mindspore-lite/src/extendrt/cxx_api/model_pool/resource_manager.cc @@ -91,7 +91,7 @@ std::vector ResourceManager::ParseCpuCoreList(size_t *can_use_core_num) { period_file.close(); auto period = std::atoi(period_line.c_str()); if (period == 0) { - MS_LOG(WARNING) << "parse cpu.cfs_period_us file failed, can not use all core!"; + MS_LOG(ERROR) << "read cpu.cfs_period_us file failed."; *can_use_core_num = 0; return {}; } diff --git a/mindspore-lite/src/extendrt/kernel/ascend/CMakeLists.txt b/mindspore-lite/src/extendrt/delegate/ascend_acl/CMakeLists.txt similarity index 66% rename from mindspore-lite/src/extendrt/kernel/ascend/CMakeLists.txt rename to mindspore-lite/src/extendrt/delegate/ascend_acl/CMakeLists.txt index de8d1552b3dae376c14ccb150170ed3a3e7940d0..3aea17aa0eb366f172f366d708a154c00f0ca983 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/CMakeLists.txt +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/CMakeLists.txt @@ -1,5 +1,5 @@ -include_directories(${TOP_DIR}/mindspore) -include_directories(${TOP_DIR}/mindspore-lite/src) +include_directories(${TOP_DIR}/mindspore/mindspore) +# include_directories(${TOP_DIR}/mindspore-lite/src) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN/") find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) @@ -7,16 +7,10 @@ find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUN find_library(acl_cblas libacl_cblas.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) file(GLOB_RECURSE ASCEND_SRC ${CMAKE_CURRENT_SOURCE_DIR} - "src/*.cc" - "api/*.cc" - "model/*.cc" - "profiling/*.cc" + "*.cc" ) -set(ASCEND_SRC ${ASCEND_SRC} ${TOP_DIR}/mindspore-lite/src/litert/kernel/ascend/src/acl_mem_manager.cc - ${CMAKE_CURRENT_SOURCE_DIR}/model/acl_allocator.cc - ) -file(GLOB_RECURSE ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} - "../acl/*.cc") +#list(APPEND ASCEND_SRC $ +# $ $) set_property(SOURCE ${ASCEND_SRC} PROPERTY COMPILE_DEFINITIONS LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" @@ -26,16 +20,17 @@ set_property(SOURCE ${ACL_SRC} PROPERTY COMPILE_DEFINITIONS LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) -add_library(ascend_kernel_plugin_mid OBJECT ${ASCEND_SRC} ${ACL_SRC}) -add_library(ascend_kernel_plugin SHARED $) -add_dependencies(ascend_kernel_plugin fbs_inner_src) -add_dependencies(ascend_kernel_plugin mindspore-extendrt) -target_link_libraries(ascend_kernel_plugin mindspore-extendrt _mindspore_ascend_symbol_obj) + +add_library(ascend_acl_plugin_mid OBJECT ${ASCEND_SRC} ${ACL_SRC}) +add_library(ascend_acl_plugin SHARED $) +add_dependencies(ascend_acl_plugin fbs_inner_src) +add_dependencies(ascend_acl_plugin mindspore-extendrt) +target_link_libraries(ascend_acl_plugin mindspore-extendrt _mindspore_ascend_symbol_obj) if("${MSLITE_REGISTRY_DEVICE}" STREQUAL "SD3403" AND PLATFORM_ARM64) find_library(acl_retr libacl_retr.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(acl_runtime libruntime.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - target_link_libraries(ascend_kernel_plugin ${ge_graph} ${acl} ${acl_retr} ${acl_cblas} ${acl_runtime}) + target_link_libraries(ascend_acl_plugin ${ge_graph} ${acl} ${acl_retr} ${acl_cblas} ${acl_runtime}) else() find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(ge_compiler libge_compiler.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) @@ -45,6 +40,8 @@ else() find_library(libaicore_utils libaicore_utils.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(libaicpu_engine_common libaicpu_engine_common.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - target_link_libraries(ascend_kernel_plugin ${ge_graph} ${ge_compiler} ${acl_cblas} ${acl_dvpp} ${acl_runtime} + target_link_libraries(ascend_acl_plugin ${ge_graph} ${ge_compiler} ${acl_cblas} ${acl_dvpp} ${acl_runtime} ${libplatform} ${libcompress} ${libopskernel} ${libaicore_utils} ${libaicpu_engine_common} ${acl}) endif() +# target_link_libraries(ascend_acl_plugin mindspore-extendrt) + diff --git a/mindspore-lite/src/extendrt/kernel/ascend/model/acl_allocator.cc b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_allocator.cc similarity index 98% rename from mindspore-lite/src/extendrt/kernel/ascend/model/acl_allocator.cc rename to mindspore-lite/src/extendrt/delegate/ascend_acl/acl_allocator.cc index ce54011b209fe928bb512dce832f5ed8181e1cfb..ccb4fba912bee6f2b1db2bebf7ddafb1b3e897e1 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/model/acl_allocator.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_allocator.cc @@ -14,14 +14,13 @@ * limitations under the License. */ -#include "src/extendrt/kernel/ascend/model/acl_allocator.h" +#include "src/extendrt/delegate/ascend_acl/acl_allocator.h" #include #include "src/common/log_adapter.h" #include "plugin/res_manager/ascend/symbol_interface/acl_rt_symbol.h" #include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" -namespace mindspore::kernel { -namespace acl { +namespace mindspore { AclAllocator *CreateAclAllocator() { MS_LOG(INFO) << "LoadAscendApiSymbols for MindSpore lite."; device::ascend::LoadAscendApiSymbols(); @@ -290,5 +289,4 @@ Status AclAllocator::CopyDeviceDataToDevice(void *src_device_data, void *dst_dev } return kSuccess; } -} // namespace acl -} // namespace mindspore::kernel +} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/kernel/ascend/model/acl_allocator.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_allocator.h similarity index 86% rename from mindspore-lite/src/extendrt/kernel/ascend/model/acl_allocator.h rename to mindspore-lite/src/extendrt/delegate/ascend_acl/acl_allocator.h index 5177392d628b91534b8b4147a1505e843f2b5107..7b98fe6ede05959f860edbe1d72eee01732700e0 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/model/acl_allocator.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_allocator.h @@ -14,18 +14,17 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ACL_ALLOCATOR_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ACL_ALLOCATOR_H_ +#ifndef DELEGATE_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ACL_ALLOCATOR_H_ +#define DELEGATE_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ACL_ALLOCATOR_H_ #include #include #include #include "acl/acl_base.h" #include "acl/acl_rt.h" #include "include/api/status.h" -#include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h" +#include "src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h" -namespace mindspore::kernel { -namespace acl { +namespace mindspore { class AclAllocator : public AscendAllocatorPluginImpl { public: AclAllocator() = default; @@ -58,6 +57,5 @@ class AclAllocator : public AscendAllocatorPluginImpl { extern "C" MS_API AclAllocator *CreateAclAllocator(); -} // namespace acl -} // namespace mindspore::kernel +} // namespace mindspore #endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ACL_ALLOCATOR_H_ diff --git a/mindspore-lite/src/extendrt/kernel/ascend/model/acl_env_guard.cc b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_env_guard.cc similarity index 96% rename from mindspore-lite/src/extendrt/kernel/ascend/model/acl_env_guard.cc rename to mindspore-lite/src/extendrt/delegate/ascend_acl/acl_env_guard.cc index 3f334b8823426de5bc7fb555f8baf1309aaa03b8..07ebe8942df5b909ecfe40ff90be9aa9de3e9b63 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/model/acl_env_guard.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_env_guard.cc @@ -14,14 +14,13 @@ * limitations under the License. */ -#include "extendrt/kernel/ascend/model/acl_env_guard.h" -#include "extendrt/kernel/ascend/model/model_infer.h" +#include "extendrt/delegate/ascend_acl/acl_env_guard.h" +#include "extendrt/delegate/ascend_acl/model_infer.h" #include "common/log_adapter.h" #include "plugin/res_manager/ascend/symbol_interface/acl_symbol.h" #include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" -namespace mindspore::kernel { -namespace acl { +namespace mindspore { std::shared_ptr AclEnvGuard::global_acl_env_ = nullptr; std::vector> AclEnvGuard::model_infers_ = {}; std::mutex AclEnvGuard::global_acl_env_mutex_; @@ -193,5 +192,4 @@ int32_t AclEnvGuard::GetModelNum() { return model_infers_.size(); } -} // namespace acl -} // namespace mindspore::kernel +} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/kernel/ascend/model/acl_env_guard.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_env_guard.h similarity index 88% rename from mindspore-lite/src/extendrt/kernel/ascend/model/acl_env_guard.h rename to mindspore-lite/src/extendrt/delegate/ascend_acl/acl_env_guard.h index 62e92956f8177e78698a78032555d80e4be61aac..3613a7ef1b8cc17a7696fb9c2707e41d7629f0d3 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/model/acl_env_guard.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_env_guard.h @@ -14,16 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_ACL_ENV_GUARD_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_ACL_ENV_GUARD_H_ +#ifndef delegate_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_ACL_ENV_GUARD_H_ +#define delegate_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_ACL_ENV_GUARD_H_ #include #include #include #include "acl/acl_base.h" -namespace mindspore::kernel { -namespace acl { +namespace mindspore { class AclInitAdapter { public: static AclInitAdapter &GetInstance(); @@ -61,7 +60,6 @@ class AclEnvGuard { aclError errno_; }; -} // namespace acl -} // namespace mindspore::kernel +} // namespace mindspore #endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_ACL_ENV_GUARD_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.cc b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..73076c821bfe06395d5ef5f48d55e7c17ccf04da --- /dev/null +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.cc @@ -0,0 +1,219 @@ +/** + * Copyright 2022-2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "extendrt/delegate/ascend_acl/acl_graph_executor.h" +#include "extendrt/delegate/ascend_acl/ascend_allocator_plugin.h" +#include "extendrt/session/lite_graph_executor.h" +#include "extendrt/delegate/factory.h" +#include "extendrt/utils/func_graph_utils.h" + +#include "plugin/res_manager/ascend/symbol_interface/acl_base_symbol.h" +#include "plugin/res_manager/ascend/symbol_interface/acl_mdl_symbol.h" +#include "plugin/res_manager/ascend/symbol_interface/acl_rt_symbol.h" +#include "plugin/res_manager/ascend/symbol_interface/acl_symbol.h" +#include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" +namespace mindspore { +namespace { +constexpr auto kProviderAcl = "litert"; +} // namespace + +Status AclGraphExecutor::Init() { + auto device_list = context_->MutableDeviceInfo(); + for (const auto &device_info : device_list) { + if (device_info == nullptr) { + MS_LOG(ERROR) << "Device info get from Context cannot be nullptr"; + return kLiteError; + } + if (device_info->GetDeviceType() == DeviceType::kAscend) { + bool is_registered = AscendAllocatorPlugin::GetInstance().Register(); + if (!is_registered) { + MS_LOG(ERROR) << "AscendAllocatorPlugin failed to register, cannot do acl memory operations"; + return kLiteError; + } + auto ascend_device_info = device_info->Cast(); + if (ascend_device_info == nullptr) { + MS_LOG(ERROR) << "Failed to cast device info to AscendDeviceInfo"; + return kLiteError; + } + return kSuccess; + } + } + return kSuccess; +} + +std::shared_ptr AclGraphExecutor::GenAclOptions() { + auto acl_options_ptr = std::make_shared(); + if (acl_options_ptr == nullptr) { + MS_LOG(ERROR) << "Acl options make shared failed."; + return nullptr; + } + auto profiling_path_val = primitive_->GetAttr(lite::kProfilingPathKey); + if (profiling_path_val != nullptr) { + auto val = GetValue(profiling_path_val); + acl_options_ptr->profiling_path = val; + } + auto dump_path_val = primitive_->GetAttr(lite::kDumpPathKey); + if (dump_path_val != nullptr) { + auto val = GetValue(dump_path_val); + acl_options_ptr->dump_path = val; + } + auto inner_calc_workspace_size = primitive_->GetAttr(lite::kInnerCalcWorkspaceSize); + if (inner_calc_workspace_size != nullptr) { + auto val = GetValue(inner_calc_workspace_size); + acl_options_ptr->multi_model_sharing_mem_prepare = val; + // is_multi_model_sharing_mem_prepare_ = true; + } + auto inner_sharing_workspace = primitive_->GetAttr(lite::kInnerSharingWorkspace); + if (inner_sharing_workspace != nullptr) { + auto val = GetValue(inner_sharing_workspace); + acl_options_ptr->multi_model_sharing_mem = val; + } + auto inner_model_path = primitive_->GetAttr(lite::kInnerModelPath); + if (inner_model_path != nullptr) { + auto val = GetValue(inner_model_path); + acl_options_ptr->model_path = val; + } + auto workspace_key = primitive_->GetAttr(lite::kInnerWorkspace); + if (workspace_key != nullptr) { + auto val = GetValue(workspace_key); + acl_options_ptr->share_workspace = val; + } + auto weightspace_key = primitive_->GetAttr(lite::kInnerWeightspace); + if (weightspace_key != nullptr) { + auto val = GetValue(weightspace_key); + acl_options_ptr->share_weightspace = val; + } + auto weightspace_workspace_key = primitive_->GetAttr(lite::kInnerWeightspaceWorkspace); + if (weightspace_workspace_key != nullptr) { + auto val = GetValue(weightspace_workspace_key); + acl_options_ptr->share_weightspace_workspace = val; + } + auto bundle_model = primitive_->GetAttr(lite::kBundleModel); + if (bundle_model != nullptr) { + auto val = GetValue(bundle_model); + acl_options_ptr->is_bundle_model = val; + } + acl_options_ptr->device_id = static_cast(0); + return acl_options_ptr; +} + +bool AclGraphExecutor::CompileGraph(const FuncGraphPtr &graph, const std::map &compile_options, + uint32_t *graph_id) { + // Get whether the current model is a bundle model for LORA. + if (graph->get_attr(lite::kBundleModel) != nullptr) { + config_info_["inner_common"][lite::kBundleModel] = "true"; + } + auto nodes = graph->TopoSort(graph->get_return()); + if (nodes.empty()) { + MS_LOG(ERROR) << "There are no nodes in the graph"; + return false; + } + void *om_data = nullptr; + size_t om_data_size = 0; + size_t cnode_count = 0; + BaseOperatorPtr op; + for (const auto &node : nodes) { + auto cnode = node->cast(); + if (!cnode || !AnfUtils::IsRealKernel(cnode)) { + continue; + } + std::string kernel_name = common::AnfAlgo::GetCNodeName(cnode); + if (kernel_name != lite::kNameCustomAscend) { + MS_LOG(ERROR) << "Only support " << lite::kNameCustomAscend << ", but got " << kernel_name << ", node " + << cnode->fullname_with_scope(); + return false; + } + cnode_count += 1; + if (cnode_count > 1) { + MS_LOG(ERROR) << "Only support one " << lite::kNameCustomAscend << " node, but got " << kernel_name << ", node " + << cnode->fullname_with_scope(); + return false; + } + std::vector inputs; + std::vector outputs; + FuncGraphUtils::GetCNodeInputsOutputs(cnode, &inputs, &outputs); + // for (size_t i = 0; i < inputs.size(); i++) { + auto &input = inputs[inputs.size() - 1]; + auto tensor_data = FuncGraphUtils::GetConstNodeValue(input.first); + om_data_size = tensor_data->Size(); + om_data = tensor_data->data_c(); + (void)FuncGraphUtils::GetCNodeOperator(cnode, &op); + } + if (om_data == nullptr || op == nullptr) { + MS_LOG(ERROR) << "om data is nullptr."; + return false; + } + // todo + primitive_ = op->GetPrim(); + auto acl_options = GenAclOptions(); + if (acl_options == nullptr) { + MS_LOG(ERROR) << "Generate acl options failed."; + return false; + } + + model_infer_ = std::make_shared(acl_options); + if (model_infer_ == nullptr) { + MS_LOG(ERROR) << "Create ModelInfer failed."; + return false; + } + if (!model_infer_->Init()) { + MS_LOG(ERROR) << "Model infer init failed."; + return false; + } + if (!model_infer_->Load(om_data, om_data_size)) { + MS_LOG(ERROR) << "Load om data failed."; + return false; + } + AclEnvGuard::AddModel(model_infer_); + return true; +} + +bool AclGraphExecutor::Resize(uint32_t graph_id, const std::vector &inputs, + const std::vector> &dims) { + (void)model_infer_->Resize(dims); + return true; +} + +std::vector AclGraphExecutor::GetOutputInfos(uint32_t graph_id) { + auto output_infos = graph_outputs_.find(graph_id) != graph_outputs_.end() ? graph_outputs_.at(graph_id) + : std::vector(); + return output_infos; +} + +bool AclGraphExecutor::RunGraph(uint32_t graph_id, const std::vector &inputs, + std::vector *output, + const std::map &compile_options) { + auto ret = model_infer_->Inference(inputs, output); + if (!ret) { + MS_LOG(ERROR) << "model infer failed."; + return false; + } + graph_outputs_[graph_id] = *output; + return true; +} + +static std::shared_ptr AclGraphExecutorCreator(const std::shared_ptr &ctx, + const ConfigInfos &config_infos) { + auto acl_executor = std::make_shared(ctx, config_infos); + if (acl_executor == nullptr && acl_executor->Init() != kSuccess) { + MS_LOG(ERROR) << "Failed to init GeGraphExecutor"; + return nullptr; + } + return acl_executor; +} + +REG_DELEGATE(kAscend, kProviderAcl, AclGraphExecutorCreator) +} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/graph_compiler.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.h similarity index 30% rename from mindspore-lite/src/extendrt/graph_compiler.h rename to mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.h index 08c0449b5719fee6ce0bd2c14723b7f12eaf59b1..baeba6e05a70eec268e25867deefd538769ceda6 100644 --- a/mindspore-lite/src/extendrt/graph_compiler.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2024 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,45 +13,57 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_H_ -#define MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_H_ +#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_ACL_ACL_GRAPH_EXECUTOR_H_ +#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_ACL_ACL_GRAPH_EXECUTOR_H_ + +#include #include #include #include -#include + #include "include/api/context.h" -#include "include/api/model.h" -#include "include/api/graph.h" #include "include/api/status.h" -#include "include/api/kernel.h" -#include "include/common/utils/utils.h" -#include "ir/func_graph.h" -#include "src/extendrt/graph_scheduler.h" +#include "extendrt/session/lite_graph_executor.h" +#include "common/config_infos.h" +#include "src/common/common.h" +#include "extendrt/delegate/ascend_acl/model_infer.h" namespace mindspore { -namespace infer { -using GraphId = uint32_t; -struct CompileOptions { - int optimize_level_; -}; -struct CompileResult {}; - -class GraphCompiler : public std::enable_shared_from_this { +class AclGraphExecutor : public LiteGraphExecutor { public: - explicit GraphCompiler(const CompileOptions &opts); - virtual ~GraphCompiler(); - ExcutionPlan Compile(FuncGraphPtr graph); - ExcutionPlan Compile(GraphSegmentPtr segment); - - protected: - ExcutionPlan Schedule(const CompileResult &); - GraphId CompileSegment(const GraphSegmentPtr segment); - CompileResult LinkSegment(); - - protected: - GraphScheduler scheduler_; - CompileOptions options_; + AclGraphExecutor(const std::shared_ptr &context, const ConfigInfos &config_info) { + context_ = context; + config_info_ = config_info; + } + ~AclGraphExecutor() {} + + bool CompileGraph(const FuncGraphPtr &graph, const std::map &compile_options, + uint32_t *graph_id) override; + bool RunGraph(uint32_t graph_id, const std::vector &inputs, + std::vector *outputs, const std::map &compile_options) override; + + bool Resize(uint32_t graph_id, const std::vector &inputs, + const std::vector &dims) override; + + std::vector GetOutputInfos(uint32_t graph_id) override; + + Status Init(); + + void Finalize() override { + AclEnvGuard::Finalize(); + } + + private: + Status BuildCustomAscendKernel(const CNodePtr &cnode); + std::shared_ptr GenAclOptions(); + + private: + std::shared_ptr context_ = nullptr; + ConfigInfos config_info_; + std::shared_ptr model_infer_; + std::shared_ptr primitive_ = nullptr; + std::map> graph_outputs_; }; -} // namespace infer + } // namespace mindspore -#endif +#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_ACL_ACL_GRAPH_EXECUTOR_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_mem_manager.cc b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_mem_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba95b775dd196414d318b3d77a20c670ce6860eb --- /dev/null +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_mem_manager.cc @@ -0,0 +1,271 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "extendrt/delegate/ascend_acl/acl_mem_manager.h" +#include +#include +#include +#include +#include +#include "src/common/log_adapter.h" +#include "plugin/res_manager/ascend/symbol_interface/acl_rt_symbol.h" +#include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" + +namespace mindspore { +STATUS AclMemManager::UpdateWorkspace(size_t work_size, size_t weight_size, int32_t device_id) { + auto it = work_mem_info_map_.find(device_id); + if (it == work_mem_info_map_.end()) { + AclModelMemInfo new_work_mem = {nullptr, 0}; + work_mem_info_map_.insert(std::make_pair(device_id, std::make_pair(new_work_mem, false))); + } else if (it->second.second == true) { + MS_LOG(ERROR) << "Device " << device_id << " has alloc memory!"; + return lite::RET_ERROR; + } + + it = work_mem_info_map_.find(device_id); + if (it == work_mem_info_map_.end()) { + MS_LOG(ERROR) << "Get mem failed!"; + return lite::RET_ERROR; + } + + if (work_size > it->second.first.mem_size) { + it->second.first.mem_size = work_size; + MS_LOG(DEBUG) << "Update work_size = " << it->second.first.mem_size << " successful."; + } + + if (weight_size > weight_mem_info_.mem_size) { + weight_mem_info_.mem_size = weight_size; + MS_LOG(DEBUG) << "Update weight_size = " << weight_size << " successful."; + } + return lite::RET_OK; +} + +STATUS AclMemManager::UpdateWorkspace(size_t work_size, int32_t device_id) { + auto it = work_mem_info_map_.find(device_id); + if (it == work_mem_info_map_.end()) { + AclModelMemInfo new_work_mem = {nullptr, 0}; + work_mem_info_map_.insert(std::make_pair(device_id, std::make_pair(new_work_mem, false))); + } else if (it->second.second == true) { + MS_LOG(ERROR) << "Device " << device_id << " has alloc memory!"; + return lite::RET_ERROR; + } + MS_LOG(DEBUG) << "Get device success."; + it = work_mem_info_map_.find(device_id); + if (it == work_mem_info_map_.end()) { + MS_LOG(ERROR) << "Get mem failed!"; + return lite::RET_ERROR; + } + MS_LOG(DEBUG) << "Begin record work size."; + if (work_size > it->second.first.mem_size) { + it->second.first.mem_size = work_size; + MS_LOG(DEBUG) << "Update work_size = " << it->second.first.mem_size << " successful."; + } + return lite::RET_OK; +} + +STATUS AclMemManager::UpdateWeightspace(std::string model_path, size_t weight_size, int32_t device_id) { + if (weight_mem_info_map_.find(device_id) == weight_mem_info_map_.end()) { + AclModelMemInfo new_weight_mem = {nullptr, weight_size}; + MemShareInfo mem_share_info; + mem_share_info.device_id = device_id; + mem_share_info.model_path = ""; + mem_share_info.mem_info = new_weight_mem; + mem_share_info.allocated = false; + std::map inner_map; + inner_map.insert(std::make_pair(model_path, mem_share_info)); + weight_mem_info_map_.insert(std::make_pair(device_id, inner_map)); + } else if (weight_mem_info_map_.at(device_id).find(model_path) == weight_mem_info_map_.at(device_id).end()) { + AclModelMemInfo new_weight_mem = {nullptr, weight_size}; + MemShareInfo mem_share_info; + mem_share_info.device_id = device_id; + mem_share_info.model_path = ""; + mem_share_info.mem_info = new_weight_mem; + mem_share_info.allocated = false; + weight_mem_info_map_.at(device_id).insert(std::make_pair(model_path, mem_share_info)); + } + return lite::RET_OK; +} + +STATUS AclMemManager::GetModelWorkMem(AclModelMemInfo *acl_work_mem_info, int32_t device_id) { + std::unique_lock acl_mtx(acl_mem_alloc_mutex_); + + auto it = work_mem_info_map_.find(device_id); + if (it == work_mem_info_map_.end()) { + MS_LOG(ERROR) << "Get work mem failed!"; + return lite::RET_ERROR; + } + it->second.second = true; + + if (it->second.first.mem_addr == nullptr) { + if (it->second.first.mem_size == 0) { + return lite::RET_ERROR; + } + auto acl_ret = + CALL_ASCEND_API(aclrtMalloc, &(it->second.first.mem_addr), it->second.first.mem_size, ACL_MEM_MALLOC_HUGE_FIRST); + if (acl_ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Call aclrtMalloc failed, err_code = " << acl_ret; + return lite::RET_ERROR; + } + MS_LOG(DEBUG) << "Malloc max work size is " << it->second.first.mem_size; + } + *acl_work_mem_info = it->second.first; + return lite::RET_OK; +} + +STATUS AclMemManager::GetModelWorkMem(void **work_ptr, int32_t device_id) { + MS_CHECK_TRUE_MSG(work_ptr != nullptr, lite::RET_NULL_PTR, "work_ptr is nullptr!"); + std::unique_lock acl_mtx(acl_mem_alloc_mutex_); + + auto it = work_mem_info_map_.find(device_id); + if (it == work_mem_info_map_.end()) { + MS_LOG(ERROR) << "Get work mem failed!"; + return lite::RET_ERROR; + } + it->second.second = true; + MS_LOG(DEBUG) << "Get device id success."; + if (it->second.first.mem_addr == nullptr) { + if (it->second.first.mem_size == 0) { + return lite::RET_ERROR; + } + MS_LOG(DEBUG) << "Begin alloc mem addr."; + auto acl_ret = + CALL_ASCEND_API(aclrtMalloc, &(it->second.first.mem_addr), it->second.first.mem_size, ACL_MEM_MALLOC_HUGE_FIRST); + if (acl_ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Call aclrtMalloc failed, err_code = " << acl_ret; + return lite::RET_ERROR; + } + MS_LOG(DEBUG) << "Malloc work mem success, max work size is " << it->second.first.mem_size; + } + *work_ptr = it->second.first.mem_addr; + return lite::RET_OK; +} + +STATUS AclMemManager::GetModelWeightMem(AclModelMemInfo *acl_weight_mem_info) { + std::unique_lock acl_mtx(acl_mem_alloc_mutex_); + if (weight_mem_info_.mem_addr == nullptr) { + if (weight_mem_info_.mem_size == 0) { + return lite::RET_ERROR; + } + auto acl_ret = + CALL_ASCEND_API(aclrtMalloc, &weight_mem_info_.mem_addr, weight_mem_info_.mem_size, ACL_MEM_MALLOC_HUGE_FIRST); + if (acl_ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Call aclrtMalloc failed, err_code = " << acl_ret; + return lite::RET_ERROR; + } + MS_LOG(DEBUG) << "Malloc max weight size is " << weight_mem_info_.mem_size; + } + *acl_weight_mem_info = weight_mem_info_; + return lite::RET_OK; +} + +STATUS AclMemManager::GetModelWeightMem(void **weight_ptr, std::string model_path, int32_t device_id) { + MS_CHECK_TRUE_MSG(weight_ptr != nullptr, lite::RET_NULL_PTR, "weight_ptr is nullptr!"); + std::unique_lock acl_mtx(acl_mem_alloc_mutex_); + if (weight_mem_info_map_.find(device_id) == weight_mem_info_map_.end()) { + MS_LOG(ERROR) << "Can't get weight mem of device " << device_id << "!"; + return lite::RET_ERROR; + } + if (weight_mem_info_map_.at(device_id).find(model_path) == weight_mem_info_map_.at(device_id).end()) { + MS_LOG(ERROR) << "Can't get weight mem of device " << device_id << " of model path " << model_path << "!"; + return lite::RET_ERROR; + } + auto &share_mem_info = weight_mem_info_map_.at(device_id).at(model_path); + + if (share_mem_info.mem_info.mem_addr == nullptr) { + if (share_mem_info.mem_info.mem_size == 0) { + MS_LOG(ERROR) << "Weight size if 0!"; + return lite::RET_ERROR; + } + auto acl_ret = CALL_ASCEND_API(aclrtMalloc, &(share_mem_info.mem_info.mem_addr), share_mem_info.mem_info.mem_size, + ACL_MEM_MALLOC_HUGE_FIRST); + if (acl_ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Call aclrtMalloc failed, err_code : " << acl_ret << "!"; + return lite::RET_ERROR; + } + MS_LOG(DEBUG) << "Malloc weight size is " << share_mem_info.mem_info.mem_size << "!"; + } + *weight_ptr = share_mem_info.mem_info.mem_addr; + return lite::RET_OK; +} + +void AclMemManager::Lock(int32_t device_id) { + acl_execute_mutex_.lock(); + if (device_lock_map_.find(device_id) == device_lock_map_.end()) { + device_lock_map_.emplace(std::piecewise_construct, std::forward_as_tuple(device_id), std::forward_as_tuple()); + } + acl_execute_mutex_.unlock(); + return device_lock_map_.at(device_id).lock(); +} + +void AclMemManager::Unlock(int32_t device_id) { + acl_execute_mutex_.lock(); + if (device_lock_map_.find(device_id) == device_lock_map_.end()) { + device_lock_map_.emplace(std::piecewise_construct, std::forward_as_tuple(device_id), std::forward_as_tuple()); + } + acl_execute_mutex_.unlock(); + return device_lock_map_.at(device_id).unlock(); +} + +void AclMemManager::ReleaseDeviceMem(int32_t device_id, std::string model_path) { + for (auto &device_id_iter : work_mem_info_map_) { + if (device_id_iter.first != device_id) { + continue; + } + if (device_id_iter.second.first.mem_addr != nullptr) { + (void)CALL_ASCEND_API(aclrtFree, device_id_iter.second.first.mem_addr); + device_id_iter.second.first.mem_addr = nullptr; + } + } + for (auto &device_id_iter : weight_mem_info_map_) { + if (device_id_iter.first != device_id) { + continue; + } + for (auto &model_path_iter : device_id_iter.second) { + if (model_path_iter.first != model_path) { + continue; + } + if (model_path_iter.second.mem_info.mem_addr != nullptr) { + (void)CALL_ASCEND_API(aclrtFree, model_path_iter.second.mem_info.mem_addr); + model_path_iter.second.mem_info.mem_addr = nullptr; + } + } + } +} + +AclMemManager::~AclMemManager() { + for (auto &mem_info_pair : work_mem_info_map_) { + if (mem_info_pair.second.first.mem_addr != nullptr) { + (void)CALL_ASCEND_API(aclrtFree, mem_info_pair.second.first.mem_addr); + mem_info_pair.second.first.mem_addr = nullptr; + mem_info_pair.second.first.mem_size = 0; + } + } + if (weight_mem_info_.mem_addr != nullptr) { + (void)CALL_ASCEND_API(aclrtFree, weight_mem_info_.mem_addr); + weight_mem_info_.mem_addr = nullptr; + weight_mem_info_.mem_size = 0; + } + for (auto &device_id_iter : weight_mem_info_map_) { + for (auto &model_path_iter : device_id_iter.second) { + if (model_path_iter.second.mem_info.mem_addr != nullptr) { + (void)CALL_ASCEND_API(aclrtFree, model_path_iter.second.mem_info.mem_addr); + model_path_iter.second.mem_info.mem_addr = nullptr; + model_path_iter.second.mem_info.mem_size = 0; + } + } + } +} +} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_mem_manager.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_mem_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..25a928cc3708ae438f75ebc87142dfc61bf5305f --- /dev/null +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_mem_manager.h @@ -0,0 +1,78 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef delegate_MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ASCEND_SRC_ACL_MEM_MANAGER_H_ +#define delegate_MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ASCEND_SRC_ACL_MEM_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "include/errorcode.h" + +namespace mindspore { +using mindspore::lite::STATUS; + +struct AclModelMemInfo { + void *mem_addr = nullptr; + size_t mem_size = 0; +}; + +struct MemShareInfo { + int32_t device_id; + std::thread::id thread_id; + std::string model_path; + AclModelMemInfo mem_info; + bool allocated; +}; +class AclMemManager { + public: + AclMemManager() {} + ~AclMemManager(); + + AclMemManager(const AclMemManager &) = delete; + AclMemManager &operator=(const AclMemManager &) = delete; + + static AclMemManager &GetInstance() { + static AclMemManager instance; + return instance; + } + STATUS UpdateWorkspace(size_t work_size, size_t weight_size, int32_t device_id); + STATUS UpdateWorkspace(size_t work_size, int32_t device_id); + STATUS UpdateWeightspace(std::string model_path, size_t weight_size, int32_t device_id); + STATUS GetModelWorkMem(AclModelMemInfo *acl_work_mem_info, int32_t device_id); + STATUS GetModelWorkMem(void **work_ptr, int32_t device_id); + STATUS GetModelWeightMem(AclModelMemInfo *acl_weight_mem_info); + STATUS GetModelWeightMem(void **weight_ptr, std::string model_path, int32_t device_id); + void ReleaseDeviceMem(int32_t device_id, std::string model_path); + void Lock() { return acl_execute_mutex_.lock(); } + void Unlock() { return acl_execute_mutex_.unlock(); } + void Lock(int32_t device_id); + void Unlock(int32_t device_id); + + private: + std::mutex acl_mem_alloc_mutex_; + std::mutex acl_execute_mutex_; + std::map device_lock_map_; + std::map> work_mem_info_map_; + std::map> weight_mem_info_map_; + AclModelMemInfo weight_mem_info_ = {nullptr, 0}; +}; +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ASCEND_SRC_ACL_MEM_MANAGER_H_ diff --git a/mindspore-lite/src/extendrt/kernel/ascend/options/acl_model_options.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_model_options.h similarity index 82% rename from mindspore-lite/src/extendrt/kernel/ascend/options/acl_model_options.h rename to mindspore-lite/src/extendrt/delegate/ascend_acl/acl_model_options.h index 5bd3db3f8532de9b191d06af16e6bf582eb49d06..39fb19d679f4fa4a222d9f5626ee50a6bc205a45 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/options/acl_model_options.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_model_options.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_SRC_ACL_MODEL_OPTIONS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_SRC_ACL_MODEL_OPTIONS_H_ +#ifndef delegate_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_SRC_ACL_MODEL_OPTIONS_H_ +#define delegate_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_SRC_ACL_MODEL_OPTIONS_H_ #include #include @@ -25,8 +25,7 @@ #include "mindapi/base/format.h" #include "acl/acl_mdl.h" -namespace mindspore::kernel { -namespace acl { +namespace mindspore { struct AclModelOptions { int32_t device_id; std::string dump_path; @@ -49,7 +48,5 @@ struct AclDynamicShapeOptions { std::vector> input_shapes; }; -using AclModelOptionsPtr = std::shared_ptr; -} // namespace acl -} // namespace mindspore::kernel +} // namespace mindspore #endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_SRC_ACL_MODEL_OPTIONS_H_ diff --git a/mindspore-lite/src/extendrt/graph_runtime/factory.cc b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_plugin_impl.cc similarity index 43% rename from mindspore-lite/src/extendrt/graph_runtime/factory.cc rename to mindspore-lite/src/extendrt/delegate/ascend_acl/acl_plugin_impl.cc index d96e41c5db269d0ea6c8436f702502d9e97c5c79..92956be2a17dbff7526188e079a466e0d52f6578 100644 --- a/mindspore-lite/src/extendrt/graph_runtime/factory.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_plugin_impl.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2021 Huawei Technologies Co., Ltd + * Copyright 2022-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,29 +13,28 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "extendrt/graph_runtime/factory.h" -#include + #include +#include "extendrt/delegate/ascend_acl/acl_plugin_impl.h" namespace mindspore { -GraphRuntimeRegistry &GraphRuntimeRegistry::GetInstance() { - static GraphRuntimeRegistry instance; - return instance; -} - -void GraphRuntimeRegistry::RegRuntime(const mindspore::GraphRuntimeType &type, const GraphRuntimeRegFunc &creator) { - if (creator == nullptr) { - return; +std::shared_ptr AscendAclExecutorPluginImpl::InitAclGraphExecutor( + const std::shared_ptr &context, const ConfigInfos &config_infos) { + if (context == nullptr) { + MS_LOG(ERROR) << "Parameter context cannot be nullptr"; + return nullptr; } - graph_runtime_map_[type] = creator; -} - -std::shared_ptr GraphRuntimeRegistry::GetRuntime( - const mindspore::GraphRuntimeType &type) { - auto it = graph_runtime_map_.find(type); - if (it == graph_runtime_map_.end()) { + auto acl_graph_executor = std::make_shared(context, config_infos); + if (acl_graph_executor == nullptr) { + MS_LOG(ERROR) << "Failed to create GeGraphExecutor"; + return nullptr; + } + if (!acl_graph_executor->Init()) { + MS_LOG(ERROR) << "Failed to init ge graph executor"; return nullptr; } - return it->second(); + return acl_graph_executor; } + +AscendAclExecutorPluginImpl *CreateAscendAclExecutorPluginImpl() { return new AscendAclExecutorPluginImpl(); } } // namespace mindspore diff --git a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_optimization.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_plugin_impl.h similarity index 35% rename from mindspore-lite/tools/graph_kernel/converter/graph_kernel_optimization.h rename to mindspore-lite/src/extendrt/delegate/ascend_acl/acl_plugin_impl.h index b220a539bcdeab07fce44cc6e7281e660e2757a5..2a169e23efcd936b0b08c9a6d9b2c79d774929cb 100644 --- a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_optimization.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_plugin_impl.h @@ -13,46 +13,33 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_OPTIMIZATION_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_OPTIMIZATION_H_ +#ifndef MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_ACL_ACL_PLUGIN_IMPL_H_ +#define MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_ACL_ACL_PLUGIN_IMPL_H_ #include -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "include/errorcode.h" -#include "tools/converter/cxx_api/converter_para.h" -#include "tools/graph_kernel/converter/graph_kernel_pass_manager_lite.h" +#include "include/api/status.h" +#include "src/common/log_adapter.h" +#include "extendrt/delegate/plugin/ascend_acl_executor_plugin.h" +// #include "extendrt/delegate/ascend_ge/ge_device_context.h" +#include "extendrt/delegate/ascend_acl/acl_graph_executor.h" namespace mindspore { -namespace graphkernel { -class GraphKernelOptimizer { +class AscendAclExecutorPluginImpl : public lite::AscendAclExecutorPluginImplBase { public: - explicit GraphKernelOptimizer(const std::shared_ptr ¶m) : converter_param_(param) {} - ~GraphKernelOptimizer() = default; - void Run(const FuncGraphPtr &func_graph); + AscendAclExecutorPluginImpl() = default; + ~AscendAclExecutorPluginImpl() = default; +// Status AdaptGraph(FuncGraphPtr graph) const override; +// bool AoeTuning(const FuncGraphPtr &graph, const std::shared_ptr &context, +// const ConfigInfos &config_infos) override; - private: - void Init() const; - // Pre-process - GkPassManagerPtr PreProcess() const; - // Cluster kernels - GkPassManagerPtr Cluster() const; - // Optimize 1 - GkPassManagerPtr HighLevelOpt1() const; - // Split kernels - GkPassManagerPtr Split() const; - // Build akg kernel - GkPassManagerPtr BuildKernel() const; - // Post-process - GkPassManagerPtr PostProcess() const; - - std::shared_ptr converter_param_; +// bool OfflineBuildGraph(const FuncGraphPtr &graph, const std::shared_ptr &context, +// const ConfigInfos &config_infos) override; - bool is_cpu{false}; - bool is_ascend{false}; + private: + std::shared_ptr InitAclGraphExecutor(const std::shared_ptr &context, + const ConfigInfos &config_infos); }; -} // namespace graphkernel -lite::STATUS GraphKernelOptimize(const FuncGraphPtr &func_graph, const std::shared_ptr ¶m); +extern "C" MS_API AscendAclExecutorPluginImpl *CreateAscendAclExecutorPluginImpl(); } // namespace mindspore -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_OPTIMIZATION_H_ +#endif // MINDSPORE_LITE_SRC_EXTENDRT_ASCEND_ACL_ACL_PLUGIN_IMPL_H_ diff --git a/mindspore-lite/src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.cc b/mindspore-lite/src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.cc similarity index 97% rename from mindspore-lite/src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.cc rename to mindspore-lite/src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.cc index 5c0ec6995129203e9184f4f07d14e04c6d225317..689a31333af0dd52aa7f6c5df105f63fd97395b3 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.cc @@ -14,15 +14,15 @@ * limitations under the License. */ -#include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h" +#include "src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h" #include #if !defined(_WIN32) #include "src/extendrt/cxx_api/dlutils.h" #endif -namespace mindspore::kernel { +namespace mindspore { namespace { -constexpr auto kAscendkernelPluginSoNmae = "libascend_kernel_plugin.so"; +constexpr auto kAscendkernelPluginSoNmae = "libascend_acl_plugin.so"; constexpr auto kFunCreateAscendAllocatorPluginImpl = "CreateAclAllocator"; #if !defined(_WIN32) std::mutex mutex_; @@ -236,4 +236,4 @@ Status AscendAllocatorPlugin::CopyHostDataToDevice(void *host_data, void *device #endif return kSuccess; } -} // namespace mindspore::kernel +} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h similarity index 91% rename from mindspore-lite/src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h rename to mindspore-lite/src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h index 6e74d6c5ff8e689a116028fb35ae6e7b8968ad32..c61d35a98afbeb32ac67132b246ad4fc2ec701c9 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h @@ -14,12 +14,12 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_ALLOCATOR_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_ALLOCATOR_PLUGIN_H_ +#ifndef delegate_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_ALLOCATOR_PLUGIN_H_ +#define delegate_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_ALLOCATOR_PLUGIN_H_ #include #include #include "include/api/status.h" -namespace mindspore::kernel { +namespace mindspore { class AscendAllocatorPluginImpl { public: AscendAllocatorPluginImpl() = default; @@ -61,5 +61,5 @@ class MS_API AscendAllocatorPlugin { bool is_registered_ = false; std::shared_ptr ascend_allocator_plugin_impl_ = nullptr; }; -} // namespace mindspore::kernel +} // namespace mindspore #endif diff --git a/mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.cc b/mindspore-lite/src/extendrt/delegate/ascend_acl/dyn_shape_process.cc similarity index 97% rename from mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.cc rename to mindspore-lite/src/extendrt/delegate/ascend_acl/dyn_shape_process.cc index d3f531f67484d46086b332791631d5762898716c..17cab81dda8a30cb30e8dd48a93712cc9ba72413 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/dyn_shape_process.cc @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "extendrt/kernel/ascend/model/dyn_shape_process.h" +#include "extendrt/delegate/ascend_acl/dyn_shape_process.h" #include #include "mindspore/ops/kernel/cpu/nnacl/op_base.h" #include "include/errorcode.h" +#include "src/common/log.h" -namespace mindspore::kernel { -namespace acl { +namespace mindspore { namespace { constexpr auto kInputDimNum = 4; constexpr auto kNHWCNIdx = 0; @@ -281,10 +281,9 @@ bool DynShapeProcess::GetRealImageSize(const std::vector &new_shape << " is invalid, please check device info of context."; return false; } - *height_p = LongToInt(height); - *width_p = LongToInt(width); - MS_LOG(DEBUG) << "Current height " << height << " width " << width; + *height_p = static_cast(height); + *width_p = static_cast(width); + MS_LOG(ERROR) << "Current height " << height << " width " << width; return true; } -} // namespace acl -} // namespace mindspore::kernel +} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/dyn_shape_process.h similarity index 86% rename from mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.h rename to mindspore-lite/src/extendrt/delegate/ascend_acl/dyn_shape_process.h index 621a18bfb945d075d310cc04d15c8d3ab7b10035..858d2f6cb392a91d0348d342fc8c4aa230c09a43 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/model/dyn_shape_process.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/dyn_shape_process.h @@ -20,15 +20,16 @@ #include #include #include -#include "extendrt/kernel/ascend/options/acl_model_options.h" -#include "common/kernel.h" +#include "extendrt/delegate/ascend_acl/acl_model_options.h" #include "include/api/types.h" #include "acl/acl.h" #include "acl/acl_mdl.h" #include "acl/acl_rt.h" +using ShapeValueDType = int64_t; +using ShapeVector = std::vector; +using ShapeArray = std::vector; -namespace mindspore::kernel { -namespace acl { +namespace mindspore { class DynShapeProcess { public: bool Init(const AclDynamicShapeOptions &options); @@ -48,7 +49,6 @@ class DynShapeProcess { size_t input_data_idx_ = 0; }; -using DynShapeProcPtr = std::shared_ptr; -} // namespace acl -} // namespace mindspore::kernel +// using DynShapeProcPtr = std::shared_ptr; +} // namespace mindspore #endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_DYN_SHAPE_PROCESS_H diff --git a/mindspore-lite/src/extendrt/kernel/ascend/model/model_infer.cc b/mindspore-lite/src/extendrt/delegate/ascend_acl/model_infer.cc similarity index 79% rename from mindspore-lite/src/extendrt/kernel/ascend/model/model_infer.cc rename to mindspore-lite/src/extendrt/delegate/ascend_acl/model_infer.cc index 9c1537e68588259bc05a2d2b846dbd96f918c4b2..2b0a76e2930df01b4878918997abf3eb0d68764f 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/model/model_infer.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/model_infer.cc @@ -14,17 +14,16 @@ * limitations under the License. */ -#include "extendrt/kernel/ascend/model/model_infer.h" +#include "extendrt/delegate/ascend_acl/model_infer.h" #include "common/log_adapter.h" #include "plugin/res_manager/ascend/symbol_interface/acl_rt_symbol.h" #include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" -namespace mindspore::kernel { -namespace acl { +namespace mindspore { namespace { std::mutex g_context_mutex; } -ModelInfer::ModelInfer(const AclModelOptionsPtr &options) +ModelInfer::ModelInfer(const std::shared_ptr &options) : init_flag_(false), device_type_("AscendCL"), context_(nullptr), @@ -62,22 +61,23 @@ bool ModelInfer::Init() { } MS_LOG(INFO) << "Open device " << device_id << " success."; - std::string overflow_mode = common::GetEnv("MS_ASCEND_CHECK_OVERFLOW_MODE"); - if (overflow_mode == "INFNAN_MODE") { - auto mode = aclrtFloatOverflowMode::ACL_RT_OVERFLOW_MODE_INFNAN; - ret = CALL_ASCEND_API(aclrtSetDeviceSatMode, mode); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Set INFNAN mode failed"; - return false; - } - } else if (overflow_mode == "SATURATION_MODE") { - auto mode = aclrtFloatOverflowMode::ACL_RT_OVERFLOW_MODE_SATURATION; - ret = CALL_ASCEND_API(aclrtSetDeviceSatMode, mode); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Set SATURATION mode failed"; - return false; - } - } + // todo + // std::string overflow_mode = common::GetEnv("MS_ASCEND_CHECK_OVERFLOW_MODE"); + // if (overflow_mode == "INFNAN_MODE") { + // auto mode = aclrtFloatOverflowMode::ACL_RT_OVERFLOW_MODE_INFNAN; + // ret = CALL_ASCEND_API(aclrtSetDeviceSatMode, mode); + // if (ret != ACL_SUCCESS) { + // MS_LOG(ERROR) << "Set INFNAN mode failed"; + // return false; + // } + // } else if (overflow_mode == "SATURATION_MODE") { + // auto mode = aclrtFloatOverflowMode::ACL_RT_OVERFLOW_MODE_SATURATION; + // ret = CALL_ASCEND_API(aclrtSetDeviceSatMode, mode); + // if (ret != ACL_SUCCESS) { + // MS_LOG(ERROR) << "Set SATURATION mode failed"; + // return false; + // } + // } ret = CALL_ASCEND_API(aclrtGetCurrentContext, &context_); if (ret != ACL_SUCCESS) { @@ -171,7 +171,7 @@ bool ModelInfer::Load(const void *om_data, size_t om_data_size) { return true; } -bool ModelInfer::Inference(const std::vector &inputs, const std::vector &outputs) { +bool ModelInfer::Inference(const std::vector &inputs, std::vector *outputs) { aclError rt_ret = CALL_ASCEND_API(aclrtSetCurrentContext, context_); if (rt_ret != ACL_SUCCESS) { MS_LOG(ERROR) << "Set the ascend device context failed, ret = " << rt_ret; @@ -180,12 +180,12 @@ bool ModelInfer::Inference(const std::vector &inputs, const std: auto ret = model_process_.PredictFromHost(inputs, outputs); if (!ret) { MS_LOG(ERROR) << "Model predict failed"; - return ret; + return false; } return true; } -bool ModelInfer::UpdateWeights(const std::vector &inputs) { +bool ModelInfer::UpdateWeights(const std::vector &inputs) { aclError rt_ret = CALL_ASCEND_API(aclrtSetCurrentContext, context_); if (rt_ret != ACL_SUCCESS) { MS_LOG(ERROR) << "Set the ascend device context failed, ret = " << rt_ret; @@ -193,26 +193,30 @@ bool ModelInfer::UpdateWeights(const std::vector &inputs) { } auto ret = model_process_.UpdateWeights(inputs); if (!ret) { - MS_LOG(ERROR) << "Update weights failed!"; - return ret; + MS_LOG(ERROR) << "Model UpdateWeights failed"; + return false; } return true; } std::vector ModelInfer::GetInputFormat() { return model_process_.GetInputFormat(); } -const std::vector ModelInfer::GetOutputShape() { return model_process_.GetOutputShape(); } -const std::vector ModelInfer::GetInputShape() { return model_process_.GetInputShape(); } +const std::vector> ModelInfer::GetOutputShape() { return model_process_.GetOutputShape(); } +const std::vector> ModelInfer::GetInputShape() { return model_process_.GetInputShape(); } const std::vector ModelInfer::GetInputDataType() { return model_process_.GetInputDataType(); } const std::vector ModelInfer::GetOutputDataType() { return model_process_.GetOutputDataType(); } std::vector ModelInfer::GetOutputFormat() { return model_process_.GetOutputFormat(); } -bool ModelInfer::Resize(const std::vector &new_shapes) { +bool ModelInfer::Resize(const std::vector> &new_shapes) { aclError rt_ret = CALL_ASCEND_API(aclrtSetCurrentContext, context_); if (rt_ret != ACL_SUCCESS) { MS_LOG(ERROR) << "Set the ascend device context failed, ret = " << rt_ret; return false; } - return model_process_.Resize(new_shapes); + auto ret = model_process_.Resize(new_shapes); + if (!ret) { + MS_LOG(ERROR) << "model resize failed."; + return false; + } + return true; } -} // namespace acl -} // namespace mindspore::kernel +} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/kernel/ascend/model/model_infer.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/model_infer.h similarity index 58% rename from mindspore-lite/src/extendrt/kernel/ascend/model/model_infer.h rename to mindspore-lite/src/extendrt/delegate/ascend_acl/model_infer.h index cecb9c33e94d0336813043857cd5a9b8de1df026..8fcd91dee6a0cc9365f0a585c12ad60e417b06a1 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/model/model_infer.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/model_infer.h @@ -14,56 +14,53 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_MODEL_INFER_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_MODEL_INFER_H_ +#ifndef DELEGATE_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_MODEL_INFER_H_ +#define DELEGATE_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_MODEL_INFER_H_ #include #include #include #include #include -#include "extendrt/kernel/ascend/model/model_process.h" -#include "extendrt/kernel/ascend/model/acl_env_guard.h" -#include "extendrt/kernel/ascend/options/acl_model_options.h" #include "include/api/types.h" #include "include/errorcode.h" -#include "extendrt/kernel/ascend/profiling/profiling.h" -namespace mindspore::kernel { -namespace acl { -using mindspore::lite::STATUS; +#include "extendrt/delegate/ascend_acl/model_process.h" +#include "extendrt/delegate/ascend_acl/acl_env_guard.h" +#include "extendrt/delegate/ascend_acl/acl_model_options.h" +#include "extendrt/delegate/ascend_acl/profiling.h" +#include "mindspore/core/include/mindapi/base/type_id.h" +namespace mindspore { +// using mindspore::lite::STATUS; class ModelInfer { public: - explicit ModelInfer(const AclModelOptionsPtr &options); + explicit ModelInfer(const std::shared_ptr &options); ~ModelInfer() = default; bool Init(); bool Finalize(bool process_ends = false); bool Load(const void *om_data, size_t om_data_size); - bool Inference(const std::vector &inputs, const std::vector &outputs); - bool UpdateWeights(const std::vector &inputs); + bool Inference(const std::vector &inputs, std::vector *outputs); + bool UpdateWeights(const std::vector &inputs); std::vector GetInputFormat(); - const std::vector GetOutputShape(); - const std::vector GetInputShape(); + const std::vector> GetOutputShape(); + const std::vector> GetInputShape(); const std::vector GetInputDataType(); const std::vector GetOutputDataType(); std::vector GetOutputFormat(); - bool Resize(const std::vector &new_shapes); + bool Resize(const std::vector> &new_shapes); private: bool init_flag_; std::string device_type_; aclrtContext context_; aclrtStream stream_; - AclModelOptionsPtr options_; + std::shared_ptr options_; ModelProcess model_process_; Profiling profiling_; std::shared_ptr acl_env_; }; - -using ModelInferPtr = std::shared_ptr; -} // namespace acl -} // namespace mindspore::kernel +} // namespace mindspore #endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_MODEL_INFER_H_ diff --git a/mindspore-lite/src/extendrt/kernel/ascend/model/model_process.cc b/mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.cc similarity index 82% rename from mindspore-lite/src/extendrt/kernel/ascend/model/model_process.cc rename to mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.cc index 24e460e7433b3e0ad91e6cfcd803b1cf2873c190..aba46ca78a13019a94ed06619e438467b3882a25 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/model/model_process.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "extendrt/kernel/ascend/model/model_process.h" +#include "extendrt/delegate/ascend_acl/model_process.h" #include #include #include @@ -23,16 +23,14 @@ #include "common/log_adapter.h" #include "src/common/utils.h" #include "src/common/log_util.h" -#include "src/litert/kernel/ascend/src/acl_mem_manager.h" -#include "src/extendrt/kernel/ascend/model/acl_allocator.h" +#include "src/extendrt/delegate/ascend_acl/acl_allocator.h" #include "plugin/res_manager/ascend/symbol_interface/acl_base_symbol.h" #include "plugin/res_manager/ascend/symbol_interface/acl_mdl_symbol.h" #include "plugin/res_manager/ascend/symbol_interface/acl_rt_symbol.h" #include "plugin/res_manager/ascend/symbol_interface/acl_symbol.h" #include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" -namespace mindspore::kernel { -namespace acl { +namespace mindspore { namespace { constexpr size_t kBatchSizeNum = 1; constexpr size_t kImageSizeHwNum = 2; @@ -109,11 +107,11 @@ aclError ModelProcess::AclrtMemcpy(void *dst, size_t destMax, const void *src, s (kUSecondInSecond * static_cast(end_time.tv_sec) + static_cast(end_time.tv_usec)) - (kUSecondInSecond * static_cast(start_time.tv_sec) + static_cast(start_time.tv_usec)); if (kind == ACL_MEMCPY_DEVICE_TO_HOST) { - MS_LOG(DEBUG) << "Device to Host copy in " << cost << " us"; + MS_LOG(ERROR) << "Device to Host copy in " << cost << " us"; } else if (kind == ACL_MEMCPY_HOST_TO_DEVICE) { - MS_LOG(DEBUG) << "Host to Device copy in " << cost << " us"; + MS_LOG(ERROR) << "Host to Device copy in " << cost << " us"; } else if (kind == ACL_MEMCPY_DEVICE_TO_DEVICE) { - MS_LOG(DEBUG) << "Device to Device copy in " << cost << " us"; + MS_LOG(ERROR) << "Device to Device copy in " << cost << " us"; } } return ret; @@ -1006,7 +1004,7 @@ bool ModelProcess::ResizeDynamicBatchAndImageSize(const std::vector return true; } -bool ModelProcess::CheckInputTensors(const std::vector &input_tensors) { +bool ModelProcess::CheckInputTensors(const std::vector &input_tensors) { if (data_input_num_ != input_tensors.size()) { MS_LOG(ERROR) << "Expect input size to be " << data_input_num_ << ", but got " << input_tensors.size(); return false; @@ -1014,29 +1012,29 @@ bool ModelProcess::CheckInputTensors(const std::vector &input_te for (size_t i = 0; i < input_tensors.size(); ++i) { auto &tensor = input_tensors[i]; auto &info = input_infos_[i]; - if (tensor->GetShapeVector() != info.dims) { + if (tensor.Shape() != info.dims) { MS_LOG(WARNING) << "Note: input " << i << " shape not match, required " << ShapeToString(info.dims) << ", given " - << ShapeToString(tensor->GetShapeVector()) << "." + << ShapeToString(tensor.Shape()) << "." << "Please check input shape has been modified by DVPP method."; } - if (tensor->dtype_id() != TransToDataType(info.data_type)) { + if (static_cast(tensor.DataType()) != TransToDataType(info.data_type)) { MS_LOG(ERROR) << "Note: input " << i << " data type not match, required " << static_cast(TransToDataType(info.data_type)) << ", given " - << static_cast(tensor->dtype_id()); + << static_cast(tensor.DataType()); return false; } - auto device_data = tensor->GetData(); - auto host_data = tensor->GetHostData(); - if (device_data != nullptr && device_data->addr != nullptr) { - if (!is_dynamic_input_ && !is_dynamic_shape_range_ && device_data->size != info.buffer_size) { + void *device_data_addr = static_cast(tensor).GetDeviceData(); + auto host_data_addr = tensor.Data().get(); + if (device_data_addr != nullptr) { + if (!is_dynamic_input_ && !is_dynamic_shape_range_ && tensor.DataSize() != info.buffer_size) { MS_LOG(ERROR) << "Input " << i << " data size not match, required size " << info.buffer_size << ", given count " - << device_data->size; + << tensor.DataSize(); return false; } - } else if (host_data != nullptr && host_data->addr != nullptr) { - if (!is_dynamic_input_ && !is_dynamic_shape_range_ && host_data->size != info.buffer_size) { + } else if (host_data_addr != nullptr) { + if (!is_dynamic_input_ && !is_dynamic_shape_range_ && tensor.DataSize() != info.buffer_size) { MS_LOG(ERROR) << "Input " << i << " data size not match, required size " << info.buffer_size << ", given count " - << host_data->size; + << tensor.DataSize(); return false; } } else { @@ -1047,42 +1045,42 @@ bool ModelProcess::CheckInputTensors(const std::vector &input_te return true; } -bool ModelProcess::CheckOutputTensors(const std::vector &outputs) { - if (outputs.size() != output_infos_.size()) { +bool ModelProcess::CheckOutputTensors(const std::vector *outputs) { + if (outputs->size() != output_infos_.size()) { MS_LOG(ERROR) << "Actual tensor count not match, required count " << output_infos_.size() << ", given count " - << outputs.size(); + << outputs->size(); return false; } if (is_dynamic_output_) { MS_LOG(INFO) << "This Model has dynamic output shape."; return true; } - for (size_t i = 0; i < outputs.size(); ++i) { - auto &tensor = outputs[i]; + for (size_t i = 0; i < outputs->size(); ++i) { + auto &tensor = outputs->at(i); auto &info = output_infos_[i]; - if (tensor->GetShapeVector() != info.dims) { + if (tensor.Shape() != info.dims) { MS_LOG(WARNING) << "Note: output " << i << " shape not match, required " << ShapeToString(info.dims) << ", given " - << ShapeToString(tensor->GetShapeVector()) << "." + << ShapeToString(tensor.Shape()) << "." << "Please check output shape."; } - if (tensor->dtype_id() != TransToDataType(info.data_type)) { + if (static_cast(tensor.DataType()) != TransToDataType(info.data_type)) { MS_LOG(ERROR) << "Note: output " << i << " data type not match, required " << static_cast(TransToDataType(info.data_type)) << ", given " - << static_cast(tensor->dtype_id()); + << static_cast(static_cast(tensor.DataType())); return false; } - auto device_data = tensor->GetData(); - auto host_data = tensor->GetHostData(); - if (device_data != nullptr && device_data->addr != nullptr) { - if (device_data->size != info.buffer_size) { + auto device_data_addr = static_cast(tensor).GetDeviceData(); + auto host_data_addr = tensor.Data().get(); + if (device_data_addr != nullptr) { + if (tensor.DataSize() != info.buffer_size) { MS_LOG(ERROR) << "Output " << i << " device data size not match, required size " << info.buffer_size - << ", given count " << tensor->GetData()->size; + << ", given count " << tensor.DataSize(); return false; } - } else if (host_data != nullptr && host_data->addr != nullptr) { - if (host_data->size != info.buffer_size) { + } else if (host_data_addr != nullptr) { + if (tensor.DataSize() != info.buffer_size) { MS_LOG(ERROR) << "Output " << i << " host data size not match, required size " << info.buffer_size - << ", given count " << tensor->GetData()->size; + << ", given count " << tensor.DataSize(); return false; } } else { @@ -1093,7 +1091,7 @@ bool ModelProcess::CheckOutputTensors(const std::vector &outputs return true; } -bool ModelProcess::CheckAndInitInput(const std::vector &inputs) { +bool ModelProcess::CheckAndInitInput(const std::vector &inputs) { // check inputs if (!CheckInputTensors(inputs)) { MS_LOG(ERROR) << "Check input tensor failed."; @@ -1105,16 +1103,16 @@ bool ModelProcess::CheckAndInitInput(const std::vector &inputs) auto &info = input_infos_[i]; auto input = inputs[i]; void *input_buffer = nullptr; - auto device_data = input->GetData(); - auto host_data = input->GetHostData(); - if (device_data && device_data->addr) { - auto input_device_id = input->device_id(); - if (input_device_id == IntToUint(device_id_)) { - input_buffer = device_data->addr; + auto device_data_addr = static_cast(input).GetDeviceData(); + auto host_data_addr = const_cast(input.Data().get()); + if (device_data_addr != nullptr) { + auto input_device_id = input.GetDeviceId(); + if (input_device_id == static_cast(device_id_)) { + input_buffer = device_data_addr; } else { // memcpy device data from src device to current device. - auto data_copy_size = inputs[i]->size(); - if (AscendAllocatorPlugin::GetInstance().CopyDeviceDataToDevice(device_data->addr, info.device_data, + auto data_copy_size = input.DataSize(); + if (AscendAllocatorPlugin::GetInstance().CopyDeviceDataToDevice(device_data_addr, info.device_data, data_copy_size, info.buffer_size, input_device_id, device_id_) != kSuccess) { MS_LOG(ERROR) << "Copy input data from device to current device failed."; @@ -1123,8 +1121,8 @@ bool ModelProcess::CheckAndInitInput(const std::vector &inputs) input_buffer = info.device_data; } } else { - auto data = host_data->addr; - auto size = host_data->size; + auto data = host_data_addr; + auto size = input.DataSize(); if (!is_run_on_device_) { ret = AclrtMemcpy(info.device_data, info.buffer_size, data, size, ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_SUCCESS) { @@ -1145,68 +1143,40 @@ bool ModelProcess::CheckAndInitInput(const std::vector &inputs) ret = aclUpdateDataBuffer(data_buffer, input_buffer, info.buffer_size); if (ret != ACL_SUCCESS) { MS_LOG(ERROR) << "Failed to update Data Buffer of input " << i << ", buffer size: " << info.buffer_size - << ", input shape: " << input->GetShapeVector(); + << ", input shape: " << input.Shape(); return false; } } return true; } -void ModelProcess::CheckAndInitDynOutputDeviceBuf(const KernelTensor *output, const AclTensorInfo &output_info, +void ModelProcess::CheckAndInitDynOutputDeviceBuf(const MSTensor output, const AclTensorInfo &output_info, void **output_device_buffer, size_t *output_buf_size, size_t output_idx) { - auto device_data = output->GetData(); - auto host_data = output->GetHostData(); - if ((host_data == nullptr) || (dyn_out_sys_buf_addr_.find(host_data->addr) != dyn_out_sys_buf_addr_.end()) || - (host_data->size == 0)) { - MS_LOG(DEBUG) << "host_data->addr: " << host_data->addr + auto device_data_addr = static_cast(output).GetDeviceData(); + auto host_data_addr = const_cast(output.Data().get()); + if ((device_data_addr == nullptr) || (dyn_out_sys_buf_addr_.find(host_data_addr) != dyn_out_sys_buf_addr_.end()) || + (output.DataSize() == 0)) { + MS_LOG(DEBUG) << "host_data->addr: " << host_data_addr << ", user not defined dynamic output buffer on host, using system defined buffer"; user_defined_output_buf_[output_idx] = false; } if (user_defined_output_buf_[output_idx]) { *output_device_buffer = output_info.device_data; - auto addr = (host_data != nullptr) ? host_data->addr : device_data->addr; - auto size = (host_data != nullptr) ? host_data->size : device_data->size; + auto addr = (host_data_addr != nullptr) ? host_data_addr : device_data_addr; + auto size = output.DataSize(); *output_buf_size = size; MS_LOG(DEBUG) << "found user buffer with addr: " << addr << " with size: " << size << ". init output device addr: " << output_info.device_data; } } -bool ModelProcess::CheckAndInitOutput(const std::vector &outputs) { - // check outputs - if (!CheckOutputTensors(outputs)) { - MS_LOG(ERROR) << "Check output tensor failed."; - return false; - } +bool ModelProcess::CheckAndInitOutput(const std::vector *outputs) { aclError ret; - // copy outputs - for (size_t i = 0; i < outputs.size(); ++i) { - auto &info = output_infos_[i]; - auto output = outputs[i]; + for (size_t i = 0; i < output_infos_.size(); ++i) { void *output_device_buffer = nullptr; - auto device_data = output->GetData(); - auto host_data = output->GetHostData(); - auto output_device_id = output->device_id(); - auto output_device_buffer_size = info.buffer_size; - bool is_dynamic = is_dynamic_input_ || is_dynamic_shape_range_ || is_dynamic_output_; - if (device_data && device_data->addr) { - output_device_buffer = (output_device_id == IntToUint(device_id_)) ? device_data->addr : info.device_data; - if (is_dynamic) { - output_device_buffer_size = device_data->size; // device data buffer size is needed for memory alloc - } - MS_LOG(DEBUG) << "user defined output device data addr: " << output_device_buffer - << ", with size: " << output_device_buffer_size; - } else if (host_data && host_data->addr && is_run_on_device_) { - output_device_buffer = host_data->addr; - } else { - output_device_buffer = info.device_data; - if (is_dynamic) { - output_device_buffer = nullptr; // in dynamic output shape, setting nullptr allows acl to alloc memory - output_device_buffer_size = 0; - CheckAndInitDynOutputDeviceBuf(output, info, &output_device_buffer, &output_device_buffer_size, i); - } - } + output_device_buffer = nullptr; // in dynamic output shape, setting nullptr allows acl to alloc memory + auto output_device_buffer_size = 0; auto data_buffer = CALL_ASCEND_API(aclmdlGetDatasetBuffer, outputs_, i); if (data_buffer == nullptr) { MS_LOG(ERROR) << "Failed to get dataset buffer of output " << i; @@ -1214,81 +1184,38 @@ bool ModelProcess::CheckAndInitOutput(const std::vector &outputs } ret = CALL_ASCEND_API(aclUpdateDataBuffer, data_buffer, output_device_buffer, output_device_buffer_size); if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Failed to update Data Buffer of output " << i << ", buffer size: " << info.buffer_size - << ", output shape: " << output->GetShapeVector(); return false; } } return true; } -bool ModelProcess::ResetDynamicOutputTensor(const std::vector &outputs) { +bool ModelProcess::ResetDynamicOutputTensor(const std::vector *outputs) { dyn_out_sys_buf_addr_.clear(); for (size_t i = 0; i < output_infos_.size(); ++i) { - auto &output = outputs[i]; auto &output_info = output_infos_[i]; // get actual output tensor info aclTensorDesc *tensor_info = CALL_ASCEND_API(aclmdlGetDatasetTensorDesc, outputs_, i); size_t output_desc_size = CALL_ASCEND_API(aclGetTensorDescSize, tensor_info); - if (output_desc_size == 0) { - MS_LOG(ERROR) << "dynamic output size from acl inference result is 0, please check graph or inputs"; - return false; - } aclDataBuffer *data_buffer = CALL_ASCEND_API(aclmdlGetDatasetBuffer, outputs_, i); void *acl_device_data = CALL_ASCEND_API(aclGetDataBufferAddr, data_buffer); - - // update host address and size - auto host_data = output->GetHostData(); - auto device_data = output->GetData(); - if (device_data && device_data->addr) { - MS_LOG(DEBUG) << "data on device, no need to update system allocated buffer"; - auto output_device_id = output->device_id(); - output->SetHostData(nullptr); - output->SetData(std::make_shared(acl_device_data, output_desc_size)); - if (output_device_id != IntToUint(device_id_)) { - MS_LOG(DEBUG) << "output across device, tensor on device " << output_device_id << " with addr " - << device_data->addr << ", infer on device " << device_id_ << " with addr " << acl_device_data; - output->SetData(std::make_shared(device_data->addr, output_desc_size)); - output_info.cur_device_data = acl_device_data; - } - } else { - if (!user_defined_output_buf_[i]) { - // data_buf_ptr is passed to tensor ref data and will be freed in destructor - void *data_buf_ptr = kernel::AscendAllocatorPlugin::GetInstance().MallocHost(output_desc_size); - output->SetHostData(std::make_shared(data_buf_ptr, output_desc_size)); - output->SetData(nullptr); - (void)dyn_out_sys_buf_addr_.insert(output->GetHostData()->addr); - MS_LOG(DEBUG) << "no user provided output buffer, memory alloc by system with addr: " - << output->GetHostData()->addr << ", size: " << output_desc_size; - } else { - if (host_data == nullptr) { - MS_LOG(ERROR) << "critical error! found user defined buffer nullptr"; - return false; - } - MS_LOG(DEBUG) << "found user provided buffer addr: " << host_data->addr << ", size: " << host_data->size - << " no need to update system allocated buffer"; - } - } - - // update acl tensor info size_t dim_nums = CALL_ASCEND_API(aclGetTensorDescNumDims, tensor_info); ShapeVector shape; for (size_t j = 0; j < dim_nums; ++j) { int64_t shape_j = aclGetTensorDescDim(tensor_info, j); shape.emplace_back(shape_j); } - output->SetShapeVector(shape); output_info.device_data = acl_device_data; output_info.cur_device_data = acl_device_data; output_info.buffer_size = output_desc_size; output_info.malloc_buffer_size = output_desc_size; + output_info.dims = shape; } return true; } -bool ModelProcess::PredictFromHost(const std::vector &inputs, - const std::vector &outputs) { +bool ModelProcess::PredictFromHost(const std::vector &inputs, const std::vector *outputs) { if (!loaded_) { MS_LOG(ERROR) << "Model has not been loaded"; return false; @@ -1297,13 +1224,6 @@ bool ModelProcess::PredictFromHost(const std::vector &inputs, MS_LOG(ERROR) << "Check or init input failed"; return false; } - for (size_t i = 0; i < outputs.size(); i++) { - if (user_defined_output_buf_.size() < outputs.size()) { - user_defined_output_buf_.push_back(true); - } else { - user_defined_output_buf_[i] = true; - } - } if (!CheckAndInitOutput(outputs)) { MS_LOG(ERROR) << "Check output tensor failed"; return false; @@ -1345,20 +1265,26 @@ bool ModelProcess::PredictFromHost(const std::vector &inputs, if (!ret) { return false; } + } else { + for (size_t i = 0; i < output_infos_.size(); ++i) { + auto &output_info = output_infos_[i]; + aclDataBuffer *data_buffer = CALL_ASCEND_API(aclmdlGetDatasetBuffer, outputs_, i); + void *acl_device_data = CALL_ASCEND_API(aclGetDataBufferAddr, data_buffer); + output_info.device_data = acl_device_data; + output_info.cur_device_data = acl_device_data; + } } if (!GetOutputs(outputs)) { MS_LOG(ERROR) << "Build outputs failed"; return false; } // The device_data is malloced by acl, user need to free the addr - if (is_dynamic_output_) { - FreeResourceOutput(&output_infos_, outputs); - } - MS_LOG(INFO) << "Execute model success"; + FreeResourceOutput(&output_infos_, outputs); + return true; } -bool ModelProcess::CreateWeightsInput(const std::vector &kernel_inputs) { +bool ModelProcess::CreateWeightsInput(const std::vector &kernel_inputs) { MS_CHECK_TRUE_MSG(weight_inputs_ != nullptr, false, "Weight inputs is nullptr!"); MS_CHECK_TRUE_MSG(model_weight_desc_ != nullptr, false, "Weight desc is nullptr!"); size_t input_size = CALL_ASCEND_API(aclmdlGetNumInputs, model_weight_desc_); @@ -1368,19 +1294,19 @@ bool ModelProcess::CreateWeightsInput(const std::vector &kernel_ return false; } for (size_t i = 0; i < kernel_inputs.size(); ++i) { - auto kernel_input = kernel_inputs[i]; + auto kernel_input = kernel_inputs.at(i); auto &info = weight_input_infos_[i]; void *input_buffer = nullptr; - auto device_data = kernel_input->GetData(); - auto host_data = kernel_input->GetHostData(); - if (device_data && device_data->addr) { - auto input_device_id = kernel_input->device_id(); - if (input_device_id == IntToUint(device_id_)) { - input_buffer = device_data->addr; + auto device_data_addr = kernel_input.GetDeviceData(); + auto host_data_addr = kernel_input.Data().get(); + if (device_data_addr != nullptr) { + auto input_device_id = kernel_input.GetDeviceId(); + if (input_device_id == static_cast(device_id_)) { + input_buffer = device_data_addr; } else { // memcpy device data from src device to current device. - auto data_copy_size = kernel_input->size(); - if (AscendAllocatorPlugin::GetInstance().CopyDeviceDataToDevice(device_data->addr, info.device_data, + auto data_copy_size = kernel_input.DataSize(); + if (AscendAllocatorPlugin::GetInstance().CopyDeviceDataToDevice(device_data_addr, info.device_data, data_copy_size, info.buffer_size, input_device_id, device_id_) != kSuccess) { MS_LOG(ERROR) << "Copy input data from device to current device failed!"; @@ -1389,8 +1315,8 @@ bool ModelProcess::CreateWeightsInput(const std::vector &kernel_ input_buffer = info.device_data; } } else { - auto data = host_data->addr; - auto size = host_data->size; + auto data = host_data_addr; + auto size = kernel_input.DataSize(); if (size != info.buffer_size) { MS_LOG(ERROR) << "Buffer size: " << info.buffer_size << "!=" << "input size :" << size << ", current only support data type fp16!"; @@ -1409,7 +1335,7 @@ bool ModelProcess::CreateWeightsInput(const std::vector &kernel_ } input_buffer = info.device_data; } else { - input_buffer = data; + input_buffer = const_cast(data); } } auto data_buffer = CALL_ASCEND_API(aclmdlGetDatasetBuffer, weight_inputs_, i); @@ -1419,8 +1345,7 @@ bool ModelProcess::CreateWeightsInput(const std::vector &kernel_ } auto acl_ret = CALL_ASCEND_API(aclUpdateDataBuffer, data_buffer, input_buffer, info.buffer_size); if (acl_ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Failed to update Data Buffer of input " << i << ", buffer size: " << info.buffer_size - << ", input shape: " << kernel_input->GetShapeVector() << "!"; + MS_LOG(ERROR) << "Failed to update Data Buffer of input "; return false; } } @@ -1460,7 +1385,7 @@ void ModelProcess::DestoryUpdateWeightBuffer() { } // for update weights -bool ModelProcess::InitUpdateWeightBuffer(const std::vector &kernel_inputs) { +bool ModelProcess::InitUpdateWeightBuffer(const std::vector &kernel_inputs) { weight_inputs_ = CALL_ASCEND_API(aclmdlCreateDataset); if (weight_inputs_ == nullptr) { MS_LOG(ERROR) << "Create input dataset failed"; @@ -1485,15 +1410,14 @@ bool ModelProcess::InitUpdateWeightBuffer(const std::vector &ker } for (size_t i = 0; i < input_size; ++i) { auto kernel_input = kernel_inputs[i]; - auto shape = kernel_input->GetShapeVector(); + auto shape = kernel_input.Shape(); aclDataType data_type = CALL_ASCEND_API(aclmdlGetInputDataType, model_weight_desc_, i); size_t type_size = 0; if (!GetSizeByDtype(data_type, &type_size)) { MS_LOG(ERROR) << "Get size of data type :" << data_type << " failed!"; return false; } - size_t tensor_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - size_t buffer_size = type_size * tensor_size; + auto buffer_size = kernel_input.DataSize(); void *data_mem_buffer = nullptr; if (!CreateDataBuffer(&data_mem_buffer, buffer_size, weight_inputs_)) { MS_LOG(ERROR) << "Add input data buffer failed, buffer size " << buffer_size << "!"; @@ -1520,7 +1444,7 @@ bool ModelProcess::InitUpdateWeightBuffer(const std::vector &ker return true; } -bool ModelProcess::UpdateWeights(const std::vector &kernel_weights) { +bool ModelProcess::UpdateWeights(const std::vector &kernel_weights) { if (!loaded_) { MS_LOG(ERROR) << "Model has not been loaded!"; return false; @@ -1571,15 +1495,15 @@ void ModelProcess::FreeResourceInput(std::vector acl_tensor_info) } void ModelProcess::FreeResourceOutput(std::vector *acl_tensor_info, - const std::vector &outputs) { + const std::vector *outputs) { for (size_t i = 0; i < acl_tensor_info->size(); i++) { auto &item = (*acl_tensor_info)[i]; - auto &output = outputs[i]; - auto device_data = output->GetData(); - if ((device_data && device_data->addr) || user_defined_output_buf_[i]) { - MS_LOG(DEBUG) << "found data managed by the user, skipping resource release"; - continue; - } +// auto &output = outputs->at(i); +// auto device_data_addr = static_cast(output).GetDeviceData(); +// if (device_data_addr || user_defined_output_buf_[i]) { +// MS_LOG(DEBUG) << "found data managed by the user, skipping resource release"; +// continue; +// } if (item.device_data != nullptr) { MS_LOG(DEBUG) << "freeing device buffer at addr: " << item.device_data; if (!is_run_on_device_) { @@ -1598,42 +1522,31 @@ void ModelProcess::FreeResourceOutput(std::vector *acl_tensor_inf } } -bool ModelProcess::GetOutputs(const std::vector &outputs) { - aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST; +bool ModelProcess::GetOutputs(const std::vector *outputs) { + std::vector new_outputs; + aclrtMemcpyKind kind = ACL_MEMCPY_DEVICE_TO_HOST; + for (size_t i = 0; i < output_infos_.size(); ++i) { - auto &output = outputs[i]; auto &output_info = output_infos_[i]; - if (output_info.cur_device_data == nullptr) { - MS_LOG(WARNING) << "Output device add is nullptr."; - continue; - } - auto host_data = output->GetHostData(); - auto output_device_id = output->device_id(); - if (host_data && host_data->addr && !is_run_on_device_) { - if (host_data->size != output_info.buffer_size) { - MS_LOG(ERROR) << "Specified output host data size " << host_data->size << " != execute output data size " - << output_info.buffer_size << ", output shape: " << output_info.dims; - return false; - } - MS_LOG(DEBUG) << "copying to host with addr: " << host_data->addr << " with size: " << output_info.buffer_size; - auto ret = - AclrtMemcpy(host_data->addr, host_data->size, output_info.cur_device_data, output_info.buffer_size, kind); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Memcpy output " << i << " from " << (is_run_on_device_ ? "host" : "device") - << " to host failed, memory size " << output_info.buffer_size << ", ret: " << ret; - return false; - } - } else if (output_device_id != IntToUint(device_id_)) { - // memcpy output data from current device to output device. - if (AscendAllocatorPlugin::GetInstance().CopyDeviceDataToDevice( - output_info.cur_device_data, output->GetData()->addr, output->size(), output_info.buffer_size, device_id_, - output_device_id) != kSuccess) { - MS_LOG(ERROR) << "Copy output data from device to current device failed."; - return false; - } + auto host_data = malloc(output_info.buffer_size); + auto ret = + AclrtMemcpy(host_data, output_info.buffer_size, output_info.cur_device_data, output_info.buffer_size, kind); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Memcpy output " << i << " from " << (is_run_on_device_ ? "host" : "device") + << " to host failed, memory size " << output_info.buffer_size << ", ret: " << ret; + return false; } + auto output = + mindspore::MSTensor::CreateTensor(output_info.name, static_cast(TransToDataType(output_info.data_type)), + output_info.dims, host_data, output_info.buffer_size); + + free(host_data); + host_data = nullptr; + new_outputs.push_back(*output); + delete output; } + const_cast *>(outputs)->clear(); + const_cast *>(outputs)->insert(outputs->end(), new_outputs.begin(), new_outputs.end()); return true; } -} // namespace acl -} // namespace mindspore::kernel +} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/kernel/ascend/model/model_process.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.h similarity index 75% rename from mindspore-lite/src/extendrt/kernel/ascend/model/model_process.h rename to mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.h index 079eeaa6fef86324140682e1dd17a68ad4cf5519..563cc12b1dbe77072c29b4cfd28f08cf79abf3b2 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/model/model_process.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_MODEL_PROCESS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_MODEL_PROCESS_H_ +#ifndef delegate_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_MODEL_PROCESS_H_ +#define delegate_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_MODEL_PROCESS_H_ #include #include @@ -29,13 +29,12 @@ #include "acl/acl_rt.h" #include "include/api/types.h" #include "include/errorcode.h" -#include "common/kernel.h" -#include "extendrt/kernel/ascend/options/acl_model_options.h" -#include "extendrt/kernel/ascend/model/dyn_shape_process.h" -#include "src/litert/kernel/ascend/src/acl_mem_manager.h" +#include "extendrt/delegate/ascend_acl/acl_model_options.h" +#include "extendrt/delegate/ascend_acl/dyn_shape_process.h" +#include "extendrt/delegate/ascend_acl/acl_mem_manager.h" +#include "mindspore/core/include/mindapi/base/type_id.h" -namespace mindspore::kernel { -namespace acl { +namespace mindspore { struct AclTensorInfo { void *cur_device_data; void *device_data; @@ -50,13 +49,14 @@ struct AclTensorInfo { class ModelProcess { public: - explicit ModelProcess(const AclModelOptionsPtr &options) : options_(options), device_id_(options->device_id) {} + explicit ModelProcess(const std::shared_ptr &options) + : options_(options), device_id_(options->device_id) {} ~ModelProcess(); bool Load(const void *om_data, size_t om_data_size); bool UnLoad(); - bool PredictFromHost(const std::vector &inputs, const std::vector &outputs); - bool UpdateWeights(const std::vector &inputs); + bool PredictFromHost(const std::vector &inputs, const std::vector *outputs); + bool UpdateWeights(const std::vector &inputs); // override this method to avoid request/reply data copy void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; } @@ -82,14 +82,14 @@ class ModelProcess { void DestroyOutputsBuffer(); bool CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset); - bool CheckAndInitInput(const std::vector &inputs); - bool CheckAndInitOutput(const std::vector &outputs); - void CheckAndInitDynOutputDeviceBuf(const KernelTensor *output, const AclTensorInfo &output_info, + bool CheckAndInitInput(const std::vector &inputs); + bool CheckAndInitOutput(const std::vector *outputs); + void CheckAndInitDynOutputDeviceBuf(const MSTensor output, const AclTensorInfo &output_info, void **output_device_buffer, size_t *output_buf_size, size_t output_idx); - bool CheckInputTensors(const std::vector &inputs); - bool CheckOutputTensors(const std::vector &outputs); + bool CheckInputTensors(const std::vector &inputs); + bool CheckOutputTensors(const std::vector *outputs); bool CheckAndSetDynFlag(); - bool GetOutputs(const std::vector &outputs); + bool GetOutputs(const std::vector *outputs); bool ResetInputSize(const std::vector &new_shapes); bool ResetOutputSize(); @@ -97,22 +97,22 @@ class ModelProcess { bool IsDynamicBatchSize(); bool IsDynamicImageSize(); bool IsDynamicDims(); - bool ResetDynamicOutputTensor(const std::vector &outputs); + bool ResetDynamicOutputTensor(const std::vector *outputs); bool ResizeDynamicInputShape(const std::vector &new_shapes); bool ResizeDynamicInputShapeRange(const std::vector &new_shapes); bool ResizeDynamicBatchAndImageSize(const std::vector &new_shapes); void FreeResourceInput(std::vector acl_tensor_info); - void FreeResourceOutput(std::vector *acl_tensor_info, const std::vector &outputs); + void FreeResourceOutput(std::vector *acl_tensor_info, const std::vector *outputs); aclError AclrtMemcpy(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind); bool PrepareMutiModelShare(const void *om_data, size_t om_data_size); - bool InitUpdateWeightBuffer(const std::vector &kernel_inputs); - bool CreateWeightsInput(const std::vector &inputs); + bool InitUpdateWeightBuffer(const std::vector &kernel_inputs); + bool CreateWeightsInput(const std::vector &inputs); void DestoryUpdateWeightBuffer(); bool ShareWeightspaceProcess(const size_t &work_size); bool ShareWorkspaceProcess(const size_t &work_size, const size_t &weight_size); bool ShareWorkspaceAndWeightspaceProcess(const size_t &work_size); - AclModelOptionsPtr options_; + std::shared_ptr options_; uint32_t model_id_ = UINT32_MAX; // if run one device(AICPU), there is no need to alloc device memory and copy inputs to(/outputs from) device bool is_run_on_device_ = false; @@ -148,6 +148,5 @@ class ModelProcess { uint32_t infer_id_ = 0; uint32_t update_id_ = 0; }; -} // namespace acl -} // namespace mindspore::kernel +} // namespace mindspore #endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_MODEL_MODEL_PROCESS_H_ diff --git a/mindspore-lite/src/extendrt/kernel/ascend/profiling/profiling.cc b/mindspore-lite/src/extendrt/delegate/ascend_acl/profiling.cc similarity index 97% rename from mindspore-lite/src/extendrt/kernel/ascend/profiling/profiling.cc rename to mindspore-lite/src/extendrt/delegate/ascend_acl/profiling.cc index f6ce41ef5d9b2a0bb3e39e5021c1095146e0384b..70f83f1556e7e8fb4434a08b83bcd7f076d89294 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/profiling/profiling.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/profiling.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "extendrt/kernel/ascend/profiling/profiling.h" +#include "extendrt/delegate/ascend_acl/profiling.h" #include #include #include @@ -23,8 +23,7 @@ #include "plugin/res_manager/ascend/symbol_interface/acl_prof_symbol.h" #include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" -namespace mindspore::kernel { -namespace acl { +namespace mindspore { namespace { std::map kAicMetrics{{"ArithmeticUtilization", ACL_AICORE_ARITHMETIC_UTILIZATION}, {"PipeUtilization", ACL_AICORE_PIPE_UTILIZATION}, @@ -150,5 +149,4 @@ bool Profiling::StopProfiling(const aclrtStream &stream) { } return true; } -} // namespace acl -} // namespace mindspore::kernel +} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/kernel/ascend/profiling/profiling.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/profiling.h similarity index 83% rename from mindspore-lite/src/extendrt/kernel/ascend/profiling/profiling.h rename to mindspore-lite/src/extendrt/delegate/ascend_acl/profiling.h index 7f1d7a58d10a1ac32a52c8bd33fb8754b56e694f..b18226c24e0a44eaced6c1d629f347a89bfe00fc 100644 --- a/mindspore-lite/src/extendrt/kernel/ascend/profiling/profiling.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/profiling.h @@ -14,15 +14,15 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_PROFILING_ASCEND_PROFILING_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_PROFILING_ASCEND_PROFILING_H_ +#ifndef delegate_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_PROFILING_ASCEND_PROFILING_H_ +#define delegate_MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_PROFILING_ASCEND_PROFILING_H_ #include #include #include "acl/acl_prof.h" -namespace mindspore::kernel { -namespace acl { +namespace mindspore { + class Profiling { public: Profiling() = default; @@ -41,6 +41,6 @@ class Profiling { aclprofAicoreMetrics aic_metrics_{ACL_AICORE_PIPE_UTILIZATION}; nlohmann::json profiling_json_; }; -} // namespace acl -} // namespace mindspore::kernel + +} // namespace mindspore #endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_PROFILING_ASCEND_PROFILING_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_ge/CMakeLists.txt b/mindspore-lite/src/extendrt/delegate/ascend_ge/CMakeLists.txt index 1c3a1201e593e0e4b595153fdc7e7375ff3f1c17..90c1ae2295fe1d796262bdf0b26158c3272763bb 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_ge/CMakeLists.txt +++ b/mindspore-lite/src/extendrt/delegate/ascend_ge/CMakeLists.txt @@ -5,26 +5,16 @@ file(STRINGS "${TOP_DIR}/version.txt" MSVERSION) add_definitions(-DMSVERSION=\"${MSVERSION}\") add_compile_definitions(ENABLE_SECURITY) -if(MSLITE_ENABLE_CONVERTER AND MSLITE_ENABLE_GRAPH_KERNEL) - add_compile_definitions(MSLITE_ENABLE_GRAPH_KERNEL) -endif() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN/") -#link_directories(${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - file(GLOB GE_EXECUTOR_SRC - ${CCSRC_DIR}/utils/ms_device_shape_transfer.cc - ${CCSRC_DIR}/utils/config_manager.cc - ${CMAKE_CURRENT_SOURCE_DIR}/*.cc - ${TOP_DIR}/mindspore-lite/tools/converter/adapter/acl/mapper/*.cc - ${TOP_DIR}/mindspore-lite/tools/converter/adapter/acl/common/utils.cc - ) + ${CMAKE_CURRENT_SOURCE_DIR}/*.cc + ${TOP_DIR}/mindspore-lite/tools/converter/adapter/acl/mapper/*.cc + ${TOP_DIR}/mindspore-lite/tools/converter/adapter/acl/common/utils.cc + ) set_property(SOURCE ${GE_EXECUTOR_SRC} PROPERTY COMPILE_DEFINITIONS LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) - - add_library(ascend_ge_plugin SHARED ${GE_EXECUTOR_SRC}) find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) @@ -38,13 +28,13 @@ find_library(libcompress libcompress.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOO find_library(libopskernel libopskernel.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(libaicore_utils libaicore_utils.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(libaicpu_engine_common libaicpu_engine_common.so ${ASCEND_CANN_RUNTIME_PATH} - ${ASCEND_TOOLKIT_RUNTIME_PATH}) + ${ASCEND_TOOLKIT_RUNTIME_PATH}) find_library(ge_runner libge_runner.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) target_link_libraries(ascend_ge_plugin ${ge_graph} ${ge_compiler} ${acl_retr} ${acl_cblas} ${acl_dvpp} - ${acl_runtime} ${libplatform} ${libcompress} ${libopskernel} ${libaicore_utils} - ${libaicpu_engine_common} ${acl} ${ge_runner} mindspore_converter mindspore_core mindspore_ops - mindspore_graph_ir) + ${acl_runtime} ${libplatform} ${libcompress} ${libopskernel} ${libaicore_utils} + ${libaicpu_engine_common} ${acl} ${ge_runner} mindspore_converter mindspore_core mindspore_ops + mindspore_graph_ir) target_link_libraries(ascend_ge_plugin mindspore-extendrt) set_target_properties(ascend_ge_plugin PROPERTIES INSTALL_RPATH "$ORIGIN") diff --git a/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_device_context.cc b/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_device_context.cc index e14a8c952131559ed855246f4e5270b007764fe6..44c671068852b4a936da18d325585128a3a32727 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_device_context.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_device_context.cc @@ -28,7 +28,6 @@ #include "common/config_infos.h" #include "common/common.h" #include "extendrt/delegate/comm_group_info.h" -#include "extendrt/delegate/ascend_ge/ge_utils.h" #include "backend/common/session/executor.h" #include "plugin/res_manager/ascend/symbol_interface/acl_rt_symbol.h" #include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" @@ -215,15 +214,10 @@ Status GeDeviceContext::Initialize(const std::shared_ptr &context, cons MS_LOG(ERROR) << "Failed to Init GE"; return status; } - auto ascend_soc_version = GetSocVersion(); - if (ascend_soc_version != "Ascend310") { - status = InitHccl(context, config_info); - if (status != kSuccess) { - MS_LOG(ERROR) << "Failed to Init HCCL"; - return status; - } - } else { - MS_LOG(INFO) << "Ascend310 does not support hccl now, no need to init."; + status = InitHccl(context, config_info); + if (status != kSuccess) { + MS_LOG(ERROR) << "Failed to Init HCCL"; + return status; } return kSuccess; } diff --git a/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc b/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc index a188aded17e52239b10bb9bdba7dd3104dc4d265..f1f52a8d293e344d8cd4547251431294e970947b 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc @@ -18,35 +18,20 @@ #include #include #include -#include "mindspore/ops/op_def/framework_ops.h" #include "extendrt/delegate/factory.h" #include "include/common/utils/scoped_long_running.h" -#include "include/api/context.h" -#include "include/api/status.h" -#include "backend/ge_backend/graph_ir/utils.h" -#include "common/device_type.h" #include "include/common/utils/ms_device_shape_transfer.h" #include "src/common/common.h" #include "src/common/file_utils.h" #include "cxx_api/acl_utils.h" -#include "utils/ms_utils_secure.h" #include "tools/common/graph_util.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/graph/remove_load_pass.h" #include "src/extendrt/utils/func_graph_utils.h" -#include "plugin/res_manager/ascend/op_adapter/transform_util.h" -#include "flow_graph/data_flow.h" -#ifdef MSLITE_ENABLE_GRAPH_KERNEL -#include "tools/graph_kernel/converter/graph_kernel_optimization.h" -#endif -#include "src/extendrt/utils/tensor_utils.h" #include "external/ge_common/ge_api_error_codes.h" #include "src/extendrt/delegate/ascend_ge/aoe_api_tune_process.h" -#include "extendrt/session/lite_graph_executor.h" #include "extendrt/delegate/ascend_ge/ge_utils.h" #include "extendrt/delegate/ascend_ge/ge_dynamic_utils.h" -#include "mindspore/ops/op_def/lite_ops.h" -#include "mindspore/ops/op_def/nn_optimizer_ops.h" #include "tools/common/string_util.h" #include "src/extendrt/cxx_api/file_utils.h" #include "mindspore/ops/infer/custom.h" @@ -54,10 +39,6 @@ #include "op_proto/inc/array_ops.h" #include "op_proto/inc/elewise_calculation_ops.h" #include "tools/optimizer/graph/attr_to_args_pass.h" -#include "mindspore/ops/op_def/nn_ops.h" -#include -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" namespace mindspore { namespace { @@ -65,37 +46,16 @@ constexpr auto kProviderGe = "ge"; constexpr auto kDump = "dump"; constexpr auto kDumpMode = "dump_mode"; constexpr auto kProfiling = "profiler"; -constexpr auto kDataFlowGraphType = "data_flow"; -constexpr auto kCustomInputSize = 2; -constexpr auto kGraphKernelParam = "graph_kernel_param"; constexpr auto kUnkonwnSessionId = -1; constexpr auto kRefModeNone = "none"; constexpr auto kRefModeVariable = "variable"; constexpr auto kRefModeAll = "all"; -constexpr float kNumMicrosecondToMillisecond = 1000.0; constexpr size_t kAlignRefData = 32; size_t ALIGN_UP_REF_DATA(size_t size) { return ((size + kMemAlignSize + kAlignRefData - 1) / kMemAlignSize) * kMemAlignSize; } -#ifdef MSLITE_ENABLE_GRAPH_KERNEL -std::shared_ptr ParseGraphKernelConfigs(const ConfigInfos &maps) { - if (maps.find(kGraphKernelParam) == maps.end()) { - return nullptr; - } - auto param = std::make_shared(); - const auto &gk_map = maps.at(kGraphKernelParam); - std::stringstream oss; - for (const auto &item : gk_map) { - oss << "--" << item.first << "=" << item.second << " "; - } - param->device = GetSocVersion(); - param->graphKernelParam.graph_kernel_flags = oss.str(); - return param; -} -#endif - backend::ge_backend::DfGraphPtr GenExampleGraph(const std::string &name) { MS_LOG(INFO) << "Gen fake graph name is " << name; auto graph = std::make_shared(name); @@ -159,40 +119,9 @@ bool UpdateOmCacheIdxFile(const std::string &idx_file_name) { std::atomic_uint32_t GeGraphExecutor::global_graph_idx_ = 0; uint32_t GeGraphExecutor::GetNextGraphIdx() { return global_graph_idx_++; } -backend::ge_backend::DfGraphPtr GetDataFlowGraph(const FuncGraphPtr &anf_graph, - const std::map &ge_options) { - MS_EXCEPTION_IF_NULL(anf_graph); - auto return_node = anf_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - auto nodes = anf_graph->TopoSort(return_node); - auto itr = std::find_if(nodes.begin(), nodes.end(), [&](const AnfNodePtr &node) { - return node != nullptr && node->isa() && opt::CheckPrimitiveType(node, prim::kPrimCustom); - }); - if (itr == nodes.end()) { - MS_LOG(ERROR) << "The dataflow graph is invalid."; - return nullptr; - } - auto custom_cnode = (*itr)->cast(); - MS_EXCEPTION_IF_NULL(custom_cnode); - if (custom_cnode->size() != kCustomInputSize) { - MS_LOG(ERROR) << "The input of dataflow custom node is not 2."; - return nullptr; - } - auto tensor = FuncGraphUtils::GetConstNodeValue(custom_cnode->input(1)); - MS_EXCEPTION_IF_NULL(tensor); - auto data = tensor->data_c(); - MS_EXCEPTION_IF_NULL(data); - auto flow_graph = reinterpret_cast(data); - MS_EXCEPTION_IF_NULL(flow_graph); - auto df_graph = std::make_shared(flow_graph->ToGeGraph()); - return df_graph; -} GeGraphExecutor::~GeGraphExecutor() { if (ge_session_) { - for (auto graph_id : init_graph_id_list_) { - ge_session_->RemoveGraph(graph_id); - } for (auto graph_id : compute_graph_id_list_) { ge_session_->RemoveGraph(graph_id); auto session_context = GeSessionManager::GetGeSessionContext(session_id_); @@ -202,8 +131,6 @@ GeGraphExecutor::~GeGraphExecutor() { } ge_session_ = nullptr; GeSessionManager::TryReleaseGeSessionContext(session_id_); - enable_update_weight_ = false; - update_weight_ptr_ = nullptr; } } @@ -418,11 +345,11 @@ bool GeGraphExecutor::InitMaxShapeParam() { return true; } -bool GeGraphExecutor::InitRealShapeParam(const std::vector &inputs) { +bool GeGraphExecutor::InitRealShapeParam(const std::vector &inputs) { if (!dyn_kv_cache_info_.dynamic_kv_cache) { return true; } - auto input_0_shape = inputs[0].shape_c(); + auto input_0_shape = inputs[0].Shape(); if (input_0_shape.size() != kShape2dDims) { MS_LOG(ERROR) << "Expected input 0 shape to be [bs, seq_length], but got " << input_0_shape; return false; @@ -486,21 +413,6 @@ bool GeGraphExecutor::Init() { cache_mode_ = model_cache_mode; MS_LOG(INFO) << "Set set model cache mode " << model_cache_mode; } - std::string variable_weights_list; - (void)GetConfigOption(lite::kAscendContextSection, "variable_weights_list", &variable_weights_list); - if (!variable_weights_list.empty()) { - update_weight_ptr_ = std::make_shared(); - if (update_weight_ptr_ == nullptr) { - MS_LOG(ERROR) << "init update weight ptr failed."; - return false; - } - if (!update_weight_ptr_->ParseUpdateWeightConfig(variable_weights_list)) { - MS_LOG(ERROR) << "ParseUpdateWeightConfig failed."; - update_weight_ptr_ = nullptr; - return false; - } - enable_update_weight_ = true; - } return true; } @@ -845,7 +757,7 @@ bool GeGraphExecutor::InitRefDataList(const std::vectorshape_c(); - ref_data_info.dtype = tensor->data_type(); + ref_data_info.dtype = static_cast(tensor->data_type_c()); ref_data_info.host_data = item.second; MS_LOG(INFO) << "Init ref data info[" << ref_data_infos_.size() << "] :" << ref_data_info.name << ", dtype:" << ref_data_info.dtype << ", shape:" << ref_data_info.shape; @@ -908,11 +820,12 @@ bool GeGraphExecutor::InitRefDataDeviceTensor() { for (size_t i = 0; i < ref_data_infos_.size(); i++) { auto &item = ref_data_infos_[i]; auto tensor = item.host_data; - item.size = tensor->Size(); + item.size = tensor->DataSize(); item.host_data = nullptr; // release host memory ShapeVector ref_data_shape = tensor->shape_c(); SetRefShape(&ref_data_shape, true, item.name); - auto desc = device::ascend::TransformUtil::GetGeTensorDesc(ref_data_shape, tensor->data_type(), kOpFormat_NCHW); + auto desc = device::ascend::TransformUtil::GetGeTensorDesc( + ref_data_shape, static_cast(tensor->data_type_c()), kOpFormat_NCHW); if (desc == nullptr) { MS_LOG(ERROR) << "Failed to get Tensor Desc"; return false; @@ -939,7 +852,7 @@ bool GeGraphExecutor::InitRefDataDeviceTensor() { } } else { item.offset = ref_data_total_size; - ref_data_total_size += ALIGN_UP_REF_DATA(tensor->Size()); + ref_data_total_size += ALIGN_UP_REF_DATA(tensor->DataSize()); new_param_tensor_map[item.name] = tensor; } } @@ -955,7 +868,7 @@ bool GeGraphExecutor::InitRefDataDeviceTensor() { } auto &tensor_val = it->second; auto dst_addr = device_memory + item.offset; - if (!memory_manager_->MemcpyHost2Device(dst_addr, item.size, tensor_val->data_c(), tensor_val->Size())) { + if (!memory_manager_->MemcpyHost2Device(dst_addr, item.size, tensor_val->data_c(), tensor_val->DataSize())) { MS_LOG(ERROR) << "Failed to memory copy host data to device"; return false; } @@ -1079,10 +992,6 @@ bool GeGraphExecutor::InitRefDataContext(const FuncGraphPtr &func_graph, } backend::ge_backend::DfGraphPtr GeGraphExecutor::CreateFakeGraph(const std::map &ge_options) { - if (enable_update_weight_) { - MS_LOG(INFO) << "Enable update weight, skip create small ge graph"; - return nullptr; - } if (build_cache_dir_.empty()) { MS_LOG(INFO) << "Option model_cache_mode " << cache_mode_ << " is not mem_opt and not load offline model or " << kGeGraphCompilerCacheDir << " is empty, skip create small ge graph"; @@ -1107,69 +1016,9 @@ backend::ge_backend::DfGraphPtr GeGraphExecutor::CreateFakeGraph(const std::map< return df_graph; } -bool GeGraphExecutor::UpdateWeights(const std::vector>> &weights) { - auto time1 = lite::GetTimeUs(); - if (init_graph_id_list_.empty()) { - MS_LOG(ERROR) << "init graph id list is empty."; - return false; - } - uint32_t init_graph_id = init_graph_id_list_[0]; - MS_LOG(INFO) << "init_graph_id: " << init_graph_id; - if (update_weight_ptr_ == nullptr) { - MS_LOG(ERROR) << "please init update weight class by build model."; - return false; - } - std::vector>> new_weight_tensors; - auto ret = update_weight_ptr_->UpdateConstantTensorData(weights, &new_weight_tensors); - if (!ret) { - MS_LOG(ERROR) << "update weight failed."; - return false; - } - MS_LOG(DEBUG) << "ExecInitGraph start."; - auto time2 = lite::GetTimeUs(); - MS_LOG(INFO) << "update weight prepare time: " << (time2 - time1) / kNumMicrosecondToMillisecond << " ms"; - - // cppcheck-suppress cppcheckError - for (size_t i = 0; i < new_weight_tensors.size(); i++) { - std::vector<::ge::Tensor> ge_inputs; - // cppcheck-suppress cppcheckError - for (size_t j = 0; j < new_weight_tensors[i].size(); j++) { - auto &input = new_weight_tensors[i][j]; - auto ge_tensor = device::ascend::TransformUtil::ConvertTensor(input, kOpFormat_NCHW, false); - if (ge_tensor == nullptr) { - MS_LOG(ERROR) << "Failed to converter input " << i << " ME Tensor to GE Tensor"; - return false; - } - ge_inputs.emplace_back(*ge_tensor); - } - std::vector<::ge::Tensor> ge_outputs; - auto ge_status = ge_session_->RunGraph(init_graph_id, ge_inputs, ge_outputs); - if (ge_status != ge::GRAPH_SUCCESS) { - MS_LOG(ERROR) << "Exec init graph failed, graph id " << init_graph_id; - return false; - } - } - auto time3 = lite::GetTimeUs(); - MS_LOG(INFO) << "update weight run init graph time: " << (time3 - time2) / kNumMicrosecondToMillisecond << " ms"; - return true; -} - backend::ge_backend::DfGraphPtr GeGraphExecutor::CreateGeGraphOnline( const FuncGraphPtr &anf_graph, std::map *ge_options_ptr) { MS_CHECK_TRUE_RET(ge_options_ptr != nullptr, nullptr); - std::vector extra_variables_names = {}; - if (enable_update_weight_ && update_weight_ptr_ != nullptr) { - auto ret = update_weight_ptr_->CreateAddOpNodeForGraph(anf_graph); - if (!ret) { - MS_LOG(ERROR) << "CreateAddOpNodeForGraph failed."; - return nullptr; - } - extra_variables_names = update_weight_ptr_->GetVariableParamsName(anf_graph); - if (extra_variables_names.empty()) { - MS_LOG(WARNING) << "GetVariableParamsName failed."; - return nullptr; - } - } backend::ge_backend::TensorOrderMap params_vals; GetParams(anf_graph, ¶ms_vals); backend::ge_backend::SetDynRefDataFunc dyn_ref_data_func = nullptr; @@ -1180,9 +1029,8 @@ backend::ge_backend::DfGraphPtr GeGraphExecutor::CreateGeGraphOnline( }; } - MS_LOG(INFO) << "extra_variables_names size: " << extra_variables_names.size(); - auto converter = std::make_shared(anf_graph, "", ref_mode_flag_, - extra_variables_names, dyn_ref_data_func); + auto converter = std::make_shared(anf_graph, "", ref_mode_flag_); + // TODO(yefeng): backend::ge_backend::BuildGraph backend::ge_backend::BuildGraph(graph_name_, converter, params_vals); auto err_code = backend::ge_backend::ErrCode(converter); if (err_code != 0) { @@ -1197,24 +1045,13 @@ backend::ge_backend::DfGraphPtr GeGraphExecutor::CreateGeGraphOnline( MS_LOG(ERROR) << "Failed to add init graph, graph name " << anf_graph->ToString(); return nullptr; } - if (enable_update_weight_ && update_weight_ptr_ != nullptr) { - init_graph_id_list_.push_back(init_graph_id); - } auto init_data_names = converter->GetInitDataNames(); - if (enable_update_weight_ && update_weight_ptr_ != nullptr) { - if (!update_weight_ptr_->SetInitDataNames(init_data_names)) { - MS_LOG(ERROR) << "set init data name failed."; - return nullptr; - } - } // copy init weight to device if (!RunGeInitGraph(init_graph_id, init_data_names, params_vals)) { MS_LOG(ERROR) << "Failed to run init graph for " << anf_graph->ToString(); return nullptr; } - if (!enable_update_weight_) { - ge_session_->RemoveGraph(init_graph_id); - } + ge_session_->RemoveGraph(init_graph_id); } else { MS_LOG(INFO) << "There is no init graph for graph " << anf_graph->ToString(); } @@ -1302,25 +1139,6 @@ backend::ge_backend::DfGraphPtr GeGraphExecutor::CompileGraphCommon( MS_LOG(ERROR) << "Input param ge_options_ptr is nullptr."; return nullptr; } - -#ifdef MSLITE_ENABLE_GRAPH_KERNEL - auto param = ParseGraphKernelConfigs(config_infos_); - if (param != nullptr) { - auto rank_id = common::GetEnv("RANK_ID"); - if (rank_id.empty()) { - auto ascend_device_info = GeUtils::GetAscendDeviceInfo(context_); - if (ascend_device_info != nullptr) { - auto rank_id_value = ascend_device_info->GetRankID(); - common::SetEnv("RANK_ID", std::to_string(rank_id_value).c_str()); - } - } - if (GraphKernelOptimize(anf_graph, param) != lite::RET_OK) { - MS_LOG(ERROR) << "Run graphkernel optimization failed."; - return nullptr; - } - } -#endif - auto remove_load_pass = std::make_shared(); remove_load_pass->Run(anf_graph); @@ -1344,13 +1162,7 @@ backend::ge_backend::DfGraphPtr GeGraphExecutor::CompileGraphCommon( } backend::ge_backend::DfGraphPtr df_graph = nullptr; - auto func_type = anf_graph->get_attr(kAttrFuncType); - is_data_flow_graph_ = func_type != nullptr && GetValue(func_type) == kDataFlowGraphType; - if (!is_data_flow_graph_) { - df_graph = CreateGeGraphOnline(anf_graph, ge_options_ptr); - } else { - df_graph = GetDataFlowGraph(anf_graph, *ge_options_ptr); - } + df_graph = CreateGeGraphOnline(anf_graph, ge_options_ptr); return df_graph; } @@ -1376,7 +1188,7 @@ bool GeGraphExecutor::CompileGraph(const FuncGraphPtr &anf_graph, const std::map return false; } } - std::vector orig_output; + std::vector orig_output; std::vector output_names; FuncGraphUtils::GetFuncGraphOutputsInfo(anf_graph, &orig_output, &output_names); original_graph_outputs_[*graph_id] = orig_output; @@ -1390,7 +1202,7 @@ bool GeGraphExecutor::GetOneRealInputs(const FuncGraphPtr &anf_graph, std::vecto MS_LOG(ERROR) << "Failed to get one real input shape"; return false; } - std::vector inputs; + std::vector inputs; std::vector input_names; FuncGraphUtils::GetFuncGraphInputsInfo(anf_graph, &inputs, &input_names); if (!input_shapes_configs.empty() && input_shapes_configs.size() != inputs.size()) { @@ -1411,13 +1223,13 @@ bool GeGraphExecutor::GetOneRealInputs(const FuncGraphPtr &anf_graph, std::vecto MS_LOG(ERROR) << "Cannot find input " << input_name << " in input_shape " << input_shape_str; return false; } - input = std::make_shared(input->data_type(), it->second); - } else if (GeDynamicUtils::IsDynamicInputShapes({input->shape_c()})) { - MS_LOG(ERROR) << "Input " << i << " is dynamic shape " << input->shape_c() + input->SetShape(it->second); + } else if (GeDynamicUtils::IsDynamicInputShapes({input->Shape()})) { + MS_LOG(ERROR) << "Input " << i << " is dynamic shape " << input->Shape() << ", but there is no input shape specified in AscendDeviceInfo or config file"; return false; } - MS_LOG(INFO) << "Input " << i << " shape " << input->shape_c() << ", datatype " << input->data_type(); + MS_LOG(INFO) << "Input " << i << " shape " << input->Shape() << ", datatype " << input->DataType(); auto ge_tensor = device::ascend::TransformUtil::ConvertTensor(input, kOpFormat_NCHW); if (ge_tensor == nullptr) { MS_LOG(ERROR) << "Failed to converter input " << i << " ME Tensor to GE Tensor"; @@ -1524,24 +1336,7 @@ bool GeGraphExecutor::RunGeGraphAsync(uint32_t graph_id, const std::vector<::ge: return is_finished; } -bool GeGraphExecutor::RunDataFlowGraphAsync(uint32_t graph_id, const std::vector<::ge::Tensor> &inputs, - std::vector<::ge::Tensor> *outputs) { - ge::DataFlowInfo data_flow_info; - int time_out = 3000; // set the timeout to 3000s. - auto ret = ge_session_->FeedDataFlowGraph(graph_id, inputs, data_flow_info, time_out); - if (ret != ge::SUCCESS) { - MS_LOG(ERROR) << "Feed input data failed."; - return false; - } - ret = ge_session_->FetchDataFlowGraph(graph_id, *outputs, data_flow_info, time_out); - if (ret != ge::SUCCESS) { - MS_LOG(ERROR) << "Fetch output data failed."; - return false; - } - return true; -} - -bool GeGraphExecutor::InitInputDataTensor(const std::vector &inputs, +bool GeGraphExecutor::InitInputDataTensor(const std::vector &inputs, std::vector<::ge::Tensor> *ge_inputs, std::vector<::ge::Tensor> *ge_outputs) { if (inputs_buffer_infos_.size() != inputs.size()) { MS_LOG(ERROR) << "Input data info size " << inputs_buffer_infos_.size() << " != inputs size " << inputs.size(); @@ -1553,20 +1348,21 @@ bool GeGraphExecutor::InitInputDataTensor(const std::vector &inp } for (size_t i = 0; i < inputs.size(); i++) { auto &input = inputs[i]; - MS_LOG(INFO) << "Input " << i << " shape " << tensor::ShapeToString(input.shape_c()) << ", datatype " - << input.data_type(); - auto tensor_size = input.Size(); + MS_LOG(INFO) << "Input " << i << " shape " << tensor::ShapeToString(input.Shape()) << ", datatype " + << input.DataType(); + auto tensor_size = input.DataSize(); auto &input_info = inputs_buffer_infos_[i]; if (input_info.max_size < tensor_size) { MS_LOG(ERROR) << "Input " << i << " data size invalid, graph size " << input_info.max_size << ", given size " << tensor_size; return false; } - if (!memory_manager_->MemcpyHost2Device(input_info.device_addr, input_info.max_size, input.data_c(), tensor_size)) { + if (!memory_manager_->MemcpyHost2Device(input_info.device_addr, input_info.max_size, input.Data().get(), + tensor_size)) { return false; } - SetGeTensorShape(&input_info.ge_tensor, input.shape_c()); + SetGeTensorShape(&input_info.ge_tensor, input.Shape()); ge_inputs->push_back(input_info.ge_tensor); } for (auto &item : ref_data_infos_) { @@ -1628,8 +1424,8 @@ bool GeGraphExecutor::BuildGraphRefMode(const FuncGraphPtr &anf_graph, uint32_t return true; } -bool GeGraphExecutor::RunGraphRefMode(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs) { +bool GeGraphExecutor::RunGraphRefMode(uint32_t graph_id, const std::vector &inputs, + std::vector *outputs) { MS_LOG(INFO) << "RunGraphRefMode begin"; std::vector<::ge::Tensor> ge_inputs; std::vector<::ge::Tensor> ge_outputs; @@ -1653,7 +1449,7 @@ bool GeGraphExecutor::RunGraphRefMode(uint32_t graph_id, const std::vector *outputs, +bool GeGraphExecutor::SyncDeviceOutputsToHost(std::vector *outputs, std::vector<::ge::Tensor> *ge_outputs) { UpdateOutputShapeInfo(ge_outputs); @@ -1667,18 +1463,19 @@ bool GeGraphExecutor::SyncDeviceOutputsToHost(std::vector *outpu for (size_t i = 0; i < output_size; ++i) { auto &output_info = outputs_buffer_infos_[i]; auto &output = (*outputs)[i]; - if (output.Size() < output_info.max_size) { - MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << output.Size() + if (output.DataSize() < output_info.max_size) { + MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << output.DataSize() << " is less than actual output size " << output_info.max_size; } - if ((*outputs)[i].data_c() == nullptr) { + if ((*outputs)[i].Data() == nullptr) { MS_LOG(ERROR) << "Output data ptr is nullptr."; return false; } - auto mem_ret = memory_manager_->MemcpyDevice2Host(reinterpret_cast(output.data_c()), output.Size(), - output_info.device_addr, output_info.max_size); + auto mem_ret = + memory_manager_->MemcpyDevice2Host(reinterpret_cast(output.MutableData()), output.DataSize(), + output_info.device_addr, output_info.max_size); if (!mem_ret) { - MS_LOG(ERROR) << "Failed to copy output data, dst size: " << output.Size() + MS_LOG(ERROR) << "Failed to copy output data, dst size: " << output.DataSize() << ", src size: " << output_info.max_size; return false; } @@ -1688,12 +1485,13 @@ bool GeGraphExecutor::SyncDeviceOutputsToHost(std::vector *outpu } else { for (size_t i = 0; i < output_size; i++) { auto &output_info = outputs_buffer_infos_[i]; - tensor::Tensor ms_tensor(output_info.dtype, output_info.shape); + auto ms_tensor = + *MSTensor::CreateTensor("", static_cast(output_info.dtype), output_info.shape, nullptr, 0); auto mem_ret = - memory_manager_->MemcpyDevice2Host(reinterpret_cast(ms_tensor.data_c()), ms_tensor.Size(), + memory_manager_->MemcpyDevice2Host(reinterpret_cast(ms_tensor.MutableData()), ms_tensor.DataSize(), output_info.device_addr, output_info.max_size); if (!mem_ret) { - MS_LOG(ERROR) << "Failed to copy output data, dst size: " << ms_tensor.Size() + MS_LOG(ERROR) << "Failed to copy output data, dst size: " << ms_tensor.DataSize() << ", src size: " << output_info.max_size; return false; } @@ -1737,40 +1535,41 @@ bool GeGraphExecutor::RunGraphWithStreamAsync(uint32_t graph_id, void *stream, c return true; } -bool GeGraphExecutor::RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs, +bool GeGraphExecutor::RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector *outputs, const std::map & /* compile_options */) { if (outputs == nullptr) { MS_LOG(ERROR) << " Input param is nullptr."; return false; } - MS_LOG(INFO) << "Run ge graph [" << graph_id << "] with " << inputs.size() << " inputs"; + + MS_LOG(INFO) << "Run ge graph [" << graph_id << "] with " << inputs.size() << " ms_tensor_inputs"; for (size_t i = 0; i < inputs.size(); i++) { - auto &input = inputs[i]; - MS_LOG(INFO) << "Input " << i << " shape " << input.shape_c() << ", datatype " << input.data_type(); + auto &ms_tensor_input = inputs[i]; + ms_tensor_input.DataType(); + MS_LOG(INFO) << "Input " << i << " shape " << ms_tensor_input.Shape() << ", datatype " + << ms_tensor_input.DataType(); } - if (ref_mode_flag_ != backend::ge_backend::RefModeFlag::kRefModeNone) { + MS_LOG(ERROR) << "RunGraphRefMode"; return RunGraphRefMode(graph_id, inputs, outputs); } std::vector<::ge::Tensor> ge_inputs; for (size_t i = 0; i < inputs.size(); i++) { - auto &input = inputs[i]; + auto &ms_tensor_input = inputs[i]; auto ge_tensor = - device::ascend::TransformUtil::ConvertTensor(std::make_shared(input), kOpFormat_NCHW, false); + device::ascend::TransformUtil::ConvertTensor(std::make_shared(ms_tensor_input), kOpFormat_NCHW, false); if (ge_tensor == nullptr) { MS_LOG(ERROR) << "Failed to converter input " << i << " ME Tensor to GE Tensor"; return false; } ge_inputs.emplace_back(*ge_tensor); } - for (auto &item : ref_data_infos_) { - ge_inputs.emplace_back(item.ge_tensor); - } + auto ref_data_infos = ref_data_infos_; + std::transform(ref_data_infos.begin(), ref_data_infos.end(), std::back_inserter(ge_inputs), + [](const auto &item) { return item.ge_tensor; }); std::vector<::ge::Tensor> ge_outputs; auto time_start = std::chrono::system_clock::now(); - auto ret = !is_data_flow_graph_ ? RunGeGraphAsync(graph_id, ge_inputs, &ge_outputs) - : RunDataFlowGraphAsync(graph_id, ge_inputs, &ge_outputs); + auto ret = RunGeGraphAsync(graph_id, ge_inputs, &ge_outputs); if (!ret) { MS_LOG(ERROR) << "Exec compute graph failed, graph id " << graph_id; return false; @@ -1789,18 +1588,18 @@ bool GeGraphExecutor::RunGraph(uint32_t graph_id, const std::vectorsize(); ++i) { const auto &tensor = ge_outputs[i]; auto &output = (*outputs)[i]; - if (output.Size() < LongToSize(UlongToLong(tensor.GetSize()))) { - MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << output.Size() + if (output.DataSize() < LongToSize(UlongToLong(tensor.GetSize()))) { + MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << output.DataSize() << " is less than actual output size " << tensor.GetSize(); } - if ((*outputs)[i].data_c() == nullptr) { + if ((*outputs)[i].Data() == nullptr) { MS_LOG(ERROR) << "Output data ptr is nullptr."; return false; } - auto mem_ret = common::huge_memcpy(reinterpret_cast(output.data_c()), output.Size(), tensor.GetData(), - tensor.GetSize()); + auto mem_ret = common::huge_memcpy(static_cast(output.MutableData()), output.DataSize(), + tensor.GetData(), tensor.GetSize()); if (mem_ret != EOK) { - MS_LOG(ERROR) << "Failed to copy output data, dst size: " << output.Size() + MS_LOG(ERROR) << "Failed to copy output data, dst size: " << output.DataSize() << ", src size: " << tensor.GetSize(); return false; } @@ -1813,8 +1612,8 @@ bool GeGraphExecutor::RunGraph(uint32_t graph_id, const std::vectorshape_c()) << ", datatype " - << ms_tensor->data_type(); + MS_LOG(INFO) << "Output " << i << " shape " << tensor::ShapeToString(ms_tensor->Shape()) << ", datatype " + << ms_tensor->DataType(); outputs->push_back(*ms_tensor); } } @@ -1824,12 +1623,12 @@ bool GeGraphExecutor::RunGraph(uint32_t graph_id, const std::vector GeGraphExecutor::GetInputInfos(uint32_t graph_id) { +std::vector GeGraphExecutor::GetInputInfos(uint32_t graph_id) { return graph_inputs_.find(graph_id) != graph_inputs_.end() ? graph_inputs_.at(graph_id) - : std::vector(); + : std::vector(); } -tensor::TensorPtr GeGraphExecutor::ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor_ptr, uint32_t graph_id, size_t idx) { +MSTensorPtr GeGraphExecutor::ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor_ptr, uint32_t graph_id, size_t idx) { MS_CHECK_TRUE_RET(ge_tensor_ptr != nullptr, nullptr); auto &ge_tensor = *ge_tensor_ptr; auto ge_tensor_desc = ge_tensor.GetTensorDesc(); @@ -1843,7 +1642,7 @@ tensor::TensorPtr GeGraphExecutor::ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor MS_LOG(ERROR) << "Graph output index is out of range."; return nullptr; } - TypeId type_id = static_cast(original_outputs[idx]->data_type_c()); + auto type_id = static_cast(original_outputs[idx]->DataType()); if (type_id == kTypeUnknown) { MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: " << static_cast(ge_tensor_desc.GetDataType()); @@ -1874,13 +1673,13 @@ tensor::TensorPtr GeGraphExecutor::ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor MS_LOG(ERROR) << "Output datatype error! Output tensor size from GE RunGraph does not match."; return nullptr; } - auto tensor_data = std::make_shared(ge_data, elem_num, ge_tensor.GetSize(), me_shape.size(), deleter); - return std::make_shared(type_id, me_shape, tensor_data); + return std::shared_ptr( + MSTensor::CreateTensor("", static_cast(type_id), me_shape, ge_data, ge_tensor.GetSize())); } -std::vector GeGraphExecutor::GetOutputInfos(uint32_t graph_id) { +std::vector GeGraphExecutor::GetOutputInfos(uint32_t graph_id) { return graph_outputs_.find(graph_id) != graph_outputs_.end() ? graph_outputs_.at(graph_id) - : std::vector(); + : std::vector(); } bool GeGraphExecutor::CreateAsCustomFuncGraph(const FuncGraphPtr &func_graph, diff --git a/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.h b/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.h index 920ed0af924e7293514052f0a3687cbce0bfe773..071ffc6d290cdc2834a23d38f34b1e0031b0a6ee 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.h @@ -32,7 +32,6 @@ #include "extendrt/delegate/ascend_ge/ge_device_context.h" #include "extendrt/delegate/ascend_ge/ge_memory_manager.h" #include "extendrt/delegate/ascend_ge/ge_context_manager.h" -#include "extendrt/delegate/ascend_ge/update_weight.h" #include "src/common/common.h" namespace mindspore { @@ -90,24 +89,21 @@ class GeGraphExecutor : public LiteGraphExecutor { bool CompileGraph(const FuncGraphPtr &graph, const std::map &compile_options, uint32_t *graph_id) override; - bool RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector *outputs, + bool RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector *outputs, const std::map &compile_options) override; - bool Resize(uint32_t graph_id, const std::vector &inputs, + bool Resize(uint32_t graph_id, const std::vector &inputs, const std::vector &dims) override { return true; } - std::vector GetInputInfos(uint32_t graph_id) override; - std::vector GetOutputInfos(uint32_t graph_id) override; + std::vector GetInputInfos(uint32_t graph_id) override; + std::vector GetOutputInfos(uint32_t graph_id) override; bool Init(); bool AoeTuning(const FuncGraphPtr &graph); bool OfflineBuildGraph(const FuncGraphPtr &graph); - bool UpdateWeights(const std::vector>> &weights) override; private: - std::shared_ptr update_weight_ptr_ = nullptr; - bool enable_update_weight_ = false; const std::shared_ptr context_; ConfigInfos config_infos_; std::shared_ptr ge_session_ = nullptr; @@ -129,10 +125,9 @@ class GeGraphExecutor : public LiteGraphExecutor { std::string build_cache_dir_; std::string build_cache_relative_dir_; - std::map> graph_inputs_; - std::map> graph_outputs_; - std::map> original_graph_outputs_; - bool is_data_flow_graph_ = false; + std::map> graph_inputs_; + std::map> graph_outputs_; + std::map> original_graph_outputs_; DynKVCacheInfo dyn_kv_cache_info_; std::shared_ptr GetAscendDeviceInfo(); @@ -150,7 +145,7 @@ class GeGraphExecutor : public LiteGraphExecutor { uint32_t *graph_id); bool RunGeInitGraph(uint32_t init_graph_id, const std::vector &init_data_names, const backend::ge_backend::TensorOrderMap ¶ms_vals); - tensor::TensorPtr ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor_ptr, uint32_t graph_id, size_t idx); + MSTensorPtr ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor_ptr, uint32_t graph_id, size_t idx); bool RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const std::vector &inputs, std::vector *outputs); @@ -162,14 +157,14 @@ class GeGraphExecutor : public LiteGraphExecutor { bool InitConstantFeatureDeviceMemory(uint32_t graph_id); bool InitInOutDeviceBuffer(const std::string &name, const ShapeVector &shape, TypeId dtype, InOutBufferInfo *buffer_info); - bool InitInputDataTensor(const std::vector &inputs, std::vector<::ge::Tensor> *ge_inputs, + bool InitInputDataTensor(const std::vector &inputs, std::vector<::ge::Tensor> *ge_inputs, std::vector<::ge::Tensor> *ge_outputs); bool InitMemoryContextManager(); bool BuildGraphRefMode(const FuncGraphPtr &anf_graph, uint32_t graph_id); - bool RunGraphRefMode(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs); - bool SyncDeviceOutputsToHost(std::vector *outputs, std::vector<::ge::Tensor> *ge_outputs); + bool RunGraphRefMode(uint32_t graph_id, const std::vector &inputs, + std::vector *outputs); + bool SyncDeviceOutputsToHost(std::vector *outputs, std::vector<::ge::Tensor> *ge_outputs); bool UpdateInputShapeOption(const FuncGraphPtr &func_graph, const std::vector> &ref_data_tensors, @@ -179,8 +174,6 @@ class GeGraphExecutor : public LiteGraphExecutor { static uint32_t GetNextGraphIdx(); bool RunGeGraphAsync(uint32_t graph_id, const std::vector<::ge::Tensor> &inputs, std::vector<::ge::Tensor> *outputs); - bool RunDataFlowGraphAsync(uint32_t graph_id, const std::vector<::ge::Tensor> &inputs, - std::vector<::ge::Tensor> *outputs); backend::ge_backend::DfGraphPtr CompileGraphCommon(const FuncGraphPtr &graph, std::map *ge_options_ptr); @@ -204,7 +197,7 @@ class GeGraphExecutor : public LiteGraphExecutor { bool SetGeTensorShape(GeTensor *ge_tensor, ShapeVector shape); void UpdateOutputShapeInfo(std::vector<::ge::Tensor> *ge_outputs); bool InitRefModeConfig(); - bool InitRealShapeParam(const std::vector &inputs); + bool InitRealShapeParam(const std::vector &inputs); bool CheckRefDataInfo(); bool InitMaxShapeParam(); void SetRefShape(std::vector *ref_shape, bool dyn, std::string tensor_name); diff --git a/mindspore-lite/src/extendrt/delegate/ascend_ge/update_weight.cc b/mindspore-lite/src/extendrt/delegate/ascend_ge/update_weight.cc deleted file mode 100644 index ef4eb884d84947244600e3ded552dedba03ea507..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_ge/update_weight.cc +++ /dev/null @@ -1,252 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/delegate/ascend_ge/update_weight.h" -#include -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/math_ops.h" -#include "tools/common/string_util.h" -#include "tools/optimizer/common/gllo_utils.h" -#include "ir/manager.h" -#include "tools/common/tensor_util.h" -#include "mindspore/ops/op_def/conv_pool_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -namespace mindspore { -namespace { -constexpr float kNumMicrosecondToMillisecond = 1000.0; -constexpr size_t kInputSize3 = 3; -constexpr size_t kConstantWeightShapeSize = 2; -constexpr size_t kConstantConvWeightShapeSize = 4; -constexpr size_t kInputIndex2 = 2; -constexpr const char *kUpdateWeightTensorNameSuffix = "_add_param"; -constexpr const char *kUpdateWeightAddNodeNameSuffix = "_add_cnode"; -constexpr std::size_t kUpdateWeightTensorNameSuffixSize = 10; -constexpr size_t kConvWeightSize = 4; -constexpr size_t kConvWeightShape0 = 0; -constexpr size_t kConvWeightShape1 = 1; -constexpr size_t kConvWeightShape2 = 2; -constexpr size_t kConvWeightShape3 = 3; -} // namespace - -bool UpdateWeight::IsMatchName(const std::string &cnode_name, const std::string ¶m_name) { - if (find(constant_cnode_name_.begin(), constant_cnode_name_.end(), cnode_name) != constant_cnode_name_.end()) { - MS_LOG(DEBUG) << "cnode name: " << cnode_name << ", param name: " << param_name; - return true; - } - return false; -} - -bool UpdateWeight::ParseUpdateWeightConfig(const std::string &names_str) { - MS_LOG(DEBUG) << "names str: " << names_str; - constant_cnode_name_ = mindspore::lite::SplitStringToVector(names_str, ','); - if (constant_cnode_name_.empty()) { - MS_LOG(ERROR) << "split name is empty, name str is: " << names_str; - return false; - } - return true; -} - -std::vector UpdateWeight::GetVariableParamsName(const FuncGraphPtr &anf_graph) { - return new_weight_param_name_; -} - -bool UpdateWeight::SetInitDataNames(const std::vector &init_data_names) { - if (init_data_names.empty()) { - MS_LOG(ERROR) << "init_data_names is empty."; - return false; - } - init_data_names_ = init_data_names; - return true; -} - -bool UpdateWeight::UpdateConstantTensorData(const std::vector>> &weights, - std::vector>> *new_weights) { - // sort by init data name. - if (new_weights == nullptr) { - MS_LOG(ERROR) << "new_weight_tensors is nullptr."; - return false; - } - auto time1 = lite::GetTimeUs(); - for (auto &weight : weights) { - std::vector> new_weight_tensors; - std::map> weights_pairs; - for (auto tensor : weight) { - MS_CHECK_TRUE_RET(tensor != nullptr, false); - weights_pairs[tensor->name()] = tensor; - } - for (auto &init_data_name : init_data_names_) { - auto size = init_data_name.size(); - if (size < kUpdateWeightTensorNameSuffixSize) { - MS_LOG(ERROR) << "can not find init data name: " << init_data_name; - return false; - } - size_t last_slash_pos = init_data_name.find_last_of('/'); - auto name = init_data_name.substr(0, last_slash_pos); - if (weights_pairs.find(name) == weights_pairs.end()) { - MS_LOG(ERROR) << "can not find init data name in user update weight tensors."; - return false; - } - auto weight_tensor = weights_pairs[name]; - weight_tensor->set_name(init_data_name); - new_weight_tensors.push_back(weight_tensor); - } - new_weights->push_back(new_weight_tensors); - } - auto time2 = lite::GetTimeUs(); - MS_LOG(INFO) << "Calculate update tensor time: " << (time2 - time1) / kNumMicrosecondToMillisecond << " ms"; - return true; -} - -ParameterPtr UpdateWeight::BuildFloatVec4DParameterNode(const FuncGraphPtr &anf_graph, ShapeVector weight_shape, - const std::string &node_name) { - if (weight_shape.size() != kConvWeightSize) { - MS_LOG(ERROR) << "weight_shape size is not 4, weight_shape size:" << weight_shape.size() << "!"; - return nullptr; - } - MS_CHECK_TRUE_RET(anf_graph != nullptr, nullptr); - auto param_node = anf_graph->add_parameter(); - MS_CHECK_TRUE_RET(param_node != nullptr, nullptr); - param_node->set_name(node_name); - auto weight_length = weight_shape[kConvWeightShape0] * weight_shape[kConvWeightShape1] * - weight_shape[kConvWeightShape2] * weight_shape[kConvWeightShape3]; - std::vector data_1d(weight_length, 0); - auto size = data_1d.size() * sizeof(float); - std::vector shape_vector = { - static_cast(weight_shape[kConvWeightShape0]), static_cast(weight_shape[kConvWeightShape1]), - static_cast(weight_shape[kConvWeightShape2]), static_cast(weight_shape[kConvWeightShape3])}; - auto tensor_info = lite::CreateTensorInfo(data_1d.data(), size, shape_vector, kNumberTypeFloat32); - if (tensor_info == nullptr) { - MS_LOG(ERROR) << "Create tensor info failed!"; - return nullptr; - } - auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info); - if (status != RET_OK) { - MS_LOG(ERROR) << "init parameter from tensor info failed!"; - return nullptr; - } - return param_node; -} - -bool JudgeNodeType(const AnfNodePtr &node) { - return !mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2D) && - !mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulV2) && - !mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMul) && - !mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimBatchMatMul); -} - -bool UpdateWeight::CreateAddOpNodeForGraph(const FuncGraphPtr &anf_graph) { - MS_CHECK_TRUE_RET(anf_graph != nullptr, false); - if (constant_cnode_name_.empty()) { - MS_LOG(ERROR) << "constant_cnode_name_ is empty, user not set config file for update weight!"; - return false; - } - auto node_list = TopoSort(anf_graph->get_return()); - for (auto &node : node_list) { - MS_CHECK_TRUE_RET(node != nullptr, false); - if (!utils::isa(node)) { - continue; - } - auto cnode = utils::cast(node); - MS_CHECK_TRUE_RET(cnode != nullptr, false); - size_t last_slash_pos = cnode->fullname_with_scope().find_last_of('/'); - string search_key = ""; - if (last_slash_pos != std::string::npos) { - search_key = cnode->fullname_with_scope().substr(0, last_slash_pos); - } else { - MS_LOG(INFO) << "Find last slash failed! Cnode name:" << cnode->fullname_with_scope() << "!"; - } - if (find(constant_cnode_name_.begin(), constant_cnode_name_.end(), search_key) == constant_cnode_name_.end()) { - continue; - } else if (JudgeNodeType(node)) { - continue; - } - if (cnode->size() < kInputSize3) { - MS_LOG(ERROR) << "cnode input size less " << kInputSize3; - return false; - } - auto weight = cnode->input(kInputIndex2); - MS_CHECK_TRUE_RET(weight != nullptr, false); - - // create Add node - auto add_prim = std::make_shared(); - if (add_prim == nullptr) { - MS_LOG(ERROR) << "create add prim failed."; - return false; - } - auto add_prim_c = add_prim->GetPrim(); - MS_CHECK_TRUE_RET(add_prim_c != nullptr, false); - if (!utils::isa(weight)) { - MS_LOG(ERROR) << "matmul weight is not constant, can not update weight."; - return false; - } - auto weight_param = weight->cast(); - MS_CHECK_TRUE_RET(weight_param != nullptr, false); - auto value = weight_param->default_param(); - MS_CHECK_TRUE_RET(value != nullptr, false); - auto weight_tensor = value->cast>(); - MS_CHECK_TRUE_RET(weight_tensor != nullptr, false); - auto weight_shape = weight_tensor->shape(); - AnfNodePtr add_param_node = nullptr; - if (weight_shape.size() == kConstantWeightShapeSize) { - std::vector> add_param_data(weight_shape[0], std::vector(weight_shape[1], 0)); - add_param_node = opt::BuildFloatVec2DParameterNode(anf_graph, add_param_data, - cnode->fullname_with_scope() + kUpdateWeightTensorNameSuffix); - if (add_param_node == nullptr) { - MS_LOG(ERROR) << "create param node failed!"; - return false; - } - } else if (weight_shape.size() == kConstantConvWeightShapeSize) { - add_param_node = BuildFloatVec4DParameterNode(anf_graph, weight_shape, - cnode->fullname_with_scope() + kUpdateWeightTensorNameSuffix); - if (add_param_node == nullptr) { - MS_LOG(ERROR) << "create param node failed!"; - return false; - } - } else { - MS_LOG(ERROR) << "now only support 2 dims matmul and 4 dims conv constant weight!" - << "weight_shape:" << weight_shape.size() << "node name:" << cnode->fullname_with_scope() << "!"; - return false; - } - - if (add_param_node == nullptr) { - MS_LOG(ERROR) << "create param node failed!"; - return false; - } - new_weight_param_name_.push_back(cnode->fullname_with_scope() + "_add_param"); - auto inputs = {weight, add_param_node}; - auto add_cnode = anf_graph->NewCNode(add_prim_c, inputs); - if (add_cnode == nullptr) { - MS_LOG(ERROR) << "new add node failed."; - return false; - } - add_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + kUpdateWeightAddNodeNameSuffix); - if (node->abstract() != nullptr) { - add_cnode->set_abstract(node->abstract()->Clone()); - } - auto manager = Manage(anf_graph); - (void)manager->Replace(weight, add_cnode); - } - if (new_weight_param_name_.size() != constant_cnode_name_.size()) { - MS_LOG(ERROR) << "init data name size is not equal user config file name size, new_weight_param_name_: " - << new_weight_param_name_.size() << ", constant_cnode_name_ size: " << constant_cnode_name_.size(); - } - MS_LOG(INFO) << "new_weight_param_name_ size: " << new_weight_param_name_.size() - << ", constant_cnode_name_ size: " << constant_cnode_name_.size(); - return true; -} -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/delegate/ascend_ge/update_weight.h b/mindspore-lite/src/extendrt/delegate/ascend_ge/update_weight.h deleted file mode 100644 index d149eb6723166e7a462f32e37846243b3fec1660..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_ge/update_weight.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_UPDATE_WEIGHTS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_UPDATE_WEIGHTS_H_ -#include -#include -#include -#include -#include "backend/ge_backend/graph_ir/utils.h" -namespace mindspore { -class UpdateWeight { - public: - UpdateWeight() = default; - ~UpdateWeight() = default; - - bool IsMatchName(const std::string &cnode_name, const std::string ¶m_name); - bool ParseUpdateWeightConfig(const std::string &config_path); - std::vector GetVariableParamsName(const FuncGraphPtr &anf_graph); - bool SetInitDataNames(const std::vector &init_data_names); - bool CreateAddOpNodeForGraph(const FuncGraphPtr &anf_graph); - bool UpdateConstantTensorData(const std::vector>> &weights, - std::vector>> *new_weights); - ParameterPtr BuildFloatVec4DParameterNode(const FuncGraphPtr &anf_graph, ShapeVector weight_shape, - const std::string &node_name); - - private: - /* note: - * cnode_name == user_config_file_name - * add_weight_name == cnode_name + "_add_param" - * - * init_data_names_ : need update weight tensor name, set by ge graph executor - * constant_cnode_name_: user_config_file_name - * new_weight_param_name_: add parameter node name - */ - std::vector new_weight_param_name_; - std::vector constant_cnode_name_; // equal matmul node name, user config file name - std::vector init_data_names_; -}; -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_GE_UPDATE_WEIGHTS_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/CMakeLists.txt b/mindspore-lite/src/extendrt/delegate/ascend_native/CMakeLists.txt deleted file mode 100644 index 2b4b6eee59b2000d2416780a0626f1b0314fa2f2..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -add_library(ascend_native_mid OBJECT - ${CMAKE_CURRENT_SOURCE_DIR}/delegate.cc - ${CMAKE_CURRENT_SOURCE_DIR}/delegate_allocator.cc - ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_helper.cc - ${CMAKE_CURRENT_SOURCE_DIR}/stub_kernel.cc - ${CMAKE_CURRENT_SOURCE_DIR}/ascend_native_plugin_impl.cc - ${CMAKE_CURRENT_SOURCE_DIR}/ops/ascend_native_composite.cc - ${CMAKE_CURRENT_SOURCE_DIR}/ops/ascend_native_stub.cc - ) - -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ascend_native_impl) diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_add_kernel.cc b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_add_kernel.cc deleted file mode 100644 index c57e0c1bbb9111d3d5c9bc96f09af0a0fc1c766c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_add_kernel.cc +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/delegate/ascend_native/ascend_native_add_kernel.h" -#include "extendrt/delegate/ascend_native/ascend_native_kernel_registry.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/add.h" -#include "infer/cxx_api/add_fusion.h" -#include "abstract/ops/primitive_infer_map.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" - -namespace mindspore::kernel { -using mindspore::ops::kNameAddFusion; - -int AscendNativeAddKernel::InferShape() { - if (out_tensors_[0]->shape().size() == 0) { - if (in_tensors_[0] != nullptr) { - out_tensors_[0]->set_shape(in_tensors_[0]->shape()); - } - } - return kSuccess; -} - -int AscendNativeAddKernel::Prepare() { return kSuccess; } - -int AscendNativeAddKernel::Run() { - MS_LOG(INFO) << "AscendNativeAddKernel::Execute"; - const std::vector &in_tensors = this->in_tensors(); - auto aBufSize = in_tensors[0]->ElementsNum(); - ascend_native::AddFp16(in_tensors[0]->device_data(), in_tensors[1]->device_data(), out_tensors()[0]->device_data(), - aBufSize, const_cast(get_stream())); - return kSuccess; -} - -int AscendNativeAddKernel::ReSize() { return kSuccess; } -REGISTER_ASCEND_NATIVE_CREATOR(kNameAddFusion, AscendNativeAddKernel) -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_base_kernel.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_base_kernel.h deleted file mode 100644 index 83229d90bfa8c7762756e91b19719aff557b197f..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_base_kernel.h +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_BASE_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_BASE_KERNEL_H_ - -#include -#include -#include -#include -#include "extendrt/delegate/type.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/utils.h" -#include "extendrt/kernel/base_kernel.h" -#include "ops/base_operator.h" -#include "mindspore/ops/op_def/op_name.h" - -namespace mindspore { -namespace common { -using KernelWithIndex = std::pair; -} // namespace common -struct KernelWithIndexAndTensor { - KernelWithIndexAndTensor() = default; - KernelWithIndexAndTensor(common::KernelWithIndex kernel_index, kernel::InferTensor *tensor_info) - : kernel_index(kernel_index), tensor_info(tensor_info) {} - - common::KernelWithIndex kernel_index; - kernel::InferTensor *tensor_info{nullptr}; -}; -namespace kernel { -class AscendNativeBaseKernel : public BaseKernel { - public: - AscendNativeBaseKernel() = delete; - - AscendNativeBaseKernel(const AscendNativeBaseKernel &kernel) = delete; - - AscendNativeBaseKernel(AscendNativeBaseKernel &&other) = delete; - - AscendNativeBaseKernel &operator=(const AscendNativeBaseKernel &kernel) = delete; - - AscendNativeBaseKernel &operator=(AscendNativeBaseKernel &&src) = delete; - - AscendNativeBaseKernel(InferPrimitive prim, const InferContext *ctx, const void *stream, std::string name) - : BaseKernel(prim, ctx), stream_(stream), name_(name) {} - - AscendNativeBaseKernel(const std::vector &inputs, const std::vector &outputs, - InferPrimitive prim, const InferContext *ctx, const void *stream, std::string name) - : BaseKernel(prim, inputs, outputs, ctx), stream_(stream), name_(name) {} - - template - std::shared_ptr AsOps() { - return std::make_shared(primitive_.base_operator->GetPrim()); - } - - void set_stream(const void *stream) { stream_ = stream; } - const void *get_stream() { return stream_; } - const std::string get_name() const { return name_; } - void set_name(std::string name) { name_ = name; } - bool InferShapeDone() const override { return true; } - int InferShape() override { return mindspore::lite::RET_OK; } - int PreProcess() override { return mindspore::lite::RET_OK; } - int PostProcess() override { return mindspore::lite::RET_OK; } - virtual bool IsWeightInputHanledInner() const { return false; } - virtual bool isFormatAndTypeSupport(int index, TypeId type, Format fmt) { return true; } - virtual size_t get_workspace_size() const { return 0; } - - void *get_workspace() const { return ws_ptr_; } - void set_workspace(void *ws_ptr) { ws_ptr_ = ws_ptr; } - - protected: - const void *stream_ = nullptr; - std::string name_; - FuncGraphPtr func_graph_; - - private: - void *ws_ptr_ = nullptr; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_BASE_KERNEL_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_encoder_kernel.cc b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_encoder_kernel.cc deleted file mode 100644 index 710938cf2f755e6c76fa1c50e9830c35b5bfd105..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_encoder_kernel.cc +++ /dev/null @@ -1,222 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/delegate/ascend_native/ascend_native_encoder_kernel.h" -#include "extendrt/delegate/ascend_native/ascend_native_kernel_registry.h" -#include "infer/encoder_layer.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" - -namespace mindspore::kernel { -using mindspore::ops::kNameEncoderLayer; - -std::vector AscendNativeEncoderKernel::GetOutputDimensions() { - std::vector dims; - if (param_.is_query_) { - dims.push_back(param_.B); - dims.push_back(param_.embedding_param.get_vcobalary_size()); - } else if (param_.is_last_norm_) { - dims.push_back(param_.B * param_.MAX_N); - dims.push_back(param_.D); - } else { - dims.push_back(param_.B); - dims.push_back(param_.MAX_N); - dims.push_back(param_.D); - } - return dims; -} - -int AscendNativeEncoderKernel::InferShape() { - out_tensors_[0]->set_shape(GetOutputDimensions()); - out_tensors_[0]->set_data_type(TypeId::kNumberTypeFloat16); - return kSuccess; -} - -int AscendNativeEncoderKernel::InitEncoderParam() { - // get encoder primitive - auto encoder_op = AsOps(); - if (encoder_op == nullptr) { - MS_LOG(ERROR) << "convert to primitive encoder failed for " << get_name(); - return kLiteError; - } - int idx = 1; // idx 0 is from tensor or input_ids - if (encoder_op->get_use_past()) { - // setup k, v cache - param_.attn_param.set_v_cache(in_tensors_.at(idx++)->device_data()); - param_.attn_param.set_k_cache(in_tensors_.at(idx++)->device_data()); - } - // setup normalization 1 parameters - param_.norm1.set_eps(encoder_op->get_eps_layernorm1()); - param_.norm1.set_gamma(in_tensors_.at(idx++)->device_data()); - param_.norm1.set_beta(in_tensors_.at(idx++)->device_data()); - // setup comm param - param_.comm_param.set_rank_id(0); - param_.comm_param.set_rank_num(1); -#ifdef LITE_ASCEND_DISTRIBUTION_TBD_FLAG - uint32_t rank_id = GetDeviceInfo(context_); - param_.comm_param.set_rank_id(rank_id); - param_.comm_param.set_rank_num(Num4); -#endif - // setup attention param - param_.attn_param.set_head_number(encoder_op->get_head_num()); - param_.attn_param.set_head_size(encoder_op->get_head_size()); - param_.attn_param.set_hidden_dim(param_.attn_param.get_head_number() * param_.attn_param.get_head_size() * - param_.comm_param.get_rank_num()); - param_.H = encoder_op->get_head_num(); - param_.HS = encoder_op->get_head_size(); - param_.D = param_.H * param_.HS * param_.comm_param.get_rank_num(); - param_.is_query_ = encoder_op->get_query_layer(); - param_.attn_param.set_is_cross(encoder_op->get_query_layer()); - param_.attn_param.set_scale(encoder_op->get_scale()); - if (param_.attn_param.get_is_cross()) idx++; // skip posotion ids - param_.attn_param.set_qkv_weight(in_tensors_.at(idx++)->device_data()); - if (param_.attn_param.get_is_cross()) { - param_.attn_param.set_kv_weight(in_tensors_.at(idx++)->device_data()); - } - param_.attn_param.set_qkv_bias(in_tensors_.at(idx++)->device_data()); - param_.attn_param.set_q_seq_len(in_tensors_.at(C0NUM)->shape()[C1NUM]); - param_.attn_param.set_kv_seq_len(in_tensors_.at(C0NUM)->shape()[C1NUM]); - param_.MAX_N = in_tensors_.at(C0NUM)->shape()[C1NUM]; - param_.B = in_tensors_.at(in_tensors_.size() - Num2)->shape()[C0NUM]; - idx++; // skip mask - param_.attn_param.set_projection_weight(in_tensors_.at(idx++)->device_data()); - param_.attn_param.set_projection_bias(in_tensors_.at(idx++)->device_data()); - param_.norm2.set_eps(encoder_op->get_eps_layernorm2()); - param_.norm2.set_gamma(in_tensors_.at(idx++)->device_data()); - param_.norm2.set_beta(in_tensors_.at(idx++)->device_data()); - param_.is_moe_ = encoder_op->get_moe(); - if (param_.is_moe_) { - param_.moe.set_expert_number(encoder_op->get_expert_num()); - param_.E = encoder_op->get_expert_num(); - } else { - param_.ffn_param.set_mapping_weight(in_tensors_.at(idx++)->device_data()); - param_.ffn_param.set_mapping_bias(in_tensors_.at(idx++)->device_data()); - param_.ffn_param.set_projection_weight(in_tensors_.at(idx++)->device_data()); - param_.ffn_param.set_projection_bias(in_tensors_.at(idx++)->device_data()); - } - param_.ffn_param.set_ffn_hidden_size(encoder_op->get_ffn_hidden_size()); - param_.HFFN = encoder_op->get_ffn_hidden_size(); - // setup normalization 3 parameters - if exist - param_.is_last_norm_ = encoder_op->get_layer_norm(); - if (param_.is_last_norm_) { - param_.norm3.set_eps(encoder_op->get_eps_layernorm3()); - param_.norm3.set_gamma(in_tensors_.at(idx++)->device_data()); - param_.norm3.set_beta(in_tensors_.at(idx++)->device_data()); - } - // setup query layer - if (param_.is_query_) { - auto t = in_tensors_.at(idx++); - param_.embedding_param.set_vcobalary_size(t->shape().at(C0NUM)); - param_.embedding_param.set_word_embedding(t->device_data()); - param_.embedding_param.set_top_query_embedding(in_tensors_.at(idx++)->device_data()); - } - param_.is_embedding_ = encoder_op->get_embedding_layer(); - // setup embedding - if (param_.is_embedding_) { - auto t = in_tensors_.at(idx++); - param_.embedding_param.set_vcobalary_size(t->shape().at(C0NUM)); - param_.embedding_param.set_word_embedding(t->device_data()); - param_.embedding_param.set_position_embedding(in_tensors_.at(idx++)->device_data()); - } - param_.capacity = 0; // capacity of tokens per expert - return kSuccess; -} - -int AscendNativeEncoderKernel::Prepare() { - auto ret = InitEncoderParam(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Ascend native encoder kernel InitEncoderParam failed."; - return kLiteError; - } - if (param_.is_query_) { - encoder_driver_ = std::make_shared(); - } else if (param_.is_last_norm_) { - encoder_driver_ = std::make_shared(); - } else { - encoder_driver_ = std::make_shared(param_.is_embedding_); - } - build_driver_input_const_tensors(); - return kSuccess; -} - -void AscendNativeEncoderKernel::build_driver_input_const_tensors() { - driver_input_tensors_.at(ENCODER_LN1_GAMMA_IDX) = param_.norm1.get_gamma(); - driver_input_tensors_.at(ENCODER_LN1_BETA_IDX) = param_.norm1.get_beta(); - if (param_.is_query_) { - driver_input_tensors_.at(ENCODER_DENSE_Q_IDX) = param_.attn_param.get_qkv_weight(); - driver_input_tensors_.at(ENCODER_DENSE_KV_CONCAT_IDX) = param_.attn_param.get_kv_weight(); - } else { - driver_input_tensors_.at(ENCODER_DENSE_CONCAT_IDX) = param_.attn_param.get_qkv_weight(); - } - driver_input_tensors_.at(ENCODER_K_CACHE_IDX) = param_.attn_param.get_k_cache(); - driver_input_tensors_.at(ENCODER_V_CACHE_IDX) = param_.attn_param.get_v_cache(); - driver_input_tensors_.at(ENCODER_DENSE_BIAS_IDX) = param_.attn_param.get_qkv_bias(); - driver_input_tensors_.at(ENCODER_PROJECTION_IDX) = param_.attn_param.get_projection_weight(); - driver_input_tensors_.at(ENCODER_PROJECTION_BIAS_IDX) = param_.attn_param.get_projection_bias(); - driver_input_tensors_.at(ENCODER_LN2_GAMMA_IDX) = param_.norm2.get_gamma(); - driver_input_tensors_.at(ENCODER_LN2_BETA_IDX) = param_.norm2.get_beta(); - driver_input_tensors_.at(ENCODER_FFN_OUT_IDX) = param_.ffn_param.get_mapping_weight(); - driver_input_tensors_.at(ENCODER_FFN_OUT_BIAS_IDX) = param_.ffn_param.get_mapping_bias(); - driver_input_tensors_.at(ENCODER_FFN_PROJ_IDX) = param_.ffn_param.get_projection_weight(); - driver_input_tensors_.at(ENCODER_FFN_PROJ_BIAS_IDX) = param_.ffn_param.get_projection_bias(); - - driver_input_tensors_.at(ENCODER_V_EMBEDDING_IDX) = param_.embedding_param.get_word_embedding(); - driver_input_tensors_.at(ENCODER_P_EMBEDDING_IDX) = param_.embedding_param.get_position_embedding(); - driver_input_tensors_.at(ENCODER_QUERY_EMBEDDING_IDX) = param_.embedding_param.get_top_query_embedding(); - - driver_input_tensors_.at(ENCODER_LN3_GAMMA_IDX) = param_.norm3.get_gamma(); - driver_input_tensors_.at(ENCODER_LN3_BETA_IDX) = param_.norm3.get_beta(); -} - -int AscendNativeEncoderKernel::Run() { - const std::vector &inputs = in_tensors(); - - if (param_.is_embedding_) { - driver_input_tensors_.at(ENCODER_INPUT_IDS_IDX) = inputs.at(0)->device_data(); - } else { - driver_input_tensors_.at(ENCODER_INPUT_IDX) = inputs.at(0)->device_data(); - } - if (param_.is_query_) { - driver_output_tensors_.at(HEAD_OUTPUT_IDX) = out_tensors_.at(0)->device_data(); - } else if (param_.is_last_norm_) { - driver_output_tensors_.at(NORM_OUTPUT_IDX) = out_tensors_.at(0)->device_data(); - } else { - driver_output_tensors_.at(ENCODER_OUTPUT_IDX) = out_tensors_.at(0)->device_data(); - } - driver_input_tensors_.at(ENCODER_BATCH_VALID_LENGTH_IDX) = inputs.at(inputs.size() - C1NUM)->device_data(); - driver_input_tensors_.at(ENCODER_POS_IDS_IDX) = inputs.at(inputs.size() - C2NUM)->device_data(); - int token = 0; - if (param_.is_embedding_) - ascend_native::PreapreVSL(driver_input_tensors_.at(ENCODER_BATCH_VALID_LENGTH_IDX), param_.B, &token, - const_cast(stream_)); - else - ascend_native::GetTokenNum(driver_input_tensors_.at(ENCODER_BATCH_VALID_LENGTH_IDX), param_.B, &token, - const_cast(stream_)); - param_.num_of_tokens = token; - static bool inc = false; - if (inc) { - param_.num_of_tokens = 1 * param_.B; - } - if (param_.is_query_ && !inc) inc = true; - void *ws = get_workspace(); - - encoder_driver_->Forward(&driver_input_tensors_, &driver_output_tensors_, ws, ¶m_, const_cast(stream_)); - return kSuccess; -} - -size_t AscendNativeEncoderKernel::get_workspace_size() const { return encoder_driver_->GetWorkspaceSize(param_); } - -REGISTER_ASCEND_NATIVE_CREATOR(kNameEncoderLayer, AscendNativeEncoderKernel) -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_encoder_kernel.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_encoder_kernel.h deleted file mode 100644 index 718155443999226d4cd3559110f7caec6a9e59cb..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_encoder_kernel.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_ENCODER_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_ENCODER_KERNEL_H_ - -#include -#include -#include -#include -#include "extendrt/delegate/ascend_native/ascend_native_base_kernel.h" -#include "extendrt/utils/func_graph_utils.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/encoder.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/query.h" - -namespace mindspore::kernel { -class AscendNativeEncoderKernel : public AscendNativeBaseKernel { - public: - AscendNativeEncoderKernel(const std::vector &inputs, const std::vector &outputs, - InferPrimitive prim, const InferContext *ctx, const void *stream, std::string name) - : AscendNativeBaseKernel(inputs, outputs, prim, ctx, stream, name), - driver_input_tensors_(ENCODER_LAST_IDX), - driver_output_tensors_(ENCODER_OUTPUT_LAST_IDX) {} - virtual ~AscendNativeEncoderKernel() {} - - int Prepare() override; - - int Run() override; - - size_t get_workspace_size() const override; - - int InferShape() override; - - private: - void build_driver_input_const_tensors(); - int InitEncoderParam(); - std::vector GetOutputDimensions(); - ascend_native::EncoderParams param_; - std::shared_ptr encoder_driver_; - std::vector driver_input_tensors_; - std::vector driver_output_tensors_; -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_ENCODER_KERNEL_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_gather_kernel.cc b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_gather_kernel.cc deleted file mode 100644 index 5cccd0e50f0813d57cedfbc705b6429093a7130c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_gather_kernel.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/delegate/ascend_native/ascend_native_gather_kernel.h" -#include -#include -#include "extendrt/delegate/ascend_native/ascend_native_kernel_registry.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/gather.h" -#include "infer/ops_func_impl/gather.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" - -namespace mindspore::kernel { -using mindspore::ops::kNameGather; - -int AscendNativeGatherKernel::InferShape() { - auto shape_input = in_tensors_.at(FIRST_INPUT)->shape(); - auto shape_indices = in_tensors_.at(SECOND_INPUT)->shape(); - int32_t axis = 0; - auto axis_tensor = reinterpret_cast(in_tensors_.at(THIRD_INPUT)->device_data()); - ascend_native::CopyDeviceFp32ToHostFp32(axis_tensor, &axis, 2, const_cast(this->get_stream())); - std::vector out_shape; - for (auto i = 0; i < axis; i++) { - out_shape.push_back(shape_input[i]); - } - for (size_t i = 0; i < shape_indices.size(); i++) { - out_shape.push_back(shape_indices[i]); - } - for (size_t i = axis + 1; i < shape_input.size(); i++) { - out_shape.push_back(shape_input[i]); - } - out_tensors()[0]->set_shape(out_shape); - return kSuccess; -} - -int AscendNativeGatherKernel::Prepare() { return kSuccess; } - -int AscendNativeGatherKernel::Run() { - MS_LOG(INFO) << "AscendNativeGatherKernel::Execute"; - const std::vector &in_tensors = this->in_tensors(); - if (in_tensors.size() != THREE_TENSOR) { - MS_LOG(ERROR) << "AscendNativeGatherKernel inputs number should be 3, instead got " << in_tensors.size(); - return kLiteError; - } - int64_t axis = 0; - auto axis_tensor = reinterpret_cast(in_tensors.at(THIRD_INPUT)->device_data()); - ascend_native::CopyDeviceFp32ToHostFp32(axis_tensor, &axis, 2, const_cast(this->get_stream())); - auto shape = in_tensors.at(FIRST_INPUT)->shape(); - size_t num_tiles = 1, m = 1; - for (auto i = 0; i < axis; i++) { - num_tiles *= shape.at(i); - } - for (size_t i = axis + 1; i < shape.size(); i++) { - m *= shape.at(i); - } - - ascend_native::GatherFp16(out_tensors().at(FIRST_INPUT)->device_data(), in_tensors.at(FIRST_INPUT)->device_data(), - reinterpret_cast(in_tensors.at(SECOND_INPUT)->device_data()), - in_tensors.at(SECOND_INPUT)->shape().at(0), m, num_tiles, shape.at(axis), - const_cast(get_stream())); - return kSuccess; -} - -int AscendNativeGatherKernel::ReSize() { return kSuccess; } - -REGISTER_ASCEND_NATIVE_CREATOR(kNameGather, AscendNativeGatherKernel) -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_gather_kernel.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_gather_kernel.h deleted file mode 100644 index 4ed72dd0f82e4c7e21603f17aad213ee7336a762..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_gather_kernel.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_GATHER_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_GATHER_KERNEL_H_ - -#include -#include -#include -#include "extendrt/delegate/ascend_native/ascend_native_base_kernel.h" -#include "extendrt/utils/func_graph_utils.h" - -namespace mindspore::kernel { -class AscendNativeGatherKernel : public AscendNativeBaseKernel { - public: - AscendNativeGatherKernel(const std::vector &inputs, const std::vector &outputs, - InferPrimitive prim, const InferContext *ctx, const void *stream, std::string name) - : AscendNativeBaseKernel(inputs, outputs, prim, ctx, stream, name) {} - int InferShape() override; - int Prepare() override; - int Run() override; - int ReSize() override; -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_GATHER_KERNEL_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/CMakeLists.txt b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/CMakeLists.txt deleted file mode 100644 index 7a9e35f46bf79e87376b373a7cf5c6ceea000332..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -## cross compile -set(CMAKE_CROSSCOMPILING ON) -set(CMAKE_C_COMPILER clang) -set(CMAKE_CXX_COMPILER clang++) -if(DEFINED ENV{MSLITE_ASCEND_TARGET}) - set(MSLITE_ASCEND_TARGET $ENV{MSLITE_ASCEND_TARGET}) -endif() -# find ascend native compiler -find_file(ASCEND_NATIVE clang++ REQUIRED) -get_filename_component(ASCEND_NATIVE_DIR ${ASCEND_NATIVE} DIRECTORY) -set(HMMI_LIB_DIR ${ASCEND_NATIVE_DIR}/../lib) -get_filename_component(HMMI_LIB ${HMMI_LIB_DIR} ABSOLUTE) -set(CMAKE_CXX_FLAGS "-fPIC -fsycl -fsycl-targets=${MSLITE_ASCEND_TARGET} \ - -fplugin=${HMMI_LIB}/HmmiTiling.so -D_GLIBCXX_USE_CXX11_ABI=1") -if(ENABLE_FAST_HASH_TABLE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DENABLE_FAST_HASH_TABLE=1") -endif() -if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") -endif() - -set(CMAKE_CXX_LINK clang++) -if(DEFINED SRC_LIST) - add_library(ascend_native_kernels_impl SHARED ${SRC_LIST}) -endif() diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/add.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/add.h deleted file mode 100644 index 0050d7e91c89c7fadd47bd10118e6379bc7b8071..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/add.h +++ /dev/null @@ -1,22 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_ADD_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_ADD_H_ -namespace mindspore::ascend_native { -void AddFp32(void *x1, void *x2, void *y, uint64_t elem_num, void *q); -void AddFp16(void *x1, void *x2, void *y, uint64_t elem_num, void *q); -} // namespace mindspore::ascend_native -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_ADD_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/encoder.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/encoder.h deleted file mode 100644 index 42df1854cad86dc8b494f7a4d71f289e74b76a24..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/encoder.h +++ /dev/null @@ -1,331 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_ENCODER_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_ENCODER_H_ - -#include - -namespace mindspore::ascend_native { -#define ENCODER_INPUT_IDX 0 -#define ENCODER_LN1_GAMMA_IDX 1 -#define ENCODER_LN1_BETA_IDX 2 -#define ENCODER_DENSE_CONCAT_IDX 3 -#define ENCODER_DENSE_Q_IDX 4 -#define ENCODER_DENSE_KV_CONCAT_IDX 5 -#define ENCODER_DENSE_BIAS_IDX 6 -#define ENCODER_PROJECTION_IDX 7 -#define ENCODER_PROJECTION_BIAS_IDX 8 -#define ENCODER_LN2_GAMMA_IDX 9 -#define ENCODER_LN2_BETA_IDX 10 -#define ENCODER_FFN_OUT_IDX 11 -#define ENCODER_FFN_OUT_BIAS_IDX 12 -#define ENCODER_FFN_PROJ_IDX 13 -#define ENCODER_FFN_PROJ_BIAS_IDX 14 -#define ENCODER_INPUT_IDS_IDX 15 -#define ENCODER_BATCH_VALID_LENGTH_IDX 16 -#define ENCODER_V_EMBEDDING_IDX 18 -#define ENCODER_P_EMBEDDING_IDX 19 -#define ENCODER_QUERY_EMBEDDING_IDX 20 -#define ENCODER_K_CACHE_IDX 21 -#define ENCODER_V_CACHE_IDX 22 -#define ENCODER_POS_IDS_IDX 23 -#define ENCODER_LN3_GAMMA_IDX 24 -#define ENCODER_LN3_BETA_IDX 25 -#define ENCODER_Q_IDX 26 -#define ENCODER_LAST_IDX 27 - -#define ENCODER_OUTPUT_IDX 0 -#define HEAD_OUTPUT_IDX 1 -#define NORM_OUTPUT_IDX 2 -#define ENCODER_OUTPUT_LAST_IDX 3 - -// worspace -// we don't get ENCODER_INDEX_OFFSET_IDX from the MS, we must prepare it -// it's size is min(sizeof(int) * batch size, 32) -// [ | ENCODER_INDEX_OFFSET_IDX ] - -typedef int (*allGatherFuncT)(const void *in, void *out, size_t size, int data_type, void *stream); -typedef int (*allReduceSumFuncT)(const void *in, void *out, size_t size, int data_type, void *stream); - -class CommParam { - public: - CommParam() : rank_id_(1), rank_num_(0) {} - CommParam(uint32_t rank_id, uint32_t rank_num) : rank_id_(rank_id), rank_num_(rank_num) {} - - void set_all_gather(allGatherFuncT allGather) { allGather_ = allGather; } - void set_all_reduce_sum(allReduceSumFuncT allReduceSum) { allReduceSum_ = allReduceSum; } - int allReduceSum(const void *in, void *out, size_t size, int data_type, void *stream) { - if (allReduceSum_ == nullptr) return false; - return allReduceSum_(in, out, size, data_type, stream); - } - int allGather(const void *in, void *out, size_t size, int data_type, void *stream) { - if (allGather_ == nullptr) return false; - return allGather_(in, out, size, data_type, stream); - } - void set_rank_id(uint32_t rank_id) { rank_id_ = rank_id; } - uint32_t get_rank_id() { return rank_id_; } - - void set_rank_num(uint32_t rank_num) { rank_num_ = rank_num; } - uint32_t get_rank_num() { return rank_num_; } - - private: - uint32_t rank_id_; - uint32_t rank_num_; - allGatherFuncT allGather_; - allReduceSumFuncT allReduceSum_; -}; - -class NormParam { - public: - NormParam() : gamma_(nullptr), beta_(nullptr), eps_(1e-5f) {} - NormParam(void *gamma, void *beta, float eps) : gamma_(gamma), beta_(beta), eps_(eps) {} - - void set_gamma(void *gamma) { gamma_ = gamma; } - void *get_gamma() { return gamma_; } - void set_beta(void *beta) { beta_ = beta; } - void *get_beta() { return beta_; } - void set_eps(float eps) { eps_ = eps; } - float get_eps() { return eps_; } - - private: - void *gamma_{nullptr}; - void *beta_{nullptr}; - float eps_; -}; - -class FfnParam { - public: - FfnParam() - : ffn_hidden_size_(0), - projection_weight_(nullptr), - projection_bias_(nullptr), - mapping_weight_(nullptr), - mapping_bias_(nullptr) {} - - void set_projection_weight(void *projection_weight) { projection_weight_ = projection_weight; } - void *get_projection_weight() { return projection_weight_; } - void set_projection_bias(void *projection_bias) { projection_bias_ = projection_bias; } - void *get_projection_bias() { return projection_bias_; } - void set_mapping_weight(void *mapping_weight) { mapping_weight_ = mapping_weight; } - void *get_mapping_weight() { return mapping_weight_; } - void set_mapping_bias(void *mapping_bias) { mapping_bias_ = mapping_bias; } - void *get_mapping_bias() { return mapping_bias_; } - void set_ffn_hidden_size(size_t ffn_hidden_size) { ffn_hidden_size_ = ffn_hidden_size; } - size_t get_ffn_hidden_size() { return ffn_hidden_size_; } - - private: - size_t ffn_hidden_size_; - void *projection_weight_{nullptr}; - void *projection_bias_{nullptr}; - void *mapping_weight_{nullptr}; - void *mapping_bias_{nullptr}; -}; - -class MoeParam : public FfnParam { - public: - MoeParam() : FfnParam(), expert_number_(1), moe_id_(0), expert_ids_(nullptr), capacity_factor_(1.1f) {} - - void set_expert_number(uint32_t expert_number) { expert_number_ = expert_number; } - uint32_t get_expert_number() { return expert_number_; } - void set_moe_id(uint32_t moe_id) { moe_id_ = moe_id; } - uint32_t get_moe_id() { return moe_id_; } - void set_expert_ids(int *expert_ids) { expert_ids_ = expert_ids; } - int *get_expert_ids() { return expert_ids_; } - void set_expert_capacity_factor(float capacity_factor) { capacity_factor_ = capacity_factor; } - float get_expert_capacity_factor() { return capacity_factor_; } - - private: - uint32_t expert_number_; - uint32_t moe_id_; - int *expert_ids_{nullptr}; - float capacity_factor_; -}; - -class AttentionParam { - public: - AttentionParam() - : is_cross_(false), - head_number_(0), - head_size_(0), - hidden_dim_(0), - q_seq_len_(0), - kv_seq_len_(0), - scale_(0.0f), - qkv_weight_(nullptr), - qkv_bias_(nullptr), - k_cache_(nullptr), - v_cache_(nullptr), - projection_weight_(nullptr), - projection_bias_(nullptr) {} - void set_is_cross(size_t is_cross) { is_cross_ = is_cross; } - bool get_is_cross() { return is_cross_; } - void set_head_number(size_t head_number) { head_number_ = head_number; } - size_t get_head_number() { return head_number_; } - void set_head_size(size_t head_size) { head_size_ = head_size; } - size_t get_head_size() { return head_size_; } - void set_q_seq_len(size_t q_seq_len) { q_seq_len_ = q_seq_len; } - size_t get_q_seq_len() { return q_seq_len_; } - void set_kv_seq_len(size_t kv_seq_len) { kv_seq_len_ = kv_seq_len; } - size_t get_kv_seq_len() { return kv_seq_len_; } - void set_hidden_dim(size_t hidden_dim) { hidden_dim_ = hidden_dim; } - size_t get_hidden_dim() { return hidden_dim_; } - void set_scale(float scale) { scale_ = scale; } - float get_scale() { return scale_; } - void set_projection_weight(void *projection_weight) { projection_weight_ = projection_weight; } - void *get_projection_weight() { return projection_weight_; } - void set_projection_bias(void *projection_bias) { projection_bias_ = projection_bias; } - void *get_projection_bias() { return projection_bias_; } - void set_qkv_weight(void *qkv_weight) { qkv_weight_ = qkv_weight; } - void *get_qkv_weight() { return qkv_weight_; } - void set_qkv_bias(void *qkv_bias) { qkv_bias_ = qkv_bias; } - void *get_qkv_bias() { return qkv_bias_; } - void set_kv_weight(void *kv_weight) { kv_weight_ = kv_weight; } - void *get_kv_weight() { return kv_weight_; } - void set_k_cache(void *k_cache) { k_cache_ = k_cache; } - void *get_k_cache() { return k_cache_; } - void set_v_cache(void *v_cache) { v_cache_ = v_cache; } - void *get_v_cache() { return v_cache_; } - - private: - bool is_cross_; - size_t head_number_; - size_t head_size_; - size_t hidden_dim_; - size_t q_seq_len_; - size_t kv_seq_len_; - float scale_; - void *qkv_weight_{nullptr}; - void *qkv_bias_{nullptr}; - void *kv_weight_{nullptr}; - void *k_cache_{nullptr}; - void *v_cache_{nullptr}; - void *projection_weight_{nullptr}; - void *projection_bias_{nullptr}; -}; - -class VslParam { - public: - VslParam() : token_number_(0), padding_offset_(nullptr), seq_len_q_(nullptr), seq_len_kv_(nullptr) {} - - void set_padding_offset(int *padding_offset) { padding_offset_ = padding_offset; } - int *get_padding_offset() { return padding_offset_; } - - void set_seq_len_q(int *seq_len_q) { seq_len_q_ = seq_len_q; } - int *get_seq_len_q() { return seq_len_q_; } - - void set_seq_len_kv(int *seq_len_kv) { seq_len_kv_ = seq_len_kv; } - int *get_seq_len_kv() { return seq_len_kv_; } - - void set_token_number(size_t token_number) { token_number_ = token_number; } - size_t get_token_number() { return token_number_; } - - private: - size_t token_number_; - int *padding_offset_{nullptr}; - int *seq_len_q_{nullptr}; - int *seq_len_kv_{nullptr}; -}; - -class EmbeddingParam { - public: - EmbeddingParam() - : vcobalary_size_(0), word_embedding_(nullptr), position_embedding_(nullptr), top_query_embedding_(nullptr) {} - - void set_vcobalary_size(size_t vcobalary_size) { vcobalary_size_ = vcobalary_size; } - size_t get_vcobalary_size() { return vcobalary_size_; } - - void set_word_embedding(void *word_embedding) { word_embedding_ = word_embedding; } - void *get_word_embedding() { return word_embedding_; } - - void set_position_embedding(void *position_embedding) { position_embedding_ = position_embedding; } - void *get_position_embedding() { return position_embedding_; } - - void set_top_query_embedding(void *top_query_embedding) { top_query_embedding_ = top_query_embedding; } - void *get_top_query_embedding() { return top_query_embedding_; } - - private: - size_t vcobalary_size_; - void *word_embedding_{nullptr}; - void *position_embedding_{nullptr}; - void *top_query_embedding_{nullptr}; -}; - -class EncoderParams { - public: - EncoderParams() {} - size_t B; // batch size - size_t D; // embedding size - size_t E; // num of experts - size_t H; // num of heads - size_t HS; // head size - size_t HFFN; // hidden size (of FFN) - size_t MAX_N; // max num of tokens - size_t vocabulary_size_; // max num of tokens - - size_t capacity; // capacity of tokens per expert - size_t num_of_tokens; // total num of tokens (in all batches) - size_t vocabulary_size; - // embedding parameters - EmbeddingParam embedding_param; - // normalization parameters - NormParam norm1; - NormParam norm2; - NormParam norm3; - // Attention param - AttentionParam attn_param; - // moe param - MoeParam moe; - // ffn param - FfnParam ffn_param; - // communication parameters - CommParam comm_param; - // vsl parameters - VslParam vsl_param; - - bool is_moe_; - bool is_query_; - bool is_embedding_; - bool is_last_norm_; -}; - -class AscendNativeEncoder { - public: - explicit AscendNativeEncoder(bool embedding = false) : embedding_(embedding) {} - virtual ~AscendNativeEncoder() {} - virtual void FFN(std::vector *ins, std::vector *outs, void *ws, EncoderParams *p, void *q); - virtual void MoE(std::vector *ins, std::vector *outs, void *ws, EncoderParams *p, void *q); - virtual void LN(void *dst_norm, void *src, void *g, void *b, float epsilon, EncoderParams *p, void *q); - virtual void Attn(std::vector *ins, std::vector *outs, void *ws, EncoderParams *p, void *q); - virtual void Forward(std::vector *ins, std::vector *outs, void *ws, EncoderParams *p, void *q); - virtual void Embed(std::vector *ins, std::vector *outs, void *ws, EncoderParams *p, void *q); - virtual size_t GetWorkspaceSize(const EncoderParams &p); - - protected: - virtual void Prepare(std::vector *ins, std::vector *outs, void *ws, EncoderParams *p, void *q); - bool embedding_; -}; - -class AscendNativeEncoderFuseLastNorm : public AscendNativeEncoder { - public: - AscendNativeEncoderFuseLastNorm() : AscendNativeEncoder(false) {} - virtual ~AscendNativeEncoderFuseLastNorm() {} - size_t GetWorkspaceSize(const EncoderParams &p) override; - void Forward(std::vector *ins, std::vector *outs, void *ws, EncoderParams *p, void *q) override; -}; - -} // namespace mindspore::ascend_native -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_ENCODER_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/encoder_utils.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/encoder_utils.h deleted file mode 100644 index 0dd467264d7e5f383527b8980f1d02f860978cf6..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/encoder_utils.h +++ /dev/null @@ -1,24 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_ENCODER_UTILS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_ENCODER_UTILS_H_ -namespace mindspore::ascend_native { -void AttentionSoftmaxFp16(void *output, void *input, int num_of_tiles, int tile_size, int *batches, int *batch_offsets, - int heads, bool incr, void *q); -template -void QueryGather(void *input, void *output, int32_t *batches, size_t B, size_t D, int num_of_elements, void *q); -} // namespace mindspore::ascend_native -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_ENCODER_UTILS_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/gather.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/gather.h deleted file mode 100644 index 1ddd54fecce64b898296f6082ac4124926f2c7d9..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/gather.h +++ /dev/null @@ -1,24 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_GATHER_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_GATHER_H_ -namespace mindspore::ascend_native { -void GatherFp32(void *y1, void *x1, int32_t *indices, int num_of_indices, size_t len_to_copy, int num_of_tiles, - size_t tile_size, void *handle); -void GatherFp16(void *y1, void *x1, int32_t *indices, int num_of_indices, size_t len_to_copy, int num_of_tiles, - size_t tile_size, void *handle); -} // namespace mindspore::ascend_native -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_GATHER_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/gemm.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/gemm.h deleted file mode 100644 index 32bdfec34fb9e32f96925f90b6f69b96db9a5ee5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/gemm.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_GEMM_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_GEMM_H_ - -#include - -namespace mindspore::ascend_native { -void GemmFp32(void *queue, size_t m_size, size_t n_size, size_t k_size, float ALPHA, void *d_hA, size_t lda, void *d_hB, - size_t ldb, float BETA, void *d_fC, size_t ldc, size_t core_num); - -void GemmFp16(void *queue, bool ta, bool tb, size_t m_size, size_t n_size, size_t k_size, float ALPHA, void *d_hA, - size_t lda, void *d_hB, size_t ldb, float BETA, void *d_fC, size_t ldc, size_t core_num); - -void BGemmFp16(void *queue, bool ta, bool tb, size_t m_size, size_t n_size, size_t k_size, float ALPHA, void *d_hA, - size_t lda, void *d_hB, size_t ldb, float BETA, void *d_fC, size_t ldc, int repeats, size_t core_num); -} // namespace mindspore::ascend_native -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_GEMM_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/layernorm.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/layernorm.h deleted file mode 100644 index 7af28567eb280e5ef0adf1fa529afb3344d94e7d..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/layernorm.h +++ /dev/null @@ -1,22 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_LAYERNORM_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_LAYERNORM_H_ -namespace mindspore::ascend_native { -void LayerNormFp32(void *output, void *input, void *gamma, void *beta, uint64_t m, uint64_t n, float epsilon, void *q); -void LayerNormFp16(void *output, void *input, void *gamma, void *beta, uint64_t m, uint64_t n, float epsilon, void *q); -} // namespace mindspore::ascend_native -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_LAYERNORM_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/utils.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/utils.h deleted file mode 100644 index eff83cfc3fc0c9c010815e25ece728e24d1a30dd..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_impl/utils.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_UTILS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_UTILS_H_ - -#include - -#define ASCEND_UB_BUFFER_SIZE 241280 -#define ASCEND_HALF_UB_BUFFER_SIZE 120640 -#define ASCEND_THIRD_UB_BUFFER_SIZE 80432 -#define ASCEND_CROSS_BUF_SIZE 256 -#define ASCEND_UB_NUM_ELEM (ASCEND_UB_BUFFER_SIZE / sizeof(T)) -// 120640 for FP16 -#define ASCEND_CROSS_NUM_ELEM (ASCEND_CROSS_BUF_SIZE / sizeof(T)) -// 128 for FP16 -#define MAX_CROSS_ELEMS (ASCEND_UB_NUM_ELEM / (ASCEND_CROSS_NUM_ELEM)) -#define ACEND_LB_UB_BUFFER_SIZE = 25600; -// 938 for FP16 -#define CeilDiv(x, y) (((x) + (y)-1) / (y)) - -namespace mindspore::ascend_native { -void *CreateStream(); -void CopyHostFp32ToDeviceFp16(void *src, void **dst_ptr, size_t elem_num, void *stream); -void CopyHostFp32ToDeviceFp32(void *src, void **dst_ptr, size_t elem_num, void *stream); -void CopyHostFp16ToDeviceFp16(void *src, void **dst_ptr, size_t elem_num, void *stream); -void CopyHostFp16ToDeviceFp32(void *src, void **dst_ptr, size_t elem_num, void *stream); -void CopyDeviceFp16ToHostFp32(void *src, void *dst, size_t elem_num, void *stream); -void CopyDeviceFp32ToHostFp32(void *src, void *dst, size_t elem_num, void *stream); -void CopyDeviceFp16ToHostFp16(void *src, void *dst, size_t elem_num, void *stream); -void CopyDeviceFp32ToHostFp16(void *src, void *dst, size_t elem_num, void *stream); -void *MallocDevice(size_t size, void *stream); -void *MallocCopy(void *src, size_t size, void *stream); -void FreeDevice(void *ptr, void *stream); -void SyncDevice(void *stream); -void PrintFp16(void *x, size_t elem_num, void *stream); -void PrintFp32(void *x, size_t elem_num, void *stream); -void PrintInt32(void *x, size_t elem_num, void *stream); -void printChecksumFp16(void *x, size_t elem_num, void *stream); -void printChecksumFp32(void *x, size_t elem_num, void *stream); -void printChecksumInt32(void *x, size_t elem_num, void *stream); -void PreapreVSL(void *batch_valid_len, int batch_size, int *token_num, void *q); -void PrintInfo(void *stream); -template -void printVector(void *x, int elem_num, void *q); -void GetTokenNum(void *batch_valid_len, int batch_size, int *token_num, void *q); - -} // namespace mindspore::ascend_native -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_ASCEND_NATIVE_IMPL_UTILS_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_kernel_registry.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_kernel_registry.h deleted file mode 100644 index b197a5246390ef28c8fa1cdbee483a0ed9eaae24..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_kernel_registry.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_BISHENG_BISHENG_KERNEL_REGISTRY_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_BISHENG_BISHENG_KERNEL_REGISTRY_H_ - -#include -#include -#include -#include "extendrt/delegate/type.h" -#include "extendrt/delegate/ascend_native/ascend_native_registration_factory.h" -#include "extendrt/delegate/ascend_native/ascend_native_base_kernel.h" - -namespace mindspore::kernel { -template -AscendNativeBaseKernel *GetAscendNativeKernelOp(const std::vector &inputs, - const std::vector &outputs, InferPrimitive prim, - const InferContext *ctx, const void *stream, std::string name) { - auto *op = new (std::nothrow) T(inputs, outputs, prim, ctx, stream, name); - if (op == nullptr) { - MS_LOG(WARNING) << "Ascend op is nullptr."; - return nullptr; - } - return op; -} -typedef AscendNativeBaseKernel *(*AscendNativeKernelOp)(const std::vector &inputs, - const std::vector &outputs, InferPrimitive prim, - const InferContext *ctx, const void *stream, std::string name); - -#define REGISTER_ASCEND_NATIVE_CREATOR(KEY, ASCEND_NATIVE_KERNEL_OP) \ - REGISTER_CLASS_CREATOR(std::string, KEY, AscendNativeKernelOp, GetAscendNativeKernelOp); - -using AscendNativeRegistrationFactory = AutoRegistrationFactory; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_BISHENG_BISHENG_KERNEL_REGISTRY_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_layernorm_kernel.cc b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_layernorm_kernel.cc deleted file mode 100644 index 41ca113c2980fa379cea87b60cb011062ff85406..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_layernorm_kernel.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/delegate/ascend_native/ascend_native_layernorm_kernel.h" -#include "extendrt/delegate/ascend_native/ascend_native_kernel_registry.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/layernorm.h" -#include "infer/cxx_api/layer_norm_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" -namespace mindspore::kernel { -using mindspore::ops::kNameLayerNormFusion; - -int AscendNativeLayernormKernel::InferShape() { - for (size_t i = 0; i < out_tensors_.size(); i++) { - if (out_tensors_[i]->shape().size() == 0) { - if (in_tensors_[i] != nullptr) { - out_tensors_[i]->set_shape(in_tensors_[i]->shape()); - } - } - } - return kSuccess; -} - -int AscendNativeLayernormKernel::Prepare() { return kSuccess; } - -int AscendNativeLayernormKernel::Run() { - MS_LOG(INFO) << "AscendNativeLayernormKernel::Execute"; - const std::vector &in_tensors = this->in_tensors(); - if (in_tensors.size() != THREE_TENSOR) { - MS_LOG(ERROR) << "AscendNativeGatherKernel inputs number should be 3, instead got " << in_tensors.size(); - return kLiteError; - } - auto shape = in_tensors[0]->shape(); - uint64_t m = 1; - for (unsigned int i = 0; i < shape.size() - 1; i++) { - m *= shape.at(i); - } - uint64_t n = shape.at(shape.size() - 1); - if (out_tensors_[0]->device_data() == nullptr) { - out_tensors_[0]->set_device_data(ascend_native::MallocDevice(out_tensors_[0]->Size(), const_cast(stream_))); - } - ascend_native::LayerNormFp32(out_tensors()[0]->device_data(), in_tensors.at(FIRST_INPUT)->device_data(), - in_tensors.at(SECOND_INPUT)->device_data(), in_tensors_.at(THIRD_INPUT)->device_data(), - m, n, 1e-5f, const_cast(get_stream())); - - return kSuccess; -} - -int AscendNativeLayernormKernel::ReSize() { return kSuccess; } - -REGISTER_ASCEND_NATIVE_CREATOR(kNameLayerNormFusion, AscendNativeLayernormKernel) -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_matmul_kernel.cc b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_matmul_kernel.cc deleted file mode 100644 index 4787bb1370c4bef806c372cf4c8f7ddac85b636e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_matmul_kernel.cc +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "extendrt/delegate/ascend_native/ascend_native_matmul_kernel.h" -#include "extendrt/delegate/ascend_native/ascend_native_kernel_registry.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/gemm.h" -#include "infer/cxx_api/mat_mul_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" - -namespace mindspore::kernel { -using mindspore::ops::kNameMatMulFusion; - -int AscendNativeMatmulKernel::InferShape() { - if (in_tensors_[0] != nullptr && in_tensors_[1] != nullptr) { - std::vector shape; - shape.push_back(in_tensors_[0]->shape()[0]); - shape.push_back(in_tensors_[1]->shape()[1]); - out_tensors_[0]->set_shape(shape); - } - return kSuccess; -} - -int AscendNativeMatmulKernel::Prepare() { return kSuccess; } - -int AscendNativeMatmulKernel::Run() { - MS_LOG(INFO) << "AscendNativeMatmulKernel::Execute"; - const std::vector &in_tensors = this->in_tensors(); - const std::vector &out_tensors = this->out_tensors(); - - auto primitive = AsOps(); - if (primitive == nullptr) { - MS_LOG(ERROR) << "convert to primitive matmul failed for " << get_name(); - return kLiteError; - } - bool transpose_a = primitive->get_transpose_a(); - bool transpose_b = primitive->get_transpose_b(); - auto shape_a = in_tensors.at(FIRST_INPUT)->shape(); - auto shape_b = in_tensors.at(SECOND_INPUT)->shape(); - if (shape_a.size() != shape_b.size() || shape_a.size() < 2) { - std::cout << "AscendNativeBatchMatMulKernel::Execute Error -- tensors have different dims or too short\n"; - return kLiteInputTensorError; - } - size_t tiles = 1; - for (size_t i = 0; i < shape_a.size() - 2; i++) { - if (shape_a.at(i) != shape_b.at(i)) { - std::cout << "AscendNativeBatchMatMulKernel::Execute Error -- tensors have different shapes\n"; - return kLiteInputTensorError; - } - tiles *= shape_a.at(i); - } - auto zeroth_mm_dim = shape_a.size() - 2; - auto m = static_cast(shape_a.at(zeroth_mm_dim)); - auto k = static_cast(shape_a.at(zeroth_mm_dim + 1)); - auto n = static_cast(shape_b.at(zeroth_mm_dim + 1)); - - ascend_native::BGemmFp16(const_cast(get_stream()), transpose_a, transpose_b, m, n, k, 1.0f, - in_tensors[0]->device_data(), k, in_tensors[1]->device_data(), n, 0.0f, - out_tensors[0]->device_data(), n, tiles, 1); - - return kSuccess; -} - -int AscendNativeMatmulKernel::ReSize() { - if (in_tensors_[0]->shape()[1] != in_tensors_[1]->shape()[0]) { - MS_LOG(ERROR) << "matmul ReSize failed"; - return lite::RET_ERROR; - } - return lite::RET_OK; -} -REGISTER_ASCEND_NATIVE_CREATOR(kNameMatMulFusion, AscendNativeMatmulKernel) -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_matmul_kernel.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_matmul_kernel.h deleted file mode 100644 index 08a1e68015102a993eca4f6dc6c610b9928f7aea..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_matmul_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_MATMUL_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_MATMUL_KERNEL_H_ - -#include -#include -#include -#include "extendrt/delegate/ascend_native/ascend_native_base_kernel.h" - -namespace mindspore::kernel { -class AscendNativeMatmulKernel : public AscendNativeBaseKernel { - public: - // AscendNativeMatmulKernel() = delete; - - AscendNativeMatmulKernel(const std::vector &inputs, const std::vector &outputs, - InferPrimitive prim, const InferContext *ctx, const void *stream, std::string name) - : AscendNativeBaseKernel(inputs, outputs, prim, ctx, stream, name) {} - - int InferShape() override; - - int Prepare() override; - - int Run() override; - - int ReSize() override; -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_MATMUL_KERNEL_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_registration_factory.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_registration_factory.h deleted file mode 100644 index ac87e7ff0f92f2ac76e1626b359e9c3b2cb5df14..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ascend_native_registration_factory.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_AUTO_REGISTRATION_FACTORY_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_AUTO_REGISTRATION_FACTORY_H_ - -#include - -namespace mindspore::kernel { -template -class AutoRegistrationFactory { - public: - struct AutoRegister { - AutoRegister(KeyType k, CreatorType creator) { - AutoRegistrationFactory::Get().Insert(k, creator); - } - }; - static AutoRegistrationFactory &Get(); - bool HasKey(KeyType k) const { return key2creator_.find(k) != key2creator_.end(); } - CreatorType GetCreator(KeyType k) { return key2creator_[k]; } - - private: - bool Insert(KeyType k, CreatorType creator) { - if (HasKey(k)) { - return false; - } - return key2creator_.emplace(k, creator).second; - } - std::unordered_map key2creator_; -}; - -#define AUTO_REGISTRATION_FACTORY_JOIN(a, b) a##b - -#define AUTO_REGISTRATION_FACTORY_UNIQUE_NAME_JOIN(a, b) AUTO_REGISTRATION_FACTORY_JOIN(a, b) - -#define AUTO_REGISTRATION_FACTORY_UNIQUE_NAME AUTO_REGISTRATION_FACTORY_UNIQUE_NAME_JOIN(g_, __COUNTER__) - -#define REGISTER_CLASS_CREATOR(KeyType, k, CreatorType, creator) \ - static AutoRegistrationFactory::AutoRegister AUTO_REGISTRATION_FACTORY_UNIQUE_NAME(k, creator); -} // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_AUTO_REGISTRATION_FACTORY_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/delegate.cc b/mindspore-lite/src/extendrt/delegate/ascend_native/delegate.cc deleted file mode 100644 index 643ef5d4a1bd36081b91490c6d5e778bf9bbcc34..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/delegate.cc +++ /dev/null @@ -1,396 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/delegate/ascend_native/delegate.h" -#include -#include -#include -#include -#include -#include -#include "extendrt/delegate/ascend_native/ascend_native_kernel_registry.h" -#include "extendrt/delegate/ascend_native/ops/ascend_native_composite.h" -#include "extendrt/delegate/ops/copy.h" -#include "extendrt/kernel/ascend_native/ascend_native_composite_kernel.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/utils.h" -#include "extendrt/delegate/ascend_native/ops/ascend_native_stub.h" -#include "extendrt/delegate/factory.h" -#include "include/common/utils/convert_utils.h" - -#include "infer/encoder_layer.h" -#include "infer/cxx_api/mul_fusion.h" -#include "infer/cxx_api/add_fusion.h" -#include "infer/cxx_api/mat_mul_fusion.h" -#include "infer/use_past_embedding.h" -#include "infer/ops_func_impl/gather.h" -#include "infer/reshape.h" -#include "infer/not_equal.h" -#include "infer/tuple_get_item.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" - -namespace mindspore { - -constexpr auto kAscendNativeProvider = "ascend_native"; -namespace { -static inline kernel::InferTensor *anfTensorToTensorInfo(const common::KernelWithIndex &tensor_id) { - auto [prev_node, index] = tensor_id; - auto data_type = FuncGraphUtils::GetTensorDataType(tensor_id); - auto tensor_val = FuncGraphUtils::GetConstNodeValue(prev_node); - auto shape = FuncGraphUtils::GetTensorShape(tensor_id); - auto name = FuncGraphUtils::GetTensorName(tensor_id); - constexpr auto tensorrt_format = mindspore::Format::NCHW; - const void *data = nullptr; - size_t data_len = 0; - if (tensor_val) { - data = tensor_val->data_c(); - data_len = tensor_val->Size(); - shape = tensor_val->shape_c(); - } else { - if (data_type == DataType::kObjectTypeTuple) { - auto tuple_abs = prev_node->abstract()->cast(); - auto abs = tuple_abs->elements().at(index); - data_type = static_cast(abs->BuildType()->type_id()); - auto base_shape = abs->BuildShape(); - auto shape_ptr = base_shape->cast(); - if (shape_ptr != nullptr) { - shape = shape_ptr->shape(); - } - } - } - auto format = tensorrt_format; - std::vector t_shape; - t_shape.resize(shape.size()); - std::transform(shape.begin(), shape.end(), t_shape.begin(), [](int64_t x) { return static_cast(x); }); - auto t = kernel::InferTensor::CreateTensor(name, static_cast(data_type), t_shape, data, data_len); - t->set_format(format); - return t; -} - -static inline BaseOperatorPtr CreateOperatorByCNode(const CNodePtr &cnode) { - auto prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(prim); - auto kernel_name = prim->name(); - // Create PrimtiveC from map and create BaseOperator. - ops::PrimitiveCPtr primc_ptr = nullptr; - static auto primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap(); - if (primc_fns.find(kernel_name) != primc_fns.end()) { - primc_ptr = primc_fns[kernel_name](); - (void)primc_ptr->SetAttrs(prim->attrs()); - } - MS_EXCEPTION_IF_NULL(primc_ptr); - static auto operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap(); - if (operator_fns.find(kernel_name) == operator_fns.end()) { - MS_LOG(EXCEPTION) << "Cannot create BaseOperator for " << kernel_name; - } - auto base_operator = operator_fns[kernel_name](primc_ptr); - MS_EXCEPTION_IF_NULL(base_operator); - return base_operator; -} -} // namespace - -bool AscendNativeDelegate::IsSupport(const CNodePtr &cnode) { - bool ret = false; - auto prim = GetValueNode(cnode->input(0)); - std::set ops = {ops::kNameEncoderLayer, ops::kNameAddFusion, ops::kNameMatMulFusion, ops::kNameGather, - ops::kNameTupleGetItem}; - if (ops.find(prim->name()) != ops.end()) { - if (prim->name() == ops::kNameMatMulFusion) { - auto base_op = CreateOperatorByCNode(cnode); - if (base_op == nullptr) { - MS_LOG(WARNING) << "no op found for " << cnode->fullname_with_scope(); - return false; - } - auto primitive = std::make_shared(base_op->GetPrim()); - if (primitive == nullptr) { - MS_LOG(WARNING) << "cannot create primitive for MatMulFusion"; - return false; - } - bool act = primitive->get_activation_type(); - if ((act == ActivationType::NO_ACTIVATION) && (cnode->size() == Num3)) { - ret = true; - } - } else if (prim->name() == ops::kNameAddFusion) { - auto shape1 = mindspore::BaseShapeToShape(cnode->input(1)->Shape()); - auto shape2 = mindspore::BaseShapeToShape(cnode->input(2)->Shape()); - auto in1 = std::reduce(shape1.begin(), shape1.end(), 1.0, std::multiplies()); - auto in2 = std::reduce(shape2.begin(), shape2.end(), 1.0, std::multiplies()); - if (in1 == in2) ret = true; - } else { - ret = true; - } - } - return ret; -} - -void AscendNativeDelegate::ReplaceNodes(const std::shared_ptr &graph) { - auto nodes = TopoSort(graph->get_return()); - // for all the nodes in the graph, call the delegate isDelegateNode and CreateKernel interface to create kernels - helper_ = std::make_shared(graph); - for (auto &node : nodes) { - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (!IsSupport(cnode)) continue; - // consider tuple only if parent is supported. - auto prim = GetCNodePrimitive(cnode); - if (prim->name() == ops::kNameTupleGetItem) { - auto parent = cnode->input(1); - if (parent->isa() && !IsSupport(parent->cast())) continue; - } - // check all node inputs belong to same subgraph - int sbg = helper_->CheckAllInputInSameSg(cnode); - // if yes add node to subgraph - if (sbg >= 0) { - helper_->AddToSubGraph(sbg, cnode); - } else { - helper_->AddSubGraph(cnode); - } - } - for (int i = 0; i < helper_->SubGroupNum(); i++) { - ReplaceSubGraph(graph, i); - } - helper_->FixAllNodes(nodes); -} - -void AscendNativeDelegate::ReplaceSubGraph(const std::shared_ptr &graph, int idx) { - auto composite_prim = std::make_shared(); - if (composite_prim == nullptr) { - MS_LOG(ERROR) << "failed to create custom node"; - return; - } - composite_prim->Init(idx); - auto composite_prim_c = composite_prim->GetPrim(); - CNodePtr composite_node = graph->NewCNode(composite_prim_c, helper_->GetSbgInputs(idx)); - composite_node->set_fullname_with_scope("composite_" + std::to_string(idx)); - helper_->SetCNode(idx, composite_node); - // abstract handled later -} - -bool AscendNativeDelegate::IsDelegateNode(const std::shared_ptr &node) { - auto cnode = node->cast(); - auto copy_prim = std::make_shared(ops::kNameCopy); - auto composite_prim = std::make_shared(ops::kNameAscendNativeComposite); - if ((cnode != nullptr) && (IsPrimitiveCNode(cnode, composite_prim) || IsPrimitiveCNode(cnode, copy_prim))) - return true; - return false; -} - -void AscendNativeDelegate::CreateInputKernelTensors(const CNodePtr &cnode, - std::vector *input_tensors, - std::shared_ptr allocator) { - input_tensors->clear(); - auto input_nodes = FuncGraphUtils::GetNodeInputs(cnode); - for (auto &tensor_id : input_nodes) { - auto it = std::find_if(kernel_list_.begin(), kernel_list_.end(), - [&tensor_id](const KernelWithIndexAndTensor &k) { return k.kernel_index == tensor_id; }); - // tensor already created - use the same tensor - if (it != kernel_list_.end()) { - input_tensors->push_back(it->tensor_info); - } else { - auto tensor_info = anfTensorToTensorInfo(tensor_id); - if (tensor_info == nullptr) { - MS_LOG(ERROR) << "failed to get tensor info"; - return; - } - input_tensors->push_back(tensor_info); - kernel_list_.push_back(KernelWithIndexAndTensor(tensor_id, tensor_info)); - } - } -} - -void AscendNativeDelegate::CreateOutputKernelTensors(const CNodePtr &cnode, - std::vector *output_tensors, - std::shared_ptr allocator) { - output_tensors->clear(); - auto output_num = AnfUtils::GetOutputTensorNum(cnode); - for (size_t output_idx = 0; output_idx < output_num; ++output_idx) { - common::KernelWithIndex tensor_id = {cnode, output_idx}; - auto it = std::find_if(kernel_list_.begin(), kernel_list_.end(), - [&tensor_id](const KernelWithIndexAndTensor &k) { return k.kernel_index == tensor_id; }); - if (it != kernel_list_.end()) { - output_tensors->push_back(it->tensor_info); - } else { - auto tensor_info = anfTensorToTensorInfo(tensor_id); - output_tensors->push_back(tensor_info); - kernel_list_.push_back(KernelWithIndexAndTensor(tensor_id, tensor_info)); - } - } -} - -std::shared_ptr AscendNativeDelegate::CreateKernel(const std::shared_ptr &node) { - // step I - Convert to cnode - if (!node->isa()) { - MS_LOG(ERROR) << "AscendNativeDelegate::CreateKernel not a cnode"; - return nullptr; - } - auto cnode = node->cast(); - if (cnode == nullptr) { - MS_LOG(ERROR) << "AscendNativeDelegate::CreateKernel cnode is nullptr"; - return nullptr; - } - auto stream = ascend_native::CreateStream(); - auto allocator = std::make_shared(stream); - // step II - Prepare kernel attributes - std::vector input_tensors; - CreateInputKernelTensors(cnode, &input_tensors, allocator); - std::vector output_tensors; - CreateOutputKernelTensors(cnode, &output_tensors, allocator); - kernel::InferPrimitive primitive; - primitive.base_operator = CreateOperatorByCNode(cnode); - primitive.cnode = cnode; - auto kernel_name = cnode->fullname_with_scope(); - auto node_type = primitive.base_operator->name(); - - if (node_type == ops::kNameAscendNativeComposite) { - auto kernel = std::make_shared(input_tensors, output_tensors, primitive, - ascend_native_ctx_.get(), stream, kernel_name); - int idx = static_cast(GetValue(primitive.base_operator->GetAttr("group"))); - auto func_graph = helper_->GetSbg(idx)->func_graph(); - kernel->set_func_graph(func_graph); - kernel->set_stream(stream); - return kernel; - } else { - // create base kernel for debug - auto orig_node_type = node_type; - // step III - Create Ascend native Kernel - auto &plugin_factory = kernel::AscendNativeRegistrationFactory::Get(); - if (node_type != ops::kNameCopy) { - node_type = ops::kNameAscendNativeStub; - } - if (plugin_factory.HasKey(node_type)) { - kernel::AscendNativeBaseKernel *ascend_native_op = plugin_factory.GetCreator(node_type)( - input_tensors, output_tensors, primitive, ascend_native_ctx_.get(), stream, node_type); - if (ascend_native_op == nullptr) { - return nullptr; - } - auto ker = std::shared_ptr(ascend_native_op); - if (!ker->IsWeightInputHanledInner()) { - auto in_tensors = ker->in_tensors(); - for (auto &t : in_tensors) { - if (t->IsConst() && t->device_data() != nullptr) { - t->set_device_data(ascend_native::MallocCopy(t->data(), t->Size(), const_cast(stream))); - } - } - } - if (node_type == "AscendNativeStub") { - ker->set_name(orig_node_type); - } else { - ker->set_name(kernel_name); - } - ker->set_stream(stream); - return ker; - } else { - MS_LOG(WARNING) << "Unsupported op type for ascend native. kernel name:" << kernel_name << " type:" << node_type; - return nullptr; - } - } -} - -void AscendNativeDelegate::CopyTensors(InferTensor *t_src, InferTensor *t_dst, const void *stream) const { - auto dst = t_dst->device_data(); - auto elem = t_src->Size(); - bool t_is_float = (t_src->data_type() == kNumberTypeFloat || t_src->data_type() == kNumberTypeFloat32); - if (t_is_float) { - ascend_native::CopyHostFp32ToDeviceFp16(t_src->data(), &dst, elem, const_cast(stream)); - } else { - int elem_size = mindspore::lite::DataTypeSize(t_src->data_type()); - switch (elem_size) { - case Num4: - ascend_native::CopyHostFp32ToDeviceFp32(t_src->data(), &dst, elem, const_cast(stream)); - break; - case Num2: - ascend_native::CopyHostFp16ToDeviceFp16(t_src->data(), &dst, elem, const_cast(stream)); - break; - case Num1: - ascend_native::CopyHostFp16ToDeviceFp16(t_src->data(), &dst, elem / 2, const_cast(stream)); - break; - default: - MS_LOG(ERROR) << "no supported size " << elem_size; - } - } - t_dst->set_device_data(dst); -} - -std::shared_ptr AscendNativeDelegate::CreateKernel(const kernel::KernelSpec &spec, - const std::vector &inputs, - const std::vector &outputs, - const InferContext *ctx) const { - // step I - Convert to cnode - auto cnode = spec.cnode; - if (cnode == nullptr) { - MS_LOG(ERROR) << "AscendNativeDelegate::CreateKernel cnode is nullptr"; - return nullptr; - } - // step II - Prepare kernel attributes - auto kernel_name = cnode->fullname_with_scope(); - auto stream = ascend_native::CreateStream(); - if (stream == nullptr) { - MS_LOG(ERROR) << "fail to create stream for kernel " << kernel_name; - return nullptr; - } - kernel::InferPrimitive primitive; - primitive.base_operator = spec.primitive; - primitive.cnode = cnode; - auto node_type = primitive.base_operator->name(); - if (node_type == ops::kNameAscendNativeComposite) { - auto kernel = std::make_shared(inputs, outputs, primitive, - ascend_native_ctx_.get(), stream, kernel_name); - int idx = static_cast(GetValue(primitive.base_operator->GetAttr("group"))); - auto func_graph = helper_->GetSbg(idx)->func_graph(); - kernel->set_func_graph(func_graph); - kernel->set_stream(stream); - return kernel; - } else { - // step III - Create Ascend native Kernel - auto &plugin_factory = kernel::AscendNativeRegistrationFactory::Get(); - if (plugin_factory.HasKey(node_type)) { - kernel::AscendNativeBaseKernel *ascend_native_op = - plugin_factory.GetCreator(node_type)(inputs, outputs, primitive, ascend_native_ctx_.get(), stream, node_type); - if (ascend_native_op == nullptr) { - return nullptr; - } - auto ker = std::shared_ptr(ascend_native_op); - if (!ker->IsWeightInputHanledInner()) { - auto in_tensors = ker->in_tensors(); - for (auto &t : in_tensors) { - if (t->IsConst() && t->device_data() != nullptr) { - CopyTensors(t, t, stream); - } - } - } - ker->set_name(kernel_name); - ker->set_stream(stream); - return ker; - } else { - MS_LOG(ERROR) << "Unsupported op type for ascend native. kernel name:" << kernel_name << " type:" << node_type; - return nullptr; - } - } -} - -ExtendDelegate *AscendDelegateCreator(const std::shared_ptr &, const ConfigInfos &) { - return &AscendNativeDelegate::Instance(); -} -REG_DELEGATE(kAscend, kAscendNativeProvider, AscendDelegateCreator); - -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/delegate.h b/mindspore-lite/src/extendrt/delegate/ascend_native/delegate.h deleted file mode 100644 index 1752e05c9a7ff1c9c7757d3670b13842f363307f..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/delegate.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_DELEGATE_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_DELEGATE_H_ - -#include -#include -#include -#include "extendrt/delegate/type.h" -#include "src/extendrt/utils/func_graph_utils.h" -#include "extendrt/delegate/ascend_native/ascend_native_base_kernel.h" -#include "extendrt/delegate/ascend_native/sub_graph_helper.h" -#include "extendrt/delegate/ascend_native/delegate_allocator.h" - -namespace mindspore { -class AscendNativeDelegate : public ExtendDelegate { - public: - static AscendNativeDelegate &Instance() { - static AscendNativeDelegate instance; - return instance; - } - AscendNativeDelegate() = default; - virtual ~AscendNativeDelegate() = default; - - void ReplaceNodes(const std::shared_ptr &graph) override; - - bool IsDelegateNode(const std::shared_ptr &node) override; - - std::shared_ptr CreateKernel(const std::shared_ptr &node) override; - std::shared_ptr CreateKernel(const kernel::KernelSpec &spec, - const std::vector &inputs, - const std::vector &outputs, - const InferContext *ctx) const override; - - void set_ascend_native_ctx(std::shared_ptr ascend_native_ctx) { - this->ascend_native_ctx_ = ascend_native_ctx; - } - - private: - void CreateInputKernelTensors(const CNodePtr &cnode, std::vector *input_tensors, - std::shared_ptr allocator); - void CreateOutputKernelTensors(const CNodePtr &cnode, std::vector *output_tensors, - std::shared_ptr allocator); - bool IsSupport(const CNodePtr &cnode); - void ReplaceSubGraph(const std::shared_ptr &graph, int idx); - std::vector kernel_list_; - std::shared_ptr ascend_native_ctx_ = nullptr; - void DrawGraph(const std::string &file_name, const std::shared_ptr &graph); - void CopyTensors(InferTensor *t_src, InferTensor *t_dst, const void *stream) const; - std::shared_ptr helper_; -}; - -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_DELEGATE_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/ops/ascend_native_composite.h b/mindspore-lite/src/extendrt/delegate/ascend_native/ops/ascend_native_composite.h deleted file mode 100644 index 14bda0454b829cbd33251a0285109d3153d8634b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/ops/ascend_native_composite.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_OPS_ASCEND_NATIVE_COMPOSITE_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_OPS_ASCEND_NATIVE_COMPOSITE_H_ -#include -#include -#include -#include -#include "ops/base_operator.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameAscendNativeComposite = "AscendNativeComposite"; -/// \brief Custom defined user-defined operator prototype. -class MIND_API AscendNativeComposite : public BaseOperator { - public: - MIND_API_BASE_MEMBER(AscendNativeComposite); - /// \brief Constructor. - AscendNativeComposite() : BaseOperator(kNameAscendNativeComposite) {} - - void Init(int64_t group); - - /// \brief Method to set type attribute. - /// - /// \param[in] type Define the concrete type of the custom op, which is used to distinguish different custom op. - void set_group(int64_t group); - int64_t get_group() const; -}; -} // namespace ops -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_OPS_ASCEND_NATIVE_COMPOSITE_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/stub_kernel.h b/mindspore-lite/src/extendrt/delegate/ascend_native/stub_kernel.h deleted file mode 100644 index 896c6e7a729203df12e7af9319db4479d4e2c718..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/stub_kernel.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_STUB_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_STUB_KERNEL_H_ - -#include -#include -#include -#include "extendrt/delegate/ascend_native/ascend_native_base_kernel.h" - -namespace mindspore::kernel { -class AscendNativeStubKernel : public AscendNativeBaseKernel { - public: - // AscendNativeStubKernel() = delete; - - AscendNativeStubKernel(const std::vector &inputs, const std::vector &outputs, - InferPrimitive prim, const InferContext *ctx, const void *stream, std::string name) - : AscendNativeBaseKernel(inputs, outputs, prim, ctx, stream, name) {} - - int Prepare() override; - - int Execute() override; - int Run() override { return 0; } -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_STUB_KERNEL_H_ diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/sub_graph_helper.cc b/mindspore-lite/src/extendrt/delegate/ascend_native/sub_graph_helper.cc deleted file mode 100644 index af43a023c6e39caa96e48964ad8048453e41b5ec..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/sub_graph_helper.cc +++ /dev/null @@ -1,597 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/delegate/ascend_native/sub_graph_helper.h" -#include -#include -#include "extendrt/utils/func_graph_utils.h" -#include "mindspore/ops/op_def/op_name.h" -#include "infer/tuple_get_item.h" -#include "infer/make_tuple.h" -#include "extendrt/delegate/ascend_native/ops/ascend_native_composite.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore { -AnSubGraph::AnSubGraph(int index) : index_{index} { func_graph_ = std::make_shared(); } -void AnSubGraph::Add(const CNodePtr &cnode) { - auto prim = GetValueNode(cnode->input(0)); - if (prim->name() == ops::kNameAscendNativeComposite) { - int idx = static_cast(GetValue(prim->GetAttr(ops::kGroup))); - MS_LOG(ERROR) << "cannot add composite in graph #" << index_ << "from " << idx; - return; - } - cnode->set_func_graph(func_graph_); - func_graph_->AddNode(cnode); -} - -int AnSubGraph::Size() const { return func_graph_->nodes().size(); } - -void AnSubGraph::AddInput(const AnfNodePtr &node) { - if (input_set_.find(node) == input_set_.end()) { - input_set_.insert(node); - inputs_.push_back(node); - } -} - -void AnSubGraph::AddOutput(const AnfNodePtr &node) { - if (output_set_.find(node) == output_set_.end()) { - output_set_.insert(node); - outputs_.push_back(node); - } -} - -int AnSubGraph::GetOutputId(const CNodePtr &cnode) const { - auto it = std::find(outputs_.begin(), outputs_.end(), cnode); - if (it != outputs_.end()) { - return it - outputs_.begin(); - } - return -1; -} - -CNodePtr AnSubGraph::CreateTuple() { - auto tuple_prim = std::make_shared(); - auto tuple_prim_c = tuple_prim->GetPrim(); - CNodePtr tuple_cnode = func_graph_->NewCNode(tuple_prim_c, {outputs_}); - tuple_cnode->set_fullname_with_scope("composite_" + std::to_string(index_) + "/make_tuple"); - return tuple_cnode; -} - -void AnSubGraph::FixGroup(SubGraphHelperPtr helper) { - // handle inputs - std::vector inputs; - for (size_t i = 1; i < cnode_->size(); i++) { - auto input = cnode_->input(i); - auto is_ginput = helper->IsGraphInput(input); - int group = helper->FindSubGraph(input); - if (group >= 0) { - input = helper->GetCNode(group); - cnode_->set_input(i, input); - } else { - CNodePtr cin; - auto prim = GetCNodePrimitive(input); - bool is_copy = (prim != nullptr) && (prim->name() == ops::kNameCopy); - if ((input->isa() && !is_copy) || is_ginput) { - auto connect_node = helper->CreateGetItemAndCopyUnique(input, 0, cin, ops::Copy::CopyFormatType::HOST_DEVICE); - cnode_->set_input(i, connect_node); - } - } - if (input->isa() || is_ginput) { - auto para = std::make_shared(func_graph_); - const std::string name = input->fullname_with_scope() + "/input_" + std::to_string(i - 1); - para->set_name(name); - para->debug_info()->set_name(name); - para->set_abstract(cnode_->input(i)->abstract()); - inputs.push_back(para); - // replace all function node input to user func_graph inputs - for (const auto &node : func_graph_->nodes()) { - if (node->isa()) { - const auto &cnode = node->cast(); - for (size_t j = 1; j < cnode->size(); j++) { - const auto &node_in = cnode->input(j); - if (node_in == input) { - cnode->set_input(j, para); - } - } - } - } - } - } - func_graph_->set_parameters(inputs); - // handle outputs - int outCount = GetOutputsCount(); - if (outCount < 1) { - MS_LOG(ERROR) << "composite don't have outputs"; - } - if (GetOutputsCount() > 1) { - auto tuple = CreateTuple(); - func_graph_->set_output(tuple); - } else { - func_graph_->set_output(outputs_.at(0)); - } -} - -void AnSubGraph::DumpNode(const AnfNodePtr &node) { - std::cout << node->fullname_with_scope() << " "; - if (node->isa()) { - const auto &cnode = node->cast(); - const auto &prim = GetValueNode(cnode->input(0)); - std::cout << prim->name(); - } - std::cout << std::endl; -} - -void AnSubGraph::Dump() { - int count = 0; - std::cout << "graph have " << func_graph_->get_inputs().size() << " inputs" << std::endl; - for (const auto &in : func_graph_->get_inputs()) { - DumpNode(in); - } - std::cout << "graph have " << func_graph_->nodes().size() << " nodes" << std::endl; - for (const auto &node : func_graph_->nodes()) { - std::cout << "node #" << count << std::endl; - DumpNode(node); - if (node->isa()) { - const auto &cnode = node->cast(); - std::cout << "node " << count << " have " << cnode->size() - 1 << " inputs" << std::endl; - for (size_t i = 1; i < cnode->size(); i++) { - const auto &input = cnode->input(i); - DumpNode(input); - } - } - count++; - } -} - -void AnSubGraph::SetAbstract() { - if (outputs_.size() == 1) { - auto node = outputs_.at(0); - if (node->isa()) { - const auto cnode = node->cast(); - cnode_->set_abstract(cnode->abstract()); - } - } else { - AbstractBasePtrList abstract_list; - for (const auto &output : outputs_) { - auto abstract = output->abstract(); - if (abstract == nullptr) { - MS_LOG(ERROR) << "Create tensor abstract for " << output->fullname_with_scope() << " failed"; - return; - } - auto data_type = abstract->BuildType()->type_id(); - while (data_type == TypeId::kObjectTypeTuple) { - MS_LOG(WARNING) << "got tuple as output of " << output->fullname_with_scope() << " in composite #" << index_; - auto tuple_abs = abstract->cast(); - abstract = tuple_abs->elements().at(0); - if (abstract == nullptr) { - MS_LOG(ERROR) << "Create tensor abstract failed in loop for " << output->fullname_with_scope(); - return; - } - data_type = abstract->BuildType()->type_id(); - } - abstract_list.emplace_back(abstract); - } - auto abstract_tuple = std::make_shared(abstract_list); - if (abstract_tuple == nullptr) { - MS_LOG(ERROR) << "create abstract_tuple failed"; - return; - } - cnode_->set_abstract(abstract_tuple); - } -} - -void SubGraphHelper::FixGroups() { - for (const auto &sbg : sg_v_) { - sbg->FixGroup(shared_from_this()); - } -} - -int SubGraphHelper::CheckAllInputInSameSg(const CNodePtr &cnode) { - int prev_id = -1; - int subg_id = sg_v_.size() - 1; - for (size_t i = 1; i < cnode->size(); i++) { - auto &input = cnode->input(i); - if (!input->isa()) { - continue; - } - subg_id = FindSubGraph(input); - if ((subg_id == -1) || ((prev_id != -1) && (subg_id != prev_id))) { - break; - } - prev_id = subg_id; - } - if ((prev_id == -1) || (subg_id == prev_id)) return subg_id; - return -1; -} - -int SubGraphHelper::GetOutputsCount(int group) { - auto &subg = sg_v_.at(group); - return subg->GetOutputsCount(); -} - -void SubGraphHelper::AddToSubGraph(int index, const CNodePtr &node, bool update) { - auto &subg = sg_v_.at(index); - subg->Add(node); - // add as inputs all non cnode and cnodes not in subgraph - if (update) { - for (size_t i = 1; i < node->size(); i++) { - const auto &input = node->input(i); - if (!input->isa()) { - subg->AddInput(input); - } else { - auto cnode = input->cast(); - int group = FindSubGraph(cnode); - if (group < 0) { - subg->AddInput(input); - } else { - // if input not in the same of input group, its an output of group - if (group != index) { - subg->AddInput(cnode); - } - } - } - } - map_[node] = index; - } -} - -void SubGraphHelper::AddSubGraph(const CNodePtr &node) { - auto index = sg_v_.size(); - auto subg = std::make_shared(index); - sg_v_.push_back(subg); - AddToSubGraph(index, node); -} - -int SubGraphHelper::FindSubGraph(const AnfNodePtr &node) const { - auto const &cnode = node->cast(); - auto it = map_.find(cnode); - if (it != map_.end()) { - return it->second; - } - return -1; -} - -void SubGraphHelper::SetCNode(int idx, const CNodePtr &cnode) { - auto sbg = GetSbg(idx); - sbg->set_cnode(cnode); -} - -const CNodePtr &SubGraphHelper::GetCNode(int idx) const { - auto sbg = GetSbg(idx); - return sbg->cnode(); -} - -void SubGraphHelper::FixOutput() { - const auto &output = func_graph_->output(); - if (output->isa()) { - int group = FindSubGraph(output); - if (group >= 0) { - func_graph_->set_output(GetCNode(group)); - } - } -} - -CNodePtr SubGraphHelper::CreateGetItem(const AnfNodePtr &node, int id, const CNodePtr &input) { - auto tuple_get_item_prim = std::make_shared(); - auto get_item_value = NewValueNode(MakeValue(id)); - if (tuple_get_item_prim == nullptr || get_item_value == nullptr) { - MS_LOG(ERROR) << "NewValueNode is nullptr"; - return nullptr; - } - auto tuple_get_item_prim_c = tuple_get_item_prim->GetPrim(); - MS_ASSERT(tuple_get_item_prim_c != nullptr); - CNodePtr get_item_cnode = func_graph_->NewCNode(tuple_get_item_prim_c, {node, get_item_value}); - if (get_item_cnode == nullptr) { - MS_LOG(ERROR) << "cannot create a new node for value " << id; - return nullptr; - } - get_item_cnode->set_abstract(input->abstract()); - get_item_cnode->set_fullname_with_scope(input->fullname_with_scope() + "/output_getitem_" + std::to_string(id)); - return get_item_cnode; -} - -int SubGraphHelper::GetOutputId(int group, const CNodePtr &input) const { - const auto subg = GetSbg(group); - return subg->GetOutputId(input); -} - -CNodePtr SubGraphHelper::CreateCopyNode(const AnfNodePtr &input, ops::Copy::CopyFormatType type) { - auto copy_prim = std::make_shared(); - if (copy_prim == nullptr) { - MS_LOG(ERROR) << "NewValueNode is nullptr"; - return nullptr; - } - copy_prim->set_copy_format(type); - auto copy_prim_c = copy_prim->GetPrim(); - MS_ASSERT(copy_prim_c != nullptr); - CNodePtr copy_cnode = func_graph_->NewCNode(copy_prim_c, {input}); - if (copy_cnode == nullptr) { - MS_LOG(ERROR) << "cannot create copy node "; - return nullptr; - } - copy_cnode->set_abstract(input->abstract()); - copy_cnode->set_fullname_with_scope(input->fullname_with_scope() + "/copy"); - return copy_cnode; -} - -bool SubGraphHelper::IsGraphInput(const AnfNodePtr &node) const { - const auto &inputs = func_graph_->get_inputs(); - auto it = std::find(inputs.begin(), inputs.end(), node); - if (it != inputs.end()) { - return true; - } - return false; -} - -void SubGraphHelper::SetOutputsAndAbstract(const AnfNodePtrList &nodes) { - for (const auto &node : nodes) { - if (node->isa()) { - auto cnode = node->cast(); - int group = FindSubGraph(cnode); - for (const auto &input : cnode->inputs()) { - if (input->isa()) { - auto cinput = input->cast(); - int in_group = FindSubGraph(cinput); - if ((in_group >= 0) && (in_group != group)) { - AddSubGraphOutput(in_group, cinput); - } - } - } - } - } - for (const auto &sbg : sg_v_) { - sbg->SetAbstract(); - } -} - -AnfNodePtr SubGraphHelper::CreateGetItemAndCopyUnique(const AnfNodePtr &node, int id, const CNodePtr &cinput, - ops::Copy::CopyFormatType type) { - auto pair = std::make_pair(id, node); - auto connect_node = node; - if (connection_map_.find(pair) != connection_map_.end()) { - connect_node = connection_map_.at(pair); - } else { - if (cinput != nullptr) { - auto get_item = CreateGetItem(connect_node, id, cinput); - if (get_item == nullptr) { - MS_LOG(ERROR) << "could not create get_item"; - return nullptr; - } - connect_node = get_item; - } - if (type != ops::Copy::CopyFormatType::NONE) { - auto copy_node = CreateCopyNode(connect_node, type); - if (copy_node == nullptr) { - MS_LOG(ERROR) << "could not create copy_node"; - return nullptr; - } - connect_node = copy_node; - } - (connection_map_)[pair] = connect_node; - } - return connect_node; -} - -void SubGraphHelper::UpdateInput(const CNodePtr &cnode, int index, const AnfNodePtr &input) const { - int group = FindSubGraph(cnode); - if (group >= 0) { - auto cnode_group = GetCNode(group); - auto prev_input = cnode->input(index); - for (size_t i = 1; i < cnode_group->size(); i++) { - if (cnode_group->input(i) == prev_input) { - // update group input - cnode_group->set_input(i, input); - break; - } - } - } - cnode->set_input(index, input); -} - -void SubGraphHelper::FixAllNodes(const AnfNodePtrList &nodes) { - // set up outputs - SetOutputsAndAbstract(nodes); - for (const auto &node : nodes) { - if (node->isa()) { - int cnode_group = FindSubGraph(node); - auto cnode = node->cast(); - for (size_t i = 1; i < cnode->size(); i++) { - auto const &input = cnode->input(i); - if (input->isa()) { - auto cinput = input->cast(); - int in_group = FindSubGraph(input); - ops::Copy::CopyFormatType oper = ops::Copy::CopyFormatType::NONE; - if (cnode_group < 0) { - oper = ops::Copy::CopyFormatType::DEVICE_HOST; - } - if ((in_group >= 0) && (in_group != cnode_group)) { - auto in_cnode = GetCNode(in_group); - if (GetOutputsCount(in_group) > 1) { - int id = GetOutputId(in_group, cinput); - if (id < 0) { - MS_LOG(ERROR) << "cannot find input " << input->fullname_with_scope() << " in group " << in_group - << "output list"; - return; - } - auto connect_node = CreateGetItemAndCopyUnique(in_cnode, id, cinput, oper); - if (connect_node == nullptr) { - MS_LOG(ERROR) << "could not create nodes"; - return; - } - UpdateInput(cnode, i, connect_node); - } else { - CNodePtr in; - auto connect_node = CreateGetItemAndCopyUnique(in_cnode, 0, in, oper); - UpdateInput(cnode, i, connect_node); - } - } - } - } - } - } - FixOutput(); - FixGroups(); -} - -void SubGraphHelper::DrawConnction(const AnfNodePtr &in_node, bool src_composite, int src_idx, const AnfNodePtr &node, - bool dst_composite, int dst_idx, std::ostream &out) const { - constexpr std::string_view quote{"\""}; - if (!src_composite && !dst_composite) { - out << quote << in_node->fullname_with_scope() << quote << "->" << quote << node->fullname_with_scope() << quote - << std::endl; - } else if (src_composite && !dst_composite) { - auto src_name = sg_v_[src_idx]->func_graph()->nodes().front()->fullname_with_scope(); - out << quote << src_name << quote << "->" << quote << node->fullname_with_scope() << quote << "[ltail=cluster_" - << src_idx << "]" << std::endl; - } else if (src_composite && dst_composite) { - auto src_name = sg_v_[src_idx]->func_graph()->nodes().front()->fullname_with_scope(); - auto dst_name = sg_v_[dst_idx]->func_graph()->nodes().front()->fullname_with_scope(); - out << quote << src_name << quote << "->" << quote << dst_name << quote << "[ltail=cluster_" << src_idx - << " lhead=cluster_" << dst_idx << "]" << std::endl; - } else { - auto dst_name = sg_v_[dst_idx]->func_graph()->nodes().front()->fullname_with_scope(); - out << quote << in_node->fullname_with_scope() << quote << "->" << quote << dst_name << quote << "[lhead=cluster_" - << dst_idx << "]" << std::endl; - } -} - -void SubGraphHelper::DrawGraph(const FuncGraphPtr &graph, std::ostream &out, bool recursive) const { - constexpr std::string_view quote{"\""}; - auto is_composite = [](const AnfNodePtr &node, int *idx) { - if (node->isa()) { - auto prim = GetCNodePrimitive(node); - if (prim->name() == ops::kNameAscendNativeComposite) { - *idx = static_cast(GetValue(prim->GetAttr(ops::kGroup))); - return true; - } - } - return false; - }; - auto nodes = TopoSort(graph->get_return()); - for (const auto &node : nodes) { - if (node->isa()) { - auto prim = GetCNodePrimitive(node); - std::string node_name = prim->name(); - int idx; - std::string color; - if (node_name == ops::kNameCopy) { - auto value = prim->GetAttr(ops::kCopyFormat); - if (value == nullptr) { - MS_LOG(ERROR) << "value returned null"; - return; - } - auto type = static_cast(GetValue(value)); - switch (type) { - case ops::Copy::CopyFormatType::DEVICE_HOST: - color = "red"; - break; - case ops::Copy::CopyFormatType::HOST_DEVICE: - color = "green"; - break; - default: - color = ""; - break; - } - } - if (!is_composite(node, &idx)) { - out << quote << node->fullname_with_scope() << quote << "[label=" << quote << node_name << quote; - if (!color.empty()) { - out << " color=" << color; - } - out << " ]" << std::endl; - } - } - } - for (const auto &node : nodes) { - int dst_idx; - if (node->isa()) { - bool dst_composite = false; - if (is_composite(node, &dst_idx) && recursive) { - out << "subgraph cluster_" << dst_idx << " {" << std::endl; - out << "label=\"composite #" << dst_idx << "\"" << std::endl; - out << "style=rounded" << std::endl; - DrawGraph(sg_v_[dst_idx]->func_graph(), out, recursive); - out << "}" << std::endl; - dst_composite = true; - } - auto cnode = node->cast(); - for (size_t i = 1; i < cnode->size(); i++) { - auto &in_node = cnode->input(i); - if (in_node->isa() || IsGraphInput(in_node)) { - bool src_composite = false; - int src_idx; - if (is_composite(in_node, &src_idx)) { - src_composite = true; - } - DrawConnction(in_node, src_composite, src_idx, node, dst_composite, dst_idx, out); - } - } - } - } -} - -void SubGraphHelper::DrawGraph(const std::string &file_name, const FuncGraphPtr &graph, bool recursive) const { - std::ofstream out(file_name); - out << "digraph ascend {" << std::endl; - out << "compound=true" << std::endl; - for (const auto &in : func_graph_->get_inputs()) { - out << "\"" << in->fullname_with_scope() << "\"[shape=box]" << std::endl; - } - DrawGraph(graph, out, recursive); - out << "}\n"; - out.close(); -} - -void SubGraphHelper::DumpNode(std::ofstream &out, const AnfNodePtr &node) const { - out << node->fullname_with_scope() << " typeid=" << node->tid() << " "; - if (node->isa()) { - const auto &cnode = node->cast(); - const auto &prim = GetValueNode(cnode->input(0)); - out << prim->name(); - } - if (node->isa()) { - out << "node is valueNode"; - } - - out << std::endl; -} - -void SubGraphHelper::Dump(std::string file_name) const { - std::ofstream out(file_name); - int count = 0; - out << "graph have " << func_graph_->get_inputs().size() << " inputs" << std::endl; - for (const auto &in : func_graph_->get_inputs()) { - DumpNode(out, in); - } - auto nodes = TopoSort(func_graph_->get_return()); - out << "graph have " << nodes.size() << " nodes" << std::endl; - for (const auto &node : nodes) { - out << "node #" << count << std::endl; - DumpNode(out, node); - if (node->isa()) { - const auto &cnode = node->cast(); - out << "node " << count << " have " << cnode->size() - 1 << " inputs" << std::endl; - for (size_t i = 1; i < cnode->size(); i++) { - const auto &input = cnode->input(i); - DumpNode(out, input); - } - } - count++; - } - out.close(); -} - -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/delegate/ascend_native/sub_graph_helper.h b/mindspore-lite/src/extendrt/delegate/ascend_native/sub_graph_helper.h deleted file mode 100644 index ea9ec5ff6cd0edcbb4f8a933d5087ae1d1d78665..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ascend_native/sub_graph_helper.h +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_SUB_GRAPH_HELPER_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_SUB_GRAPH_HELPER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "ir/anf.h" -#include "mindapi/base/type_id.h" -#include "extendrt/delegate/ops/copy.h" - -namespace mindspore { -class SubGraphHelper; -using SubGraphHelperPtr = std::shared_ptr; - -class AnSubGraph { - public: - AnSubGraph() = delete; - explicit AnSubGraph(int index); - void Add(const CNodePtr &cnode); - int Size() const; - void AddInput(const AnfNodePtr &node); - void AddOutput(const AnfNodePtr &node); - const std::vector &inputs() const { return inputs_; } - void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } - const CNodePtr &cnode() const { return cnode_; } - int GetOutputId(const CNodePtr &cnode) const; - int GetOutputsCount() const { return outputs_.size(); } - void FixGroup(SubGraphHelperPtr helper); - void Dump(); - const FuncGraphPtr &func_graph() { return func_graph_; } - void SetAbstract(); - - private: - int index_{0}; - CNodePtr CreateTuple(); - void DumpNode(const AnfNodePtr &node); - FuncGraphPtr func_graph_; - CNodePtr cnode_; - std::vector inputs_; - std::unordered_set input_set_; - std::vector outputs_; - std::unordered_set output_set_; -}; - -class SubGraphHelper : public std::enable_shared_from_this { - using SubGraphPtr = std::shared_ptr; - - public: - SubGraphHelper() = delete; - explicit SubGraphHelper(const FuncGraphPtr &graph) : func_graph_{graph} {} - int CheckAllInputInSameSg(const CNodePtr &cnode); - int FindSubGraph(const AnfNodePtr &node) const; - void AddSubGraph(const CNodePtr &node); - void AddToSubGraph(int index, const CNodePtr &node, bool update = true); - int SubGroupNum() const { return sg_v_.size(); } - const std::vector &GetSbgInputs(int idx) { return GetSbg(idx)->inputs(); } - SubGraphPtr &CreateSubGraph(); - const CNodePtr &GetCNode(int idx) const; - void DrawGraph(const std::string &file_name, const FuncGraphPtr &graph, bool recursive = false) const; - void SetCNode(int idx, const CNodePtr &cnode); - void FixAllNodes(const AnfNodePtrList &nodes); - const SubGraphPtr &GetSbg(int i) const { return sg_v_[i]; } - AnfNodePtr CreateGetItemAndCopyUnique(const AnfNodePtr &node, int id, const CNodePtr &cinput, - ops::Copy::CopyFormatType type); - bool IsGraphInput(const AnfNodePtr &node) const; - void Dump(std::string file_name) const; - - private: - int GetOutputsCount(int group); - int GetOutputId(int group, const CNodePtr &input) const; - void AddSubGraphOutput(int group, const CNodePtr &cnode) { GetSbg(group)->AddOutput(cnode); } - CNodePtr CreateGetItem(const AnfNodePtr &node, int id, const CNodePtr &input); - CNodePtr CreateCopyNode(const AnfNodePtr &input, ops::Copy::CopyFormatType type); - void SetOutputsAndAbstract(const AnfNodePtrList &nodes); - void UpdateInput(const CNodePtr &cnode, int index, const AnfNodePtr &input) const; - void FixOutput(); - void FixGroups(); - void DrawGraph(const FuncGraphPtr &graph, std::ostream &out, bool recursive) const; - void DrawConnction(const AnfNodePtr &in_node, bool src_composite, int src_idx, const AnfNodePtr &node, - bool dst_composite, int dst_idx, std::ostream &out) const; - void DumpNode(std::ofstream &out, const AnfNodePtr &node) const; - std::vector sg_v_; - std::unordered_map map_; - std::map, AnfNodePtr> connection_map_; // ->node - const FuncGraphPtr &func_graph_; -}; -}; // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_ASCEND_NATIVE_SUB_GRAPH_HELPER_H_ diff --git a/mindspore-lite/src/extendrt/delegate/comm_group_info.cc b/mindspore-lite/src/extendrt/delegate/comm_group_info.cc index fbdf67223a61cf3ace18b7600956a68a7ab92dd5..2468ca3d8b66d403649263d520f331936912c060 100644 --- a/mindspore-lite/src/extendrt/delegate/comm_group_info.cc +++ b/mindspore-lite/src/extendrt/delegate/comm_group_info.cc @@ -26,12 +26,46 @@ #include "include/common/debug/common.h" #include "utils/file_utils.h" namespace mindspore::lite { +std::optional CommGroupInfo::CreatePrefixPath(const std::string &input_path, bool support_relative_path) { + std::optional prefix_path; + std::optional file_name; + FileUtils::SplitDirAndFileName(input_path, &prefix_path, &file_name); + if (!file_name.has_value()) { + MS_LOG(ERROR) << "Cannot get file_name from: " << input_path; + return std::nullopt; + } + auto file_name_str = file_name.value(); +#if defined(SYSTEM_ENV_POSIX) + if (file_name_str.length() > NAME_MAX) { +MS_LOG(ERROR) << "The length of file name: " << file_name_str.length() << " exceeds limit: " << NAME_MAX; +return std::nullopt; +} +#endif + + std::string prefix_path_str; + if (prefix_path.has_value()) { + auto create_prefix_path = FileUtils::CreateNotExistDirs(prefix_path.value(), support_relative_path); + if (!create_prefix_path.has_value()) { + return std::nullopt; + } + prefix_path_str = create_prefix_path.value(); + } else { + auto pwd_path = FileUtils::GetRealPath("./"); + if (!pwd_path.has_value()) { + MS_LOG(ERROR) << "Cannot get pwd path"; + return std::nullopt; + } + prefix_path_str = pwd_path.value(); + } + return std::string(prefix_path_str + "/" + file_name_str); +} + bool CommGroupInfo::CheckPath(const std::string path) const { if (path.size() > PATH_MAX) { MS_LOG(ERROR) << "The checkpoit path " << path << " is too long"; return false; } - auto realpath = Common::CreatePrefixPath(path, true); + auto realpath = CreatePrefixPath(path, true); if (!realpath.has_value()) { MS_LOG(ERROR) << "Get real path failed, path=" << path; return false; diff --git a/mindspore-lite/src/extendrt/delegate/comm_group_info.h b/mindspore-lite/src/extendrt/delegate/comm_group_info.h index d0bf5bab7f5c9a3c2f719101dcdc6c54b83e4175..68716442556bef13fa4dff3e0c4d1eace4ade805 100644 --- a/mindspore-lite/src/extendrt/delegate/comm_group_info.h +++ b/mindspore-lite/src/extendrt/delegate/comm_group_info.h @@ -32,6 +32,7 @@ class CommGroupInfo { CommGroupInfo() {} ~CommGroupInfo() = default; bool LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) const; + static std::optional CreatePrefixPath(const std::string &input_path, bool support_relative_path); private: bool CheckPointExit(const std::string path) const; diff --git a/mindspore-lite/src/extendrt/delegate/delegate_utils.h b/mindspore-lite/src/extendrt/delegate/delegate_utils.h deleted file mode 100644 index bbd54131ef10bdeeb6762745c2d775adc33f9a9d..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/delegate_utils.h +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_DELEGATE_UTILS_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_DELEGATE_UTILS_H_ -#include -#include -#include -#include "src/common/log_adapter.h" -#include "include/errorcode.h" -#include "base/base.h" -#include "src/extendrt/delegate/tensorrt/tensor_info.h" - -namespace mindspore::lite { -bool IsSubGraphInputTensor(const std::vector &inputs, const TensorInfo &input); - -template -std::vector FindPreOps(T *cur_op, std::vector all_ops) { - std::vector in_ops; - for (auto in_tensor : cur_op->inputs()) { - for (auto op : all_ops) { - if (std::find(op->outputs().begin(), op->outputs().end(), in_tensor) != op->outputs().end()) { - in_ops.push_back(op); - } - } - } - return in_ops; -} - -template -std::vector FindNextOps(T *cur_op, std::vector all_ops) { - std::vector out_ops; - for (auto out_tensor : cur_op->outputs()) { - for (auto op : all_ops) { - if (std::find(op->inputs().begin(), op->inputs().end(), out_tensor) != op->inputs().end()) { - out_ops.push_back(op); - } - } - } - return out_ops; -} - -template -void FindPreNextOps(std::vector all_ops) { - std::map> in_tensor_op; - std::map> out_tensor_op; - for (auto op : all_ops) { - for (auto in_tensor : op->inputs()) { - in_tensor_op[in_tensor].insert(op); - } - for (auto out_tensor : op->outputs()) { - out_tensor_op[out_tensor].insert(op); - } - } - for (auto op : all_ops) { - std::set in_ops_set; - for (auto in_tensor : op->inputs()) { - auto in_ops = out_tensor_op[in_tensor]; - in_ops_set.insert(in_ops.begin(), in_ops.end()); - } - std::vector in_ops_vec; - in_ops_vec.assign(in_ops_set.begin(), in_ops_set.end()); - op->set_in_ops(in_ops_vec); - - std::set out_ops_set; - for (auto out_tensor : op->outputs()) { - auto out_ops = in_tensor_op[out_tensor]; - out_ops_set.insert(out_ops.begin(), out_ops.end()); - } - std::vector out_ops_vec; - out_ops_vec.assign(out_ops_set.begin(), out_ops_set.end()); - op->set_out_ops(out_ops_vec); - } -} - -template -int GetGraphInOutOps(const std::vector &inputs, const std::vector &outputs, - std::vector *in_ops, std::vector *out_ops, const std::vector &all_ops) { - for (auto in_tensor : inputs) { - for (auto op : all_ops) { - if (std::find(op->inputs().begin(), op->inputs().end(), in_tensor) != op->inputs().end() && - std::find(in_ops->begin(), in_ops->end(), op) == in_ops->end()) { - in_ops->push_back(op); - } - } - } - if (in_ops->empty()) { - MS_LOG(ERROR) << "Can't find the input ops for npu sub graph."; - return RET_ERROR; - } - - for (auto out_tensor : outputs) { - for (auto op : all_ops) { - if (std::find(op->outputs().begin(), op->outputs().end(), out_tensor) != op->outputs().end() && - std::find(out_ops->begin(), out_ops->end(), op) == out_ops->end()) { - out_ops->push_back(op); - } - } - } - if (out_ops->empty()) { - MS_LOG(ERROR) << "Can't find the output ops for npu sub graph."; - return RET_ERROR; - } - return RET_OK; -} -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_DELEGATE_UTILS_H_ diff --git a/mindspore-lite/src/extendrt/delegate/factory.cc b/mindspore-lite/src/extendrt/delegate/factory.cc index 822e6b2501db3500181a919bec35ada27ceeef6e..fb5b378096dcee617412e0d8bb014535f6db1c84 100644 --- a/mindspore-lite/src/extendrt/delegate/factory.cc +++ b/mindspore-lite/src/extendrt/delegate/factory.cc @@ -15,10 +15,8 @@ */ #include "src/extendrt/delegate/factory.h" -#include "src/extendrt/delegate/type.h" namespace mindspore { -using mindspore::ExtendDelegate; template DelegateRegistry &DelegateRegistry::GetInstance() { @@ -64,7 +62,6 @@ T DelegateRegistry::GetDelegate(const mindspore::DeviceType &device_type, con return (*(creator_it->second))(ctx, config_infos); } -template class DelegateRegistry; template class DelegateRegistry>; } // namespace mindspore diff --git a/mindspore-lite/src/extendrt/delegate/factory.h b/mindspore-lite/src/extendrt/delegate/factory.h index 6c9b24f04277389d4bebf3c77f5db1dfb119f0df..2505ecfc71edd930cbb3b95ce4c26b58be9f9adb 100644 --- a/mindspore-lite/src/extendrt/delegate/factory.h +++ b/mindspore-lite/src/extendrt/delegate/factory.h @@ -25,12 +25,8 @@ #include "src/extendrt/delegate_graph_executor.h" #include "include/api/context.h" #include "src/common/config_infos.h" -#include "extendrt/session/lite_graph_executor.h" namespace mindspore { -using mindspore::LiteGraphExecutor; -// (zhaizhiqiang): Wrap graph executor as delegate. -// typedef std::shared_ptr (*DelegateCreator)(const std::shared_ptr &); template using DelegateCreator = std::function &, const ConfigInfos &)>; diff --git a/mindspore-lite/src/extendrt/delegate/graph_executor/litert/CMakeLists.txt b/mindspore-lite/src/extendrt/delegate/graph_executor/litert/CMakeLists.txt index d28fd55a5cb63d168d263574a628f3c1c96c0285..0a4e6243ec991e3d5ed7f6aa71d68b0b013afe11 100644 --- a/mindspore-lite/src/extendrt/delegate/graph_executor/litert/CMakeLists.txt +++ b/mindspore-lite/src/extendrt/delegate/graph_executor/litert/CMakeLists.txt @@ -1,5 +1,4 @@ if(NOT MSLITE_ENABLE_RUNTIME_CONVERT) - set(API_SRC ${API_SRC} ${CORE_DIR}/utils/status.cc) set(LITE_SRC ${LITE_SRC} ${LITE_DIR}/src/common/config_file.cc) endif() @@ -25,7 +24,6 @@ set(LITE_SRC ${LITE_DIR}/src/litert/infer_manager.cc ${LITE_DIR}/src/litert/runtime_shape_fusion_pass.cc ${LITE_DIR}/src/litert/runtime_pass.cc - # ${LITE_DIR}/src/litert/pass/runtime_ncx_pass.cc ${LITE_DIR}/src/litert/schema_tensor_wrapper.cc ${LITE_DIR}/src/tensor.cc ${LITE_DIR}/src/tensorlist.cc @@ -108,14 +106,6 @@ if(MSLITE_ENABLE_EXPERIMENTAL_KERNEL) set(LITE_SRC ${LITE_SRC} ${EXPERIMENT_SRC}) endif() -if(MSLITE_ENABLE_RUNTIME_GLOG) - if(NOT MSLITE_ENABLE_RUNTIME_CONVERT AND NOT MSLITE_ENABLE_KERNEL_EXECUTOR - AND NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) - set(LITE_SRC ${LITE_SRC} - ${CORE_DIR}/utils/log_adapter.cc) - endif() -endif() - if(MSLITE_ENABLE_RUNTIME_CONVERT) file(GLOB RUNTIME_CONVERT_SRC ${LITE_DIR}/src/common/ops/ops_def.cc @@ -207,20 +197,6 @@ if(MSLITE_ENABLE_MINDRT) ) set(LITE_SRC ${LITE_SRC} ${CONTROL_FLOW_ACTOR_SRC}) endif() -else() - set(LITE_SRC ${LITE_SRC} - ${CORE_DIR}/mindrt/src/thread/core_affinity.cc - ${CORE_DIR}/mindrt/src/thread/threadpool.cc - ) -endif() - -if(MSLITE_ENABLE_GRAPH_KERNEL) - file(GLOB_RECURSE GRAPH_KERNEL_SRC - ${TOOLS_DIR}/graph_kernel/common/*.cc - ${TOOLS_DIR}/graph_kernel/runtime/*.cc - ${OPS_DIR}/kernel/cpu/akg/akg_kernel_loader.cc - ) - set(LITE_SRC ${LITE_SRC} ${GRAPH_KERNEL_SRC}) endif() if(NOT MSLITE_ENABLE_COREML) @@ -245,7 +221,7 @@ add_dependencies(msplugin_infer_lite_src_mid fbs_src fbs_inner_src) add_dependencies(msplugin_infer_lite_src_mid lite_src_common_mid) add_library(msplugin-ge-litert SHARED $) -target_link_libraries(msplugin-ge-litert lite_src_common_mid) +target_link_libraries(msplugin-ge-litert lite_src_common_mid mindspore_core) add_dependencies(msplugin-ge-litert mindspore_converter) target_link_libraries(msplugin-ge-litert mindspore_converter) @@ -271,7 +247,8 @@ if(PLATFORM_ARM32 OR PLATFORM_ARM64 AND NOT TARGET_HIMIX AND NOT TARGET_MIX210 AND NOT TARGET_OHOS_LITE AND NOT MACHINE_LINUX_ARM64) target_link_libraries(msplugin-ge-litert log) endif() -if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite") +if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" AND NOT + (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) target_link_libraries(msplugin-ge-litert minddata_eager_mid minddata-lite) endif() @@ -288,7 +265,7 @@ endif() if(MSLITE_ENABLE_RUNTIME_CONVERT) target_link_libraries(msplugin-ge-litert quantizer_mid fusion_mid proto_mid graph_pass_mid preprocess_mid - cpu_kernel_mid ccsrc_src_mid converter_src_mid lite_exporter_mid + cpu_kernel_mid converter_src_mid lite_exporter_mid config_parser_mid mslite_converter_plugin mindspore_core mindspore_ops coder_mid mindir_serializer_mid mindspore::protobuf ${SECUREC_LIBRARY}) target_link_libraries(msplugin-ge-litert diff --git a/mindspore-lite/src/extendrt/delegate/graph_executor/litert/func_graph_reuse_manager.cc b/mindspore-lite/src/extendrt/delegate/graph_executor/litert/func_graph_reuse_manager.cc index 9059f20e6de5f246f2c99f66bb1235a4d96eb233..4c85f7524dae6a621768e1f5f0868aa0af702b9b 100644 --- a/mindspore-lite/src/extendrt/delegate/graph_executor/litert/func_graph_reuse_manager.cc +++ b/mindspore-lite/src/extendrt/delegate/graph_executor/litert/func_graph_reuse_manager.cc @@ -103,9 +103,8 @@ Status FuncGraphReuseManager::StoreFbModelBuf(void *model_buf, size_t data_size, } Status FuncGraphReuseManager::GetInOut(std::map> config_info, - std::vector *in_tensor, - std::vector *out_tensor, std::vector *in_name, - std::vector *out_name) { + std::vector *in_tensor, std::vector *out_tensor, + std::vector *in_name, std::vector *out_name) { std::unique_lock l(mtx_manager_); auto id = config_info.find(mindspore::lite::kInnerModelParallelRunnerSection); if (id != config_info.end()) { @@ -131,9 +130,8 @@ Status FuncGraphReuseManager::GetInOut(std::map> config_info, - std::vector in_tensor, - std::vector out_tensor, std::vector in_name, - std::vector out_name) { + std::vector in_tensor, std::vector out_tensor, + std::vector in_name, std::vector out_name) { std::unique_lock l(mtx_manager_); auto id = config_info.find(lite::kInnerModelParallelRunnerSection); if (id != config_info.end()) { diff --git a/mindspore-lite/src/extendrt/delegate/graph_executor/litert/func_graph_reuse_manager.h b/mindspore-lite/src/extendrt/delegate/graph_executor/litert/func_graph_reuse_manager.h index 207d6d3de019d6e53f1d91a29284a42e5b9b968b..49245b1c3bd8d006b3223c88b29394103d183522 100644 --- a/mindspore-lite/src/extendrt/delegate/graph_executor/litert/func_graph_reuse_manager.h +++ b/mindspore-lite/src/extendrt/delegate/graph_executor/litert/func_graph_reuse_manager.h @@ -46,10 +46,10 @@ class FuncGraphReuseManager { std::map> config_info); Status GetInOut(std::map> config_info, - std::vector *in_tensor, std::vector *out_tensor, + std::vector *in_tensor, std::vector *out_tensor, std::vector *in_name, std::vector *out_name); Status StoreInOut(std::map> config_info, - std::vector in_tensor, std::vector out_tensor, + std::vector in_tensor, std::vector out_tensor, std::vector in_name, std::vector out_name); void ReleaseSharedFuncGraph(std::map> config_info); @@ -63,8 +63,8 @@ class FuncGraphReuseManager { std::unordered_map all_func_graphs_; std::unordered_map all_fb_model_buf_; std::unordered_map> all_infer_helpers_; - std::unordered_map> all_in_tensors_; - std::unordered_map> all_out_tensors_; + std::unordered_map> all_in_tensors_; + std::unordered_map> all_out_tensors_; std::unordered_map> all_in_names_; std::unordered_map> all_out_names_; }; diff --git a/mindspore-lite/src/extendrt/delegate/graph_executor/litert/graph_executor.cc b/mindspore-lite/src/extendrt/delegate/graph_executor/litert/graph_executor.cc index aea764267619d3573afab710c1046457e11d4ee7..02a3fb969f6e98f53e7f59cd5c9caf88bf7eaf3e 100644 --- a/mindspore-lite/src/extendrt/delegate/graph_executor/litert/graph_executor.cc +++ b/mindspore-lite/src/extendrt/delegate/graph_executor/litert/graph_executor.cc @@ -158,8 +158,7 @@ bool LiteRTGraphExecutor::CompileGraph(const FuncGraphPtr &graph, const std::map return true; } -bool LiteRTGraphExecutor::RunGraph(uint32_t, const std::vector &inputs, - std::vector *outputs, +bool LiteRTGraphExecutor::RunGraph(uint32_t, const std::vector &inputs, std::vector *outputs, const std::map &compile_options) { MS_LOG(INFO) << "LiteRTGraphExecutor::RunGraph with input and outputs"; MS_EXCEPTION_IF_NULL(outputs); @@ -177,44 +176,36 @@ bool LiteRTGraphExecutor::RunGraph(uint32_t, const std::vector & for (size_t i = 0; i < inputs.size(); i++) { auto input = input_tensors.at(i); auto &user_input = inputs.at(i); - if (user_input.data_type() != input->data_type()) { + if (static_cast(user_input.DataType()) != input->data_type()) { ResetTensorData(old_data, input_tensors); - MS_LOG(EXCEPTION) << "Tensor " << user_input.id() << " has a different data type from input" - << input->tensor_name() << "."; } - if (user_input.data_c() == nullptr) { + if (user_input.Data().get() == nullptr) { ResetTensorData(old_data, input_tensors); - MS_LOG(EXCEPTION) << "Tensor " << user_input.id() << " has no data."; } old_data.push_back(input->data()); if (input->data_type() == kObjectTypeString) { #ifndef STRING_KERNEL_CLIP - std::vector shape = - TruncateShape(user_input.shape_c(), input->data_type(), user_input.DataSize(), false); - if (shape.empty() && !(user_input.shape_c().empty())) { + std::vector shape = TruncateShape(user_input.Shape(), input->data_type(), user_input.DataSize(), false); + if (shape.empty() && !(user_input.Shape().empty())) { ResetTensorData(old_data, input_tensors); - MS_LOG(EXCEPTION) << "Input dims of tensor " << user_input.id() << " is invalid."; } input->set_shape(shape); - input->set_data(user_input.data_c(), false); + input->set_data(const_cast(user_input.Data().get()), false); #else MS_LOG(ERROR) << unsupport_string_tensor_log; return kLiteError; #endif } else { - if (user_input.data_c() != input->data()) { - if (input->Size() != user_input.Size()) { + if (user_input.Data().get() != input->data()) { + if (input->Size() != user_input.DataSize()) { ResetTensorData(old_data, input_tensors); #ifndef ENABLE_LITE_ACL - MS_LOG(EXCEPTION) << "Tensor " << user_input.id() << " has wrong data size."; #else - MS_LOG(WARNING) << "Please check tensor " << user_input.id() - << " has been modified data size by DVPP method."; std::vector truncate_shape = {static_cast(user_input.DataSize())}; input->set_shape(truncate_shape); #endif } - input->set_data(user_input.data_c(), false); + input->set_data(const_cast(user_input.Data().get()), false); } } } @@ -250,12 +241,11 @@ bool LiteRTGraphExecutor::RunGraph(uint32_t, const std::vector & MS_LOG(DEBUG) << "Empty outputs."; return false; } - outputs->clear(); - *outputs = TensorUtils::MSTensorToTensor(res); + *outputs = res; return true; } -bool LiteRTGraphExecutor::Resize(uint32_t, const std::vector &inputs, +bool LiteRTGraphExecutor::Resize(uint32_t, const std::vector &inputs, const std::vector> &dims) { auto input_tensors = lite_session_->GetInputs(); if (input_tensors.empty()) { @@ -278,36 +268,33 @@ bool LiteRTGraphExecutor::Resize(uint32_t, const std::vector &in return true; } -std::vector LiteRTGraphExecutor::GetInputInfos(uint32_t) { +std::vector LiteRTGraphExecutor::GetInputInfos(uint32_t) { if (lite_session_ == nullptr) { MS_LOG(ERROR) << "Session is null."; return {}; } auto inputs = lite_session_->GetInputs(); - std::vector input_tensors; + std::vector res; + res.resize(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { - auto type_id = inputs[i]->data_type(); - auto shape = inputs[i]->shape(); - std::vector lite_shape; - std::transform(shape.begin(), shape.end(), std::back_inserter(lite_shape), - [](int c) { return static_cast(c); }); - auto tmp = tensor::Tensor(type_id, lite_shape); - tmp.set_name(inputs[i]->tensor_name()); - input_tensors.push_back(tmp); - } - return input_tensors; + auto impl = std::make_shared(inputs[i]); + if (impl == nullptr || impl->lite_tensor() == nullptr) { + MS_LOG(ERROR) << "impl is nullptr."; + return {}; + } + auto tensor = MSTensor(impl); + if (tensor == nullptr) { + MS_LOG(ERROR) << "create tensor failed."; + return {}; + } + res[i] = tensor; + } + return res; } -std::vector LiteRTGraphExecutor::GetOutputInfos(uint32_t) { +std::vector LiteRTGraphExecutor::GetOutputInfos(uint32_t) { auto outputs = GetLiteSessionOutputs(); - std::vector output_tensors; - for (size_t i = 0; i < outputs.size(); ++i) { - auto type_id = static_cast(outputs[i].DataType()); - auto tmp = tensor::Tensor(type_id, outputs[i].Shape()); - tmp.set_name(outputs[i].Name()); - output_tensors.push_back(tmp); - } - return output_tensors; + return outputs; } void LiteRTGraphExecutor::ResetTensorData(std::vector old_data, const std::vector &tensors) { diff --git a/mindspore-lite/src/extendrt/delegate/graph_executor/litert/graph_executor.h b/mindspore-lite/src/extendrt/delegate/graph_executor/litert/graph_executor.h index b46fb6d633d425a8f784a7758d96a17669e7a90a..981353ff62d11f638a70beee92e5a46f13d98e43 100644 --- a/mindspore-lite/src/extendrt/delegate/graph_executor/litert/graph_executor.h +++ b/mindspore-lite/src/extendrt/delegate/graph_executor/litert/graph_executor.h @@ -29,7 +29,6 @@ #include "src/litert/lite_session.h" #include "src/common/helper/infer_helpers.h" #include "src/common/config_infos.h" -#include "src/extendrt/session/lite_graph_executor.h" namespace mindspore { class LiteRTGraphExecutor : public LiteGraphExecutor { @@ -47,13 +46,13 @@ class LiteRTGraphExecutor : public LiteGraphExecutor { uint32_t *graph_id) override; bool CompileGraph(const void *model_data, size_t data_size, const std::map &compile_options, uint32_t *graph_id) override; - bool RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector *outputs, + bool RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector *outputs, const std::map &compile_options) override; - bool Resize(uint32_t graph_id, const std::vector &inputs, + bool Resize(uint32_t graph_id, const std::vector &inputs, const std::vector &dims) override; - std::vector GetInputInfos(uint32_t graph_id) override; - std::vector GetOutputInfos(uint32_t graph_id) override; + std::vector GetInputInfos(uint32_t graph_id) override; + std::vector GetOutputInfos(uint32_t graph_id) override; std::shared_ptr CreateLiteSession(const std::shared_ptr &context, const ConfigInfos &config_infos); diff --git a/mindspore-lite/src/extendrt/delegate/ops/copy.cc b/mindspore-lite/src/extendrt/delegate/ops/copy.cc deleted file mode 100644 index b7a12cf29ba1bd45c386fad97773a3b71c0e20df..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ops/copy.cc +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/delegate/ops/copy.h" -#include "mindapi/base/shared_ptr.h" -#include "mindapi/ir/common.h" -#include "mindapi/ir/value.h" -#include "utils/check_convert_utils.h" -#include "mindspore/ops/op_def/op_name.h" -#include "ops/primitive_c.h" -#include "src/common/log_adapter.h" -#include "mindapi/helper.h" -#include "src/common/utils.h" -#include "abstract/ops/op_infer.h" -#include "abstract/ops/primitive_infer_map.h" - -namespace mindspore { -namespace ops { -MIND_API_OPERATOR_IMPL(Copy, BaseOperator); - -void Copy::set_copy_format(CopyFormatType format) { this->AddAttr(kCopyFormat, api::MakeValue(format)); } - -int Copy::get_copy_format() const { - auto value_ptr = GetAttr(kCopyFormat); - return static_cast(GetValue(value_ptr)); -} - -class CopyInfer : public abstract::OpInferBase { - public: - BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) const override; - - TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override; - AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive, - const std::vector &input_args) const override; -}; - -BaseShapePtr CopyInfer::InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) const { - return input_args[kInputIndex0]->GetShape(); -} - -TypePtr CopyInfer::InferType(const PrimitivePtr &primitive, const std::vector &input_args) const { - auto format_ptr = primitive->GetAttr(kCopyFormat); - auto oper = static_cast(GetValue(format_ptr)); - TypePtr in_type = input_args[kInputIndex0]->GetType(); - TypePtr res = in_type; - if ((in_type == kFloat32) && (oper == Copy::CopyFormatType::HOST_DEVICE)) { - res = kFloat16; - } - if ((in_type == kFloat16) && (oper == Copy::CopyFormatType::DEVICE_HOST)) { - res = kFloat32; - } - return res; -} - -AbstractBasePtr CopyInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) const { - MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - const int64_t input_num = 1; - CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); - MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]); - auto type = InferType(primitive, input_args); - auto shape = InferShape(primitive, input_args); - return abstract::MakeAbstract(shape, type); -} - -GVAR_DEF(PrimitivePtr, kPrimCopy, std::make_shared(kNameCopy)); -REGISTER_PRIMITIVE_OP_INFER_IMPL(Copy, kPrimCopy, CopyInfer, false); -} // namespace ops -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/delegate/ops/copy.h b/mindspore-lite/src/extendrt/delegate/ops/copy.h deleted file mode 100644 index b07ba9a1460aa9d3bf01a2ba413f866f41ffc030..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/ops/copy.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_OPS_COPY_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_OPS_COPY_H_ -#include -#include -#include -#include - -#include "ops/base_operator.h" -#include "include/api/data_type.h" - -namespace mindspore { -namespace ops { -constexpr auto kNameCopy = "Copy"; -constexpr auto kCopyFormat = "copy_format"; - -/// \brief Custom defined user-defined operator prototype. -class MIND_API Copy : public BaseOperator { - public: - enum CopyFormatType : int { - NONE = 0, - HOST_DEVICE, - DEVICE_HOST, - }; - - MIND_API_BASE_MEMBER(Copy); - /// \brief Constructor. - Copy() : BaseOperator(kNameCopy) {} - void set_copy_format(CopyFormatType format); - int get_copy_format() const; -}; -} // namespace ops -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_OPS_COPY_H_ diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/cache_algorithm.h b/mindspore-lite/src/extendrt/delegate/parameter_cache/cache_algorithm.h deleted file mode 100644 index ac50e7485270e1b5b16922dd6238fb6aee3e64dd..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/cache_algorithm.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_CACHE_ALGORITHM_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_CACHE_ALGORITHM_H_ - -#include -#include "include/api/status.h" - -namespace mindspore { -namespace cache { -struct CacheNoe { - CacheNoe(int _index, int _frequency, int _value) : key(_index), frequency(_frequency), value(_value) {} - int key; // host input index - int frequency; - int value; // cache index -}; - -class CacheAlgorithm { - public: - virtual ~CacheAlgorithm() {} - virtual int Get(int key) = 0; - virtual void Put(int key, int value) = 0; - virtual Status Init(size_t cache_size, int min_host_index, int max_host_index) = 0; - virtual Status CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *cache_index, - std::vector *need_swap_indies, std::vector *need_swap_indies_cache_index) = 0; -}; -} // namespace cache -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_CACHE_ALGORITHM_H_ diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/cache_mem_base.h b/mindspore-lite/src/extendrt/delegate/parameter_cache/cache_mem_base.h deleted file mode 100644 index d43919edca459a32b1883176bfda80ab54b1b36b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/cache_mem_base.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_CACHE_MEM_BASE_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_CACHE_MEM_BASE_H_ -#include -#include - -namespace mindspore { -namespace cache { -class CacheMemBase { - public: - CacheMemBase() = default; - virtual ~CacheMemBase() = default; - virtual bool InitDevice(uint32_t device_id, const void *context) = 0; - virtual void *MallocMemory(size_t size) = 0; - virtual void FreeMemory(void *buf) = 0; - virtual bool SynchronizeStream() = 0; - virtual bool CopyHostMemToDevice(void *dst, const void *src, size_t size) = 0; - virtual bool CopyDeviceMemToHost(void *dst, const void *src, size_t size) = 0; - virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, - size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) = 0; - virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, - size_t cache_vocab_size, size_t embedding_size, size_t swap_in_size) = 0; -}; -} // namespace cache -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_CACHE_MEM_BASE_H_ diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/embedding_cache.cc b/mindspore-lite/src/extendrt/delegate/parameter_cache/embedding_cache.cc deleted file mode 100644 index 71dda4e098b6d21db29a56e1d4b3e34d566b4d44..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/embedding_cache.cc +++ /dev/null @@ -1,238 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "src/extendrt/delegate/parameter_cache/embedding_cache.h" -#include -#include -#include -#include -#include -#include -#include "src/common/log_adapter.h" -#include "include/errorcode.h" -#include "src/extendrt/delegate/parameter_cache/gpu/gpu_cache_mem.h" -#include "src/extendrt/delegate/parameter_cache/lfu_cache.h" -#include "src/extendrt/delegate/parameter_cache/factory_mgr_base.h" -#include "utils/convert_utils_base.h" - -namespace { -constexpr size_t kEmbeddingTensorShapeSize = 2; -} -namespace mindspore { -namespace cache { -void LookUpTableTask(size_t indices_lens, size_t first_dim_size, const char *input_addr, const int *indices_addr, - char *output_addr, size_t embedding_len, int min_host_index) { - for (size_t i = 0; i < indices_lens; ++i) { - int index = indices_addr[i] - min_host_index; - if (index >= 0 && index < static_cast(first_dim_size)) { - size_t pos = index * embedding_len; - std::memcpy(output_addr, input_addr + pos, embedding_len); - } else { - memset(output_addr, 0, embedding_len); - } - output_addr += embedding_len; - } -} - -EmbeddingCache::~EmbeddingCache() { - if (hash_swap_value_device_addr_ != nullptr) { - device_cache_->FreeMemory(hash_swap_value_device_addr_); - hash_swap_value_device_addr_ = nullptr; - } - if (hash_swap_value_addr_ != nullptr) { - free(hash_swap_value_addr_); - hash_swap_value_addr_ = nullptr; - } - if (hash_swap_index_addr_ != nullptr) { - device_cache_->FreeMemory(hash_swap_index_addr_); - hash_swap_index_addr_ = nullptr; - } -} - -Status EmbeddingCache::Init(mindspore::MSTensor host_cache_tensor, mindspore::MSTensor device_tensor) { - MS_ASSERT(device_tensor.Shape().size() == kEmbeddingTensorShapeSize); - MS_ASSERT(host_cache_tensor.Shape().size() == kEmbeddingTensorShapeSize); - MS_ASSERT(device_tensor.DataType() == host_cache_tensor.DataType()); - MS_ASSERT(host_cache_tensor.Data() != nullptr); - - if (device_tensor.Shape()[1] != host_cache_tensor.Shape()[1]) { - MS_LOG(ERROR) << device_tensor.Name() << " embedding_size is invalid, device size is " << device_tensor.Shape()[1] - << ", host size is " << host_cache_tensor.Shape()[1]; - return kLiteError; - } - if (SizeToInt(host_cache_size_) != host_cache_tensor.Shape()[0]) { - MS_LOG(ERROR) << device_tensor.Name() << " host_cache_size is invalid, host_cache_size" - << host_cache_tensor.Shape()[0] << ", index begin:" << min_host_index_ - << ", index end:" << max_host_index_ << "rank_group_size_ num:" << rank_group_size_ - << ", rank id:" << rank_id_ << ", vocab_size_:" << vocab_size_; - return kLiteError; - } - - data_type_ = device_tensor.DataType(); - switch (data_type_) { - case DataType::kNumberTypeFloat32: - sizeof_data_type_ = sizeof(float); - break; - default: - MS_LOG(ERROR) << device_tensor.Name() << " unsupported data type " << static_cast(data_type_); - return kLiteError; - } - host_addr_ = host_cache_tensor.MutableData(); - embedding_size_ = device_tensor.Shape()[1]; - device_start_index_ = device_cache_size_ * rank_id_; - // host cache tensor is device tensor - if (device_tensor.Shape()[0] == host_cache_tensor.Shape()[0]) { - device_start_index_ = min_host_index_; - } - return kSuccess; -} - -Status EmbeddingCache::MallocCacheMemory() { - auto hash_swap_value_size = embedding_size_ * batch_elements_ * sizeof_data_type_; - hash_swap_value_device_addr_ = device_cache_->MallocMemory(hash_swap_value_size); - if (hash_swap_value_device_addr_ == nullptr) { - MS_LOG(ERROR) << "malloc hash_swap_value_device failed, malloc size " << hash_swap_value_size; - return kLiteMemoryFailed; - } - - hash_swap_value_addr_ = malloc(hash_swap_value_size); - if (hash_swap_value_addr_ == nullptr) { - MS_LOG(ERROR) << "malloc hash_swap_value failed, malloc size " << hash_swap_value_size; - return kLiteMemoryFailed; - } - - // data type of index - hash_swap_index_addr_ = static_cast(device_cache_->MallocMemory(batch_elements_ * sizeof(int))); - if (hash_swap_index_addr_ == nullptr) { - MS_LOG(ERROR) << "malloc hash_swap_index failed, malloc size " << batch_elements_ * sizeof(int); - return kLiteMemoryFailed; - } - return kSuccess; -} - -Status EmbeddingCache::Init(uint32_t device_id, const void *context, mindspore::MSTensor host_cache_tensor, - mindspore::MSTensor device_tensor) { - auto ret = Init(host_cache_tensor, device_tensor); - if (ret != kSuccess) { - return ret; - } - cache_ = lite::FactoryManagerBase::Instance().GetProduct("lfu"); - if (cache_ == nullptr) { - MS_LOG(ERROR) << "malloc LFUCacheAlgorithm failed"; - return kLiteMemoryFailed; - } - ret = cache_->Init(device_cache_size_, min_host_index_, max_host_index_); - if (ret != kSuccess) { - MS_LOG(ERROR) << "init cache failed," << ret.CodeAsString(); - return kLiteError; - } - - device_cache_ = lite::FactoryManagerBase::Instance().GetProduct("gpu"); - if (device_cache_ == nullptr) { - MS_LOG(ERROR) << "get cache failed"; - return kLiteMemoryFailed; - } - if (!device_cache_->InitDevice(device_id, context)) { - MS_LOG(ERROR) << "init device failed"; - return kLiteError; - } - ret = MallocCacheMemory(); - if (ret != kSuccess) { - return ret; - } - - MS_LOG(INFO) << "init succ, rank_group_size_ num:" << rank_group_size_ << ", rank id:" << rank_id_ - << ", vocab_size_:" << vocab_size_ << ", host_cache_size_:" << host_cache_size_ - << ", device_cache_size_:" << device_cache_size_ << ", embedding_size_:" << embedding_size_ - << ", batch_elements_:" << batch_elements_ << ", index begin:" << min_host_index_ - << ", index end:" << max_host_index_; - return kSuccess; -} - -Status EmbeddingCache::SetHostCacheAddr(void *addr, size_t size) { - if (sizeof_data_type_ * host_cache_size_ * embedding_size_ != size) { - return kLiteParamInvalid; - } - host_addr_ = addr; - - // copy part of host mem to device - auto ret = - device_cache_->CopyHostMemToDevice(device_addr_, addr, sizeof_data_type_ * device_cache_size_ * embedding_size_); - if (!ret) { - MS_LOG(ERROR) << "CopyHostMemToDevice failed, copy size " - << sizeof_data_type_ * device_cache_size_ * embedding_size_; - return kLiteMemoryFailed; - } - - // init cache - auto index_num = device_cache_size_; - for (size_t i = 0; i < index_num; i++) { - cache_->Put(min_host_index_ + i, i); - } - - return kSuccess; -} - -Status EmbeddingCache::SetDeviceCacheAddr(void *device_mem_addr, size_t size) { - if (sizeof_data_type_ * device_cache_size_ * embedding_size_ != size) { - return kLiteParamInvalid; - } - - device_addr_ = device_mem_addr; - SetHostCacheAddr(host_addr_, sizeof_data_type_ * host_cache_size_ * embedding_size_); - - return kSuccess; -} - -Status EmbeddingCache::CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *cache_index) { - std::vector need_swap_indies; - std::vector need_swap_indies_cache_index; - auto ret = - cache_->CheckCacheHit(batch_ids, batch_ids_len, cache_index, &need_swap_indies, &need_swap_indies_cache_index); - if (ret != kSuccess) { - MS_LOG(ERROR) << "CheckCacheHit failed"; - return ret; - } - auto swap_indices_size = need_swap_indies.size(); - if (swap_indices_size > 0) { - LookUpTableTask(swap_indices_size, host_cache_size_, static_cast(host_addr_), need_swap_indies.data(), - static_cast(hash_swap_value_addr_), embedding_size_ * sizeof_data_type_, min_host_index_); - - auto device_cache_ret = device_cache_->CopyHostMemToDevice(hash_swap_value_device_addr_, hash_swap_value_addr_, - swap_indices_size * embedding_size_ * sizeof_data_type_); - if (!device_cache_ret) { - MS_LOG(ERROR) << "copy swap value to device failed"; - return kLiteMemoryFailed; - } - - device_cache_ret = device_cache_->CopyHostMemToDevice(hash_swap_index_addr_, need_swap_indies_cache_index.data(), - swap_indices_size * sizeof(int)); - if (!device_cache_ret) { - MS_LOG(ERROR) << "copy swap indies to device failed"; - return kLiteMemoryFailed; - } - - device_cache_ret = device_cache_->HashSwapIn(device_addr_, hash_swap_value_device_addr_, hash_swap_index_addr_, - device_cache_size_, embedding_size_, swap_indices_size); - if (!device_cache_ret) { - MS_LOG(ERROR) << "HashSwapIn failed"; - return kLiteMemoryFailed; - } - } - - return kSuccess; -} -} // namespace cache -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/embedding_cache.h b/mindspore-lite/src/extendrt/delegate/parameter_cache/embedding_cache.h deleted file mode 100644 index e4bab4176694d208b0958412226b228e58376191..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/embedding_cache.h +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_H_ -#include -#include -#include -#include "include/api/status.h" -#include "include/api/data_type.h" -#include "src/common/log_adapter.h" -#include "src/extendrt/delegate/parameter_cache/cache_algorithm.h" -#include "src/extendrt/delegate/parameter_cache/cache_mem_base.h" - -namespace mindspore { -namespace cache { -class EmbeddingCache { - public: - EmbeddingCache(size_t vocab_size, size_t device_cache_size, size_t batch_elements, int rank_id, int rank_group_size) - : vocab_size_(vocab_size), - device_cache_size_(device_cache_size), - batch_elements_(batch_elements), - rank_id_(rank_id), - rank_group_size_(rank_group_size) { - MS_ASSERT(rank_group_size_ != 0); - auto local_shard_size = static_cast(std::ceil(static_cast(vocab_size_) / rank_group_size_)); - min_host_index_ = local_shard_size * rank_id_; - max_host_index_ = std::min(min_host_index_ + local_shard_size, static_cast(vocab_size_)); - host_cache_size_ = max_host_index_ - min_host_index_; - - MS_LOG(INFO) << "rank_group_size_ num:" << rank_group_size_ << ", rank id:" << rank_id_ - << ", vocab_size_:" << vocab_size_ << ", host_cache_size_:" << host_cache_size_ - << ", index begin:" << min_host_index_ << ", index end:" << max_host_index_; - } - - ~EmbeddingCache(); - Status Init(uint32_t device_id, const void *context, mindspore::MSTensor host_cache_tensor, - mindspore::MSTensor device_tensor); - Status SetHostCacheAddr(void *addr, size_t size); - Status SetDeviceCacheAddr(void *host_mem_addr, size_t size); - Status CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *hash_index); - size_t GetDeviceStartIndex() { return device_start_index_; } - - private: - Status Init(mindspore::MSTensor host_cache_tensor, mindspore::MSTensor device_tensor); - Status MallocCacheMemory(); - - private: - std::shared_ptr device_cache_{nullptr}; - std::shared_ptr cache_{nullptr}; - - size_t vocab_size_{0}; // total size - size_t host_cache_size_{0}; // local host size - size_t device_cache_size_{0}; // local device cache size - size_t device_start_index_{0}; - size_t embedding_size_{0}; - size_t batch_elements_{0}; - - DataType data_type_{DataType::kNumberTypeFloat32}; - size_t sizeof_data_type_{0}; - - void *device_addr_{nullptr}; // hash_info.device_address.addr - void *host_addr_{nullptr}; - - int *hash_swap_index_addr_; // embedding_device_cache_->hash_swap_index_addr_ - void *hash_swap_value_addr_; - void *hash_swap_value_device_addr_; - - int rank_id_; - int rank_group_size_; - int min_host_index_{0}; - int max_host_index_{0}; -}; -} // namespace cache -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_H_ diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/embedding_cache_manager.cc b/mindspore-lite/src/extendrt/delegate/parameter_cache/embedding_cache_manager.cc deleted file mode 100644 index 7e7eaa3fe0d0830a321bbc7eca6fa60815d9e4fd..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/embedding_cache_manager.cc +++ /dev/null @@ -1,194 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "src/extendrt/delegate/parameter_cache/embedding_cache_manager.h" -#include -#include -#include -#include "src/common/log_adapter.h" -#include "include/errorcode.h" - -namespace { -constexpr size_t kGatherInputsSize = 3; -} -namespace mindspore { -namespace cache { -Status EmbeddingCacheManager::Init(const std::string &cache_model_path, size_t vocab_size, size_t device_cache_size) { - if (cache_model_path.empty() || vocab_size == 0 || device_cache_size >= vocab_size) { - MS_LOG(INFO) << "no cache model , vocab_size " << vocab_size << ", device_cache_size " << device_cache_size; - return kSuccess; - } - - host_cache_model_ = std::make_shared(); - if (host_cache_model_ == nullptr) { - MS_LOG(ERROR) << "HostCacheModel malloc failed"; - return kLiteMemoryFailed; - } - auto ret = host_cache_model_->LoadCache(cache_model_path); - if (ret != kSuccess) { - MS_LOG(ERROR) << "load cache failed"; - return ret; - } - vocab_size_ = vocab_size; - device_cache_size_ = device_cache_size; - - MS_LOG(INFO) << "cache manager init succ, cache model" << cache_model_path << " , vocab_size " << vocab_size - << ", device_cache_size " << device_cache_size; - return ret; -} - -Status EmbeddingCacheManager::Init(DelegateModel *model, size_t vocab_size, - size_t device_cache_size) { - if (model == nullptr || vocab_size == 0 || device_cache_size >= vocab_size) { - MS_LOG(INFO) << "no cache model , vocab_size " << vocab_size << ", device_cache_size " << device_cache_size; - return kSuccess; - } - - host_cache_model_ = std::make_shared(); - if (host_cache_model_ == nullptr) { - MS_LOG(ERROR) << "HostCacheModel malloc failed"; - return kLiteMemoryFailed; - } - auto ret = host_cache_model_->LoadCache(model); - if (ret != kSuccess) { - MS_LOG(ERROR) << "load cache failed"; - return ret; - } - vocab_size_ = vocab_size; - device_cache_size_ = device_cache_size; - - MS_LOG(INFO) << "cache manager init succ, vocab_size " << vocab_size << ", device_cache_size " << device_cache_size; - return ret; -} - -bool EmbeddingCacheManager::CheckIsCacheKernel(kernel::Kernel *kernel) { - if (host_cache_model_ == nullptr) { - return false; - } - return host_cache_model_->CheckIsCacheKernel(kernel); -} - -Status EmbeddingCacheManager::InitCacheKernel(kernel::Kernel *kernel, uint32_t device_id, const void *context) { - if (host_cache_model_ == nullptr) { - MS_LOG(ERROR) << "cache model is nullptr, kernel " << kernel->name() << " init cache failed"; - return kLiteError; - } - auto host_cache_tensor = host_cache_model_->GetHostCacheTensor(kernel); - if (host_cache_tensor == nullptr) { - MS_LOG(ERROR) << kernel->name() << ": invalid cache kernel"; - return kLiteError; - } - - // only support embedding cache - if (kernel->type() != schema::PrimitiveType_Gather) { - MS_LOG(ERROR) << kernel->name() << " is not embedding kernel"; - return kLiteError; - } - MS_ASSERT(kernel->size() == kGatherInputsSize); - auto device_tensor = kernel->inputs()[0]; - size_t batch_elements = kernel->inputs()[1].ElementNum(); - auto cache = - std::make_shared(vocab_size_, device_cache_size_, batch_elements, rank_id_, rank_group_size_); - if (cache == nullptr) { - MS_LOG(ERROR) << kernel->name() << ": malloc EmbeddingCache failed"; - return kLiteError; - } - - auto ret = cache->Init(device_id, context, host_cache_tensor, device_tensor); - if (ret != kSuccess) { - MS_LOG(ERROR) << kernel->name() << ": EmbeddingCache init failed"; - return kLiteError; - } - - caches_[device_tensor.Name()] = cache; - MS_LOG(INFO) << kernel->name() << " is cache kernel, input tensor " << kernel->inputs()[1].Name() << ", cache tensor " - << device_tensor.Name(); - - return kSuccess; -} - -bool EmbeddingCacheManager::IsCacheTensor(mindspore::MSTensor tensor) { - if (host_cache_model_ == nullptr) { - return false; - } - auto cache = caches_.find(tensor.Name()); - if (cache != caches_.end()) { - return true; - } - return false; -} - -std::vector EmbeddingCacheManager::GetCacheShape(mindspore::MSTensor tensor) { - std::vector shape = tensor.Shape(); - if (shape.size() > 0 && IsCacheTensor(tensor)) { - shape[0] = device_cache_size_; - } - return shape; -} - -size_t EmbeddingCacheManager::GetCacheDataSize(mindspore::MSTensor tensor) { - auto data_size = tensor.DataSize(); - auto &shape = tensor.Shape(); - if (shape.size() > 0 && IsCacheTensor(tensor) && shape[0] > 0) { - data_size = data_size * device_cache_size_ / shape[0]; - } - return data_size; -} - -Status EmbeddingCacheManager::SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size) { - auto cache_iter = caches_.find(tensor_name); - if (cache_iter == caches_.end() || cache_iter->second == nullptr) { - MS_LOG(ERROR) << "not find cache, " << tensor_name; - return kLiteError; - } - auto cache = cache_iter->second; - return cache->SetDeviceCacheAddr(device_mem_addr, size); -} - -// device_addr is model input device addr -int EmbeddingCacheManager::CacheHandle(const std::string &tensor_name, mindspore::MSTensor model_input_tensor, - void *model_input_device_addr) { - auto cache_iter = caches_.find(tensor_name); - if (cache_iter == caches_.end()) { - MS_LOG(ERROR) << "not find cache, " << tensor_name; - return lite::RET_ERROR; - } - auto cache = cache_iter->second; - hash_indices_.resize(model_input_tensor.ElementNum()); - auto ret = cache->CheckCacheHit(static_cast(model_input_tensor.MutableData()), hash_indices_.size(), - hash_indices_.data()); - if (ret != kSuccess) { - MS_LOG(ERROR) << "CheckCacheHit failed, " << model_input_tensor.Name(); - return lite::RET_ERROR; - } - - for (size_t i = 0; i < hash_indices_.size(); i++) { - if (hash_indices_[i] != -1) { - hash_indices_[i] += cache->GetDeviceStartIndex(); - } - } - - auto cuda_ret = cudaMemcpy(model_input_device_addr, hash_indices_.data(), hash_indices_.size() * sizeof(int), - cudaMemcpyHostToDevice); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "copy mem failed, " << model_input_tensor.Name(); - return lite::RET_ERROR; - } - MS_LOG(INFO) << "cache handle succ, " << model_input_tensor.Name() << "," << tensor_name; - - return lite::RET_OK; -} -} // namespace cache -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/embedding_cache_manager.h b/mindspore-lite/src/extendrt/delegate/parameter_cache/embedding_cache_manager.h deleted file mode 100644 index d74acb3b898d9fb9f9e0c36343b7c8c6d236022c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/embedding_cache_manager.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_MANAGER_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_MANAGER_H_ -#include -#include -#include -#include -#include "include/api/kernel.h" -#include "include/api/status.h" -#include "include/api/data_type.h" -#include "src/extendrt/delegate/parameter_cache/embedding_cache.h" -#include "src/extendrt/delegate/parameter_cache/load_host_cache_model.h" -#include "src/extendrt/delegate/tensorrt/distribution/distribution_base.h" - -namespace mindspore { -namespace cache { -class EmbeddingCacheManager { - public: - EmbeddingCacheManager() { - rank_id_ = lite::GetRankID(); - rank_group_size_ = lite::GetGPUGroupSize(); - } - Status Init(const std::string &cache_model_path, size_t vocab_size, size_t device_cache_size); - Status Init(DelegateModel *model, size_t vocab_size, size_t device_cache_size); - bool CheckIsCacheKernel(kernel::Kernel *kernel); - Status InitCacheKernel(kernel::Kernel *kernel, uint32_t device_id, const void *context); - bool IsCacheTensor(mindspore::MSTensor tensor); - int CacheHandle(const std::string &tensor_name, mindspore::MSTensor model_input_tensor, void *device_addr); - Status SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size); - std::vector GetCacheShape(mindspore::MSTensor tensor); - size_t GetCacheDataSize(mindspore::MSTensor tensor); - - private: - std::map> caches_; - std::vector hash_indices_; - int rank_id_{0}; - int rank_group_size_{1}; - - std::shared_ptr host_cache_model_; - size_t vocab_size_; - size_t device_cache_size_; -}; -} // namespace cache -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_MANAGER_H_ diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/factory_mgr_base.h b/mindspore-lite/src/extendrt/delegate/parameter_cache/factory_mgr_base.h deleted file mode 100644 index 5a96cf3ea88aca628b334b056ec0269a10b430a4..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/factory_mgr_base.h +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_FACTORY_MGR_BASE_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_FACTORY_MGR_BASE_H_ -#include -#include -#include "include/api/status.h" - -namespace mindspore { -namespace lite { -template -class ProcductRegistrar { - public: - virtual std::shared_ptr Create() = 0; - - protected: - ProcductRegistrar() {} - virtual ~ProcductRegistrar() {} - - private: - ProcductRegistrar(const ProcductRegistrar &); - const ProcductRegistrar &operator=(const ProcductRegistrar &); -}; - -template -class FactoryManagerBase { - public: - static FactoryManagerBase &Instance() { - static FactoryManagerBase instance; - return instance; - } - void RegProduct(const KEY &key, ProcductRegistrar *registrar) { registrars[key] = registrar; } - - std::shared_ptr GetProduct(const KEY &key) { - auto registrar_iter = registrars.find(key); - if (registrar_iter != registrars.end()) { - if (registrar_iter->second != nullptr) { - return registrar_iter->second->Create(); - } - } - return nullptr; - } - - private: - FactoryManagerBase() = default; - ~FactoryManagerBase() = default; - FactoryManagerBase(const FactoryManagerBase &); - const FactoryManagerBase &operator=(const FactoryManagerBase &); - - private: - std::map *> registrars; -}; - -template -class CommonProcductRegistrar : public ProcductRegistrar { - public: - explicit CommonProcductRegistrar(const KEY &key) { - FactoryManagerBase::Instance().RegProduct(key, this); - } - std::shared_ptr Create() { return std::make_shared(); } -}; - -#define RET_COMMON_PRODUCT_REGISTRAR(KEY, PRODUCT, PRODUCT_IMPL, key, name) \ - static mindspore::lite::CommonProcductRegistrar g_commonProcductRegistrar##name(key); -} // namespace lite -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_FACTORY_MGR_BASE_H_ diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/gpu/gpu_cache_mem.cc b/mindspore-lite/src/extendrt/delegate/parameter_cache/gpu/gpu_cache_mem.cc deleted file mode 100644 index 6adf8d27ba7f6e21fa5d08ffc70625a7035f128e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/gpu/gpu_cache_mem.cc +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/parameter_cache/gpu/gpu_cache_mem.h" -#include -#include -#include "src/extendrt/delegate/tensorrt/cuda_impl/hash.cuh" -#include "plugin/res_manager/gpu/device/cuda_driver.h" -#include "src/common/log_adapter.h" -#include "src/extendrt/delegate/parameter_cache/factory_mgr_base.h" -namespace mindspore { -namespace cache { -namespace gpu { -RET_COMMON_PRODUCT_REGISTRAR(std::string, cache::CacheMemBase, cache::gpu::GPUCacheMem, "gpu", GPUCacheMem); -bool GPUCacheMem::InitDevice(uint32_t device_id, const void *context) { - auto cuda_ret = cudaSetDevice(static_cast(device_id)); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "Failed to set device id " << device_id << ", cuda_ret " << cuda_ret << " " - << cudaGetErrorString(cuda_ret); - return false; - } - if (context != nullptr) { - stream_ = *(reinterpret_cast(context)); - return true; - } - - cuda_ret = cudaStreamCreate(&stream_); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "Cuda create stream failed, cuda_ret " << cuda_ret << " " << cudaGetErrorString(cuda_ret); - return false; - } - - return true; -} - -void *GPUCacheMem::MallocMemory(size_t size) { - void *device_ptr = nullptr; - auto cuda_ret = cudaMalloc(&device_ptr, size); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "Cuda Malloc failed for size:" << size << ", cuda_ret " << cuda_ret << " " - << cudaGetErrorString(cuda_ret); - return nullptr; - } - MS_LOG(DEBUG) << "cudaMalloc size: " << size; - return device_ptr; -} - -void GPUCacheMem::FreeMemory(void *device_addr) { - auto cuda_ret = cudaFree(device_addr); - if (cuda_ret != cudaSuccess && cuda_ret != cudaErrorCudartUnloading) { - MS_LOG(WARNING) << "free cuda memory failed, " - << ", cuda_ret " << cuda_ret << " " << cudaGetErrorString(cuda_ret); - } -} - -bool GPUCacheMem::SynchronizeStream() { - auto cuda_ret = cudaStreamSynchronize(stream_); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "Cuda sync stream failed, cuda_ret " << cuda_ret << " " << cudaGetErrorString(cuda_ret); - return false; - } - - return true; -} - -bool GPUCacheMem::CopyHostMemToDevice(void *dst, const void *src, size_t size) { - if (dst == nullptr) { - MS_LOG(ERROR) << "dst is nullptr"; - return false; - } - if (src == nullptr) { - MS_LOG(ERROR) << "src is nullptr"; - return false; - } - - auto cuda_ret = cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, stream_); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "Cuda memcpy failed, cuda_ret " << cuda_ret << " " << cudaGetErrorString(cuda_ret); - return false; - } - - return true; -} - -bool GPUCacheMem::CopyDeviceMemToHost(void *dst, const void *src, size_t size) { - if (dst == nullptr) { - MS_LOG(ERROR) << "dst is nullptr"; - return false; - } - if (src == nullptr) { - MS_LOG(ERROR) << "src is nullptr"; - return false; - } - - auto cuda_ret = cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, stream_); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "Cuda memcpy failed, cuda_ret " << cuda_ret << " " << cudaGetErrorString(cuda_ret); - return false; - } - - return true; -} - -bool GPUCacheMem::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t, - size_t embedding_size, size_t swap_out_size) { - if (hash_table_addr == nullptr) { - MS_LOG(ERROR) << "hash_table_addr is nullptr"; - return false; - } - if (swap_out_value_addr == nullptr) { - MS_LOG(ERROR) << "swap_out_value_addr is nullptr"; - return false; - } - if (swap_out_index_addr == nullptr) { - MS_LOG(ERROR) << "swap_out_index_addr is nullptr"; - return false; - } - - DoHashSwapOut(reinterpret_cast(hash_table_addr), reinterpret_cast(swap_out_value_addr), - reinterpret_cast(swap_out_index_addr), swap_out_size, embedding_size, stream_); - return true; -} - -bool GPUCacheMem::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t, - size_t embedding_size, size_t swap_in_size) { - if (hash_table_addr == nullptr) { - MS_LOG(ERROR) << "hash_table_addr is nullptr"; - return false; - } - if (swap_in_value_addr == nullptr) { - MS_LOG(ERROR) << "swap_in_value_addr is nullptr"; - return false; - } - if (swap_in_index_addr == nullptr) { - MS_LOG(ERROR) << "swap_in_index_addr is nullptr"; - return false; - } - - DoHashSwapIn(reinterpret_cast(hash_table_addr), reinterpret_cast(swap_in_value_addr), - reinterpret_cast(swap_in_index_addr), swap_in_size, embedding_size, stream_); - return true; -} -} // namespace gpu -} // namespace cache -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/gpu/gpu_cache_mem.h b/mindspore-lite/src/extendrt/delegate/parameter_cache/gpu/gpu_cache_mem.h deleted file mode 100644 index 15b11febcf51077dd7289884961d19531bb18a4b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/gpu/gpu_cache_mem.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_GPU_GPU_CACHE_MEM_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_GPU_GPU_CACHE_MEM_H_ - -#include -#include -#include "src/extendrt/delegate/parameter_cache/cache_mem_base.h" - -namespace mindspore { -namespace cache { -namespace gpu { -class GPUCacheMem : public cache::CacheMemBase { - public: - GPUCacheMem() = default; - ~GPUCacheMem() override = default; - bool InitDevice(uint32_t device_id, const void *context) override; - void *MallocMemory(size_t size) override; - void FreeMemory(void *buf) override; - bool SynchronizeStream() override; - bool CopyHostMemToDevice(void *dst, const void *src, size_t size) override; - bool CopyDeviceMemToHost(void *dst, const void *src, size_t size) override; - bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size, - size_t embedding_size, size_t swap_out_size) override; - bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size, - size_t embedding_size, size_t swap_in_size) override; - - private: - cudaStream_t stream_; -}; -} // namespace gpu -} // namespace cache -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_GPU_GPU_CACHE_MEM_H_ diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/lfu_cache.cc b/mindspore-lite/src/extendrt/delegate/parameter_cache/lfu_cache.cc deleted file mode 100644 index fb81278ddc36b69cc1414d280573fb5e8aa41229..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/lfu_cache.cc +++ /dev/null @@ -1,243 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include "src/common/log_adapter.h" -#include "src/extendrt/delegate/parameter_cache/lfu_cache.h" -#include "src/extendrt/delegate/parameter_cache/factory_mgr_base.h" -namespace mindspore { -namespace cache { -RET_COMMON_PRODUCT_REGISTRAR(std::string, cache::CacheAlgorithm, cache::LFUCacheAlgorithm, "lfu", LFUCacheAlgorithm); - -LFUCacheAlgorithm::~LFUCacheAlgorithm() { - for (auto iter : key_table_) { - delete *(iter.second); - } - key_table_.clear(); - frequency_table_.clear(); -} - -Status LFUCacheAlgorithm::Init(size_t cache_size, int min_host_index, int max_host_index) { - if (cache_size == 0 || min_host_index < 0 || max_host_index <= 0) { - return kLiteParamInvalid; - } - cache_size_ = cache_size; - min_host_index_ = min_host_index; - max_host_index_ = max_host_index; - return kSuccess; -} - -CacheNoe *LFUCacheAlgorithm::GetNode(int key) { - auto key_table_iter = key_table_.find(key); - if (key_table_iter == key_table_.end()) { - return nullptr; - } - auto node_iter = key_table_iter->second; - auto node = *node_iter; - - auto node_list_iter = frequency_table_.find(key); - if (node_list_iter == frequency_table_.end()) { - return nullptr; - } - auto &node_list = node_list_iter->second; - node_list.erase(node_iter); - - if (node_list.empty()) { - frequency_table_.erase(node_list_iter); - } - - node->frequency += 1; - frequency_table_[node->frequency].emplace_front(node); - key_table_[key] = frequency_table_[node->frequency].begin(); - return node; -} - -int LFUCacheAlgorithm::Get(int key) { - auto node = GetNode(key); - if (node != nullptr) { - return node->value; - } - return -1; -} - -void LFUCacheAlgorithm::Put(int key, int value) { - auto node = GetNode(key); - if (node != nullptr) { - node->value = value; - return; - } - - if (cache_size_ == 0) { - return; - } - - CacheNoe *add_node = nullptr; - if (key_table_.size() == cache_size_) { - add_node = frequency_table_.begin()->second.back(); - key_table_.erase(add_node->key); - frequency_table_.begin()->second.pop_back(); - if (frequency_table_.begin()->second.size() == 0) { - frequency_table_.erase(frequency_table_.begin()->first); - } - add_node->value = value; - add_node->key = key; - add_node->frequency = 1; - } else { - add_node = new CacheNoe(key, 1, value); - if (add_node == nullptr) { - return; - } - } - - frequency_table_[1].emplace_front(add_node); - key_table_[key] = frequency_table_[1].begin(); -} - -void LFUCacheAlgorithm::GetHitNodesAndSwapIndex(const int *batch_ids, const size_t batch_ids_len, int *cache_index, - std::unordered_map *hit_index_nodes, - std::unordered_map> *need_swap_map) { - // 找到没有命中和命中的index - for (size_t i = 0; i < batch_ids_len; i++) { - auto key = batch_ids[i]; - if (key < min_host_index_ || key >= max_host_index_) { - cache_index[i] = -1; - // out range - continue; - } - - auto hit_iter = hit_index_nodes->find(key); - if (hit_iter != hit_index_nodes->end()) { - auto node = hit_iter->second; - node->frequency += 1; - cache_index[i] = node->value; - continue; - } - - auto swap_iter = need_swap_map->find(key); - if (swap_iter != need_swap_map->end()) { - swap_iter->second.push_back(i); - continue; - } - - auto node_iter_iter = key_table_.find(key); - if (node_iter_iter == key_table_.end()) { - (*need_swap_map)[key].push_back(i); - continue; - } - auto node_iter = node_iter_iter->second; - auto node = *node_iter; - - auto node_list_iter = frequency_table_.find(node->frequency); - if (node_list_iter == frequency_table_.end()) { - continue; - } - auto &node_list = node_list_iter->second; - node_list.erase(node_iter); - - if (node_list.empty()) { - frequency_table_.erase(node_list_iter); - } - // hit - node->frequency += 1; - cache_index[i] = node->value; - (*hit_index_nodes)[key] = node; - } - return; -} - -std::list LFUCacheAlgorithm::GetSwapNodes(const std::unordered_map> &need_swap_map) { - std::list need_swap_nodes; - auto swap_size = need_swap_map.size(); - - while (swap_size > 0 && !frequency_table_.empty()) { - auto node_list_iter = frequency_table_.begin(); - if (node_list_iter->second.size() > swap_size) { - auto iter = node_list_iter->second.begin(); - std::advance(iter, swap_size); - need_swap_nodes.splice(need_swap_nodes.end(), node_list_iter->second, node_list_iter->second.begin(), iter); - swap_size = 0; - } else { - swap_size -= node_list_iter->second.size(); - need_swap_nodes.splice(need_swap_nodes.end(), node_list_iter->second); - frequency_table_.erase(node_list_iter); - } - } - return need_swap_nodes; -} - -Status LFUCacheAlgorithm::CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *cache_index, - std::vector *need_swap_indies, - std::vector *need_swap_indies_cache_index) { - if (batch_ids == nullptr) { - MS_LOG(ERROR) << "batch_ids is nullptr"; - return kLiteNullptr; - } - if (cache_index == nullptr) { - MS_LOG(ERROR) << "cache_index is nullptr"; - return kLiteNullptr; - } - std::unordered_map> need_swap_map; - std::unordered_map hit_index_nodes; - GetHitNodesAndSwapIndex(batch_ids, batch_ids_len, cache_index, &hit_index_nodes, &need_swap_map); - - // get need_swap_indies.size() least recently used node - std::list need_swap_nodes = GetSwapNodes(need_swap_map); - - // 更新老节点的值 - { - if (need_swap_map.size() != need_swap_nodes.size()) { - MS_LOG(ERROR) << " need_swap_map.size() " << need_swap_map.size() << " != need_swap_nodes.size() " - << need_swap_nodes.size(); - return kLiteError; - } - need_swap_indies_cache_index->reserve(need_swap_map.size()); - auto need_swap_map_iter = need_swap_map.begin(); - for (auto iter = need_swap_nodes.begin(); - iter != need_swap_nodes.end() && need_swap_map_iter != need_swap_map.end(); iter++, need_swap_map_iter++) { - auto node = *iter; - key_table_.erase(node->key); - node->key = need_swap_map_iter->first; - node->frequency = 1; - for (auto index : need_swap_map_iter->second) { - cache_index[index] = node->value; - } - need_swap_indies->push_back(need_swap_map_iter->first); - need_swap_indies_cache_index->push_back(node->value); - MS_LOG(INFO) << "device index " << node->value << ",for host index " << need_swap_map_iter->first; - key_table_[(*iter)->key] = iter; - } - - auto node_list_iter = frequency_table_.begin(); - if (node_list_iter->second.size() > 0) { - auto iter = node_list_iter->second.begin(); - if ((*iter)->frequency == 1) { - node_list_iter->second.splice(node_list_iter->second.begin(), need_swap_nodes); - } else { - frequency_table_[1] = need_swap_nodes; - } - } else { - frequency_table_[1] = need_swap_nodes; - } - } - for (auto node_iter : hit_index_nodes) { - auto node = node_iter.second; - frequency_table_[node->frequency].emplace_front(node); - key_table_[node->key] = frequency_table_[node->frequency].begin(); - } - return kSuccess; -} -} // namespace cache -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/lfu_cache.h b/mindspore-lite/src/extendrt/delegate/parameter_cache/lfu_cache.h deleted file mode 100644 index 889978ef2a33b2c624f58871d25ef7043f98634e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/lfu_cache.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_LFU_CACHE_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_LFU_CACHE_H_ - -#include -#include -#include -#include -#include "include/api/status.h" -#include "src/extendrt/delegate/parameter_cache/cache_algorithm.h" -namespace mindspore { -namespace cache { -class LFUCacheAlgorithm : public CacheAlgorithm { - public: - LFUCacheAlgorithm() {} - ~LFUCacheAlgorithm() override; - - int Get(int key) override; - void Put(int key, int value) override; - Status Init(size_t cache_size, int min_host_index, int max_host_index) override; - Status CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *cache_index, - std::vector *need_swap_indies, std::vector *need_swap_indies_cache_index) override; - - private: - CacheNoe *GetNode(int key); - void GetHitNodesAndSwapIndex(const int *batch_ids, const size_t batch_ids_len, int *cache_index, - std::unordered_map *hit_index_nodes, - std::unordered_map> *need_swap_map); - std::list GetSwapNodes(const std::unordered_map> &need_swap_map); - - std::unordered_map::iterator> key_table_; - std::map> frequency_table_; - size_t cache_size_{0}; - - int min_host_index_{0}; - int max_host_index_{1}; -}; -} // namespace cache -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_LFU_CACHE_H_ diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/load_host_cache_model.cc b/mindspore-lite/src/extendrt/delegate/parameter_cache/load_host_cache_model.cc deleted file mode 100644 index 4fefbe062f88d92b894ff0bad8a12e63eafb0476..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/load_host_cache_model.cc +++ /dev/null @@ -1,149 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "src/extendrt/delegate/parameter_cache/load_host_cache_model.h" -#include "src/common/log_adapter.h" -#include "src/common/common.h" -#include "include/errorcode.h" -#include "src/common/file_utils.h" - -namespace { -constexpr size_t kGatherInputsSize = 3; -} -namespace mindspore { -namespace cache { -HostCacheModel::~HostCacheModel() { - if (cache_model_ != nullptr) { - delete cache_model_; - cache_model_ = nullptr; - } -} -MSTensor *SchemaTensorToMSTensor(lite::SchemaTensorWrapper *schema_tensor_wrapper, - mindspore::schema::Tensor *schema_tensor) { - MS_CHECK_TRUE_RET(schema_tensor_wrapper != nullptr && schema_tensor != nullptr, nullptr); - std::vector shape; - for (size_t j = 0; j < schema_tensor->dims()->size(); j++) { - shape.push_back(schema_tensor->dims()->data()[j]); - } - std::string tensor_name; - if (schema_tensor->name() != nullptr) { - tensor_name = schema_tensor->name()->str(); - } - return MSTensor::CreateRefTensor(tensor_name, (DataType)schema_tensor->dataType(), shape, - schema_tensor_wrapper->data(), schema_tensor_wrapper->length()); -} - -Status HostCacheModel::LoadCache(const std::string &model_path) { - cache_model_ = lite::LiteImportFromPath(model_path.c_str()); - if (cache_model_ == nullptr) { - MS_LOG(ERROR) << "Import model failed"; - return kLiteGraphFileError; - } - - auto allTensors = cache_model_->graph_.all_tensors_; - for (auto node : cache_model_->graph_.all_nodes_) { - // only support embedding cache - if (node == nullptr || node->node_type_ != schema::PrimitiveType_Gather) { - continue; - } - - auto input_index = node->input_indices_[0]; - if (input_index > allTensors.size() - 1) { - MS_LOG(ERROR) << "invalid kernel input, input_index " << input_index << ",allTensors.size() " - << allTensors.size(); - return kLiteOutOfTensorRange; - } - auto schema_tensor_wrapper = cache_model_->GetSchemaTensor(input_index); - if (schema_tensor_wrapper == nullptr) { - MS_LOG(ERROR) << "invalid kernel input, input_index " << input_index; - return kLiteOutOfTensorRange; - } - - auto schema_tensor = allTensors[input_index]; - if (schema_tensor != nullptr && schema_tensor_wrapper->data() != nullptr) { - auto tensor = SchemaTensorToMSTensor(schema_tensor_wrapper, schema_tensor); - if (tensor == nullptr) { - return kLiteMemoryFailed; - } - cache_tensor_[tensor->Name()] = *tensor; - MS_LOG(INFO) << tensor->Name() << " is cache tensor, and the node is [" << node->name_ << "]"; - delete tensor; - } - } - return kSuccess; -} - -size_t GetVocabSize(kernel::Kernel *kernel) { - size_t vocab_size = 0; - auto cache_config = kernel->GetConfig(lite::kMSCacheSection); - auto vocab_size_iter = cache_config.find(lite::kMSCacheVocabSizeKey); - if (vocab_size_iter == cache_config.end()) { - return vocab_size; - } - - auto vocab_size_opt = lite::GenericParseValue(vocab_size_iter->second); - if (!vocab_size_opt.IsNone()) { - vocab_size = vocab_size_opt.Get(); - } - return vocab_size; -} - -Status HostCacheModel::LoadCache(DelegateModel *model) { - KernelIter from, end; - for (KernelIter iter = model->BeginKernelIterator(); iter != model->EndKernelIterator(); iter++) { - kernel::Kernel *kernel = *iter; - // only support embedding cache - if (kernel->type() != schema::PrimitiveType_Gather) { - continue; - } - MS_ASSERT(kernel->size() == kGatherInputsSize); - auto tensor = kernel->inputs()[0]; - if (tensor.Data() == nullptr) { - continue; - } - - size_t vocab_size = GetVocabSize(kernel); - if (vocab_size == 0) { - continue; - } - - cache_tensor_[tensor.Name()] = tensor; - } - return mindspore::kSuccess; -} - -bool HostCacheModel::CheckIsCacheKernel(kernel::Kernel *kernel) { - if (GetHostCacheTensor(kernel) == nullptr) { - return false; - } - return true; -} - -MSTensor HostCacheModel::GetHostCacheTensor(kernel::Kernel *kernel) { - if (kernel != nullptr && kernel->size() > 0) { - auto iter = cache_tensor_.find(kernel->inputs()[0].Name()); - if (iter != cache_tensor_.end()) { - return iter->second; - } - } - return MSTensor(nullptr); -} -} // namespace cache -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/delegate/parameter_cache/load_host_cache_model.h b/mindspore-lite/src/extendrt/delegate/parameter_cache/load_host_cache_model.h deleted file mode 100644 index 8340993e5b958e6c2d3d3ac39bf6a3af97dbc4fb..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/parameter_cache/load_host_cache_model.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_LOAD_HOST_CACHE_MODEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_LOAD_HOST_CACHE_MODEL_H_ - -#include -#include -#include "include/api/status.h" -#include "include/api/data_type.h" -#include "include/api/types.h" -#include "include/api/kernel.h" -#include "include/api/delegate.h" -#include "src/litert/lite_model.h" - -namespace mindspore { -namespace cache { -class HostCacheModel { - public: - HostCacheModel() = default; - ~HostCacheModel(); - Status LoadCache(const std::string &model_path); - Status LoadCache(DelegateModel *model); - bool CheckIsCacheKernel(kernel::Kernel *kernel); - MSTensor GetHostCacheTensor(kernel::Kernel *kernel); - - private: - std::map cache_tensor_; - mindspore::lite::LiteModel *cache_model_{nullptr}; - char *model_buf_{nullptr}; - size_t model_size_; -}; -} // namespace cache -} // namespace mindspore -#endif // MINDSPORE_LITE_EMBEDDING_CACHE_H_ diff --git a/mindspore-lite/src/extendrt/delegate/plugin/ascend_native_executor_plugin.cc b/mindspore-lite/src/extendrt/delegate/plugin/ascend_acl_executor_plugin.cc similarity index 47% rename from mindspore-lite/src/extendrt/delegate/plugin/ascend_native_executor_plugin.cc rename to mindspore-lite/src/extendrt/delegate/plugin/ascend_acl_executor_plugin.cc index 332fc6a6e20a61e7aaa08def14754bc167ab5016..b64551168f0630c9c1d631f0ed2af346d8dde98d 100644 --- a/mindspore-lite/src/extendrt/delegate/plugin/ascend_native_executor_plugin.cc +++ b/mindspore-lite/src/extendrt/delegate/plugin/ascend_acl_executor_plugin.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023 Huawei Technologies Co., Ltd + * Copyright 2022-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "extendrt/delegate/plugin/ascend_native_executor_plugin.h" +#include "extendrt/delegate/plugin/ascend_acl_executor_plugin.h" #include #include "src/common/log_adapter.h" #if !defined(_WIN32) @@ -22,56 +22,55 @@ namespace mindspore::lite { namespace { -constexpr auto kAscendNativePluginSoName = "libascend_native_plugin.so"; -constexpr auto kFunCreateAscendNativePluginImpl = "CreateAscendNativeExecutorPluginImpl"; +constexpr auto kAscendAclPluginSoName = "libascend_acl_plugin.so"; +constexpr auto kFunCreateAscendAclPluginImpl = "CreateAscendAclExecutorPluginImpl"; } // namespace -AscendNativeExecutorPlugin::AscendNativeExecutorPlugin() = default; -AscendNativeExecutorPlugin::~AscendNativeExecutorPlugin() { +AscendAclExecutorPlugin::AscendAclExecutorPlugin() = default; +AscendAclExecutorPlugin::~AscendAclExecutorPlugin() { #if !defined(_WIN32) - MS_LOG(DEBUG) << "~AscendNativeExecutorPlugin() begin."; - ascend_native_plugin_impl_ = nullptr; + MS_LOG(DEBUG) << "~AscendAclExecutorPlugin() begin."; + acl_plugin_impl_ = nullptr; DLSoClose(handle_); - MS_LOG(DEBUG) << "~AscendNativeExecutorPlugin() end."; + MS_LOG(DEBUG) << "~AscendAclExecutorPlugin() end."; #endif } -AscendNativeExecutorPlugin &AscendNativeExecutorPlugin::GetInstance() { - static AscendNativeExecutorPlugin instance; +AscendAclExecutorPlugin &AscendAclExecutorPlugin::GetInstance() { + static AscendAclExecutorPlugin instance; return instance; } -bool AscendNativeExecutorPlugin::Register() { +bool AscendAclExecutorPlugin::Register() { #if !defined(_WIN32) if (is_registered_) { return true; } auto ret = - DLSoPath({"libmindspore-lite.so", "_c_lite", "tools/converter/lib"}, kAscendNativePluginSoName, &plugin_path_); + DLSoPath({"libmindspore-lite.so", "_c_lite", "tools/converter/lib"}, kAscendAclPluginSoName, &plugin_path_); if (ret != kSuccess) { - MS_LOG(ERROR) << "Get real path of " << kAscendNativePluginSoName << " failed."; + MS_LOG(ERROR) << "Get real path of " << kAscendAclPluginSoName << " failed."; return false; } - MS_LOG(INFO) << "Find ascend ge plugin so success, path = " << plugin_path_; + MS_LOG(INFO) << "Find ascend acl plugin so success, path = " << plugin_path_; void *function = nullptr; - ret = DLSoOpen(plugin_path_, kFunCreateAscendNativePluginImpl, &handle_, &function); + ret = DLSoOpen(plugin_path_, kFunCreateAscendAclPluginImpl, &handle_, &function); if (ret != kSuccess) { MS_LOG(ERROR) << "DLSoOpen failed, so path: " << plugin_path_ << ", err: " << ret.ToString(); return false; } - auto create_plugin_impl_func = reinterpret_cast(function); + auto create_plugin_impl_func = reinterpret_cast(function); if (create_plugin_impl_func == nullptr) { - MS_LOG(ERROR) << "Cast " << kFunCreateAscendNativePluginImpl << " failed."; + MS_LOG(ERROR) << "Cast " << kFunCreateAscendAclPluginImpl << " failed."; return false; } - ascend_native_plugin_impl_ = std::shared_ptr(create_plugin_impl_func()); - if (ascend_native_plugin_impl_ == nullptr) { - MS_LOG(ERROR) << "Create Ascend native plugin implement failed."; + acl_plugin_impl_ = std::shared_ptr(create_plugin_impl_func()); + if (acl_plugin_impl_ == nullptr) { + MS_LOG(ERROR) << "Create Ascend acl plugin implement failed."; return false; } is_registered_ = true; - MS_LOG(INFO) << "Register Ascend native plugin success."; + MS_LOG(INFO) << "Register Ascend acl plugin success."; #endif return true; } - } // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/plugin/ascend_native_executor_plugin.h b/mindspore-lite/src/extendrt/delegate/plugin/ascend_acl_executor_plugin.h similarity index 52% rename from mindspore-lite/src/extendrt/delegate/plugin/ascend_native_executor_plugin.h rename to mindspore-lite/src/extendrt/delegate/plugin/ascend_acl_executor_plugin.h index 90d10a324af00515350f9738e86e4f0f60242516..58816fa88d89bcb428a76eca9efc436decf4b2c6 100644 --- a/mindspore-lite/src/extendrt/delegate/plugin/ascend_native_executor_plugin.h +++ b/mindspore-lite/src/extendrt/delegate/plugin/ascend_acl_executor_plugin.h @@ -1,5 +1,5 @@ /** - * Copyright 2023 Huawei Technologies Co., Ltd + * Copyright 2022-2023 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,38 +13,34 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_PLUGIN_ASCEND_NATIVE_EXECUTOR_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_PLUGIN_ASCEND_NATIVE_EXECUTOR_PLUGIN_H_ +#ifndef MINDSPORE_LITE_SRC_EXTENDRT_PLUGIN_ASCEND_ACL_EXECUTOR_PLUGIN_H_ +#define MINDSPORE_LITE_SRC_EXTENDRT_PLUGIN_ASCEND_ACL_EXECUTOR_PLUGIN_H_ #include #include #include "include/api/context.h" #include "include/api/status.h" -#include "src/common/log_adapter.h" -#include "mindapi/base/macros.h" -#include "base/base.h" -#include "common/config_infos.h" namespace mindspore::lite { -class AscendNativeExecutorPluginImplBase { +class AscendAclExecutorPluginImplBase { public: - AscendNativeExecutorPluginImplBase() = default; - virtual ~AscendNativeExecutorPluginImplBase() = default; + AscendAclExecutorPluginImplBase() = default; + virtual ~AscendAclExecutorPluginImplBase() = default; }; -class MS_API AscendNativeExecutorPlugin { +class MS_API AscendAclExecutorPlugin { public: - static AscendNativeExecutorPlugin &GetInstance(); + static AscendAclExecutorPlugin &GetInstance(); bool Register(); private: - AscendNativeExecutorPlugin(); - ~AscendNativeExecutorPlugin(); + AscendAclExecutorPlugin(); + ~AscendAclExecutorPlugin(); std::string plugin_path_; void *handle_ = nullptr; bool is_registered_ = false; - std::shared_ptr ascend_native_plugin_impl_ = nullptr; + std::shared_ptr acl_plugin_impl_ = nullptr; }; } // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_PLUGIN_ASCEND_NATIVE_EXECUTOR_PLUGIN_H_ +#endif // MINDSPORE_LITE_SRC_EXTENDRT_PLUGIN_ASCEND_ACL_EXECUTOR_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/plugin/tensorrt_executor_plugin.cc b/mindspore-lite/src/extendrt/delegate/plugin/tensorrt_executor_plugin.cc deleted file mode 100644 index 1cab6163131c875fc18e1232b8f2afc60daa6315..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/plugin/tensorrt_executor_plugin.cc +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "extendrt/delegate/plugin/tensorrt_executor_plugin.h" -#include -#include "src/common/log_adapter.h" -#if !defined(_WIN32) -#include "extendrt/cxx_api/dlutils.h" -#endif - -namespace mindspore::lite { -namespace { -constexpr auto kTensorRtPluginSoName = "libtensorrt_plugin.so"; -constexpr auto kFunCreateTRTPluginImp = "CreateTensorRTPluginImpl"; -} // namespace -TensorRTExecutorPlugin::TensorRTExecutorPlugin() = default; -TensorRTExecutorPlugin::~TensorRTExecutorPlugin() { -#if !defined(_WIN32) - MS_LOG(DEBUG) << "~TensorRTExecutorPlugin() begin."; - DLSoClose(handle_); - MS_LOG(DEBUG) << "~TensorRTExecutorPlugin() end."; -#endif -} - -TensorRTExecutorPlugin &TensorRTExecutorPlugin::GetInstance() { - static TensorRTExecutorPlugin instance; - return instance; -} - -bool TensorRTExecutorPlugin::Register() { - auto status = TryRegister(); - if (status.IsError()) { - MS_LOG(ERROR) << status.ToString(); - return false; - } - MS_LOG(INFO) << "Register tensorrt plugin success."; - return true; -} - -Status TensorRTExecutorPlugin::TryRegister() { -#if !defined(_WIN32) - if (is_registered_) { - return kSuccess; - } - std::string plugin_path; - auto ret = DLSoPath({"libmindspore-lite.so", "_c_lite"}, kTensorRtPluginSoName, &plugin_path); - if (ret != kSuccess) { - return {kLiteError, std::string("Get real path of ") + kTensorRtPluginSoName + " failed."}; - } - void *function = nullptr; - ret = DLSoOpen(plugin_path, kFunCreateTRTPluginImp, &handle_, &function); - if (ret != kSuccess) { - return {kLiteError, "DLSoOpen failed, so path: " + plugin_path}; - } - auto create_kernel_func = reinterpret_cast(function); - if (create_kernel_func == nullptr) { - return {kLiteError, std::string("Cast ") + kFunCreateTRTPluginImp + " failed."}; - } - auto plugin_impl = create_kernel_func(); - if (plugin_impl == nullptr) { - return {kLiteError, "Create custom TensorRT kernel failed."}; - } - group_size_ = plugin_impl->GetGPUGroupSize(); - rank_id_ = plugin_impl->GetRankID(); - is_registered_ = true; -#endif - return kSuccess; -} - -int TensorRTExecutorPlugin::GetGPUGroupSize() { -#ifdef SUPPORT_TENSORRT - if (!is_registered_) { - Register(); - } -#endif - return group_size_; -} - -int TensorRTExecutorPlugin::GetRankID() { -#ifdef SUPPORT_TENSORRT - if (!is_registered_) { - Register(); - } -#endif - return rank_id_; -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/plugin/tensorrt_executor_plugin.h b/mindspore-lite/src/extendrt/delegate/plugin/tensorrt_executor_plugin.h deleted file mode 100644 index 823252e260d7d80dff1ddc9aa5c0b541af5e3cc2..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/plugin/tensorrt_executor_plugin.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_EXECUTOR_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_EXECUTOR_PLUGIN_H_ -#include "include/api/status.h" -#include "src/common/log_adapter.h" -#include "mindapi/base/macros.h" - -namespace mindspore::lite { -class MS_API TensorRTExecutorPlugin { - public: - static TensorRTExecutorPlugin &GetInstance(); - bool Register(); - Status TryRegister(); - - int GetGPUGroupSize(); - int GetRankID(); - - private: - TensorRTExecutorPlugin(); - ~TensorRTExecutorPlugin(); - - void *handle_ = nullptr; - bool is_registered_ = false; - int group_size_ = 1; - int rank_id_ = 0; -}; - -class TensorRTExecutorPluginImplBase { - public: - TensorRTExecutorPluginImplBase() = default; - virtual ~TensorRTExecutorPluginImplBase() = default; - virtual int GetGPUGroupSize() const = 0; - virtual int GetRankID() const = 0; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_EXECUTOR_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/CMakeLists.txt b/mindspore-lite/src/extendrt/delegate/tensorrt/CMakeLists.txt deleted file mode 100644 index 452b434288339734c3342f27cdbb8f5200935b41..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/CMakeLists.txt +++ /dev/null @@ -1,109 +0,0 @@ -set(CUDA_PATH $ENV{CUDA_HOME}) -include_directories(${OPS_DIR}/kernel/gpu) -set(CUDA_VERSION 11.1) -set(CUDA_LIB_PATH ${CUDA_PATH}/lib64) -include_directories(${CUDA_PATH}) -include_directories(${CUDA_PATH}/include) -find_package(CUDA) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN/") -add_compile_definitions(GPU_TENSORRT) -set(TENSORRT_PATH $ENV{TENSORRT_PATH}) -set(TENSORRT_LIB_PATH ${TENSORRT_PATH}/lib) -include_directories(${TENSORRT_PATH}/include) - -include_directories(${OPS_DIR}/kernel/cpu) -include_directories(${CCSRC_DIR}/../) -include_directories(${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops) - -if(DEFINED ENV{MS_ENABLE_CUDA_DISTRIBUTION}) - set(MS_ENABLE_CUDA_DISTRIBUTION $ENV{MS_ENABLE_CUDA_DISTRIBUTION}) -else() - set(MS_ENABLE_CUDA_DISTRIBUTION "off") -endif() - -set(NCCL_MPI_SRC_STUB - ${CMAKE_CURRENT_SOURCE_DIR}/distribution/distribution_collective.cc - ${CMAKE_CURRENT_SOURCE_DIR}/distribution/distribution_base.cc - ) - -# nccl mpi -if(MS_ENABLE_CUDA_DISTRIBUTION STREQUAL "on") - message("enable cuda gpu distribution collective") - file(GLOB NCCL_MPI_SRC LIST_DIRECTORIES false - ${CMAKE_CURRENT_SOURCE_DIR}/distribution/*.cc - ${CCSRC_DIR}/plugin/device/gpu/hal/device/distribution/collective_wrapper.cc - ${CCSRC_DIR}/plugin/device/gpu/hal/device/distribution/mpi_wrapper.cc - ${CCSRC_DIR}/plugin/device/gpu/hal/device/distribution/nccl_wrapper.cc - ) - list(REMOVE_ITEM NCCL_MPI_SRC ${NCCL_MPI_SRC_STUB}) - - add_compile_definitions(LITE_CUDA_DISTRIBUTION) - include(${TOP_DIR}/cmake/external_libs/ompi.cmake) - include(${TOP_DIR}/cmake/external_libs/nccl.cmake) - - add_library(gpu_distribution_collective OBJECT ${NCCL_MPI_SRC}) - add_library(mindspore::nccl ALIAS nccl::nccl) - add_library(mindspore::ompi ALIAS ompi::mpi) - target_link_libraries(gpu_distribution_collective PRIVATE mindspore::ompi mindspore::nccl) -else() - add_library(gpu_distribution_collective OBJECT ${NCCL_MPI_SRC_STUB}) -endif() -add_dependencies(gpu_distribution_collective fbs_src) - -file(GLOB TENSORRT_RUNTIME_SRC LIST_DIRECTORIES false - ${CMAKE_CURRENT_SOURCE_DIR}/*.cc - ${CMAKE_CURRENT_SOURCE_DIR}/op/*.cc - ${CMAKE_CURRENT_SOURCE_DIR}/optimizer/*.cc - ${CMAKE_CURRENT_SOURCE_DIR}/cuda_impl/*.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../../../extendrt/delegate/delegate_utils.cc - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/cuda_device_info.cc - ${OPS_DIR}/kernel/cpu/nnacl/nnacl_common.c - ${TOP_DIR}/mindspore-lite/src/common/file_utils.cc - ) - -# include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache) - -#set(TENSORRT_RUNTIME_SRC -# ${TENSORRT_RUNTIME_SRC} -# ${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/embedding_cache_manager.cc -# ${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/load_host_cache_model.cc -# ${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/lfu_cache.cc -# ${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/embedding_cache.cc -# ${CMAKE_CURRENT_SOURCE_DIR}/../parameter_cache/gpu/gpu_cache_mem.cc -# ) - -link_libraries(${CUDA_LIB_PATH}/libcudnn.so) -link_libraries(${CUDA_LIB_PATH}/libcublasLt.so) - -add_library(libcudart SHARED IMPORTED) -set_target_properties(libcudart PROPERTIES IMPORTED_LOCATION ${CUDA_LIB_PATH}/libcudart.so) - -add_library(libnvinfer SHARED IMPORTED) -set_target_properties(libnvinfer PROPERTIES IMPORTED_LOCATION ${TENSORRT_LIB_PATH}/libnvinfer.so) - -add_library(libcublas SHARED IMPORTED) -set_target_properties(libcublas PROPERTIES IMPORTED_LOCATION ${CUDA_LIB_PATH}/libcublas.so) -add_library(tensorrt_plugin SHARED ${TENSORRT_RUNTIME_SRC}) - -add_dependencies(tensorrt_plugin fbs_src) - -target_link_libraries( - tensorrt_plugin - libcudart - libcublas - libnvinfer -) -if(SUPPORT_TENSORRT AND (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) - add_library(libcuda SHARED IMPORTED) - set_target_properties(libcuda PROPERTIES IMPORTED_LOCATION ${CUDA_LIB_PATH}/stubs/libcuda.so) - target_link_libraries( - tensorrt_plugin - libcuda - ) -endif() - -add_subdirectory(cuda_impl) - -target_link_libraries(tensorrt_plugin cuda_kernel_mid gpu_distribution_collective) -target_link_libraries(tensorrt_plugin mindspore-extendrt mindspore_core mindspore_ops) diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/CMakeLists.txt b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/CMakeLists.txt deleted file mode 100644 index 280296aa2bb54f540ce1b3e2118af30e015fa050..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -#cuda -find_package(CUDA) -add_compile_definitions(ENABLE_GPU) -file(GLOB_RECURSE CUDA_KERNEL_SRC - ${CMAKE_CURRENT_SOURCE_DIR}/*.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/gatherd.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/swish_impl.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/cumsum_impl.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/batchtospace_impl.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/spacetobatch_impl.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/depthtospace_impl.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/select_impl.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/maxpool_with_argmax_impl.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/roi_align_impl.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/nms_with_mask_impl.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/boundingbox_decode_impl.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/where_impl.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/one_hot_impl.cu - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_ops/tensor_scatter_arithmetic.cu - ) - -set_source_files_properties(${CUDA_KERNEL_SRC} PROPERTIES CUDA_SOURCE_PROPERTY_FORMAT OBJ) -SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGES} -fPIC") -SET(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-arch=sm_53) -cuda_add_library(cuda_kernel_mid STATIC ${CUDA_KERNEL_SRC}) diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/activation.cu b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/activation.cu deleted file mode 100644 index 29c4424dea0ccf02d343f92dec298126208a6607..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/activation.cu +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/cuda_impl/activation.cuh" -#include -#include -#include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h" - -template -__global__ void SigmoidKernel(const T *input1, T *output, int element_cnt) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < element_cnt; pos += blockDim.x * gridDim.x) { - output[pos] = static_cast(1) / (static_cast(1) + exp(-input1[pos])); - } -} - -template -__global__ void GeluKernel(const T *input_addr, T *output_addr, int size) { - // formula: - // gelu(x) = 0.5 * x * (1.0 + tanh(y)) - // tanh(y) = 2 / (1 + exp(-2y)) - 1) - // y = sqrt(2/pi) * (x + 0.044715 * x^3) - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - float x = input_addr[pos]; - float tanh_res = tanh(0.7978845608f * (x + 0.044715f * x * x * x)); - output_addr[pos] = 0.5f * x * (1.0f + tanh_res); - } -} - -template -void Sigmoid(const T *input1, T *output, int element_cnt, cudaStream_t stream) { - SigmoidKernel<<>>(input1, output, element_cnt); - return; -} - -template -void Gelu(const T *input1, T *output, int element_cnt, cudaStream_t stream) { - GeluKernel<<>>(input1, output, element_cnt); - return; -} - -template void Sigmoid(const float *input1, float *output, int element_cnt, cudaStream_t stream); - -template void Gelu(const float *input1, float *output, int element_cnt, cudaStream_t stream); diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cast.cu b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cast.cu deleted file mode 100644 index 438b527782b7d3665738d618ccbb23b818b56012..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cast.cu +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/cuda_impl/cast.cuh" -#include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h" - -// Generic cast -template -__device__ __forceinline__ void CastBase(const S *input_addr, T *output_addr) { - *output_addr = static_cast((*input_addr)); -} - -template -__global__ void CastKernel(const int input_size, const S *input_addr, T *output_addr) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) { - CastBase(input_addr + pos, output_addr + pos); - } -} - -template -void Cast(const int input_size, const S *input_addr, T *output_addr, cudaStream_t stream) { - CastKernel<<>>(input_size, input_addr, output_addr); -} - -template void Cast(const int input_size, const int8_t *input_addr, int8_t *output_addr, cudaStream_t stream); -template void Cast(const int input_size, const int8_t *input_addr, int32_t *output_addr, cudaStream_t stream); -template void Cast(const int input_size, const int8_t *input_addr, float *output_addr, cudaStream_t stream); - -template void Cast(const int input_size, const int32_t *input_addr, int8_t *output_addr, cudaStream_t stream); -template void Cast(const int input_size, const int32_t *input_addr, int32_t *output_addr, cudaStream_t stream); -template void Cast(const int input_size, const int32_t *input_addr, float *output_addr, cudaStream_t stream); -template void Cast(const int input_size, const int32_t *input_addr, bool *output_addr, cudaStream_t stream); - -template void Cast(const int input_size, const float *input_addr, int8_t *output_addr, cudaStream_t stream); -template void Cast(const int input_size, const float *input_addr, int32_t *output_addr, cudaStream_t stream); -template void Cast(const int input_size, const float *input_addr, float *output_addr, cudaStream_t stream); diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.cc deleted file mode 100644 index 9a5a98fd0462491304102c66894b226ea3ccd6dd..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.cc +++ /dev/null @@ -1,168 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h" - -namespace mindspore::lite { -void Cublas2DTranspose(const float *in_addr, float *out_addr, const int *params, cublasHandle_t cublas_handle) { - const int m = params[0]; - const int n = params[1]; - const float alpha = 1.0f; - const float beta = 0.0f; - CUBLAS_CHECK_VOID( - cublasSgeam(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, &alpha, in_addr, n, &beta, out_addr, m, out_addr, m)); -} - -void CublasMM1Batch(const void *a_addr, const void *b_addr, void *c_addr, const int *params, - const cublasOperation_t *operations, const cudaDataType *data_types, cublasHandle_t cublas_handle) { - const int m = params[0]; - const int n = params[1]; - const int k = params[2]; - cublasOperation_t trans_a = operations[0]; - cublasOperation_t trans_b = operations[1]; - const int lda = (trans_a == CUBLAS_OP_N) ? k : m; - const int ldb = (trans_b == CUBLAS_OP_N) ? n : k; - const int ldc = n; - cudaDataType type_a = data_types[0]; - cudaDataType type_b = data_types[1]; - cudaDataType type_c = data_types[2]; - cudaDataType compute_type = data_types[3]; - const float alpha = 1.0f; - const float beta = 0.0f; - CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_b, trans_a, n, m, k, &alpha, b_addr, type_b, ldb, a_addr, type_a, - lda, &beta, c_addr, type_c, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -} - -void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *params, - const cublasOperation_t *operations, const cudaDataType *data_types, - cublasHandle_t cublas_handle) { - cublasOperation_t trans_a = operations[0]; - cublasOperation_t trans_b = operations[1]; - const int m = params[0]; - const int n = params[1]; - const int k = params[2]; - const int batch = params[3]; - const int lda = (trans_a == CUBLAS_OP_N) ? k : m; - const int ldb = (trans_b == CUBLAS_OP_N) ? n : k; - const int ldc = n; - cudaDataType type_a = data_types[0]; - cudaDataType type_b = data_types[1]; - cudaDataType type_c = data_types[2]; - cudaDataType compute_type = data_types[3]; - const float alpha = 1.0f; - const float beta = 0.0f; - CUBLAS_CHECK_VOID(cublasGemmBatchedEx(cublas_handle, trans_b, trans_a, n, m, k, &alpha, b_addrs, type_b, ldb, a_addrs, - type_a, lda, &beta, c_addrs, type_c, ldc, batch, compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -} - -void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds, - const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta, - cublasHandle_t cublas_handle, cublasGemmAlgo_t algo) { - const int m = params[0]; - const int n = params[1]; - const int k = params[2]; - cublasOperation_t trans_a = operations[0]; - cublasOperation_t trans_b = operations[1]; - const int lda = lds[0]; - const int ldb = lds[1]; - const int ldc = lds[2]; - cudaDataType type_a = data_types[0]; - cudaDataType type_b = data_types[1]; - cudaDataType type_c = data_types[2]; - cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; - if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) { - compute_type = CUBLAS_COMPUTE_16F; - } - CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda, b_addr, type_b, - ldb, beta, c_addr, type_c, ldc, compute_type, algo)); -} - -void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, - const int *lds, const cublasOperation_t *operations, const int *strides, - const cudaDataType *data_types, void *alpha, void *beta, int batch, - cublasHandle_t cublas_handle, cublasGemmAlgo_t algo) { - const int m = params[0]; - const int n = params[1]; - const int k = params[2]; - cublasOperation_t trans_a = operations[0]; - cublasOperation_t trans_b = operations[1]; - const int lda = lds[0]; - const int ldb = lds[1]; - const int ldc = lds[2]; - cudaDataType type_a = data_types[0]; - cudaDataType type_b = data_types[1]; - cudaDataType type_c = data_types[2]; - cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; - if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) { - compute_type = CUBLAS_COMPUTE_16F; - } - const int stride_a = strides[0]; - const int stride_b = strides[1]; - const int stride_c = strides[2]; - - CUBLAS_CHECK_VOID(cublasGemmStridedBatchedEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda, - stride_a, b_addr, type_b, ldb, stride_b, beta, c_addr, type_c, ldc, - stride_c, batch, compute_type, algo)); -} - -void CublasLtGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds, - const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta, - const void *bias, cudaStream_t stream, cublasLtHandle_t cublaslt_handle) { - cublasOperation_t trans_a = operations[0]; - cublasOperation_t trans_b = operations[1]; - cudaDataType type_a = data_types[0]; - cudaDataType type_b = data_types[1]; - cudaDataType type_c = data_types[2]; - const int m = params[0]; - const int n = params[1]; - const int k = params[2]; - const int lda = lds[0]; - const int ldb = lds[1]; - const int ldc = lds[2]; - - cublasLtMatrixLayout_t mat_a_desc = NULL; - cublasLtMatrixLayoutCreate(&mat_a_desc, type_a, (trans_a == CUBLAS_OP_N) ? m : k, (trans_a == CUBLAS_OP_N) ? k : m, - lda); - cublasLtMatrixLayout_t mat_b_desc = NULL; - cublasLtMatrixLayoutCreate(&mat_b_desc, type_b, (trans_b == CUBLAS_OP_N) ? k : n, (trans_b == CUBLAS_OP_N) ? n : k, - ldb); - cublasLtMatrixLayout_t mat_c_desc = NULL; - cublasLtMatrixLayoutCreate(&mat_c_desc, type_c, m, n, ldc); - - cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; - if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) { - compute_type = CUBLAS_COMPUTE_16F; - } - - cublasLtMatmulDesc_t mat_operation_desc = NULL; - cublasLtMatmulDescCreate(&mat_operation_desc, compute_type, type_a); - cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(cublasOperation_t)); - cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(cublasOperation_t)); - if (bias != nullptr) { - cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; - cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t)); - cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void *)); - } - - cublasLtMatmul(cublaslt_handle, mat_operation_desc, alpha, a_addr, mat_a_desc, b_addr, mat_b_desc, beta, c_addr, - mat_c_desc, c_addr, mat_c_desc, NULL, NULL, 0, stream); - cublasLtMatrixLayoutDestroy(mat_a_desc); - cublasLtMatrixLayoutDestroy(mat_b_desc); - cublasLtMatrixLayoutDestroy(mat_c_desc); - cublasLtMatmulDescDestroy(mat_operation_desc); -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h deleted file mode 100644 index 528daf8fafbd1a2ecc2ddf89ec2d092655fe5613..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h" -#include "src/common/log_util.h" - -// cublas API error checking -#define CUBLAS_CHECK_VOID(err) \ - do { \ - cublasStatus_t cublas_err = (err); \ - if (cublas_err != CUBLAS_STATUS_SUCCESS) { \ - MS_LOG(ERROR) << "cublas error " << cublas_err; \ - return; \ - } \ - } while (0) - -#define CUBLAS_CHECK(err) \ - do { \ - cublasStatus_t cublas_err = (err); \ - if (cublas_err != CUBLAS_STATUS_SUCCESS) { \ - MS_LOG(ERROR) << "cublas error " << cublas_err; \ - return -1; \ - } \ - } while (0) - -namespace mindspore::lite { -// a: m * n -// params order: m, n -void Cublas2DTranspose(const float *in_addr, float *out_addr, const int *params, cublasHandle_t cublas_handle); - -// a: m * k, b: k * n, c: m * n -// params order: m, n, k -// operations order: trans_a, trans_b -// data_types: type_a, type_b, type_c, compute type -void CublasMM1Batch(const void *a_addr, const void *b_addr, void *c_addr, const int *params, - const cublasOperation_t *operations, const cudaDataType *data_types, cublasHandle_t cublas_handle); - -// a: batch * m * k, b: batch * k * n, c: batch * m * n -// params order: m, n, k, batch -// operations order: trans_a, trans_b -// data_types: type_a, type_b, type_c, compute type -void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *params, - const cublasOperation_t *operations, const cudaDataType *data_types, cublasHandle_t cublas_handle); - -void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds, - const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta, - cublasHandle_t cublas_handle, cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT_TENSOR_OP); -void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, - const int *lds, const cublasOperation_t *operations, const int *strides, - const cudaDataType *data_types, void *alpha, void *beta, int batch, - cublasHandle_t cublas_handle, cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT_TENSOR_OP); - -void CublasLtGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds, - const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta, - const void *bias, cudaStream_t stream, cublasLtHandle_t cublaslt_handle); -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h deleted file mode 100644 index 8bafbaa52614761ec4a3d37ff8d5e089333d005e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUDA_HELPER_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUDA_HELPER_H_ - -#include -#include - -class CudaHelper { - public: - int GetThreadNum() const; - int GetThreadNum(const int block_size) const; - int GetBlocksNum(const int total_threads) const; - int GetBlocksNum(const int total_threads, const int block_size) const; - static CudaHelper &GetInstance(); - - private: - CudaHelper(); - ~CudaHelper() = default; - CudaHelper(const CudaHelper &) = delete; - CudaHelper &operator=(const CudaHelper &) = delete; - - int max_blocks_; - int threads_per_block_; -}; - -#ifndef GET_BLOCKS -#define GET_BLOCKS(total_threads) CudaHelper::GetInstance().GetBlocksNum(total_threads) -#endif -#define GET_BLOCKS_CAL(total_threads, block_size) CudaHelper::GetInstance().GetBlocksNum(total_threads, block_size) -#ifndef GET_THREADS -#define GET_THREADS CudaHelper::GetInstance().GetThreadNum() -#endif -#define GET_THREADS_CAL(block_size) CudaHelper::GetInstance().GetThreadNum(block_size) - -#define CUDA_CHECK(ret) \ - do { \ - cudaError_t cuda_ret = (ret); \ - if ((cuda_ret) != cudaSuccess) { \ - return -1; \ - } \ - } while (0) - -#define CUDA_CHECK_VOID(ret) \ - do { \ - cudaError_t cuda_ret = (ret); \ - if ((cuda_ret) != cudaSuccess) { \ - return; \ - } \ - } while (0) - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUDA_HELPER_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.h b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.h deleted file mode 100644 index 3e65d7bc5275df59a962e5dd7ecb4b987aed8603..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUDNN_UTILS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUDNN_UTILS_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h" -#include "src/common/log_util.h" - -#define CUDNN_CHECK_VOID(err) \ - do { \ - cudnnStatus_t cudnn_err = (err); \ - if (cudnn_err != CUDNN_STATUS_SUCCESS) { \ - MS_LOG(ERROR) << "cudnn error " << cudnnGetErrorString(cudnn_err); \ - return; \ - } \ - } while (0) - -#define CUDNN_CHECK(err) \ - do { \ - cudnnStatus_t cudnn_err = (err); \ - if (cudnn_err != CUDNN_STATUS_SUCCESS) { \ - MS_LOG(ERROR) << "cudnn error " << cudnnGetErrorString(cudnn_err); \ - return -1; \ - } \ - } while (0) -namespace mindspore::lite { -cudnnDataType_t ConvertCudnnDataType(nvinfer1::DataType trt_datatype); - -int CudnnActivation(cudnnHandle_t handle, cudnnActivationDescriptor_t activation_desc, - const cudnnTensorDescriptor_t x_esc, const void *x, const cudnnTensorDescriptor_t y_dsc, void *y); -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUDNN_UTILS_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/equal.cuh b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/equal.cuh deleted file mode 100644 index 69551308a9781f5693a92d333f9fd9536f9f5d51..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/equal.cuh +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_EQUAL_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_EQUAL_H_ - -template -void Equal(const T *input1, const T *input2, T *output, int element_cnt, cudaStream_t stream); - -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_CDUA_IMPL_EQUAL_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cu b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cu deleted file mode 100644 index e5150eda5c8d97771a6d6ad239b7b318fc8342e0..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cu +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cuh" -#include -#include -#include -#include "nnacl/op_base.h" - -__device__ __forceinline__ uint64_t Pop(const uint64_t *chunks, uint64_t *curr_chunk, uint8_t bit_count, - int32_t *curr_bit_count, int32_t *curr_chunk_index) { - const int kMaxBitCount = 64; - uint64_t right = *curr_chunk >> static_cast(kMaxBitCount - *curr_bit_count); - uint64_t res = right & ((1u << bit_count) - 1); - *curr_bit_count -= static_cast(bit_count); - if (*curr_bit_count > 0) { - return res; - } - if (*curr_bit_count == 0) { - if (*curr_chunk_index > -1) { - *curr_bit_count = kMaxBitCount; - *curr_chunk = chunks[(*curr_chunk_index)--]; - } - return res; - } - *curr_bit_count += static_cast(bit_count); - *curr_chunk = chunks[(*curr_chunk_index)--]; - right |= (*curr_chunk & ((1u << (static_cast(bit_count) - *curr_bit_count)) - 1)) << *curr_bit_count; - *curr_bit_count = kMaxBitCount - (static_cast(bit_count) - *curr_bit_count); - return right; -} - -template -__global__ void FSE_Decode_kernel(const uint64_t *chunks, const uint16_t *states_table, const uint8_t *bit_count_table, - const uint16_t *symbol_table, const uint64_t *ptable, int ptable_size, - const T *centroids, uint64_t out_size, T *output, const uint64_t current_chunk_input, - bool use_curr_chunk) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= ptable_size) { - return; - } - int32_t curr_chunk_index = static_cast(ptable[idx] >> 32); - uint16_t state = static_cast((ptable[idx] >> 16) & 0xffff); - int32_t curr_bit_count = static_cast(ptable[idx] & 0xffff); - if (curr_bit_count == 0) { - curr_chunk_index--; - curr_bit_count = 64; - } - uint64_t curr_chunk = - ((idx == ptable_size - 1) && (use_curr_chunk)) ? current_chunk_input : chunks[curr_chunk_index + 1]; - const uint64_t output_offset = out_size * idx / ptable_size; - uint64_t out_count = out_size * (idx + 1) / ptable_size - output_offset; - T *output_ptr = output + output_offset; - while ((curr_chunk_index >= 0) || (bit_count_table[state] == 0) || (curr_bit_count > 0)) { - if (out_count == 0) { - break; - } - output_ptr[--out_count] = centroids[symbol_table[state]]; - state = states_table[state] + Pop(chunks, &curr_chunk, bit_count_table[state], &curr_bit_count, &curr_chunk_index); - } -} - -template -void FSE_Decode(const uint64_t *chunks, const uint16_t *states_table, const uint8_t *bit_count_table, - const uint16_t *symbol_table, const uint64_t *ptable, int ptable_size, const T *centroids, - uint64_t out_size, T *output, const uint32_t &device_id, uint64_t current_chunk_input, - bool use_curr_chunk, cudaStream_t cuda_stream) { - const int kThreads = 256; - const int kBlocks = UP_DIV(ptable_size, kThreads); - FSE_Decode_kernel<<>>(chunks, states_table, bit_count_table, symbol_table, ptable, - ptable_size, centroids, out_size, output, - current_chunk_input, use_curr_chunk); -} - -template void FSE_Decode(const uint64_t *chunks, const uint16_t *states_table, const uint8_t *bit_count_table, - const uint16_t *symbol_table, const uint64_t *ptable, int ptable_size, const float *centroids, - uint64_t out_size, float *output, const uint32_t &device_id, uint64_t current_chunk_input, - bool use_curr_chunk, cudaStream_t cuda_stream); - -template void FSE_Decode(const uint64_t *chunks, const uint16_t *states_table, const uint8_t *bit_count_table, - const uint16_t *symbol_table, const uint64_t *ptable, int ptable_size, const half *centroids, - uint64_t out_size, half *output, const uint32_t &device_id, uint64_t current_chunk_input, - bool use_curr_chunk, cudaStream_t cuda_stream); diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/hash.cu b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/hash.cu deleted file mode 100755 index a2f19576abf8a76b7ed8555b0cc6cf1e30432011..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/hash.cu +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/cuda_impl/hash.cuh" -#include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h" - -template -__global__ void HashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size, - const int hash_dim) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < index_size; i += blockDim.x * gridDim.x) { - int hash_index = swap_out_index[i]; - for (int j = 0; j < hash_dim; j++) { - swap_out_value[i * hash_dim + j] = hash_table[hash_index * hash_dim + j]; - } - } - return; -} - -template -__global__ void HashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size, - const int hash_dim) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < index_size; i += blockDim.x * gridDim.x) { - int hash_index = swap_in_index[i]; - for (int j = 0; j < hash_dim; j++) { - hash_table[hash_index * hash_dim + j] = swap_in_value[i * hash_dim + j]; - } - } - return; -} - -template -void DoHashSwapOut(const T *hash_table, T *swap_out_value, const int *swap_out_index, const int index_size, - const int hash_dim, cudaStream_t cuda_stream) { - HashSwapOut<<>>(hash_table, swap_out_value, swap_out_index, - index_size, hash_dim); - return; -} - -template -void DoHashSwapIn(T *hash_table, const T *swap_in_value, const int *swap_in_index, const int index_size, - const int hash_dim, cudaStream_t cuda_stream) { - HashSwapIn<<>>(hash_table, swap_in_value, swap_in_index, - index_size, hash_dim); - return; -} - -template void DoHashSwapOut(const float *hash_table, float *swap_out_value, const int *swap_out_index, - const int index_size, const int hash_dim, cudaStream_t cuda_stream); - -template void DoHashSwapIn(float *hash_table, const float *swap_in_value, const int *swap_in_index, - const int index_size, const int hash_dim, cudaStream_t cuda_stream); diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/logical.cu b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/logical.cu deleted file mode 100644 index 997b39975796ca3224ca0f5b5efde9bb361fe6a2..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/logical.cu +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/cuda_impl/logical.cuh" -#include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h" - -template -__global__ void LogicalNotKernel(const T *input1, T *output, int element_cnt) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < element_cnt; pos += blockDim.x * gridDim.x) { - output[pos] = static_cast(input1[pos] == 0); - } -} - -template -__global__ void LogicalAndKernel(const T *input_addr1, const T *input_addr2, T *output, int size) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - output[pos] = input_addr1[pos] * input_addr2[pos]; - } -} - -template -__global__ void LogicalOrKernel(const T *input_addr1, const T *input_addr2, T *output, int size) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - T sum = input_addr1[pos] + input_addr2[pos]; - output[pos] = static_cast(sum > 0); - } -} - -template -__global__ void GreaterOrEqualKernal(const T *input1, const T *input2, T *output, int element_cnt) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < element_cnt; pos += blockDim.x * gridDim.x) { - output[pos] = (input1[pos] >= input2[pos]); - } -} - -template -__global__ void LessOrEqualKernal(const T *input1, const T *input2, T *output, int element_cnt) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < element_cnt; pos += blockDim.x * gridDim.x) { - output[pos] = (input1[pos] <= input2[pos]); - } -} - -template -void LogicalNot(const T *input1, T *output, int element_cnt, cudaStream_t stream) { - LogicalNotKernel<<>>(input1, output, element_cnt); -} - -template -void LogicalAnd(const T *input1, const T *input2, T *output, int element_cnt, cudaStream_t stream) { - LogicalAndKernel<<>>(input1, input2, output, element_cnt); -} - -template -void LogicalOr(const T *input1, const T *input2, T *output, int element_cnt, cudaStream_t stream) { - LogicalOrKernel<<>>(input1, input2, output, element_cnt); -} - -template -void GreaterOrEqual(const T *input1, const T *input2, T *output, int element_cnt, cudaStream_t stream) { - GreaterOrEqualKernal<<>>(input1, input2, output, element_cnt); -} - -template -void LessOrEqual(const T *input1, const T *input2, T *output, int element_cnt, cudaStream_t stream) { - LessOrEqualKernal<<>>(input1, input2, output, element_cnt); -} - -template void GreaterOrEqual(const float *input1, const float *input2, float *output, int element_cnt, - cudaStream_t stream); - -template void GreaterOrEqual(const int *input1, const int *input2, int *output, int element_cnt, cudaStream_t stream); - -template void LessOrEqual(const float *input1, const float *input2, float *output, int element_cnt, - cudaStream_t stream); - -template void LessOrEqual(const int *input1, const int *input2, int *output, int element_cnt, cudaStream_t stream); - -template void LogicalNot(const int32_t *input1, int32_t *output, int element_cnt, cudaStream_t stream); - -template void LogicalAnd(const int32_t *input1, const int32_t *input2, int32_t *output, int element_cnt, - cudaStream_t stream); - -template void LogicalOr(const int32_t *input1, const int32_t *input2, int32_t *output, int element_cnt, - cudaStream_t stream); diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/normalize.cu b/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/normalize.cu deleted file mode 100644 index b3b24bcb983e0c12b461282fd16b415af6daba66..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/cuda_impl/normalize.cu +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/cuda_impl/normalize.cuh" -#include -#include -#include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/utils.cuh" - -template -__global__ void NormalizeKernel(const T *input, const T *gamma, const T *beta, T *output, size_t n, float epsilion, - int dim_before_axis) { - const int tid = threadIdx.x; - const int bid = blockIdx.x; - const int block_loop = (dim_before_axis - 1) / gridDim.x + 1; - const int element_cnt = dim_before_axis * n; - - __shared__ float s_mean[2048]; - __shared__ float s_variance[2048]; - float sum = 0.0f; - float variance = 0.0f; - - for (int block = 0; block < block_loop; block++) { - float local_sum = 0.0f; - int mean_index = bid + block * gridDim.x; - int num_index = bid * n + block * gridDim.x * blockDim.x; - for (int i = tid; i < n; i += blockDim.x) { - if (num_index + i >= element_cnt) { - break; - } - local_sum += static_cast(input[num_index + i]); - } - sum = blockReduceSum(local_sum); - if (tid == 0) { - s_mean[mean_index] = sum / n; - } - } - __syncthreads(); - - for (int block = 0; block < block_loop; block++) { - float local_var_sum = 0.0f; - int var_index = bid + block * gridDim.x; - int num_index = bid * n + block * gridDim.x * blockDim.x; - for (int i = tid; i < n; i += blockDim.x) { - if (num_index + i >= element_cnt) { - break; - } - float diff = static_cast(input[num_index + i]) - s_mean[var_index]; - local_var_sum += diff * diff; - } - variance = blockReduceSum(local_var_sum); - if (tid == 0) { - s_variance[var_index] = rsqrtf(variance / n + epsilion); - } - } - __syncthreads(); - for (int block = 0; block < block_loop; block++) { - int var_index = bid + block * gridDim.x; - int num_index = bid * n + block * gridDim.x * blockDim.x; - for (int i = tid; i < n; i += blockDim.x) { - if (num_index + i >= element_cnt) { - break; - } - float beta_val = (beta == nullptr) ? 0.0f : static_cast(beta[i]); - output[num_index + i] = - static_cast(((static_cast(input[num_index + i]) - s_mean[var_index]) * s_variance[var_index]) * - static_cast(gamma[i]) + - beta_val); - } - } -} - -template -void Normalize(const T *input, const T *gamma, const T *beta, T *output, size_t dim_at_axis, float epsilion, - int element_cnt, cudaStream_t stream) { - int thread_num = GET_THREADS_CAL(dim_at_axis); - int block_num = GET_BLOCKS_CAL(element_cnt, thread_num); - int dim_before_axis = element_cnt / dim_at_axis; - NormalizeKernel<<>>(input, gamma, beta, output, dim_at_axis, epsilion, - dim_before_axis); - return; -} - -template void Normalize(const float *input, const float *gamma, const float *beta, float *output, size_t dim_at_axis, - float epsilion, int element_cnt, cudaStream_t stream); diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_base.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_base.cc deleted file mode 100644 index 865eb3afcdf1f9ee0260da66ae2ce9970936735e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_base.cc +++ /dev/null @@ -1,24 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/distribution/distribution_base.h" -#include "src/extendrt/delegate/plugin/tensorrt_executor_plugin.h" - -namespace mindspore::lite { -int GetGPUGroupSize() { return TensorRTExecutorPlugin::GetInstance().GetGPUGroupSize(); } - -int GetRankID() { return TensorRTExecutorPlugin::GetInstance().GetRankID(); } -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_collective.h b/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_collective.h deleted file mode 100644 index 79c7546a1e4d4ab136c9b8bdc8126d76b3bfb768..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_collective.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_COLLECTIVE_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_COLLECTIVE_H_ - -#include -#include "NvInfer.h" -#include "mindapi/base/types.h" -#include "src/extendrt/delegate/tensorrt/distribution/distribution_base.h" - -namespace mindspore::lite { -class DistributionCollective { - public: - DistributionCollective(DistributionCollective const &) = delete; - - DistributionCollective &operator=(const DistributionCollective &) = delete; - - static DistributionCollective &instance(); - - int ReduceScatterWrapper(const void *input_addr, void *output_addr, size_t count, nvinfer1::DataType data_type, - ReduceMode reduce_type, cudaStream_t stream, const std::string &group); - - int AllReduceWrapper(const void *input_addr, void *output_addr, size_t count, nvinfer1::DataType data_type, - ReduceMode reduce_type, cudaStream_t stream, const std::string &group); - - int AllGatherWrapper(const void *input_addr, void *output_addr, size_t count, nvinfer1::DataType data_type, - cudaStream_t stream, const std::string &group_name); - - private: - DistributionCollective(); - - ~DistributionCollective() = default; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_COLLECTIVE_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_collective_impl.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_collective_impl.cc deleted file mode 100644 index 36e136966d8aac607af51cd8199f740923240bff..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_collective_impl.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/distribution/distribution_collective.h" -#include -#include -#include -#include "plugin/device/gpu/hal/device/distribution/collective_wrapper.h" -#include "src/extendrt/delegate/tensorrt/distribution/distribution_utils.h" -#include "src/extendrt/delegate/tensorrt/distribution/distribution_base.h" - -namespace mindspore::lite { -DistributionCollective::DistributionCollective() { - InitMPI(); - InitNCCLComm(); -} - -DistributionCollective &DistributionCollective::instance() { - static DistributionCollective instance; - return instance; -} - -int DistributionCollective::ReduceScatterWrapper(const void *input_addr, void *output_addr, size_t count, - nvinfer1::DataType data_type, ReduceMode reduce_type, - cudaStream_t stream, const std::string &group) { - int rank_id = GetRankID(); - MS_LOG(DEBUG) << "ReduceScatter on rank: " << rank_id; - ncclResult_t ret = ReduceScatter(input_addr, output_addr, count, ConvertNCCLDataType(data_type), - ConvertNCCLReduceMode(reduce_type), stream, group); - if (ret != ncclSuccess) { - MS_LOG(ERROR) << "ReduceScatter failed: " << static_cast(ret); - return RET_ERROR; - } - auto cuda_ret = cudaStreamSynchronize(stream); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaStreamSynchronize failed: " << static_cast(cuda_ret); - return RET_ERROR; - } - return RET_OK; -} - -int DistributionCollective::AllReduceWrapper(const void *input_addr, void *output_addr, size_t count, - nvinfer1::DataType data_type, ReduceMode reduce_type, cudaStream_t stream, - const std::string &group) { - int rank_id = GetRankID(); - MS_LOG(DEBUG) << "AllReduce on rank: " << rank_id; - ncclResult_t ret = AllReduce(input_addr, output_addr, count, ConvertNCCLDataType(data_type), - ConvertNCCLReduceMode(reduce_type), stream, group); - if (ret != ncclSuccess) { - MS_LOG(ERROR) << "AllReduce failed: " << static_cast(ret); - return RET_ERROR; - } - auto cuda_ret = cudaStreamSynchronize(stream); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaStreamSynchronize failed: " << static_cast(cuda_ret); - return RET_ERROR; - } - return RET_OK; -} - -int DistributionCollective::AllGatherWrapper(const void *input_addr, void *output_addr, size_t count, - nvinfer1::DataType data_type, cudaStream_t stream, - const std::string &group_name) { - int rank_id = GetRankID(); - MS_LOG(DEBUG) << "AllGather on rank: " << rank_id; - ncclResult_t ret = AllGather(input_addr, output_addr, count, ConvertNCCLDataType(data_type), stream, group_name); - if (ret != ncclSuccess) { - MS_LOG(ERROR) << "AllGather failed: " << static_cast(ret); - return RET_ERROR; - } - auto cuda_ret = cudaStreamSynchronize(stream); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaStreamSynchronize failed: " << static_cast(cuda_ret); - return RET_ERROR; - } - return RET_OK; -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_utils.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_utils.cc deleted file mode 100644 index bed0648b0f9dde4cc5eb56f3d0c46e3a256e5aa1..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_utils.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/distribution/distribution_utils.h" -#include -#include "src/common/log_adapter.h" - -namespace mindspore::lite { -ncclDataType_t ConvertNCCLDataType(nvinfer1::DataType type_id) { - std::unordered_map data_type_map = { - {nvinfer1::DataType::kINT8, ncclInt8}, - {nvinfer1::DataType::kINT32, ncclInt32}, - {nvinfer1::DataType::kFLOAT, ncclFloat32}, - {nvinfer1::DataType::kHALF, ncclHalf}, - }; - auto iter = data_type_map.find(type_id); - ncclDataType_t data_type; - if (iter != data_type_map.end()) { - data_type = iter->second; - } else { - data_type = ncclFloat32; - MS_LOG(WARNING) << "invalid data_type for NCCL, need check: " << static_cast(type_id); - } - return data_type; -} - -ncclRedOp_t ConvertNCCLReduceMode(ReduceMode mode) { - std::unordered_map reduce_ops_ = { - // higher version support mean {schema::ReduceMode::ReduceMode_ReduceMean, ncclAvg}, - {ReduceMode::Reduce_Max, ncclMax}, - {ReduceMode::Reduce_Min, ncclMin}, - {ReduceMode::Reduce_Prod, ncclProd}, - {ReduceMode::Reduce_Sum, ncclSum}, - }; - auto iter = reduce_ops_.find(mode); - ncclRedOp_t nccl_mode; - if (iter != reduce_ops_.end()) { - nccl_mode = iter->second; - } else { - nccl_mode = ncclSum; - MS_LOG(WARNING) << "invalid reduce for NCCL, need check: " << static_cast(mode); - } - return nccl_mode; -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_utils.h b/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_utils.h deleted file mode 100644 index 9c1f732773e316e165bf300ac0fb35f817fdd9f9..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/distribution/distribution_utils.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_UTILS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_UTILS_H_ - -#include -#include "include/errorcode.h" -#include "NvInfer.h" -#include "mindapi/base/types.h" - -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; - -namespace mindspore::lite { -ncclDataType_t ConvertNCCLDataType(nvinfer1::DataType type_id); - -ncclRedOp_t ConvertNCCLReduceMode(ReduceMode mode); -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_DISTRIBUTION_DISTRIBUTION_UTILS_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/activation_opt_plugin.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/activation_opt_plugin.cc deleted file mode 100644 index 09a87fba29591045873f2b8b04502d69c4968fb8..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/activation_opt_plugin.cc +++ /dev/null @@ -1,116 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/activation_opt_plugin.h" -#include -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/activation.cuh" -#include "kernel/gpu/cuda_impl/cuda_ops/swish_impl.cuh" - -namespace mindspore::lite { -REGISTER_TENSORRT_PLUGIN(ActivationOptPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int ActivationOptPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - return RunCudaActivation(inputDesc, inputs, outputs, stream); -} - -bool ActivationOptPlugin::needResize(const int *current_dims, const int *last_dims) { - for (int i = 0; i < infer_dims_cnt_; i++) { - if (current_dims[i] != last_dims[i]) { - return true; - } - } - return false; -} - -int ActivationOptPlugin::RunCuDNNActivation(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - if (needResize(infer_dims_, inputDesc[0].dims.d)) { - if (input_desc_ != nullptr) { - CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc_)); - input_desc_ = nullptr; - } - CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc_)); - for (int i = 0; i < inputDesc[0].dims.nbDims; i++) { - infer_dims_[i] = inputDesc[0].dims.d[i]; - } - CUDNN_CHECK(cudnnSetTensorNdDescriptor(input_desc_, ConvertCudnnDataType(inputDesc[0].type), infer_dims_cnt_, - infer_dims_, infer_stride_)); - } - CHECK_NULL_RETURN(cudnn_handle_); - CHECK_NULL_RETURN(activation_desc_); - CHECK_NULL_RETURN(input_desc_); - CUDNN_CHECK(cudnnSetStream(cudnn_handle_, stream)); - auto ret = CudnnActivation(cudnn_handle_, activation_desc_, input_desc_, inputs[0], input_desc_, outputs[0]); - if (ret != RET_OK) { - MS_LOG(ERROR) << "cudnn activation func call failed " << layer_name_; - return ret; - } - return RET_OK; -} - -int ActivationOptPlugin::RunCudaActivation(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - switch (activation_type_) { - case (ActivationType::SIGMOID): { - Sigmoid(static_cast(inputs[0]), static_cast(outputs[0]), GetDimsVolume(inputDesc[0].dims), - stream); - break; - } - case (ActivationType::GELU): { - Gelu(static_cast(inputs[0]), static_cast(outputs[0]), GetDimsVolume(inputDesc[0].dims), - stream); - break; - } - case (ActivationType::SWISH): { - Swish(GetDimsVolume(inputDesc[0].dims), static_cast(inputs[0]), static_cast(outputs[0]), - stream, device_id_); - break; - } - default: { - MS_LOG(ERROR) << "invalid activation type: " << static_cast(activation_type_); - return RET_ERROR; - } - } - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *ActivationOptPlugin::clone() const noexcept { - auto *plugin = new ActivationOptPlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -size_t ActivationOptPlugin::getSerializationSize() const noexcept { return sizeof(ActivationType); } - -void ActivationOptPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &activation_type_, sizeof(ActivationType)); -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/activation_opt_plugin.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/activation_opt_plugin.h deleted file mode 100644 index ffeb6b68189d06af792e8470f3542554380ae13f..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/activation_opt_plugin.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ACTIVATION_OPT_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ACTIVATION_OPT_PLUGIN_H_ - -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.h" -#include "infer/cxx_api/activation.h" - -namespace mindspore::lite { -constexpr auto ACTIVATION_OPT_PLUGIN_NAME{"ActivationOptPlugin"}; -class ActivationOptPlugin : public TensorRTPlugin { - public: - ActivationOptPlugin(const std::string name, ActivationType activation_type, uint32_t device_id) - : TensorRTPlugin(name, std::string(ACTIVATION_OPT_PLUGIN_NAME), device_id), activation_type_(activation_type) {} - - ActivationOptPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(ACTIVATION_OPT_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - activation_type_ = static_cast(fields[0].data)[0]; - } - - ActivationOptPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(ACTIVATION_OPT_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &activation_type_, sizeof(ActivationType)); - } - - ActivationOptPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - - private: - bool needResize(const int *current_dims, const int *last_dims); - int RunCudaActivation(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - int RunCuDNNActivation(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - const std::string layer_name_; - std::string name_space_; - ActivationType activation_type_; - cudnnHandle_t cudnn_handle_{nullptr}; - cudnnActivationDescriptor_t activation_desc_{nullptr}; - cudnnTensorDescriptor_t input_desc_{nullptr}; - int infer_dims_[5]{1, 1, 1, 1, 1}; - int infer_stride_[5]{1, 1, 1, 1, 1}; - int infer_dims_cnt_{0}; -}; -class ActivationOptPluginCreater : public TensorRTPluginCreater { - public: - ActivationOptPluginCreater() : TensorRTPluginCreater(std::string(ACTIVATION_OPT_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ACTIVATION_OPT_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/activation_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/activation_tensorrt.cc deleted file mode 100644 index 98c5639d812670697602a8618c65603f7e48ef1b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/activation_tensorrt.cc +++ /dev/null @@ -1,259 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/activation_tensorrt.h" -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/cast_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/op/activation_opt_plugin.h" -#include "infer/cxx_api/activation.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" - -namespace mindspore::lite { -namespace { -bool HasCustomActivationPlugin(ActivationType type) { - std::unordered_set plugin_activation = {ActivationType::SIGMOID, ActivationType::GELU, - ActivationType::SWISH}; - return plugin_activation.find(type) != plugin_activation.end(); -} -} // namespace - -int ActivationTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - auto activation_op = AsOps(); - if (activation_op == nullptr) { - MS_LOG(ERROR) << "op convert failed"; - return RET_ERROR; - } - ActivationType activation_type = ActivationType::NO_ACTIVATION; - if (activation_op->HasAttr(ops::kActivationType)) { - activation_type = activation_op->get_activation_type(); - } - if (activation_type == ActivationType::HSWISH) { - return RET_OK; - } - auto activation_params_opt = TryConvertActivationType(activation_type); - if (!activation_params_opt) { - MS_LOG(ERROR) << "Unsupported op action type for TensorRT: " << activation_type; - return RET_ERROR; - } - return RET_OK; -} -int ActivationTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx->network() == nullptr) { - MS_LOG(ERROR) << "network is invalid"; - return RET_ERROR; - } - auto activation_op = AsOps(); - if (activation_op == nullptr) { - MS_LOG(ERROR) << "op convert failed"; - return RET_ERROR; - } - float alpha = activation_op->get_alpha(); - nvinfer1::ITensor *activation_input = input(ctx, 0).trt_tensor_; - if (input(ctx, 0).trt_tensor_->getType() == nvinfer1::DataType::kINT32) { - activation_input = TRTTensorCast(ctx, input(ctx, 0).trt_tensor_, nvinfer1::DataType::kFLOAT, op_name_ + "_cast_in"); - } - - auto runtime_precision_mode = runtime_->GetRuntimePrecisionMode(); - ActivationType activation_type = ActivationType::NO_ACTIVATION; - if (activation_op->HasAttr(ops::kActivationType)) { - activation_type = activation_op->get_activation_type(); - } - auto activation_layer = ActivationTensorRT::AddActivation( - ctx, activation_type, alpha, std::isfinite(activation_op->get_min_val()) ? activation_op->get_min_val() : FLT_MIN, - std::isfinite(activation_op->get_max_val()) ? activation_op->get_max_val() : FLT_MAX, activation_input, op_name_, - device_id_, quant_type_, runtime_precision_mode); - if (activation_layer == nullptr) { - MS_LOG(ERROR) << "add activation op failed for TensorRT."; - return RET_ERROR; - } - - activation_layer->setName(op_name_.c_str()); - // cast to origin type - nvinfer1::ITensor *out_tensor = activation_layer->getOutput(0); - if (out_tensor->getType() != ConvertDataType(out_tensors_[0].DataType())) { - out_tensor = TRTTensorCast(ctx, activation_layer->getOutput(0), ConvertDataType(out_tensors_[0].DataType()), - op_name_ + "_cast_out"); - } - ctx->RegisterTensor(ITensorHelper{out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - this->layer_ = activation_layer; - return RET_OK; -} - -nvinfer1::ILayer *ActivationTensorRT::AddHSwishActivation(TensorRTContext *ctx, nvinfer1::ITensor *trt_in_tensor, - const std::string &op_name) { - if (trt_in_tensor->getDimensions().nbDims <= 0) { - MS_LOG(ERROR) << "Invalid input dims count " << trt_in_tensor->getDimensions().nbDims << ", " << op_name; - return nullptr; - } - size_t dims_size = mindspore::IntToSize(trt_in_tensor->getDimensions().nbDims); - static const float add_3_const = 3.0f; - auto add_input1 = - ConvertScalarToITensor(ctx, dims_size, &add_3_const, DataType::kNumberTypeFloat32, op_name + "_add3"); - if (add_input1 == nullptr) { - MS_LOG(ERROR) << "Failed to add const input 3 for hard swish: " << op_name; - return nullptr; - } - auto add_3 = ctx->network()->addElementWise(*trt_in_tensor, *add_input1, nvinfer1::ElementWiseOperation::kSUM); - if (add_3 == nullptr) { - MS_LOG(ERROR) << "Failed to add layer x+3 for hard swish: " << op_name; - return nullptr; - } - add_3->setName((op_name + "_add3").c_str()); - auto add_3_output = add_3->getOutput(0); - if (add_3_output == nullptr) { - MS_LOG(ERROR) << "Failed to get output of layer x+3 for hard swish: " << op_name; - return nullptr; - } - auto relu6 = ctx->network()->addActivation(*add_3_output, nvinfer1::ActivationType::kCLIP); - if (relu6 == nullptr) { - MS_LOG(ERROR) << "Failed to add layer relu6 for hard swish: " << op_name; - return nullptr; - } - relu6->setAlpha(0.0f); - relu6->setBeta(6.0f); - relu6->setName((op_name + "_relu6").c_str()); - auto relu6_output = relu6->getOutput(0); - if (relu6_output == nullptr) { - MS_LOG(ERROR) << "Failed to get output of layer relu6 for hard swish: " << op_name; - return nullptr; - } - auto mul = ctx->network()->addElementWise(*trt_in_tensor, *relu6_output, nvinfer1::ElementWiseOperation::kPROD); - if (mul == nullptr) { - MS_LOG(ERROR) << "Failed to add layer mul for hard swish: " << op_name; - return nullptr; - } - mul->setName((op_name + "_mul").c_str()); - auto mul_output = mul->getOutput(0); - if (mul_output == nullptr) { - MS_LOG(ERROR) << "Failed to get output of layer mul for hard swish: " << op_name; - return nullptr; - } - static const float div_6_const = 6.0f; - auto div_input1 = - ConvertScalarToITensor(ctx, dims_size, &div_6_const, DataType::kNumberTypeFloat32, op_name + "_div6"); - if (div_input1 == nullptr) { - MS_LOG(ERROR) << "Failed to add const input 6 for hard swish: " << op_name; - return nullptr; - } - auto real_div = ctx->network()->addElementWise(*mul_output, *div_input1, nvinfer1::ElementWiseOperation::kDIV); - if (real_div == nullptr) { - MS_LOG(ERROR) << "Failed to add layer real div for hard swish: " << op_name; - return nullptr; - } - return real_div; -} - -nvinfer1::ILayer *ActivationTensorRT::AddGeluActivation(TensorRTContext *ctx, nvinfer1::ITensor *trt_in_tensor, - const std::string &op_name) { - if (trt_in_tensor->getDimensions().nbDims <= 0) { - MS_LOG(ERROR) << "Invalid input dims count " << trt_in_tensor->getDimensions().nbDims << ", " << op_name; - return nullptr; - } - auto expand_dims = [](TensorRTContext *ctx, nvinfer1::ITensor *tensor, int nbdims) { - while (tensor->getDimensions().nbDims != nbdims) { - tensor = ExpandDim(ctx, tensor, 0); - } - return tensor; - }; - int nbdims = trt_in_tensor->getDimensions().nbDims; - auto const_three = expand_dims(ctx, ctx->ConvertTo1DTensor(3.f), nbdims); - auto p3 = - ctx->network()->addElementWise(*trt_in_tensor, *const_three, nvinfer1::ElementWiseOperation::kPOW)->getOutput(0); - auto gelu_p1 = expand_dims(ctx, ctx->ConvertTo1DTensor(0.044715f), nbdims); - auto prod1 = ctx->network()->addElementWise(*p3, *gelu_p1, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - auto sum = ctx->network()->addElementWise(*prod1, *trt_in_tensor, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - auto gelu_p2 = expand_dims(ctx, ctx->ConvertTo1DTensor(0.7978845608f), nbdims); - auto prod2 = ctx->network()->addElementWise(*sum, *gelu_p2, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - auto tanh = ctx->network()->addActivation(*prod2, nvinfer1::ActivationType::kTANH)->getOutput(0); - auto const_one = expand_dims(ctx, ctx->ConvertTo1DTensor(1.f), nbdims); - auto sum2 = ctx->network()->addElementWise(*const_one, *tanh, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - auto prod3 = - ctx->network()->addElementWise(*sum2, *trt_in_tensor, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - auto gelu_p3 = expand_dims(ctx, ctx->ConvertTo1DTensor(0.5f), nbdims); - return ctx->network()->addElementWise(*prod3, *gelu_p3, nvinfer1::ElementWiseOperation::kPROD); -} - -nvinfer1::ILayer *ActivationTensorRT::AddActivation(TensorRTContext *ctx, ActivationType activation_type, float alpha, - float min_value, float max_value, nvinfer1::ITensor *trt_in_tensor, - const std::string &op_name, uint32_t device_id, - schema::QuantType quant_type, - RuntimePrecisionMode runtime_precision_mode) { - if (activation_type == ActivationType::HSWISH) { - return AddHSwishActivation(ctx, trt_in_tensor, op_name); - } - if (activation_type == ActivationType::GELU) { - return AddGeluActivation(ctx, trt_in_tensor, op_name); - } - // Just some action_code correct, unfind code is set to default relu. need double check. - auto action_param_opt = TryConvertActivationType(activation_type); - if (!action_param_opt) { - MS_LOG(ERROR) << "Unsupported op action type for TensorRT: " << activation_type; - return nullptr; - } - auto action_param = action_param_opt.value(); - nvinfer1::IActivationLayer *activation_layer = - ctx->network()->addActivation(*trt_in_tensor, action_param.activation_type); - if (activation_layer == nullptr) { - MS_LOG(ERROR) << "add activation op failed for TensorRT."; - return nullptr; - } - - if (activation_type == ActivationType::HARD_TANH) { - activation_layer->setAlpha(min_value); - activation_layer->setBeta(max_value); - return activation_layer; - } - - if (activation_type == ActivationType::SWISH) { - auto sigmoid_tensor = activation_layer->getOutput(0); - nvinfer1::ElementWiseOperation element_wise_op_ = nvinfer1::ElementWiseOperation::kPROD; - nvinfer1::IElementWiseLayer *swish_layer = - ctx->network()->addElementWise(*sigmoid_tensor, *trt_in_tensor, element_wise_op_); - if (swish_layer == nullptr) { - MS_LOG(ERROR) << "add activation op failed for TensorRT."; - return nullptr; - } - return swish_layer; - } - - if (action_param.has_alpha) { - activation_layer->setAlpha(alpha); - } - - if (action_param.has_beta) { - activation_layer->setBeta(action_param.beta); - } - - return activation_layer; -} -REGISTER_TENSORRT_CREATOR(ops::kNameActivation, ActivationTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/activation_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/activation_tensorrt.h deleted file mode 100644 index d62d19fbe04f1b006fc4d4b57249e95ac407a09b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/activation_tensorrt.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ACTIVATION_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ACTIVATION_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class ActivationTensorRT : public TensorRTOp { - public: - ActivationTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~ActivationTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - static nvinfer1::ILayer *AddActivation( - TensorRTContext *ctx, ActivationType activation_type, float alpha, float min_value, float max_value, - nvinfer1::ITensor *trt_in_tensor, const std::string &op_name, uint32_t device_id = 0, - schema::QuantType quant_type = schema::QuantType_QUANT_NONE, - RuntimePrecisionMode runtime_precision_mode = RuntimePrecisionMode::RuntimePrecisionMode_FP32); - - private: - static nvinfer1::ILayer *AddHSwishActivation(TensorRTContext *ctx, nvinfer1::ITensor *trt_in_tensor, - const std::string &op_name); - static nvinfer1::ILayer *AddGeluActivation(TensorRTContext *ctx, nvinfer1::ITensor *trt_in_tensor, - const std::string &op_name); -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ACTIVATION_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/addn_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/addn_tensorrt.cc deleted file mode 100644 index 930eb5abcbadbacac543db90e9068e3d47aff6ee..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/addn_tensorrt.cc +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/addn_tensorrt.h" -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" - -namespace mindspore::lite { -int AddNTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() <= 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - return RET_OK; -} - -int AddNTensorRT::AddInnerOp(TensorRTContext *ctx) { - auto *add_layer = ctx->network()->addElementWise(*input(ctx, 0).trt_tensor_, *input(ctx, 1).trt_tensor_, - nvinfer1::ElementWiseOperation::kSUM); - if (add_layer == nullptr) { - MS_LOG(ERROR) << "addElementWise failed for TensorRT : " << op_name_; - return RET_ERROR; - } - - nvinfer1::ITensor *out_tensor = add_layer->getOutput(0); - for (size_t i = 2; i < in_tensors_.size(); ++i) { - add_layer = - ctx->network()->addElementWise(*out_tensor, *input(ctx, i).trt_tensor_, nvinfer1::ElementWiseOperation::kSUM); - } - this->layer_ = add_layer; - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameAddN, AddNTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/addn_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/addn_tensorrt.h deleted file mode 100644 index 97d92bb24854f2717cebd6cb4a534479a1dd1e7e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/addn_tensorrt.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_ADDN_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_ADDN_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class AddNTensorRT : public TensorRTOp { - public: - AddNTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~AddNTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_ADDN_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/akg_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/akg_tensorrt.cc deleted file mode 100644 index 123db81d340e3b5bcb847568a01e3a92814be905..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/akg_tensorrt.cc +++ /dev/null @@ -1,334 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/akg_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "infer/attention.h" -#include "infer/custom.h" -#include "tools/graph_kernel/common/utils.h" - -namespace mindspore::lite { -namespace { -std::vector SplitString(const std::string &raw_str, char delimiter) { - std::vector res; - std::string::size_type last_pos = 0; - auto cur_pos = raw_str.find(delimiter); - while (cur_pos != std::string::npos) { - (void)res.emplace_back(raw_str.substr(last_pos, cur_pos - last_pos)); - cur_pos++; - last_pos = cur_pos; - cur_pos = raw_str.find(delimiter, cur_pos); - } - if (last_pos < raw_str.size()) { - (void)res.emplace_back(raw_str.substr(last_pos, raw_str.size() - last_pos + 1)); - } - return res; -} - -int GetCustomShape(const std::string &attr, std::vector> *shapes) { - auto split_shape_str = SplitString(attr, ','); - for (size_t i = 0; i < split_shape_str.size(); i++) { - size_t dim = std::stoul(split_shape_str[i]); - if (i + dim >= split_shape_str.size()) { - MS_LOG(ERROR) << "Shape string is invalid. The shape dim is " << dim << ", but only " - << split_shape_str.size() - i << " values follow."; - return RET_ERROR; - } - std::vector shape; - for (size_t j = i + 1; j <= i + dim; j++) { - shape.push_back(std::stoi(split_shape_str[j])); - } - i += dim; - shapes->push_back(shape); - } - return RET_OK; -} -} // namespace - -std::string ReadFileToString(std::string filename) { - std::ifstream ifile(filename.c_str()); - std::ostringstream buf; - char ch; - while (buf && ifile.get(ch)) { - buf.put(ch); - } - return buf.str(); -} - -int AkgTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - - auto akg_op = AsOps(); - if (akg_op == nullptr) { - MS_LOG(ERROR) << "op action convert failed"; - return RET_ERROR; - } - - auto attr_map = akg_op->get_attr(); - AkgParamT params; - auto res = memset_s(¶ms, sizeof(params), 0, sizeof(params)); - if (res != EOK) { - MS_LOG(ERROR) << "memset_s output was truncated or an error occurred."; - return RET_ERROR; - } - std::string ptx_path_str = - std::string(reinterpret_cast(attr_map["ptx_path"].data()), attr_map["ptx_path"].size()); - params.ptx_path_len = ptx_path_str.length(); - auto res0 = snprintf_s(params.ptx_path, LEN_LIMIT, LEN_LIMIT - 1, "%s", ptx_path_str.c_str()); - if (res0 < 0) { - MS_LOG(ERROR) << "snprintf_s encountered an encoding error or a runtime-constraint violation."; - } else if (res0 >= static_cast(LEN_LIMIT)) { - MS_LOG(ERROR) << "snprintf_s output was truncated."; - } - - std::string kernel_name_str = - std::string(reinterpret_cast(attr_map["kernel_name"].data()), attr_map["kernel_name"].size()); - kernel_name_str += "_kernel0"; - auto res1 = snprintf_s(params.kernel_name, LEN_LIMIT, LEN_LIMIT - 1, "%s", kernel_name_str.c_str()); - if (res1 < 0) { - MS_LOG(ERROR) << "snprintf_s encountered an encoding error or a runtime-constraint violation."; - } else if (res1 >= static_cast(LEN_LIMIT)) { - MS_LOG(ERROR) << "snprintf_s output was truncated."; - } - std::string outputs_shape_str(reinterpret_cast(attr_map["outputs_shape"].data()), - attr_map["outputs_shape"].size()); - - std::vector> outputs_shape; - (void)GetCustomShape(outputs_shape_str, &outputs_shape); - size_t idx = 0; - size_t num_output = 0; - for (auto shp : outputs_shape) { - for (auto v : shp) { - params.output_shapes[idx] = v; - idx += 1; - } - params.output_shapes_separators[num_output] = idx; - num_output += 1; - } - - std::string outputs_type_str(reinterpret_cast(attr_map["outputs_type"].data()), - attr_map["outputs_type"].size()); - std::vector outputs_type = SplitString(outputs_type_str, ','); - params.output_types_len = outputs_type.size(); - for (size_t i = 0; i < outputs_type.size(); i++) { - params.output_types[i] = std::stoul(outputs_type[i]); - } - - params.bx = - std::stoul(std::string(reinterpret_cast(attr_map["GridDimX"].data()), attr_map["GridDimX"].size())); - params.by = - std::stoul(std::string(reinterpret_cast(attr_map["GridDimY"].data()), attr_map["GridDimY"].size())); - params.bz = - std::stoul(std::string(reinterpret_cast(attr_map["GridDimZ"].data()), attr_map["GridDimZ"].size())); - params.tx = - std::stoul(std::string(reinterpret_cast(attr_map["BlockDimX"].data()), attr_map["BlockDimX"].size())); - params.ty = - std::stoul(std::string(reinterpret_cast(attr_map["BlockDimY"].data()), attr_map["BlockDimY"].size())); - params.tz = - std::stoul(std::string(reinterpret_cast(attr_map["BlockDimZ"].data()), attr_map["BlockDimZ"].size())); - - nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_; - auto plugin = std::make_shared(input_tensor->getName(), params, device_id_); - const size_t input_number = inputs().size(); - nvinfer1::ITensor *inputTensors[input_number]; - for (size_t i = 0; i < input_number; i++) { - inputTensors[i] = input(ctx, i).trt_tensor_; - } - nvinfer1::IPluginV2Layer *akg_layer = ctx->network()->addPluginV2(inputTensors, input_number, *plugin); - if (akg_layer == nullptr) { - MS_LOG(ERROR) << "add akg op failed for TensorRT."; - return RET_ERROR; - } - akg_layer->setName((op_name_ + "plugin_akg").c_str()); // should be set as Fused_MatMul_Add_XXX_fusion?? - const size_t output_number = outputs().size(); - for (size_t i = 0; i < output_number; i++) { - nvinfer1::ITensor *out_tensor = akg_layer->getOutput(i); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::DEFAULT_FORMAT, false, true}, out_tensors_[i].Name()); - } - this->layer_ = akg_layer; - return RET_OK; -} - -int AkgTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - dynamic_shape_params_.support_dynamic_ = false; - dynamic_shape_params_.support_hw_dynamic_ = false; - return RET_OK; -} - -// PLUGIN of Akg Layer -REGISTER_TENSORRT_PLUGIN(AkgPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -CUresult AkgPlugin::GetFunction(CUfunction *func) { - CUmodule module; - CUjit_option options[1]; - options[0] = CU_JIT_MAX_REGISTERS; - void *values[1]; - int total_threads = params_.tx * params_.ty * params_.tz; - int total_warps = std::ceil(static_cast(total_threads) / static_cast(WARP_SIZE)); - int limit_warps = (total_warps + WARP_ALLOC_GRAN - 1) / WARP_ALLOC_GRAN * WARP_ALLOC_GRAN; - int total_register_unit_nums = MAX_REGISTER_PER_THREAD_BLOCK / REGISTER_UNIT_IN_WARP; - int register_unit_nums_per_warp = total_register_unit_nums / limit_warps; - int64_t register_nums = (register_unit_nums_per_warp * REGISTER_UNIT_IN_WARP) / WARP_SIZE; - values[0] = reinterpret_cast(register_nums); - - std::string ptx_path_str = std::string(params_.ptx_path); - CUresult result = cuModuleLoadDataEx(&module, ReadFileToString(ptx_path_str).c_str(), 1, options, values); - if (result != CUDA_SUCCESS) { - const char *msg = nullptr; - cuGetErrorName(result, &msg); - MS_LOG(ERROR) << "cuModuleLoadDataEx failed. Kernel name: << " << params_.kernel_name << ". Error message: " << msg; - return result; - } - result = cuModuleGetFunction(func, module, params_.kernel_name); - if (result != CUDA_SUCCESS) { - const char *msg = nullptr; - cuGetErrorName(result, &msg); - MS_LOG(ERROR) << "cuModuleGetFunction failed. Kernel name: << " << params_.kernel_name - << ". Error message: " << msg; - return result; - } - return result; -} - -bool AkgPlugin::Launch(const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) { - if (stream == 0) { - MS_LOG(ERROR) << "stream should not be nullptr. Kernel name: " << params_.kernel_name; - return false; - } - std::vector runtimeargs; - void *addr_ptrs[num_of_inputs_ + num_of_outputs_]; - void *addr[num_of_inputs_ + num_of_outputs_]; - for (size_t i = 0; i < num_of_inputs_; i++) { - addr[i] = const_cast(inputs[i]); - addr_ptrs[i] = &(addr[i]); - runtimeargs.push_back(addr_ptrs[i]); - } - for (size_t i = 0; i < num_of_outputs_; i++) { - addr[i + num_of_inputs_] = const_cast(outputs[i]); - addr_ptrs[i + num_of_inputs_] = &(addr[i + num_of_inputs_]); - runtimeargs.push_back(addr_ptrs[i + num_of_inputs_]); - } - - CUresult result = cuLaunchKernel(kernel_addr_, params_.bx, params_.by, params_.bz, params_.tx, params_.ty, params_.tz, - 0, stream, reinterpret_cast(&runtimeargs[0]), 0); - if (result != CUDA_SUCCESS) { - const char *msg = nullptr; - cuGetErrorName(result, &msg); - MS_LOG(ERROR) << "Launch kernel failed. Kernel name: " << params_.kernel_name - << ". cuLaunchKernel error message: " << msg; - return false; - } - return true; -} - -int AkgPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - (void)Launch(inputs, outputs, workspace, stream); - return RET_OK; -} - -bool AkgPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept { - return true; -} - -nvinfer1::DimsExprs AkgPlugin::getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs dims; - size_t start = 0; - size_t end = 0; - if (index != 0) { - start = params_.output_shapes_separators[index - 1]; - } - end = params_.output_shapes_separators[index]; - dims.nbDims = end - start; - for (size_t i = start; i < end; i++) { - dims.d[i - start] = exprBuilder.constant(params_.output_shapes[i]); - } - return dims; -} - -nvinfer1::DataType AkgPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept { - if (params_.output_types[index] == size_t(DataType::kNumberTypeFloat32)) { - return nvinfer1::DataType::kFLOAT; - } else if (params_.output_types[index] == size_t(DataType::kNumberTypeFloat16)) { - return nvinfer1::DataType::kHALF; - } else if (params_.output_types[index] == size_t(DataType::kNumberTypeInt32)) { - return nvinfer1::DataType::kINT32; - } else if (params_.output_types[index] == size_t(DataType::kNumberTypeInt8)) { - return nvinfer1::DataType::kINT8; - } else if (params_.output_types[index] == size_t(DataType::kNumberTypeBool)) { - return nvinfer1::DataType::kBOOL; - } else { - MS_EXCEPTION(TypeError); - } -} - -nvinfer1::IPluginV2DynamicExt *AkgPlugin::clone() const noexcept { - auto *plugin = new AkgPlugin(*this); - if (plugin == nullptr) { - MS_LOG(ERROR) << "plugin is null"; - return nullptr; - } - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -void AkgPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept { - num_of_inputs_ = nbInputs; - num_of_outputs_ = nbOutputs; -} -int AkgPlugin::initialize() noexcept { - CUresult result = GetFunction(&kernel_addr_); - if (result != CUDA_SUCCESS) { - const char *msg = nullptr; - cuGetErrorName(result, &msg); - MS_EXCEPTION(RuntimeError) << "Get function " << params_.kernel_name << " failed. Error message: " << msg; - } - return 0; -} - -void AkgPlugin::terminate() noexcept {} - -size_t AkgPlugin::getSerializationSize() const noexcept { return sizeof(AkgParamT); } - -void AkgPlugin::serialize(void *buffer) const noexcept { SerializeValue(&buffer, ¶ms_, sizeof(AkgParamT)); } - -REGISTER_TENSORRT_CREATOR("CustomAkgGpu", AkgTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/akg_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/akg_tensorrt.h deleted file mode 100644 index 44a411e826174baeedd52194b0f6a4573799870b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/akg_tensorrt.h +++ /dev/null @@ -1,119 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_AKG_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_AKG_TENSORRT_H_ - -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -const int MAX_REGISTER_PER_THREAD_BLOCK = 65536; -const int REGISTER_UNIT_IN_WARP = 256; -const int WARP_SIZE = 32; -const int WARP_ALLOC_GRAN = 4; -const int LEN_LIMIT = 200; - -typedef struct { - char ptx_path[LEN_LIMIT]; - char kernel_name[LEN_LIMIT]; - size_t ptx_path_len; - uint32_t output_shapes[LEN_LIMIT]; - uint32_t output_shapes_separators[LEN_LIMIT]; - size_t output_types_len; - size_t output_types[LEN_LIMIT]; - size_t bx; - size_t by; - size_t bz; - size_t tx; - size_t ty; - size_t tz; -} AkgParamT; -class AkgTensorRT : public TensorRTOp { - public: - AkgTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) { - dynamic_shape_params_.support_dynamic_ = false; - dynamic_shape_params_.support_hw_dynamic_ = false; - } - - ~AkgTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto AKG_PLUGIN_NAME{"AkgPlugin"}; -class AkgPlugin : public TensorRTPlugin { - public: - AkgPlugin(const std::string name, AkgParamT params, uint32_t device_id) - : TensorRTPlugin(name, std::string(AKG_PLUGIN_NAME), device_id), params_(params) {} - - AkgPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(AKG_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - params_ = static_cast(fields[0].data)[0]; - } - - AkgPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(AKG_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, ¶ms_, sizeof(AkgParamT)); - } - - AkgPlugin() = delete; - - ~AkgPlugin() override {} - - CUresult GetFunction(CUfunction *func); - bool Launch(const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream); - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override; - void terminate() noexcept override; - int initialize() noexcept override; - - private: - const std::string layer_name_; - std::string name_space_; - mutable AkgParamT params_; - CUfunction kernel_addr_{nullptr}; - size_t num_of_inputs_ = 0; - size_t num_of_outputs_ = 0; - size_t num_of_workspace_ = 0; -}; -class AkgPluginCreater : public TensorRTPluginCreater { - public: - AkgPluginCreater() : TensorRTPluginCreater(std::string(AKG_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_AKG_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/allgather_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/allgather_tensorrt.cc deleted file mode 100644 index bd4e5844e5557048c7d4e30c86ac5d77abd63d2e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/allgather_tensorrt.cc +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/allgather_tensorrt.h" -#include -#include "NvInferRuntimeCommon.h" -#include "infer/all_gather.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" - -namespace mindspore::lite { -REGISTER_TENSORRT_PLUGIN(AllGatherPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int AllGatherTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { -#ifndef LITE_CUDA_DISTRIBUTION - MS_LOG(ERROR) - << "Unsupported package for gpu distribution feature, please recompile with MS_ENABLE_CUDA_DISTRIBUTION set to on."; - return RET_ERROR; -#else - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "invalid input tensor size: " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "invalid output tensor size: " << out_tensors.size(); - return RET_ERROR; - } - dynamic_shape_params_.support_hw_dynamic_ = false; - return RET_OK; -#endif -} - -int AllGatherTensorRT::AddInnerOp(TensorRTContext *ctx) { - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_}; - auto allgather_op = AsOps(); - if (allgather_op == nullptr) { - MS_LOG(ERROR) << "convert failed for " << op_name_; - return RET_ERROR; - } - int rank = GetGPUGroupSize(); - auto plugin = std::make_shared(op_name_, rank, device_id_); - MS_LOG(INFO) << op_name_ << " group size: " << rank << ", rank id: " << GetRankID(); - nvinfer1::IPluginV2Layer *allgather_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); - if (allgather_layer == nullptr) { - MS_LOG(ERROR) << "create AllGather layer failed for: " << op_name_; - return RET_ERROR; - } - nvinfer1::ITensor *allgather_out = allgather_layer->getOutput(0); - allgather_layer->setName(op_name_.c_str()); - ctx->RegisterTensor(ITensorHelper{allgather_out, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - this->layer_ = allgather_layer; - return RET_OK; -} - -// AllGatherPlugin -int AllGatherPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - MS_LOG(INFO) << "all gather run at rank id: " << GetRankID() << " stream: " << stream; - nvinfer1::Dims input_dims = inputDesc[0].dims; - int send_element_cnt = std::accumulate(input_dims.d, input_dims.d + input_dims.nbDims, 1, std::multiplies()); - const void *input = inputs[0]; - void *output = outputs[0]; - auto ret = DistributionCollective::instance().AllGatherWrapper(input, output, send_element_cnt, inputDesc->type, - stream, NCCL_WORLD_GROUP); - if (ret != RET_OK) { - MS_LOG(ERROR) << "AllGather nccl run failed for " << layer_name_; - return ret; - } - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *AllGatherPlugin::clone() const noexcept { - auto *plugin = new AllGatherPlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -nvinfer1::DimsExprs AllGatherPlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs out_dims{}; - out_dims.nbDims = inputs->nbDims; - auto rank_dim = exprBuilder.constant(rank_); - out_dims.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs->d[0], *rank_dim); - for (int i = 1; i < inputs->nbDims; i++) { - out_dims.d[i] = inputs->d[i]; - } - return out_dims; -} -REGISTER_TENSORRT_CREATOR(ops::kNameAllGather, AllGatherTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/allgather_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/allgather_tensorrt.h deleted file mode 100644 index 667f095f484f2a11f995719ca48794efb58544fd..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/allgather_tensorrt.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ALLGATHER_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ALLGATHER_TENSORRT_H_ -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" -#include "src/extendrt/delegate/tensorrt/distribution/distribution_collective.h" - -namespace mindspore::lite { -constexpr auto ALLGATHER_PLUGIN_NAME{"AllGatherPlugin"}; -class AllGatherTensorRT : public TensorRTOp { - public: - AllGatherTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~AllGatherTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -class AllGatherPlugin : public TensorRTPlugin { - public: - AllGatherPlugin(const std::string name, int rank, uint32_t device_id) - : TensorRTPlugin(name, std::string(ALLGATHER_PLUGIN_NAME), device_id), rank_(rank) {} - - AllGatherPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(ALLGATHER_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - rank_ = static_cast(fields[0].data)[0]; - } - - AllGatherPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(ALLGATHER_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &rank_, sizeof(int)); - } - - AllGatherPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - - private: - int rank_{0}; -}; -class AllGatherPluginCreater : public TensorRTPluginCreater { - public: - AllGatherPluginCreater() : TensorRTPluginCreater(std::string(ALLGATHER_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ALLGATHER_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/allreduce_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/allreduce_tensorrt.cc deleted file mode 100644 index 59bab91f93a37b11ede139ea4e34c19390fd1f9a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/allreduce_tensorrt.cc +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/allreduce_tensorrt.h" -#include -#include -#include "NvInferRuntimeCommon.h" -#include "infer/all_reduce.h" - -namespace mindspore::lite { -REGISTER_TENSORRT_PLUGIN(AllReducePluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int AllReduceTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { -#ifndef LITE_CUDA_DISTRIBUTION - MS_LOG(ERROR) - << "Unsupported package for gpu distribution feature, please recompile with MS_ENABLE_CUDA_DISTRIBUTION set to on."; - return RET_ERROR; -#else - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "invalid input tensor size: " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "invalid output tensor size: " << out_tensors.size(); - return RET_ERROR; - } - dynamic_shape_params_.support_hw_dynamic_ = false; - return RET_OK; -#endif -} - -int AllReduceTensorRT::AddInnerOp(TensorRTContext *ctx) { - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_}; - auto reduce_op = AsOps(); - if (reduce_op == nullptr) { - MS_LOG(ERROR) << "convert failed for " << op_name_; - return RET_ERROR; - } - auto reduce_mode = ReduceMode::Reduce_Sum; - auto rank = GetGPUGroupSize(); - auto plugin = std::make_shared(op_name_, reduce_mode, rank, device_id_); - MS_LOG(INFO) << op_name_ << " group size: " << rank << ", rank id: " << GetRankID(); - nvinfer1::IPluginV2Layer *allreduce_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); - if (allreduce_layer == nullptr) { - MS_LOG(ERROR) << "create AllReduce layer failed for: " << op_name_; - return RET_ERROR; - } - nvinfer1::ITensor *allreduce_out = allreduce_layer->getOutput(0); - allreduce_layer->setName(op_name_.c_str()); - ctx->RegisterTensor(ITensorHelper{allreduce_out, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - this->layer_ = allreduce_layer; - return RET_OK; -} - -// AllReducePlugin -int AllReducePlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - MS_LOG(INFO) << "AllReduce run at rank id: " << GetRankID() << " stream: " << stream; - nvinfer1::Dims output_dims = outputDesc[0].dims; - int recieve_element_cnt = - std::accumulate(output_dims.d, output_dims.d + output_dims.nbDims, 1, std::multiplies()); - const void *input = inputs[0]; - void *output = outputs[0]; - auto data_type = inputDesc->type; - auto ret = DistributionCollective::instance().AllReduceWrapper(input, output, recieve_element_cnt, data_type, - red_mode_, stream, NCCL_WORLD_GROUP); - if (ret != RET_OK) { - MS_LOG(ERROR) << "AllReduce nccl run failed for " << layer_name_; - return ret; - } - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *AllReducePlugin::clone() const noexcept { - auto *plugin = new AllReducePlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -nvinfer1::DimsExprs AllReducePlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs out_dims{}; - out_dims.nbDims = inputs->nbDims; - for (int i = 0; i < inputs->nbDims; i++) { - out_dims.d[i] = inputs->d[i]; - } - return out_dims; -} - -size_t AllReducePlugin::getSerializationSize() const noexcept { return sizeof(ReduceMode); } - -void AllReducePlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &red_mode_, sizeof(ReduceMode)); -} - -REGISTER_TENSORRT_CREATOR(ops::kAllReduce, AllReduceTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/allreduce_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/allreduce_tensorrt.h deleted file mode 100644 index 25029ef31cbe4376ac7441fd15aa4a738e65d3c5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/allreduce_tensorrt.h +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ALLREDUCE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ALLREDUCE_TENSORRT_H_ -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "src/extendrt/delegate/tensorrt/distribution/distribution_collective.h" - -namespace mindspore::lite { -constexpr auto ALLREDUCE_PLUGIN_NAME{"AllReducePlugin"}; -class AllReduceTensorRT : public TensorRTOp { - public: - AllReduceTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~AllReduceTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -class AllReducePlugin : public TensorRTPlugin { - public: - AllReducePlugin(const std::string name, ReduceMode red_mode, int rank, uint32_t device_id) - : TensorRTPlugin(name, std::string(ALLREDUCE_PLUGIN_NAME), device_id), red_mode_(red_mode), rank_(rank) {} - - AllReducePlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(ALLREDUCE_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - red_mode_ = static_cast(fields[0].data)[0]; - rank_ = static_cast(fields[1].data)[0]; - } - - AllReducePlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(ALLREDUCE_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &red_mode_, sizeof(ReduceMode)); - DeserializeValue(&serialData, &serialLength, &rank_, sizeof(int)); - } - - AllReducePlugin() = delete; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - - private: - ReduceMode red_mode_; - int rank_{0}; -}; -class AllReducePluginCreater : public TensorRTPluginCreater { - public: - AllReducePluginCreater() : TensorRTPluginCreater(std::string(ALLREDUCE_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ALLREDUCE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/batchnorm_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/batchnorm_tensorrt.cc deleted file mode 100644 index 7f39e25a490934009405e1c1204afda5f70eaeae..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/batchnorm_tensorrt.cc +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/batchnorm_tensorrt.h" -#include -#include -#include "infer/fused_batch_norm.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" - -namespace mindspore::lite { -int BatchNormTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE5 && in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int BatchNormTensorRT::AddInnerOp(TensorRTContext *ctx) { - CHECK_NULL_RETURN(ctx->network()); - auto norm_op = AsOps(); - CHECK_NULL_RETURN(norm_op); - epsilon_ = norm_op->get_epsilon(); - - ITensorHelper norm_input; - int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &norm_input); - if (ret != RET_OK || norm_input.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim norm_input failed for " << op_name_; - return RET_ERROR; - } - auto expect_shape = ConvertMSShape(norm_input.trt_tensor_->getDimensions()); - gamma_ = ConvertTensorWithExpandDims(ctx, in_tensors_[1], expect_shape, op_name_ + in_tensors_[1].Name()); - CHECK_NULL_RETURN(gamma_); - beta_ = - ConvertTensorWithExpandDims(ctx, in_tensors_[BETA_INDEX], expect_shape, op_name_ + in_tensors_[BETA_INDEX].Name()); - CHECK_NULL_RETURN(beta_); - mean_ = - ConvertTensorWithExpandDims(ctx, in_tensors_[MEAN_INDEX], expect_shape, op_name_ + in_tensors_[MEAN_INDEX].Name()); - CHECK_NULL_RETURN(mean_); - var_ = - ConvertTensorWithExpandDims(ctx, in_tensors_[VAR_INDEX], expect_shape, op_name_ + in_tensors_[VAR_INDEX].Name()); - CHECK_NULL_RETURN(var_); - - return RunAsTrtOps(ctx, norm_input); -} - -int BatchNormTensorRT::RunAsTrtOps(TensorRTContext *ctx, ITensorHelper norm_input) { - // var + min epsilon - auto const_epsilon = ConvertScalarToITensor(ctx, norm_input.trt_tensor_->getDimensions().nbDims, &epsilon_, - DataType::kNumberTypeFloat32, op_name_ + "_epsilion"); - CHECK_NULL_RETURN(const_epsilon); - auto var_epsilon = - ctx->network()->addElementWise(*var_, *const_epsilon, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - CHECK_NULL_RETURN(var_epsilon); - - // standard deviation - auto std_dev = ctx->network()->addUnary(*var_epsilon, nvinfer1::UnaryOperation::kSQRT)->getOutput(0); - CHECK_NULL_RETURN(std_dev); - - auto scale = ctx->network()->addElementWise(*gamma_, *std_dev, nvinfer1::ElementWiseOperation::kDIV)->getOutput(0); - CHECK_NULL_RETURN(scale); - - auto mean_mul_scale = - ctx->network()->addElementWise(*mean_, *scale, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - CHECK_NULL_RETURN(mean_mul_scale); - - auto bias = - ctx->network()->addElementWise(*beta_, *mean_mul_scale, nvinfer1::ElementWiseOperation::kSUB)->getOutput(0); - CHECK_NULL_RETURN(bias); - - // scale with bias - auto scale_layer = - ctx->network()->addElementWise(*norm_input.trt_tensor_, *scale, nvinfer1::ElementWiseOperation::kPROD); - this->layer_ = scale_layer; - auto scale_out = scale_layer->getOutput(0); - CHECK_NULL_RETURN(scale_out); - auto beta_out = ctx->network()->addElementWise(*scale_out, *bias, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - CHECK_NULL_RETURN(beta_out); - ctx->RegisterTensor(ITensorHelper{beta_out, Format::NCHW, true}, out_tensors_[0].Name()); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameFusedBatchNorm, BatchNormTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/batchnorm_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/batchnorm_tensorrt.h deleted file mode 100644 index 7c7b97544fe65d326b12264a9f3966a21d5c25f9..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/batchnorm_tensorrt.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_BATCH_NORM_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_BATCH_NORM_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -constexpr int BETA_INDEX = 2; -constexpr int MEAN_INDEX = 3; -constexpr int VAR_INDEX = 4; - -class BatchNormTensorRT : public TensorRTOp { - public: - BatchNormTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~BatchNormTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int RunAsTrtOps(TensorRTContext *ctx, ITensorHelper helper); - - float epsilon_{0.0f}; - nvinfer1::ITensor *gamma_{nullptr}; - nvinfer1::ITensor *beta_{nullptr}; - nvinfer1::ITensor *mean_{nullptr}; - nvinfer1::ITensor *var_{nullptr}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_BATCH_NORM_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/batchtospace_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/batchtospace_tensorrt.cc deleted file mode 100644 index e8d53516a217f449bcf53a7ea4af745c831f710e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/batchtospace_tensorrt.cc +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/batchtospace_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "kernel/gpu/cuda_impl/cuda_ops/batchtospace_impl.cuh" -#include "infer/batch_to_space.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" - -namespace mindspore::lite { -int BatchToSpaceTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - - if (out_tensors.size() < 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int BatchToSpaceTensorRT::AddInnerOp(TensorRTContext *ctx) { - nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_; - auto block_size_vec = ConvertTensorAsIntVector(in_tensors_[1]); - constexpr size_t block_input_elem_count = 2; - if (block_size_vec.size() != block_input_elem_count) { - MS_LOG(ERROR) << "Failed to get block input, block size " << block_size_vec.size() << ", node: " << op_name_; - return RET_ERROR; - } - int bh = block_size_vec[0]; - int bw = block_size_vec[1]; - if (bh != bw) { - MS_LOG(ERROR) << "block_h not equal block_w " << op_name_; - return RET_ERROR; - } - auto pad_vec = ConvertTensorAsIntVector(in_tensors_[INPUT_SIZE2]); - constexpr size_t pad_input_elem_count = 4; - if (pad_vec.size() != pad_input_elem_count) { - MS_LOG(ERROR) << "Failed to get pad input, pad size " << pad_vec.size() << ", node: " << op_name_; - return RET_ERROR; - } - int ph0 = pad_vec[0]; - int ph1 = pad_vec[1]; - int pw0 = pad_vec[INPUT_SIZE2]; - int pw1 = pad_vec[INPUT_SIZE3]; - - auto plugin = std::make_shared(input_tensor->getName(), bh, ph0, ph1, pw0, pw1, device_id_); - if (plugin == nullptr) { - MS_LOG(ERROR) << "add batchtospace plugin failed for" << op_name_; - return RET_ERROR; - } - nvinfer1::ITensor *inputTensors[] = {input_tensor}; - nvinfer1::IPluginV2Layer *space2batch_opt_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); - if (space2batch_opt_layer == nullptr) { - MS_LOG(ERROR) << "add batchtospace op failed for TensorRT."; - return RET_ERROR; - } - space2batch_opt_layer->setName(op_name_.c_str()); - nvinfer1::ITensor *out_tensor = space2batch_opt_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - this->layer_ = space2batch_opt_layer; - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(BatchToSpacePluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int BatchToSpacePlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - return RunCudaBatchToSpace(inputDesc, inputs, outputs, stream); -} - -int BatchToSpacePlugin::RunCudaBatchToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - nvinfer1::Dims input_dims = inputDesc[0].dims; - int in = input_dims.d[0]; - int ic = input_dims.d[1]; - int ih = input_dims.d[2]; - int iw = input_dims.d[3]; - int on = in / (bh_ * bh_); - int oc = ic; - int oh = ih * bh_ - ph0_ - ph1_; - int ow = iw * bh_ - pw0_ - pw1_; - - int size = on * oc * oh * ow; - - CalBatchToSpace(size, static_cast(inputs[0]), in, ih, iw, ic, on, oh, ow, oc, ph0_, ph1_, pw0_, - pw1_, bh_, static_cast(outputs[0]), device_id_, stream); - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *BatchToSpacePlugin::clone() const noexcept { - auto *plugin = new (std::nothrow) BatchToSpacePlugin(*this); - if (plugin == nullptr) { - MS_LOG(ERROR) << "new plugin failed!"; - return nullptr; - } - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -size_t BatchToSpacePlugin::getSerializationSize() const noexcept { return sizeof(int) * 5; } - -nvinfer1::DimsExprs BatchToSpacePlugin::getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, - int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs dims; - dims.nbDims = inputs[0].nbDims; - auto bh_mul_bh = exprBuilder.constant(bh_ * bh_); - dims.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, *inputs[0].d[0], *bh_mul_bh); - auto bh = exprBuilder.constant(bh_); - dims.d[1] = inputs[0].d[1]; - auto ph_sum = exprBuilder.constant(ph0_ + ph1_); - auto pod0 = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[0].d[INPUT_SIZE2], *bh); - dims.d[INPUT_SIZE2] = exprBuilder.operation(nvinfer1::DimensionOperation::kSUB, *pod0, *ph_sum); - auto pw_sum = exprBuilder.constant(pw0_ + pw1_); - auto pod1 = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[0].d[INPUT_SIZE3], *bh); - dims.d[INPUT_SIZE3] = exprBuilder.operation(nvinfer1::DimensionOperation::kSUB, *pod1, *pw_sum); - return dims; -} - -void BatchToSpacePlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &bh_, sizeof(int)); - SerializeValue(&buffer, &ph0_, sizeof(int)); - SerializeValue(&buffer, &ph1_, sizeof(int)); - SerializeValue(&buffer, &pw0_, sizeof(int)); - SerializeValue(&buffer, &pw1_, sizeof(int)); -} -REGISTER_TENSORRT_CREATOR(ops::kNameBatchToSpace, BatchToSpaceTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/batchtospace_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/batchtospace_tensorrt.h deleted file mode 100644 index 1b5dc12e6806534635879ed01a65723100ce3b65..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/batchtospace_tensorrt.h +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_BATCHTOSPACETENSORRT_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_BATCHTOSPACETENSORRT_PLUGIN_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class BatchToSpaceTensorRT : public TensorRTOp { - public: - BatchToSpaceTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~BatchToSpaceTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto BATCHTOSPACETENSORRT_PLUGIN_NAME{"BatchToSpacePlugin"}; -class BatchToSpacePlugin : public TensorRTPlugin { - public: - BatchToSpacePlugin(const std::string name, int bh, int ph0, int ph1, int pw0, int pw1, uint32_t device_id) - : TensorRTPlugin(name, std::string(BATCHTOSPACETENSORRT_PLUGIN_NAME), device_id), - bh_(bh), - ph0_(ph0), - ph1_(ph1), - pw0_(pw0), - pw1_(pw1) {} - - BatchToSpacePlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(BATCHTOSPACETENSORRT_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - bh_ = static_cast(fields[0].data)[0]; - ph0_ = static_cast(fields[1].data)[0]; - ph1_ = static_cast(fields[2].data)[0]; - pw0_ = static_cast(fields[3].data)[0]; - pw1_ = static_cast(fields[4].data)[0]; - } - - BatchToSpacePlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(BATCHTOSPACETENSORRT_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &bh_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &ph0_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &ph1_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &pw0_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &pw1_, sizeof(int)); - } - - BatchToSpacePlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept { - return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT && - tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; - } - - private: - int RunCudaBatchToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - int bh_; - int ph0_; - int ph1_; - int pw0_; - int pw1_; - const std::string layer_name_; - std::string name_space_; -}; -class BatchToSpacePluginCreater : public TensorRTPluginCreater { - public: - BatchToSpacePluginCreater() : TensorRTPluginCreater(std::string(BATCHTOSPACETENSORRT_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_BATCHTOSPACETENSORRT_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/bounding_box_decode_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/bounding_box_decode_tensorrt.cc deleted file mode 100644 index cb94c53207c788483ded4ced095f872e3fd52f61..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/bounding_box_decode_tensorrt.cc +++ /dev/null @@ -1,137 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/bounding_box_decode_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "kernel/gpu/cuda_impl/cuda_ops/boundingbox_decode_impl.cuh" -#include "infer/bounding_box_decode.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" - -namespace mindspore::lite { -int BoundingBoxDecodeTensorRT::IsSupport(const BaseOperatorPtr &base_operator, - const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int BoundingBoxDecodeTensorRT::AddInnerOp(TensorRTContext *ctx) { - auto op = AsOps(); - auto wh_ratio_clip_attr = op->GetAttr("wh_ratio_clip"); - float wh_ratio_clip = GetValue(wh_ratio_clip_attr); - - auto max_shape_attr = op->GetAttr("max_shape"); - std::vector max_shape = GetValue>(max_shape_attr); - if (max_shape.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "max_shape size not equal 2 for " << op_name_; - } - std::vector max_shape_32(INPUT_SIZE2); - max_shape_32[0] = max_shape[0]; - max_shape_32[1] = max_shape[1]; - - auto means_attr = op->GetAttr("means"); - std::vector means = GetValue>(means_attr); - auto stds_attr = op->GetAttr("stds"); - std::vector stds = GetValue>(stds_attr); - - auto plugin = std::make_shared(op_name_, means, stds, max_shape_32, wh_ratio_clip); - if (plugin == nullptr) { - MS_LOG(ERROR) << "create ActivationOptPlugin failed for " << op_name_; - return RET_ERROR; - } - auto in_tensor1 = input(ctx, 0).trt_tensor_; - auto in_tensor2 = input(ctx, 1).trt_tensor_; - if (in_tensors_[0].DataType() == DataType::kNumberTypeFloat16) { - in_tensor1 = TRTTensorCast(ctx, in_tensor1, nvinfer1::DataType::kFLOAT, op_name_ + "_cast_in_0"); - } - if (in_tensors_[1].DataType() == DataType::kNumberTypeFloat16) { - in_tensor2 = TRTTensorCast(ctx, in_tensor2, nvinfer1::DataType::kFLOAT, op_name_ + "_cast_in_1"); - } - nvinfer1::ITensor *inputTensors[] = {in_tensor1, in_tensor2}; - nvinfer1::IPluginV2Layer *layer = ctx->network()->addPluginV2(inputTensors, INPUT_SIZE2, *plugin); - this->layer_ = layer; - nvinfer1::ITensor *op_out_tensor = layer->getOutput(0); - if (op_out_tensor == nullptr) { - MS_LOG(ERROR) << "addElementWise out tensor is nullptr."; - return RET_ERROR; - } - ctx->RegisterTensor(ITensorHelper{op_out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(BoundingBoxDecodePluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int BoundingBoxDecodePlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - return RunCudaBoundingBoxDecode(inputDesc, inputs, outputs, stream); -} - -int BoundingBoxDecodePlugin::RunCudaBoundingBoxDecode(const nvinfer1::PluginTensorDesc *inputDesc, - const void *const *inputs, void *const *outputs, - cudaStream_t stream) { - BoundingBoxDecode(GetDimsVolume(inputDesc[0].dims), static_cast(inputs[0]), - static_cast(inputs[1]), static_cast(outputs[0]), means_[0], - means_[1], means_[INPUT_SIZE2], means_[INPUT_SIZE3], stds_[0], stds_[1], stds_[INPUT_SIZE2], - stds_[INPUT_SIZE3], max_shape_[0], max_shape_[1], wh_ratio_clip_, device_id_, stream); - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *BoundingBoxDecodePlugin::clone() const noexcept { - auto *plugin = new BoundingBoxDecodePlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -size_t BoundingBoxDecodePlugin::getSerializationSize() const noexcept { - return sizeof(float) * (INPUT_SIZE4 + INPUT_SIZE4 + 1) + sizeof(int) * INPUT_SIZE2; -} - -bool BoundingBoxDecodePlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, - int nbInputs, int nbOutputs) noexcept { - return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT && - tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; -} - -void BoundingBoxDecodePlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &means_[0], sizeof(float) * INPUT_SIZE4); - SerializeValue(&buffer, &stds_[0], sizeof(float) * INPUT_SIZE4); - SerializeValue(&buffer, &max_shape_[0], sizeof(int) * INPUT_SIZE2); - SerializeValue(&buffer, &wh_ratio_clip_, sizeof(float)); -} - -REGISTER_TENSORRT_CREATOR(ops::kNameBoundingBoxDecode, BoundingBoxDecodeTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/bounding_box_decode_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/bounding_box_decode_tensorrt.h deleted file mode 100644 index 051a2f9a84d0eb25c37053caad5f200bb7d3544a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/bounding_box_decode_tensorrt.h +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_BOUNDING_BOX_DECODE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_BOUNDING_BOX_DECODE_TENSORRT_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class BoundingBoxDecodeTensorRT : public TensorRTOp { - public: - BoundingBoxDecodeTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~BoundingBoxDecodeTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto BOUNDING_BOX_DECODE_PLUGIN_NAME{"BoundingBoxDecodePlugin"}; -class BoundingBoxDecodePlugin : public TensorRTPlugin { - public: - BoundingBoxDecodePlugin(const std::string name, const std::vector &means, const std::vector &stds, - const std::vector &max_shape, float wh_ratio_clip) - : TensorRTPlugin(name, std::string(BOUNDING_BOX_DECODE_PLUGIN_NAME)), - means_(means), - stds_(stds), - max_shape_(max_shape), - wh_ratio_clip_(wh_ratio_clip) {} - - BoundingBoxDecodePlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(BOUNDING_BOX_DECODE_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - means_.resize(INPUT_SIZE4); - means_[0] = static_cast(fields[0].data)[0]; - means_[1] = static_cast(fields[1].data)[0]; - means_[INPUT_SIZE2] = static_cast(fields[INPUT_SIZE2].data)[0]; - means_[INPUT_SIZE3] = static_cast(fields[INPUT_SIZE3].data)[0]; - stds_.resize(INPUT_SIZE4); - stds_[0] = static_cast(fields[INPUT_SIZE4].data)[0]; - stds_[1] = static_cast(fields[INPUT_SIZE5].data)[0]; - stds_[INPUT_SIZE2] = static_cast(fields[INPUT_SIZE6].data)[0]; - stds_[INPUT_SIZE3] = static_cast(fields[INPUT_SIZE7].data)[0]; - max_shape_.resize(INPUT_SIZE2); - max_shape_[0] = static_cast(fields[INPUT_SIZE8].data)[0]; - max_shape_[1] = static_cast(fields[INPUT_SIZE9].data)[0]; - wh_ratio_clip_ = static_cast(fields[INPUT_SIZE10].data)[0]; - } - - BoundingBoxDecodePlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(BOUNDING_BOX_DECODE_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &means_[0], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &means_[1], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &means_[INPUT_SIZE2], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &means_[INPUT_SIZE3], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &stds_[0], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &stds_[1], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &stds_[INPUT_SIZE2], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &stds_[INPUT_SIZE3], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &max_shape_[0], sizeof(int)); - DeserializeValue(&serialData, &serialLength, &max_shape_[1], sizeof(int)); - DeserializeValue(&serialData, &serialLength, &wh_ratio_clip_, sizeof(float)); - } - - BoundingBoxDecodePlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - int getNbOutputs() const noexcept override { return 1; } - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override { - return nvinfer1::DataType::kFLOAT; - } - - private: - int RunCudaBoundingBoxDecode(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream); - const std::string layer_name_; - std::string name_space_; - std::vector means_; - std::vector stds_; - std::vector max_shape_; - float wh_ratio_clip_; -}; -class BoundingBoxDecodePluginCreater : public TensorRTPluginCreater { - public: - BoundingBoxDecodePluginCreater() : TensorRTPluginCreater(std::string(BOUNDING_BOX_DECODE_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_BOUNDING_BOX_DECODE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/cast_plugin.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/cast_plugin.cc deleted file mode 100644 index bafb9f60be492ba9a6759ae244e087a2a517f295..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/cast_plugin.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/cast_plugin.h" -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/cuda_impl/cast.cuh" - -namespace mindspore::lite { -REGISTER_TENSORRT_PLUGIN(CastPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int CastPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - nvinfer1::Dims input_dims = inputDesc[0].dims; - int element_cnt = std::accumulate(input_dims.d, input_dims.d + input_dims.nbDims, 1, std::multiplies()); - - if (inputDesc->type == outputDesc->type) { - int element_size = (outputDesc->type == nvinfer1::DataType::kFLOAT) - ? sizeof(float) - : ((outputDesc->type == nvinfer1::DataType::kINT32) ? sizeof(int) : 0); - auto cuda_ret = cudaMemcpy(outputs[0], inputs[0], element_cnt * element_size, cudaMemcpyDeviceToDevice); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "copy mem failed for " << layer_name_; - return RET_ERROR; - } - return RET_OK; - } - if (inputDesc->type == nvinfer1::DataType::kINT32 && dest_datatype_ == nvinfer1::DataType::kFLOAT) { - auto input = static_cast(inputs[0]); - auto output = static_cast(outputs[0]); - Cast(element_cnt, input, output, stream); - } else if (inputDesc->type == nvinfer1::DataType::kFLOAT && dest_datatype_ == nvinfer1::DataType::kINT32) { - auto input = static_cast(inputs[0]); - auto output = static_cast(outputs[0]); - Cast(element_cnt, input, output, stream); - } else { - MS_LOG(ERROR) << "unsupported data type cast " << layer_name_; - } - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *CastPlugin::clone() const noexcept { - auto *plugin = new CastPlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -nvinfer1::DataType CastPlugin::getOutputDataType(int, const nvinfer1::DataType *, int) const noexcept { - return dest_datatype_; -} - -size_t CastPlugin::getSerializationSize() const noexcept { return sizeof(nvinfer1::DataType); } - -void CastPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &dest_datatype_, sizeof(nvinfer1::DataType)); -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/cast_plugin.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/cast_plugin.h deleted file mode 100644 index 61c1b3a4a8073f21a5cbc3112dbb442cd2969746..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/cast_plugin.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CAST_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CAST_PLUGIN_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -constexpr auto CAST_PLUGIN_NAME{"CastPluginCreater"}; -class CastPlugin : public TensorRTPlugin { - public: - CastPlugin(const std::string &name, nvinfer1::DataType dest_datatype, uint32_t device_id = 0) - : TensorRTPlugin(name, std::string(CAST_PLUGIN_NAME), device_id), dest_datatype_(dest_datatype) {} - - CastPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(CAST_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - dest_datatype_ = static_cast(fields[0].data)[0]; - } - - CastPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(CAST_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &dest_datatype_, sizeof(nvinfer1::DataType)); - } - - CastPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - - nvinfer1::DataType getOutputDataType(int, const nvinfer1::DataType *, int) const noexcept override; - - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - - private: - nvinfer1::DataType dest_datatype_; -}; -class CastPluginCreater : public TensorRTPluginCreater { - public: - CastPluginCreater() : TensorRTPluginCreater(std::string(CAST_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CAST_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/cast_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/cast_tensorrt.cc deleted file mode 100644 index d43534b73cb43f91c776eb9080efee222a0988ca..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/cast_tensorrt.cc +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/cast_tensorrt.h" -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/cast_plugin.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore::lite { -int CastTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "invalid input tensor size: " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "invalid output tensor size: " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int CastTensorRT::AddInnerOp(TensorRTContext *ctx) { - // cast to type tensor - auto type_tensor = in_tensors_[1]; - if (!type_tensor.IsConst()) { - MS_LOG(ERROR) << "unknown cast type of " << op_name_; - return RET_ERROR; - } - auto type_vec = ConvertTensorAsIntVector(type_tensor); - if (type_vec.size() != 1) { - MS_LOG(ERROR) << "Failed to get type input, type size " << type_vec.size() << ", node: " << op_name_; - return RET_ERROR; - } - DataType data_type = static_cast(type_vec[0]); - MS_LOG(DEBUG) << op_name_ << " cast to data type(43 float): " << type_vec[0]; - nvinfer1::DataType dest_datatype = ConvertDataType(data_type); - auto trt_tensor = input(ctx, 0).trt_tensor_; - -#if TRT_VERSION_GE(7, 2) - dest_datatype = (dest_datatype == nvinfer1::DataType::kBOOL ? nvinfer1::DataType::kINT32 : dest_datatype); - auto cast_layer = ctx->network()->addIdentity(*trt_tensor); -#else - auto plugin = std::make_shared(op_name_, dest_datatype); - nvinfer1::ITensor *inputTensors[] = {trt_tensor}; - nvinfer1::IPluginV2Layer *cast_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); -#endif - if (cast_layer == nullptr) { - MS_LOG(ERROR) << "create cast layer failed for: " << op_name_; - return RET_ERROR; - } -#if TRT_VERSION_GE(7, 2) - cast_layer->setOutputType(0, dest_datatype); -#endif - cast_layer->setName(op_name_.c_str()); - nvinfer1::ITensor *cast_out = cast_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{cast_out, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - this->layer_ = cast_layer; - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameCast, CastTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/cast_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/cast_tensorrt.h deleted file mode 100644 index 1b75dffb455a30bac867ac6711ea83824f115481..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/cast_tensorrt.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CAST_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CAST_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/cast.cuh" - -namespace mindspore::lite { -class CastTensorRT : public TensorRTOp { - public: - CastTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~CastTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return true; } - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - // CastTensorRT -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CAST_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/concate_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/concate_tensorrt.cc deleted file mode 100644 index e420675b81e7327bb64c17bcc0adf66f09ae4a6d..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/concate_tensorrt.cc +++ /dev/null @@ -1,160 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/concate_tensorrt.h" -#include -#include -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -ConcateTensorRT::ConcateTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) { - if (type_ == ops::kNameConcat) { - axis_ = AsOps()->get_axis(); - } else { - axis_ = AsOps()->get_axis(); - } -} - -int ConcateTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (type_ != ops::kNameStack && type_ != ops::kNameConcat) { - MS_LOG(ERROR) << "Unsupported op :" << op_name_ << " , type: " << type_; - return RET_ERROR; - } - if (in_tensors.size() == 0) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - - return RET_OK; -} -int ConcateTensorRT::CheckParams(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - int input_nbDims = input(ctx, 0).trt_tensor_->getDimensions().nbDims; - if (axis_ == -1) { - axis_ = input_nbDims - (type_ == ops::kNameConcat); - } - if (axis_ < 0 || axis_ > input_nbDims || (axis_ == input_nbDims && type_ != ops::kNameStack)) { - MS_LOG(ERROR) << "concate_op valid axis : " << axis_ << " , input dims : " << input_nbDims; - return RET_ERROR; - } - - if (SizeToInt(in_tensors_.size()) != ReadyInputsNumber(ctx)) { - MS_LOG(ERROR) << "concate_op in tensor is invalid, trt tensor has " << in_tensors_.size() - << ", but origin ms tensor has " << in_tensors_.size(); - return RET_ERROR; - } - return RET_OK; -} - -int ConcateTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (CheckParams(ctx) != RET_OK) { - MS_LOG(ERROR) << "Check input tensors failed: " << op_name_; - return RET_ERROR; - } - if (type_ == ops::kNameConcat && in_tensors_.size() == 1) { - auto output = ctx->network()->addIdentity(*input(ctx, 0).trt_tensor_)->getOutput(0); - ctx->RegisterTensor(ITensorHelper{output}, out_tensors_[0].Name()); - return RET_OK; - } - - nvinfer1::ITensor *trt_input_tensors[in_tensors_.size()]; - int ret = PreProcessInputs(ctx, trt_input_tensors); - if (ret != RET_OK) { - MS_LOG(ERROR) << "PreProcessInputs failed for " << op_name_; - return ret; - } - - bool has_rank_0 = false; - for (size_t i = 0; i < in_tensors_.size(); ++i) { - if (!input(ctx, i).is_tensor) { - has_rank_0 = true; - break; - } - } - if (type_ == ops::kNameStack && !has_rank_0) { - for (size_t i = 0; i < in_tensors_.size(); ++i) { - auto shuffle_layer = ctx->network()->addShuffle(*trt_input_tensors[i]); - if (shuffle_layer == nullptr) { - MS_LOG(ERROR) << "addShuffle failed for TensorRT."; - return RET_ERROR; - } - bool has_rank_n = (trt_input_tensors[i]->getDimensions().nbDims > 1); - if (has_rank_n) { - auto shuffer_dims_opt = UnsqueezeDims(trt_input_tensors[i]->getDimensions(), axis_, 1); - if (!shuffer_dims_opt) { - MS_LOG(ERROR) << "UnsqueezeDims failed."; - return RET_ERROR; - } - shuffle_layer->setReshapeDimensions(shuffer_dims_opt.value()); - trt_input_tensors[i] = shuffle_layer->getOutput(0); - } - } - } - nvinfer1::IConcatenationLayer *concate_layer = - ctx->network()->addConcatenation(trt_input_tensors, static_cast(in_tensors_.size())); - if (concate_layer == nullptr) { - MS_LOG(ERROR) << "addConcatenation failed for TensorRT."; - return RET_ERROR; - } - - if (axis_ != RET_INVALID_OP_ATTR) { - concate_layer->setAxis(axis_); - } - concate_layer->setName(op_name_.c_str()); - auto concat_output = concate_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{concat_output, NCHW, true}, out_tensors_[0].Name()); - this->layer_ = concate_layer; - return RET_OK; -} - -int ConcateTensorRT::PreProcessInputs(TensorRTContext *ctx, nvinfer1::ITensor *trt_input_tensors[]) { - int input_nbDims = input(ctx, 0).trt_tensor_->getDimensions().nbDims; - out_format_ = input(ctx, 0).format_; - same_format_ = input(ctx, 0).same_format_; - - for (size_t i = 0; i < in_tensors_.size(); i++) { - if (input(ctx, i).trt_tensor_->getDimensions().nbDims != input_nbDims) { - MS_LOG(ERROR) << "dims of inputs is invalid for " << in_tensors_[i].Name() - << " input dim size : " << input(ctx, i).trt_tensor_->getDimensions().nbDims - << " ms input dims size : " << input_nbDims; - return RET_ERROR; - } - // keep origin format if all input format are the same - if (input_nbDims == DIMENSION_4D && input(ctx, i).format_ != out_format_) { - out_format_ = Format::NCHW; - } - } - - for (size_t i = 0; i < in_tensors_.size(); i++) { - trt_input_tensors[i] = input(ctx, i).trt_tensor_; - MS_LOG(DEBUG) << "concate input " << GetTensorFormat(input(ctx, i)); - } - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameConcat, ConcateTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameStack, ConcateTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/concate_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/concate_tensorrt.h deleted file mode 100644 index 06aac81b267d8c2bf2a8f7fdf2b26aac7a1c7157..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/concate_tensorrt.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONCATE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONCATE_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "infer/stack.h" - -namespace mindspore::lite { -class ConcateTensorRT : public TensorRTOp { - public: - ConcateTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name); - - ~ConcateTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int PreProcessInputs(TensorRTContext *ctx, nvinfer1::ITensor *trt_input_tensors[]); - int CheckParams(TensorRTContext *ctx); - - Format out_format_{Format::NCHW}; - bool same_format_{true}; - int axis_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONCATE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/constantofshape_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/constantofshape_tensorrt.cc deleted file mode 100644 index de8d35d3ef3563a5c262c0d5ee5e751fdc83447a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/constantofshape_tensorrt.cc +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/constantofshape_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "infer/constant_of_shape.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore::lite { -int ConstantOfShapeTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - auto constofshape_op = AsOps(); - if (constofshape_op == nullptr) { - MS_LOG(ERROR) << "Failed to as operator ConstantOfShape: " << op_name_; - return RET_ERROR; - } - DataType data_type = static_cast(constofshape_op->get_data_type()); - if (data_type != DataType::kNumberTypeInt32 && data_type != DataType::kNumberTypeFloat32) { - MS_LOG(ERROR) << "Unsupported data type for " << op_name_; - return RET_ERROR; - } - return RET_OK; -} - -int ConstantOfShapeTensorRT::AddInnerOp(TensorRTContext *ctx) { - auto constofshape_op = AsOps(); - if (constofshape_op == nullptr) { - MS_LOG(ERROR) << "Failed to as operator ConstantOfShape: " << op_name_; - return RET_ERROR; - } - auto &&value_vector = constofshape_op->get_value(); - nvinfer1::ITensor *value_tensor; - if (static_cast(constofshape_op->get_data_type()) == DataType::kNumberTypeInt32) { - auto value = static_cast(*value_vector.begin()); - value_tensor = ctx->ConvertTo1DTensor(value); - } else { - auto value = *value_vector.begin(); - value_tensor = ctx->ConvertTo1DTensor(value); - } - CHECK_NULL_RETURN(value_tensor); - - auto unsqueeze_layer = ctx->network()->addShuffle(*value_tensor); - CHECK_NULL_RETURN(unsqueeze_layer); - - auto shape = input(ctx, 0).trt_tensor_; - int rank = shape->getDimensions().d[0]; - nvinfer1::Dims unsqueeze{rank}; - std::fill(unsqueeze.d, unsqueeze.d + rank, 1); - unsqueeze_layer->setReshapeDimensions(unsqueeze); - unsqueeze_layer->setZeroIsPlaceholder(false); - value_tensor = unsqueeze_layer->getOutput(0); - CHECK_NULL_RETURN(value_tensor); - - auto out_tensor = Broadcast(ctx, value_tensor, shape); - if (static_cast(constofshape_op->get_data_type()) == DataType::kNumberTypeInt32) { - out_tensor = TRTTensorCast(ctx, out_tensor, nvinfer1::DataType::kINT32, op_name_ + "_cast_out"); - } - - auto output_helper = ITensorHelper{out_tensor, Format::NCHW, true}; - ctx->RegisterTensor(output_helper, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "output " << GetTensorFormat(output_helper); - this->layer_ = unsqueeze_layer; - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameConstantOfShape, ConstantOfShapeTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/constantofshape_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/constantofshape_tensorrt.h deleted file mode 100644 index 03c2c390f24dd8b47a9872f66604b323ca17844f..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/constantofshape_tensorrt.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_CONSTANT_OF_SHAPE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_CONSTANT_OF_SHAPE_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class ConstantOfShapeTensorRT : public TensorRTOp { - public: - ConstantOfShapeTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~ConstantOfShapeTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_CONSTANT_OF_SHAPE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.cc deleted file mode 100644 index 608c3f4e00bdf31e279b31a06f13b9e4de580f0c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.cc +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.h" -#include -#include "nnacl/pack.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore::lite { -int Conv2dBackpropInputTensorRT::IsSupport(const BaseOperatorPtr &base_operator, - const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} -int Conv2dBackpropInputTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - auto deconv_op = AsOps(); - if (deconv_op == nullptr) { - MS_LOG(ERROR) << "op action convert failed"; - return RET_ERROR; - } - nvinfer1::ITensor *deconv_input = input(ctx, 0).trt_tensor_; - - // transpose weight - const auto &weight_tensor = in_tensors_[1]; - nvinfer1::Weights kernelWeights = lite::ConvertWeight(weight_tensor); - - // deconv basic params - int nbOutputMaps = weight_tensor.Shape()[1]; - if (nbOutputMaps <= 0) { - MS_LOG(ERROR) << "out_channel is invalid"; - return RET_ERROR; - } - - auto kernel_size = deconv_op->get_kernel_size(); - if (kernel_size.empty()) { - MS_LOG(ERROR) << "kernel_size is null"; - return RET_ERROR; - } - nvinfer1::Dims kernelSize = lite::ConvertCudaDims(std::vector(kernel_size.begin(), kernel_size.end())); - if (kernelSize.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return RET_ERROR; - } - // bias - nvinfer1::Weights biasWeights{}; - biasWeights.type = ConvertDataType(weight_tensor.DataType()); - biasWeights.count = 0; - biasWeights.values = nullptr; - - nvinfer1::IDeconvolutionLayer *deconv_layer = - ctx->network()->addDeconvolutionNd(*deconv_input, nbOutputMaps, kernelSize, kernelWeights, biasWeights); - - if (deconv_layer == nullptr) { - MS_LOG(ERROR) << "DeconvolutionLayer failed"; - return RET_ERROR; - } - deconv_layer->setName((op_name_ + "_deconv").c_str()); - this->layer_ = deconv_layer; - // set extra params - SetAttributes(deconv_op, deconv_layer); - - nvinfer1::ITensor *out_tensor = deconv_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - return RET_OK; -} - -void Conv2dBackpropInputTensorRT::SetAttributes(const std::shared_ptr &ms_op, - nvinfer1::IDeconvolutionLayer *decon_layer) { - // kernel_size - auto kernel_size = ms_op->get_kernel_size(); - if (!kernel_size.empty()) { - auto kernel_size_val = std::vector(kernel_size.begin(), kernel_size.end()); - nvinfer1::Dims kernel_size_dims = lite::ConvertCudaDims(kernel_size_val); - if (kernel_size_dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return; - } - decon_layer->setKernelSizeNd(kernel_size_dims); - } - - // nbOutputMaps - int nbOutputMaps = in_tensors_[1].Shape()[1]; - decon_layer->setNbOutputMaps(nbOutputMaps); - - // stride - auto stride = ms_op->get_stride(); - if (!stride.empty()) { - auto stride_val = std::vector(stride.begin() + INPUT_SIZE2, stride.end()); - nvinfer1::Dims stride_dims = lite::ConvertCudaDims(stride_val); - if (stride_dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return; - } - decon_layer->setStrideNd(stride_dims); - } - - // nbGroups - int32_t nbGroups = static_cast(ms_op->get_group()); - decon_layer->setNbGroups(nbGroups); - - // padding - PadMode pad_mode = ms_op->get_pad_mode(); - if (pad_mode == PadMode::SAME) { - decon_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); - } else { - auto padding = ms_op->get_pad_list(); - auto padding_val = std::vector(padding.begin(), padding.end()); - nvinfer1::Dims dims_pre{}; - dims_pre.nbDims = DIMENSION_2D; - dims_pre.d[0] = padding_val[0]; // up - dims_pre.d[1] = padding_val[INPUT_SIZE2]; // left - decon_layer->setPrePadding(dims_pre); - nvinfer1::Dims dims_post{}; - dims_post.nbDims = DIMENSION_2D; - dims_post.d[0] = padding_val[1]; - dims_post.d[1] = padding_val[INPUT_SIZE3]; - decon_layer->setPostPadding(dims_post); - } -} - -REGISTER_TENSORRT_CREATOR(ops::kNameConv2DBackpropInputFusion, Conv2dBackpropInputTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.h deleted file mode 100644 index da5b1c91990655d7d788eb56562282fb9529f8b6..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv2dbackpropinput_tensorrt.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONV2DBACKPROPINPUT_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONV2DBACKPROPINPUT_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "infer/cxx_api/conv2d_backprop_input_fusion.h" - -namespace mindspore::lite { -class Conv2dBackpropInputTensorRT : public TensorRTOp { - public: - Conv2dBackpropInputTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~Conv2dBackpropInputTensorRT() = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - void SetAttributes(const std::shared_ptr &conv_op, - nvinfer1::IDeconvolutionLayer *decon_layer); -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONV2DBACKPROPINPUT_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv3d_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv3d_tensorrt.cc deleted file mode 100644 index ad02eabe89387fbb806c51a9974374561095f2a8..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/conv3d_tensorrt.cc +++ /dev/null @@ -1,150 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/conv3d_tensorrt.h" -#include -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore::lite { -constexpr int BIAS_INDEX = 2; -int Conv3DTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int Conv3DTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - auto conv_op = AsOps(); - if (conv_op == nullptr) { - MS_LOG(ERROR) << "op action convert failed"; - return RET_ERROR; - } - - nvinfer1::ITensor *conv_input = input(ctx, 0).trt_tensor_; - - // transpose weight - const auto &weight_tensor = in_tensors_[1]; - nvinfer1::Weights kernelWeights = lite::ConvertWeight(weight_tensor); - - // conv - int nbOutputMaps = weight_tensor.Shape()[0]; - if (nbOutputMaps <= 0) { - MS_LOG(ERROR) << "out_channel is invalid"; - return RET_ERROR; - } - - auto kernel_size = conv_op->get_kernel_size(); - if (kernel_size.empty()) { - MS_LOG(ERROR) << "kernel_size is null"; - return RET_ERROR; - } - nvinfer1::Dims kernelSize = lite::ConvertCudaDims(std::vector(kernel_size.begin(), kernel_size.end())); - if (kernelSize.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return RET_ERROR; - } - // bias not support - nvinfer1::Weights biasWeights{}; - - nvinfer1::IConvolutionLayer *conv_layer = - ctx->network()->addConvolutionNd(*conv_input, nbOutputMaps, kernelSize, kernelWeights, biasWeights); - - if (conv_layer == nullptr) { - MS_LOG(ERROR) << "ConvolutionLayer failed"; - return RET_ERROR; - } - conv_layer->setName((op_name_ + "_conv").c_str()); - this->layer_ = conv_layer; - - // add params - SetAttributes(conv_op, conv_layer); - - // add activation - auto out_tensor = conv_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - return RET_OK; -} - -void Conv3DTensorRT::SetAttributes(const std::shared_ptr &conv_op, - nvinfer1::IConvolutionLayer *conv_layer) { - auto stride = conv_op->get_stride(); - if (!stride.empty()) { - auto stride_val = std::vector(stride.begin() + 2, stride.end()); - auto dims = ConvertCudaDims(stride_val); - if (dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return; - } - conv_layer->setStrideNd(dims); - } - - auto dilation = conv_op->get_dilation(); - if (!dilation.empty()) { - auto dilation_val = std::vector(dilation.begin() + 2, dilation.end()); - auto dims = ConvertCudaDims(dilation_val); - if (dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return; - } - conv_layer->setDilationNd(dims); - } - int nbGroups = conv_op->get_group(); - if (nbGroups > 0) { - conv_layer->setNbGroups(nbGroups); - } - - PadMode pad_mode = conv_op->get_pad_mode(); - if (pad_mode == PadMode::SAME) { - conv_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); - } else { - std::vector padding; - padding = conv_op->get_pad(); - if (padding.size() == DIMENSION_6D) { - auto padding_val = std::vector(padding.begin(), padding.end()); - if (padding_val[0] != padding_val[1] || padding_val[DIMENSION_2D] != padding_val[DIMENSION_3D] || - padding_val[DIMENSION_4D] != padding_val[DIMENSION_5D]) { - MS_LOG(WARNING) << op_name_ << " has different up and down padding value"; - nvinfer1::Dims pre_dims{INPUT_SIZE3, {padding_val[0], padding_val[DIMENSION_2D], padding_val[DIMENSION_4D]}}; - conv_layer->setPrePadding(pre_dims); - nvinfer1::Dims post_dims{INPUT_SIZE3, {padding_val[1], padding_val[DIMENSION_3D], padding_val[DIMENSION_5D]}}; - conv_layer->setPostPadding(post_dims); - } else { - nvinfer1::Dims dims{INPUT_SIZE3, {padding_val[0], padding_val[DIMENSION_2D], padding_val[DIMENSION_4D]}}; - conv_layer->setPaddingNd(dims); - } - } else if (padding.empty()) { - nvinfer1::Dims3 dims; - conv_layer->setPaddingNd(dims); - } else { - MS_LOG(WARNING) << "pad list is invalid for " << op_name_; - } - } -} - -Conv3DTensorRT::~Conv3DTensorRT() {} -REGISTER_TENSORRT_CREATOR(ops::kNameConv3D, Conv3DTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/convolution_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/convolution_tensorrt.cc deleted file mode 100644 index 6cc8b7b3d134f32a6532df5535fc19b7a88c5881..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/convolution_tensorrt.cc +++ /dev/null @@ -1,191 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/convolution_tensorrt.h" -#include -#include "src/extendrt/delegate/tensorrt/op/activation_tensorrt.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore::lite { -constexpr int BIAS_INDEX = 2; - -int ConvolutionTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (in_tensors.size() != INPUT_SIZE2 && in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - if (in_tensors[0].format() != Format::NHWC && in_tensors[0].format() != Format::NCHW) { - MS_LOG(ERROR) << "Unsupported input tensor format of " << in_tensors[0].format(); - return RET_ERROR; - } - return RET_OK; -} - -int ConvolutionTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - auto conv_op = AsOps(); - if (conv_op == nullptr) { - MS_LOG(ERROR) << "op action convert failed"; - return RET_ERROR; - } - - nvinfer1::ITensor *conv_input = input(ctx, 0).trt_tensor_; - - // transpose weight - const auto &weight_tensor = in_tensors_[1]; - nvinfer1::Weights kernelWeights = lite::ConvertWeight(weight_tensor); - - // conv - int nbOutputMaps = weight_tensor.Shape()[0]; - if (nbOutputMaps <= 0) { - MS_LOG(ERROR) << "out_channel is invalid"; - return RET_ERROR; - } - - auto kernel_size = conv_op->get_kernel_size(); - if (kernel_size.empty()) { - MS_LOG(ERROR) << "kernel_size is null"; - return RET_ERROR; - } - nvinfer1::Dims kernelSize = lite::ConvertCudaDims(std::vector(kernel_size.begin(), kernel_size.end())); - if (kernelSize.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return RET_ERROR; - } - // bias - nvinfer1::Weights biasWeights{}; - if (in_tensors_.size() >= INPUT_SIZE3) { - biasWeights = lite::ConvertWeight(in_tensors_[BIAS_INDEX]); - } else { - biasWeights.type = ConvertDataType(weight_tensor.DataType()); - biasWeights.count = 0; - biasWeights.values = nullptr; - } - - nvinfer1::IConvolutionLayer *conv_layer = - ctx->network()->addConvolutionNd(*conv_input, nbOutputMaps, kernelSize, kernelWeights, biasWeights); - - if (conv_layer == nullptr) { - MS_LOG(ERROR) << "ConvolutionLayer failed"; - return RET_ERROR; - } - conv_layer->setName((op_name_ + "_conv").c_str()); - this->layer_ = conv_layer; - - // add params - SetAttributes(conv_op, conv_layer); - - // add activation - nvinfer1::ILayer *activation_layer = nullptr; - ActivationType activation_type = ActivationType::NO_ACTIVATION; - if (conv_op->HasAttr(ops::kActivationType)) { - activation_type = conv_op->get_activation_type(); - } - if (activation_type == ActivationType::NO_ACTIVATION) { - activation_layer = conv_layer; - } else { - activation_layer = - ActivationTensorRT::AddActivation(ctx, activation_type, 0, 0, 0, conv_layer->getOutput(0), op_name_, device_id_); - if (activation_layer == nullptr) { - MS_LOG(ERROR) << "addActivation for conv failed"; - return RET_ERROR; - } - activation_layer->setName((op_name_ + "_activation").c_str()); - } - auto out_tensor = activation_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - return RET_OK; -} - -void ConvolutionTensorRT::SetAttributes(const std::shared_ptr &conv_op, - nvinfer1::IConvolutionLayer *conv_layer) { - auto stride = conv_op->get_stride(); - if (!stride.empty()) { - auto stride_val = std::vector(stride.begin(), stride.end()); - auto dims = ConvertCudaDims(stride_val); - if (dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return; - } - conv_layer->setStrideNd(dims); - } - - auto dilation = conv_op->get_dilation(); - if (!dilation.empty()) { - auto dilation_val = std::vector(dilation.begin(), dilation.end()); - auto dims = ConvertCudaDims(dilation_val); - if (dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return; - } - conv_layer->setDilationNd(dims); - } - int nbGroups = conv_op->get_group(); - if (nbGroups > 0) { - conv_layer->setNbGroups(nbGroups); - } - - PadMode pad_mode = conv_op->get_pad_mode(); - if (pad_mode == PadMode::SAME) { - conv_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); - } else { - std::vector padding; - if (conv_op->HasAttr(ops::kPadList)) { - padding = conv_op->get_pad_list(); - } else if (conv_op->HasAttr(ops::kPad)) { - padding = conv_op->get_pad(); - } - if (padding.size() == DIMENSION_4D) { - auto padding_val = std::vector(padding.begin(), padding.end()); - if (padding_val[0] != padding_val[1] || padding_val[DIMENSION_2D] != padding_val[DIMENSION_3D]) { - MS_LOG(WARNING) << op_name_ << " has different up and down padding value"; - nvinfer1::Dims2 pre_dims(padding_val[0], padding_val[DIMENSION_2D]); - conv_layer->setPrePadding(pre_dims); - nvinfer1::Dims2 post_dims(padding_val[1], padding_val[DIMENSION_3D]); - conv_layer->setPostPadding(post_dims); - } else { - nvinfer1::Dims2 dims(padding_val[0], padding_val[DIMENSION_2D]); - conv_layer->setPaddingNd(dims); - } - } else if (padding.empty()) { - nvinfer1::Dims2 dims; - conv_layer->setPaddingNd(dims); - } else { - MS_LOG(WARNING) << "pad list is invalid for " << op_name_; - } - } -} - -ConvolutionTensorRT::~ConvolutionTensorRT() { - if (pack_weight_ != nullptr) { - free(pack_weight_); - pack_weight_ = nullptr; - } -} -REGISTER_TENSORRT_CREATOR(ops::kNameConv2DFusion, ConvolutionTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/convolution_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/convolution_tensorrt.h deleted file mode 100644 index da709f8019a45cf52c0ee498fe2823db4aade527..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/convolution_tensorrt.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONVOLUTION_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONVOLUTION_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "infer/cxx_api/conv2d_fusion.h" - -namespace mindspore::lite { -class ConvolutionTensorRT : public TensorRTOp { - public: - ConvolutionTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~ConvolutionTensorRT() override; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return true; } - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - void SetAttributes(const std::shared_ptr &conv_op, nvinfer1::IConvolutionLayer *current_layer_); - - void *pack_weight_{nullptr}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONVOLUTION_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/cumsum_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/cumsum_tensorrt.cc deleted file mode 100644 index 0a216592f0e04b31ff6f0fe3e5e951fa81ccc723..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/cumsum_tensorrt.cc +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/cumsum_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "kernel/gpu/cuda_impl/cuda_ops/cumsum_impl.cuh" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore::lite { -int CumsumTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - - if (out_tensors.size() < 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int CumsumTensorRT::AddInnerOp(TensorRTContext *ctx) { - ITensorHelper input_helper; - int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &input_helper); - if (ret != RET_OK || input_helper.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim input tensor failed for " << op_name_; - return ret; - } - auto axis_vec = ConvertTensorAsIntVector(in_tensors_[1]); - if (axis_vec.size() != 1) { - MS_LOG(ERROR) << "Failed to get axis input, axis size " << axis_vec.size() << ", node: " << op_name_; - return RET_ERROR; - } - int axis = axis_vec[0]; - auto cumsum_op = AsOps(); - bool exclusive = cumsum_op->get_exclusive(); - bool reverse = cumsum_op->get_reverse(); - auto plugin = - std::make_shared(input_helper.trt_tensor_->getName(), axis, exclusive, reverse, device_id_); - nvinfer1::ITensor *inputTensors[] = {input_helper.trt_tensor_}; - nvinfer1::IPluginV2Layer *cumsum_opt_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); - if (cumsum_opt_layer == nullptr) { - MS_LOG(ERROR) << "add cumsum op failed for TensorRT."; - return RET_ERROR; - } - cumsum_opt_layer->setName(op_name_.c_str()); - nvinfer1::ITensor *out_tensor = cumsum_opt_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, input_helper.format_, input_helper.same_format_}, - out_tensors_[0].Name()); - this->layer_ = cumsum_opt_layer; - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(CumsumPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int CumsumPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - return RunCudaCumsum(inputDesc, inputs, outputs, stream); -} - -int CumsumPlugin::RunCudaCumsum(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - auto &dims = inputDesc[0].dims; - size_t out_dim = 1; - for (int i = 0; i < axis_; ++i) { - out_dim *= dims.d[i]; - } - size_t in_dim = 1; - for (int i = axis_ + 1; i < dims.nbDims; ++i) { - in_dim *= dims.d[i]; - } - size_t axis_dim = dims.d[axis_]; - size_t stride = axis_dim * in_dim; - size_t stride2 = in_dim; - CumSum(static_cast(inputs[0]), static_cast(outputs[0]), nullptr, out_dim, - axis_dim, in_dim, stride, stride2, exclusive_, reverse_, device_id_, stream); - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *CumsumPlugin::clone() const noexcept { - auto *plugin = new (std::nothrow) CumsumPlugin(*this); - if (plugin == nullptr) { - MS_LOG(ERROR) << "new plugin failed!"; - return nullptr; - } - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -size_t CumsumPlugin::getSerializationSize() const noexcept { return sizeof(int) + 2 * sizeof(bool); } - -void CumsumPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &axis_, sizeof(int)); - SerializeValue(&buffer, &exclusive_, sizeof(bool)); - SerializeValue(&buffer, &reverse_, sizeof(bool)); -} -REGISTER_TENSORRT_CREATOR(ops::kNameCumSum, CumsumTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/cumsum_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/cumsum_tensorrt.h deleted file mode 100644 index b6441b5dab579976b787332db2c482e6e3e25175..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/cumsum_tensorrt.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CUMSUM_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CUMSUM_PLUGIN_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class CumsumTensorRT : public TensorRTOp { - public: - CumsumTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~CumsumTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto CUMSUM_PLUGIN_NAME{"CumsumPlugin"}; -class CumsumPlugin : public TensorRTPlugin { - public: - CumsumPlugin(const std::string name, int axis, bool exclusive, bool reverse, uint32_t device_id) - : TensorRTPlugin(name, std::string(CUMSUM_PLUGIN_NAME), device_id), - axis_(axis), - exclusive_(exclusive), - reverse_(reverse) {} - - CumsumPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(CUMSUM_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - axis_ = static_cast(fields[0].data)[0]; - exclusive_ = static_cast(fields[1].data)[0]; - reverse_ = static_cast(fields[2].data)[0]; - } - - CumsumPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(CUMSUM_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &axis_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &exclusive_, sizeof(bool)); - DeserializeValue(&serialData, &serialLength, &reverse_, sizeof(bool)); - } - - CumsumPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - - private: - int RunCudaCumsum(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - int axis_; - bool exclusive_; - bool reverse_; - const std::string layer_name_; - std::string name_space_; -}; -class CumsumPluginCreater : public TensorRTPluginCreater { - public: - CumsumPluginCreater() : TensorRTPluginCreater(std::string(CUMSUM_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CUMSUM_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.cc deleted file mode 100644 index 412913d0e8c76040d9d15e1741f11ac8198e41a8..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.cc +++ /dev/null @@ -1,165 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.h" -#include -#include "nnacl/pack.h" -#include "infer/cxx_api/conv2d_transpose_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore::lite { -int Deconv3dTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2 && in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} -int Deconv3dTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - auto deconv_op = AsOps(); - if (deconv_op == nullptr) { - MS_LOG(ERROR) << "op action convert failed"; - return RET_ERROR; - } - nvinfer1::ITensor *deconv_input = input(ctx, 0).trt_tensor_; - - // transpose weight - const auto &weight_tensor = in_tensors_[1]; - nvinfer1::Weights kernelWeights = lite::ConvertWeight(weight_tensor); - - // deconv basic params - int nbOutputMaps = weight_tensor.Shape()[0]; - if (nbOutputMaps <= 0) { - MS_LOG(ERROR) << "out_channel is invalid"; - return RET_ERROR; - } - - auto kernel_size = deconv_op->get_kernel_size(); - if (kernel_size.empty()) { - MS_LOG(ERROR) << "kernel_size is null"; - return RET_ERROR; - } - nvinfer1::Dims kernelSize = lite::ConvertCudaDims(std::vector(kernel_size.begin(), kernel_size.end())); - if (kernelSize.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return RET_ERROR; - } - // bias - nvinfer1::Weights biasWeights{}; - - nvinfer1::IDeconvolutionLayer *deconv_layer = - ctx->network()->addDeconvolutionNd(*deconv_input, nbOutputMaps, kernelSize, kernelWeights, biasWeights); - - if (deconv_layer == nullptr) { - MS_LOG(ERROR) << "Deconv3dLayer failed"; - return RET_ERROR; - } - deconv_layer->setName((op_name_ + "_deconv").c_str()); - // set extra params - SetAttributes(deconv_op, deconv_layer); - - nvinfer1::ITensor *out_tensor = deconv_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - this->layer_ = deconv_layer; - return RET_OK; -} - -void Deconv3dTensorRT::SetAttributes(const std::shared_ptr &ms_op, - nvinfer1::IDeconvolutionLayer *decon_layer) { - // kernel_size - auto kernel_size = ms_op->get_kernel_size(); - if (!kernel_size.empty()) { - auto kernel_size_val = std::vector(kernel_size.begin(), kernel_size.end()); - nvinfer1::Dims kernel_size_dims = lite::ConvertCudaDims(kernel_size_val); - if (kernel_size_dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return; - } - decon_layer->setKernelSizeNd(kernel_size_dims); - } - - // stride - auto stride = ms_op->get_stride(); - if (!stride.empty()) { - auto stride_val = std::vector(stride.begin() + INPUT_SIZE2, stride.end()); - nvinfer1::Dims strides = lite::ConvertCudaDims(stride_val); - if (strides.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return; - } - decon_layer->setStrideNd(strides); - } - - // nbOutputMaps - int nbOutputMaps = in_tensors_[1].Shape()[1]; - decon_layer->setNbOutputMaps(nbOutputMaps); - - // nbGroups - int32_t nbGroups = static_cast(ms_op->get_group()); - decon_layer->setNbGroups(nbGroups); - - // padding - PadMode pad_mode = ms_op->get_pad_mode(); - if (pad_mode == PadMode::SAME) { - decon_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); - } else { - std::vector padding; - if (ms_op->HasAttr(ops::kPadList)) { - auto pad_list_ptr = ms_op->GetAttr(ops::kPadList); - padding = GetValue>(pad_list_ptr); - } - std::vector out_pad; - if (ms_op->HasAttr(ops::kOutputPadding)) { - auto pad_list_ptr = ms_op->GetAttr(ops::kOutputPadding); - out_pad = GetValue>(pad_list_ptr); - } - if (padding.empty() || out_pad.empty()) { - MS_LOG(WARNING) << "on pad value of " << op_name_; - return; - } - auto padding_val = std::vector(padding.begin(), padding.end()); - auto out_pad_val = std::vector(out_pad.begin() + INPUT_SIZE2, out_pad.end()); // h, w - if (out_pad_val.size() != DIMENSION_3D || padding_val.size() != DIMENSION_6D) { - MS_LOG(ERROR) << "invalid size of pad " << op_name_; - return; - } - nvinfer1::Dims dims_pre{}; - dims_pre.nbDims = DIMENSION_3D; - dims_pre.d[0] = padding_val[0]; - dims_pre.d[1] = padding_val[INPUT_SIZE2]; - dims_pre.d[INPUT_SIZE2] = padding_val[INPUT_SIZE4]; - decon_layer->setPrePadding(dims_pre); - nvinfer1::Dims dims_post{}; - dims_post.nbDims = DIMENSION_3D; - dims_post.d[0] = padding_val[1] - out_pad_val[0]; - dims_post.d[1] = padding_val[INPUT_SIZE3] - out_pad_val[1]; - dims_post.d[INPUT_SIZE2] = padding_val[INPUT_SIZE5] - out_pad_val[INPUT_SIZE2]; - decon_layer->setPostPadding(dims_post); - } -} - -Deconv3dTensorRT::~Deconv3dTensorRT() {} -REGISTER_TENSORRT_CREATOR(ops::kNameConv3DTranspose, Deconv3dTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.h deleted file mode 100644 index d33a02916f7e01ba0e97e61db4d8d4e62c2f359a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconv3d_tensorrt.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONV3DTRANSPOSE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONV3DTRANSPOSE_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "infer/conv3d_transpose.h" - -namespace mindspore::lite { -class Deconv3dTensorRT : public TensorRTOp { - public: - Deconv3dTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~Deconv3dTensorRT() override; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - void SetAttributes(const std::shared_ptr &conv_op, nvinfer1::IDeconvolutionLayer *decon_layer); -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_CONV3DTRANSPOSE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.cc deleted file mode 100644 index 5856f63de313d4e3bf8d9ce81bb184fde9fd4399..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.cc +++ /dev/null @@ -1,192 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/deconvolution_tensorrt.h" -#include -#include "src/extendrt/delegate/tensorrt/op/activation_tensorrt.h" -#include "nnacl/pack.h" -#include "infer/cxx_api/conv2d_transpose_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore::lite { -int DeconvolutionTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2 && in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - if (in_tensors[0].format() != Format::NHWC && in_tensors[0].format() != Format::NCHW) { - MS_LOG(ERROR) << "Unsupported input tensor format of " << in_tensors[0].format(); - return RET_ERROR; - } - return RET_OK; -} -int DeconvolutionTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - auto deconv_op = AsOps(); - if (deconv_op == nullptr) { - MS_LOG(ERROR) << "op action convert failed"; - return RET_ERROR; - } - nvinfer1::ITensor *deconv_input = input(ctx, 0).trt_tensor_; - - // transpose weight - const auto &weight_tensor = in_tensors_[1]; - nvinfer1::Weights kernelWeights = lite::ConvertWeight(weight_tensor); - - // deconv basic params - int nbOutputMaps = weight_tensor.Shape()[0]; - if (nbOutputMaps <= 0) { - MS_LOG(ERROR) << "out_channel is invalid"; - return RET_ERROR; - } - - auto kernel_size = deconv_op->get_kernel_size(); - if (kernel_size.empty()) { - MS_LOG(ERROR) << "kernel_size is null"; - return RET_ERROR; - } - nvinfer1::Dims kernelSize = lite::ConvertCudaDims(std::vector(kernel_size.begin(), kernel_size.end())); - if (kernelSize.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return RET_ERROR; - } - // bias - nvinfer1::Weights biasWeights{}; - if (in_tensors_.size() >= INPUT_SIZE3) { - biasWeights = lite::ConvertWeight(in_tensors_[INPUT_SIZE3 - 1]); - } else { - biasWeights.type = ConvertDataType(weight_tensor.DataType()); - biasWeights.count = 0; - biasWeights.values = nullptr; - } - - nvinfer1::IDeconvolutionLayer *deconv_layer = - ctx->network()->addDeconvolutionNd(*deconv_input, nbOutputMaps, kernelSize, kernelWeights, biasWeights); - - if (deconv_layer == nullptr) { - MS_LOG(ERROR) << "DeconvolutionLayer failed"; - return RET_ERROR; - } - deconv_layer->setName((op_name_ + "_deconv").c_str()); - this->layer_ = deconv_layer; - // set extra params - SetAttributes(deconv_op, deconv_layer); - - // add activation - nvinfer1::ILayer *activation_layer = nullptr; - ActivationType activation_type = ActivationType::NO_ACTIVATION; - if (deconv_op->HasAttr(ops::kActivationType)) { - activation_type = deconv_op->get_activation_type(); - } - if (activation_type == ActivationType::NO_ACTIVATION) { - activation_layer = deconv_layer; - } else { - activation_layer = ActivationTensorRT::AddActivation(ctx, activation_type, 0, 0, 0, deconv_layer->getOutput(0), - op_name_, device_id_); - if (activation_layer == nullptr) { - MS_LOG(ERROR) << "addActivation for conv failed"; - return RET_ERROR; - } - activation_layer->setName((op_name_ + "_activation").c_str()); - } - nvinfer1::ITensor *out_tensor = activation_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - return RET_OK; -} - -void DeconvolutionTensorRT::SetAttributes(const std::shared_ptr &ms_op, - nvinfer1::IDeconvolutionLayer *decon_layer) { - // kernel_size - auto kernel_size = ms_op->get_kernel_size(); - if (!kernel_size.empty()) { - auto kernel_size_val = std::vector(kernel_size.begin(), kernel_size.end()); - nvinfer1::Dims kernel_size_dims = lite::ConvertCudaDims(kernel_size_val); - if (kernel_size_dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return; - } - decon_layer->setKernelSizeNd(kernel_size_dims); - } - - // nbOutputMaps - int nbOutputMaps = in_tensors_[1].Shape()[1]; - decon_layer->setNbOutputMaps(nbOutputMaps); - - // stride - auto stride = ms_op->get_stride(); - if (!stride.empty()) { - auto stride_val = std::vector(stride.begin(), stride.end()); - nvinfer1::Dims stride_dims = lite::ConvertCudaDims(stride_val); - if (stride_dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return; - } - decon_layer->setStrideNd(stride_dims); - } - - // nbGroups - int32_t nbGroups = static_cast(ms_op->get_group()); - decon_layer->setNbGroups(nbGroups); - - // padding - PadMode pad_mode = ms_op->get_pad_mode(); - if (pad_mode == PadMode::SAME) { - decon_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); - } else { - auto padding = ms_op->get_pad_list(); - std::vector out_pad(DIMENSION_2D, 0); - if (ms_op->HasAttr(ops::kOutputPaddings)) { - out_pad = ms_op->get_output_paddings(); - } - if (padding.empty() || out_pad.empty()) { - MS_LOG(WARNING) << "on pad value of " << op_name_; - return; - } - auto padding_val = std::vector(padding.begin(), padding.end()); - auto out_pad_val = std::vector(out_pad.begin(), out_pad.end()); // h, w - if (out_pad_val.size() != DIMENSION_2D || padding_val.size() != DIMENSION_4D) { - MS_LOG(ERROR) << "invalid size of pad " << op_name_; - return; - } - nvinfer1::Dims dims_pre{}; - dims_pre.nbDims = DIMENSION_2D; - dims_pre.d[0] = padding_val[0]; // up - dims_pre.d[1] = padding_val[2]; // left - decon_layer->setPrePadding(dims_pre); - nvinfer1::Dims dims_post{}; - dims_post.nbDims = DIMENSION_2D; - dims_post.d[0] = padding_val[1] - out_pad_val[0]; // down - dims_post.d[1] = padding_val[3] - out_pad_val[1]; // right - decon_layer->setPostPadding(dims_post); - } -} - -DeconvolutionTensorRT::~DeconvolutionTensorRT() { - if (pack_weight_ != nullptr) { - free(pack_weight_); - pack_weight_ = nullptr; - } -} -REGISTER_TENSORRT_CREATOR(ops::kNameConv2dTransposeFusion, DeconvolutionTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.cc deleted file mode 100644 index aaa3bf7cc915cbc715aa8b3c05761b33ed6432af..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.cc +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "kernel/gpu/cuda_impl/cuda_ops/depthtospace_impl.cuh" -#include "infer/depth_to_space.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" - -namespace mindspore::lite { -int DepthToSpaceTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - - if (out_tensors.size() < 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int DepthToSpaceTensorRT::AddInnerOp(TensorRTContext *ctx) { - nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_; - auto op = AsOps(); - int block_size = op->get_block_size(); - - auto plugin = std::make_shared(input_tensor->getName(), block_size, device_id_); - if (plugin == nullptr) { - MS_LOG(ERROR) << "add depthtospace plugin failed for" << op_name_; - return RET_ERROR; - } - nvinfer1::ITensor *inputTensors[] = {input_tensor}; - nvinfer1::IPluginV2Layer *layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); - if (layer == nullptr) { - MS_LOG(ERROR) << "add depthtospace op failed for TensorRT."; - return RET_ERROR; - } - layer->setName(op_name_.c_str()); - nvinfer1::ITensor *out_tensor = layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - this->layer_ = layer; - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(DepthToSpacePluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int DepthToSpacePlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - return RunCudaDepthToSpace(inputDesc, inputs, outputs, stream); -} - -int DepthToSpacePlugin::RunCudaDepthToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - nvinfer1::Dims input_dims = inputDesc[0].dims; - int in = input_dims.d[0]; - int ic = input_dims.d[1]; - int ih = input_dims.d[2]; - int iw = input_dims.d[3]; - int on = in; - int oc = ic / block_size_ / block_size_; - int oh = ih * block_size_; - int ow = iw * block_size_; - - int size = on * oc * oh * ow; - - CalDepthToSpace(size, static_cast(inputs[0]), in, ic, ih, iw, on, oc, oh, ow, block_size_, - static_cast(outputs[0]), device_id_, stream); - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *DepthToSpacePlugin::clone() const noexcept { - auto *plugin = new (std::nothrow) DepthToSpacePlugin(*this); - if (plugin == nullptr) { - MS_LOG(ERROR) << "new plugin failed!"; - return nullptr; - } - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -size_t DepthToSpacePlugin::getSerializationSize() const noexcept { return sizeof(int); } - -nvinfer1::DimsExprs DepthToSpacePlugin::getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, - int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs dims; - dims.nbDims = inputs[0].nbDims; - dims.d[0] = inputs[0].d[0]; - dims.d[1] = inputs[0].d[1]; - auto block_size_sqrt = exprBuilder.constant(block_size_ * block_size_); - dims.d[1] = exprBuilder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, *inputs[0].d[1], *block_size_sqrt); - auto block_size = exprBuilder.constant(block_size_); - dims.d[INPUT_SIZE2] = - exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[0].d[INPUT_SIZE2], *block_size); - dims.d[INPUT_SIZE3] = - exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[0].d[INPUT_SIZE3], *block_size); - return dims; -} - -void DepthToSpacePlugin::serialize(void *buffer) const noexcept { SerializeValue(&buffer, &block_size_, sizeof(int)); } -REGISTER_TENSORRT_CREATOR(ops::kNameDepthToSpace, DepthToSpaceTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.h deleted file mode 100644 index e483a9ca477851ccb8cd1206bbd8e5be5587d85c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/depthtospace_tensorrt.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DEPTHTOSPACETENSORRT_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DEPTHTOSPACETENSORRT_PLUGIN_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class DepthToSpaceTensorRT : public TensorRTOp { - public: - DepthToSpaceTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~DepthToSpaceTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto DEPTHTOSPACETENSORRT_PLUGIN_NAME{"DepthToSpacePlugin"}; -class DepthToSpacePlugin : public TensorRTPlugin { - public: - DepthToSpacePlugin(const std::string name, int block_size, uint32_t device_id) - : TensorRTPlugin(name, std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME), device_id), block_size_(block_size) {} - - DepthToSpacePlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - block_size_ = static_cast(fields[0].data)[0]; - } - - DepthToSpacePlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &block_size_, sizeof(int)); - } - - DepthToSpacePlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - - private: - int RunCudaDepthToSpace(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - int block_size_; - const std::string layer_name_; - std::string name_space_; -}; -class DepthToSpacePluginCreater : public TensorRTPluginCreater { - public: - DepthToSpacePluginCreater() : TensorRTPluginCreater(std::string(DEPTHTOSPACETENSORRT_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_DEPTHTOSPACETENSORRT_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/elementwise_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/elementwise_tensorrt.cc deleted file mode 100644 index a475281ce90c309d8e522818e2bde9fcf18ae1d5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/elementwise_tensorrt.cc +++ /dev/null @@ -1,378 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/elementwise_tensorrt.h" -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "src/extendrt/delegate/tensorrt/op/activation_tensorrt.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "infer/cxx_api/sub_fusion.h" -#include "infer/cxx_api/div_fusion.h" -#include "infer/cxx_api/reduce_fusion.h" -#include "infer/cxx_api/pow_fusion.h" -#include "infer/cxx_api/add_fusion.h" -#include "infer/cxx_api/mul_fusion.h" -#include "infer/eltwise.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_n.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -namespace { -std::unordered_map NOT_BOOL_PRIM2NV_ELEM_OP = { -#if TRT_VERSION_GE(7, 2) - {ops::kNameLess, nvinfer1::ElementWiseOperation::kLESS}, - {ops::kNameGreater, nvinfer1::ElementWiseOperation::kGREATER}, -#endif - {ops::kNameAddFusion, nvinfer1::ElementWiseOperation::kSUM}, - {ops::kNamePowFusion, nvinfer1::ElementWiseOperation::kPOW}, - {ops::kNameDivFusion, nvinfer1::ElementWiseOperation::kDIV}, - {ops::kNameRealDiv, nvinfer1::ElementWiseOperation::kDIV}, - {ops::kNameFloorDiv, nvinfer1::ElementWiseOperation::kFLOOR_DIV}, - {ops::kNameSubFusion, nvinfer1::ElementWiseOperation::kSUB}, - {ops::kNameMulFusion, nvinfer1::ElementWiseOperation::kPROD}, - {ops::kNameMinimum, nvinfer1::ElementWiseOperation::kMIN}, - {ops::kNameMaximum, nvinfer1::ElementWiseOperation::kMAX}, - {ops::kNameBiasAdd, nvinfer1::ElementWiseOperation::kSUM}, -#if TRT_VERSION_GE(7, 2) - {ops::kNameEqual, nvinfer1::ElementWiseOperation::kEQUAL}, - {ops::kNameNotEqual, nvinfer1::ElementWiseOperation::kEQUAL}, -#endif -}; -} // namespace - -int ElementWiseTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "invalid input tensort size: " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "invalid output tensort size: " << out_tensors.size(); - return RET_ERROR; - } - - bool is_not_bool_arith = NOT_BOOL_PRIM2NV_ELEM_OP.find(type_) != NOT_BOOL_PRIM2NV_ELEM_OP.end(); - if (is_not_bool_arith) { - if (std::any_of(in_tensors.begin(), in_tensors.end(), - [](const TensorInfo &tensor) { return tensor.DataType() == DataType::kNumberTypeBool; })) { - MS_LOG(ERROR) << "invalid input type for : " << op_name_; - return RET_ERROR; - } - element_wise_op_ = NOT_BOOL_PRIM2NV_ELEM_OP[type_]; - } - if (!is_not_bool_arith) { - // PrimitiveType_Eltwise - auto eltwise_op = AsOps(); - if (eltwise_op == nullptr) { - MS_LOG(ERROR) << "convert to Eltwise failed: " << op_name_; - return RET_ERROR; - } - EltwiseMode eltwiseMode = eltwise_op->get_mode(); - std::map eltwise_modes = { - {EltwiseMode::SUM, nvinfer1::ElementWiseOperation::kSUM}, - {EltwiseMode::PROD, nvinfer1::ElementWiseOperation::kPROD}, - {EltwiseMode::MAXIMUM, nvinfer1::ElementWiseOperation::kMAX}, - }; - auto iter_mode = eltwise_modes.find(eltwiseMode); - if (iter_mode != eltwise_modes.end()) { - element_wise_op_ = iter_mode->second; - } else { - MS_LOG(ERROR) << "unsupported type for ElementWise op" << op_name_; - return RET_ERROR; - } - } - return RET_OK; -} - -void ElementWiseTensorRT::LogicalOpChangeInputType(TensorRTContext *ctx, ITensorHelper *x_input, - ITensorHelper *y_input) { - if (type_ == ops::kNameGreater || type_ == ops::kNameLess) { - if (x_input->trt_tensor_->getType() != nvinfer1::DataType::kINT32) { - x_input->trt_tensor_ = - TRTTensorCast(ctx, x_input->trt_tensor_, nvinfer1::DataType::kINT32, op_name_ + "_input_cast_to_int_0"); - } - if (y_input->trt_tensor_->getType() != nvinfer1::DataType::kINT32) { - y_input->trt_tensor_ = - TRTTensorCast(ctx, y_input->trt_tensor_, nvinfer1::DataType::kINT32, op_name_ + "_input_cast_to_int_1"); - } - } -} - -int ElementWiseTensorRT::AddInnerOp(TensorRTContext *ctx) { - ITensorHelper x_input; - ITensorHelper y_input; - int ret = PreprocessInputTensors(ctx, &x_input, &y_input); - if (ret != RET_OK) { - MS_LOG(ERROR) << "PreprocessInputTensors failed."; - return RET_ERROR; - } - nvinfer1::IElementWiseLayer *cal_layer; - if (type_ == ops::kNameFloorMod) { - cal_layer = AddFoorMod(ctx, x_input.trt_tensor_, y_input.trt_tensor_); - } else { - cal_layer = ctx->network()->addElementWise(*x_input.trt_tensor_, *y_input.trt_tensor_, element_wise_op_); - } - - if (cal_layer == nullptr) { - MS_LOG(ERROR) << "addElementWise failed for TensorRT."; - return RET_ERROR; - } - cal_layer->setName(op_name_.c_str()); - this->layer_ = cal_layer; - ctx->RegisterLayer(cal_layer, op_name_); - - nvinfer1::ITensor *op_out_tensor = cal_layer->getOutput(0); - if (op_out_tensor == nullptr) { - MS_LOG(ERROR) << "addElementWise out tensor is nullptr."; - return RET_ERROR; - } - // add activation - nvinfer1::ITensor *activation_out_tensor = AddActivation(ctx, op_out_tensor); - op_out_tensor = (activation_out_tensor == nullptr) ? op_out_tensor : activation_out_tensor; - - // scale and shift - if (type_ == ops::kNamePowFusion) { - auto pow_op = AsOps(); - if (pow_op == nullptr) { - MS_LOG(ERROR) << "PowFusion convert failed."; - return RET_ERROR; - } - float scale = pow_op->get_scale(); - float shift = pow_op->get_shift(); - if (abs(scale - 1) >= 1.0e-05 || abs(shift - 0) >= 1.0e-05) { - MS_LOG(WARNING) << "deal with scale and shift for pow op"; - } - } -#if TRT_VERSION_GE(7, 2) - if (type_ == ops::kNameNotEqual) { - op_out_tensor = ctx->network()->addUnary(*op_out_tensor, nvinfer1::UnaryOperation::kNOT)->getOutput(0); - } - std::unordered_set bool_producer_ops = {ops::kNameNotEqual, ops::kNameEqual, ops::kNameGreater, - ops::kNameLess}; - if (bool_producer_ops.find(type_) != bool_producer_ops.end()) { - auto cast_layer = ctx->network()->addIdentity(*op_out_tensor); - if (cast_layer == nullptr) { - MS_LOG(ERROR) << "create cast layer failed for: " << op_name_; - return RET_ERROR; - } - cast_layer->setOutputType(0, nvinfer1::DataType::kINT32); - op_out_tensor = cast_layer->getOutput(0); - MS_LOG(INFO) << "bool result cast to int32" << op_name_; - } -#endif - auto is_tensor = x_input.is_tensor || y_input.is_tensor; - auto output_helper = ITensorHelper{op_out_tensor, x_input.format_, x_input.same_format_, is_tensor}; - ctx->RegisterTensor(output_helper, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "output " << GetTensorFormat(output_helper); - return RET_OK; -} - -int ElementWiseTensorRT::PreprocessInputTensors(TensorRTContext *ctx, ITensorHelper *x_input, ITensorHelper *y_input) { - if (HasConst()) { - int ret = AddConstTensor(ctx); - if (ret != RET_OK) { - return ret; - } - } - *x_input = input(ctx, 0); - *y_input = input(ctx, 1); - if (in_tensors_[0].DataType() != in_tensors_[1].DataType()) { - MS_LOG(INFO) << "trt op elementwise layer not support different input data type, cast to higher one"; - auto higher_index = in_tensors_[0].DataType() > in_tensors_[1].DataType() ? 0 : 1; - auto highter_trt_tensor = input(ctx, higher_index).trt_tensor_; - auto cast_layer = ctx->network()->addIdentity(*input(ctx, 1 - higher_index).trt_tensor_); - CHECK_NULL_RETURN(cast_layer); - cast_layer->setOutputType(0, highter_trt_tensor->getType()); - auto cast_output = cast_layer->getOutput(0); - CHECK_NULL_RETURN(cast_output); - ctx->RegisterTensor( - ITensorHelper{cast_output, input(ctx, higher_index).format_, input(ctx, higher_index).same_format_}, - out_tensors_[0].Name()); - cast_layer->setName((op_name_ + "_cast").c_str()); - if (higher_index != 0) { - x_input->trt_tensor_ = cast_output; - } else { - y_input->trt_tensor_ = cast_output; - } - } - - MS_LOG(DEBUG) << "after transpose " << GetTensorFormat(*x_input); - MS_LOG(DEBUG) << "after transpose " << GetTensorFormat(*y_input); - if (BroadcastInputTensors(ctx, x_input, y_input) != RET_OK) { - return RET_ERROR; - } - - while (x_input->trt_tensor_->getDimensions().nbDims < y_input->trt_tensor_->getDimensions().nbDims) { - x_input->trt_tensor_ = ExpandDim(ctx, x_input->trt_tensor_, 0); - } - while (x_input->trt_tensor_->getDimensions().nbDims > y_input->trt_tensor_->getDimensions().nbDims) { - y_input->trt_tensor_ = ExpandDim(ctx, y_input->trt_tensor_, 0); - } - return RET_OK; -} - -int ElementWiseTensorRT::BroadcastInputTensors(TensorRTContext *ctx, ITensorHelper *x_input, ITensorHelper *y_input) { - if (GetDimsVolume(x_input->trt_tensor_->getDimensions()) == GetDimsVolume(y_input->trt_tensor_->getDimensions()) && - x_input->trt_tensor_->getDimensions().nbDims != y_input->trt_tensor_->getDimensions().nbDims) { - bool x_large = x_input->trt_tensor_->getDimensions().nbDims > y_input->trt_tensor_->getDimensions().nbDims; - auto input_tensor = x_large ? y_input : x_input; - auto output_dim = x_large ? x_input->trt_tensor_->getDimensions() : y_input->trt_tensor_->getDimensions(); - auto reshape_layer = ctx->network()->addShuffle(*input_tensor->trt_tensor_); - if (reshape_layer == nullptr) { - MS_LOG(ERROR) << "add reshape failed for " << op_name_; - return RET_ERROR; - } - reshape_layer->setReshapeDimensions(output_dim); - input_tensor->trt_tensor_ = reshape_layer->getOutput(0); - return RET_OK; - } else if (GetDimsVolume(x_input->trt_tensor_->getDimensions()) != - GetDimsVolume(y_input->trt_tensor_->getDimensions()) && - x_input->trt_tensor_->getDimensions().nbDims != y_input->trt_tensor_->getDimensions().nbDims) { - bool x_large = x_input->trt_tensor_->getDimensions().nbDims > y_input->trt_tensor_->getDimensions().nbDims; - auto input_tensor = x_large ? y_input : x_input; - auto output_dim = x_large ? x_input->trt_tensor_->getDimensions() : y_input->trt_tensor_->getDimensions(); - nvinfer1::Dims in_tensor_dims = input_tensor->trt_tensor_->getDimensions(); - while (in_tensor_dims.nbDims < output_dim.nbDims) { - input_tensor->trt_tensor_ = ExpandDim(ctx, input_tensor->trt_tensor_, 0); - in_tensor_dims = input_tensor->trt_tensor_->getDimensions(); - } - return RET_OK; - } else { - return RET_OK; - } -} - -nvinfer1::IElementWiseLayer *ElementWiseTensorRT::AddFoorMod(TensorRTContext *ctx, nvinfer1::ITensor *x0_trt, - nvinfer1::ITensor *x1_trt) { - nvinfer1::IElementWiseLayer *layer_0 = - ctx->network()->addElementWise(*x0_trt, *x1_trt, nvinfer1::ElementWiseOperation::kFLOOR_DIV); - layer_0->setName((op_name_ + "_floor_div").c_str()); - auto result_0 = layer_0->getOutput(0); - - nvinfer1::IElementWiseLayer *layer_1 = - ctx->network()->addElementWise(*result_0, *x1_trt, nvinfer1::ElementWiseOperation::kPROD); - layer_1->setName((op_name_ + "_prod").c_str()); - auto result_1 = layer_1->getOutput(0); - - nvinfer1::IElementWiseLayer *layer_2 = - ctx->network()->addElementWise(*x0_trt, *result_1, nvinfer1::ElementWiseOperation::kSUB); - layer_2->setName((op_name_ + "_sub").c_str()); - - return layer_2; -} - -nvinfer1::ITensor *ElementWiseTensorRT::AddActivation(TensorRTContext *ctx, nvinfer1::ITensor *in_tensor) { - ActivationType activation = ActivationType::NO_ACTIVATION; - if (type_ == ops::kNameAddFusion) { - auto sum_op = AsOps(); - if (sum_op == nullptr) { - MS_LOG(ERROR) << "AddFusion convert failed."; - return nullptr; - } - if (sum_op->HasAttr(ops::kActivationType)) { - activation = sum_op->get_activation_type(); - } - } else if (type_ == ops::kNameDivFusion) { - auto div_op = AsOps(); - if (div_op == nullptr) { - MS_LOG(ERROR) << "DivFusion convert failed."; - return nullptr; - } - if (div_op->HasAttr(ops::kActivationType)) { - activation = div_op->get_activation_type(); - } - } else if (type_ == ops::kNameSubFusion) { - auto sub_op = AsOps(); - if (sub_op == nullptr) { - MS_LOG(ERROR) << "SubFusion convert failed."; - return nullptr; - } - if (sub_op->HasAttr(ops::kActivationType)) { - activation = sub_op->get_activation_type(); - } - } else if (type_ == ops::kNameMulFusion) { - auto mul_op = AsOps(); - if (mul_op == nullptr) { - MS_LOG(ERROR) << "MulFusion convert failed."; - return nullptr; - } - if (mul_op->HasAttr(ops::kActivationType)) { - activation = mul_op->get_activation_type(); - } - } else { - MS_LOG(DEBUG) << "no activation need for: " << op_name_; - } - nvinfer1::ITensor *activation_out_tensor = nullptr; - if (activation != ActivationType::NO_ACTIVATION) { - auto activation_layer = - ActivationTensorRT::AddActivation(ctx, activation, 0, 0, 0, in_tensor, op_name_, device_id_); - if (activation_layer == nullptr) { - MS_LOG(ERROR) << "addActivation for element wise failed"; - return nullptr; - } - activation_layer->setName((op_name_ + "_activation").c_str()); - activation_out_tensor = activation_layer->getOutput(0); - } - return activation_out_tensor; -} - -int ElementWiseTensorRT::AddConstTensor(TensorRTContext *ctx) { - int const_tensor_index = in_tensors_[0].IsConst() ? 0 : 1; - if (in_tensors_[0].IsConst() && in_tensors_[1].IsConst()) { - auto large_size_index = in_tensors_[0].ElementNum() >= in_tensors_[1].ElementNum() ? 0 : 1; - const_tensor_index = 1 - large_size_index; - } - auto expect_shape = ConvertMSShape(input(ctx, 1 - const_tensor_index).trt_tensor_->getDimensions()); - auto &const_tensor = in_tensors_[const_tensor_index]; - nvinfer1::ITensor *constant_input = ConvertConstantTensorWithDims(ctx, const_tensor, expect_shape, op_name_); - CHECK_NULL_RETURN(constant_input); - auto const_shape = const_tensor.Shape(); - auto is_scalar = const_shape.empty(); - auto const_helper = ITensorHelper{constant_input, input(ctx, 1 - const_tensor_index).format_, true, !is_scalar}; - ctx->RegisterTensor(const_helper, const_tensor.Name()); - return RET_OK; -} - -REGISTER_TENSORRT_CREATOR(ops::kNameSubFusion, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameDivFusion, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameRealDiv, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNamePowFusion, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameAddFusion, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameMulFusion, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameEltwise, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameMinimum, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameMaximum, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameBiasAdd, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameFloorMod, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameFloorDiv, ElementWiseTensorRT) -#if TRT_VERSION_GE(7, 2) -REGISTER_TENSORRT_CREATOR(ops::kNameNotEqual, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameEqual, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameLess, ElementWiseTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameGreater, ElementWiseTensorRT) -#endif -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/elementwise_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/elementwise_tensorrt.h deleted file mode 100644 index 95e99d2d4cfa49789d1726a24ee309c113f6a113..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/elementwise_tensorrt.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ELEMENTWISE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ELEMENTWISE_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class ElementWiseTensorRT : public TensorRTOp { - public: - ElementWiseTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~ElementWiseTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - nvinfer1::ITensor *AddActivation(TensorRTContext *ctx, nvinfer1::ITensor *in_tensor); - - nvinfer1::IElementWiseLayer *AddFoorMod(TensorRTContext *ctx, nvinfer1::ITensor *x0_trt, nvinfer1::ITensor *x1_trt); - void LogicalOpChangeInputType(TensorRTContext *ctx, ITensorHelper *x_input, ITensorHelper *y_input); - - int AddConstTensor(TensorRTContext *ctx); - - int BroadcastInputTensors(TensorRTContext *ctx, ITensorHelper *x_input, ITensorHelper *y_input); - - bool SameTensor(nvinfer1::ITensor *trt_tensor, TensorInfo *ms_tensor); - - int PreprocessInputTensors(TensorRTContext *ctx, ITensorHelper *x_input, ITensorHelper *y_input); - - nvinfer1::ElementWiseOperation element_wise_op_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ELEMENTWISE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/equal_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/equal_tensorrt.cc deleted file mode 100644 index 16568daa8d265cbbd571bb06447d81c34c42791d..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/equal_tensorrt.cc +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/equal_tensorrt.h" -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" - -namespace mindspore::lite { -REGISTER_TENSORRT_PLUGIN(EqualPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int EqualTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "invalid input tensor size: " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "invalid output tensor size: " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int EqualTensorRT::AddInnerOp(TensorRTContext *ctx) { - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_, input(ctx, 1).trt_tensor_}; - auto plugin = std::make_shared(op_name_, device_id_); - nvinfer1::IPluginV2Layer *equal_layer = ctx->network()->addPluginV2(inputTensors, INPUT_SIZE2, *plugin); - if (equal_layer == nullptr) { - MS_LOG(ERROR) << "create equal layer failed for: " << op_name_; - return RET_ERROR; - } - layer_ = equal_layer; - nvinfer1::ITensor *equal_out = equal_layer->getOutput(0); - equal_layer->setName(op_name_.c_str()); - ctx->RegisterTensor(ITensorHelper{equal_out, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} - -int EqualPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - nvinfer1::Dims input_dims = inputDesc[0].dims; - int element_cnt = std::accumulate(input_dims.d, input_dims.d + input_dims.nbDims, 1, std::multiplies()); - - if (inputDesc->type == nvinfer1::DataType::kINT32) { - const int *input1 = static_cast(inputs[0]); - const int *input2 = static_cast(inputs[1]); - int *output = static_cast(outputs[0]); - Equal(input1, input2, output, element_cnt, stream); - } else if (inputDesc->type == nvinfer1::DataType::kFLOAT) { - const float *input1 = static_cast(inputs[0]); - const float *input2 = static_cast(inputs[1]); - float *output = static_cast(outputs[0]); - Equal(input1, input2, output, element_cnt, stream); - } else { - MS_LOG(ERROR) << "unsupported equal data type"; - } - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *EqualPlugin::clone() const noexcept { - auto *plugin = new EqualPlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} -#if TRT_VERSION_LS(7, 2) -REGISTER_TENSORRT_CREATOR(ops::kNameEqual, EqualTensorRT) -#endif -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/equal_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/equal_tensorrt.h deleted file mode 100644 index 46736aec08fa2dd5aac8af3cca887ed8fbc00699..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/equal_tensorrt.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_EQUAL_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_EQUAL_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/equal.cuh" - -namespace mindspore::lite { -constexpr auto EQUAL_PLUGIN_NAME{"EqualPlugin"}; -class EqualTensorRT : public TensorRTOp { - public: - EqualTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~EqualTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -class EqualPlugin : public TensorRTPlugin { - public: - EqualPlugin(const std::string name, uint32_t device_id) - : TensorRTPlugin(name, std::string(EQUAL_PLUGIN_NAME), device_id) {} - - EqualPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(EQUAL_PLUGIN_NAME)) {} - - EqualPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(EQUAL_PLUGIN_NAME)) {} - - EqualPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; -}; -class EqualPluginCreater : public TensorRTPluginCreater { - public: - EqualPluginCreater() : TensorRTPluginCreater(std::string(EQUAL_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_EQUAL_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fastgelu_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/fastgelu_tensorrt.cc deleted file mode 100644 index 8f5ac6cf86444057f376939b859ba6ca17c0f046..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fastgelu_tensorrt.cc +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/fastgelu_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/op/cast_tensorrt.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" - -namespace mindspore::lite { -// FastGelu is defined as "0.5x * (1 + tanh[sqrt(2.0/Pi) * (x + 0.044715 * x^3)])" in paper "Gaussian Error Linear Units -// (GELUs)" by Dan Hendrycks, Kevin Gimpel, 2016. -constexpr float FASTGELU_PARAM1 = 3.f; -constexpr float FASTGELU_PARAM2 = 0.044715f; -constexpr float FASTGELU_PARAM3 = 0.7978845608f; // sqrt(2.0/Pi) -constexpr float FASTGELU_PARAM4 = 1.f; -constexpr float FASTGELU_PARAM5 = 0.5f; - -int FastGeluTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int FastGeluTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - int input_nbdims = input(ctx, 0).trt_tensor_->getDimensions().nbDims; - if (input_nbdims < 0) { - MS_LOG(ERROR) << "input dims should not be less than 0 for " << op_name_; - return RET_ERROR; - } - int ret = RunAsTrtOps(ctx); - if (ret != RET_OK) { - MS_LOG(ERROR) << "add layer failed for " << op_name_; - return ret; - } - return ret; -} - -int FastGeluTensorRT::RunAsTrtOps(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid for " << op_name_; - return RET_ERROR; - } - auto trt_in_tensor = input(ctx, 0).trt_tensor_; - if (trt_in_tensor->getDimensions().nbDims <= 0) { - MS_LOG(ERROR) << "Invalid input dims count " << trt_in_tensor->getDimensions().nbDims << ", " << op_name_; - return RET_ERROR; - } - auto expand_dims = [](TensorRTContext *ctx, nvinfer1::ITensor *tensor, int nbdims) { - while (tensor->getDimensions().nbDims != nbdims) { - tensor = ExpandDim(ctx, tensor, 0); - } - return tensor; - }; - int nbdims = trt_in_tensor->getDimensions().nbDims; - auto const_three = expand_dims(ctx, ctx->ConvertTo1DTensor(FASTGELU_PARAM1), nbdims); - CHECK_NULL_RETURN(const_three); - auto p3 = - ctx->network()->addElementWise(*trt_in_tensor, *const_three, nvinfer1::ElementWiseOperation::kPOW)->getOutput(0); - CHECK_NULL_RETURN(p3); - auto gelu_p1 = expand_dims(ctx, ctx->ConvertTo1DTensor(FASTGELU_PARAM2), nbdims); - CHECK_NULL_RETURN(gelu_p1); - auto prod1 = ctx->network()->addElementWise(*p3, *gelu_p1, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - CHECK_NULL_RETURN(prod1); - auto sum = ctx->network()->addElementWise(*prod1, *trt_in_tensor, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - CHECK_NULL_RETURN(sum); - auto gelu_p2 = expand_dims(ctx, ctx->ConvertTo1DTensor(FASTGELU_PARAM3), nbdims); - CHECK_NULL_RETURN(gelu_p2); - auto prod2 = ctx->network()->addElementWise(*sum, *gelu_p2, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - CHECK_NULL_RETURN(prod2); - auto tanh = ctx->network()->addActivation(*prod2, nvinfer1::ActivationType::kTANH)->getOutput(0); - CHECK_NULL_RETURN(tanh); - auto const_one = expand_dims(ctx, ctx->ConvertTo1DTensor(FASTGELU_PARAM4), nbdims); - CHECK_NULL_RETURN(const_one); - auto sum2 = ctx->network()->addElementWise(*const_one, *tanh, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - CHECK_NULL_RETURN(sum2); - auto prod3 = - ctx->network()->addElementWise(*sum2, *trt_in_tensor, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - CHECK_NULL_RETURN(prod3); - auto gelu_p3 = expand_dims(ctx, ctx->ConvertTo1DTensor(FASTGELU_PARAM5), nbdims); - CHECK_NULL_RETURN(gelu_p3); - auto fastgelu_layer = ctx->network()->addElementWise(*prod3, *gelu_p3, nvinfer1::ElementWiseOperation::kPROD); - if (fastgelu_layer == nullptr) { - MS_LOG(ERROR) << "add fastgelu op failed for TensorRT."; - return RET_ERROR; - } - nvinfer1::ITensor *out_tensor = fastgelu_layer->getOutput(0); - if (out_tensor == nullptr) { - MS_LOG(ERROR) << "add fastgelu op failed for TensorRT."; - return RET_ERROR; - } - - // cast to origin type - if (out_tensor->getType() != ConvertDataType(out_tensors_[0].DataType())) { - out_tensor = TRTTensorCast(ctx, fastgelu_layer->getOutput(0), ConvertDataType(out_tensors_[0].DataType()), - op_name_ + "_cast_out"); - } - ctx->RegisterTensor(ITensorHelper{out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - fastgelu_layer->setName(op_name_.c_str()); - this->layer_ = fastgelu_layer; - ctx->RegisterLayer(fastgelu_layer, op_name_); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameFastGeLU, FastGeluTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fastgelu_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/fastgelu_tensorrt.h deleted file mode 100644 index 02b9896d1d59fcabec01f3ce06f350e87e234b19..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fastgelu_tensorrt.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_FASTGELU_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_FASTGELU_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class FastGeluTensorRT : public TensorRTOp { - public: - FastGeluTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~FastGeluTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int RunAsTrtOps(TensorRTContext *ctx); -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_FASTGELU_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fill_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/fill_tensorrt.cc deleted file mode 100644 index ae5f54341b55900b9f7523e7b1acd642ff4263f3..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fill_tensorrt.cc +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/fill_tensorrt.h" -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "infer/fill.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" - -namespace mindspore::lite { -int FillTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { -#if TRT_VERSION_GE(8, 2) - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - return RET_OK; -#else - MS_LOG(WARNING) << "low TensorRT version don't support fill op, please upgrade TensorRT version to 8.2 or higher"; - return RET_ERROR; -#endif -} - -int FillTensorRT::AddInnerOp(TensorRTContext *ctx) { -#if TRT_VERSION_GE(8, 2) - ITensorHelper fill_input; - nvinfer1::FillOperation op = nvinfer1::FillOperation::kLINSPACE; - auto *fill_layer = ctx->network()->addFill({}, op); - if (fill_layer == nullptr) { - MS_LOG(ERROR) << "addFill failed for TensorRT : " << op_name_; - return RET_ERROR; - } - fill_layer->setInput(0, *input(ctx, 1).trt_tensor_); - nvinfer1::ITensor *alpha_tensor = nullptr; - if (!in_tensors_[0].IsConst()) { - alpha_tensor = input(ctx, 0).trt_tensor_; - } else { - alpha_tensor = - ConvertScalarToITensor(ctx, 0, in_tensors_[0].Data(), in_tensors_[0].DataType(), op_name_ + "_alpha"); - } - fill_layer->setInput(1, *alpha_tensor); - int nbdims = input(ctx, 1).trt_tensor_->getDimensions().d[0]; - nvinfer1::ITensor *beta_tensor = nullptr; - if (in_tensors_[0].DataType() == DataType::kNumberTypeInt32) { - beta_tensor = ctx->ConvertTo1DTensor(std::vector(nbdims, 0)); - } else { - beta_tensor = ctx->ConvertTo1DTensor(std::vector(nbdims, 0.f)); - } - fill_layer->setInput(INPUT_SIZE2, *beta_tensor); - - nvinfer1::ITensor *out_tensor = fill_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - this->layer_ = fill_layer; - return RET_OK; -#else - MS_LOG(WARNING) << "low TensorRT version don't support fill op, please upgrade TensorRT version to 8.2 or higher"; - return RET_ERROR; -#endif -} -REGISTER_TENSORRT_CREATOR(ops::kNameFill, FillTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fill_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/fill_tensorrt.h deleted file mode 100644 index d228fa62f0b78e9b3e7400bb17b34d0f75c885bb..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fill_tensorrt.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_FILL_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_FILL_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class FillTensorRT : public TensorRTOp { - public: - FillTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~FillTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - std::vector zeros_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_FILL_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fse_decoder_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/fse_decoder_tensorrt.cc deleted file mode 100644 index f4af141fa290ac9e1043000a9674fdc33e0b89f8..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fse_decoder_tensorrt.cc +++ /dev/null @@ -1,288 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/fse_decoder_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/fse_decode.cuh" -#include "infer/fse_decode.h" -#include "tools/converter/quantizer/fse_chunk_end.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" - -namespace mindspore::lite { -namespace { -constexpr std::size_t kTwo = 2; -constexpr std::size_t kThree = 3; -constexpr std::size_t kFour = 4; -constexpr std::size_t kSix = 6; -constexpr std::size_t kInputSize = 7; -} // namespace - -template -bool isValidData(const T *data, size_t data_size, int value) { - for (size_t i = 0; i < data_size; i++) { - if (data[i] >= value) { - return false; - } - } - return true; -} - -bool FseDecoderTensorRT::IsChunkEndDataValid() { - auto bit_count_buff = reinterpret_cast(in_tensors_[kTwo].Data()); - auto ptable_buff = reinterpret_cast(in_tensors_[kSix].Data()); - auto ptable_size = in_tensors_[kSix].ElementNum(); - auto chunk_size = in_tensors_[0].ElementNum(); - auto state_size = in_tensors_[1].ElementNum(); - - for (size_t i = 0; i < static_cast(ptable_size); i++) { - mindspore::lite::quant::ChunkEndData ptable_data(ptable_buff[i]); - if (ptable_data.state >= state_size) { - MS_LOG(ERROR) << "ERROR: ptable[" << i << "].state: " << ptable_data.state; - return false; - } - if (ptable_data.bit_count != static_cast(bit_count_buff[ptable_data.state])) { - MS_LOG(ERROR) << "ERROR: ptable[" << i << "].bit_count: " << ptable_data.bit_count << ", bit_count_buff[" - << ptable_data.state << "]: " << static_cast(bit_count_buff[ptable_data.state]); - return false; - } - uint64_t chunk_index = ptable_data.bs_position / (CHAR_BIT * sizeof(uint64_t)); - if (chunk_index >= (uint64_t)chunk_size) { - MS_LOG(ERROR) << "ERROR: ptable[" << i << "].bs_position: " << ptable_data.bs_position - << "chunk_size:" << chunk_size << "ptable_size:" << ptable_size; - return false; - } - } - return true; -} - -// FSEDecode TensorRT op -int FseDecoderTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != kInputSize) { - MS_LOG(ERROR) << "Unsupported number of inputs, size is " << in_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int FseDecoderTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - - auto fse_decoder_op = AsOps(); - if (fse_decoder_op == nullptr) { - MS_LOG(ERROR) << "op action convert failed"; - return RET_ERROR; - } - - // Check tables validity - MS_ASSERT(in_tensors_[1].ElementNum() == in_tensors_[kThree].ElementNum()); - MS_ASSERT(in_tensors_[kTwo].ElementNum() == in_tensors_[kThree].ElementNum()); - - if (!isValidData(reinterpret_cast(in_tensors_[1].Data()), in_tensors_[1].ElementNum(), - in_tensors_[1].ElementNum()) || - !isValidData(reinterpret_cast(in_tensors_[kTwo].Data()), in_tensors_[kTwo].ElementNum(), 64) || - !isValidData(reinterpret_cast(in_tensors_[kThree].Data()), in_tensors_[kThree].ElementNum(), - in_tensors_[kFour].ElementNum())) { - MS_LOG(ERROR) << "Invalid data in tables"; - return RET_ERROR; - } - - MS_ASSERT(IsChunkEndDataValid()); - - uint64_t curr_chunk_idx = static_cast(fse_decoder_op->get_curr_chunk_index()); - int64_t dst_type = fse_decoder_op->get_dst_t(); - uint64_t curr_bit_count = static_cast(fse_decoder_op->get_curr_bit_count()); - uint64_t table_log = static_cast(fse_decoder_op->get_table_log()); - uint64_t curr_chunk = static_cast(fse_decoder_op->get_curr_chunk()); - const int input_number = inputs().size(); - auto output = outputs().at(0); - auto output_shape = output.Shape(); - - // Convert tensors to int32 for TensorRT - nvinfer1::Dims dims{}; - dims.nbDims = 1; - size_t start_const = C0NUM; - size_t end_const = C4NUM; - for (size_t i = 0; i < in_tensors_.size(); i++) { - auto in_tensor = input(ctx, i); - if ((i >= start_const && i < end_const) || (i == kSix)) { - auto size = inputs().at(i).DataSize(); - dims.d[0] = size / sizeof(int32_t); - nvinfer1::IConstantLayer *constant_tensor; - nvinfer1::Weights weights{nvinfer1::DataType::kINT32, inputs().at(i).Data(), dims.d[0]}; - constant_tensor = ctx->network()->addConstant(dims, weights); - ctx->RegisterLayer(constant_tensor, inputs().at(i).Name() + "_" + op_name_); - in_tensor.trt_tensor_ = constant_tensor->getOutput(0); - ctx->RegisterTensor(in_tensor, inputs().at(i).Name()); - } else { - in_tensor.trt_tensor_ = lite::ConvertConstantTensor(ctx, inputs().at(i), op_name_); - ctx->RegisterTensor(in_tensor, inputs().at(i).Name()); - } - } - nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_; - auto plugin = std::make_shared(input_tensor->getName(), curr_chunk_idx, dst_type, curr_bit_count, - table_log, curr_chunk, output_shape, device_id_); - nvinfer1::ITensor *inputTensors[input_number]; - for (int i = 0; i < input_number; i++) { - inputTensors[i] = input(ctx, i).trt_tensor_; - } - nvinfer1::IPluginV2Layer *fse_decoder_layer = ctx->network()->addPluginV2(inputTensors, input_number, *plugin); - if (fse_decoder_layer == nullptr) { - MS_LOG(ERROR) << "add fse decoder op failed for TensorRT."; - return RET_ERROR; - } - fse_decoder_layer->setName((op_name_ + "plugin_fse_decoder").c_str()); - nvinfer1::ITensor *fse_decoder_tensor = fse_decoder_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{fse_decoder_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - this->layer_ = fse_decoder_layer; - return RET_OK; -} - -// PLUGIN of FSE Decode Layer -REGISTER_TENSORRT_PLUGIN(FseDecoderPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int FseDecoderPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - if (dst_type_ == mindspore::kNumberTypeFloat16) { - return RunFseDecoder(inputDesc, outputDesc, inputs, outputs, workspace, stream); - } else { - return RunFseDecoder(inputDesc, outputDesc, inputs, outputs, workspace, stream); - } -} - -template -std::unique_ptr getTensor(T *tensor, int size) { - using non_const_T = std::remove_const_t; - auto buff = std::make_unique(size); - cudaMemcpy(reinterpret_cast(buff.get()), tensor, size * sizeof(T), cudaMemcpyDeviceToHost); - return buff; -} - -template -int FseDecoderPlugin::RunFseDecoder(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) { - auto chunks = reinterpret_cast(inputs[0]); - auto states_table = reinterpret_cast(inputs[1]); - auto bit_count_table = reinterpret_cast(inputs[kTwo]); - auto symbol_table = reinterpret_cast(inputs[kThree]); - auto centroids = reinterpret_cast(inputs[kFour]); - auto ptable = reinterpret_cast(inputs[kSix]); - auto out = reinterpret_cast(outputs[0]); - nvinfer1::Dims output_dims = outputDesc[0].dims; - nvinfer1::Dims input_dims = inputDesc[kSix].dims; - auto out_size = std::accumulate(output_dims.d, output_dims.d + output_dims.nbDims, 1, std::multiplies()); - int ptable_size = (std::accumulate(input_dims.d, input_dims.d + input_dims.nbDims, 1, std::multiplies())); - ptable_size = ptable_size * sizeof(int32_t) / - sizeof(uint64_t); // transform to original size due to previous conversion (from int32 to uint64) - bool use_curr_chunk = (curr_bit_count_ > table_log_); - FSE_Decode(chunks, states_table, bit_count_table, symbol_table, ptable, ptable_size, centroids, out_size, out, - device_id_, curr_chunk_, use_curr_chunk, stream); - return 0; -} - -bool FseDecoderPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept { - bool format = (tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - auto type = tensorsDesc[pos].type; - if (pos == nbInputs) { - format &= (type == nvinfer1::DataType::kFLOAT); - } else { - switch (pos) { - case C0NUM: - case C1NUM: - case C2NUM: - case C3NUM: - case C5NUM: - format &= (type == nvinfer1::DataType::kINT32); - break; - case C4NUM: - format &= (type == nvinfer1::DataType::kFLOAT); - break; - default: - break; - } - } - return format; -} - -void FseDecoderPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {} - -size_t FseDecoderPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept { - return 0; -} - -nvinfer1::DataType FseDecoderPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const noexcept { - return nvinfer1::DataType::kFLOAT; -} - -nvinfer1::DimsExprs FseDecoderPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, - int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs out_dims{}; - out_dims.nbDims = output_shape_.size(); - for (int i = 0; i < out_dims.nbDims; i++) { - out_dims.d[i] = exprBuilder.constant(output_shape_[i]); - } - return out_dims; -} - -nvinfer1::IPluginV2DynamicExt *FseDecoderPlugin::clone() const noexcept { - auto *plugin = new FseDecoderPlugin(*this); - if (plugin == nullptr) { - MS_LOG(ERROR) << "plugin is null"; - return nullptr; - } - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -int FseDecoderPlugin::initialize() noexcept { return 0; } - -void FseDecoderPlugin::terminate() noexcept {} - -size_t FseDecoderPlugin::getSerializationSize() const noexcept { return INPUT_SIZE4 * sizeof(int); } - -void FseDecoderPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &curr_chunk_idx_, sizeof(int)); - SerializeValue(&buffer, &dst_type_, sizeof(int)); - SerializeValue(&buffer, &curr_bit_count_, sizeof(int)); - SerializeValue(&buffer, &table_log_, sizeof(int)); - SerializeValue(&buffer, &curr_chunk_, sizeof(int)); -} - -REGISTER_TENSORRT_CREATOR(ops::kNameFSEDecode, FseDecoderTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fse_decoder_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/fse_decoder_tensorrt.h deleted file mode 100644 index df4850cb7a5f922311f36314eb2c195e08bc0c92..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fse_decoder_tensorrt.h +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_FSE_DECODER_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_FSE_DECODER_TENSORRT_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class FseDecoderTensorRT : public TensorRTOp { - public: - FseDecoderTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~FseDecoderTensorRT() override = default; - - bool IsWeightInputHanledInner() const override { return true; } - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - bool IsChunkEndDataValid(); -}; - -constexpr auto FSE_DECODER_PLUGIN_NAME{"FseDecoderPlugin"}; -class FseDecoderPlugin : public TensorRTPlugin { - public: - FseDecoderPlugin(const std::string name, int64_t curr_chunk_idx, int64_t dst_type, int64_t curr_bit_count, - int64_t table_log, uint64_t curr_chunk, const ShapeVector &out_shape, uint32_t device_id) - : TensorRTPlugin(name, std::string(FSE_DECODER_PLUGIN_NAME), device_id), - curr_chunk_idx_(curr_chunk_idx), - dst_type_(dst_type), - curr_bit_count_(curr_bit_count), - table_log_(table_log), - curr_chunk_(curr_chunk), - output_shape_(out_shape) {} - - FseDecoderPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(FSE_DECODER_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - curr_chunk_idx_ = static_cast(fields[0].data)[0]; - dst_type_ = static_cast(fields[1].data)[0]; - curr_bit_count_ = static_cast(fields[INPUT_SIZE2].data)[0]; - table_log_ = static_cast(fields[INPUT_SIZE3].data)[0]; - curr_chunk_ = static_cast(fields[INPUT_SIZE4].data)[0]; - } - - FseDecoderPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(FSE_DECODER_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &curr_chunk_idx_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &dst_type_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &curr_bit_count_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &table_log_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &curr_chunk_, sizeof(int)); - } - - FseDecoderPlugin() = delete; - - ~FseDecoderPlugin() override {} - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override; - void terminate() noexcept override; - int initialize() noexcept override; - - private: - template - int RunFseDecoder(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream); - - const std::string layer_name_; - std::string name_space_; - uint64_t curr_chunk_idx_; - int64_t dst_type_; - uint64_t curr_bit_count_; - uint64_t table_log_; - uint64_t curr_chunk_; - ShapeVector output_shape_; -}; -class FseDecoderPluginCreater : public TensorRTPluginCreater { - public: - FseDecoderPluginCreater() : TensorRTPluginCreater(std::string(FSE_DECODER_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_FSE_DECODER_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fullyconnected_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/fullyconnected_tensorrt.cc deleted file mode 100644 index 2148cb1da2d74a0e85966927f99202d090d57a4d..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fullyconnected_tensorrt.cc +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/fullyconnected_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "src/extendrt/delegate/tensorrt/op/activation_tensorrt.h" -#include "infer/cxx_api/full_connection.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" - -namespace mindspore::lite { -constexpr int BIAS_INDEX = 2; -int FullyConnectedTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (in_tensors.size() != INPUT_SIZE2 && in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int FullyConnectedTensorRT::AddInnerOp(TensorRTContext *ctx) { - auto primitive = AsOps(); - CHECK_NULL_RETURN(primitive); - if (primitive->HasAttr(ops::kActivationType)) { - activation_ = primitive->get_activation_type(); - } - int axis = primitive->get_axis(); - ITensorHelper fc_input; - auto ret = PreprocessInputs(ctx, &fc_input); - if (ret != RET_OK) { - MS_LOG(ERROR) << "PreprocessInputs failed for " << op_name_; - return ret; - } - auto kernel_weight = ConvertWeight(!in_tensors_[1].IsConst() ? in_tensors_[0] : in_tensors_[1]); - nvinfer1::Weights bias_weight{}; - if (primitive->get_has_bias()) { - bias_weight = ConvertWeight(in_tensors_[BIAS_INDEX]); - } - nvinfer1::IFullyConnectedLayer *fc_layer = ctx->network()->addFullyConnected( - *(fc_input.trt_tensor_), in_tensors_[1].Shape()[1 - axis], kernel_weight, bias_weight); - if (fc_layer == nullptr) { - MS_LOG(ERROR) << "addFullyConnected failed for " << op_name_; - return RET_ERROR; - } - this->layer_ = fc_layer; - fc_layer->setName(op_name_.c_str()); - nvinfer1::ITensor *out_tensor = fc_layer->getOutput(0); - - int origin_input_dims = input(ctx, 0).trt_tensor_->getDimensions().nbDims; - if (out_tensor->getDimensions().nbDims != origin_input_dims) { - std::vector squeeze_dim; - for (int i = 0; i != origin_input_dims; ++i) { - squeeze_dim.push_back(out_tensor->getDimensions().d[i]); - } - out_tensor = Reshape(ctx, out_tensor, squeeze_dim); - } - // add activation - if (activation_ != ActivationType::NO_ACTIVATION) { - nvinfer1::ILayer *activation_layer = - ActivationTensorRT::AddActivation(ctx, activation_, 0, 0, 0, out_tensor, op_name_, device_id_); - if (activation_layer == nullptr) { - MS_LOG(ERROR) << "addActivation for matmul failed"; - return RET_ERROR; - } - activation_layer->setName((op_name_ + "_activation").c_str()); - out_tensor = activation_layer->getOutput(0); - } - - ctx->RegisterTensor(ITensorHelper{out_tensor, fc_input.format_}, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "output " << GetTensorFormat(out_tensor); - return RET_OK; -} - -int FullyConnectedTensorRT::PreprocessInputs(TensorRTContext *ctx, ITensorHelper *fc_input) { - auto ret = PreprocessInputs2SameDim(ctx, input(ctx, in_tensors_[1].IsConst() ? 0 : 1), fc_input); - if (ret != RET_OK) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim failed for " << op_name_; - return ret; - } - auto origin_dims = fc_input->trt_tensor_->getDimensions(); - if (origin_dims.nbDims != DIMENSION_4D) { - std::vector expand_dim(origin_dims.d, origin_dims.d + origin_dims.nbDims); - for (int i = 0; i < DIMENSION_4D - origin_dims.nbDims; i++) { - expand_dim.push_back(1); - } - fc_input->trt_tensor_ = Reshape(ctx, fc_input->trt_tensor_, expand_dim); - } - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameFullConnection, FullyConnectedTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fullyconnected_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/fullyconnected_tensorrt.h deleted file mode 100644 index e0fdddf0dac3da1852d45e64902e9c5cb85f7e28..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/fullyconnected_tensorrt.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_FULLYCONNECTED_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_FULLYCONNECTED_TENSORRT_H_ - -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class FullyConnectedTensorRT : public TensorRTOp { - public: - FullyConnectedTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~FullyConnectedTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return true; } - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int PreprocessInputs(TensorRTContext *ctx, ITensorHelper *fc_input); - - ActivationType activation_{ActivationType::NO_ACTIVATION}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_FULLYCONNECTED_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_d_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_d_tensorrt.cc deleted file mode 100644 index 30b4860432b32d4b48092ebe6d8424d008222c2f..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_d_tensorrt.cc +++ /dev/null @@ -1,137 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/gather_d_tensorrt.h" -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" - -namespace mindspore::lite { -REGISTER_TENSORRT_PLUGIN(GatherDPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int GatherDTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported gatherd input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "invalid gatherd input tensor size: " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "invalid gatherd output tensor size: " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int GatherDTensorRT::AddInnerOp(TensorRTContext *ctx) { - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_, input(ctx, 2).trt_tensor_}; - auto dim_vec = ConvertTensorAsIntVector(in_tensors_[1]); - if (dim_vec.size() != 1) { - MS_LOG(ERROR) << "Failed to get dim input, dim count " << dim_vec.size() << ", node: " << op_name_; - return RET_ERROR; - } - size_t dim = dim_vec[0]; - - auto plugin = std::make_shared(op_name_, dim, device_id_); - nvinfer1::IPluginV2Layer *gatherd_layer = ctx->network()->addPluginV2(inputTensors, INPUT_SIZE2, *plugin); - if (gatherd_layer == nullptr) { - MS_LOG(ERROR) << "create gatherd failed for: " << op_name_; - return RET_ERROR; - } - nvinfer1::ITensor *gatherd_out = gatherd_layer->getOutput(0); - gatherd_layer->setName(op_name_.c_str()); - ctx->RegisterTensor(ITensorHelper{gatherd_out, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - this->layer_ = gatherd_layer; - return RET_OK; -} - -int GatherDPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - nvinfer1::Dims input_dims = inputDesc[0].dims; - int dims = input_dims.nbDims; - if (axis_ < 0) { - axis_ += dims; - } - - if (inputDesc->type == nvinfer1::DataType::kINT32) { - auto input = static_cast(inputs[0]); - auto index = static_cast(inputs[1]); - auto output = static_cast(outputs[0]); - Reshape(inputDesc, outputDesc); - GatherD(input, index, output, static_cast(axis_), num_, input_dims.nbDims, input_shape_helper_, - index_shape_helper_, stream, device_id_); - } else if (inputDesc->type == nvinfer1::DataType::kFLOAT) { - auto input = static_cast(inputs[0]); - auto index = static_cast(inputs[1]); - auto output = static_cast(outputs[0]); - Reshape(inputDesc, outputDesc); - GatherD(input, index, output, static_cast(axis_), num_, input_dims.nbDims, input_shape_helper_, - index_shape_helper_, stream, device_id_); - } else { - MS_LOG(ERROR) << "unsupported data type gatherd" << layer_name_; - } - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *GatherDPlugin::clone() const noexcept { - auto *plugin = new GatherDPlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -nvinfer1::DimsExprs GatherDPlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs out_dims{}; - out_dims.nbDims = inputs[1].nbDims; - for (int i = 0; i < inputs[1].nbDims; i++) { - out_dims.d[i] = inputs[1].d[i]; - } - return out_dims; -} - -void GatherDPlugin::Reshape(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc) { - nvinfer1::Dims input_dims = inputDesc[0].dims; - nvinfer1::Dims output_dims = outputDesc[0].dims; - - if (input_dims.nbDims > static_cast(kMaxShapeRank)) { - MS_LOG(EXCEPTION) << "The rank of input should be less than " << kMaxShapeRank << ", but got " << input_dims.nbDims - << "."; - } - num_ = 1; - for (size_t i = 0; i < static_cast(input_dims.nbDims); i++) { - input_shape_helper_.shape[i] = static_cast(input_dims.d[i]); - index_shape_helper_.shape[i] = static_cast(output_dims.d[i]); - num_ *= static_cast(output_dims.d[i]); - } - - return; -} -REGISTER_TENSORRT_CREATOR(ops::kNameGatherD, GatherDTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_d_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_d_tensorrt.h deleted file mode 100644 index 8c4706d8ff3b82412627c04ee2901022f744518c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_d_tensorrt.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GATHER_D_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GATHER_D_TENSORRT_H_ -#include -#include -#include "kernel/gpu/cuda_impl/cuda_ops/gatherd.cuh" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" - -namespace mindspore::lite { -constexpr auto GATHER_D_PLUGIN_NAME{"GatherDPluginCreater"}; -class GatherDTensorRT : public TensorRTOp { - public: - GatherDTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~GatherDTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -class GatherDPlugin : public TensorRTPlugin { - public: - GatherDPlugin(const std::string name, size_t dim, uint32_t device_id) - : TensorRTPlugin(name, std::string(GATHER_D_PLUGIN_NAME), device_id), axis_(dim) {} - - GatherDPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(GATHER_D_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - axis_ = static_cast(fields[0].data)[0]; - } - - GatherDPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(GATHER_D_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &axis_, sizeof(int)); - } - - GatherDPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - - private: - int axis_; - ShapeHelper input_shape_helper_; - ShapeHelper index_shape_helper_; - size_t num_; - void Reshape(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc); -}; - -class GatherDPluginCreater : public TensorRTPluginCreater { - public: - GatherDPluginCreater() : TensorRTPluginCreater(std::string(GATHER_D_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GATHER_D_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_nd_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_nd_tensorrt.cc deleted file mode 100644 index 761a492fa9b6912adc743fc6aca43bd427331429..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_nd_tensorrt.cc +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/gather_nd_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" - -namespace mindspore::lite { -int GatherNDTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { -#if TRT_VERSION_GE(8, 2) - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (in_tensors[1].DataType() != DataType::kNumberTypeInt32) { - MS_LOG(ERROR) << "Gather indices only support Int32"; - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - - return RET_OK; -#else - MS_LOG(WARNING) << "low TensorRT version don't support gathernd op, please upgrade TensorRT version to 8.2 or higher"; - return RET_ERROR; -#endif -} - -int GatherNDTensorRT::AddInnerOp(TensorRTContext *ctx) { -#if TRT_VERSION_GE(8, 2) - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - ITensorHelper gather_nd_input = input(ctx, 0); - ITensorHelper indices_tensor = input(ctx, 1); - if (indices_tensor.trt_tensor_->getDimensions().nbDims < 1) { - MS_LOG(ERROR) << "addGather failed for TensorRT."; - return RET_ERROR; - } - if (in_tensors_[0].DataType() == DataType::kNumberTypeBool) { - gather_nd_input.trt_tensor_ = - TRTTensorCast(ctx, gather_nd_input.trt_tensor_, nvinfer1::DataType::kINT32, op_name_ + "_input"); - } - nvinfer1::IGatherLayer *gather_layer = - ctx->network()->addGatherV2(*gather_nd_input.trt_tensor_, *indices_tensor.trt_tensor_, nvinfer1::GatherMode::kND); - if (gather_layer == nullptr) { - MS_LOG(ERROR) << "addGatherND failed for TensorRT."; - return RET_ERROR; - } - gather_layer->setNbElementWiseDims(0); - this->layer_ = gather_layer; - gather_layer->setName(op_name_.c_str()); - nvinfer1::ITensor *op_output = gather_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{op_output, gather_nd_input.format_, gather_nd_input.same_format_}, - out_tensors_[0].Name()); - return RET_OK; -#else - MS_LOG(WARNING) << "low TensorRT version don't support gathernd op, please upgrade TensorRT version to 8.2 or higher"; - return RET_ERROR; -#endif -} -REGISTER_TENSORRT_CREATOR(ops::kNameGatherNd, GatherNDTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_nd_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_nd_tensorrt.h deleted file mode 100644 index 921f27bcca97b8c54e4b840dee6aec9093cd8c6f..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_nd_tensorrt.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENT_RT_DELEGATE_TENSORRT_OP_GATHER_ND_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENT_RT_DELEGATE_TENSORRT_OP_GATHER_ND_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class GatherNDTensorRT : public TensorRTOp { - public: - GatherNDTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~GatherNDTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENT_RT_DELEGATE_TENSORRT_OP_GATHER_ND_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_tensorrt.cc deleted file mode 100644 index cbb17614e9b7cf5c26949c83a30b862ac18d5622..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_tensorrt.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/gather_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" - -namespace mindspore::lite { -constexpr int AXIS_INDEX = 2; - -int GatherTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "invalid input tensor size: " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "invalid output tensor size: " << out_tensors.size(); - return RET_ERROR; - } - if (in_tensors[1].DataType() != DataType::kNumberTypeInt32 && - in_tensors[1].DataType() != DataType::kNumberTypeInt64) { - MS_LOG(ERROR) << "Gather indices only support Int32"; - return RET_ERROR; - } - if (in_tensors[AXIS_INDEX].ElementNum() == 1) { - auto axis_vec = ConvertTensorAsIntVector(in_tensors_[AXIS_INDEX]); - if (axis_vec.size() != 1) { - MS_LOG(ERROR) << "Failed to get axis input, dim count " << axis_vec.size() << ", node: " << op_name_; - return RET_ERROR; - } - axis_ = axis_vec[0]; - } else { - MS_LOG(ERROR) << "TensorRT axis is attribute."; - return RET_ERROR; - } - return RET_OK; -} - -int GatherTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - for (size_t i = 0; i != AXIS_INDEX; ++i) { - if (input(ctx, i).trt_tensor_ == nullptr) { - auto const_input = ConvertConstantTensor(ctx, in_tensors_[i], op_name_); - auto is_scalar = in_tensors_[i].Shape().empty(); - ctx->RegisterTensor(ITensorHelper{const_input, NCHW, true, !is_scalar}, in_tensors_[i].Name()); - } - } - - ITensorHelper gather_input = input(ctx, 0); - int ret = PreprocessInputs2SameDim(ctx, gather_input, &gather_input); - if (ret != RET_OK || gather_input.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim gather failed for " << op_name_; - return RET_ERROR; - } - ITensorHelper indices_tensor = input(ctx, 1); - ret = PreprocessInputs2SameDim(ctx, indices_tensor, &indices_tensor); - if (ret != RET_OK || indices_tensor.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim indices failed for " << op_name_; - return RET_ERROR; - } - - nvinfer1::IGatherLayer *gather_layer = - ctx->network()->addGather(*gather_input.trt_tensor_, *indices_tensor.trt_tensor_, axis_); - if (gather_layer == nullptr) { - MS_LOG(ERROR) << "addGather failed for TensorRT."; - return RET_ERROR; - } - - this->layer_ = gather_layer; - gather_layer->setName(op_name_.c_str()); - nvinfer1::ITensor *op_output = gather_layer->getOutput(0); - auto old_shape = ConvertMSShape(op_output->getDimensions()); - // keep shape - if (!indices_tensor.is_tensor && old_shape.size() > 1) { - auto squeeze = ctx->network()->addShuffle(*op_output); - if (squeeze == nullptr) { - MS_LOG(ERROR) << "add output squeeze failed for " << op_name_; - return RET_ERROR; - } - squeeze->setName((op_name_ + "_squeeze_out").c_str()); - old_shape.erase(old_shape.begin() + axis_); - squeeze->setReshapeDimensions(ConvertCudaDims(old_shape)); - op_output = squeeze->getOutput(0); - } - - auto out_helper = ITensorHelper{op_output, gather_input.format_, gather_input.same_format_}; - if (old_shape.size() == 1) { - out_helper.is_tensor = false; - } - ctx->RegisterTensor(out_helper, out_tensors_[0].Name()); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameGather, GatherTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_tensorrt.h deleted file mode 100644 index a8017ad484625847c4c4f6db00baecd802e07a90..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/gather_tensorrt.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GATHER_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GATHER_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class GatherTensorRT : public TensorRTOp { - public: - GatherTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~GatherTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return true; } - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int axis_{0}; - TensorInfo indices_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GATHER_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/glu_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/glu_tensorrt.cc deleted file mode 100644 index 86de930ff2a8d48a8bf038b287da08db4df22bb1..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/glu_tensorrt.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/glu_tensorrt.h" -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" - -namespace mindspore::lite { -int GLUTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int GLUTensorRT::AddInnerOp(TensorRTContext *ctx) { - dim_ = AsOps()->get_axis(); - auto rank = input(ctx, 0).trt_tensor_->getDimensions().nbDims; - dim_ = (dim_ != -1) ? dim_ : rank - 1; - // split - std::vector split_dims(rank, 1); - split_dims[dim_] = SPLITE_NUM; - auto split_dims_tensor = ctx->ConvertTo1DTensor(split_dims); - auto in_tensor_shape = ctx->network()->addShape(*input(ctx, 0).trt_tensor_)->getOutput(0); - auto split_tensor1 = ctx->network() - ->addElementWise(*in_tensor_shape, *split_dims_tensor, nvinfer1::ElementWiseOperation::kDIV) - ->getOutput(0); - nvinfer1::Dims starts{rank}; - std::fill(starts.d, starts.d + rank, 0); - nvinfer1::Dims strides{rank}; - std::fill(strides.d, strides.d + rank, 1); - nvinfer1::ISliceLayer *slice_layer = ctx->network()->addSlice(*input(ctx, 0).trt_tensor_, starts, {}, strides); - slice_layer->setInput(INPUT_INDEX, *split_tensor1); - auto input1 = slice_layer->getOutput(0); - std::vector start_mask(rank, 0); - start_mask[dim_] = 1; - auto start_dims_tensor = ctx->ConvertTo1DTensor(start_mask); - nvinfer1::ISliceLayer *slice_layer2 = ctx->network()->addSlice(*input(ctx, 0).trt_tensor_, {}, {}, strides); - auto start_tensor = ctx->network() - ->addElementWise(*split_tensor1, *start_dims_tensor, nvinfer1::ElementWiseOperation::kPROD) - ->getOutput(0); - slice_layer2->setInput(1, *start_tensor); - slice_layer2->setInput(INPUT_INDEX, *split_tensor1); - auto input2 = slice_layer2->getOutput(0); - // sigmoid - auto sigmoid_tensor = ctx->network()->addActivation(*input2, nvinfer1::ActivationType::kSIGMOID)->getOutput(0); - // mul - auto mul_layer = ctx->network()->addElementWise(*input1, *sigmoid_tensor, nvinfer1::ElementWiseOperation::kPROD); - auto out_tensor = mul_layer->getOutput(0); - this->layer_ = mul_layer; - ctx->RegisterTensor(ITensorHelper{out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameGLU, GLUTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/glu_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/glu_tensorrt.h deleted file mode 100644 index 768b10a3e85be7490f788270d96618d8b46c9c1b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/glu_tensorrt.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GLU_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GLU_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -constexpr int SPLITE_NUM = 2; -constexpr int INPUT_INDEX = 2; -class GLUTensorRT : public TensorRTOp { - public: - GLUTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~GLUTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int64_t dim_{-1}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GLU_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/greaterorequal_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/greaterorequal_tensorrt.cc deleted file mode 100644 index 01939787a9b78575ad8b8b686bbc9dc0166e3861..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/greaterorequal_tensorrt.cc +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/greaterorequal_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include "NvInferRuntimeCommon.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/logical.cuh" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" - -namespace mindspore::lite { -int GreaterorequalTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int GreaterorequalTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "network or input tensor is invalid"; - return RET_ERROR; - } - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_, input(ctx, 1).trt_tensor_}; - for (size_t i = 0; i != in_tensors_.size(); ++i) { - ITensorHelper input_helper = input(ctx, i); - if (input_helper.trt_tensor_->getType() != nvinfer1::DataType::kINT32) { - auto cast_layer = ctx->network()->addIdentity(*input_helper.trt_tensor_); - if (cast_layer == nullptr) { - MS_LOG(ERROR) << "create cast layer failed for: " << op_name_; - return RET_ERROR; - } - cast_layer->setOutputType(0, nvinfer1::DataType::kINT32); - inputTensors[i] = cast_layer->getOutput(0); - } - } - auto plugin = std::make_shared(op_name_, schema::PrimitiveType_GreaterEqual); - if (plugin == nullptr) { - MS_LOG(ERROR) << "create GreaterorequalPlugin failed for " << op_name_; - return RET_ERROR; - } - nvinfer1::IPluginV2Layer *greaterorequal_layer = ctx->network()->addPluginV2(inputTensors, 2, *plugin); - this->layer_ = greaterorequal_layer; - - nvinfer1::ITensor *op_out_tensor = greaterorequal_layer->getOutput(0); - if (op_out_tensor == nullptr) { - MS_LOG(ERROR) << "greaterorequal out tensor is nullptr."; - return RET_ERROR; - } - ctx->RegisterTensor(ITensorHelper{op_out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(GreaterorequalPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int GreaterorequalPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - return RunCudaGreaterorequal(inputDesc, inputs, outputs, stream); -} - -int GreaterorequalPlugin::RunCudaGreaterorequal(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - if (inputDesc->type == nvinfer1::DataType::kINT32) { - GreaterOrEqual(static_cast(inputs[0]), static_cast(inputs[1]), - static_cast(outputs[0]), GetDimsVolume(inputDesc[0].dims), stream); - } else if (inputDesc->type == nvinfer1::DataType::kFLOAT) { - GreaterOrEqual(static_cast(inputs[0]), static_cast(inputs[1]), - static_cast(outputs[0]), GetDimsVolume(inputDesc[0].dims), stream); - } else { - MS_LOG(ERROR) << "unsupported equal data type"; - return RET_ERROR; - } - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *GreaterorequalPlugin::clone() const noexcept { - auto *plugin = new (std::nothrow) GreaterorequalPlugin(*this); - if (plugin == nullptr) { - MS_LOG(ERROR) << "malloc greaterorequal plugin failed"; - return nullptr; - } - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -bool GreaterorequalPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, - int nbInputs, int nbOutputs) noexcept { - if (tensorsDesc[pos].format != nvinfer1::TensorFormat::kLINEAR) { - return false; - } - if (pos == 0) { - return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT || tensorsDesc[pos].type == nvinfer1::DataType::kINT32; - } - if (pos < nbInputs) { - return tensorsDesc[pos].type == tensorsDesc[pos - 1].type; - } - if (pos < nbInputs + nbOutputs) { - return tensorsDesc[pos].type == nvinfer1::DataType::kINT32; - } - return false; -} - -size_t GreaterorequalPlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType); } - -void GreaterorequalPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &primitive_type_, sizeof(schema::PrimitiveType)); -} -REGISTER_TENSORRT_CREATOR(ops::kNameGreaterEqual, GreaterorequalTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/greaterorequal_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/greaterorequal_tensorrt.h deleted file mode 100644 index f615aa67c82233bd83193073a79898ddf59b2f52..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/greaterorequal_tensorrt.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GREATEROREQUAL_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GREATEROREQUAL_PLUGIN_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class GreaterorequalTensorRT : public TensorRTOp { - public: - GreaterorequalTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~GreaterorequalTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto GREATEROREQUAL_PLUGIN_NAME{"GreaterorequalPlugin"}; -class GreaterorequalPlugin : public TensorRTPlugin { - public: - GreaterorequalPlugin(const std::string name, schema::PrimitiveType primitive_type) - : TensorRTPlugin(name, std::string(GREATEROREQUAL_PLUGIN_NAME)), primitive_type_(primitive_type) {} - - GreaterorequalPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(GREATEROREQUAL_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - primitive_type_ = static_cast(fields[0].data)[0]; - } - - GreaterorequalPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(GREATEROREQUAL_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &primitive_type_, sizeof(schema::PrimitiveType)); - } - - GreaterorequalPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override { - return nvinfer1::DataType::kINT32; - } - - private: - int RunCudaGreaterorequal(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream); - const std::string layer_name_; - std::string name_space_; - schema::PrimitiveType primitive_type_; -}; -class GreaterorequalPluginCreater : public TensorRTPluginCreater { - public: - GreaterorequalPluginCreater() : TensorRTPluginCreater(std::string(GREATEROREQUAL_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_GREATEROREQUAL_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/instancenorm_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/instancenorm_tensorrt.cc deleted file mode 100644 index b767bda7125e824de27b18b0b3fabcdc87ab3afe..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/instancenorm_tensorrt.cc +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/instancenorm_tensorrt.h" -#include -#include -#include "infer/instance_norm.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" - -namespace mindspore::lite { -namespace { -constexpr int GAMMA_INDEX = 1; -constexpr int BETA_INDEX = 2; -} // namespace -int InstanceNormTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (!in_tensors_[GAMMA_INDEX].IsConst() || !in_tensors_[BETA_INDEX].IsConst()) { - MS_LOG(ERROR) << "Unsupported non const gamma or beta input, is gamma const: " << in_tensors_[GAMMA_INDEX].IsConst() - << ", is beta const: " << in_tensors_[BETA_INDEX].IsConst(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - dynamic_shape_params_.support_dynamic_ = false; - dynamic_shape_params_.support_hw_dynamic_ = false; - return RET_OK; -} - -int InstanceNormTensorRT::AddInnerOp(TensorRTContext *ctx) { - CHECK_NULL_RETURN(ctx->network()); - auto norm_op = AsOps(); - CHECK_NULL_RETURN(norm_op); - epsilon_ = norm_op->get_epsilon(); - - ITensorHelper norm_input = input(ctx, 0); - if (norm_input.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "Input tensorrt tensor cannot be nullptr, op: " << op_name_; - return RET_ERROR; - } - auto norm_input_dims = norm_input.trt_tensor_->getDimensions(); - if (norm_input_dims.nbDims != kDim4) { - MS_LOG(ERROR) << "Expect count of input dims to be " << kDim4 << ", but got: " << CudaDimsAsString(norm_input_dims) - << " , op: " << op_name_; - return RET_ERROR; - } - if (IsDynamicInput(ctx, 0)) { - MS_LOG(ERROR) << "Not support dynamic input, input dims: " << CudaDimsAsString(norm_input_dims) - << ", op: " << op_name_; - return RET_ERROR; - } - auto &gamma_input = in_tensors_[GAMMA_INDEX]; - auto &beta_input = in_tensors_[BETA_INDEX]; - auto nc = norm_input_dims.d[0] * norm_input_dims.d[1]; - if (gamma_input.ElementNum() != nc || beta_input.ElementNum() != nc) { - MS_LOG(ERROR) << "Element number of gamma or beta expect to be N*C of input, but got gamma element number: " - << gamma_input.ElementNum() << ", beta element number: " << beta_input.ElementNum() - << ", input dims: " << CudaDimsAsString(norm_input_dims) << ", op: " << op_name_; - return RET_ERROR; - } - auto expect_shape = ConvertMSShape(norm_input_dims); - if (gamma_input.Shape().size() == 1) { - expect_shape[kDim2] = expect_shape[kDim2] * expect_shape[kDim3]; - expect_shape.erase(expect_shape.begin() + kDim3); - } - gamma_ = ConvertTensorWithExpandDims(ctx, gamma_input, expect_shape, op_name_ + gamma_input.Name()); - CHECK_NULL_RETURN(gamma_); - beta_ = ConvertTensorWithExpandDims(ctx, beta_input, expect_shape, op_name_ + beta_input.Name()); - CHECK_NULL_RETURN(beta_); - - auto reshape_layer = ctx->network()->addShuffle(*norm_input.trt_tensor_); - auto reshape_shape = ConvertMSShape(norm_input_dims); - reshape_shape[kDim2] = reshape_shape[kDim2] * reshape_shape[kDim3]; - reshape_shape.erase(reshape_shape.begin() + kDim3); - reshape_layer->setReshapeDimensions(ConvertCudaDims(reshape_shape)); - // n,c,hw - auto reshape_output = reshape_layer->getOutput(0); - - constexpr uint32_t reduce_axis_hw = (1 << 2); - // scale = gama / sqrt(mean(hw*hw) - mean(hw)^2 + epsilon_) - // dst[index] = (src[index] - mean(hw)) * scale + beta_data = src[index]*scale - mean(hw)*scale + beta_data - // mean(hw) - auto mean_layer = ctx->network()->addReduce(*reshape_output, nvinfer1::ReduceOperation::kAVG, reduce_axis_hw, true); - auto mean_output = mean_layer->getOutput(0); - // mean(hw)^2 - auto mean_square_layer = - ctx->network()->addElementWise(*mean_output, *mean_output, nvinfer1::ElementWiseOperation::kPROD); - auto mean_square_output = mean_square_layer->getOutput(0); - // hw*hw - auto square_layer = - ctx->network()->addElementWise(*reshape_output, *reshape_output, nvinfer1::ElementWiseOperation::kPROD); - auto square_output = square_layer->getOutput(0); - // mean(hw*hw) - auto square_mean_layer = - ctx->network()->addReduce(*square_output, nvinfer1::ReduceOperation::kAVG, reduce_axis_hw, true); - auto square_mean_output = square_mean_layer->getOutput(0); - // mean(hw*hw) - mean(hw)^2 - auto var_layer = - ctx->network()->addElementWise(*square_mean_output, *mean_square_output, nvinfer1::ElementWiseOperation::kSUB); - auto var_output = var_layer->getOutput(0); - - auto const_epsilon = ConvertScalarToITensor(ctx, var_output->getDimensions().nbDims, &epsilon_, - DataType::kNumberTypeFloat32, op_name_ + "_epsilion"); - CHECK_NULL_RETURN(const_epsilon); - // mean(hw*hw) - mean(hw)^2 + epsilon_ - auto var_epsilon = - ctx->network()->addElementWise(*var_output, *const_epsilon, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - CHECK_NULL_RETURN(var_epsilon); - - // sqrt(mean(hw*hw) - mean(hw)^2 + epsilon_), standard deviation - auto std_dev = ctx->network()->addUnary(*var_epsilon, nvinfer1::UnaryOperation::kSQRT)->getOutput(0); - CHECK_NULL_RETURN(std_dev); - // gama / sqrt(mean(hw*hw) - mean(hw)^2 + epsilon_) - auto scale = ctx->network()->addElementWise(*gamma_, *std_dev, nvinfer1::ElementWiseOperation::kDIV)->getOutput(0); - CHECK_NULL_RETURN(scale); - - // mean(hw)*scale - auto mean_mul_scale = - ctx->network()->addElementWise(*mean_output, *scale, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - CHECK_NULL_RETURN(mean_mul_scale); - - // bias = - mean(hw)*scale + beta_data - auto bias = - ctx->network()->addElementWise(*beta_, *mean_mul_scale, nvinfer1::ElementWiseOperation::kSUB)->getOutput(0); - CHECK_NULL_RETURN(bias); - - // scale with bias: src[index]*scale - auto scale_layer = ctx->network()->addElementWise(*reshape_output, *scale, nvinfer1::ElementWiseOperation::kPROD); - this->layer_ = scale_layer; - auto scale_out = scale_layer->getOutput(0); - CHECK_NULL_RETURN(scale_out); - // src[index]*scale - mean(hw)*scale + beta_data - auto beta_out = ctx->network()->addElementWise(*scale_out, *bias, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - CHECK_NULL_RETURN(beta_out); - - auto reshape_final_layer = ctx->network()->addShuffle(*beta_out); - reshape_final_layer->setReshapeDimensions(norm_input_dims); - auto output = reshape_final_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{output, Format::NCHW, true}, out_tensors_[0].Name()); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameInstanceNorm, InstanceNormTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/instancenorm_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/instancenorm_tensorrt.h deleted file mode 100644 index eee3534906b44441d87b97b0af9f5f77a7e7887a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/instancenorm_tensorrt.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_INSTANCE_NORM_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_INSTANCE_NORM_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class InstanceNormTensorRT : public TensorRTOp { - public: - InstanceNormTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~InstanceNormTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int RunAsTrtOps(TensorRTContext *ctx, ITensorHelper helper); - - float epsilon_{0.0f}; - nvinfer1::ITensor *gamma_{nullptr}; - nvinfer1::ITensor *beta_{nullptr}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_INSTANCE_NORM_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/l2normalization_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/l2normalization_tensorrt.cc deleted file mode 100644 index e43fc36060828535b0bd968fb676d806320be748..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/l2normalization_tensorrt.cc +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/l2normalization_tensorrt.h" -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "infer/cxx_api/l2_normalize_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" - -namespace mindspore::lite { -int L2NormalizationTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - return RET_OK; -} - -int L2NormalizationTensorRT::AddInnerOp(TensorRTContext *ctx) { - auto input_tensor = input(ctx, 0).trt_tensor_; - int nbdims = input_tensor->getDimensions().nbDims; - auto op = AsOps(); - int64_t axis = op->get_axis()[0]; - if (axis < 0) { - axis += nbdims; - } - - if (axis < 0 || axis >= nbdims) { - MS_LOG(ERROR) << "axis error : " << axis << " for " << op_name_; - return RET_ERROR; - } - - float epsilon = op->get_epsilon(); - - auto pow = - ctx->network()->addElementWise(*input_tensor, *input_tensor, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - auto sum = ctx->network()->addReduce(*pow, nvinfer1::ReduceOperation::kSUM, 1 << axis, true)->getOutput(0); - - auto ep = ctx->ConvertTo1DTensor(epsilon); - while (ep->getDimensions().nbDims < nbdims) { - ep = ExpandDim(ctx, ep, 0); - } - - if (input_tensor->getType() != nvinfer1::DataType::kFLOAT) { - ep = TRTTensorCast(ctx, ep, input_tensor->getType(), op_name_ + "_cast_epsilon"); - } - if (ep == nullptr) { - MS_LOG(ERROR) << "ep is nullptr!"; - return RET_ERROR; - } - auto norm = ctx->network()->addElementWise(*sum, *ep, nvinfer1::ElementWiseOperation::kMAX)->getOutput(0); - norm = ctx->network()->addUnary(*norm, nvinfer1::UnaryOperation::kSQRT)->getOutput(0); - auto div_layer = ctx->network()->addElementWise(*input_tensor, *norm, nvinfer1::ElementWiseOperation::kDIV); - - nvinfer1::ITensor *out_tensor = div_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - this->layer_ = div_layer; - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameL2NormalizeFusion, L2NormalizationTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/l2normalization_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/l2normalization_tensorrt.h deleted file mode 100644 index 3af5b5c1c1bfb79165328f041aac07eabcc9c613..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/l2normalization_tensorrt.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_L2NORMALIZATION_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_L2NORMALIZATION_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class L2NormalizationTensorRT : public TensorRTOp { - public: - L2NormalizationTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~L2NormalizationTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - std::vector zeros_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_L2NORMALIZATION_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/lessorequal_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/lessorequal_tensorrt.cc deleted file mode 100644 index 742cfcf1013fe9999099c4334ec11f330bd4ec2b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/lessorequal_tensorrt.cc +++ /dev/null @@ -1,128 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/lessorequal_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/logical.cuh" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" - -namespace mindspore::lite { -int LessorequalTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int LessorequalTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "network or input tensor is invalid"; - return RET_ERROR; - } - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_, input(ctx, 1).trt_tensor_}; - auto plugin = std::make_shared(op_name_, schema::PrimitiveType_LessEqual); - if (plugin == nullptr) { - MS_LOG(ERROR) << "create LessorequalPlugin failed for " << op_name_; - return RET_ERROR; - } - nvinfer1::IPluginV2Layer *lessorequal_layer = ctx->network()->addPluginV2(inputTensors, 2, *plugin); - this->layer_ = lessorequal_layer; - nvinfer1::ITensor *op_out_tensor = lessorequal_layer->getOutput(0); - if (op_out_tensor == nullptr) { - MS_LOG(ERROR) << "lessorequal out tensor is nullptr."; - return RET_ERROR; - } - ctx->RegisterTensor(ITensorHelper{op_out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(LessorequalPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int LessorequalPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - return RunCudaLessorequal(inputDesc, inputs, outputs, stream); -} - -int LessorequalPlugin::RunCudaLessorequal(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - if (inputDesc->type == nvinfer1::DataType::kINT32) { - LessOrEqual(static_cast(inputs[0]), static_cast(inputs[1]), - static_cast(outputs[0]), GetDimsVolume(inputDesc[0].dims), stream); - } else if (inputDesc->type == nvinfer1::DataType::kFLOAT) { - LessOrEqual(static_cast(inputs[0]), static_cast(inputs[1]), - static_cast(outputs[0]), GetDimsVolume(inputDesc[0].dims), stream); - } else { - MS_LOG(ERROR) << "unsupported equal data type"; - return RET_ERROR; - } - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *LessorequalPlugin::clone() const noexcept { - auto *plugin = new (std::nothrow) LessorequalPlugin(*this); - if (plugin == nullptr) { - MS_LOG(ERROR) << "malloc lessorequal plugin failed"; - return nullptr; - } - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -bool LessorequalPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept { - if (tensorsDesc[pos].format != nvinfer1::TensorFormat::kLINEAR) { - return false; - } - if (pos == 0) { - return tensorsDesc[pos].type == nvinfer1::DataType::kINT32 || tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT; - } - if (pos < nbInputs) { - return tensorsDesc[pos].type == tensorsDesc[pos - 1].type; - } - if (pos < nbInputs + nbOutputs) { - return tensorsDesc[pos].type == nvinfer1::DataType::kINT32; - } - return false; -} - -size_t LessorequalPlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType); } - -void LessorequalPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &primitive_type_, sizeof(schema::PrimitiveType)); -} -REGISTER_TENSORRT_CREATOR(ops::kNameLessEqual, LessorequalTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/lessorequal_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/lessorequal_tensorrt.h deleted file mode 100644 index a3273dd3b099cb927a1ed279aea52509848c2ef4..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/lessorequal_tensorrt.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_LESSOREQUAL_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_LESSOREQUAL_PLUGIN_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class LessorequalTensorRT : public TensorRTOp { - public: - LessorequalTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~LessorequalTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto LESSOREQUAL_PLUGIN_NAME{"LessorequalPlugin"}; -class LessorequalPlugin : public TensorRTPlugin { - public: - LessorequalPlugin(const std::string name, schema::PrimitiveType primitive_type) - : TensorRTPlugin(name, std::string(LESSOREQUAL_PLUGIN_NAME)), primitive_type_(primitive_type) {} - - LessorequalPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(LESSOREQUAL_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - primitive_type_ = static_cast(fields[0].data)[0]; - } - - LessorequalPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(LESSOREQUAL_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &primitive_type_, sizeof(schema::PrimitiveType)); - } - - LessorequalPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override { - return nvinfer1::DataType::kINT32; - } - - private: - int RunCudaLessorequal(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - const std::string layer_name_; - std::string name_space_; - schema::PrimitiveType primitive_type_; -}; -class LessorequalPluginCreater : public TensorRTPluginCreater { - public: - LessorequalPluginCreater() : TensorRTPluginCreater(std::string(LESSOREQUAL_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_LESSOREQUAL_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/logical_not_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/logical_not_tensorrt.cc deleted file mode 100644 index 8692db3cbf1787d3dcf1f2dfff29296a3bdcb8b4..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/logical_not_tensorrt.cc +++ /dev/null @@ -1,126 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/logical_not_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/logical.cuh" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" - -namespace mindspore::lite { -int LogicalNotTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - } - return RET_OK; -} - -int LogicalNotTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr || this->in_tensors_.size() != 1) { - MS_LOG(ERROR) << "network or input tensor is invalid"; - return RET_ERROR; - } - ITensorHelper input_helper = input(ctx, 0); - if (input_helper.trt_tensor_->getType() != nvinfer1::DataType::kINT32) { - auto cast_layer = ctx->network()->addIdentity(*input_helper.trt_tensor_); - if (cast_layer == nullptr) { - MS_LOG(ERROR) << "create cast layer failed for: " << op_name_; - return RET_ERROR; - } - cast_layer->setOutputType(0, nvinfer1::DataType::kINT32); - input_helper.trt_tensor_ = cast_layer->getOutput(0); - } - auto plugin = std::make_shared(op_name_, schema::PrimitiveType_LogicalNot); - if (plugin == nullptr) { - MS_LOG(ERROR) << "create ActivationOptPlugin failed for " << op_name_; - return RET_ERROR; - } - nvinfer1::ITensor *inputTensors[] = {input_helper.trt_tensor_}; - nvinfer1::IPluginV2Layer *logical_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); - this->layer_ = logical_layer; - nvinfer1::ITensor *op_out_tensor = logical_layer->getOutput(0); - if (op_out_tensor == nullptr) { - MS_LOG(ERROR) << "addElementWise out tensor is nullptr."; - return RET_ERROR; - } - ctx->RegisterTensor(ITensorHelper{op_out_tensor, input_helper.format_, input_helper.same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(LogicalNotPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int LogicalNotPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - return RunCudaLogical(inputDesc, inputs, outputs, stream); -} - -int LogicalNotPlugin::RunCudaLogical(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - switch (primitive_type_) { - case (schema::PrimitiveType_LogicalNot): { - LogicalNot(static_cast(inputs[0]), static_cast(outputs[0]), GetDimsVolume(inputDesc[0].dims), - stream); - break; - } - default: { - MS_LOG(ERROR) << "invalid logical type: " << static_cast(primitive_type_); - return RET_ERROR; - } - } - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *LogicalNotPlugin::clone() const noexcept { - auto *plugin = new LogicalNotPlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -bool LogicalNotPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept { - return tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR && - tensorsDesc[pos].type == nvinfer1::DataType::kINT32; -} - -size_t LogicalNotPlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType); } - -void LogicalNotPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &primitive_type_, sizeof(schema::PrimitiveType)); -} -REGISTER_TENSORRT_CREATOR(ops::kNameLogicalNot, LogicalNotTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/logical_not_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/logical_not_tensorrt.h deleted file mode 100644 index 45aad4c96ba82fc92dd80a5116611ebc26ccc67c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/logical_not_tensorrt.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_LOGICAL_NOT_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_LOGICAL_NOT_TENSORRT_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class LogicalNotTensorRT : public TensorRTOp { - public: - LogicalNotTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~LogicalNotTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto LOGICAL_NOT_PLUGIN_NAME{"LogicalNotPlugin"}; -class LogicalNotPlugin : public TensorRTPlugin { - public: - LogicalNotPlugin(const std::string name, schema::PrimitiveType primitive_type) - : TensorRTPlugin(name, std::string(LOGICAL_NOT_PLUGIN_NAME)), primitive_type_(primitive_type) {} - - LogicalNotPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(LOGICAL_NOT_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - primitive_type_ = static_cast(fields[0].data)[0]; - } - - LogicalNotPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(LOGICAL_NOT_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &primitive_type_, sizeof(schema::PrimitiveType)); - } - - LogicalNotPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override; - - private: - int RunCudaLogical(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - const std::string layer_name_; - std::string name_space_; - schema::PrimitiveType primitive_type_; -}; -class LogicalNotPluginCreater : public TensorRTPluginCreater { - public: - LogicalNotPluginCreater() : TensorRTPluginCreater(std::string(LOGICAL_NOT_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_LOGICAL_NOT_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/logical_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/logical_tensorrt.cc deleted file mode 100644 index e293d02a834d7bdebc8b95e895ea90e11014806e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/logical_tensorrt.cc +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/logical_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/logical.cuh" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" - -namespace mindspore::lite { -int LogicalTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int LogicalTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "network or input tensor is invalid"; - return RET_ERROR; - } - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_, input(ctx, 1).trt_tensor_}; - for (size_t i = 0; i != in_tensors_.size(); ++i) { - ITensorHelper input_helper = input(ctx, i); - if (input_helper.trt_tensor_->getType() != nvinfer1::DataType::kINT32) { - auto cast_layer = ctx->network()->addIdentity(*input_helper.trt_tensor_); - if (cast_layer == nullptr) { - MS_LOG(ERROR) << "create cast layer failed for: " << op_name_; - return RET_ERROR; - } - cast_layer->setOutputType(0, nvinfer1::DataType::kINT32); - inputTensors[i] = cast_layer->getOutput(0); - } - } - schema::PrimitiveType schema_type = schema::PrimitiveType_NONE; - if (type_ == ops::kNameLogicalAnd) { - schema_type = schema::PrimitiveType_LogicalAnd; - } else if (type_ == ops::kNameLogicalOr) { - schema_type = schema::PrimitiveType_LogicalOr; - } - - auto plugin = std::make_shared(op_name_, schema_type); - if (plugin == nullptr) { - MS_LOG(ERROR) << "create ActivationOptPlugin failed for " << op_name_; - return RET_ERROR; - } - nvinfer1::IPluginV2Layer *logical_layer = ctx->network()->addPluginV2(inputTensors, 2, *plugin); - this->layer_ = logical_layer; - nvinfer1::ITensor *op_out_tensor = logical_layer->getOutput(0); - if (op_out_tensor == nullptr) { - MS_LOG(ERROR) << "addElementWise out tensor is nullptr."; - return RET_ERROR; - } - ctx->RegisterTensor(ITensorHelper{op_out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(LogicalPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int LogicalPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - return RunCudaLogical(inputDesc, inputs, outputs, stream); -} - -int LogicalPlugin::RunCudaLogical(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - switch (primitive_type_) { - case (schema::PrimitiveType_LogicalAnd): { - LogicalAnd(static_cast(inputs[0]), static_cast(inputs[1]), - static_cast(outputs[0]), GetDimsVolume(inputDesc[0].dims), stream); - break; - } - case (schema::PrimitiveType_LogicalOr): { - LogicalOr(static_cast(inputs[0]), static_cast(inputs[1]), - static_cast(outputs[0]), GetDimsVolume(inputDesc[0].dims), stream); - break; - } - default: { - MS_LOG(ERROR) << "invalid logical type: " << static_cast(primitive_type_); - return RET_ERROR; - } - } - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *LogicalPlugin::clone() const noexcept { - auto *plugin = new LogicalPlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -bool LogicalPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept { - if (tensorsDesc[pos].format != nvinfer1::TensorFormat::kLINEAR) { - return false; - } - if (pos == 0) { - return tensorsDesc[pos].type == nvinfer1::DataType::kINT32; - } - if (pos < nbInputs) { - return tensorsDesc[pos].type == tensorsDesc[pos - 1].type; - } - if (pos < nbInputs + nbOutputs) { - return tensorsDesc[pos].type == nvinfer1::DataType::kINT32; - } - return false; -} - -size_t LogicalPlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType); } - -void LogicalPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &primitive_type_, sizeof(schema::PrimitiveType)); -} - -REGISTER_TENSORRT_CREATOR(ops::kNameLogicalOr, LogicalTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameLogicalAnd, LogicalTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/logical_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/logical_tensorrt.h deleted file mode 100644 index 61721bc48113263e44ede1e979ba1e72fcb4093c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/logical_tensorrt.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_LOGICAL_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_LOGICAL_PLUGIN_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class LogicalTensorRT : public TensorRTOp { - public: - LogicalTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~LogicalTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto LOGICAL_PLUGIN_NAME{"LogicalPlugin"}; -class LogicalPlugin : public TensorRTPlugin { - public: - LogicalPlugin(const std::string name, schema::PrimitiveType primitive_type) - : TensorRTPlugin(name, std::string(LOGICAL_PLUGIN_NAME)), primitive_type_(primitive_type) {} - - LogicalPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(LOGICAL_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - primitive_type_ = static_cast(fields[0].data)[0]; - } - - LogicalPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(LOGICAL_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &primitive_type_, sizeof(schema::PrimitiveType)); - } - - LogicalPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override { - return nvinfer1::DataType::kINT32; - } - - private: - int RunCudaLogical(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - const std::string layer_name_; - std::string name_space_; - schema::PrimitiveType primitive_type_; -}; -class LogicalPluginCreater : public TensorRTPluginCreater { - public: - LogicalPluginCreater() : TensorRTPluginCreater(std::string(LOGICAL_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_LOGICAL_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/lstm_plugin.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/lstm_plugin.cc deleted file mode 100644 index f4f9bad98879adabd8a508903ecc691be9e730e8..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/lstm_plugin.cc +++ /dev/null @@ -1,172 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/lstm_plugin.h" -#include -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "kernel/gpu/cuda_impl/cuda_ops/swish_impl.cuh" - -namespace mindspore::lite { -REGISTER_TENSORRT_PLUGIN(LSTMPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int LSTMPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - cudnnRNNDataDescriptor_t xdesc; - cudnnRNNDataDescriptor_t ydesc; - - cudnnTensorDescriptor_t hdesc; - cudnnTensorDescriptor_t cdesc; - - cudnnRNNDataLayout_t layout{CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED}; - - int *seq_length_array; - int *dev_seq_lenghts; - size_t weight_space_size; - size_t reserve_space_size{0}; - size_t workspace_size; - void *d_workspace; - void *d_reserve_space{nullptr}; - - // lstm type - cudnnRNNMode_t cell_mode{CUDNN_LSTM}; - cudnnRNNBiasMode_t bias_mode{CUDNN_RNN_DOUBLE_BIAS}; - cudnnDirectionMode_t direction_mode{CUDNN_UNIDIRECTIONAL}; - cudnnRNNInputMode_t input_mode{CUDNN_LINEAR_INPUT}; - cudnnForwardMode_t fwd_mode{CUDNN_FWD_MODE_INFERENCE}; - - cudnnRNNDescriptor_t rnn_desc; - cudnnDropoutDescriptor_t dropout_desc; - - cudnnHandle_t cudnn_handle; - cudnnDataType_t data_type{CUDNN_DATA_FLOAT}; - cudnnDataType_t math_precison{CUDNN_DATA_FLOAT}; - cudnnMathType_t math_type{CUDNN_DEFAULT_MATH}; - cudnnRNNAlgo_t rnn_algo{CUDNN_RNN_ALGO_STANDARD}; - CUDNN_CHECK(cudnnCreate(&cudnn_handle)); - CUDNN_CHECK(cudnnSetStream(cudnn_handle, reinterpret_cast(stream))); - CUDNN_CHECK(cudnnCreateRNNDataDescriptor(&xdesc)); - CUDNN_CHECK(cudnnCreateRNNDataDescriptor(&ydesc)); - CUDNN_CHECK(cudnnCreateTensorDescriptor(&hdesc)); - CUDNN_CHECK(cudnnCreateTensorDescriptor(&cdesc)); - CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc)); - CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc)); - seq_length_array = new int[batch_size_]; - for (int i = 0; i < batch_size_; ++i) { - seq_length_array[i] = seq_len_; - } - cudaMalloc(&dev_seq_lenghts, sizeof(int) * batch_size_); - cudaMemcpy(dev_seq_lenghts, seq_length_array, sizeof(int) * batch_size_, cudaMemcpyHostToDevice); - - CUDNN_CHECK( - cudnnSetRNNDataDescriptor(xdesc, data_type, layout, seq_len_, batch_size_, input_size_, seq_length_array, nullptr)); - CUDNN_CHECK(cudnnSetRNNDataDescriptor(ydesc, data_type, layout, seq_len_, batch_size_, hidden_size_, seq_length_array, - nullptr)); - constexpr int kDims = 3; - int dim[kDims]; - int stride[kDims]; - - dim[0] = num_layers_; - dim[1] = batch_size_; - dim[INPUT_SIZE2] = hidden_size_; - - stride[0] = dim[INPUT_SIZE2] * dim[1]; - stride[1] = dim[INPUT_SIZE2]; - stride[INPUT_SIZE2] = 1; - - CUDNN_CHECK(cudnnSetTensorNdDescriptor(hdesc, data_type, kDims, dim, stride)); - CUDNN_CHECK(cudnnSetTensorNdDescriptor(cdesc, data_type, kDims, dim, stride)); - - CUDNN_CHECK(cudnnSetDropoutDescriptor(dropout_desc, cudnn_handle, 0, nullptr, 0, 1)); - - CUDNN_CHECK(cudnnSetRNNDescriptor_v8(rnn_desc, rnn_algo, cell_mode, bias_mode, direction_mode, input_mode, data_type, - math_precison, math_type, input_size_, hidden_size_, hidden_size_, num_layers_, - dropout_desc, 0)); - - // Set up weights and bias parameters - CUDNN_CHECK(cudnnGetRNNWeightSpaceSize(cudnn_handle, rnn_desc, &weight_space_size)); - - // Set up work space and reserved memory - CUDNN_CHECK(cudnnGetRNNTempSpaceSizes(cudnn_handle, rnn_desc, fwd_mode, xdesc, &workspace_size, &reserve_space_size)); - - if (workspace_size > 0) { - cudaMalloc(reinterpret_cast(&d_workspace), workspace_size); - } - if (reserve_space_size > 0) { - cudaMalloc(reinterpret_cast(&d_reserve_space), reserve_space_size); - } - auto x_addr = static_cast(inputs[0]); - auto hx_addr = static_cast(inputs[1]); - auto cx_addr = static_cast(inputs[INPUT_SIZE2]); - auto w_addr = static_cast(inputs[INPUT_SIZE3]); - - auto y_addr = static_cast(outputs[0]); - - CUDNN_CHECK(cudnnRNNForward(cudnn_handle, rnn_desc, fwd_mode, dev_seq_lenghts, xdesc, x_addr, ydesc, y_addr, hdesc, - hx_addr, nullptr, cdesc, cx_addr, nullptr, weight_space_size, w_addr, workspace_size, - d_workspace, reserve_space_size, d_workspace)); - return RET_OK; -} - -nvinfer1::DimsExprs LSTMPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs dims; - if (index == 0) { - dims = inputs[0]; - dims.d[INPUT_SIZE2] = exprBuilder.constant(hidden_size_); - } - return dims; -} - -nvinfer1::DataType LSTMPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept { - return nvinfer1::DataType::kFLOAT; -} - -nvinfer1::IPluginV2DynamicExt *LSTMPlugin::clone() const noexcept { - auto *plugin = new LSTMPlugin(layer_name_, num_layers_, batch_size_, seq_len_, input_size_, hidden_size_, dropout_, - has_bias_, bidirectional_, device_id_); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -size_t LSTMPlugin::getSerializationSize() const noexcept { - return sizeof(int) * INPUT_SIZE5 + sizeof(float) + sizeof(bool) * INPUT_SIZE2; -} - -void LSTMPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &num_layers_, sizeof(int)); - SerializeValue(&buffer, &batch_size_, sizeof(int)); - SerializeValue(&buffer, &seq_len_, sizeof(int)); - SerializeValue(&buffer, &input_size_, sizeof(int)); - SerializeValue(&buffer, &hidden_size_, sizeof(int)); - SerializeValue(&buffer, &dropout_, sizeof(float)); - SerializeValue(&buffer, &has_bias_, sizeof(bool)); - SerializeValue(&buffer, &bidirectional_, sizeof(bool)); -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/lstm_plugin.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/lstm_plugin.h deleted file mode 100644 index 518104becfa72df20139aadea4715acf4ceb29e0..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/lstm_plugin.h +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ACTIVATION_OPT_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ACTIVATION_OPT_PLUGIN_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.h" - -namespace mindspore::lite { -constexpr auto LSTM_PLUGIN_NAME{"LSTMPlugin"}; -constexpr size_t kInputShapeSize = 3; -class LSTMPlugin : public TensorRTPlugin { - public: - LSTMPlugin(const std::string &name, int num_layers, int batch_size, int seq_len, int input_size, int hidden_size, - float dropout, bool has_bias, bool bidirectional, uint32_t device_id) - : TensorRTPlugin(name, std::string(LSTM_PLUGIN_NAME), device_id), - num_layers_(num_layers), - batch_size_(batch_size), - seq_len_(seq_len), - input_size_(input_size), - hidden_size_(hidden_size), - dropout_(dropout), - has_bias_(has_bias), - bidirectional_(bidirectional) {} - - LSTMPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(LSTM_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - num_layers_ = static_cast(fields[0].data)[0]; - batch_size_ = static_cast(fields[1].data)[0]; - seq_len_ = static_cast(fields[INPUT_SIZE2].data)[0]; - input_size_ = static_cast(fields[INPUT_SIZE3].data)[0]; - hidden_size_ = static_cast(fields[INPUT_SIZE4].data)[0]; - dropout_ = static_cast(fields[INPUT_SIZE5].data)[0]; - has_bias_ = static_cast(fields[INPUT_SIZE6].data)[0]; - bidirectional_ = static_cast(fields[INPUT_SIZE7].data)[0]; - } - - LSTMPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(LSTM_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &num_layers_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &batch_size_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &seq_len_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &input_size_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &hidden_size_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &dropout_, sizeof(float)); - DeserializeValue(&serialData, &serialLength, &has_bias_, sizeof(bool)); - DeserializeValue(&serialData, &serialLength, &bidirectional_, sizeof(bool)); - } - - LSTMPlugin() = delete; - - ~LSTMPlugin() {} - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override; - int getNbOutputs() const noexcept override { return 1; } - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override { - return tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR && - tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT; - } - - private: - int num_layers_; - int batch_size_; - int seq_len_; - int input_size_; - int hidden_size_; - - float dropout_; - bool has_bias_; - bool bidirectional_; -}; -class LSTMPluginCreater : public TensorRTPluginCreater { - public: - LSTMPluginCreater() : TensorRTPluginCreater(std::string(LSTM_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ACTIVATION_OPT_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/lstm_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/lstm_tensorrt.cc deleted file mode 100644 index 994c3b3dbe0f59a38c5c863f1c8eea9cb64f16e0..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/lstm_tensorrt.cc +++ /dev/null @@ -1,539 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "src/extendrt/delegate/tensorrt/op/lstm_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/op/lstm_plugin.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h" -#include "infer/lstm.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" - -namespace mindspore::lite { -int LSTMTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { -#if TRT_VERSION_GE(7, 0) - if (in_tensors.size() != INPUT_TENSOR_SIZE && in_tensors.size() != INPUT_SIZE4) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != OUTPUT_TENSOR_SIZE && out_tensors.size() != INPUT_SIZE5) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - dynamic_shape_params_.support_dynamic_ = false; - dynamic_shape_params_.support_hw_dynamic_ = false; - return RET_OK; -#else - MS_LOG(WARNING) << "low TensorRT version don't support LSTM op, please upgrade TensorRT version to 7 or higher"; - return RET_ERROR; -#endif -} - -int LSTMTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - int input_data_dims_cnt = input(ctx, 0).trt_tensor_->getDimensions().nbDims; - if (input_data_dims_cnt != DIMENSION_3D) { - MS_LOG(ERROR) << "invalid input data shape dims for " << op_name_; - return RET_ERROR; - } - int ret = PreProcess(ctx); - if (ret != RET_OK) { - MS_LOG(ERROR) << "PreProcess for " << op_name_; - return ret; - } - if (in_tensors_.size() == INPUT_SIZE4) { - return RunLSTMPlugin(ctx); - } - - ret = AddLSTMLayers(ctx); - if (ret != RET_OK) { - MS_LOG(ERROR) << "AddLSTMLayers for " << op_name_; - return RET_ERROR; - } - - if (op_data_out_ == nullptr) { - MS_LOG(ERROR) << "layers final output tensor is invalid for " << op_name_; - return RET_ERROR; - } - MS_LOG(DEBUG) << "lstm op_data_out_ " << GetTensorFormat(op_data_out_); - MS_LOG(DEBUG) << "lstm op_hidden_out_ " << GetTensorFormat(op_hidden_out_); - MS_LOG(DEBUG) << "lstm op_cell_out_ " << GetTensorFormat(op_cell_out_); - ctx->RegisterTensor(ITensorHelper{op_data_out_}, out_tensors_[0].Name()); - ctx->RegisterTensor(ITensorHelper{op_hidden_out_}, out_tensors_[OUTPUT_HIDDEN_INDEX].Name()); - ctx->RegisterTensor(ITensorHelper{op_cell_out_}, out_tensors_[OUTPUT_CELL_INDEX].Name()); - return RET_OK; -} - -int LSTMTensorRT::RunLSTMPlugin(TensorRTContext *ctx) { - auto lstm_op = AsOps(); - if (params_.directional_cnt_ == BIDIRECTIONAL) { - MS_LOG(ERROR) << "mindir lstm with bidirectional not support yet"; - return RET_ERROR; - } - auto plugin = std::make_shared(op_name_, params_.layer_count_, params_.batch_size_, - params_.sequence_size_, params_.input_data_size_, params_.hidden_size_, - 0.f, lstm_op->get_has_bias(), false, device_id_); - CHECK_NULL_RETURN(plugin); - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_, input(ctx, 1).trt_tensor_, - input(ctx, INPUT_SIZE2).trt_tensor_, input(ctx, INPUT_SIZE3).trt_tensor_}; - nvinfer1::IPluginV2Layer *lstm_layer = ctx->network()->addPluginV2(inputTensors, INPUT_SIZE4, *plugin); - CHECK_NULL_RETURN(lstm_layer); - ctx->RegisterTensor(ITensorHelper{lstm_layer->getOutput(0)}, out_tensors_[0].Name()); - return RET_OK; -} - -int LSTMTensorRT::PreProcess(TensorRTContext *ctx) { - auto ms_input_shape = ConvertMSShape(input(ctx, 0).trt_tensor_->getDimensions()); - params_.sequence_size_ = ms_input_shape[0]; - params_.batch_size_ = ms_input_shape[1]; - params_.input_data_size_ = ms_input_shape[INPUT_SIZE_INDEX]; - if (params_.batch_size_ != 1) { - MS_LOG(WARNING) << op_name_ << " lstm has batchsize " << params_.batch_size_ << ", needs further verify"; - } - auto lstm_op = AsOps(); - params_.layer_count_ = lstm_op->get_num_layers() == 0 ? 1 : lstm_op->get_num_layers(); - params_.hidden_size_ = lstm_op->get_hidden_size(); - params_.directional_cnt_ = lstm_op->get_bidirectional() ? BIDIRECTIONAL : 1; - params_.data_type_ = ConvertDataType(in_tensors_[1].DataType()); - if (in_tensors_.size() == INPUT_SIZE4) { - return RET_OK; - } - // ms: 0 sequence size, 1 batch size, 2 input size -> tensorrt: 0 batch size, 1 sequence size, 2 input size - auto transpose_in_layer = ctx->network()->addShuffle(*input(ctx, 0).trt_tensor_); - if (transpose_in_layer == nullptr) { - MS_LOG(ERROR) << "create transpose_in_layer failed for " << op_name_; - return RET_ERROR; - } - this->layer_ = transpose_in_layer; - nvinfer1::Permutation transpose_perm{{1, 0, INPUT_SIZE_INDEX}}; - transpose_in_layer->setFirstTranspose(transpose_perm); - transpose_in_layer->setName((op_name_ + "transpose_in").c_str()); - input_data_ = transpose_in_layer->getOutput(0); - MS_LOG(DEBUG) << "lstm input " << GetTensorFormat(input_data_); - - return RET_OK; -} - -int LSTMTensorRT::AddLSTMLayers(TensorRTContext *ctx) { - nvinfer1::ITensor *data_out{nullptr}; - nvinfer1::ITensor *hidden_init{nullptr}; - nvinfer1::ITensor *cell_init{nullptr}; - if (in_tensors_[HIDDEN_IN_TENSOR_INIT].Data() != nullptr && in_tensors_[CELL_IN_TENSOR_INIT].Data() != nullptr) { - TensorInfo &hidden_in_init = in_tensors_[HIDDEN_IN_TENSOR_INIT]; - TensorInfo &cell_in_init = in_tensors_[CELL_IN_TENSOR_INIT]; - hidden_init_name_ = hidden_in_init.Name() + "_hidden_init"; - - hidden_init = ctx->network()->addInput( - hidden_init_name_.c_str(), nvinfer1::DataType::kFLOAT, - nvinfer1::Dims3(params_.layer_count_ * params_.directional_cnt_, params_.batch_size_, params_.hidden_size_)); - if (hidden_init == nullptr) { - MS_LOG(ERROR) << "add hidden_init input tensor failed for " << op_name_; - return RET_ERROR; - } - op_binding_tensor_.push_back(BindingHelper{hidden_init_name_, hidden_in_init.MutableData(), - nvinfer1::DataType::kFLOAT, hidden_in_init.DataSize()}); - cell_init_name_ = cell_in_init.Name() + "_cell_init"; - cell_init = ctx->network()->addInput( - cell_init_name_.c_str(), nvinfer1::DataType::kFLOAT, - nvinfer1::Dims3(params_.layer_count_ * params_.directional_cnt_, params_.batch_size_, params_.hidden_size_)); - if (cell_init == nullptr) { - MS_LOG(ERROR) << "add cell_init input tensor failed for " << op_name_; - return RET_ERROR; - } - op_binding_tensor_.push_back( - BindingHelper{cell_init_name_, cell_in_init.MutableData(), nvinfer1::DataType::kFLOAT, cell_in_init.DataSize()}); - - sequence_size_input_ = - ctx->network()->addInput((op_name_ + "_seq_input").c_str(), nvinfer1::DataType::kINT32, nvinfer1::Dims{}); - if (sequence_size_input_ == nullptr) { - MS_LOG(ERROR) << "add sequence_size_input_ input tensor failed for " << op_name_; - return RET_ERROR; - } - op_binding_tensor_.push_back( - BindingHelper{(op_name_ + "_seq_input"), ¶ms_.sequence_size_, nvinfer1::DataType::kINT32, sizeof(int)}); - } else { - hidden_init = input(ctx, HIDDEN_IN_TENSOR_INIT).trt_tensor_; - cell_init = input(ctx, CELL_IN_TENSOR_INIT).trt_tensor_; - sequence_size_input_ = ctx->ConvertTo0DTensor(input(ctx, 0).trt_tensor_->getDimensions().d[0]); - } - - nvinfer1::ITensor *max_sequence_size = - ctx->network() - ->addConstant(nvinfer1::Dims{}, nvinfer1::Weights{nvinfer1::DataType::kINT32, ¶ms_.sequence_size_, 1}) - ->getOutput(0); - if (max_sequence_size == nullptr) { - MS_LOG(ERROR) << "add max_sequence_size constant tensor failed for " << op_name_; - return RET_ERROR; - } - LstmState next_state{input_data_, nullptr, nullptr}; // init states - std::vector hidden_outputs; - std::vector cell_outputs; - int input_weight_offset = 0; - int state_weight_offset = 0; - int bias_offset = 0; - - if (params_.layer_count_ != 1) { - MS_LOG(WARNING) << op_name_ << " needs verify for layer cnt: " << params_.layer_count_; - } - for (int i = 0; i < params_.layer_count_; i++) { - LstmState layer_input_states[BIDIRECTIONAL]; - LstmWeights layer_weights[BIDIRECTIONAL]; - layer_weights[0].max_seq_size_ = max_sequence_size; - int ret = ParseLSTMCellInputs(ctx, i, hidden_init, cell_init, layer_input_states, &input_weight_offset, - &state_weight_offset, &bias_offset, layer_weights, next_state); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ParseLSTMCellInputs failed for " << op_name_; - return RET_ERROR; - } - data_out = AddLSTMCell(ctx, layer_input_states, layer_weights, &next_state); - hidden_outputs.push_back(next_state.hidden_); - cell_outputs.push_back(next_state.cell_); - if (data_out == nullptr || next_state.hidden_ == nullptr || next_state.cell_ == nullptr) { - MS_LOG(ERROR) << "AddLSTMCell failed for " << op_name_; - return RET_ERROR; - } - } - - op_hidden_out_ = ConcateAll(ctx, hidden_outputs); - if (op_hidden_out_ == nullptr) { - MS_LOG(ERROR) << "concat hidden output failed for " << op_name_; - return RET_ERROR; - } - op_cell_out_ = ConcateAll(ctx, cell_outputs); - if (op_cell_out_ == nullptr) { - MS_LOG(ERROR) << "concat cell output failed for " << op_name_; - return RET_ERROR; - } - op_data_out_ = data_out; - return RET_OK; -} - -int LSTMTensorRT::ParseLSTMCellInputs(TensorRTContext *ctx, int layer_index, nvinfer1::ITensor *hidden_init, - nvinfer1::ITensor *cell_init, LstmState *layer_input_states, - int *input_weight_offset, int *state_weight_offset, int *bias_offset, - LstmWeights *layer_weights, const LstmState &next_state) { - nvinfer1::Dims2 dim_input_weight(LSTM_GATE_NUM * params_.hidden_size_, params_.input_data_size_); - nvinfer1::Dims2 dim_state_weight(LSTM_GATE_NUM * params_.hidden_size_, params_.hidden_size_); - nvinfer1::Dims dim_bias{1, {LSTM_GATE_NUM * params_.hidden_size_}}; - - TensorInfo &input_weight = in_tensors_[INPUT_WEIGHT]; - TensorInfo &state_weight = in_tensors_[STATE_WEIGHT]; - TensorInfo &bias = in_tensors_[BIAS]; - - nvinfer1::Dims dimW = layer_index == 0 ? dim_input_weight : dim_state_weight; - - for (int direction_index = 0; direction_index < params_.directional_cnt_; direction_index++) { - nvinfer1::ITensor *index = - ctx->network() - ->addConstant(nvinfer1::Dims{}, - nvinfer1::Weights{nvinfer1::DataType::kINT32, - &INDICES[layer_index * params_.directional_cnt_ + direction_index], 1}) - ->getOutput(0); - MS_ASSERT(index); - layer_input_states[direction_index].data_ = next_state.data_; - layer_input_states[direction_index].hidden_ = ctx->network()->addGather(*hidden_init, *index, 0)->getOutput(0); - layer_input_states[direction_index].cell_ = ctx->network()->addGather(*cell_init, *index, 0)->getOutput(0); - MS_ASSERT(layer_input_states[direction_index].hidden_); - MS_ASSERT(layer_input_states[direction_index].cell_); - - // weight order: input, output, forget, cell - if (params_.data_type_ != nvinfer1::DataType::kFLOAT) { - MS_LOG(WARNING) << "more data type need to be done"; - return RET_ERROR; - } - const float *input_weight_ptr = static_cast(input_weight.Data()); - const float *state_weight_ptr = static_cast(state_weight.Data()); - const float *bias_ptr = static_cast(bias.Data()); - nvinfer1::Weights slice_input_weight{params_.data_type_, input_weight_ptr + *input_weight_offset, - GetDimsVolume(dimW)}; - (*input_weight_offset) += slice_input_weight.count; - nvinfer1::Weights slice_state_weight{params_.data_type_, state_weight_ptr + *state_weight_offset, - GetDimsVolume(dim_state_weight)}; - (*state_weight_offset) += slice_state_weight.count; - layer_weights[direction_index].input_weights_ = ctx->network()->addConstant(dimW, slice_input_weight)->getOutput(0); - layer_weights[direction_index].state_weights_ = - ctx->network()->addConstant(dim_state_weight, slice_state_weight)->getOutput(0); - MS_ASSERT(layer_weights[direction_index].input_weights_); - MS_ASSERT(layer_weights[direction_index].state_weights_); - - // bias - nvinfer1::Weights slice_input_bias{params_.data_type_, bias_ptr + *bias_offset, GetDimsVolume(dim_bias)}; - (*bias_offset) += slice_input_bias.count; - nvinfer1::Weights slice_state_bias{params_.data_type_, bias_ptr + *bias_offset, GetDimsVolume(dim_bias)}; - (*bias_offset) += slice_state_bias.count; - layer_weights[direction_index].input_bias_ = ctx->network()->addConstant(dim_bias, slice_input_bias)->getOutput(0); - layer_weights[direction_index].state_bias_ = ctx->network()->addConstant(dim_bias, slice_state_bias)->getOutput(0); - MS_ASSERT(layer_weights[direction_index].input_bias_); - MS_ASSERT(layer_weights[direction_index].state_bias_); - } - if (params_.directional_cnt_ == BIDIRECTIONAL) { - layer_weights[1].max_seq_size_ = layer_weights[0].max_seq_size_; - } - return RET_OK; -} - -nvinfer1::ITensor *LSTMTensorRT::Reshape(TensorRTContext *ctx, nvinfer1::ITensor *tensor, nvinfer1::Dims dims) { - nvinfer1::IShuffleLayer *shuffle = ctx->network()->addShuffle(*tensor); - shuffle->setReshapeDimensions(dims); - return shuffle->getOutput(0); -} - -nvinfer1::ITensor *LSTMTensorRT::ConcateAll(TensorRTContext *ctx, std::vector all_tensor, - int axis) { - if (all_tensor.size() == 1) { - return all_tensor[0]; - } - nvinfer1::IConcatenationLayer *concat = ctx->network()->addConcatenation(all_tensor.data(), all_tensor.size()); - if (concat == nullptr) { - MS_LOG(ERROR) << "addConcatenation failed for " << op_name_; - return nullptr; - } - if (axis >= all_tensor[0]->getDimensions().nbDims) { - MS_LOG(ERROR) << op_name_ << " concat axis is " << axis << ", larger than tensor dims " - << all_tensor[0]->getDimensions().nbDims; - return nullptr; - } - concat->setAxis(axis); - this->layer_ = concat; - return concat->getOutput(0); -} - -nvinfer1::ITensor *LSTMTensorRT::AddLSTMCell(TensorRTContext *ctx, const LstmState *layer_input_states, - const LstmWeights *layer_weights, LstmState *next_state) { - nvinfer1::ITensor *backward_output = nullptr; - nvinfer1::ITensor *backward_hidden_out = nullptr; - nvinfer1::ITensor *backward_cell_out = nullptr; - nvinfer1::ITensor *forward_hidden_out = nullptr; - nvinfer1::ITensor *forward_cell_out = nullptr; - - nvinfer1::ITensor *forward_output = - AddLSTMCalculation(ctx, layer_input_states[0], layer_weights[0], &forward_hidden_out, &forward_cell_out); - if (params_.directional_cnt_ == BIDIRECTIONAL) { - backward_output = - AddLSTMCalculation(ctx, layer_input_states[1], layer_weights[1], &backward_hidden_out, &backward_cell_out, true); - } - - // concate forward and backward - nvinfer1::ITensor *output_tensor = forward_output; - nvinfer1::ITensor *cell_out = forward_cell_out; - nvinfer1::ITensor *hidden_out = forward_hidden_out; - if (backward_output != nullptr && backward_hidden_out != nullptr && backward_cell_out != nullptr) { - nvinfer1::ITensor *output_concat_input[BIDIRECTIONAL] = {forward_output, backward_output}; - auto ouput_out_layer = ctx->network()->addConcatenation(output_concat_input, BIDIRECTIONAL); - this->layer_ = ouput_out_layer; - if (ouput_out_layer == nullptr) { - MS_LOG(ERROR) << "create one loop output concat failed for " << op_name_; - return nullptr; - } - ouput_out_layer->setAxis(1); // ms: 0 sequence size, 1 layer * direction, 2 batchsize, 3 hidden - output_tensor = ouput_out_layer->getOutput(0); - - nvinfer1::ITensor *hidden_concat_input[BIDIRECTIONAL] = {forward_hidden_out, backward_hidden_out}; - auto hidden_out_layer = ctx->network()->addConcatenation(hidden_concat_input, BIDIRECTIONAL); - hidden_out_layer->setAxis(0); - hidden_out = hidden_out_layer->getOutput(0); - - nvinfer1::ITensor *cell_concat_input[BIDIRECTIONAL] = {forward_cell_out, backward_cell_out}; - auto cell_out_layer = ctx->network()->addConcatenation(cell_concat_input, BIDIRECTIONAL); - cell_out_layer->setAxis(0); - cell_out = cell_out_layer->getOutput(0); - } - if (hidden_out == nullptr || cell_out == nullptr) { - MS_LOG(ERROR) << "get one loop hidden_out and cell_out failed for " << op_name_; - return nullptr; - } - *next_state = LstmState{output_tensor, hidden_out, cell_out}; - return output_tensor; -} -nvinfer1::ITensor *LSTMTensorRT::AddLSTMCalculation(TensorRTContext *ctx, const LstmState &input_state, - const LstmWeights &lstm_weights, nvinfer1::ITensor **hidden_out, - nvinfer1::ITensor **cell_out, bool is_backward) { - std::vector all_batch_outputs; - std::vector all_batch_hidden; - std::vector all_batch_cell; - for (int batch_index = 0; batch_index < params_.batch_size_; batch_index++) { - LstmState one_batch_input_state; - nvinfer1::ITensor *batch_index_tensor = - ctx->network() - ->addConstant(nvinfer1::Dims{}, nvinfer1::Weights{nvinfer1::DataType::kINT32, &INDICES[batch_index], 1}) - ->getOutput(0); - one_batch_input_state.data_ = ctx->network()->addGather(*input_state.data_, *batch_index_tensor, 0)->getOutput(0); - one_batch_input_state.hidden_ = - ctx->network()->addGather(*input_state.hidden_, *batch_index_tensor, 0)->getOutput(0); - one_batch_input_state.cell_ = ctx->network()->addGather(*input_state.cell_, *batch_index_tensor, 0)->getOutput(0); - nvinfer1::ITensor *one_batch_hidden = nullptr; - nvinfer1::ITensor *one_batch_cell = nullptr; - nvinfer1::ITensor *one_batch_output = - AddLSTMOneLoop(ctx, one_batch_input_state, lstm_weights, &one_batch_hidden, &one_batch_cell, is_backward); - if (one_batch_output == nullptr || one_batch_cell == nullptr || one_batch_hidden == nullptr) { - MS_LOG(ERROR) << "AddLSTMOneLoop failed for " << op_name_ << " at batch index " << batch_index; - return nullptr; - } - all_batch_outputs.push_back(one_batch_output); - all_batch_hidden.push_back(one_batch_hidden); - all_batch_cell.push_back(one_batch_cell); - } - *hidden_out = ConcateAll(ctx, all_batch_hidden, 1); - *cell_out = ConcateAll(ctx, all_batch_cell, 1); - return ConcateAll(ctx, all_batch_outputs, BATCH_SIZE_INDEX); -} - -nvinfer1::ITensor *LSTMTensorRT::AddLSTMOneLoop(TensorRTContext *ctx, const LstmState &input_state, - const LstmWeights &lstm_weights, nvinfer1::ITensor **hidden_out, - nvinfer1::ITensor **cell_out, bool is_backward) { -#if TRT_VERSION_GE(7, 0) - nvinfer1::ILoop *sequence_loop = ctx->network()->addLoop(); - if (sequence_loop == nullptr) { - MS_LOG(ERROR) << "add sequence_loop layer failed for " << op_name_; - return nullptr; - } - std::string loop_name = op_name_ + "_loop" + (is_backward ? "_backward" : "_forward"); - sequence_loop->setName(loop_name.c_str()); - sequence_loop->addTripLimit(*sequence_size_input_, nvinfer1::TripLimit::kCOUNT); - nvinfer1::ITensor *input = sequence_loop->addIterator(*input_state.data_, 0, is_backward)->getOutput(0); - - nvinfer1::ILayer *hidden_mid = sequence_loop->addRecurrence(*input_state.hidden_); - if (hidden_mid == nullptr) { - MS_LOG(ERROR) << "add hidden layer failed for " << op_name_; - return nullptr; - } - nvinfer1::ILayer *cell_mid = sequence_loop->addRecurrence(*input_state.cell_); - if (cell_mid == nullptr) { - MS_LOG(ERROR) << "add cell layer failed for " << op_name_; - return nullptr; - } - - nvinfer1::ITensor *input_matmul = - ctx->network() - ->addMatrixMultiply(*input, nvinfer1::MatrixOperation::kVECTOR, *lstm_weights.input_weights_, - nvinfer1::MatrixOperation::kTRANSPOSE) - ->getOutput(0); - - nvinfer1::ITensor *hidden_matmul = - ctx->network() - ->addMatrixMultiply(*hidden_mid->getOutput(0), nvinfer1::MatrixOperation::kVECTOR, *lstm_weights.state_weights_, - nvinfer1::MatrixOperation::kTRANSPOSE) - ->getOutput(0); - - nvinfer1::ITensor *weights_add = - ctx->network()->addElementWise(*input_matmul, *hidden_matmul, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - - nvinfer1::ITensor *bias = - ctx->network() - ->addElementWise(*lstm_weights.input_bias_, *lstm_weights.state_bias_, nvinfer1::ElementWiseOperation::kSUM) - ->getOutput(0); - - nvinfer1::ITensor *gates_calculate = - ctx->network()->addElementWise(*weights_add, *bias, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - - const auto isolateGate = [&](nvinfer1::ITensor &gates, int gateIndex) -> nvinfer1::ITensor * { - nvinfer1::ISliceLayer *slice = - ctx->network()->addSlice(gates, nvinfer1::Dims{1, {gateIndex * params_.hidden_size_}}, - nvinfer1::Dims{1, {params_.hidden_size_}}, nvinfer1::Dims{1, {1}}); - return Reshape(ctx, slice->getOutput(0), nvinfer1::Dims{1, {params_.hidden_size_}}); - }; - // weight order: input, output, forget, cell - nvinfer1::ITensor *i = - ctx->network()->addActivation(*isolateGate(*gates_calculate, 0), nvinfer1::ActivationType::kSIGMOID)->getOutput(0); - - nvinfer1::ITensor *o = - ctx->network()->addActivation(*isolateGate(*gates_calculate, 1), nvinfer1::ActivationType::kSIGMOID)->getOutput(0); - - nvinfer1::ITensor *f = - ctx->network() - ->addActivation(*isolateGate(*gates_calculate, FORGET_GATE), nvinfer1::ActivationType::kSIGMOID) - ->getOutput(0); - - nvinfer1::ITensor *c = ctx->network() - ->addActivation(*isolateGate(*gates_calculate, CELL_GATE), nvinfer1::ActivationType::kTANH) - ->getOutput(0); - - nvinfer1::ITensor *C = - ctx->network() - ->addElementWise(*ctx->network() - ->addElementWise(*f, *cell_mid->getOutput(0), nvinfer1::ElementWiseOperation::kPROD) - ->getOutput(0), - *ctx->network()->addElementWise(*i, *c, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0), - nvinfer1::ElementWiseOperation::kSUM) - ->getOutput(0); - nvinfer1::ITensor *H = - ctx->network() - ->addElementWise(*o, *ctx->network()->addActivation(*C, nvinfer1::ActivationType::kTANH)->getOutput(0), - nvinfer1::ElementWiseOperation::kPROD) - ->getOutput(0); - - // Recurrent backedge input for hidden and cell. - cell_mid->setInput(1, *C); - hidden_mid->setInput(1, *H); - // outputs - nvinfer1::LoopOutput output_mode = is_backward ? nvinfer1::LoopOutput::kREVERSE : nvinfer1::LoopOutput::kCONCATENATE; - nvinfer1::ILoopOutputLayer *output_layer = sequence_loop->addLoopOutput(*H, output_mode); - output_layer->setInput(1, *lstm_weights.max_seq_size_); - *hidden_out = Reshape( - ctx, sequence_loop->addLoopOutput(*hidden_mid->getOutput(0), nvinfer1::LoopOutput::kLAST_VALUE)->getOutput(0), - nvinfer1::Dims3(1, 1, params_.hidden_size_)); - *cell_out = - Reshape(ctx, sequence_loop->addLoopOutput(*cell_mid->getOutput(0), nvinfer1::LoopOutput::kLAST_VALUE)->getOutput(0), - nvinfer1::Dims3(1, 1, params_.hidden_size_)); - return Reshape(ctx, output_layer->getOutput(0), nvinfer1::Dims4(params_.sequence_size_, 1, 1, params_.hidden_size_)); -#else - MS_LOG(ERROR) << "low TensorRT version don't support LSTM op, please upgrade TensorRT version to 7 or higher"; - return nullptr; -#endif -} - -int LSTMTensorRT::Prepare(void **network_tensor_bindings, nvinfer1::ICudaEngine *engine) { - if (in_tensors_.size() == INPUT_SIZE4) { - return RET_OK; - } - if (in_tensors_[HIDDEN_IN_TENSOR_INIT].Data() == nullptr && in_tensors_[CELL_IN_TENSOR_INIT].Data() == nullptr) { - return RET_OK; - } - - if (op_binding_tensor_.size() == 0) { - MS_LOG(DEBUG) << "using serialized engine, add input tensor for " << op_name_; - TensorInfo &hidden_in_init = in_tensors_[HIDDEN_IN_TENSOR_INIT]; - TensorInfo &cell_in_init = in_tensors_[CELL_IN_TENSOR_INIT]; - - hidden_init_name_ = hidden_in_init.Name() + "_hidden_init"; - cell_init_name_ = cell_in_init.Name() + "_cell_init"; - op_binding_tensor_.push_back( - BindingHelper{hidden_init_name_, hidden_in_init.Data(), nvinfer1::DataType::kFLOAT, hidden_in_init.DataSize()}); - op_binding_tensor_.push_back( - BindingHelper{cell_init_name_, cell_in_init.Data(), nvinfer1::DataType::kFLOAT, cell_in_init.DataSize()}); - params_.sequence_size_ = in_tensors_[0].Shape()[0]; - op_binding_tensor_.push_back( - BindingHelper{(op_name_ + "_seq_input"), ¶ms_.sequence_size_, nvinfer1::DataType::kINT32, sizeof(int)}); - } - for (auto tensor : op_binding_tensor_) { - auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(tensor.name_, tensor.size_, tensor.data_type_); - if (device_ptr == nullptr) { - MS_LOG(ERROR) << "malloc for inputs tensor device memory failed " << tensor.name_; - return RET_ERROR; - } - int index = engine->getBindingIndex(tensor.name_.c_str()); - network_tensor_bindings[index] = device_ptr; - runtime_->GetAllocator()->SyncMemInHostAndDevice(const_cast(tensor.data_), tensor.name_, tensor.size_, - true); - runtime_->GetAllocator()->MarkMemValid(tensor.name_, true); - } - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameLSTM, LSTMTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/lstm_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/lstm_tensorrt.h deleted file mode 100644 index 695141f91090b3f5160439294ef1ef31d5a46526..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/lstm_tensorrt.h +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_LSTM_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_LSTM_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -constexpr int INPUT_TENSOR_SIZE = 6; -constexpr int OUTPUT_TENSOR_SIZE = 3; -constexpr int INPUT_WEIGHT = 1; -constexpr int STATE_WEIGHT = 2; -constexpr int BIAS = 3; -constexpr int HIDDEN_IN_TENSOR_INIT = 4; -constexpr int CELL_IN_TENSOR_INIT = 5; -constexpr int LSTM_GATE_NUM = 4; -constexpr int BIDIRECTIONAL = 2; -constexpr int OUTPUT_HIDDEN_INDEX = 1; -constexpr int OUTPUT_CELL_INDEX = 2; -constexpr int INPUT_SIZE_INDEX = 2; -constexpr int FORGET_GATE = 2; -constexpr int CELL_GATE = 3; -constexpr int BATCH_SIZE_INDEX = 2; -static const std::array INDICES{0, 1, 2, 3}; - -struct LSTMParams { - int sequence_size_; - int input_data_size_; - int batch_size_; - int layer_count_; - int hidden_size_; - nvinfer1::DataType data_type_; - int directional_cnt_; -}; - -struct LstmState { - nvinfer1::ITensor *data_{nullptr}; - nvinfer1::ITensor *hidden_{nullptr}; - nvinfer1::ITensor *cell_{nullptr}; -}; - -struct LstmWeights { - nvinfer1::ITensor *input_weights_{nullptr}; - nvinfer1::ITensor *state_weights_{nullptr}; - nvinfer1::ITensor *input_bias_{nullptr}; - nvinfer1::ITensor *state_bias_{nullptr}; - nvinfer1::ITensor *max_seq_size_{nullptr}; -}; - -class LSTMTensorRT : public TensorRTOp { - public: - LSTMTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~LSTMTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return false; } - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - int Prepare(void **network_tensor_bindings, nvinfer1::ICudaEngine *engine) override; - - private: - int PreProcess(TensorRTContext *ctx); - int RunLSTMPlugin(TensorRTContext *ctx); - - int AddLSTMLayers(TensorRTContext *ctx); - - nvinfer1::ITensor *AddLSTMCell(TensorRTContext *ctx, const LstmState *layer_input_states, - const LstmWeights *layer_weights, LstmState *next_state); - - nvinfer1::ITensor *Reshape(TensorRTContext *ctx, nvinfer1::ITensor *tensor, nvinfer1::Dims dims); - - nvinfer1::ITensor *ConcateAll(TensorRTContext *ctx, std::vector all_tensort, int axis = 0); - - nvinfer1::ITensor *AddLSTMCalculation(TensorRTContext *ctx, const LstmState &input_state, - const LstmWeights &lstm_weights, nvinfer1::ITensor **hidden_out, - nvinfer1::ITensor **cell_out, bool is_backward = false); - nvinfer1::ITensor *AddLSTMOneLoop(TensorRTContext *ctx, const LstmState &input_state, const LstmWeights &lstm_weights, - nvinfer1::ITensor **hidden_out, nvinfer1::ITensor **cell_out, - bool is_backward = false); - - int ParseLSTMCellInputs(TensorRTContext *ctx, int layer_index, nvinfer1::ITensor *hidden_init, - nvinfer1::ITensor *cell_init, LstmState *input_state, int *input_weight_offset, - int *state_weight_offset, int *bias_offset, LstmWeights *lstm_weights, - const LstmState &next_state); - - nvinfer1::ITensor *input_data_{nullptr}; - nvinfer1::ITensor *sequence_size_input_{nullptr}; - nvinfer1::ITensor *op_data_out_{nullptr}; - nvinfer1::ITensor *op_hidden_out_{nullptr}; - nvinfer1::ITensor *op_cell_out_{nullptr}; - LSTMParams params_; - std::string hidden_init_name_; - std::string cell_init_name_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_LSTM_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/matmul_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/matmul_tensorrt.cc deleted file mode 100644 index a27f7c1f55810246d2c13e7e51cda8466d1d9810..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/matmul_tensorrt.cc +++ /dev/null @@ -1,276 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/matmul_tensorrt.h" -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "src/extendrt/delegate/tensorrt/op/activation_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h" -#include "infer/cxx_api/mat_mul_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" - -namespace mindspore::lite { -MatMulTensorRT::~MatMulTensorRT() { - if (weight_ptr_ != nullptr) { - free(weight_ptr_); - weight_ptr_ = nullptr; - } -} -int MatMulTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2 && in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int MatMulTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (type_ == ops::kNameMatMulFusion) { - auto primitive = AsOps(); - if (primitive == nullptr) { - MS_LOG(ERROR) << "convert to primitive matmul failed for " << op_name_; - return RET_ERROR; - } - transpose_a_ = primitive->get_transpose_a(); - transpose_b_ = primitive->get_transpose_b(); - if (primitive->HasAttr(ops::kActivationType)) { - activation_ = primitive->get_activation_type(); - } - } - nvinfer1::ITensor *out_tensor = nullptr; - MS_LOG(DEBUG) << "use origin tensorrt matmul for " << op_name_; - out_tensor = AddAsMatmul(ctx); - if (out_tensor == nullptr) { - MS_LOG(ERROR) << "add matmul failed for " << op_name_; - return RET_ERROR; - } - - // add activation - if (activation_ != ActivationType::NO_ACTIVATION) { - nvinfer1::ILayer *activation_layer = - ActivationTensorRT::AddActivation(ctx, activation_, 0, 0, 0, out_tensor, op_name_, device_id_); - if (activation_layer == nullptr) { - MS_LOG(ERROR) << "addActivation for matmul failed"; - return RET_ERROR; - } - activation_layer->setName((op_name_ + "_activation").c_str()); - out_tensor = activation_layer->getOutput(0); - } - - ctx->RegisterTensor(ITensorHelper{out_tensor, out_format_}, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "output " << GetTensorFormat(out_tensor, out_format_, true); - return RET_OK; -} - -bool MatMulTensorRT::HasConst() const { return in_tensors_[0].IsConst() || in_tensors_[1].IsConst(); } - -int MatMulTensorRT::PreprocessMatMulInputs(TensorRTContext *ctx, ITensorHelper *matmul_a, ITensorHelper *matmul_b) { - if (!HasConst()) { - *matmul_a = input(ctx, 0); - *matmul_b = input(ctx, 1); - int ret = PreprocessInputs2SameDim(ctx, *matmul_a, matmul_a); - ret += PreprocessInputs2SameDim(ctx, *matmul_b, matmul_b); - if (ret != RET_OK || matmul_a->trt_tensor_ == nullptr || matmul_b->trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim of matmul inputs failed for " << op_name_; - return ret; - } - out_format_ = matmul_a->format_; - if (matmul_a->format_ != matmul_b->format_) { - MS_LOG(WARNING) << "matmul input tensor has different format " << op_name_; - out_format_ = Format::NCHW; - } - } else { - for (size_t i = 0; i < in_tensors_.size(); i++) { - auto in_tensor = input(ctx, i); - if (in_tensors_[i].IsConst() || in_tensor.trt_tensor_ == nullptr) { - in_tensor.trt_tensor_ = lite::ConvertConstantTensor(ctx, in_tensors_[i], op_name_); - in_tensor.format_ = Format::NCHW; - ctx->RegisterTensor(in_tensor, in_tensors_[i].Name()); - } - } - - auto weight = ProcessWeightTensor(ctx); - *matmul_a = input(ctx, 0); - *matmul_b = input(ctx, 1); - if (weight == nullptr) { - MS_LOG(ERROR) << "create constant weight tensor failed for " << op_name_; - return RET_ERROR; - } - int weight_index = in_tensors_[1].IsConst() ? 1 : 0; - ITensorHelper *weight_helper = (weight_index == 1) ? matmul_b : matmul_a; - ITensorHelper *var_helper = (weight_index == 1) ? matmul_a : matmul_b; - weight_helper->trt_tensor_ = weight; - int ret = PreprocessInputs2SameDim(ctx, *var_helper, var_helper); - if (ret != RET_OK || var_helper->trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim of matmul input var_helper failed for " << op_name_; - return ret; - } - out_format_ = var_helper->format_; - } - return RET_OK; -} - -nvinfer1::ITensor *MatMulTensorRT::ProcessWeightTensor(TensorRTContext *ctx) { - nvinfer1::ITensor *weight = nullptr; - int weight_index = in_tensors_[1].IsConst() ? 1 : 0; - if (in_tensors_[weight_index].Shape().size() < - static_cast(input(ctx, 0).trt_tensor_->getDimensions().nbDims)) { - std::vector expect_shape(input(ctx, 1 - weight_index).trt_tensor_->getDimensions().nbDims, 1); - auto origin_shape = in_tensors_[weight_index].Shape(); - for (size_t i = 0; i < origin_shape.size(); i++) { - expect_shape[expect_shape.size() - 1 - i] = origin_shape[origin_shape.size() - 1 - i]; - } - weight = ConvertTensorWithExpandDims(ctx, in_tensors_[weight_index], expect_shape, op_name_); - } else if (in_tensors_[weight_index].Shape().size() == - static_cast(input(ctx, 0).trt_tensor_->getDimensions().nbDims)) { - weight = ConvertConstantTensor(ctx, in_tensors_[weight_index], op_name_); - } else { - MS_LOG(ERROR) << "input tensor shape is invalid for " << op_name_; - return nullptr; - } - return weight; -} - -nvinfer1::ITensor *MatMulTensorRT::AddAsMatmul(TensorRTContext *ctx) { - ITensorHelper matmul_a; - ITensorHelper matmul_b; - - int ret = PreprocessMatMulInputs(ctx, &matmul_a, &matmul_b); - if (ret != RET_OK || matmul_a.trt_tensor_ == nullptr || matmul_b.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessMatMulInputs matmul failed for " << op_name_; - return nullptr; - } - - MS_LOG(DEBUG) << "matmul input a " << GetTensorFormat(matmul_a); - MS_LOG(DEBUG) << "matmul input b " << GetTensorFormat(matmul_b); - - auto matmul_layer = ctx->network()->addMatrixMultiply( - *matmul_a.trt_tensor_, transpose_a_ ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE, - *matmul_b.trt_tensor_, transpose_b_ ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE); - if (matmul_layer == nullptr) { - MS_LOG(ERROR) << "addMatrixMultiply failed for " << op_name_; - return nullptr; - } - this->layer_ = matmul_layer; - matmul_layer->setName(op_name_.c_str()); - return AddBias(ctx, matmul_layer->getOutput(0)); -} - -nvinfer1::ITensor *MatMulTensorRT::AddAsFullConnect(TensorRTContext *ctx) { - nvinfer1::Weights weight; - nvinfer1::Weights bias = ConvertWeight(in_tensors_[kBiasIndex]); - nvinfer1::ITensor *input_a = input(ctx, 0).trt_tensor_; - out_format_ = input(ctx, 0).format_; - if (input_a->getDimensions().nbDims != DIMENSION_4D) { - nvinfer1::Dims in_dims(input_a->getDimensions()); - in_dims.nbDims = DIMENSION_4D; - for (int i = input_a->getDimensions().nbDims; i < DIMENSION_4D; i++) { - in_dims.d[i] = 1; - } - input_a = Reshape(ctx, input_a, in_dims); - if (input_a == nullptr) { - MS_LOG(ERROR) << "reshape input failed for " << op_name_; - return nullptr; - } - MS_LOG(DEBUG) << "full connect expand input a to " << GetTensorFormat(input_a); - } else { - ITensorHelper tmp_input; - int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &tmp_input); - if (ret != RET_OK || tmp_input.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "rPreprocessInputs2SameDim failed for " << op_name_; - return nullptr; - } - input_a = tmp_input.trt_tensor_; - out_format_ = tmp_input.format_; - MS_LOG(DEBUG) << "full connect preprocess input a to " << GetTensorFormat(tmp_input); - } - if (!transpose_b_) { - // transpose weight - weight = TransposeWeight2D(in_tensors_[1], &weight_ptr_); - if (weight.values == nullptr || weight_ptr_ == nullptr) { - MS_LOG(ERROR) << "TransposeWeight2D input weight failed for " << op_name_; - return nullptr; - } - } else { - weight = ConvertWeight(in_tensors_[1]); - } - - int output_cnt = in_tensors_[kBiasIndex].Shape()[0]; - - auto fc_layer = ctx->network()->addFullyConnected(*input_a, output_cnt, weight, bias); - if (fc_layer == nullptr) { - MS_LOG(ERROR) << "add fully connected layer failed for " << op_name_; - return nullptr; - } - this->layer_ = fc_layer; - fc_layer->setName((op_name_ + "_fullyconnected").c_str()); - nvinfer1::ITensor *out_tensor = fc_layer->getOutput(0); - int origin_input_dims = input(ctx, 0).trt_tensor_->getDimensions().nbDims; - if (out_tensor->getDimensions().nbDims != origin_input_dims) { - std::vector squeeze_dim; - for (int i = 0; i != origin_input_dims; ++i) { - squeeze_dim.push_back(out_tensor->getDimensions().d[i]); - } - out_tensor = Reshape(ctx, out_tensor, squeeze_dim); - } - return out_tensor; -} -nvinfer1::ITensor *MatMulTensorRT::AddBias(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor) { - nvinfer1::ITensor *out_tensor = input_tensor; - if (in_tensors_.size() == kBiasIndex + 1) { - nvinfer1::ITensor *bias = nullptr; - if (in_tensors_[kBiasIndex].Shape().size() < static_cast(out_tensor->getDimensions().nbDims)) { - std::vector expect_dims(input_tensor->getDimensions().nbDims, 1); - expect_dims[expect_dims.size() - 1] = in_tensors_[kBiasIndex].Shape().back(); - bias = ConvertTensorWithExpandDims(ctx, in_tensors_[kBiasIndex], expect_dims, op_name_); - } else if (in_tensors_[kBiasIndex].Shape().size() == static_cast(out_tensor->getDimensions().nbDims)) { - bias = ConvertConstantTensor(ctx, in_tensors_[kBiasIndex], op_name_); - } else { - MS_LOG(ERROR) << "input tensor shape is invalid for " << op_name_; - return nullptr; - } - if (bias == nullptr) { - MS_LOG(ERROR) << "create constant bias tensor failed for " << op_name_; - return nullptr; - } - auto bias_layer = ctx->network()->addElementWise(*out_tensor, *bias, nvinfer1::ElementWiseOperation::kSUM); - if (bias_layer == nullptr) { - MS_LOG(ERROR) << "add bias add layer failed for " << op_name_; - return nullptr; - } - auto bias_layer_name = op_name_ + "_bias"; - bias_layer->setName(bias_layer_name.c_str()); - out_tensor = bias_layer->getOutput(0); - } - return out_tensor; -} - -bool MatMulTensorRT::RunFullConnect(TensorRTContext *ctx) { - if (in_tensors_.size() == INPUT_SIZE3 && in_tensors_[1].IsConst() && in_tensors_[kBiasIndex].IsConst() && - !transpose_a_ && in_tensors_[1].Shape().size() == DIMENSION_2D && - (input(ctx, 0).trt_tensor_->getDimensions().nbDims == DIMENSION_2D || - input(ctx, 0).trt_tensor_->getDimensions().nbDims == DIMENSION_4D)) { - return true; - } - return false; -} -REGISTER_TENSORRT_CREATOR(ops::kNameMatMulFusion, MatMulTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/matmul_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/matmul_tensorrt.h deleted file mode 100644 index dcf7253e4e3285cfd4605155d75d5cba4f996e12..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/matmul_tensorrt.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MATMUL_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MATMUL_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class MatMulTensorRT : public TensorRTOp { - public: - MatMulTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~MatMulTensorRT() override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return true; } - - bool HasConst() const override; - - private: - int PreprocessMatMulInputs(TensorRTContext *ctx, ITensorHelper *matmul_a, ITensorHelper *matmul_b); - - nvinfer1::ITensor *ProcessWeightTensor(TensorRTContext *ctx); - - nvinfer1::ITensor *AddAsMatmul(TensorRTContext *ctx); - - nvinfer1::ITensor *AddAsFullConnect(TensorRTContext *ctx); - - nvinfer1::ITensor *AddBias(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor); - - bool RunFullConnect(TensorRTContext *ctx); - - bool transpose_a_{false}; - bool transpose_b_{false}; - Format out_format_{Format::NCHW}; - ActivationType activation_{ActivationType::NO_ACTIVATION}; - void *weight_ptr_{nullptr}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MATMUL_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/maxpool3d_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/maxpool3d_tensorrt.cc deleted file mode 100644 index efeaf95600ad9469fdc9cb449063ee972431764a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/maxpool3d_tensorrt.cc +++ /dev/null @@ -1,153 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/maxpool3d_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "infer/max_pool3d.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" - -namespace mindspore::lite { -int MaxPool3DTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int MaxPool3DTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (in_tensors_.size() != 1) { - MS_LOG(ERROR) << "invalid input tensor size: " << in_tensors_.size(); - return RET_ERROR; - } - MS_LOG(DEBUG) << "before transpose " << GetTensorFormat(input(ctx, 0)); - int ret = ParseParams(ctx); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ParseParams failed for : " << op_name_; - return RET_ERROR; - } - - nvinfer1::ITensor *pool_input = input(ctx, 0).trt_tensor_; - - // global version pooling - if (kernel_size_.empty()) { - int reduce_axes = ((1 << pool_input->getDimensions().nbDims) - 1) & ~0b11; - auto *layer = ctx->network()->addReduce(*pool_input, nvinfer1::ReduceOperation::kAVG, reduce_axes, true); - if (layer == nullptr) { - MS_LOG(ERROR) << "addReduce for pool failed"; - return RET_ERROR; - } - layer->setName(op_name_.c_str()); - this->layer_ = layer; - } else { - // pooling layer - nvinfer1::Dims windowSize = lite::ConvertCudaDims(kernel_size_); - if (windowSize.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return RET_ERROR; - } - nvinfer1::IPoolingLayer *pooling_layer = - ctx->network()->addPoolingNd(*pool_input, nvinfer1::PoolingType::kMAX, windowSize); - if (pooling_layer == nullptr) { - MS_LOG(ERROR) << "addPoolingNd failed for TensorRT."; - return RET_ERROR; - } - ret = AddParams(pooling_layer); - if (ret != RET_OK) { - MS_LOG(ERROR) << "AddParams failed for : " << op_name_; - return RET_ERROR; - } - pooling_layer->setName(op_name_.c_str()); - this->layer_ = pooling_layer; - } - - // add activation - nvinfer1::ITensor *out_trt_tensor = layer_->getOutput(0); - auto output_helper = ITensorHelper{out_trt_tensor, Format::NCHW, true}; - ctx->RegisterTensor(output_helper, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "output " << GetTensorFormat(output_helper); - return RET_OK; -} - -int MaxPool3DTensorRT::ParseParams(TensorRTContext *ctx) { - auto pool_primitive = AsOps(); - if (pool_primitive == nullptr) { - MS_LOG(ERROR) << "convert PoolFusion failed: " << op_name_; - return RET_ERROR; - } - - auto kernel_size = pool_primitive->get_kernel_size(); - if (kernel_size.empty()) { - MS_LOG(ERROR) << "get kernel size failed: " << op_name_; - return RET_ERROR; - } - kernel_size_ = std::vector(kernel_size.begin() + INPUT_SIZE2, kernel_size.end()); - - auto stride = pool_primitive->get_strides(); - if (stride.empty()) { - MS_LOG(ERROR) << "get stride failed: " << op_name_; - return RET_ERROR; - } - stride_ = std::vector(stride.begin() + INPUT_SIZE2, stride.end()); - auto padding = pool_primitive->get_pad(); - if (padding.empty()) { - MS_LOG(INFO) << "get padding is null, set to default 0: " << op_name_; - padding_ = {0, 0, 0, 0, 0, 0}; - } else { - padding_ = std::vector(padding.begin(), padding.end()); - } - - pad_mode_ = pool_primitive->get_pad_mode(); - return RET_OK; -} - -int MaxPool3DTensorRT::AddParams(nvinfer1::IPoolingLayer *pooling_layer) { - nvinfer1::Dims stride_dims = ConvertCudaDims(stride_); - if (stride_dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return RET_ERROR; - } - pooling_layer->setStrideNd(stride_dims); - if (pad_mode_ == PadMode::SAME) { - pooling_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); - } else { - if (padding_.size() != DIMENSION_6D) { - MS_LOG(ERROR) << "Invalid padding " << padding_ << ", op: " << op_name_; - return RET_ERROR; - } - nvinfer1::Dims pre_dims{}; - pre_dims.nbDims = DIMENSION_3D; - pre_dims.d[0] = padding_[kDim0]; - pre_dims.d[1] = padding_[kDim2]; - pre_dims.d[INPUT_SIZE2] = padding_[kDim4]; - pooling_layer->setPrePadding(pre_dims); - - nvinfer1::Dims post_dims{}; - post_dims.nbDims = DIMENSION_3D; - post_dims.d[0] = padding_[kDim1]; - post_dims.d[1] = padding_[kDim3]; - post_dims.d[INPUT_SIZE2] = padding_[kDim5]; - pooling_layer->setPostPadding(post_dims); - } - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameMaxPool3D, MaxPool3DTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/maxpool3d_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/maxpool3d_tensorrt.h deleted file mode 100644 index 80b41e30ec03d40ce0c2d4682b55602cdde33b4c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/maxpool3d_tensorrt.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MAX_POOL3D_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MAX_POOL3D_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class MaxPool3DTensorRT : public TensorRTOp { - public: - MaxPool3DTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~MaxPool3DTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int ParseParams(TensorRTContext *ctx); - - int AddParams(nvinfer1::IPoolingLayer *pooling_layer); - - std::vector padding_; - std::vector kernel_size_; - std::vector stride_; - - PadMode pad_mode_{PadMode::PAD}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MAX_POOL3D_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/maxpool_with_argmax_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/maxpool_with_argmax_tensorrt.cc deleted file mode 100644 index e821ef3f5204da0b3650e2b0dadfeb514d89c1f4..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/maxpool_with_argmax_tensorrt.cc +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/maxpool_with_argmax_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "kernel/gpu/cuda_impl/cuda_ops/maxpool_with_argmax_impl.cuh" -#include "infer/max_pool_with_argmax.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" - -namespace mindspore::lite { -int MaxPoolWithArgMaxTensorRT::IsSupport(const BaseOperatorPtr &base_operator, - const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int MaxPoolWithArgMaxTensorRT::AddInnerOp(TensorRTContext *ctx) { - auto op = AsOps(); - CHECK_NULL_RETURN(op); - auto pad_mode = op->get_pad_mode(); - auto stride = op->get_strides(); - auto kernel_size = op->get_kernel_size(); - - auto plugin = std::make_shared(op_name_, kernel_size, stride, pad_mode); - if (plugin == nullptr) { - MS_LOG(ERROR) << "create ActivationOptPlugin failed for " << op_name_; - return RET_ERROR; - } - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_}; - nvinfer1::IPluginV2Layer *layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); - this->layer_ = layer; - nvinfer1::ITensor *op_out_tensor = layer->getOutput(0); - CHECK_NULL_RETURN(op_out_tensor); - ctx->RegisterTensor(ITensorHelper{op_out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - op_out_tensor = layer->getOutput(1); - CHECK_NULL_RETURN(op_out_tensor); - ctx->RegisterTensor(ITensorHelper{op_out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[1].Name()); - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(MaxPoolWithArgMaxPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int MaxPoolWithArgMaxPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - return RunCudaMaxPoolWithArgmax(inputDesc, inputs, outputs, stream); -} - -int MaxPoolWithArgMaxPlugin::RunCudaMaxPoolWithArgmax(const nvinfer1::PluginTensorDesc *inputDesc, - const void *const *inputs, void *const *outputs, - cudaStream_t stream) { - auto dims = inputDesc[0].dims; - int n = dims.d[0]; - int c = dims.d[1]; - int h = dims.d[INPUT_SIZE2]; - int w = dims.d[INPUT_SIZE3]; - int th = h / strides_[1] + (h % strides_[1] == 0); - int ph = std::max(0, (th - 1) * strides_[1] + kernel_size_[1] - h) / INPUT_SIZE2; - int tw = w / strides_[INPUT_SIZE2] + (w % strides_[INPUT_SIZE2] == 0); - int pw = std::max(0, (tw - 1) * strides_[INPUT_SIZE2] + kernel_size_[INPUT_SIZE2] - w) / INPUT_SIZE2; - int out_h = 0; - int out_w = 0; - if (pad_mode_ == PadMode::VALID) { - out_h = std::ceil((h - (kernel_size_[1] - 1)) / strides_[1]); - out_w = std::ceil((w - (kernel_size_[INPUT_SIZE2] - 1)) / strides_[INPUT_SIZE2]); - } else { - out_h = std::ceil(h / strides_[1]); - out_w = std::ceil(w / strides_[INPUT_SIZE2]); - } - CalMaxPoolWithArgmax(static_cast(inputs[0]), n, c, h, w, kernel_size_[1], - kernel_size_[INPUT_SIZE2], strides_[1], strides_[INPUT_SIZE2], ph, pw, out_h, out_w, - static_cast(outputs[0]), static_cast(outputs[1]), device_id_, - stream); - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *MaxPoolWithArgMaxPlugin::clone() const noexcept { - auto *plugin = new MaxPoolWithArgMaxPlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -size_t MaxPoolWithArgMaxPlugin::getSerializationSize() const noexcept { - return sizeof(float) * (INPUT_SIZE4 + INPUT_SIZE4) + sizeof(PadMode); -} - -void MaxPoolWithArgMaxPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &kernel_size_[0], sizeof(float) * INPUT_SIZE4); - SerializeValue(&buffer, &strides_[0], sizeof(float) * INPUT_SIZE4); - SerializeValue(&buffer, &pad_mode_, sizeof(PadMode)); -} - -REGISTER_TENSORRT_CREATOR(ops::kNameMaxPoolWithArgmax, MaxPoolWithArgMaxTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/maxpool_with_argmax_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/maxpool_with_argmax_tensorrt.h deleted file mode 100644 index 29857d071f2d56a0f3bdd7401edf98925d00675d..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/maxpool_with_argmax_tensorrt.h +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MAXPOOL_WITH_ARGMAX_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MAXPOOL_WITH_ARGMAX_PLUGIN_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class MaxPoolWithArgMaxTensorRT : public TensorRTOp { - public: - MaxPoolWithArgMaxTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~MaxPoolWithArgMaxTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto MAXPOOL_WITH_ARGMAX_PLUGIN_NAME{"MaxPoolWithArgMaxPlugin"}; -class MaxPoolWithArgMaxPlugin : public TensorRTPlugin { - public: - MaxPoolWithArgMaxPlugin(const std::string name, const std::vector &kernel_size, - const std::vector &strides, const PadMode &pad_mode) - : TensorRTPlugin(name, std::string(MAXPOOL_WITH_ARGMAX_PLUGIN_NAME)), - kernel_size_(kernel_size), - strides_(strides), - pad_mode_(pad_mode) {} - - MaxPoolWithArgMaxPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(MAXPOOL_WITH_ARGMAX_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - kernel_size_.resize(INPUT_SIZE4); - kernel_size_[0] = static_cast(fields[0].data)[0]; - kernel_size_[1] = static_cast(fields[1].data)[0]; - kernel_size_[INPUT_SIZE2] = static_cast(fields[INPUT_SIZE2].data)[0]; - kernel_size_[INPUT_SIZE3] = static_cast(fields[INPUT_SIZE3].data)[0]; - strides_.resize(INPUT_SIZE4); - strides_[0] = static_cast(fields[INPUT_SIZE4].data)[0]; - strides_[1] = static_cast(fields[INPUT_SIZE5].data)[0]; - strides_[INPUT_SIZE2] = static_cast(fields[INPUT_SIZE6].data)[0]; - strides_[INPUT_SIZE3] = static_cast(fields[INPUT_SIZE7].data)[0]; - pad_mode_ = static_cast(fields[INPUT_SIZE8].data)[0]; - } - - MaxPoolWithArgMaxPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(MAXPOOL_WITH_ARGMAX_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &kernel_size_[0], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &kernel_size_[1], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &kernel_size_[INPUT_SIZE2], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &kernel_size_[INPUT_SIZE3], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &strides_[0], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &strides_[1], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &strides_[INPUT_SIZE2], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &strides_[INPUT_SIZE3], sizeof(float)); - DeserializeValue(&serialData, &serialLength, &pad_mode_, sizeof(PadMode)); - } - - MaxPoolWithArgMaxPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - int getNbOutputs() const noexcept override { return INPUT_SIZE2; } - - private: - int RunCudaMaxPoolWithArgmax(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream); - std::vector kernel_size_; - std::vector strides_; - PadMode pad_mode_; -}; -class MaxPoolWithArgMaxPluginCreater : public TensorRTPluginCreater { - public: - MaxPoolWithArgMaxPluginCreater() : TensorRTPluginCreater(std::string(MAXPOOL_WITH_ARGMAX_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MAXPOOL_WITH_ARGMAX_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/mirror_pad_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/mirror_pad_tensorrt.cc deleted file mode 100644 index c3f3c85001340b8fb8d2ced3628ac95e30018ea1..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/mirror_pad_tensorrt.cc +++ /dev/null @@ -1,111 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/mirror_pad_tensorrt.h" -#include "infer/mirror_pad.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -namespace mindspore::lite { -constexpr int SIZE_INDEX = 2; -int MirrorPadTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int MirrorPadTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - auto mirrorpad_op = AsOps(); - CHECK_NULL_RETURN(mirrorpad_op); - auto pad_mode = mirrorpad_op->get_mode(); - int ret = RunAsTrtOps(ctx, pad_mode); - return ret; -} - -int MirrorPadTensorRT::RunAsTrtOps(TensorRTContext *ctx, string mode) { - auto pad_input = input(ctx, 0).trt_tensor_; - if (pad_input == nullptr) { - MS_LOG(ERROR) << "TensorRt Tensor of input 0 of pad " << op_name_ << " is nullptr"; - return RET_ERROR; - } - auto input_shape = ConvertMSShape(pad_input->getDimensions()); - if (!in_tensors_[1].IsConst()) { - MS_LOG(ERROR) << "Input 1 of pad " << op_name_ << " is not constant"; - return RET_ERROR; - } - auto pad_vec = ConvertTensorAsIntVector(in_tensors_[1]); - if (pad_vec.empty()) { - MS_LOG(ERROR) << "Failed to get pad input, node: " << op_name_; - return RET_ERROR; - } - constexpr size_t pad_multi_times = 2; - if (pad_vec.size() % pad_multi_times != 0 && pad_vec.size() != input_shape.size() * pad_multi_times) { - MS_LOG(ERROR) << "pad tensor is invalid, pad count: " << pad_vec.size() - << ", input dims count: " << input_shape.size() << ", op: " << op_name_; - return RET_ERROR; - } -#if TRT_VERSION_GE(8, 0) - std::vector start_values; - std::vector size_values; - std::vector stride_values; - for (size_t i = 0; i < pad_vec.size(); i += pad_multi_times) { - start_values.push_back(-pad_vec[i]); - stride_values.push_back(1); - size_values.push_back(pad_vec[i] + pad_vec[i + 1]); - } - nvinfer1::ITensor *size; - auto totalPadding = ctx->ConvertTo1DTensor(size_values); - auto shape = ctx->network()->addShape(*pad_input)->getOutput(0); - size = ctx->network()->addElementWise(*shape, *totalPadding, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - auto slice_layer = - ctx->network()->addSlice(*pad_input, ConvertCudaDims(start_values), {}, ConvertCudaDims(stride_values)); - slice_layer->setInput(SIZE_INDEX, *size); - if (slice_layer == nullptr) { - MS_LOG(ERROR) << "Failed to add slice layer for op " << op_name_; - return RET_ERROR; - } - if (mode == "REFLECT") { - slice_layer->setMode(nvinfer1::SliceMode::kREFLECT); - } else { - MS_LOG(ERROR) << "Not support padding mode " << mode << " for op " << op_name_; - return RET_ERROR; - } - slice_layer->setName(op_name_.c_str()); - this->layer_ = slice_layer; - auto out_tensor = slice_layer->getOutput(0); - if (out_tensor == nullptr) { - MS_LOG(ERROR) << "Failed to get output tensor of op " << op_name_; - return RET_ERROR; - } - auto output_tensor = ITensorHelper{out_tensor, NCHW, true}; - ctx->RegisterTensor(output_tensor, out_tensors_[0].Name()); - return RET_OK; -#else - MS_LOG(ERROR) << "Only support pad mode constant and input dims count 8 when trt version < 8.0, op: " << op_name_; - return RET_ERROR; -#endif -} -REGISTER_TENSORRT_CREATOR(ops::kNameMirrorPad, MirrorPadTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/mirror_pad_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/mirror_pad_tensorrt.h deleted file mode 100644 index 9fa0032f4ff2ee2419d1cad233892b3a69c9f08e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/mirror_pad_tensorrt.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MIRROR_PAD_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MIRROR_PAD_TENSORRT_H_ -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class MirrorPadTensorRT : public TensorRTOp { - public: - MirrorPadTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~MirrorPadTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int RunAsTrtOps(TensorRTContext *ctx, string mode); -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MIRROR_PAD_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/nmswithmask_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/nmswithmask_tensorrt.cc deleted file mode 100644 index c162da7ee9dc75b4d1f8354dd5e917929af880c7..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/nmswithmask_tensorrt.cc +++ /dev/null @@ -1,159 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "src/extendrt/delegate/tensorrt/op/nmswithmask_tensorrt.h" -#include "kernel/gpu/cuda_impl/cuda_ops/nms_with_mask_impl.cuh" -#include "infer/nms_with_mask.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_n.h" - -namespace mindspore::lite { -int NMSwithmaskTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int NMSwithmaskTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "network or input tensor is invalid"; - return RET_ERROR; - } - auto in_tensor = input(ctx, 0).trt_tensor_; - if (in_tensors_[0].DataType() == DataType::kNumberTypeFloat16) { - in_tensor = TRTTensorCast(ctx, in_tensor, nvinfer1::DataType::kFLOAT, op_name_ + "_cast_in"); - } - CHECK_NULL_RETURN(in_tensor); - auto input1_dims = in_tensor->getDimensions(); - if (input1_dims.nbDims != INPUT_SIZE2 || input1_dims.d[1] != INPUT_SIZE5) { - MS_LOG(ERROR) << "input tensor is invalid"; - return RET_ERROR; - } - auto num_input = input1_dims.d[0]; - auto nms_with_mask_op = AsOps(); - if (nms_with_mask_op == nullptr) { - MS_LOG(ERROR) << "Failed to as operator ConstantOfShape: " << op_name_; - return RET_ERROR; - } - auto iou_value = GetValue(nms_with_mask_op->GetAttr(kAttrIouThreshold)); - auto plugin = std::make_shared(op_name_, num_input, iou_value); - CHECK_NULL_RETURN(plugin); - nvinfer1::ITensor *inputTensors[] = {in_tensor}; - nvinfer1::IPluginV2Layer *nmswithmask_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); - CHECK_NULL_RETURN(nmswithmask_layer); - this->layer_ = nmswithmask_layer; - for (int i = 0; i < INPUT_SIZE3; i++) { - nvinfer1::ITensor *op_out_tensor = nmswithmask_layer->getOutput(i); - CHECK_NULL_RETURN(op_out_tensor); - ctx->RegisterTensor(ITensorHelper{op_out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[i].Name()); - } - - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(NMSwithmaskPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int NMSwithmaskPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - return RunCudaNMSwithmask(inputDesc, inputs, outputs, stream); -} - -int NMSwithmaskPlugin::RunCudaNMSwithmask(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - int box_size_ = INPUT_SIZE5; - void *data_buff = nullptr; - cudaMalloc(&data_buff, NmsRoundUpPower2(num_input_) * sizeof(float)); - void *index_buff = nullptr; - cudaMalloc(&index_buff, NmsRoundUpPower2(num_input_) * sizeof(int)); - void *row_mask = nullptr; - cudaMalloc(&row_mask, num_input_ * num_input_ * sizeof(bool)); - - CalSort(static_cast(num_input_), static_cast(inputs[0]), static_cast(outputs[0]), - static_cast(index_buff), static_cast(data_buff), box_size_, device_id_, stream); - CalPreprocess(static_cast(num_input_), static_cast(outputs[1]), - static_cast(outputs[INPUT_SIZE2]), static_cast(inputs[0]), - static_cast(outputs[0]), static_cast(index_buff), box_size_, - static_cast(row_mask), device_id_, stream); - CalNms(static_cast(num_input_), iou_value_, static_cast(outputs[0]), - static_cast(outputs[INPUT_SIZE2]), box_size_, static_cast(row_mask), device_id_, stream); - cudaFree(data_buff); - cudaFree(index_buff); - cudaFree(row_mask); - return RET_OK; -} -nvinfer1::DimsExprs NMSwithmaskPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, - int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs dims; - if (index == 0) { - dims = inputs[0]; - } - if (index == 1) { - dims.d[0] = inputs[0].d[0]; - dims.nbDims = 1; - } - if (index == INPUT_SIZE2) { - dims.d[0] = inputs[0].d[0]; - dims.nbDims = 1; - } - return dims; -} - -nvinfer1::DataType NMSwithmaskPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const noexcept { - nvinfer1::DataType datatype; - if (index == 0) { - datatype = nvinfer1::DataType::kFLOAT; - } - if (index == 1) { - datatype = nvinfer1::DataType::kINT32; - } - if (index == INPUT_SIZE2) { - datatype = nvinfer1::DataType::kINT32; - } - return datatype; -} - -nvinfer1::IPluginV2DynamicExt *NMSwithmaskPlugin::clone() const noexcept { - auto *plugin = new (std::nothrow) NMSwithmaskPlugin(*this); - if (plugin == nullptr) { - MS_LOG(ERROR) << "malloc nms with mask plugin failed"; - return nullptr; - } - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} -REGISTER_TENSORRT_CREATOR(ops::kNameNMSWithMask, NMSwithmaskTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/nmswithmask_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/nmswithmask_tensorrt.h deleted file mode 100644 index 7452b5a688b2755136e8ef383ec5b8696dd2e7eb..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/nmswithmask_tensorrt.h +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_OP_NMSWITHMASK_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_OP_NMSWITHMASK_PLUGIN_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -constexpr auto NMSWITHMASK_PLUGIN_NAME{"NMSwithmaskPlugin"}; -class NMSwithmaskTensorRT : public TensorRTOp { - public: - NMSwithmaskTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~NMSwithmaskTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; -class NMSwithmaskPlugin : public TensorRTPlugin { - public: - explicit NMSwithmaskPlugin(const std::string name, int num_input, float iou_value) - : TensorRTPlugin(name, std::string(NMSWITHMASK_PLUGIN_NAME)), num_input_(num_input), iou_value_(iou_value) {} - - NMSwithmaskPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(NMSWITHMASK_PLUGIN_NAME)) {} - - NMSwithmaskPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(NMSWITHMASK_PLUGIN_NAME)) {} - - NMSwithmaskPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override; - int getNbOutputs() const noexcept override { return INPUT_SIZE3; } - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override { - if (tensorsDesc[pos].format != nvinfer1::TensorFormat::kLINEAR) { - return false; - } - if (pos <= nbInputs) { - return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT; - } - if (pos < nbInputs + nbOutputs) { - return tensorsDesc[pos].type == nvinfer1::DataType::kINT32; - } - return false; - } - - private: - int RunCudaNMSwithmask(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - const std::string layer_name_; - std::string name_space_; - int num_input_{0}; - float iou_value_{0.5}; -}; -class NMSwithmaskPluginCreater : public TensorRTPluginCreater { - public: - NMSwithmaskPluginCreater() : TensorRTPluginCreater(std::string(NMSWITHMASK_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_OP_NMSWITHMASK_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/normalize_opt_plugin.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/normalize_opt_plugin.cc deleted file mode 100644 index db9b8704cc6d4e4b0f55bf0825b06428a8ce6544..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/normalize_opt_plugin.cc +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/normalize_opt_plugin.h" -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h" -#include "NvInferRuntimeCommon.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/normalize.cuh" - -namespace mindspore::lite { -REGISTER_TENSORRT_PLUGIN(NormalizeOptPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int NormalizeOptPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - auto input = static_cast(inputs[0]); - auto gamma = static_cast(inputs[1]); - auto beta = static_cast(inputs[2]); - auto output = static_cast(outputs[0]); - auto input_dims = inputDesc[0].dims; - size_t dim_at_axis = input_dims.d[axis_]; - int element_cnt = std::accumulate(input_dims.d, input_dims.d + input_dims.nbDims, 1, std::multiplies()); - Normalize(input, gamma, beta, output, dim_at_axis, epsilion_, element_cnt, stream); - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *NormalizeOptPlugin::clone() const noexcept { - auto *plugin = new NormalizeOptPlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -size_t NormalizeOptPlugin::getSerializationSize() const noexcept { return sizeof(size_t) + sizeof(float); } - -void NormalizeOptPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &axis_, sizeof(size_t)); - SerializeValue(&buffer, &epsilion_, sizeof(float)); -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/normalize_opt_plugin.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/normalize_opt_plugin.h deleted file mode 100644 index 644aa96d2c9e8fb9d4fb9e2d989f26088ff0993f..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/normalize_opt_plugin.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_NORMALIZE_OPT_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_NORMALIZE_OPT_PLUGIN_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -constexpr auto NORMALIZE_OPT_PLUGIN_NAME{"NormalizeOptPlugin"}; -class NormalizeOptPlugin : public TensorRTPlugin { - public: - NormalizeOptPlugin(const std::string name, size_t axis, float epsilion, uint32_t device_id) - : TensorRTPlugin(name, std::string(NORMALIZE_OPT_PLUGIN_NAME), device_id), axis_(axis), epsilion_(epsilion) {} - - NormalizeOptPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(NORMALIZE_OPT_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - axis_ = static_cast(fields[0].data)[0]; - epsilion_ = static_cast(fields[1].data)[0]; - } - - NormalizeOptPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(NORMALIZE_OPT_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &axis_, sizeof(size_t)); - DeserializeValue(&serialData, &serialLength, &epsilion_, sizeof(float)); - } - - NormalizeOptPlugin() = delete; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - - private: - size_t axis_{0}; - float epsilion_{0.0f}; -}; -class NormalizeOptPluginCreater : public TensorRTPluginCreater { - public: - NormalizeOptPluginCreater() : TensorRTPluginCreater(std::string(NORMALIZE_OPT_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_MATMUL_OPT_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/normalize_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/normalize_tensorrt.cc deleted file mode 100644 index fa99fc95bca322f2f7181ce670256609aee083da..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/normalize_tensorrt.cc +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/normalize_tensorrt.h" -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/normalize_opt_plugin.h" -#include "infer/cxx_api/layer_norm_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" - -namespace mindspore::lite { -int NormalizeTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE3 && in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != INPUT_SIZE3 && out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - - return RET_OK; -} - -int NormalizeTensorRT::AddInnerOp(TensorRTContext *ctx) { - CHECK_NULL_RETURN(ctx->network()); - auto norm_op = AsOps(); - CHECK_NULL_RETURN(norm_op); - int input_nbdims = input(ctx, 0).trt_tensor_->getDimensions().nbDims; - int being_norm_axis = norm_op->get_begin_norm_axis(); - being_norm_axis = being_norm_axis >= 0 ? being_norm_axis : input_nbdims + being_norm_axis; - int begin_params_axis = norm_op->get_begin_params_axis(); - begin_params_axis = begin_params_axis >= 0 ? begin_params_axis : input_nbdims + begin_params_axis; - if (begin_params_axis != being_norm_axis || begin_params_axis != input_nbdims - 1) { - MS_LOG(ERROR) << "only support normalize on last one dim, being_norm_axis is " << being_norm_axis << " for " - << op_name_; - return RET_ERROR; - } - axis_ = begin_params_axis; - epsilon_ = norm_op->get_epsilon(); - int ret = PreprocessInputs(ctx); - if (ret != RET_OK) { - MS_LOG(ERROR) << "preprocess input failed for " << op_name_; - return ret; - } - return RunOptPlugin() ? RunAsOptPlugin(ctx) : RunAsTrtOps(ctx); -} - -int NormalizeTensorRT::PreprocessInputs(TensorRTContext *ctx) { - int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &norm_input_); - if (ret != RET_OK || norm_input_.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim norm_input failed for " << op_name_; - return RET_ERROR; - } - if (in_tensors_.size() == BETA_INDEX + 1) { - auto expect_shape = ConvertMSShape(input(ctx, 0).trt_tensor_->getDimensions()); - gamma_ = ConvertTensorWithExpandDims(ctx, in_tensors_[1], expect_shape, op_name_ + in_tensors_[1].Name()); - CHECK_NULL_RETURN(gamma_); - beta_ = ConvertTensorWithExpandDims(ctx, in_tensors_[BETA_INDEX], expect_shape, - op_name_ + in_tensors_[BETA_INDEX].Name()); - CHECK_NULL_RETURN(beta_); - } - return RET_OK; -} - -int NormalizeTensorRT::RunAsOptPlugin(TensorRTContext *ctx) { - auto plugin = std::make_shared(op_name_, axis_, epsilon_, device_id_); - if (plugin == nullptr) { - MS_LOG(ERROR) << "create NormalizeOptPlugin failed for " << op_name_; - return RET_ERROR; - } - nvinfer1::ITensor *inputTensors[] = {norm_input_.trt_tensor_, gamma_, beta_}; - nvinfer1::IPluginV2Layer *norm_layer = ctx->network()->addPluginV2(inputTensors, INPUT_SIZE3, *plugin); - if (norm_layer == nullptr) { - MS_LOG(ERROR) << "add norm opt plugin layer failed for " << op_name_; - return RET_ERROR; - } - layer_ = norm_layer; - layer_->setName(op_name_.c_str()); - nvinfer1::ITensor *out_tensor = norm_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, norm_input_.format_, norm_input_.same_format_}, out_tensors_[0].Name()); - return RET_OK; -} - -int NormalizeTensorRT::RunAsTrtOps(TensorRTContext *ctx) { - size_t axis = 1u << axis_; - // first output, add later - - // mean - auto mean = - ctx->network()->addReduce(*(norm_input_.trt_tensor_), nvinfer1::ReduceOperation::kAVG, axis, true)->getOutput(0); - CHECK_NULL_RETURN(mean); - // x - mean - auto sub_mean = ctx->network() - ->addElementWise(*(norm_input_.trt_tensor_), *mean, nvinfer1::ElementWiseOperation::kSUB) - ->getOutput(0); - CHECK_NULL_RETURN(sub_mean); - // (x - mean)^2 - auto const_two = ConvertScalarToITensor(ctx, input(ctx, 0).trt_tensor_->getDimensions().nbDims, &two_, - DataType::kNumberTypeFloat32, op_name_ + "_two"); - CHECK_NULL_RETURN(const_two); - auto pow = ctx->network()->addElementWise(*sub_mean, *const_two, nvinfer1::ElementWiseOperation::kPOW)->getOutput(0); - CHECK_NULL_RETURN(pow); - // mean of (x - mean)^2 - auto var = ctx->network()->addReduce(*pow, nvinfer1::ReduceOperation::kAVG, axis, true)->getOutput(0); - CHECK_NULL_RETURN(var); - - // var + min epsilon - auto const_epsilon = ConvertScalarToITensor(ctx, input(ctx, 0).trt_tensor_->getDimensions().nbDims, &epsilon_, - DataType::kNumberTypeFloat32, op_name_ + "_epsilion"); - CHECK_NULL_RETURN(const_epsilon); - auto var_epsilon = - ctx->network()->addElementWise(*var, *const_epsilon, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - CHECK_NULL_RETURN(var_epsilon); - - // standard deviation - auto std_dev = ctx->network()->addUnary(*var_epsilon, nvinfer1::UnaryOperation::kSQRT)->getOutput(0); - CHECK_NULL_RETURN(std_dev); - - // sub_mean / std_dev - auto norm_layer = ctx->network()->addElementWise(*sub_mean, *std_dev, nvinfer1::ElementWiseOperation::kDIV); - CHECK_NULL_RETURN(norm_layer); - this->layer_ = norm_layer; - auto norm = norm_layer->getOutput(0); - CHECK_NULL_RETURN(norm); - - // scale with gamma and beta - nvinfer1::ITensor *out_tensor = nullptr; - if (gamma_ != nullptr && beta_ != nullptr) { - auto gamma_out = - ctx->network()->addElementWise(*norm, *gamma_, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - CHECK_NULL_RETURN(gamma_out); - auto beta_out = - ctx->network()->addElementWise(*gamma_out, *beta_, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - CHECK_NULL_RETURN(beta_out); - out_tensor = beta_out; - } else { - out_tensor = norm; - } - ctx->RegisterTensor(ITensorHelper{out_tensor, norm_input_.format_, norm_input_.same_format_}, out_tensors_[0].Name()); - return RET_OK; -} - -bool NormalizeTensorRT::RunOptPlugin() { return false; } -REGISTER_TENSORRT_CREATOR(ops::kNameLayerNormFusion, NormalizeTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/normalize_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/normalize_tensorrt.h deleted file mode 100644 index 8222b195c005c65f023ae29ac5ec5b60ba193722..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/normalize_tensorrt.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_NORMALIZE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_NORMALIZE_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -constexpr int BETA_INDEX = 2; - -class NormalizeTensorRT : public TensorRTOp { - public: - NormalizeTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~NormalizeTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return true; } - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int PreprocessInputs(TensorRTContext *ctx); - - int RunAsOptPlugin(TensorRTContext *ctx); - - int RunAsTrtOps(TensorRTContext *ctx); - - bool RunOptPlugin(); - - ITensorHelper norm_input_; - nvinfer1::ITensor *gamma_{nullptr}; - nvinfer1::ITensor *beta_{nullptr}; - size_t axis_{0}; - const float two_{2.0f}; - float epsilon_{0.0f}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_NORMALIZE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.cc deleted file mode 100644 index dd44c75af2abc309dc5a4d502c75675ddcf70bba..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.cc +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "src/extendrt/delegate/tensorrt/op/onehot_tensorrt.h" -#include "kernel/gpu/cuda_impl/cuda_ops/one_hot_impl.cuh" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_o.h" - -namespace mindspore::lite { -constexpr int INDICES_INDEX = 0; -constexpr int DEPTH_INDEX = 1; -constexpr int ON_VALUE_INDEX = 2; -constexpr int OFF_VALUE_INDEX = 3; - -int OnehotTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE4 && in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int OnehotTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "network or input tensor is invalid"; - return RET_ERROR; - } - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_, input(ctx, DEPTH_INDEX).trt_tensor_, - input(ctx, ON_VALUE_INDEX).trt_tensor_, input(ctx, OFF_VALUE_INDEX).trt_tensor_}; - ITensorHelper indice_helper = input(ctx, 0); - if (indice_helper.trt_tensor_->getType() != nvinfer1::DataType::kINT32) { - inputTensors[0] = TRTTensorCast(ctx, input(ctx, 0).trt_tensor_, nvinfer1::DataType::kINT32, op_name_ + "_cast_in"); - } - ITensorHelper depth_helper = input(ctx, DEPTH_INDEX); - if (depth_helper.trt_tensor_->getType() != nvinfer1::DataType::kINT32) { - inputTensors[DEPTH_INDEX] = - TRTTensorCast(ctx, input(ctx, DEPTH_INDEX).trt_tensor_, nvinfer1::DataType::kINT32, op_name_ + "_cast_in"); - } - auto &depth_tensor = in_tensors_[DEPTH_INDEX]; - if (depth_tensor.Data() == nullptr) { - MS_LOG(ERROR) << "get depth input tensor null for " << op_name_; - return RET_ERROR; - } - const int *depth_ptr = reinterpret_cast(depth_tensor.Data()); - int depth = *depth_ptr; - auto onehot_op = AsOps(); - int axis = onehot_op->get_axis(); - auto plugin = std::make_shared(op_name_, axis, depth); - if (plugin == nullptr) { - MS_LOG(ERROR) << "create OnehotPlugin failed for " << op_name_; - return RET_ERROR; - } - nvinfer1::IPluginV2Layer *onehot_layer = ctx->network()->addPluginV2(inputTensors, 4, *plugin); - this->layer_ = onehot_layer; - nvinfer1::ITensor *op_out_tensor = onehot_layer->getOutput(0); - if (op_out_tensor == nullptr) { - MS_LOG(ERROR) << "onehot out tensor is nullptr."; - return RET_ERROR; - } - ctx->RegisterTensor(ITensorHelper{op_out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(OnehotPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int OnehotPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - return RunCudaOneHot(inputDesc, inputs, outputs, stream); -} - -int OnehotPlugin::RunCudaOneHot(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - int left_dims = 1; - int right_dims = 1; - for (int i = 0; i != inputDesc[0].dims.nbDims; ++i) { - if (axis_ == -1 || i < axis_) { - left_dims *= inputDesc[0].dims.d[i]; - } - if (axis_ != -1 && i >= axis_) { - right_dims *= inputDesc[0].dims.d[i]; - } - } - if (inputDesc[0].type == nvinfer1::DataType::kINT32 && inputDesc[ON_VALUE_INDEX].type == nvinfer1::DataType::kFLOAT) { - OneHot(static_cast(inputs[0]), depth_, static_cast(inputs[ON_VALUE_INDEX]), - static_cast(inputs[OFF_VALUE_INDEX]), left_dims, right_dims, - static_cast(outputs[0]), device_id_, stream); - } else if (inputDesc[0].type == nvinfer1::DataType::kINT32 && - inputDesc[ON_VALUE_INDEX].type == nvinfer1::DataType::kHALF) { - OneHot(static_cast(inputs[0]), depth_, static_cast(inputs[ON_VALUE_INDEX]), - static_cast(inputs[OFF_VALUE_INDEX]), left_dims, right_dims, - static_cast(outputs[0]), device_id_, stream); - } else { - MS_LOG(ERROR) << "invalid onehot type "; - return RET_ERROR; - } - - return RET_OK; -} -nvinfer1::DimsExprs OnehotPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs dims; - dims.nbDims = inputs[0].nbDims + 1; - auto indice_dims = inputs[0].nbDims; - if (axis_ == -1) { - for (int i = 0; i != inputs[0].nbDims; ++i) { - dims.d[i] = inputs[0].d[i]; - } - auto depth = exprBuilder.constant(depth_); - dims.d[dims.nbDims - 1] = depth; - } else { - for (int i = 0; i != indice_dims; ++i) { - if (i >= axis_) { - dims.d[i + 1] = inputs[0].d[i]; - } else { - dims.d[i] = inputs[0].d[i]; - } - } - auto depth = exprBuilder.constant(depth_); - dims.d[axis_] = depth; - } - return dims; -} - -nvinfer1::IPluginV2DynamicExt *OnehotPlugin::clone() const noexcept { - auto *plugin = new OnehotPlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -size_t OnehotPlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType) + 2 * sizeof(int); } - -nvinfer1::DataType OnehotPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept { - return inputTypes[ON_VALUE_INDEX]; -} - -void OnehotPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &axis_, sizeof(int)); - SerializeValue(&buffer, &depth_, sizeof(int)); -} - -REGISTER_TENSORRT_CREATOR(ops::kNameOneHot, OnehotTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.h deleted file mode 100644 index 554f441309b65dc18aa5f1ccce158e9c672e5edd..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/onehot_tensorrt.h +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ONEHOT_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ONEHOT_TENSORRT_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class OnehotTensorRT : public TensorRTOp { - public: - OnehotTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~OnehotTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto ONEHOT_PLUGIN_NAME{"OnehotPlugin"}; -class OnehotPlugin : public TensorRTPlugin { - public: - OnehotPlugin(const std::string name, int axis, int depth) - : TensorRTPlugin(name, std::string(ONEHOT_PLUGIN_NAME)), axis_(axis), depth_(depth) {} - - OnehotPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(ONEHOT_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - axis_ = static_cast(fields[0].data)[0]; - depth_ = static_cast(fields[1].data)[0]; - } - - OnehotPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(ONEHOT_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &axis_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &depth_, sizeof(int)); - } - - OnehotPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override; - void serialize(void *buffer) const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - - private: - int RunCudaOneHot(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - int axis_; - int depth_; -}; -class OnehotPluginCreater : public TensorRTPluginCreater { - public: - OnehotPluginCreater() : TensorRTPluginCreater(std::string(ONEHOT_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ONEHOT_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/oneslike_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/oneslike_tensorrt.cc deleted file mode 100644 index a503d0f28e7315cb98a6681494d7ae03f4e11aa3..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/oneslike_tensorrt.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/oneslike_tensorrt.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_o.h" - -namespace mindspore::lite { -int OneslikeTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int OneslikeTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - int ret = RunAsTrtOps(ctx); - if (ret != RET_OK) { - MS_LOG(ERROR) << "oneslike op failed for " << op_name_; - return ret; - } - return ret; -} - -int OneslikeTensorRT::RunAsTrtOps(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - auto input_trt_tensor = input(ctx, 0).trt_tensor_; - nvinfer1::ITensor *value_tensor; - if (in_tensors_[0].DataType() == DataType::kNumberTypeFloat32) { - const float value = 1.f; - value_tensor = ctx->ConvertTo1DTensor(value); - } else if (in_tensors_[0].DataType() == DataType::kNumberTypeInt32) { - const int value = 1; - value_tensor = ctx->ConvertTo1DTensor(value); - } else { - MS_LOG(ERROR) << "dtype not implement: " << in_tensors_[0].DataType(); - return RET_ERROR; - } - CHECK_NULL_RETURN(value_tensor); - auto unsqueeze_layer = ctx->network()->addShuffle(*value_tensor); - CHECK_NULL_RETURN(unsqueeze_layer); - - auto shape_tensor = ctx->network()->addShape(*input_trt_tensor)->getOutput(0); - CHECK_NULL_RETURN(shape_tensor); - int rank = shape_tensor->getDimensions().d[0]; - nvinfer1::Dims unsqueeze{rank}; - std::fill(unsqueeze.d, unsqueeze.d + rank, 1); - unsqueeze_layer->setReshapeDimensions(unsqueeze); - unsqueeze_layer->setZeroIsPlaceholder(false); - value_tensor = unsqueeze_layer->getOutput(0); - CHECK_NULL_RETURN(value_tensor); - - auto out_tensor = Broadcast(ctx, value_tensor, shape_tensor); - - CHECK_NULL_RETURN(out_tensor); - ctx->RegisterTensor(ITensorHelper{out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameOnesLike, OneslikeTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/pad_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/pad_tensorrt.cc deleted file mode 100644 index 26fdffe14ae20141814e279e4405a8a3666cf1f8..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/pad_tensorrt.cc +++ /dev/null @@ -1,269 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/pad_tensorrt.h" -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "infer/cxx_api/pad_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" - -namespace mindspore::lite { -int PadTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2 && in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - if (!in_tensors_[1].IsConst()) { - MS_LOG(ERROR) << "invalid pad tensor for: " << op_name_; - return RET_ERROR; - } - auto pad_op = AsOps(); - if (pad_op == nullptr) { - MS_LOG(ERROR) << "convert PadFusion failed: " << op_name_; - return RET_ERROR; - } - if (pad_op->HasAttr(ops::kPaddingMode)) { - padding_mode_ = pad_op->get_padding_mode(); - } -#if TRT_VERSION_GE(8, 0) - if (padding_mode_ != PaddingMode::CONSTANT && padding_mode_ != PaddingMode::REFLECT) { - MS_LOG(ERROR) << "Unsupported padding mode: " << PaddingMode(padding_mode_) << ", for op: " << op_name_; - return RET_ERROR; - } -#else - if (padding_mode_ != PaddingMode::CONSTANT) { - MS_LOG(ERROR) << "Unsupported padding mode: " << PaddingMode(padding_mode_) << ", for op: " << op_name_; - return RET_ERROR; - } -#endif - if (in_tensors[0].format() != Format::NHWC && in_tensors[0].format() != Format::NCHW) { - MS_LOG(ERROR) << "Unsupported input tensor format of " << in_tensors[0].format(); - return RET_ERROR; - } - if (pad_op->HasAttr(ops::kConstantValue)) { - constant_value_ = pad_op->get_constant_value(); - } - return RET_OK; -} - -int PadTensorRT::AddInnerOpFix(TensorRTContext *ctx, const std::vector &input_shape, - nvinfer1::ITensor *pad_input, const std::vector &pad_vec) { -#if TRT_VERSION_GE(8, 0) - std::vector start_values; - std::vector size_values; - std::vector stride_values; - for (size_t i = 0; i < pad_vec.size(); i += INPUT_SIZE2) { - start_values.push_back(-pad_vec[i]); - stride_values.push_back(1); - size_values.push_back(input_shape[i / INPUT_SIZE2] + pad_vec[i] + pad_vec[i + 1]); - } - auto slice_layer = ctx->network()->addSlice(*pad_input, ConvertCudaDims(start_values), ConvertCudaDims(size_values), - ConvertCudaDims(stride_values)); - if (slice_layer == nullptr) { - MS_LOG(ERROR) << "Failed to add slice layer for op " << op_name_; - return RET_ERROR; - } - if (padding_mode_ == PaddingMode::REFLECT) { - slice_layer->setMode(nvinfer1::SliceMode::kREFLECT); - } else if (padding_mode_ == PaddingMode::CONSTANT) { - slice_layer->setMode(nvinfer1::SliceMode::kFILL); - auto const_input = - ConvertScalarToITensor(ctx, 1, &constant_value_, DataType::kNumberTypeFloat32, op_name_ + "_fill"); - if (const_input == nullptr) { - MS_LOG(ERROR) << "Failed to create scalar tensor of constant value for op " << op_name_; - return RET_ERROR; - } - constexpr int fill_input_index = 4; - slice_layer->setInput(fill_input_index, *const_input); - } else { - MS_LOG(ERROR) << "Not support padding mode " << padding_mode_ << " for op " << op_name_; - return RET_ERROR; - } - slice_layer->setName(op_name_.c_str()); - this->layer_ = slice_layer; - auto out_tensor = slice_layer->getOutput(0); - if (out_tensor == nullptr) { - MS_LOG(ERROR) << "Failed to get output tensor of op " << op_name_; - return RET_ERROR; - } - auto output_tensor = ITensorHelper{out_tensor, NCHW, true}; - ctx->RegisterTensor(output_tensor, out_tensors_[0].Name()); - return RET_OK; -#else - MS_LOG(ERROR) << "Only support pad mode constant and input dims count 8 when trt version < 8.0, op: " << op_name_; - return RET_ERROR; -#endif -} - -int PadTensorRT::AddInnerOpDynamic(TensorRTContext *ctx, const std::vector &input_shape, - nvinfer1::ITensor *pad_input, const std::vector &pad_vec) { -#if TRT_VERSION_GE(8, 0) - std::vector pre_values; - std::vector post_values; - std::vector stride_values; - for (size_t i = 0; i < pad_vec.size(); i += INPUT_SIZE2) { - pre_values.push_back(-pad_vec[i]); - post_values.push_back(pad_vec[i + 1]); - stride_values.push_back(1); - } - auto post_tensor = ctx->ConvertTo1DTensor(post_values); - auto pre_tensor = ctx->ConvertTo1DTensor(pre_values); - auto shape_tensor = ctx->network()->addShape(*pad_input)->getOutput(0); - auto size_tensor = - ctx->network()->addElementWise(*shape_tensor, *post_tensor, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - size_tensor = - ctx->network()->addElementWise(*size_tensor, *pre_tensor, nvinfer1::ElementWiseOperation::kSUB)->getOutput(0); - auto slice_layer = - ctx->network()->addSlice(*pad_input, ConvertCudaDims(pre_values), {-1, {}}, ConvertCudaDims(stride_values)); - if (slice_layer == nullptr) { - MS_LOG(ERROR) << "Failed to add slice layer for op " << op_name_; - return RET_ERROR; - } - slice_layer->setInput(INPUT_SIZE2, *size_tensor); - if (padding_mode_ == PaddingMode::REFLECT) { - slice_layer->setMode(nvinfer1::SliceMode::kREFLECT); - } else if (padding_mode_ == PaddingMode::CONSTANT) { - slice_layer->setMode(nvinfer1::SliceMode::kFILL); - auto const_input = - ConvertScalarToITensor(ctx, 1, &constant_value_, DataType::kNumberTypeFloat32, op_name_ + "_fill"); - if (const_input == nullptr) { - MS_LOG(ERROR) << "Failed to create scalar tensor of constant value for op " << op_name_; - return RET_ERROR; - } - constexpr int fill_input_index = 4; - slice_layer->setInput(fill_input_index, *const_input); - } else { - MS_LOG(ERROR) << "Not support padding mode " << padding_mode_ << " for op " << op_name_; - return RET_ERROR; - } - slice_layer->setName(op_name_.c_str()); - this->layer_ = slice_layer; - auto out_tensor = slice_layer->getOutput(0); - if (out_tensor == nullptr) { - MS_LOG(ERROR) << "Failed to get output tensor of op " << op_name_; - return RET_ERROR; - } - auto output_tensor = ITensorHelper{out_tensor, NCHW, true}; - ctx->RegisterTensor(output_tensor, out_tensors_[0].Name()); - return RET_OK; -#else - MS_LOG(ERROR) << "Only support pad mode constant and input dims count 8 when trt version < 8.0, op: " << op_name_; - return RET_ERROR; -#endif -} - -int PadTensorRT::AddInnerOp(TensorRTContext *ctx) { - auto pad_input = input(ctx, 0).trt_tensor_; - if (pad_input == nullptr) { - MS_LOG(ERROR) << "TensorRt Tensor of input 0 of pad " << op_name_ << " is nullptr"; - return RET_ERROR; - } - auto input_shape = ConvertMSShape(pad_input->getDimensions()); - if (!in_tensors_[1].IsConst()) { - MS_LOG(ERROR) << "Input 1 of pad " << op_name_ << " is not constant"; - return RET_ERROR; - } - auto pad_vec = ConvertTensorAsIntVector(in_tensors_[1]); - if (pad_vec.empty()) { - MS_LOG(ERROR) << "Failed to get pad input, node: " << op_name_; - return RET_ERROR; - } - constexpr size_t pad_multi_times = 2; - if (pad_vec.size() % INPUT_SIZE2 != 0 && pad_vec.size() != input_shape.size() * pad_multi_times) { - MS_LOG(ERROR) << "pad tensor is invalid, pad count: " << pad_vec.size() - << ", input dims count: " << input_shape.size() << ", op: " << op_name_; - return RET_ERROR; - } - if (input_shape.size() == kDim4 && padding_mode_ == PaddingMode::CONSTANT) { - return AddInnerOpOld(ctx); - } - if (IsDynamicInput(ctx, 0)) { - return AddInnerOpDynamic(ctx, input_shape, pad_input, pad_vec); - } else { - return AddInnerOpFix(ctx, input_shape, pad_input, pad_vec); - } -} - -int PadTensorRT::AddInnerOpOld(TensorRTContext *ctx) { - TensorInfo &pad_tensor = in_tensors_[1]; - int element_cnt = pad_tensor.ElementNum(); - if (element_cnt != input(ctx, 0).trt_tensor_->getDimensions().nbDims * INPUT_SIZE2) { - MS_LOG(ERROR) << "pad tensor cnt is invalid. cnt: " << element_cnt - << ", input tensor dims cnt: " << input(ctx, 0).trt_tensor_->getDimensions().nbDims; - return RET_ERROR; - } - - nvinfer1::ITensor *pad_input = input(ctx, 0).trt_tensor_; - MS_LOG(DEBUG) << "before transpose " << GetTensorFormat(pad_input, input(ctx, 0).format_, input(ctx, 0).same_format_); - - // trt 6 only support 2D padding - auto pad_vec = ConvertTensorAsIntVector(in_tensors_[1]); - if (pad_vec.empty()) { - MS_LOG(ERROR) << "Failed to get pad input, node: " << op_name_; - return RET_ERROR; - } - nvinfer1::IPaddingLayer *padding_layer = nullptr; - constexpr size_t expect_pad_size = 8; // NCHW dim number * 2 - if (pad_vec.size() == expect_pad_size) { - // only support pad at HW index - nvinfer1::DimsHW prePadding; - nvinfer1::DimsHW postPadding; - // NCHW: 0: N_pre, 1: N_post, 2: C_pre, 3: C_post, 4: H_pre, 5: H_post, 6: W_pre, 7: W_post - constexpr size_t n_pre = 0; - constexpr size_t n_post = 1; - constexpr size_t c_pre = 2; - constexpr size_t c_post = 3; - constexpr size_t h_pre = 4; - constexpr size_t h_post = 5; - constexpr size_t w_pre = 6; - constexpr size_t w_post = 7; - if (pad_vec[n_pre] != 0 || pad_vec[n_post] != 0 || pad_vec[c_pre] != 0 || pad_vec[c_post] != 0) { - MS_LOG(WARNING) << "tensorrt padding only support pad at HW index, unsupported padding value of: " << op_name_; - } - prePadding = nvinfer1::DimsHW{pad_vec[h_pre], pad_vec[w_pre]}; - postPadding = nvinfer1::DimsHW{pad_vec[h_post], pad_vec[w_post]}; - MS_LOG(DEBUG) << op_name_ << " prePadding: " << prePadding.d[0] << ", " << prePadding.d[1] - << "; postPadding: " << postPadding.d[0] << ", " << postPadding.d[1]; - - padding_layer = ctx->network()->addPadding(*pad_input, prePadding, postPadding); - } else { - MS_LOG(ERROR) << "need check for pad_tensor dims: " << op_name_ - << ", pad_tensor ElementNum: " << pad_tensor.ElementNum(); - return RET_ERROR; - } - if (padding_layer == nullptr) { - MS_LOG(ERROR) << "add padding layer failed for " << op_name_; - return RET_ERROR; - } - this->layer_ = padding_layer; - padding_layer->setName(op_name_.c_str()); - nvinfer1::ITensor *out_tensor = padding_layer->getOutput(0); - bool same_format = SameDims(out_tensor->getDimensions(), out_tensors_[0].Shape()) && - SameDims(input(ctx, 0).trt_tensor_->getDimensions(), in_tensors_[0].Shape()); - auto output_helper = ITensorHelper{out_tensor, Format::NCHW, same_format}; - ctx->RegisterTensor(output_helper, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "after transpose " << GetTensorFormat(output_helper); - return RET_OK; -} - -REGISTER_TENSORRT_CREATOR(ops::kNamePadFusion, PadTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/pad_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/pad_tensorrt.h deleted file mode 100644 index 413c024402ef20830af516024aa1c482ce5c6169..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/pad_tensorrt.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_PAD_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_PAD_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class PadTensorRT : public TensorRTOp { - public: - PadTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~PadTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return true; } - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &inputs, - const std::vector &outputs) override; - - private: - int AddInnerOpFix(TensorRTContext *ctx, const std::vector &input_shape, nvinfer1::ITensor *pad_input, - const std::vector &pad_vec); - int AddInnerOpDynamic(TensorRTContext *ctx, const std::vector &input_shape, nvinfer1::ITensor *pad_input, - const std::vector &pad_vec); - float constant_value_ = 0.0f; - PaddingMode padding_mode_ = PaddingMode::CONSTANT; - int AddInnerOpOld(TensorRTContext *ctx); -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_PAD_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/pool_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/pool_tensorrt.cc deleted file mode 100644 index 07f727111b16ea92394e7ebb5971f584c77d8c33..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/pool_tensorrt.cc +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/pool_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/op/activation_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "infer/cxx_api/avg_pool_fusion.h" -#include "infer/cxx_api/max_pool_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" - -namespace mindspore::lite { -int PoolTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int PoolTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (in_tensors_.size() != 1) { - MS_LOG(ERROR) << "invalid input tensor size: " << in_tensors_.size(); - return RET_ERROR; - } - MS_LOG(DEBUG) << "before transpose " << GetTensorFormat(input(ctx, 0)); - int ret = ParseParams(ctx); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ParseParams failed for : " << op_name_; - return RET_ERROR; - } - - nvinfer1::ITensor *pool_input = input(ctx, 0).trt_tensor_; - - // global version pooling - if (kernel_size_.empty()) { - int reduce_axes = ((1 << pool_input->getDimensions().nbDims) - 1) & ~0b11; - auto *layer = ctx->network()->addReduce(*pool_input, nvinfer1::ReduceOperation::kAVG, reduce_axes, true); - if (layer == nullptr) { - MS_LOG(ERROR) << "addReduce for pool failed"; - return RET_ERROR; - } - layer->setName(op_name_.c_str()); - this->layer_ = layer; - } else { - // pooling layer - nvinfer1::Dims windowSize = lite::ConvertCudaDims(kernel_size_); - if (windowSize.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return RET_ERROR; - } - nvinfer1::IPoolingLayer *pooling_layer = ctx->network()->addPoolingNd(*pool_input, pooling_type_, windowSize); - if (pooling_layer == nullptr) { - MS_LOG(ERROR) << "addPoolingNd failed for TensorRT."; - return RET_ERROR; - } - ret = AddParams(pooling_layer); - if (ret != RET_OK) { - MS_LOG(ERROR) << "AddParams failed for : " << op_name_; - return RET_ERROR; - } - pooling_layer->setName(op_name_.c_str()); - this->layer_ = pooling_layer; - } - - // add activation - nvinfer1::ILayer *activation_layer = nullptr; - if (activation_type_ == ActivationType::NO_ACTIVATION) { - activation_layer = this->layer_; - } else { - activation_layer = ActivationTensorRT::AddActivation(ctx, activation_type_, 0, 0, 0, this->layer_->getOutput(0), - op_name_, device_id_); - if (activation_layer == nullptr) { - MS_LOG(ERROR) << "addActivation for pool failed"; - return RET_ERROR; - } - activation_layer->setName((op_name_ + "_activation").c_str()); - } - nvinfer1::ITensor *out_trt_tensor = activation_layer->getOutput(0); - auto output_helper = ITensorHelper{out_trt_tensor, Format::NCHW, true}; - ctx->RegisterTensor(output_helper, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "output " << GetTensorFormat(output_helper); - return RET_OK; -} - -int PoolTensorRT::ParseParams(TensorRTContext *ctx) { - if (type_ == ops::kNameAvgPoolFusion) { - auto pool_primitive = AsOps(); - if (pool_primitive == nullptr) { - MS_LOG(ERROR) << "convert PoolFusion failed: " << op_name_; - return RET_ERROR; - } - pooling_type_ = nvinfer1::PoolingType::kAVERAGE; - - auto stride = pool_primitive->get_strides(); - if (stride.empty()) { - MS_LOG(ERROR) << "get stride failed: " << op_name_; - return RET_ERROR; - } - stride_ = std::vector(stride.begin(), stride.end()); - auto kernel_size = pool_primitive->get_kernel_size(); - if (kernel_size.empty()) { - MS_LOG(WARNING) << op_name_ << "don't has kernel size"; - } else { - kernel_size_ = std::vector(kernel_size.begin(), kernel_size.end()); - } - auto padding = pool_primitive->get_pad(); - if (!padding.empty() && padding.size() != DIMENSION_4D) { - MS_LOG(ERROR) << op_name_ << "has invalid pad dims: " << padding.size(); - return RET_ERROR; - } else if (padding.empty()) { - padding_ = std::vector(DIMENSION_4D, 0); - } else { - padding_ = std::vector(padding.begin(), padding.end()); - } - - pad_mode_ = static_cast(pool_primitive->get_pad_mode()); - if (pool_primitive->HasAttr(ops::kActivationType)) { - activation_type_ = pool_primitive->get_activation_type(); - } - } else if (type_ == ops::kNameMaxPoolFusion) { - auto pool_primitive = AsOps(); - if (pool_primitive == nullptr) { - MS_LOG(ERROR) << "convert PoolFusion failed: " << op_name_; - return RET_ERROR; - } - pooling_type_ = nvinfer1::PoolingType::kMAX; - - auto kernel_size = pool_primitive->get_kernel_size(); - if (kernel_size.empty()) { - MS_LOG(ERROR) << "get kernel size failed: " << op_name_; - return RET_ERROR; - } - kernel_size_ = std::vector(kernel_size.begin(), kernel_size.end()); - - auto stride = pool_primitive->get_strides(); - if (stride.empty()) { - MS_LOG(ERROR) << "get stride failed: " << op_name_; - return RET_ERROR; - } - stride_ = std::vector(stride.begin(), stride.end()); - auto padding = pool_primitive->get_pad(); - if (padding.empty()) { - MS_LOG(INFO) << "get padding is null, set to default 0: " << op_name_; - padding_ = {0, 0, 0, 0}; - } else { - padding_ = std::vector(padding.begin(), padding.end()); - } - - pad_mode_ = pool_primitive->get_pad_mode(); - if (pool_primitive->HasAttr(ops::kActivationType)) { - activation_type_ = pool_primitive->get_activation_type(); - } - } else { - MS_LOG(ERROR) << "unsupported primitive type of " << type_ << " for node: " << op_name_; - return RET_ERROR; - } - return RET_OK; -} - -int PoolTensorRT::AddParams(nvinfer1::IPoolingLayer *pooling_layer) { - nvinfer1::Dims stride_dims = ConvertCudaDims(stride_); - if (stride_dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return RET_ERROR; - } - pooling_layer->setStrideNd(stride_dims); - if (pad_mode_ == PadMode::SAME) { - pooling_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); - } else { - if (padding_.size() != DIMENSION_4D) { - MS_LOG(ERROR) << "Invalid padding " << padding_ << ", op: " << op_name_; - return RET_ERROR; - } - nvinfer1::Dims pre_dims{}; - pre_dims.nbDims = DIMENSION_2D; - pre_dims.d[0] = padding_[kDim0]; - pre_dims.d[1] = padding_[kDim2]; - pooling_layer->setPrePadding(pre_dims); - - nvinfer1::Dims post_dims{}; - post_dims.nbDims = DIMENSION_2D; - post_dims.d[0] = padding_[kDim1]; - post_dims.d[1] = padding_[kDim3]; - pooling_layer->setPostPadding(post_dims); - } - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameAvgPoolFusion, PoolTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameMaxPoolFusion, PoolTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/pool_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/pool_tensorrt.h deleted file mode 100644 index a0dbc46f44287bb0b6b8ba07ad6255f796b0bd28..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/pool_tensorrt.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_POOL_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_POOL_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class PoolTensorRT : public TensorRTOp { - public: - PoolTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~PoolTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int ParseParams(TensorRTContext *ctx); - - int AddParams(nvinfer1::IPoolingLayer *pooling_layer); - - std::vector kernel_size_; - - std::vector stride_; - - std::vector padding_; - - nvinfer1::PoolingType pooling_type_; - - PadMode pad_mode_ = PadMode::PAD; - - ActivationType activation_type_ = ActivationType::NO_ACTIVATION; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_POOL_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/prelu_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/prelu_tensorrt.cc deleted file mode 100644 index 4d2d2c56e4801a9c28f1232f9030a0069aea1cf6..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/prelu_tensorrt.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/prelu_tensorrt.h" -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "infer/cxx_api/prelu_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" - -namespace mindspore::lite { -int PReluTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - return RET_OK; -} - -int PReluTensorRT::AddInnerOp(TensorRTContext *ctx) { - ITensorHelper prelu_input; - int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &prelu_input); - if (ret != RET_OK || prelu_input.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim input tensor failed for " << op_name_; - return ret; - } - int input_nbdims = prelu_input.trt_tensor_->getDimensions().nbDims; - int slope_nbdims = in_tensors_[1].Shape().size(); - ITensorHelper slope_helper; - if (input_nbdims != slope_nbdims) { - auto expect_shape = ConvertMSShape(input(ctx, 0).trt_tensor_->getDimensions()); - slope_helper.trt_tensor_ = ConvertTensorWithExpandDims(ctx, in_tensors_[1], expect_shape, op_name_ + "_slope"); - } - if (slope_helper.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "add const input tensor failed for " << op_name_; - return RET_ERROR; - } - ret = PreprocessInputs2SameDim(ctx, slope_helper, &slope_helper); - if (ret != RET_OK || slope_helper.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim slope tensor failed for " << op_name_; - return ret; - } - - auto *prelu_layer = ctx->network()->addParametricReLU(*prelu_input.trt_tensor_, *slope_helper.trt_tensor_); - if (prelu_layer == nullptr) { - MS_LOG(ERROR) << "addParameticReLU failed for TensorRT : " << op_name_; - return RET_ERROR; - } - - nvinfer1::ITensor *out_tensor = prelu_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, prelu_input.format_, prelu_input.same_format_}, out_tensors_[0].Name()); - this->layer_ = prelu_layer; - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNamePReLUFusion, PReluTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/prelu_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/prelu_tensorrt.h deleted file mode 100644 index 663c23805a65c88ca9129da09d552a7f1a9de618..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/prelu_tensorrt.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_PRELU_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_PRELU_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class PReluTensorRT : public TensorRTOp { - public: - PReluTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~PReluTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_PRELU_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/reduce_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/reduce_tensorrt.cc deleted file mode 100644 index 933f2becccfbbed6f10ebb2ea0cb6db7ddfb50b9..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/reduce_tensorrt.cc +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/reduce_tensorrt.h" -#include -#include "infer/cxx_api/reduce_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" - -namespace mindspore::lite { -int ReduceTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - } - return RET_OK; -} - -int ReduceTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - auto reduce_op = AsOps(); - if (reduce_op == nullptr) { - MS_LOG(ERROR) << "convert failed"; - return RET_ERROR; - } - bool keep_dims = reduce_op->get_keep_dims(); - out_format_ = input(ctx, 0).format_; - nvinfer1::ITensor *reduce_input = input(ctx, 0).trt_tensor_; - MS_LOG(DEBUG) << "origin input " << GetTensorFormat(input(ctx, 0)); - - MS_LOG(DEBUG) << "after transpose input " << GetTensorFormat(reduce_input, out_format_, keep_dims); - if (reduce_op->get_mode() == ReduceMode::Reduce_L2) { - // x^2 - auto *pow2_layer = - ctx->network()->addElementWise(*reduce_input, *reduce_input, nvinfer1::ElementWiseOperation::kPROD); - CHECK_NULL_RETURN(pow2_layer); - pow2_layer->setName((op_name_ + "_pow2").c_str()); - - reduce_input = pow2_layer->getOutput(0); - CHECK_NULL_RETURN(reduce_input); - } - - uint32_t reduceAxis = GetAxis(ctx); - auto reduce_operation_opt = TryConvertTRTReduceMode(reduce_op->get_mode()); - if (!reduce_operation_opt) { - MS_LOG(WARNING) << "invalid reduce for TensorRT, need check: " << static_cast(reduce_op->get_mode()); - return RET_ERROR; - } - nvinfer1::IReduceLayer *layer = - ctx->network()->addReduce(*reduce_input, reduce_operation_opt.value(), reduceAxis, keep_dims); - CHECK_NULL_RETURN(layer); - layer->setName(op_name_.c_str()); - this->layer_ = layer; - - nvinfer1::ITensor *out_tensor = layer->getOutput(0); - CHECK_NULL_RETURN(out_tensor); - bool is_tensor = out_tensor->getDimensions().nbDims != 0; - if (!is_tensor) { - auto squeeze = ctx->network()->addShuffle(*out_tensor); - CHECK_NULL_RETURN(squeeze); - squeeze->setName((op_name_ + "_squeeze_out").c_str()); - squeeze->setReshapeDimensions(nvinfer1::Dims{1, {1}}); - out_tensor = squeeze->getOutput(0); - } - - if (reduce_op->get_mode() == ReduceMode::Reduce_L2) { - auto sqrt_layer = ctx->network()->addUnary(*out_tensor, nvinfer1::UnaryOperation::kSQRT); - CHECK_NULL_RETURN(sqrt_layer); - sqrt_layer->setName((op_name_ + "_sqrt").c_str()); - out_tensor = sqrt_layer->getOutput(0); - } - auto output_helper = ITensorHelper{out_tensor, out_format_, true, is_tensor}; - ctx->RegisterTensor(output_helper, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "output " << GetTensorFormat(output_helper); - return RET_OK; -} - -uint32_t ReduceTensorRT::GetAxis(TensorRTContext *ctx) { - uint32_t reduceAxis = 0; - if (in_tensors_.size() == 1) { - std::vector axis; - axis.resize(input(ctx, 0).trt_tensor_->getDimensions().nbDims); - std::iota(axis.begin(), axis.end(), 0); - for (int ax : axis) { - MS_LOG(DEBUG) << op_name_ << " reduceAxis at index : " << ax; - reduceAxis |= 1 << ax; - } - } else { - // axis - auto axis_tensor = this->in_tensors_[1]; - if (!axis_tensor.IsConst()) { - MS_LOG(ERROR) << "invalid axis_tensor"; - return reduceAxis; - } - if (axis_tensor.DataType() != DataType::kNumberTypeInt32 && axis_tensor.DataType() != DataType::kNumberTypeInt64) { - MS_LOG(WARNING) << "not int data type"; - } - auto axis_vec = ConvertTensorAsIntVector(axis_tensor); - if (axis_vec.empty()) { - MS_LOG(ERROR) << "Failed to get axis input, axis size " << axis_vec.size() << ", node: " << op_name_; - return RET_ERROR; - } - auto input_0 = input(ctx, 0).trt_tensor_; - for (size_t i = 0; i < axis_vec.size(); i++) { - int format_axis_data = (axis_vec[i] < 0) ? input_0->getDimensions().nbDims + axis_vec[i] : axis_vec[i]; - MS_LOG(DEBUG) << op_name_ << " reduceAxis at index : " << axis_vec[i]; - reduceAxis |= 1u << format_axis_data; - } - } - return reduceAxis; -} -REGISTER_TENSORRT_CREATOR(ops::kNameReduceFusion, ReduceTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/reduce_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/reduce_tensorrt.h deleted file mode 100644 index 9a365235d5d9c562237ae3b2eabf43ca5fbbb6f5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/reduce_tensorrt.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_REDUCE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_REDUCE_TENSORRT_H_ - -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class ReduceTensorRT : public TensorRTOp { - public: - ReduceTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~ReduceTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return true; } - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - uint32_t GetAxis(TensorRTContext *ctx); - Format out_format_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_REDUCE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/reducescatter_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/reducescatter_tensorrt.cc deleted file mode 100644 index e8c3643da8c5615002d49eaf07f94e1c9d6a93fe..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/reducescatter_tensorrt.cc +++ /dev/null @@ -1,126 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/reducescatter_tensorrt.h" -#include -#include -#include "NvInferRuntimeCommon.h" -#include "infer/reduce_scatter.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" - -namespace mindspore::lite { -REGISTER_TENSORRT_PLUGIN(ReduceScatterPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int ReduceScatterTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { -#ifndef LITE_CUDA_DISTRIBUTION - MS_LOG(ERROR) - << "Unsupported package for gpu distribution feature, please recompile with MS_ENABLE_CUDA_DISTRIBUTION set to on."; - return RET_ERROR; -#else - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "invalid input tensor size: " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "invalid output tensor size: " << out_tensors.size(); - return RET_ERROR; - } - dynamic_shape_params_.support_hw_dynamic_ = false; - return RET_OK; -#endif -} - -int ReduceScatterTensorRT::AddInnerOp(TensorRTContext *ctx) { - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_}; - auto reduce_op = AsOps(); - if (reduce_op == nullptr) { - MS_LOG(ERROR) << "convert failed for " << op_name_; - return RET_ERROR; - } - auto reduce_mode = reduce_op->get_mode(); - auto rank = GetGPUGroupSize(); - auto plugin = std::make_shared(op_name_, reduce_mode, rank, device_id_); - MS_LOG(INFO) << op_name_ << " group size: " << rank << ", rank id: " << GetRankID(); - nvinfer1::IPluginV2Layer *reduce_scatter_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); - if (reduce_scatter_layer == nullptr) { - MS_LOG(ERROR) << "create ReduceScatter layer failed for: " << op_name_; - return RET_ERROR; - } - nvinfer1::ITensor *reduce_scatter_out = reduce_scatter_layer->getOutput(0); - reduce_scatter_layer->setName(op_name_.c_str()); - ctx->RegisterTensor(ITensorHelper{reduce_scatter_out, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - this->layer_ = reduce_scatter_layer; - return RET_OK; -} - -// ReduceScatterPlugin -int ReduceScatterPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - MS_LOG(INFO) << "ReduceScatter run at rank id: " << GetRankID() << " stream: " << stream; - nvinfer1::Dims output_dims = outputDesc[0].dims; - int recieve_element_cnt = - std::accumulate(output_dims.d, output_dims.d + output_dims.nbDims, 1, std::multiplies()); - const void *input = inputs[0]; - void *output = outputs[0]; - auto data_type = inputDesc->type; - auto ret = DistributionCollective::instance().ReduceScatterWrapper(input, output, recieve_element_cnt, data_type, - red_mode_, stream, NCCL_WORLD_GROUP); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ReduceScatter nccl run failed for " << layer_name_; - return ret; - } - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *ReduceScatterPlugin::clone() const noexcept { - auto *plugin = new ReduceScatterPlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -nvinfer1::DimsExprs ReduceScatterPlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs out_dims{}; - out_dims.nbDims = inputs->nbDims; - auto rank_dim = exprBuilder.constant(rank_); - out_dims.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kCEIL_DIV, *inputs->d[0], *rank_dim); - for (int i = 1; i < inputs->nbDims; i++) { - out_dims.d[i] = inputs->d[i]; - } - return out_dims; -} - -size_t ReduceScatterPlugin::getSerializationSize() const noexcept { return sizeof(ReduceMode); } - -void ReduceScatterPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &red_mode_, sizeof(ReduceMode)); -} - -REGISTER_TENSORRT_CREATOR(ops::kNameReduceScatter, ReduceScatterTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/reducescatter_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/reducescatter_tensorrt.h deleted file mode 100644 index 62624620f470cf017a15c01d6f99143d78f6ea30..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/reducescatter_tensorrt.h +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_REDUCESCATTER_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_REDUCESCATTER_TENSORRT_H_ -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "src/extendrt/delegate/tensorrt/distribution/distribution_collective.h" - -namespace mindspore::lite { -constexpr auto REDUCESCATTER_PLUGIN_NAME{"ReduceScatterPlugin"}; -class ReduceScatterTensorRT : public TensorRTOp { - public: - ReduceScatterTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~ReduceScatterTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -class ReduceScatterPlugin : public TensorRTPlugin { - public: - ReduceScatterPlugin(const std::string name, ReduceMode red_mode, int rank, uint32_t device_id) - : TensorRTPlugin(name, std::string(REDUCESCATTER_PLUGIN_NAME), device_id), red_mode_(red_mode), rank_(rank) {} - - ReduceScatterPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(REDUCESCATTER_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - red_mode_ = static_cast(fields[0].data)[0]; - rank_ = static_cast(fields[1].data)[0]; - } - - ReduceScatterPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(REDUCESCATTER_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &red_mode_, sizeof(ReduceMode)); - DeserializeValue(&serialData, &serialLength, &rank_, sizeof(int)); - } - - ReduceScatterPlugin() = delete; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - - private: - ReduceMode red_mode_; - int rank_{0}; -}; -class ReduceScatterPluginCreater : public TensorRTPluginCreater { - public: - ReduceScatterPluginCreater() : TensorRTPluginCreater(std::string(REDUCESCATTER_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_REDUCESCATTER_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.cc deleted file mode 100644 index 3d82f8d1abbaf64e42c6003d869a5f8092d8bca3..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.cc +++ /dev/null @@ -1,266 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/resize_tensorrt.h" -#include -#include -#include -#include -#include "nnacl/nnacl_common.h" -#include "infer/resize.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" - -namespace mindspore::lite { -int ResizeTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1 && in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - } - resize_op_ = AsOps(); - if (resize_op_ == nullptr) { - MS_LOG(ERROR) << "convert failed " << op_name_; - return RET_ERROR; - } - dynamic_shape_params_.support_hw_dynamic_ = - (resize_op_->get_new_height() > 0 && resize_op_->get_new_width() > 0) ? false : true; - dynamic_shape_params_.support_hw_dynamic_ &= resize_op_->get_method() != ResizeMethod::LINEAR; - - return RET_OK; -} - -int ResizeTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - - nvinfer1::ITensor *resize_in_tensor = input(ctx, 0).trt_tensor_; - MS_LOG(DEBUG) << "origin input " << GetTensorFormat(input(ctx, 0)); - - nvinfer1::ITensor *output_tensor = RunTensorRT(ctx, resize_in_tensor); - if (output_tensor == nullptr) { - return RET_ERROR; - } - auto output_helper = ITensorHelper{output_tensor, Format::NCHW, true}; - ctx->RegisterTensor(output_helper, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "output " << GetTensorFormat(output_helper); - return RET_OK; -} - -nvinfer1::ITensor *ResizeTensorRT::RunTensorRT(TensorRTContext *ctx, nvinfer1::ITensor *resize_in_tensor) { - nvinfer1::IResizeLayer *resize_layer = ctx->network()->addResize(*resize_in_tensor); - if (resize_layer == nullptr) { - MS_LOG(ERROR) << "create resize layer failed for " << op_name_; - return nullptr; - } - int ret = SetOutputDims(ctx, resize_in_tensor, resize_layer); - if (ret != RET_OK) { - MS_LOG(ERROR) << "SetOutputDims failed for " << op_name_; - return nullptr; - } - - ret = SetParams(resize_layer); - if (ret != RET_OK) { - MS_LOG(ERROR) << "SetParams failed for " << op_name_; - return nullptr; - } - this->layer_ = resize_layer; - return resize_layer->getOutput(0); -} - -int ResizeTensorRT::SetOutputDims(TensorRTContext *ctx, nvinfer1::ITensor *resize_in_tensor, - nvinfer1::IResizeLayer *resize_layer) { - nvinfer1::Dims in_dims = resize_in_tensor->getDimensions(); - if (in_tensors_.size() == 1 && in_dims.nbDims == DIMENSION_4D) { - nvinfer1::Dims4 new_dims(in_dims.d[0], in_dims.d[1], resize_op_->get_new_height(), - resize_op_->get_new_width()); // nchw - resize_layer->setOutputDimensions(new_dims); // static shape - } else if (resize_op_->HasAttr(kAttrScales)) { - auto scales = resize_op_->GetAttr(kAttrScales); - if (!scales) { - return RET_ERROR; - } - auto scales_val = GetValue>(scales); - if (SizeToInt(scales_val.size()) != in_dims.nbDims) { - MS_LOG(ERROR) << "Size " << scales_val.size() << " of scales get from attr != input dims count " << in_dims.nbDims - << ", op: " << op_name_; - return RET_ERROR; - } - resize_layer->setScales(scales_val.data(), scales_val.size()); - } else { - auto shape_value_tensor = in_tensors_[1]; - if (!shape_value_tensor.IsConst() && in_tensors_.size() >= INPUT_SIZE2) { - // dynamic output shape - auto shape_tensor = input(ctx, 1).trt_tensor_; - if (shape_tensor->getDimensions().d[0] == INPUT_SIZE4) { - resize_layer->setInput(1, *shape_tensor); - } else { - auto in_tensor_shape = ctx->network()->addShape(*resize_in_tensor)->getOutput(0); - CHECK_NULL_RETURN(in_tensor_shape); - nvinfer1::Dims start_dims{1, {0}}; - nvinfer1::Dims size_dims{1, {2}}; - nvinfer1::Dims stride_dims{1, {1}}; - auto nc = ctx->network()->addSlice(*in_tensor_shape, start_dims, size_dims, stride_dims)->getOutput(0); - CHECK_NULL_RETURN(nc); - - nvinfer1::ITensor *trt_input_tensors[INPUT_SIZE2]; - trt_input_tensors[0] = nc; - trt_input_tensors[1] = shape_tensor; - - auto concat_layer = ctx->network()->addConcatenation(trt_input_tensors, INPUT_SIZE2); - concat_layer->setAxis(0); - auto nchw = concat_layer->getOutput(0); - CHECK_NULL_RETURN(nchw); - nchw = TRTTensorCast(ctx, nchw, nvinfer1::DataType::kINT32, op_name_ + "_input_nchw_to_int32"); - CHECK_NULL_RETURN(nchw); - resize_layer->setInput(1, *nchw); - } - } else { - std::vector out_shape; - ParseValueFromShapeTensor(ctx, shape_value_tensor, &out_shape); - if (out_shape.size() == DIMENSION_2D && in_dims.nbDims == DIMENSION_4D) { - // out_shape: origin_n, out_shape[0], out_shape[1], origin_c - out_shape.insert(out_shape.begin(), in_dims.d[0]); // batch size is dynamic - out_shape.insert(out_shape.begin() + 1, in_dims.d[kNCHW_C]); // channel is const - } - if (shape_value_tensor.DataType() == DataType::kNumberTypeInt32) { - if (resize_in_tensor->getDimensions().d[0] == -1) { - nvinfer1::IShapeLayer *shape_layer = ctx->network()->addShape(*resize_in_tensor); - auto in_shape = shape_layer->getOutput(0); - mask2_[2] = out_shape[kNCHW_H]; - mask2_[3] = out_shape[kNCHW_W]; - auto mask1 = ConvertConstantTensor1D(ctx, mask1_, nvinfer1::DataType::kINT32); - auto mask2 = ConvertConstantTensor1D(ctx, mask2_, nvinfer1::DataType::kINT32); - in_shape = - ctx->network()->addElementWise(*in_shape, *mask1, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - in_shape = - ctx->network()->addElementWise(*in_shape, *mask2, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - resize_layer->setInput(1, *in_shape); - } else { - nvinfer1::Dims dims; - dims.nbDims = DIMENSION_4D; - dims.d[0] = out_shape[0]; - dims.d[1] = out_shape[1]; - dims.d[2] = out_shape[2]; - dims.d[3] = out_shape[3]; - resize_layer->setOutputDimensions(dims); - } - } else { - float scales[DIMENSION_4D]{1, 1, 1, 1}; - scales[kNCHW_H] = out_shape[kNHWC_H]; - scales[kNCHW_W] = out_shape[kNHWC_W]; - resize_layer->setScales(scales, DIMENSION_4D); - } - } - } - return RET_OK; -} - -void ResizeTensorRT::ParseValueFromShapeTensor(TensorRTContext *ctx, const TensorInfo &shape_value_tensor, - std::vector *out_shape) { - switch (shape_value_tensor.DataType()) { - case DataType::kNumberTypeFloat32: { - const float *shape_data_fp32 = static_cast(shape_value_tensor.Data()); - for (int64_t i = 0; i < shape_value_tensor.ElementNum(); i++) { - out_shape->push_back(*(shape_data_fp32 + i)); - } - if (out_shape->size() == INPUT_SIZE2) { - out_shape->insert(out_shape->begin(), 1.f); - out_shape->insert(out_shape->end(), 1.f); - } - break; - } - case DataType::kNumberTypeFloat16: { - const uint16_t *shape_data_fp16 = static_cast(shape_value_tensor.Data()); - for (int64_t i = 0; i < shape_value_tensor.ElementNum(); i++) { - out_shape->push_back(ShortToFloat32(*(shape_data_fp16 + i))); - } - if (out_shape->size() == INPUT_SIZE2) { - out_shape->insert(out_shape->begin(), 1.f); - out_shape->insert(out_shape->end(), 1.f); - } - break; - } - case DataType::kNumberTypeInt32: { - const int *shape_data_int32 = static_cast(shape_value_tensor.Data()); - for (int64_t i = 0; i < shape_value_tensor.ElementNum(); i++) { - out_shape->push_back(*(shape_data_int32 + i)); - } - break; - } - case DataType::kNumberTypeInt64: { - auto shape_data_int = static_cast(shape_value_tensor.Data()); - for (int64_t i = 0; i < shape_value_tensor.ElementNum(); i++) { - out_shape->push_back(LongToFloat(shape_data_int[i])); - } - break; - } - default: - MS_LOG(WARNING) << op_name_ - << " more datatype need to check: " << static_cast(shape_value_tensor.DataType()); - break; - } -} - -int ResizeTensorRT::SetParams(nvinfer1::IResizeLayer *resize_layer) { - auto method = resize_op_->get_method(); - std::map method_map = {{ResizeMethod::LINEAR, nvinfer1::ResizeMode::kLINEAR}, - {ResizeMethod::NEAREST, nvinfer1::ResizeMode::kNEAREST}}; - if (method_map.find(method) == method_map.end()) { - MS_LOG(ERROR) << op_name_ << " unsupported resize mode " << static_cast(method); - return RET_ERROR; - } - resize_layer->setResizeMode(method_map.at(method)); - - auto coordinate_transform_mode = resize_op_->get_coordinate_transform_mode(); -// unsupported for trt6, but support setCoordinateTransformation() in version8 -#if TRT_VERSION_GE(8, 0) - std::map transform_map = { - {CoordinateTransformMode::ASYMMETRIC, nvinfer1::ResizeCoordinateTransformation::kASYMMETRIC}, - // kASYMMETRIC has better precision - {CoordinateTransformMode::ALIGN_CORNERS, nvinfer1::ResizeCoordinateTransformation::kASYMMETRIC}, - {CoordinateTransformMode::HALF_PIXEL, nvinfer1::ResizeCoordinateTransformation::kHALF_PIXEL}}; - auto transform_it = transform_map.find(coordinate_transform_mode); - if (transform_it == transform_map.end()) { - MS_LOG(ERROR) << op_name_ << " not support resize coordinate transform mode " << coordinate_transform_mode; - return RET_ERROR; - } - resize_layer->setCoordinateTransformation(transform_it->second); - if (resize_op_->get_new_height() != 0 || resize_op_->get_new_width() != 0 || - (coordinate_transform_mode == CoordinateTransformMode::ALIGN_CORNERS && method == ResizeMethod::LINEAR)) { - resize_layer->setCoordinateTransformation(nvinfer1::ResizeCoordinateTransformation::kALIGN_CORNERS); - } - if (resize_op_->get_nearest_mode() != NearestMode::NORMAL) { - std::unordered_map nearest_mode_transform = { - {NearestMode::ROUND_HALF_DOWN, nvinfer1::ResizeRoundMode::kHALF_DOWN}, - {NearestMode::ROUND_HALF_UP, nvinfer1::ResizeRoundMode::kHALF_UP}, - {NearestMode::FLOOR, nvinfer1::ResizeRoundMode::kFLOOR}, - {NearestMode::CEIL, nvinfer1::ResizeRoundMode::kCEIL}}; - resize_layer->setNearestRounding(nearest_mode_transform.at(resize_op_->get_nearest_mode())); - } -#else - if (coordinate_transform_mode != CoordinateTransformMode::ASYMMETRIC) { - MS_LOG(WARNING) << op_name_ << " has coordinate_transform_mode may not supported: " - << static_cast(coordinate_transform_mode); - } -#endif - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameResize, ResizeTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.h deleted file mode 100644 index 5f1398559f66c9d2bd16514b17536c1d654310e6..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/resize_tensorrt.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_RESIZE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_RESIZE_TENSORRT_H_ - -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "infer/resize.h" - -namespace mindspore::lite { -class ResizeTensorRT : public TensorRTOp { - public: - ResizeTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~ResizeTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return true; } - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - nvinfer1::ITensor *RunTensorRT(TensorRTContext *ctx, nvinfer1::ITensor *resize_in_tensor); - - int SetOutputDims(TensorRTContext *ctx, nvinfer1::ITensor *resize_in_tensor, nvinfer1::IResizeLayer *resize_layer); - - void ParseValueFromShapeTensor(TensorRTContext *ctx, const TensorInfo &shape_value_tensor, - std::vector *out_shape); - - int SetParams(nvinfer1::IResizeLayer *resize_layer); - - std::shared_ptr resize_op_{nullptr}; - int mask1_[4]{1, 1, 0, 0}; - int mask2_[4]{0, 0, 0, 0}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_RESIZE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/reverse_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/reverse_tensorrt.cc deleted file mode 100644 index 7e92f2d5b3b6821ad3834fbade6ca635a65de350..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/reverse_tensorrt.cc +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/reverse_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" - -namespace mindspore::lite { -int ReverseTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - return RET_OK; -} - -int ReverseTensorRT::AddInnerOp(TensorRTContext *ctx) { - auto input_helper = input(ctx, 0); - auto dims = input_helper.trt_tensor_->getDimensions(); - for (int i = 0; i != dims.nbDims; ++i) { - if (dims.d[i] <= 0) { - MS_LOG(ERROR) << "this version do not support dynamic reverse op : " << op_name_; - return RET_ERROR; - } - } - std::vector concat_inputs; - auto reverse_op = AsOps(); - auto axis = reverse_op->get_axis(); - if (axis.size() != 1) { - MS_LOG(WARNING) << "reverse op has more than 1 axis for " << op_name_; - return RET_ERROR; - } - for (int i = dims.d[axis[0]] - 1; i >= 0; --i) { - nvinfer1::Dims start = nvinfer1::Dims{dims.nbDims, {}}; - std::fill(start.d, start.d + dims.nbDims, 0); - start.d[axis[0]] = i; - - nvinfer1::Dims size = dims; - size.d[axis[0]] = 1; - - nvinfer1::Dims stride = nvinfer1::Dims{dims.nbDims, {}}; - std::fill(stride.d, stride.d + dims.nbDims, 1); - - auto slice = ctx->network()->addSlice(*input_helper.trt_tensor_, start, size, stride)->getOutput(0); - concat_inputs.push_back(slice); - } - - auto concat_layer = ctx->network()->addConcatenation(concat_inputs.data(), concat_inputs.size()); - concat_layer->setAxis(axis[0]); - this->layer_ = concat_layer; - - auto out_tensor = concat_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameReverseV2, ReverseTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/reverse_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/reverse_tensorrt.h deleted file mode 100644 index 8a831f100ff80ed073d9a5c68c28ff54a6610376..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/reverse_tensorrt.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_REVERSE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_REVERSE_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class ReverseTensorRT : public TensorRTOp { - public: - ReverseTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~ReverseTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_REVERSE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/roialign_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/roialign_tensorrt.cc deleted file mode 100644 index f1c97615ea6b4254b9f45e90c60427c43103fda5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/roialign_tensorrt.cc +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "src/extendrt/delegate/tensorrt/op/roialign_tensorrt.h" -#include "kernel/gpu/cuda_impl/cuda_ops/roi_align_impl.cuh" -#include "infer/roi_align.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" - -namespace mindspore::lite { -int ROIAlignTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size for ROIAlignTensorRT, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size for ROIAlignTensorRT, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int ROIAlignTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "network or input tensor is invalid"; - return RET_ERROR; - } - auto op = AsOps(); - int pooled_height = op->get_pooled_height(); - int pooled_width = op->get_pooled_width(); - float spatial_scale = op->get_spatial_scale(); - int sample_num = op->get_sample_num(); - int roi_end_mode = op->get_roi_end_mode(); - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_, input(ctx, 1).trt_tensor_}; - int channel = inputTensors[0]->getDimensions().d[1]; - int height = inputTensors[0]->getDimensions().d[INPUT_SIZE2]; - int width = inputTensors[0]->getDimensions().d[INPUT_SIZE3]; - int roi_rows = inputTensors[1]->getDimensions().d[0]; - int roi_cols = inputTensors[1]->getDimensions().d[1]; - auto plugin = std::make_shared(op_name_, pooled_height, pooled_width, spatial_scale, sample_num, - roi_end_mode, channel, height, width, roi_rows, roi_cols); - CHECK_NULL_RETURN(plugin); - nvinfer1::IPluginV2Layer *roialign_layer = ctx->network()->addPluginV2(inputTensors, INPUT_SIZE2, *plugin); - CHECK_NULL_RETURN(roialign_layer); - this->layer_ = roialign_layer; - nvinfer1::ITensor *op_out_tensor = roialign_layer->getOutput(0); - CHECK_NULL_RETURN(op_out_tensor); - ctx->RegisterTensor(ITensorHelper{op_out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(ROIAlignPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int ROIAlignPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - return RunCudaROIAlign(inputDesc, inputs, outputs, stream); -} - -int ROIAlignPlugin::RunCudaROIAlign(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - if (inputDesc[1].type == nvinfer1::DataType::kFLOAT) { - ROIAlign(static_cast(inputs[0]), static_cast(inputs[1]), roi_rows_, roi_cols_, - static_cast(outputs[0]), spatial_scale_, sample_num_, roi_end_mode_, channel_, height_, width_, - pooled_height_, pooled_width_, device_id_, stream); - } else { - MS_LOG(ERROR) << "unsupported roialign data type"; - return RET_ERROR; - } - return RET_OK; -} - -nvinfer1::DimsExprs ROIAlignPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, - int nbInputDims, nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs dims; - dims.nbDims = INPUT_SIZE4; - dims.d[0] = inputs[1].d[0]; - dims.d[1] = inputs[0].d[1]; - auto pooled_height = exprBuilder.constant(pooled_height_); - dims.d[INPUT_SIZE2] = pooled_height; - auto pooled_width = exprBuilder.constant(pooled_width_); - dims.d[INPUT_SIZE3] = pooled_width; - return dims; -} - -bool ROIAlignPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept { - return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT && - tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; -} - -nvinfer1::IPluginV2DynamicExt *ROIAlignPlugin::clone() const noexcept { - auto *plugin = new (std::nothrow) ROIAlignPlugin(*this); - if (plugin == nullptr) { - MS_LOG(ERROR) << "malloc roialign plugin failed"; - return nullptr; - } - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} -REGISTER_TENSORRT_CREATOR(ops::kNameROIAlign, ROIAlignTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/roialign_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/roialign_tensorrt.h deleted file mode 100644 index 4401ca6aa3b8b3ebdef16cde4d57a2779b6a98d6..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/roialign_tensorrt.h +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_OP_ROIALIGN_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_OP_ROIALIGN_PLUGIN_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -constexpr auto ROIALIGN_PLUGIN_NAME{"ROIAlignPlugin"}; -class ROIAlignTensorRT : public TensorRTOp { - public: - ROIAlignTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~ROIAlignTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; -class ROIAlignPlugin : public TensorRTPlugin { - public: - explicit ROIAlignPlugin(const std::string name, int pooled_height, int pooled_width, float spatial_scale, - int sample_num, int roi_end_mode, int channel, int height, int width, int roi_rows, - int roi_cols) - : TensorRTPlugin(name, std::string(ROIALIGN_PLUGIN_NAME)), - pooled_height_(pooled_height), - pooled_width_(pooled_width), - spatial_scale_(spatial_scale), - sample_num_(sample_num), - roi_end_mode_(roi_end_mode), - channel_(channel), - height_(height), - width_(width), - roi_rows_(roi_rows), - roi_cols_(roi_cols) {} - - ROIAlignPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(ROIALIGN_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - pooled_height_ = static_cast(fields[0].data)[0]; - pooled_width_ = static_cast(fields[1].data)[0]; - spatial_scale_ = static_cast(fields[INPUT_SIZE2].data)[0]; - sample_num_ = static_cast(fields[INPUT_SIZE3].data)[0]; - roi_end_mode_ = static_cast(fields[INPUT_SIZE4].data)[0]; - channel_ = static_cast(fields[INPUT_SIZE5].data)[0]; - height_ = static_cast(fields[INPUT_SIZE6].data)[0]; - width_ = static_cast(fields[INPUT_SIZE7].data)[0]; - roi_rows_ = static_cast(fields[INPUT_SIZE8].data)[0]; - roi_cols_ = static_cast(fields[INPUT_SIZE9].data)[0]; - } - - ROIAlignPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(ROIALIGN_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &pooled_height_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &pooled_width_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &spatial_scale_, sizeof(float)); - DeserializeValue(&serialData, &serialLength, &sample_num_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &roi_end_mode_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &channel_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &height_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &width_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &roi_rows_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &roi_cols_, sizeof(int)); - } - - ROIAlignPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override; - - private: - int RunCudaROIAlign(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - int pooled_height_; - int pooled_width_; - float spatial_scale_; - int sample_num_{INPUT_SIZE2}; - int roi_end_mode_{1}; - int channel_; - int height_; - int width_; - int roi_rows_; - int roi_cols_; -}; -class ROIAlignPluginCreater : public TensorRTPluginCreater { - public: - ROIAlignPluginCreater() : TensorRTPluginCreater(std::string(ROIALIGN_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_OP_ROIALIGN_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/rsqrt_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/rsqrt_tensorrt.cc deleted file mode 100644 index 5faeeed6596caaf3478b6d361b03644f6b2a5b0b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/rsqrt_tensorrt.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "src/extendrt/delegate/tensorrt/op/rsqrt_tensorrt.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" - -namespace mindspore::lite { -int RsqrtTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int RsqrtTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid for " << op_name_; - return RET_ERROR; - } - int input_nbdims = input(ctx, 0).trt_tensor_->getDimensions().nbDims; - if (input_nbdims == -1) { - MS_LOG(ERROR) << "Invalid input dims " << input_nbdims << " for " << op_name_; - return RET_ERROR; - } - int ret = RunAsTrtOps(ctx); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Rsqrt op failed for " << op_name_; - return ret; - } - return ret; -} - -int RsqrtTensorRT::RunAsTrtOps(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid for " << op_name_; - return RET_ERROR; - } - auto const_one = ctx->ConvertTo1DTensor(std::vector(1, 1.f)); - const_one = Reshape(ctx, const_one, std::vector(input(ctx, 0).trt_tensor_->getDimensions().nbDims, 1)); - CHECK_NULL_RETURN(const_one); - auto sqrt_tensor = - ctx->network()->addUnary(*input(ctx, 0).trt_tensor_, nvinfer1::UnaryOperation::kSQRT)->getOutput(0); - auto rsqrt_layer = ctx->network()->addElementWise(*const_one, *sqrt_tensor, nvinfer1::ElementWiseOperation::kDIV); - CHECK_NULL_RETURN(rsqrt_layer); - auto out_tensor = rsqrt_layer->getOutput(0); - CHECK_NULL_RETURN(out_tensor); - this->layer_ = rsqrt_layer; - ctx->RegisterTensor(ITensorHelper{out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameRsqrt, RsqrtTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/rsqrt_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/rsqrt_tensorrt.h deleted file mode 100644 index 790bd459a9886424b7e8db0d0e878a66d52fc777..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/rsqrt_tensorrt.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_RSQRT_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_RSQRT_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class RsqrtTensorRT : public TensorRTOp { - public: - RsqrtTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~RsqrtTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int RunAsTrtOps(TensorRTContext *ctx); -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_RSQRT_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/scale_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/scale_tensorrt.cc deleted file mode 100644 index f735ec3968d47b968648f101fd1961f1d1d46021..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/scale_tensorrt.cc +++ /dev/null @@ -1,202 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/scale_tensorrt.h" -#include -#include -#include "src/extendrt/delegate/tensorrt/op/activation_tensorrt.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "infer/cxx_api/scale_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -constexpr int SCALE_INDEX = 1; -constexpr int SHIFT_INDEX = 2; -constexpr int POWER_INDEX = 3; - -int ScaleTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2 && in_tensors.size() != INPUT_SIZE3 && in_tensors.size() != INPUT_SIZE4) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is: " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is: " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int ScaleTensorRT::AddInnerOp(TensorRTContext *ctx) { - CHECK_NULL_RETURN(ctx); - out_format_ = input(ctx, 0).format_; - out_same_format_ = input(ctx, 0).same_format_; - - auto scale_op = AsOps(); - CHECK_NULL_RETURN(scale_op); - ActivationType activation_type = ActivationType::NO_ACTIVATION; - if (scale_op->HasAttr(ops::kActivationType)) { - activation_type = scale_op->get_activation_type(); - } - // mode of scale - axis_ = scale_op->get_axis(); - int input_nbdims = input(ctx, 0).trt_tensor_->getDimensions().nbDims; - if (input_nbdims < 0 || (axis_ < 0 && input_nbdims + axis_ < 0)) { - MS_LOG(ERROR) << "Not support axis " << axis_ << " or input dims " << input_nbdims << " for op " << op_name_; - return RET_ERROR; - } - axis_ = axis_ < 0 ? static_cast(input_nbdims + axis_) : axis_; - - mode_ = GetScaleMode(input(ctx, 0).trt_tensor_, axis_); - MS_LOG(DEBUG) << "before transpose " << GetTensorFormat(input(ctx, 0)); - nvinfer1::ITensor *scale_in_tensor = PreProcessInputTensor(ctx); - if (scale_in_tensor == nullptr) { - MS_LOG(ERROR) << "PreProcessInputTensor failed: " << op_name_; - return RET_ERROR; - } - MS_LOG(DEBUG) << "after transpose " << GetTensorFormat(scale_in_tensor, out_format_, out_same_format_); - - nvinfer1::ITensor *op_out_tensor{nullptr}; - if (scale_in_tensor->getDimensions().nbDims == DIMENSION_4D && mode_ != nvinfer1::ScaleMode::kCHANNEL) { - op_out_tensor = RunAs4DimsScale(ctx, scale_in_tensor); - } else { - op_out_tensor = RunAsMutiDimsScale(ctx, scale_in_tensor); - } - CHECK_NULL_RETURN(op_out_tensor); - - // add activation - if (activation_type != ActivationType::NO_ACTIVATION) { - auto activation_layer = - ActivationTensorRT::AddActivation(ctx, activation_type, 0, 0, 0, op_out_tensor, op_name_, device_id_); - CHECK_NULL_RETURN(activation_layer); - activation_layer->setName((op_name_ + "_activation").c_str()); - op_out_tensor = activation_layer->getOutput(0); - } - - auto output_helper = ITensorHelper{op_out_tensor, out_format_, out_same_format_}; - ctx->RegisterTensor(output_helper, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "output " << GetTensorFormat(output_helper); - return RET_OK; -} - -nvinfer1::ITensor *ScaleTensorRT::PreProcessInputTensor(TensorRTContext *ctx) { - nvinfer1::ITensor *scale_in_tensor = input(ctx, 0).trt_tensor_; - return scale_in_tensor; -} - -nvinfer1::ScaleMode ScaleTensorRT::GetScaleMode(nvinfer1::ITensor *input, int64_t axis) { - nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kUNIFORM; - auto input_data_shape = ConvertMSShape(input->getDimensions()); - auto input_weight_shape = in_tensors_[1].Shape(); - int total = std::accumulate(input_data_shape.begin(), input_data_shape.end(), 1, std::multiplies()); - if (input_weight_shape.size() == 0 || (input_weight_shape.size() == 1 && input_weight_shape[0] == 1)) { - mode = nvinfer1::ScaleMode::kUNIFORM; - } else if ((axis < static_cast(input_data_shape.size()) && input_weight_shape.size() == 1 && - input_data_shape[axis] == input_weight_shape[0]) || - (input_data_shape.size() == DIMENSION_4D && axis == DIMENSION_3D)) { - mode = nvinfer1::ScaleMode::kCHANNEL; - } else if (input_weight_shape.size() == 1 && input_weight_shape[0] == total) { - mode = nvinfer1::ScaleMode::kELEMENTWISE; - } else { - MS_LOG(WARNING) << "ScaleMode create failed: " << op_name_; - return mode; - } - MS_LOG(DEBUG) << op_name_ << " ScaleMode(UNIFORM 0, CHANNEL 1, ELEMENTWISE 2): " << static_cast(mode); - return mode; -} - -nvinfer1::ITensor *ScaleTensorRT::RunAs4DimsScale(TensorRTContext *ctx, nvinfer1::ITensor *scale_in_tensor) { - bool nd = false; - // (input * scale + shift) ^ power - nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, 0}; - nvinfer1::Weights shift{nvinfer1::DataType::kFLOAT, nullptr, 0}; - nvinfer1::Weights scale{nvinfer1::DataType::kFLOAT, nullptr, 0}; - if (in_tensors_.size() > SCALE_INDEX) { - scale.values = in_tensors_[SCALE_INDEX].MutableData(); - MS_ASSERT(scale.values); - scale.count = in_tensors_[SCALE_INDEX].ElementNum(); - scale.type = ConvertDataType(in_tensors_[SCALE_INDEX].DataType()); - shift.type = scale.type; - power.type = scale.type; - nd = in_tensors_[1].Shape().size() == 1 ? false : true; - } - if (in_tensors_.size() > SHIFT_INDEX) { - shift.values = in_tensors_[SHIFT_INDEX].MutableData(); - MS_ASSERT(shift.values); - shift.count = in_tensors_[SHIFT_INDEX].ElementNum(); - } - if (in_tensors_.size() > POWER_INDEX) { - power.values = in_tensors_[POWER_INDEX].MutableData(); - MS_ASSERT(power.values); - power.count = in_tensors_[POWER_INDEX].ElementNum(); - } - nvinfer1::IScaleLayer *cal_layer = nullptr; - - if (nd) { - MS_LOG(WARNING) << "multi dims ScaleMode enter"; - cal_layer = ctx->network()->addScaleNd(*scale_in_tensor, mode_, shift, scale, power, axis_); - } else { - cal_layer = ctx->network()->addScale(*scale_in_tensor, mode_, shift, scale, power); - } - - if (cal_layer == nullptr) { - MS_LOG(ERROR) << "addScaleNd failed for: " << op_name_; - return nullptr; - } - cal_layer->setName(op_name_.c_str()); - this->layer_ = cal_layer; - return cal_layer->getOutput(0); -} - -nvinfer1::ITensor *ScaleTensorRT::RunAsMutiDimsScale(TensorRTContext *ctx, nvinfer1::ITensor *scale_in_tensor) { - auto expect_shape = ConvertMSShape(scale_in_tensor->getDimensions()); - auto scale_tensor = ConvertConstantTensorWithDims(ctx, in_tensors_[1], expect_shape, op_name_); - if (scale_tensor == nullptr) { - MS_LOG(ERROR) << "ConvertConstantTensorWithDims failed for " << op_name_; - return nullptr; - } - auto mul_layer = - ctx->network()->addElementWise(*scale_in_tensor, *scale_tensor, nvinfer1::ElementWiseOperation::kPROD); - if (mul_layer == nullptr) { - MS_LOG(ERROR) << "add mul failed for " << op_name_; - return nullptr; - } - mul_layer->setName((op_name_ + "_scale").c_str()); - layer_ = mul_layer; - nvinfer1::ITensor *out_tensor = mul_layer->getOutput(0); - // add shift - if (in_tensors_.size() >= INPUT_SIZE3) { - auto shift_tensor = ConvertConstantTensorWithDims(ctx, in_tensors_[SHIFT_INDEX], expect_shape, op_name_); - if (shift_tensor == nullptr) { - MS_LOG(ERROR) << "ConvertConstantTensorWithDims failed for " << op_name_; - return nullptr; - } - auto shift_layer = ctx->network()->addElementWise(*out_tensor, *shift_tensor, nvinfer1::ElementWiseOperation::kSUM); - if (shift_layer == nullptr) { - MS_LOG(ERROR) << "add bias failed for " << op_name_; - return nullptr; - } - shift_layer->setName((op_name_ + "_shift").c_str()); - out_tensor = shift_layer->getOutput(0); - } - if (in_tensors_.size() == INPUT_SIZE4) { - MS_LOG(WARNING) << op_name_ << " has power"; - return nullptr; - } - return out_tensor; -} -REGISTER_TENSORRT_CREATOR(ops::kNameScaleFusion, ScaleTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/scale_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/scale_tensorrt.h deleted file mode 100644 index 1de70f2a31ac398d64516f8bd5b71cc578e52dc9..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/scale_tensorrt.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SCALE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SCALE_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -namespace mindspore::lite { -class ScaleTensorRT : public TensorRTOp { - public: - ScaleTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~ScaleTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return true; } - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - nvinfer1::ScaleMode GetScaleMode(nvinfer1::ITensor *, int64_t axis); - - nvinfer1::ITensor *PreProcessInputTensor(TensorRTContext *ctx); - - nvinfer1::ITensor *RunAs4DimsScale(TensorRTContext *ctx, nvinfer1::ITensor *scale_in_tensor); - - nvinfer1::ITensor *RunAsMutiDimsScale(TensorRTContext *ctx, nvinfer1::ITensor *scale_in_tensor); - - Format out_format_; - - bool out_same_format_{true}; - - nvinfer1::ScaleMode mode_; - - int64_t axis_{0}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SCALE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/scatternd_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/scatternd_tensorrt.cc deleted file mode 100644 index ea1061f76506b017fa7ffbe91dbd7fc0a698b787..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/scatternd_tensorrt.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/scatternd_tensorrt.h" -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "infer/scatter_nd_update.h" -#include "infer/tensor_scatter_update.h" -#include "infer/tensor_scatter_add.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" - -namespace mindspore::lite { -int ScatterNdTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { -#if TRT_VERSION_GE(8, 2) - if (in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - return RET_OK; -#else - MS_LOG(WARNING) << "low TensorRT version don't support Scatter op, please upgrade TensorRT version to 8.2 or higher"; - return RET_ERROR; -#endif -} - -int ScatterNdTensorRT::AddInnerOp(TensorRTContext *ctx) { -#if TRT_VERSION_GE(8, 2) - ITensorHelper scatter_input = input(ctx, 0); - if (in_tensors_[0].IsConst() && scatter_input.trt_tensor_ == nullptr) { - scatter_input.trt_tensor_ = lite::ConvertConstantTensor(ctx, in_tensors_[0], op_name_); - scatter_input.format_ = Format::NCHW; - ctx->RegisterTensor(scatter_input, in_tensors_[0].Name()); - } - ITensorHelper indices_helper; - int ret = PreprocessInputs2SameDim(ctx, input(ctx, 1), &indices_helper); - if (ret != RET_OK || indices_helper.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim indices tensor failed for " << op_name_; - return ret; - } - ITensorHelper updates_helper; - ret = PreprocessInputs2SameDim(ctx, input(ctx, INPUT_SIZE2), &updates_helper); - if (ret != RET_OK || updates_helper.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim update tensor failed for " << op_name_; - return ret; - } - - nvinfer1::IScatterLayer *scatter_layer = ctx->network()->addScatter( - *scatter_input.trt_tensor_, *indices_helper.trt_tensor_, *updates_helper.trt_tensor_, nvinfer1::ScatterMode::kND); - if (scatter_layer == nullptr) { - MS_LOG(ERROR) << "addScatter failed for TensorRT."; - return RET_ERROR; - } - - nvinfer1::ITensor *out_tensor = scatter_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, scatter_input.format_, scatter_input.same_format_}, - out_tensors_[0].Name()); - this->layer_ = scatter_layer; - return RET_OK; -#else - MS_LOG(WARNING) << "low TensorRT version don't support Scatter op, please upgrade TensorRT version to 8.2 or higher"; - return RET_ERROR; -#endif -} -REGISTER_TENSORRT_CREATOR(ops::kNameScatterNdUpdate, ScatterNdTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterUpdate, ScatterNdTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/scatternd_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/scatternd_tensorrt.h deleted file mode 100644 index 0e70ae35dc3ae06b962037c2a91dc776995b7cc5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/scatternd_tensorrt.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SCATTERND_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SCATTERND_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class ScatterNdTensorRT : public TensorRTOp { - public: - ScatterNdTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~ScatterNdTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_SCATTERND_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/shape_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/shape_tensorrt.cc deleted file mode 100644 index 941d374e429d1743f95762f0620810abb1ae4087..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/shape_tensorrt.cc +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/shape_tensorrt.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "infer/dynamic_shape.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -int ShapeTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} -int ShapeTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - nvinfer1::ITensor *shape_input = input(ctx, 0).trt_tensor_; - - nvinfer1::IShapeLayer *shape_layer = ctx->network()->addShape(*shape_input); - - if (shape_layer == nullptr) { - MS_LOG(ERROR) << "add shape op failed for TensorRT."; - return RET_ERROR; - } - shape_layer->setName(op_name_.c_str()); - ctx->RegisterTensor(ITensorHelper{shape_layer->getOutput(0), Format::NCHW, true}, out_tensors_[0].Name()); - this->layer_ = shape_layer; - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameShape, ShapeTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameDynamicShape, ShapeTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/shape_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/shape_tensorrt.h deleted file mode 100644 index 749e0329de7461b871171a01adc9bcb9138d7591..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/shape_tensorrt.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SHAPE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SHAPE_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class ShapeTensorRT : public TensorRTOp { - public: - ShapeTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~ShapeTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SHAPE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/shuffle_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/shuffle_tensorrt.cc deleted file mode 100644 index 3d6e81e62b80765db11778429943e29cf5ea0e43..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/shuffle_tensorrt.cc +++ /dev/null @@ -1,410 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/shuffle_tensorrt.h" -#include -#include -#include -#include -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "infer/unsqueeze.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" - -namespace mindspore::lite { -int ShuffleTensorRT::IsSqueezeSupport() { - constexpr size_t input_count_without_constant = 1; - constexpr size_t input_count_with_constant = 2; - if (in_tensors_.size() == input_count_without_constant) { - auto squeeze_op = AsOps(); - if (squeeze_op == nullptr) { - MS_LOG(ERROR) << "SqueezeOp convert failed"; - return RET_ERROR; - } - param_axis_ = squeeze_op->get_axis(); - } else if (in_tensors_.size() == input_count_with_constant) { - if (!in_tensors_[1].IsConst()) { - MS_LOG(ERROR) << "Expect input 1 to be const when input size is 2, type: " << type_ << ", op: " << op_name_; - return RET_ERROR; - } - auto axis = ConvertTensorAsIntVector(in_tensors_[1]); - std::copy(axis.begin(), axis.end(), std::back_inserter(param_axis_)); - } else { - MS_LOG(ERROR) << "Unsupported in_tensors size " << in_tensors_.size() << " of " << type_; - return RET_ERROR; - } - if (param_axis_.empty()) { - MS_LOG(WARNING) << op_name_ << " is a full dim squeeze, don't support dynamic input shape."; - dynamic_shape_params_.support_dynamic_ = false; - dynamic_shape_params_.support_hw_dynamic_ = false; - } - return RET_OK; -} -int ShuffleTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (type_ == ops::kNameFlatten || type_ == ops::kNameUnsqueeze) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported in_tensors size " << in_tensors.size() << " of " << type_; - return RET_ERROR; - } - } else if (type_ == ops::kNameSqueeze) { - return IsSqueezeSupport(); - } else if (type_ == ops::kNameReshape) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "PrimitiveType_Transpose Unsupported in_tensors size: " << in_tensors.size(); - return RET_ERROR; - } - dynamic_shape_params_.support_hw_dynamic_ = false; - } else if (type_ == ops::kNameTranspose || type_ == ops::kNameExpandDims) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "PrimitiveType_Transpose Unsupported in_tensors size: " << in_tensors.size(); - return RET_ERROR; - } - if (!in_tensors[1].IsConst()) { - MS_LOG(ERROR) << "Unsupported shape tensor of " << type_; - return RET_ERROR; - } - } else if (type_ == ops::kNameBroadcastTo) { - if (in_tensors.size() > INPUT_SIZE2) { - MS_LOG(ERROR) << "PrimitiveType_Transpose Unsupported in_tensors size: " << in_tensors.size(); - return RET_ERROR; - } - } else { - MS_LOG(ERROR) << "Unsupported op type:" << type_; - return RET_ERROR; - } - return RET_OK; -} - -int ShuffleTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - ctx_ = ctx; - - int ret = InputTensorPreprocess(ctx); - if (ret != RET_OK || shuffler_input_ == nullptr) { - MS_LOG(ERROR) << "InputTensorPreprocess failed for " << op_name_; - return RET_ERROR; - } - - nvinfer1::IShuffleLayer *shuffle_layer = ctx->network()->addShuffle(*shuffler_input_); - if (shuffle_layer == nullptr) { - MS_LOG(ERROR) << "add Shuffle op failed for TensorRT."; - return RET_ERROR; - } - shuffle_layer->setName(op_name_.c_str()); - this->layer_ = shuffle_layer; - - ret = RET_OK; - if (type_ == ops::kNameUnsqueeze) { - ret = AddUnsqueezeOp(shuffle_layer); - } else if (type_ == ops::kNameSqueeze) { - ret = AddSqueezeOp(shuffle_layer); - } else if (type_ == ops::kNameTranspose) { - ret = AddTransposeOp(shuffle_layer); - } else if (type_ == ops::kNameReshape) { - ret = AddReshapeOp(shuffle_layer); - } else if (type_ == ops::kNameFlatten) { - ret = AddFlattenOp(shuffle_layer); - } else if (type_ == ops::kNameExpandDims) { - ret = AddExpandDimsOp(shuffle_layer); - } else if (type_ == ops::kNameBroadcastTo) { - ret = AddBroadcastToOp(shuffle_layer); - } else { - MS_LOG(ERROR) << "Unsupported op type for " << op_name_; - return RET_ERROR; - } - if (ret != RET_OK) { - MS_LOG(ERROR) << "AddOp failed for " << op_name_; - return ret; - } - - if (shuffler_output_ == nullptr) { - MS_LOG(ERROR) << "output tensor create failed for " << op_name_; - return RET_ERROR; - } - auto output_helper = ITensorHelper{shuffler_output_, out_format_, true}; - ctx->RegisterTensor(output_helper, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "output " << GetTensorFormat(output_helper); - return RET_OK; -} - -int ShuffleTensorRT::InputTensorPreprocess(TensorRTContext *ctx) { - auto input_0 = in_tensors_[0]; - if (!ctx->HasTensor(input_0.Name()) && input_0.IsConst()) { - shuffler_input_ = lite::ConvertConstantTensor(ctx_, input_0, op_name_); - out_format_ = NCHW; - ctx->RegisterTensor({shuffler_input_}, input_0.Name()); - return RET_OK; - } - shuffler_input_ = input(ctx, 0).trt_tensor_; - MS_LOG(DEBUG) << "before transpose " << GetTensorFormat(input(ctx, 0)); - out_format_ = input(ctx, 0).format_; - - MS_LOG(DEBUG) << "after transpose " << GetTensorFormat(shuffler_input_, out_format_, true); - return RET_OK; -} - -int ShuffleTensorRT::AddSqueezeOp(nvinfer1::IShuffleLayer *shuffle_layer) { - // axis - auto squeeze_shape = shuffler_input_->getDimensions(); - std::vector new_shape(squeeze_shape.d, squeeze_shape.d + squeeze_shape.nbDims); - if (param_axis_.empty()) { - MS_LOG(WARNING) << op_name_ << " has null axis."; - for (int i = SizeToInt(new_shape.size()) - 1; i >= 0; i--) { - if (new_shape[i] == 1) { - new_shape.erase(new_shape.begin() + i); - } - } - if (new_shape.empty()) { - new_shape = {1}; - } - } else { - for (int i = SizeToInt(param_axis_.size()) - 1; i >= 0; i--) { - if (param_axis_[i] >= SizeToInt(new_shape.size()) || new_shape[param_axis_[i]] != 1) { - MS_LOG(WARNING) << "squeeze_shape value at " << i << " is " << param_axis_[i] << ", need check " << op_name_; - } - if (param_axis_[i] < 0) { - param_axis_[i] += shuffler_input_->getDimensions().nbDims; - } - new_shape.erase(new_shape.begin() + param_axis_[i]); - } - } - - std::vector subscripts(shuffler_input_->getDimensions().nbDims); - std::iota(subscripts.begin(), subscripts.end(), 0); - auto p = std::remove_if(subscripts.begin(), subscripts.end(), [&](int x) { - return std::find(param_axis_.begin(), param_axis_.end(), x) != param_axis_.end(); - }); - subscripts.resize(p - subscripts.begin()); - bool anyValuesUnknown = std::any_of(new_shape.begin(), new_shape.end(), [](int shape) { return shape == -1; }); - if (!anyValuesUnknown) { - nvinfer1::Dims squeeze_dims = lite::ConvertCudaDims(new_shape); - shuffle_layer->setReshapeDimensions(squeeze_dims); - } else { - auto squeeze_shape_tensor = ctx_->network()->addShape(*shuffler_input_)->getOutput(0); - auto subscripts_tensor = ctx_->ConvertTo1DTensor(subscripts); - auto newDims = ctx_->network()->addGather(*squeeze_shape_tensor, *subscripts_tensor, 0)->getOutput(0); - shuffle_layer->setInput(1, *newDims); - } - shuffler_output_ = shuffle_layer->getOutput(0); - return shuffler_output_ == nullptr ? RET_ERROR : RET_OK; -} - -int ShuffleTensorRT::AddUnsqueezeOp(nvinfer1::IShuffleLayer *shuffle_layer) { - // Unsqueeze - auto unsqueeze_op = AsOps(); - if (unsqueeze_op == nullptr) { - MS_LOG(ERROR) << "AddUnsqueezeOp convert failed"; - return RET_ERROR; - } - // axis - param_axis_ = unsqueeze_op->get_axis(); - if (param_axis_.empty()) { - MS_LOG(ERROR) << "axis is invalid for " << op_name_; - return RET_ERROR; - } - if (param_axis_.size() != 1) { - MS_LOG(WARNING) << op_name_ << " has unsqueeze axis size: " << param_axis_.size(); - } - nvinfer1::ITensor *expand_input = shuffler_input_; - if (input(ctx_, 0).is_tensor == true) { - for (size_t i = 0; i < param_axis_.size(); i++) { - expand_input = ExpandDim(ctx_, expand_input, param_axis_[i]); - } - } - shuffler_output_ = expand_input; - return shuffler_output_ == nullptr ? RET_ERROR : RET_OK; -} - -int ShuffleTensorRT::AddTransposeOp(nvinfer1::IShuffleLayer *shuffle_layer) { - if (shuffler_input_->getDimensions().nbDims != in_tensors_[1].ElementNum()) { - MS_LOG(WARNING) << "transpose perm is invalid for input, ignore " << op_name_; - shuffler_output_ = shuffler_input_; - return RET_OK; - } - auto transpose_op = AsOps(); - if (transpose_op == nullptr) { - MS_LOG(ERROR) << "AddTransposeOp convert failed"; - return RET_ERROR; - } - // perm - auto perm_ternsor = in_tensors_[1]; - if (!perm_ternsor.IsConst()) { - MS_LOG(ERROR) << "AddTransposeOp perm_ternsor data is invalid: " << op_name_; - return RET_ERROR; - } - - nvinfer1::Permutation perm{}; - if (perm_ternsor.DataType() == DataType::kNumberTypeInt64) { - auto perm_data = reinterpret_cast(perm_ternsor.Data()); - for (int64_t i = 0; i < perm_ternsor.ElementNum(); i++) { - perm.order[i] = perm_data[i]; - } - } else if (perm_ternsor.DataType() == DataType::kNumberTypeInt32) { - auto perm_data = reinterpret_cast(perm_ternsor.Data()); - for (int64_t i = 0; i < perm_ternsor.ElementNum(); i++) { - perm.order[i] = perm_data[i]; - } - } else { - MS_LOG(ERROR) << op_name_ << " perm tensor data type is " << static_cast(perm_ternsor.DataType()); - return RET_ERROR; - } - - shuffle_layer->setFirstTranspose(perm); - - shuffler_output_ = shuffle_layer->getOutput(0); - return RET_OK; -} - -int ShuffleTensorRT::AddReshapeOp(nvinfer1::IShuffleLayer *shuffle_layer) { - auto &shape_tensor = in_tensors_[1]; - if (shape_tensor.IsConst()) { - // static shuffle layer - auto reshape_dims = lite::ConvertCudaDims(shape_tensor); - if (reshape_dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return RET_ERROR; - } - shuffle_layer->setReshapeDimensions(reshape_dims); - } else { - if (in_tensors_.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "invalid shape tensor for reshape " << op_name_; - return RET_ERROR; - } - shuffle_layer->setInput(1, *input(ctx_, 1).trt_tensor_); - } - shuffler_output_ = shuffle_layer->getOutput(0); - return RET_OK; -} - -int ShuffleTensorRT::AddFlattenOp(nvinfer1::IShuffleLayer *shuffle_layer) { - nvinfer1::Dims flatten_dims; - nvinfer1::Dims dims = input(ctx_, 0).trt_tensor_->getDimensions(); - flatten_dims.nbDims = DIMENSION_2D; - flatten_dims.d[0] = dims.d[0] == -1 ? 0 : dims.d[0]; - flatten_dims.d[1] = std::accumulate(dims.d + 1, dims.d + dims.nbDims, 1, std::multiplies()); - if (flatten_dims.d[1] <= 0) { - MS_LOG(ERROR) << op_name_ << "infer shape failed"; - } - shuffle_layer->setReshapeDimensions(flatten_dims); - shuffler_output_ = shuffle_layer->getOutput(0); - return RET_OK; -} - -int ShuffleTensorRT::AddExpandDimsOp(nvinfer1::IShuffleLayer *shuffle_layer) { - if (!input(ctx_, 0).is_tensor) { - shuffler_output_ = shuffler_input_; - return RET_OK; - } - auto axis_vec = ConvertTensorAsIntVector(in_tensors_[1]); - if (axis_vec.size() != 1) { - MS_LOG(ERROR) << "Failed to get axis input, dim count " << axis_vec.size() << ", node: " << op_name_; - return RET_ERROR; - } - int axis = axis_vec[0]; - if (axis > (-1 - shuffler_input_->getDimensions().nbDims) && axis < -1) { - axis = shuffler_input_->getDimensions().nbDims + axis + 1; - } - - shuffler_output_ = ExpandDim(ctx_, shuffler_input_, axis); - return shuffler_output_ == nullptr ? RET_ERROR : RET_OK; -} - -int ShuffleTensorRT::AddBroadcastToOp(nvinfer1::IShuffleLayer *shuffle_layer) { - if (in_tensors_.size() > 1 && !in_tensors_[1].IsConst()) { -#if TRT_VERSION_GE(7, 2) - auto shape_tensor = input(ctx_, 1).trt_tensor_; - auto input = ctx_->network()->addShape(*shuffler_input_)->getOutput(0); - auto one_tensor = ctx_->ConvertTo1DTensor(1); - auto eq_one = - ctx_->network()->addElementWise(*shape_tensor, *one_tensor, nvinfer1::ElementWiseOperation::kEQUAL)->getOutput(0); - auto int_eq_one = TRTTensorCast(ctx_, eq_one, nvinfer1::DataType::kINT32, op_name_ + "_cast_int_one"); - if (int_eq_one == nullptr) { - MS_LOG(ERROR) << "int_eq_one is nullptr!"; - return RET_ERROR; - } - auto x = ctx_->network()->addElementWise(*int_eq_one, *input, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - auto zero_tensor = ctx_->ConvertTo1DTensor(0); - auto not_eq_one = - ctx_->network()->addElementWise(*zero_tensor, *int_eq_one, nvinfer1::ElementWiseOperation::kEQUAL)->getOutput(0); - auto int_not_eq_one = TRTTensorCast(ctx_, not_eq_one, nvinfer1::DataType::kINT32, op_name_ + "_cast_int_not_one"); - if (int_not_eq_one == nullptr) { - MS_LOG(ERROR) << "int_not_eq_one is nullptr!"; - return RET_ERROR; - } - auto y = ctx_->network() - ->addElementWise(*int_not_eq_one, *shape_tensor, nvinfer1::ElementWiseOperation::kPROD) - ->getOutput(0); - auto new_shape = ctx_->network()->addElementWise(*x, *y, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - shuffler_output_ = Broadcast(ctx_, shuffler_input_, new_shape); -#else - MS_LOG(WARNING) - << "low TensorRT version don't support broadcastto op, please upgrade TensorRT version to 7.2 or higher"; - return RET_ERROR; -#endif - } else { - std::vector input_shape; - if (in_tensors_.size() == 1) { - auto broadcast_op = AsOps(); - auto shape_64 = broadcast_op->get_shape(); - std::transform(shape_64.begin(), shape_64.end(), std::back_inserter(input_shape), [](auto x) { return x; }); - } else { - input_shape = ConvertTensorAsIntVector(in_tensors_[1]); - } - if (input_shape.empty()) { - MS_LOG(ERROR) << "Failed to get input shape from const input 1, node: " << op_name_; - return RET_ERROR; - } - - nvinfer1::Dims in_tensor_dims = shuffler_input_->getDimensions(); - auto input_shape_tensor = ctx_->ConvertTo1DTensor(input_shape); - - while (in_tensor_dims.nbDims < static_cast(input_shape.size())) { - shuffler_input_ = ExpandDim(ctx_, shuffler_input_, 0); - if (shuffler_input_->getDimensions().nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return RET_ERROR; - } - shuffle_layer->setReshapeDimensions(shuffler_input_->getDimensions()); - shuffler_input_ = shuffle_layer->getOutput(0); - in_tensor_dims = shuffler_input_->getDimensions(); - } - - auto size_tensor = ctx_->network()->addShape(*shuffler_input_)->getOutput(0); - size_tensor = ctx_->network() - ->addElementWise(*input_shape_tensor, *size_tensor, nvinfer1::ElementWiseOperation::kMAX) - ->getOutput(0); - shuffler_output_ = Broadcast(ctx_, shuffler_input_, size_tensor); - } - return shuffler_output_ == nullptr ? RET_ERROR : RET_OK; -} - -REGISTER_TENSORRT_CREATOR(ops::kNameUnsqueeze, ShuffleTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameSqueeze, ShuffleTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameReshape, ShuffleTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameTranspose, ShuffleTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameFlatten, ShuffleTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameExpandDims, ShuffleTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameBroadcastTo, ShuffleTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/shuffle_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/shuffle_tensorrt.h deleted file mode 100644 index 1377c169e85e9a82e418611a9f04270ed04a20e8..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/shuffle_tensorrt.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SHUFFLE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SHUFFLE_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" - -namespace mindspore::lite { -class ShuffleTensorRT : public TensorRTOp { - public: - ShuffleTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - ~ShuffleTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int InputTensorPreprocess(TensorRTContext *ctx); - int AddSqueezeOp(nvinfer1::IShuffleLayer *shuffle_layer); - int AddUnsqueezeOp(nvinfer1::IShuffleLayer *shuffle_layer); - int AddTransposeOp(nvinfer1::IShuffleLayer *shuffle_layer); - int AddReshapeOp(nvinfer1::IShuffleLayer *shuffle_layer); - int AddFlattenOp(nvinfer1::IShuffleLayer *shuffle_layer); - int AddExpandDimsOp(nvinfer1::IShuffleLayer *shuffle_layer); - int AddBroadcastToOp(nvinfer1::IShuffleLayer *shuffle_layer); - int IsSqueezeSupport(); - - Format out_format_ = Format::NCHW; - nvinfer1::ITensor *shuffler_input_{nullptr}; - nvinfer1::ITensor *shuffler_output_{nullptr}; - TensorRTContext *ctx_{nullptr}; - std::vector param_axis_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SHUFFLE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/slice_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/slice_tensorrt.cc deleted file mode 100644 index 55f134e1696acf670cfba9e26bf1310d8d7e154b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/slice_tensorrt.cc +++ /dev/null @@ -1,338 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/slice_tensorrt.h" -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "infer/cxx_api/slice_fusion.h" -#include "infer/crop.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -namespace { -class StrideSliceTensorRTUtil final : public SliceTensorRTUtil { - public: - StrideSliceTensorRTUtil() = default; - ~StrideSliceTensorRTUtil() = default; - bool IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override { - if (in_tensors.size() < HAS_AXIS - 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return false; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return false; - } - if (!in_tensors.at(BEGINS_INDEX).IsConst() || !in_tensors.at(ENDS_INDEX).IsConst()) { - MS_LOG(ERROR) << "invalid input tensor for: " << op_name_; - return false; - } - return true; - } - - bool GetConstInputValue(const std::vector &in_tensors, int *axis_val, int *start_val, int *end_val, - int *stride_val) { - int64_t axis_index = in_tensors.size() == HAS_AXIS ? AXIS_INDEX : -1; - - const auto &begin = in_tensors.at(BEGINS_INDEX); - const auto &stride = in_tensors.back(); - const auto &end = in_tensors.at(ENDS_INDEX); - - if (begin.ElementNum() != 1 || end.ElementNum() != 1 || stride.ElementNum() != 1) { - MS_LOG(ERROR) - << "Only support element number of begin, end and stride to be 1 when this number < input dims number, op: " - << op_name_; - return false; - } - *axis_val = 0; - if (axis_index != -1) { - auto axis_vec = ConvertTensorAsIntVector(in_tensors[axis_index]); - if (axis_vec.size() != 1) { - MS_LOG(ERROR) << "Failed to get axis input, node: " << op_name_ << ", axis count: " << axis_vec.size(); - return false; - } - *axis_val = axis_vec[0]; - } - auto start_vec = ConvertTensorAsIntVector(begin); - auto end_vec = ConvertTensorAsIntVector(end); - auto stride_vec = ConvertTensorAsIntVector(stride); - if (start_vec.size() != 1 || end_vec.size() != 1 || stride_vec.size() != 1) { - MS_LOG(ERROR) << "Failed to get start, end or stride input, node: " << op_name_; - return {}; - } - *start_val = start_vec[0]; - *end_val = end_vec[0]; - *stride_val = stride_vec[0]; - return true; - } - - std::tuple GetSliceParams(const BaseOperatorPtr &base_operator, - const std::vector &in_tensors, - const std::vector &out_tensors, - const ITensorHelper &helper) override { - const TensorInfo &begin = in_tensors.at(BEGINS_INDEX); - const TensorInfo &stride = in_tensors.back(); - const TensorInfo &end = in_tensors.at(ENDS_INDEX); - nvinfer1::Dims start_dims; - nvinfer1::Dims size_dims; - nvinfer1::Dims stride_dims; - - if (begin.ElementNum() == helper.trt_tensor_->getDimensions().nbDims) { - start_dims = lite::ConvertCudaDims(begin); - size_dims.nbDims = start_dims.nbDims; - auto end_dims = lite::ConvertCudaDims(end); - for (int i = 0; i < size_dims.nbDims; i++) { - size_dims.d[i] = end_dims.d[i] - start_dims.d[i]; - } - stride_dims = lite::ConvertCudaDims(stride); - } else { - int axis_value = 0; - int start_value = 0; - int end_value = 0; - int stride_value = 0; - if (!GetConstInputValue(in_tensors, &axis_value, &start_value, &end_value, &stride_value)) { - return {}; - } - auto input_dims = helper.trt_tensor_->getDimensions(); - start_dims.nbDims = input_dims.nbDims; - size_dims.nbDims = input_dims.nbDims; - stride_dims = nvinfer1::Dims{size_dims.nbDims, {}}; - std::fill(start_dims.d, start_dims.d + start_dims.nbDims, 0); - std::fill(stride_dims.d, stride_dims.d + stride_dims.nbDims, 1); - if (start_value < 0) { - start_value = input_dims.d[axis_value] + start_value; - } - for (int i = 0; i < start_dims.nbDims; i++) { - if (i == axis_value) { - start_dims.d[i] = start_value; - stride_dims.d[i] = stride_value; - if (end_value >= 0) { - size_dims.d[i] = std::min(end_value, input_dims.d[i]) - start_dims.d[i]; - } else if (end_value >= -input_dims.d[i]) { - size_dims.d[i] = end_value + input_dims.d[i] - start_dims.d[i]; - } else { - size_dims.d[i] = input_dims.d[i]; - } - } else { - size_dims.d[i] = helper.trt_tensor_->getDimensions().d[i]; - } - } - } - return std::make_tuple(start_dims, size_dims, stride_dims); - } - nvinfer1::ITensor *PostProcess(TensorRTContext *ctx, nvinfer1::ITensor *input, - const std::vector &in_tensors, - const std::vector &out_tensors) override { - if (shrink_axis_ != 0) { - auto shape = ConvertMSShape(input->getDimensions()); - for (int i = SizeToInt(shape.size()) - 1; i >= 0; --i) { - int mask = 1 << i; - if ((shrink_axis_ & mask) != 0) { - shape.erase(shape.begin() + i); - } - } - return shape.empty() ? nullptr : Reshape(ctx, input, shape); - } - return input; - } - void SetShrinkAxis(int shrink_axis) { shrink_axis_ = shrink_axis; } - - private: - int shrink_axis_; -}; - -class SliceFusionTensorRTUtil final : public SliceTensorRTUtil { - public: - SliceFusionTensorRTUtil() = default; - ~SliceFusionTensorRTUtil() = default; - bool IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override { - if (in_tensors.size() != SLICE_INPUT_SIZE) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return false; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return false; - } - return true; - } - std::tuple GetSliceParams(const BaseOperatorPtr &base_operator, - const std::vector &in_tensors, - const std::vector &out_tensors, - const ITensorHelper &helper) override { - const auto &begin = in_tensors.at(1); - const auto &size = in_tensors.at(SIZE_INDEX); - - auto start_dims = lite::ConvertCudaDims(begin); - auto size_dims = lite::ConvertCudaDims(size); - for (int i = 0; i < size_dims.nbDims; ++i) { - if (size_dims.d[i] == -1) { - size_dims.d[i] = helper.trt_tensor_->getDimensions().d[i]; - } - } - auto stride_dims = lite::ConvertCudaDims(1, begin.ElementNum()); - - return std::make_tuple(start_dims, size_dims, stride_dims); - } -}; - -class CropTensorRTUtil final : public SliceTensorRTUtil { - public: - CropTensorRTUtil() = default; - ~CropTensorRTUtil() = default; - bool IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override { - if (in_tensors.size() != CROP_INPUT_SIZE) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return false; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return false; - } - auto crop_primitive = TensorRTOp::AsOps(base_operator); - if (crop_primitive == nullptr) { - MS_LOG(ERROR) << "Cast primitive to crop fail"; - return false; - } - axis_ = static_cast(crop_primitive->get_axis()); - return true; - } - std::tuple GetSliceParams(const BaseOperatorPtr &base_operator, - const std::vector &in_tensors, - const std::vector &out_tensors, - const ITensorHelper &helper) override { - auto crop_primitive = TensorRTOp::AsOps(base_operator); - auto offsets_ptr = crop_primitive->get_offsets(); - if (offsets_ptr.empty()) { - MS_LOG(ERROR) << "Crop Op do not have offset attr"; - return {}; - } - if (axis_ < 0) { - axis_ += helper.trt_tensor_->getDimensions().nbDims; - } - if (axis_ < 0 || axis_ + SizeToInt(offsets_ptr.size()) != helper.trt_tensor_->getDimensions().nbDims) { - MS_LOG(ERROR) << "axis and offsets not match input tensor shape, axis is " << crop_primitive->get_axis() - << " , offsets size is " << offsets_ptr.size() << " , input size is " - << helper.trt_tensor_->getDimensions().nbDims; - return {}; - } - - std::vector begin(helper.trt_tensor_->getDimensions().nbDims, 0); - for (size_t i = 0; i != offsets_ptr.size(); ++i) { - begin[axis_ + i] = offsets_ptr[i]; - } - - std::vector size(helper.trt_tensor_->getDimensions().nbDims); - for (size_t i = 0; i != size.size(); ++i) { - size[i] = in_tensors.at(1).Shape().at(i); - } - - auto start_dims = lite::ConvertCudaDims(begin); - auto size_dims = lite::ConvertCudaDims(size); - auto stride_dims = lite::ConvertCudaDims(1, begin.size()); - - return std::make_tuple(start_dims, size_dims, stride_dims); - } - - private: - int axis_; -}; -} // namespace - -SliceTensorRT::SliceTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) { - if (type_ == ops::kNameStridedSlice) { - auto slice_fusion_util = std::make_unique(); - auto op = AsOps(); - slice_fusion_util->SetShrinkAxis(op->get_shrink_axis_mask()); - util_ = std::move(slice_fusion_util); - } else if (type_ == ops::kNameSliceFusion) { - util_ = std::make_unique(); - } else if (type_ == ops::kNameCrop) { - util_ = std::make_unique(); - } else { - util_ = nullptr; - } - if (util_ != nullptr) { - util_->op_name_ = op_name_; - } -} - -int SliceTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (util_ == nullptr) { - MS_LOG(ERROR) << "Unsupported op_type: " << op_name_; - return RET_ERROR; - } - if (!util_->IsSupport(base_operator, in_tensors, out_tensors)) { - return RET_ERROR; - } - dynamic_shape_params_.support_dynamic_ = false; - dynamic_shape_params_.support_hw_dynamic_ = false; - return RET_OK; -} - -int SliceTensorRT::AddInnerOp(TensorRTContext *ctx) { - ITensorHelper slice_input; - int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &slice_input); - if (ret != RET_OK || slice_input.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim input tensor failed for " << op_name_; - return RET_ERROR; - } - - nvinfer1::Dims start_dims; - nvinfer1::Dims size_dims; - nvinfer1::Dims stride_dims; - std::tie(start_dims, size_dims, stride_dims) = - util_->GetSliceParams(base_operator_, in_tensors_, out_tensors_, slice_input); - if (start_dims.nbDims == -1 || size_dims.nbDims == -1 || stride_dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - return RET_ERROR; - } - - nvinfer1::ISliceLayer *slice_layer = - ctx->network()->addSlice(*slice_input.trt_tensor_, start_dims, size_dims, stride_dims); - if (slice_layer == nullptr) { - MS_LOG(ERROR) << "add Slice op failed for TensorRT: " << op_name_; - return RET_ERROR; - } - this->layer_ = slice_layer; - slice_layer->setName(op_name_.c_str()); - nvinfer1::ITensor *out_tensor = slice_layer->getOutput(0); - auto post_tensor = util_->PostProcess(ctx, out_tensor, in_tensors_, out_tensors_); - bool rank_0 = false; - if (post_tensor == nullptr) { - rank_0 = true; - post_tensor = out_tensor; - } - auto helper = ITensorHelper{post_tensor, slice_input.format_, slice_input.same_format_, !rank_0}; - ctx->RegisterTensor(helper, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "slice output : " << GetTensorFormat(helper); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameCrop, SliceTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/slice_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/slice_tensorrt.h deleted file mode 100644 index f5735ace65025d68fd3d2d9cd508bf48b141e0ee..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/slice_tensorrt.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SLICE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SLICE_TENSORRT_H_ -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class SliceTensorRTUtil { - public: - SliceTensorRTUtil() = default; - virtual ~SliceTensorRTUtil() = default; - virtual bool IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) = 0; - virtual std::tuple GetSliceParams( - const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, const ITensorHelper &helper) = 0; - virtual nvinfer1::ITensor *PostProcess(TensorRTContext *ctx, nvinfer1::ITensor *input, - const std::vector &in_tensors, - const std::vector &out_tensors) { - return input; - } - std::string op_name_; -}; - -constexpr int BEGINS_INDEX = 1; -constexpr int ENDS_INDEX = 2; -constexpr int SIZE_INDEX = 2; -constexpr int HAS_AXIS = 5; -constexpr int AXIS_INDEX = 3; -constexpr int CROP_INPUT_SIZE = 2; -constexpr int SLICE_INPUT_SIZE = 3; -class SliceTensorRT : public TensorRTOp { - public: - SliceTensorRT(const BaseOperatorPtr &base_operator, const std::vector &inputs, - const std::vector &outputs, std::string name); - - ~SliceTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - std::unique_ptr util_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SLICE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/slicefusion_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/slicefusion_tensorrt.cc deleted file mode 100644 index 8fab16c08d8e0cccc0bbdcb62f99772633d1a51a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/slicefusion_tensorrt.cc +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/slicefusion_tensorrt.h" -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "infer/cxx_api/slice_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -nvinfer1::ITensor *SliceFusionTensorRT::GetDynamicSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input, - nvinfer1::Dims start_dims, nvinfer1::Dims size_dims) { - auto in_tensor_shape = ctx->network()->addShape(*input)->getOutput(0); - if (in_tensor_shape == nullptr) { - MS_LOG(ERROR) << "add shape layer of input failed!"; - return nullptr; - } - std::vector shape_tensors; - auto input_dims = input->getDimensions(); - std::vector input_shape_vec; - for (int i = 0; i < input_dims.nbDims; ++i) { - if (input_dims.d[i] == -1) { - if (!input_shape_vec.empty()) { - shape_tensors.push_back(ctx->ConvertTo1DTensor(input_shape_vec)); - input_shape_vec.clear(); - } - auto starts = nvinfer1::Dims{1, {i}}; - auto size = nvinfer1::Dims{1, {1}}; - auto strides = nvinfer1::Dims{1, {1}}; - auto slice_layer = ctx->network()->addSlice(*in_tensor_shape, starts, size, strides); - if (slice_layer == nullptr) { - MS_LOG(ERROR) << "add slice layer failed"; - return nullptr; - } - auto start_tensor = ctx->ConvertTo1DTensor(start_dims.d[i]); - shape_tensors.push_back( - ctx->network() - ->addElementWise(*slice_layer->getOutput(0), *start_tensor, nvinfer1::ElementWiseOperation::kSUB) - ->getOutput(0)); - } else { - input_shape_vec.push_back(size_dims.d[i]); - } - } - if (!input_shape_vec.empty()) { - shape_tensors.push_back(ctx->ConvertTo1DTensor(input_shape_vec)); - } - nvinfer1::ITensor *concat_tensors[shape_tensors.size()]; - for (size_t i = 0; i != shape_tensors.size(); ++i) { - concat_tensors[i] = shape_tensors[i]; - } - auto concat_layer = ctx->network()->addConcatenation(concat_tensors, shape_tensors.size()); - if (concat_layer == nullptr) { - MS_LOG(ERROR) << "add concat layer failed!"; - return nullptr; - } - concat_layer->setAxis(0); - - return concat_layer->getOutput(0); -} - -int SliceFusionTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != SLICE_INPUT_SIZE) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - dynamic_shape_params_.support_hw_dynamic_ = false; - return RET_OK; -} - -int SliceFusionTensorRT::AddInnerOp(TensorRTContext *ctx) { - ITensorHelper slice_input; - int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &slice_input); - if (ret != RET_OK || slice_input.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim input tensor failed for " << op_name_; - return RET_ERROR; - } - - const auto &begin = in_tensors_.at(1); - const auto &size = in_tensors_.at(SIZE_INDEX); - - auto start_dims = lite::ConvertCudaDims(begin); - auto size_dims = lite::ConvertCudaDims(size); - nvinfer1::ITensor *size_tensor = nullptr; - for (int i = 0; i < size_dims.nbDims; ++i) { - if (size_dims.d[i] == -1 && !IsDynamicInput(ctx, 0)) { - size_dims.d[i] = slice_input.trt_tensor_->getDimensions().d[i]; - } - } - if (IsDynamicInput(ctx, 0)) { - size_tensor = GetDynamicSliceSize(ctx, slice_input.trt_tensor_, start_dims, size_dims); - size_dims = nvinfer1::Dims{-1}; - } - auto stride_dims = lite::ConvertCudaDims(1, begin.ElementNum()); - - nvinfer1::ISliceLayer *slice_layer = - ctx->network()->addSlice(*slice_input.trt_tensor_, start_dims, size_dims, stride_dims); - if (slice_layer == nullptr) { - MS_LOG(ERROR) << "add Slice op failed for TensorRT: " << op_name_; - return RET_ERROR; - } - if (size_tensor != nullptr) { - slice_layer->setInput(INPUT_SIZE2, *size_tensor); - } - this->layer_ = slice_layer; - slice_layer->setName(op_name_.c_str()); - nvinfer1::ITensor *out_tensor = slice_layer->getOutput(0); - auto helper = ITensorHelper{out_tensor, slice_input.format_, slice_input.same_format_}; - ctx->RegisterTensor(helper, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "slice output : " << GetTensorFormat(helper); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameSliceFusion, SliceFusionTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/slicefusion_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/slicefusion_tensorrt.h deleted file mode 100644 index f7e610bed2905a4ef1709523da5ba253f5897363..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/slicefusion_tensorrt.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SLICE_FUSION_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SLICE_FUSION_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -constexpr int SIZE_INDEX = 2; -constexpr int SLICE_INPUT_SIZE = 3; -class SliceFusionTensorRT : public TensorRTOp { - public: - SliceFusionTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~SliceFusionTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - nvinfer1::ITensor *GetDynamicSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input, nvinfer1::Dims, - nvinfer1::Dims); -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SLICE_FUSION_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/softmax_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/softmax_tensorrt.cc deleted file mode 100644 index 159b83d5d757da154eccf340fdd362e8cf8a97eb..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/softmax_tensorrt.cc +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/softmax_tensorrt.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -int SoftMaxTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (type_ == ops::kNameSoftmax) { - auto softmax_op = AsOps(); - auto axis = softmax_op->get_axis(); - axis_val_ = std::vector(axis.begin(), axis.end()); - } - - if (type_ == ops::kNameLogSoftmax) { - auto log_softmax_op = AsOps(); - auto axis = log_softmax_op->get_axis(); - axis_val_ = std::vector(1, axis); - } - - if (axis_val_.size() != 1) { - MS_LOG(ERROR) << "axis needs check"; - return RET_ERROR; - } - - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} -int SoftMaxTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "network is invalid"; - return RET_ERROR; - } - nvinfer1::ISoftMaxLayer *softmax_layer_ = AddSoftMaxOp(ctx); - if (softmax_layer_ == nullptr) { - MS_LOG(ERROR) << "add softmax op failed for TensorRT."; - return RET_ERROR; - } - softmax_layer_->setName((op_name_ + "_softmax").c_str()); - this->layer_ = softmax_layer_; - - nvinfer1::ITensor *out_tensor = softmax_layer_->getOutput(0); - if (out_tensor == nullptr) { - MS_LOG(ERROR) << "softmax output tensor create failed for TensorRT."; - return RET_ERROR; - } - if (type_ == ops::kNameLogSoftmax) { - out_tensor = ctx->network()->addUnary(*out_tensor, nvinfer1::UnaryOperation::kLOG)->getOutput(0); - } - ctx->RegisterTensor(ITensorHelper{out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} - -nvinfer1::ISoftMaxLayer *SoftMaxTensorRT::AddSoftMaxOp(TensorRTContext *ctx) { - nvinfer1::ISoftMaxLayer *current_layer_ = ctx->network()->addSoftMax(*input(ctx, 0).trt_tensor_); - if (current_layer_ == nullptr) { - MS_LOG(ERROR) << "add softmax op failed for TensorRT."; - return nullptr; - } - - int64_t axis_format_value = - (axis_val_[0] == -1) ? input(ctx, 0).trt_tensor_->getDimensions().nbDims - 1 : axis_val_[0]; - uint32_t axis_bit = 1 << axis_format_value; - MS_LOG(DEBUG) << op_name_ << " axis_value is " << axis_format_value << ", set axis to " << axis_bit; - current_layer_->setAxes(axis_bit); - return current_layer_; -} -REGISTER_TENSORRT_CREATOR(ops::kNameSoftmax, SoftMaxTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameLogSoftmax, SoftMaxTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/spacetobatch_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/spacetobatch_tensorrt.cc deleted file mode 100644 index 5c292cff5c3d0c54f362890cbe2b451866760e6b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/spacetobatch_tensorrt.cc +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/spacetobatch_tensorrt.h" -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "kernel/gpu/cuda_impl/cuda_ops/spacetobatch_impl.cuh" -#include "infer/space_to_batch_nd.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -int SpaceToBatchTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - - if (out_tensors.size() < 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int SpaceToBatchTensorRT::AddInnerOp(TensorRTContext *ctx) { - nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_; - auto block_size_vec = ConvertTensorAsIntVector(in_tensors_[1]); - constexpr size_t block_input_elem_count = 2; - if (block_size_vec.size() != block_input_elem_count) { - MS_LOG(ERROR) << "Failed to get block input, block size " << block_size_vec.size() << ", node: " << op_name_; - return RET_ERROR; - } - int bh = block_size_vec[0]; - int bw = block_size_vec[1]; - if (bh != bw) { - MS_LOG(ERROR) << "block_h not equal block_w " << op_name_; - return RET_ERROR; - } - auto pad_vec = ConvertTensorAsIntVector(in_tensors_[INPUT_SIZE2]); - constexpr size_t pad_input_elem_count = 4; - if (pad_vec.size() != pad_input_elem_count) { - MS_LOG(ERROR) << "Failed to get pad input, pad size " << pad_vec.size() << ", node: " << op_name_; - return RET_ERROR; - } - int ph0 = pad_vec[0]; - int ph1 = pad_vec[1]; - int pw0 = pad_vec[INPUT_SIZE2]; - int pw1 = pad_vec[INPUT_SIZE3]; - - auto plugin = std::make_shared(input_tensor->getName(), bh, ph0, ph1, pw0, pw1, device_id_); - if (plugin == nullptr) { - MS_LOG(ERROR) << "add spacetobatch plugin failed for" << op_name_; - return RET_ERROR; - } - nvinfer1::ITensor *inputTensors[] = {input_tensor}; - nvinfer1::IPluginV2Layer *space2batch_opt_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); - if (space2batch_opt_layer == nullptr) { - MS_LOG(ERROR) << "add spacetobatch op failed for TensorRT."; - return RET_ERROR; - } - space2batch_opt_layer->setName(op_name_.c_str()); - nvinfer1::ITensor *out_tensor = space2batch_opt_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - this->layer_ = space2batch_opt_layer; - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(SpaceToBatchPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int SpaceToBatchPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - return RunCudaSpaceToBatch(inputDesc, inputs, outputs, stream); -} - -int SpaceToBatchPlugin::RunCudaSpaceToBatch(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - nvinfer1::Dims input_dims = inputDesc[0].dims; - int in = input_dims.d[0]; - int ic = input_dims.d[1]; - int ih = input_dims.d[2]; - int iw = input_dims.d[3]; - int on = in * bh_ * bh_; - int oc = ic; - int oh = (ih + ph0_ + ph1_) / bh_; - int ow = (iw + pw0_ + pw1_) / bh_; - - int size = in * ic * ih * iw; - - CalSpaceToBatch(size, static_cast(inputs[0]), in, ih, iw, ic, on, oh, ow, oc, ph0_, ph1_, pw0_, - pw1_, bh_, static_cast(outputs[0]), device_id_, stream); - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *SpaceToBatchPlugin::clone() const noexcept { - auto *plugin = new (std::nothrow) SpaceToBatchPlugin(*this); - if (plugin == nullptr) { - MS_LOG(ERROR) << "new plugin failed!"; - return nullptr; - } - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -size_t SpaceToBatchPlugin::getSerializationSize() const noexcept { return sizeof(int) * 5; } - -nvinfer1::DimsExprs SpaceToBatchPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, - int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept { - nvinfer1::DimsExprs dims; - dims.nbDims = inputs[0].nbDims; - auto bh_mul_bh = exprBuilder.constant(bh_ * bh_); - dims.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[0].d[0], *bh_mul_bh); - dims.d[1] = inputs[0].d[1]; - auto bh = exprBuilder.constant(bh_); - auto ph_sum = exprBuilder.constant(ph0_ + ph1_); - auto sum0 = exprBuilder.operation(nvinfer1::DimensionOperation::kSUM, *inputs[0].d[INPUT_SIZE2], *ph_sum); - dims.d[INPUT_SIZE2] = exprBuilder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, *sum0, *bh); - auto pw_sum = exprBuilder.constant(pw0_ + pw1_); - auto sum1 = exprBuilder.operation(nvinfer1::DimensionOperation::kSUM, *inputs[0].d[INPUT_SIZE3], *pw_sum); - dims.d[INPUT_SIZE3] = exprBuilder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, *sum1, *bh); - return dims; -} - -void SpaceToBatchPlugin::serialize(void *buffer) const noexcept { - SerializeValue(&buffer, &bh_, sizeof(int)); - SerializeValue(&buffer, &ph0_, sizeof(int)); - SerializeValue(&buffer, &ph1_, sizeof(int)); - SerializeValue(&buffer, &pw0_, sizeof(int)); - SerializeValue(&buffer, &pw1_, sizeof(int)); -} -REGISTER_TENSORRT_CREATOR(ops::kNameSpaceToBatchND, SpaceToBatchTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/spacetobatch_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/spacetobatch_tensorrt.h deleted file mode 100644 index 67e6786f337d2b560e2100038b907a8778e34a80..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/spacetobatch_tensorrt.h +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SPACETOTENSORRT_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SPACETOTENSORRT_PLUGIN_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class SpaceToBatchTensorRT : public TensorRTOp { - public: - SpaceToBatchTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~SpaceToBatchTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto SPACETOTENSORRT_PLUGIN_NAME{"SpaceToBatchPlugin"}; -class SpaceToBatchPlugin : public TensorRTPlugin { - public: - SpaceToBatchPlugin(const std::string name, int bh, int ph0, int ph1, int pw0, int pw1, uint32_t device_id) - : TensorRTPlugin(name, std::string(SPACETOTENSORRT_PLUGIN_NAME), device_id), - bh_(bh), - ph0_(ph0), - ph1_(ph1), - pw0_(pw0), - pw1_(pw1) {} - - SpaceToBatchPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(SPACETOTENSORRT_PLUGIN_NAME)) { - const nvinfer1::PluginField *fields = fc->fields; - bh_ = static_cast(fields[0].data)[0]; - ph0_ = static_cast(fields[1].data)[0]; - ph1_ = static_cast(fields[2].data)[0]; - pw0_ = static_cast(fields[3].data)[0]; - pw1_ = static_cast(fields[4].data)[0]; - } - - SpaceToBatchPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(SPACETOTENSORRT_PLUGIN_NAME)) { - DeserializeValue(&serialData, &serialLength, &bh_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &ph0_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &ph1_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &pw0_, sizeof(int)); - DeserializeValue(&serialData, &serialLength, &pw1_, sizeof(int)); - } - - SpaceToBatchPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept { - return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT && - tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; - } - - private: - int RunCudaSpaceToBatch(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); - int bh_; - int ph0_; - int ph1_; - int pw0_; - int pw1_; - const std::string layer_name_; - std::string name_space_; -}; -class SpaceToBatchPluginCreater : public TensorRTPluginCreater { - public: - SpaceToBatchPluginCreater() : TensorRTPluginCreater(std::string(SPACETOTENSORRT_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SPACETOTENSORRT_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/split_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/split_tensorrt.cc deleted file mode 100644 index 1771c2871d6f374e098df864dbbbd566b143d4ae..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/split_tensorrt.cc +++ /dev/null @@ -1,199 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/split_tensorrt.h" -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "infer/unstack.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" - -namespace mindspore::lite { -int SplitTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1 && in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -nvinfer1::ITensor *SplitTensorRT::GetDynamicSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input, size_t i) { - auto in_tensor_shape = ctx->network()->addShape(*input)->getOutput(0); - if (in_tensor_shape == nullptr) { - MS_LOG(ERROR) << "add shape layer of input failed!"; - return nullptr; - } - auto len_tensor = ctx->ConvertTo1DTensor(static_cast(size_splits_[i])); - if (len_tensor == nullptr) { - MS_LOG(ERROR) << "convert 1d tensor failed!"; - return nullptr; - } - - nvinfer1::ITensor *concat_input_tensors[INPUT_SIZE2]; - concat_input_tensors[0] = in_tensor_shape; - concat_input_tensors[1] = len_tensor; - auto concat_layer = ctx->network()->addConcatenation(concat_input_tensors, INPUT_SIZE2); - if (concat_layer == nullptr) { - MS_LOG(ERROR) << "add concat layer failed!"; - return nullptr; - } - concat_layer->setAxis(0); - auto shape_and_len = concat_layer->getOutput(0); - if (shape_and_len == nullptr) { - MS_LOG(ERROR) << "get concat layer result failed!"; - return nullptr; - } - - std::vector gather_slices(input->getDimensions().nbDims); - std::iota(gather_slices.begin(), gather_slices.end(), 0); - gather_slices[axis_] = gather_slices.size(); - auto gather_slices_tensor = ctx->ConvertTo1DTensor(gather_slices); - nvinfer1::IGatherLayer *gather_layer = ctx->network()->addGather(*shape_and_len, *gather_slices_tensor, 0); - if (gather_layer == nullptr) { - MS_LOG(ERROR) << "add gather layer failed!"; - return nullptr; - } - - return gather_layer->getOutput(0); -} - -int SplitTensorRT::AddInnerOp(TensorRTContext *ctx) { - ITensorHelper split_input; - int ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), &split_input); - if (ret != RET_OK || split_input.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "PreprocessInputs2SameDim input tensor failed for " << op_name_; - return ret; - } - - ret = ParseParams(split_input); - if (ret != RET_OK) { - MS_LOG(ERROR) << op_name_ << " parse params failed."; - return ret; - } - - int input_nbdims = split_input.trt_tensor_->getDimensions().nbDims; - axis_ = axis_ < 0 ? axis_ + input_nbdims : axis_; - - if (axis_ < 0 || axis_ >= input_nbdims) { - MS_LOG(ERROR) << "invalid axis : " << axis_; - return RET_ERROR; - } - int split_sum = std::accumulate(size_splits_.begin(), size_splits_.end(), 0); - int split_sum_expect = split_input.trt_tensor_->getDimensions().d[axis_]; - - if (size_splits_[size_splits_.size() - 1] == -1) { - size_splits_[size_splits_.size() - 1] = split_sum_expect - split_sum - 1; - split_sum = split_sum_expect; - } - - if (split_sum != split_sum_expect) { - MS_LOG(ERROR) << "Sum of size splits not equal input tensor dim. "; - return RET_ERROR; - } - - int axis_dim_index = 0; - nvinfer1::Dims one_dims = lite::ConvertCudaDims(1, input_nbdims); - nvinfer1::ISliceLayer *slice_layer = nullptr; - - for (int i = 0; i < output_num_; ++i) { - nvinfer1::Dims start_dims = lite::ConvertCudaDims(0, input_nbdims); - start_dims.d[axis_] = axis_dim_index; - nvinfer1::Dims size_dims{-1}; - nvinfer1::ITensor *size_tensor = nullptr; - if (!IsDynamicInput(ctx, 0)) { - size_dims = split_input.trt_tensor_->getDimensions(); - size_dims.d[axis_] = size_splits_[i]; - } else { - size_tensor = GetDynamicSliceSize(ctx, split_input.trt_tensor_, i); - } - axis_dim_index += size_splits_[i]; - - slice_layer = ctx->network()->addSlice(*split_input.trt_tensor_, start_dims, size_dims, one_dims); - if (slice_layer == nullptr) { - MS_LOG(ERROR) << "add Slice op failed for TensorRT: " << op_name_; - return RET_ERROR; - } - if (size_tensor != nullptr) { - slice_layer->setInput(INPUT_SIZE2, *size_tensor); - } - - nvinfer1::ITensor *out_tensor = slice_layer->getOutput(0); - bool res_is_tensor = true; - if (type_ == ops::kNameUnstack) { - auto shuffer_layer = ctx->network()->addShuffle(*out_tensor); - res_is_tensor = out_tensor->getDimensions().nbDims > 1; - if (res_is_tensor) { - auto shuffer_dims_opt = SqueezeDims(out_tensor->getDimensions(), axis_); - if (!shuffer_dims_opt) { - MS_LOG(ERROR) << "SqueezeDims failed."; - return RET_ERROR; - } - shuffer_layer->setReshapeDimensions(shuffer_dims_opt.value()); - out_tensor = shuffer_layer->getOutput(0); - } - } - ctx->RegisterTensor(ITensorHelper{out_tensor, split_input.format_, split_input.same_format_, res_is_tensor}, - out_tensors_[i].Name()); - } - this->layer_ = slice_layer; - return RET_OK; -} - -int SplitTensorRT::ParseParams(const ITensorHelper &helper) { - if (type_ == ops::kNameSplit) { - auto split_op = AsOps(); - CHECK_NULL_RETURN(split_op); - axis_ = split_op->get_axis(); - output_num_ = split_op->get_output_num(); - auto size_splits_ptr = GetValue>(split_op->GetAttr("size_splits")); - if (!size_splits_ptr.empty()) { - size_splits_.resize(size_splits_ptr.size()); - std::copy(size_splits_ptr.begin(), size_splits_ptr.end(), size_splits_.begin()); - } else if (in_tensors_.size() == INPUT_SIZE2 && in_tensors_[1].IsConst() && - (in_tensors_[1].DataType() == DataType::kNumberTypeInt32 || - in_tensors_[1].DataType() == DataType::kNumberTypeInt64)) { - size_splits_ = ConvertTensorAsIntVector(in_tensors_[1]); - } else { - MS_LOG(INFO) << op_name_ << " has invalid input size and size_splits: " << in_tensors_.size(); - } - } else if (type_ == ops::kNameUnstack) { - auto unstack_op = AsOps(); - CHECK_NULL_RETURN(unstack_op); - axis_ = unstack_op->get_axis(); - output_num_ = out_tensors_.size(); - } else { - MS_LOG(ERROR) << op_name_ << " has invalid type for split"; - return RET_ERROR; - } - int axis_dim = helper.trt_tensor_->getDimensions().d[axis_]; - if (size_splits_.empty()) { - if (output_num_ == 0 || axis_dim % output_num_ != 0) { - MS_LOG(ERROR) << "axis dim can not be split into same subdim output_num : " << output_num_ - << " axis_dim: " << axis_dim; - return RET_ERROR; - } - int split_width = axis_dim / output_num_; - size_splits_.resize(output_num_); - std::fill(size_splits_.begin(), size_splits_.end(), split_width); - } - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameSplit, SplitTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameUnstack, SplitTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/split_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/split_tensorrt.h deleted file mode 100644 index e5788a39e949a512f55ea00ad3f1b6c133c51897..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/split_tensorrt.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SPLIT_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SPLIT_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class SplitTensorRT : public TensorRTOp { - public: - SplitTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~SplitTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - nvinfer1::ITensor *GetDynamicSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input, size_t i); - int ParseParams(const ITensorHelper &helper); - int64_t axis_; - int output_num_; - std::vector size_splits_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SPLIT_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/square_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/square_tensorrt.cc deleted file mode 100644 index 03965389de58f3edafb83ebe57555e288b525151..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/square_tensorrt.cc +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/square_tensorrt.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -int SquareTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int SquareTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - auto norm_op = AsOps(); - CHECK_NULL_RETURN(norm_op); - int input_nbdims = input(ctx, 0).trt_tensor_->getDimensions().nbDims; - if (input_nbdims == -1) { - MS_LOG(ERROR) << "square failed for " << op_name_; - return RET_ERROR; - } - int ret = RunAsTrtOps(ctx); - if (ret != RET_OK) { - MS_LOG(ERROR) << "square failed for " << op_name_; - return ret; - } - return ret; -} - -int SquareTensorRT::RunAsTrtOps(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is invalid"; - return RET_ERROR; - } - auto square_layer = ctx->network()->addElementWise(*input(ctx, 0).trt_tensor_, *input(ctx, 0).trt_tensor_, - nvinfer1::ElementWiseOperation::kPROD); - CHECK_NULL_RETURN(square_layer); - auto out_tensor = square_layer->getOutput(0); - CHECK_NULL_RETURN(out_tensor); - this->layer_ = square_layer; - ctx->RegisterTensor(ITensorHelper{out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameSquare, SquareTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/square_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/square_tensorrt.h deleted file mode 100644 index a571484414264da8d265f0eb8ce0b2bd259dab5e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/square_tensorrt.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SQUARE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SQUARE_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class SquareTensorRT : public TensorRTOp { - public: - SquareTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~SquareTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int RunAsTrtOps(TensorRTContext *ctx); -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_SQUARE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/strideslice_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/strideslice_tensorrt.cc deleted file mode 100644 index d1cfeafd4dcd2bd4a17773d82588de5434397669..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/strideslice_tensorrt.cc +++ /dev/null @@ -1,429 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/strideslice_tensorrt.h" -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -nvinfer1::ITensor *StrideSliceTensorRT::GetDynamicAxisSliceStart(TensorRTContext *ctx, nvinfer1::ITensor *input, - int axis, int nbdims) { - if (axis == 0 && nbdims == 1) { - return input; - } - std::vector gather_inputs; - if (axis == 0) { - gather_inputs.push_back(input); - gather_inputs.push_back(ctx->ConvertTo1DTensor(std::vector(nbdims - 1, 0))); - } else if (axis == nbdims - 1) { - gather_inputs.push_back(ctx->ConvertTo1DTensor(std::vector(nbdims - 1, 0))); - gather_inputs.push_back(input); - } else { - gather_inputs.push_back(ctx->ConvertTo1DTensor(std::vector(axis, 0))); - gather_inputs.push_back(input); - gather_inputs.push_back(ctx->ConvertTo1DTensor(std::vector(nbdims - 1 - axis, 0))); - } - auto concat_layer = ctx->network()->addConcatenation(gather_inputs.data(), gather_inputs.size()); - if (concat_layer == nullptr) { - MS_LOG(ERROR) << "add concat layer failed!"; - return nullptr; - } - concat_layer->setAxis(0); - return concat_layer->getOutput(0); -} - -nvinfer1::ITensor *StrideSliceTensorRT::GetDynamicAxisSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input, - int size_dim, int axis, - nvinfer1::ITensor *size_tensor) { - auto in_tensor_shape = ctx->network()->addShape(*input)->getOutput(0); - if (in_tensor_shape == nullptr) { - MS_LOG(ERROR) << "add shape layer of input failed!"; - return nullptr; - } - auto len_tensor = (size_tensor == nullptr ? ctx->ConvertTo1DTensor(static_cast(size_dim)) : size_tensor); - if (len_tensor == nullptr) { - MS_LOG(ERROR) << "convert 1d tensor failed!"; - return nullptr; - } - - nvinfer1::ITensor *concat_input_tensors[INPUT_SIZE2]; - concat_input_tensors[0] = in_tensor_shape; - concat_input_tensors[1] = len_tensor; - auto concat_layer = ctx->network()->addConcatenation(concat_input_tensors, INPUT_SIZE2); - if (concat_layer == nullptr) { - MS_LOG(ERROR) << "add concat layer failed!"; - return nullptr; - } - concat_layer->setAxis(0); - auto shape_and_len = concat_layer->getOutput(0); - if (shape_and_len == nullptr) { - MS_LOG(ERROR) << "get concat layer result failed!"; - return nullptr; - } - - std::vector gather_slices(input->getDimensions().nbDims); - std::iota(gather_slices.begin(), gather_slices.end(), 0); - gather_slices[axis] = gather_slices.size(); - auto gather_slices_tensor = ctx->ConvertTo1DTensor(gather_slices); - nvinfer1::IGatherLayer *gather_layer = ctx->network()->addGather(*shape_and_len, *gather_slices_tensor, 0); - if (gather_layer == nullptr) { - MS_LOG(ERROR) << "add gather layer failed!"; - return nullptr; - } - - return gather_layer->getOutput(0); -} - -nvinfer1::ITensor *StrideSliceTensorRT::GetDynamicSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *slice_input, - size_t end_mask) { - std::vector end_mask_axis; - std::vector end_unmask_axis; - std::vector start_vector; - for (int i = 0; i < size_dims_.nbDims; i++) { - start_vector.push_back(start_dims_.d[i]); - size_t mask = 1 << i; - if ((end_mask & mask) == 0) { - end_mask_axis.push_back(1); - end_unmask_axis.push_back(0); - } else { - end_mask_axis.push_back(0); - end_unmask_axis.push_back(1); - } - } - auto end_mask_tensor = ctx->ConvertTo1DTensor(end_mask_axis); - auto end_tensor = - ctx->network() - ->addElementWise(*input(ctx, INPUT_SIZE2).trt_tensor_, *end_mask_tensor, nvinfer1::ElementWiseOperation::kPROD) - ->getOutput(0); - auto end_unmask_tensor = ctx->ConvertTo1DTensor(end_unmask_axis); - auto input_shape = ctx->network()->addShape(*slice_input)->getOutput(0); - auto unmask_tensor = ctx->network() - ->addElementWise(*input_shape, *end_unmask_tensor, nvinfer1::ElementWiseOperation::kPROD) - ->getOutput(0); - auto real_end_tensor = - ctx->network()->addElementWise(*end_tensor, *unmask_tensor, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - auto start_tensor = ctx->ConvertTo1DTensor(start_vector); - auto size_tensor = - ctx->network()->addElementWise(*real_end_tensor, *start_tensor, nvinfer1::ElementWiseOperation::kSUB)->getOutput(0); - return size_tensor; -} - -nvinfer1::ITensor *StrideSliceTensorRT::GetDynamicSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input, - const nvinfer1::Dims &size_dims, - const nvinfer1::Dims &start_dims) { - auto in_tensor_shape = ctx->network()->addShape(*input)->getOutput(0); - if (in_tensor_shape == nullptr) { - MS_LOG(ERROR) << "add shape layer of input failed!"; - return nullptr; - } - std::vector is_dynamic; - std::vector is_fix; - std::vector size_vec; - std::vector start_vec; - for (int i = 0; i != size_dims.nbDims; ++i) { - is_dynamic.push_back(size_dims.d[i] <= 0); - is_fix.push_back(size_dims.d[i] > 0); - size_vec.push_back(size_dims.d[i]); - start_vec.push_back(start_dims.d[i]); - } - auto is_dynamic_tensor = ctx->ConvertTo1DTensor(is_dynamic); - auto is_fix_tensor = ctx->ConvertTo1DTensor(is_fix); - auto size_tensor = ctx->ConvertTo1DTensor(size_vec); - auto start_tensor = ctx->ConvertTo1DTensor(start_vec); - auto dynamic_in_tensor = - ctx->network()->addElementWise(*in_tensor_shape, *start_tensor, nvinfer1::ElementWiseOperation::kSUB)->getOutput(0); - auto fix_tensor = - ctx->network()->addElementWise(*is_fix_tensor, *size_tensor, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - auto dynamic_tensor = - ctx->network() - ->addElementWise(*is_dynamic_tensor, *dynamic_in_tensor, nvinfer1::ElementWiseOperation::kPROD) - ->getOutput(0); - size_tensor = - ctx->network()->addElementWise(*dynamic_tensor, *fix_tensor, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); - return size_tensor; -} - -int StrideSliceTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() < HAS_AXIS - 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -bool StrideSliceTensorRT::GetConstInputValue(int *start_val, int *stride_val) { - const auto &begin = in_tensors_.at(BEGINS_INDEX); - if (begin.IsConst()) { - if (begin.ElementNum() != 1) { - MS_LOG(ERROR) << "Only support element number of begin to be 1 when this number < input dims number, op: " - << op_name_; - return false; - } - auto start_vec = ConvertTensorAsIntVector(begin); - if (start_vec.size() != 1) { - MS_LOG(ERROR) << "Failed to get start or stride input, node: " << op_name_; - return false; - } - *start_val = start_vec[0]; - } - const auto &stride = in_tensors_.back(); - if (stride.ElementNum() != 1) { - MS_LOG(ERROR) << "Only support element number of stride to be 1 when this number < input dims number, op: " - << op_name_; - return false; - } - auto stride_vec = ConvertTensorAsIntVector(stride); - if (stride_vec.size() != 1) { - MS_LOG(ERROR) << "Failed to get start or stride input, node: " << op_name_; - return false; - } - *stride_val = stride_vec[0]; - return true; -} - -int StrideSliceTensorRT::ComputeDimsMulti(TensorRTContext *ctx, ITensorHelper *slice_input, const TensorInfo &begin, - const TensorInfo &stride, const TensorInfo &end, size_t start_mask, - size_t end_mask) { - auto input_dims = slice_input->trt_tensor_->getDimensions(); - start_dims_ = lite::ConvertCudaDims(begin); - stride_dims_ = lite::ConvertCudaDims(stride); - auto end_dims = lite::ConvertCudaDims(end); - size_dims_.nbDims = input_dims.nbDims; - for (int i = 0; i < size_dims_.nbDims; i++) { - size_t mask = 1 << i; - start_dims_.d[i] = ((start_mask & mask) == 0 ? start_dims_.d[i] : 0); - if (start_dims_.d[i] < 0) { - start_dims_.d[i] += input_dims.d[i]; - } - if (end.IsConst()) { - if ((end_mask & mask) != 0 && input_dims.d[i] > 0) { - end_dims.d[i] = input_dims.d[i]; - } else if ((end_mask & mask) != 0) { - size_dims_.d[i] = -1; - continue; - } - if (end_dims.d[i] >= 0) { - if (input_dims.d[i] >= 0) { - size_dims_.d[i] = std::min(end_dims.d[i], input_dims.d[i]) - start_dims_.d[i]; - } else { - size_dims_.d[i] = end_dims.d[i] - start_dims_.d[i]; - } - } else if (end_dims.d[i] >= -input_dims.d[i]) { - size_dims_.d[i] = end_dims.d[i] + input_dims.d[i] - start_dims_.d[i]; - } else { - size_dims_.d[i] = input_dims.d[i]; - } - if (size_dims_.d[i] < 0) { - size_dims_.d[i] += input_dims.d[i]; - stride_dims_.d[i] = -stride_dims_.d[i]; - } - } - } - return RET_OK; -} - -int StrideSliceTensorRT::ComputeDimsSingle(TensorRTContext *ctx, ITensorHelper *slice_input, const TensorInfo &begin, - const TensorInfo &stride, const TensorInfo &end, size_t start_mask, - size_t end_mask) { - auto input_dims = slice_input->trt_tensor_->getDimensions(); - - int axis_value = GetAxis(ctx); - int start_value = 0; - int stride_value = 0; - if (!GetConstInputValue(&start_value, &stride_value)) { - return RET_ERROR; - } - stride_dims_.nbDims = input_dims.nbDims; - std::fill(stride_dims_.d, stride_dims_.d + stride_dims_.nbDims, 1); - stride_dims_.d[axis_value] = stride_value; - - if (!begin.IsConst() && !end.IsConst()) { - return RET_OK; - } - - if (start_value < 0) { - start_value = input_dims.d[axis_value] + start_value; - } - start_dims_.nbDims = input_dims.nbDims; - std::fill(start_dims_.d, start_dims_.d + start_dims_.nbDims, 0); - start_dims_.d[axis_value] = start_value; - - if (!end.IsConst()) { - return RET_OK; - } - int end_value = ConvertTensorAsIntVector(end)[0]; - size_dims_ = slice_input->trt_tensor_->getDimensions(); - for (int i = 0; i < start_dims_.nbDims; i++) { - if (i == axis_value) { - if (end_value >= 0) { - size_dims_.d[i] = std::min(end_value, input_dims.d[i]) - start_dims_.d[i]; - } else if (end_value >= -input_dims.d[i]) { - size_dims_.d[i] = end_value + input_dims.d[i] - start_dims_.d[i]; - } else { - size_dims_.d[i] = input_dims.d[i]; - } - } - size_dims_.d[i] = std::abs(size_dims_.d[i] / stride_dims_.d[i]) + ((size_dims_.d[i] % stride_dims_.d[i]) != 0); - } - DebugDims("size : ", size_dims_); - return RET_OK; -} - -int StrideSliceTensorRT::ComputeDims(TensorRTContext *ctx, ITensorHelper *slice_input, const TensorInfo &begin, - const TensorInfo &stride, const TensorInfo &end, size_t start_mask, - size_t end_mask) { - if (static_cast(begin.ElementNum()) == slice_input->trt_tensor_->getDimensions().nbDims) { - return ComputeDimsMulti(ctx, slice_input, begin, stride, end, start_mask, end_mask); - } - return ComputeDimsSingle(ctx, slice_input, begin, stride, end, start_mask, end_mask); -} - -int StrideSliceTensorRT::GetAxis(TensorRTContext *ctx) { - int64_t axis_index = in_tensors_.size() == HAS_AXIS ? AXIS_INDEX : -1; - int axis_value = 0; - if (axis_index != -1) { - auto axis_vec = ConvertTensorAsIntVector(in_tensors_[axis_index]); - if (axis_vec.size() != 1) { - MS_LOG(ERROR) << "Failed to get axis input, node: " << op_name_ << ", axis count: " << axis_vec.size(); - return -1; - } - axis_value = axis_vec[0]; - } - if (axis_value < 0) { - axis_value += input(ctx, 0).trt_tensor_->getDimensions().nbDims; - } - return axis_value; -} - -int StrideSliceTensorRT::ComputeSliceDims(TensorRTContext *ctx, ITensorHelper *slice_input) { - auto op = AsOps(); - shrink_axis_ = op->get_shrink_axis_mask(); - size_t start_mask = op->get_begin_mask(); - size_t end_mask = op->get_end_mask(); - - const auto &begin = in_tensors_.at(BEGINS_INDEX); - const auto &stride = in_tensors_.back(); - const auto &end = in_tensors_.at(ENDS_INDEX); - - auto input_dims = slice_input->trt_tensor_->getDimensions(); - if (begin.ElementNum() == input_dims.nbDims) { - int dims_ret = ComputeDims(ctx, slice_input, begin, stride, end, start_mask, end_mask); - if (dims_ret) { - MS_LOG(ERROR) << "comput start dims, stride dims, size dims filed for " << op_name_; - return RET_ERROR; - } - if (IsDynamicInput(ctx, 0) && end.IsConst()) { - size_tensor_ = GetDynamicSliceSize(ctx, slice_input->trt_tensor_, size_dims_, start_dims_); - size_dims_ = nvinfer1::Dims{-1}; - } - if (!end.IsConst()) { - size_tensor_ = GetDynamicSliceSize(ctx, slice_input->trt_tensor_, end_mask); - size_dims_ = nvinfer1::Dims{-1}; - } - } else { - int axis_value = GetAxis(ctx); - int dims_ret = ComputeDims(ctx, slice_input, begin, stride, end, start_mask, end_mask); - if (dims_ret) { - MS_LOG(ERROR) << "comput start dims, stride dims, size dims filed for " << op_name_; - return RET_ERROR; - } - if (IsDynamicInput(ctx, 0) && begin.IsConst() && end.IsConst()) { - size_tensor_ = - GetDynamicAxisSliceSize(ctx, slice_input->trt_tensor_, size_dims_.d[axis_value], axis_value, nullptr); - size_dims_ = nvinfer1::Dims{-1}; - } - if (!begin.IsConst()) { - start_tensor_ = GetDynamicAxisSliceStart(ctx, input(ctx, 1).trt_tensor_, axis_value, input_dims.nbDims); - start_dims_ = nvinfer1::Dims{-1}; - } - if (!end.IsConst()) { - auto start_tensor = input(ctx, 1).trt_tensor_; - if (start_tensor == nullptr) { - auto start_vec = ConvertTensorAsIntVector(begin); - int start_value = start_vec[0]; - if (start_value < 0) { - start_value = slice_input->trt_tensor_->getDimensions().d[axis_value] + start_value; - } - start_tensor = ctx->ConvertTo1DTensor(start_value); - } - auto len_tensor = - ctx->network() - ->addElementWise(*input(ctx, INPUT_SIZE2).trt_tensor_, *start_tensor, nvinfer1::ElementWiseOperation::kSUB) - ->getOutput(0); - size_tensor_ = GetDynamicAxisSliceSize(ctx, slice_input->trt_tensor_, -1, axis_value, len_tensor); - size_dims_ = nvinfer1::Dims{-1}; - } - } - return RET_OK; -} - -int StrideSliceTensorRT::AddInnerOp(TensorRTContext *ctx) { - auto in_tensor = input(ctx, 0); - if (in_tensors_[0].IsConst() && in_tensor.trt_tensor_ == nullptr) { - in_tensor.trt_tensor_ = lite::ConvertConstantTensor(ctx, in_tensors_[0], op_name_); - in_tensor.format_ = Format::NCHW; - ctx->RegisterTensor(in_tensor, in_tensors_[0].Name()); - } - - if (ComputeSliceDims(ctx, &in_tensor) != RET_OK) { - return RET_ERROR; - } - nvinfer1::ISliceLayer *slice_layer = - ctx->network()->addSlice(*in_tensor.trt_tensor_, start_dims_, size_dims_, stride_dims_); - if (slice_layer == nullptr) { - MS_LOG(ERROR) << "add Slice op failed for TensorRT: " << op_name_; - return RET_ERROR; - } - if (start_tensor_ != nullptr) { - slice_layer->setInput(1, *start_tensor_); - } - if (size_tensor_ != nullptr) { - slice_layer->setInput(INPUT_SIZE2, *size_tensor_); - } - this->layer_ = slice_layer; - slice_layer->setName(op_name_.c_str()); - nvinfer1::ITensor *out_tensor = slice_layer->getOutput(0); - auto shape = ConvertMSShape(out_tensor->getDimensions()); - bool rank_0 = false; - if (shrink_axis_ != 0) { - for (int i = SizeToInt(shape.size()) - 1; i >= 0; --i) { - int mask = 1 << i; - if ((shrink_axis_ & mask) != 0) { - shape.erase(shape.begin() + i); - } - } - if (!shape.empty()) { - out_tensor = Reshape(ctx, out_tensor, shape); - } else { - rank_0 = true; - } - } - auto helper = ITensorHelper{out_tensor, in_tensor.format_, in_tensor.same_format_, !rank_0}; - ctx->RegisterTensor(helper, out_tensors_[0].Name()); - MS_LOG(DEBUG) << "slice output : " << GetTensorFormat(helper); - return RET_OK; -} - -REGISTER_TENSORRT_CREATOR(ops::kNameStridedSlice, StrideSliceTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/strideslice_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/strideslice_tensorrt.h deleted file mode 100644 index 9f15d0f865f01453f5b38ae460e65eb913255e20..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/strideslice_tensorrt.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_STRIDE_SLICE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_STRIDE_SLICE_TENSORRT_H_ -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -constexpr int BEGINS_INDEX = 1; -constexpr int ENDS_INDEX = 2; -constexpr int HAS_AXIS = 5; -constexpr int AXIS_INDEX = 3; -class StrideSliceTensorRT : public TensorRTOp { - public: - StrideSliceTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~StrideSliceTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - nvinfer1::ITensor *GetDynamicAxisSliceStart(TensorRTContext *ctx, nvinfer1::ITensor *input, int axis, int nbdims); - nvinfer1::ITensor *GetDynamicSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input, - const nvinfer1::Dims &size_dims, const nvinfer1::Dims &start_dims); - nvinfer1::ITensor *GetDynamicSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *slice_input, size_t end_mask); - nvinfer1::ITensor *GetDynamicAxisSliceSize(TensorRTContext *ctx, nvinfer1::ITensor *input, int size_dim, int axis, - nvinfer1::ITensor *size_tensor); - int ComputeSliceDims(TensorRTContext *ctx, ITensorHelper *slice_input); - int ComputeDims(TensorRTContext *ctx, ITensorHelper *slice_input, const TensorInfo &begin, const TensorInfo &stride, - const TensorInfo &end, size_t start_mask, size_t end_mask); - int ComputeDimsSingle(TensorRTContext *ctx, ITensorHelper *slice_input, const TensorInfo &begin, - const TensorInfo &stride, const TensorInfo &end, size_t start_mask, size_t end_mask); - int ComputeDimsMulti(TensorRTContext *ctx, ITensorHelper *slice_input, const TensorInfo &begin, - const TensorInfo &stride, const TensorInfo &end, size_t start_mask, size_t end_mask); - bool GetConstInputValue(int *start_val, int *stride_val); - int GetAxis(TensorRTContext *ctx); - size_t shrink_axis_; - size_t start_axis_; - size_t end_axis_; - nvinfer1::Dims start_dims_; - nvinfer1::Dims size_dims_; - nvinfer1::Dims stride_dims_; - nvinfer1::ITensor *size_tensor_{nullptr}; - nvinfer1::ITensor *start_tensor_{nullptr}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_STRIDE_SLICE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_executor_plugin.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_executor_plugin.h deleted file mode 100644 index 69f0a704d25c8c35a754514216f141dfebf1815e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_executor_plugin.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_ -#include -#include -#include "src/common/log_adapter.h" -#include "include/errorcode.h" -#include "NvInferRuntimeCommon.h" -#include - -namespace mindspore::lite { -void SerializeValue(void **buffer, const void *value, size_t cpy_size); -void DeserializeValue(void const **buffer, size_t *buffer_size, void *value, size_t cpy_size); -class TensorRTPlugin : public nvinfer1::IPluginV2DynamicExt { - public: - TensorRTPlugin(const std::string &layer_name, const std::string &plugin_name, uint32_t device_id = 0) - : layer_name_(layer_name), plugin_name_(plugin_name), device_id_(device_id) {} - - // It doesn't make sense to make GeluPluginDynamic without arguments, so we delete - // default constructor. - TensorRTPlugin() = delete; - - // IPluginV2DynamicExt Methods - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override; - - // IPluginV2 Methods - const char *getPluginType() const noexcept override; - const char *getPluginVersion() const noexcept override; - int getNbOutputs() const noexcept override; - int initialize() noexcept override; - void terminate() noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - void destroy() noexcept override; - void setPluginNamespace(const char *pluginNamespace) noexcept override; - const char *getPluginNamespace() const noexcept override; - - protected: - std::string layer_name_; - std::string name_space_; - std::string plugin_version_{"1"}; - std::string plugin_name_; - uint32_t device_id_{0}; -}; - -template -class TensorRTPluginCreater : public nvinfer1::IPluginCreator { - public: - explicit TensorRTPluginCreater(const std::string &plugin_name) : plugin_name_(plugin_name) { - // Fill PluginFieldCollection with PluginField arguments metadata - field_collection_.nbFields = fields_.size(); - field_collection_.fields = fields_.data(); - } - - const char *getPluginName() const noexcept override { return plugin_name_.c_str(); } - - const char *getPluginVersion() const noexcept override { return plugin_version_.c_str(); } - - const nvinfer1::PluginFieldCollection *getFieldNames() noexcept override { return &field_collection_; } - - void setPluginNamespace(const char *pluginNamespace) noexcept override { name_space_ = std::string(pluginNamespace); } - - const char *getPluginNamespace() const noexcept override { return name_space_.c_str(); } - - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) noexcept { - return new (std::nothrow) T(name, fc); - } - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, size_t serialLength) noexcept { - return new (std::nothrow) T(name, serialData, serialLength); - } - - protected: - static nvinfer1::PluginFieldCollection field_collection_; - static std::vector fields_; - std::string name_space_; - std::string plugin_version_{"1"}; - std::string plugin_name_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_op.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_op.cc deleted file mode 100644 index 2130eb23e32f7db1b22971cff2ab8f9977e9720c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_op.cc +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h" - -namespace mindspore::lite { -TensorRTOp::TensorRTOp(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : base_operator_(base_operator), in_tensors_(in_tensors), out_tensors_(out_tensors), op_name_(std::move(name)) { - MS_EXCEPTION_IF_NULL(base_operator); - - this->type_ = base_operator->name(); - auto primitive_c = base_operator->GetPrim(); - if (primitive_c != nullptr) { - return; - } -} - -const BaseOperatorPtr &TensorRTOp::GetBaseOperator() { return this->base_operator_; } - -std::string TensorRTOp::GetOpName() { return this->op_name_; } - -std::vector &TensorRTOp::inputs() { return this->in_tensors_; } - -std::vector &TensorRTOp::outputs() { return this->out_tensors_; } - -ITensorHelper TensorRTOp::input(TensorRTContext *ctx, size_t i) { - auto in_ms_tensor = in_tensors_[i]; - ITensorHelper in_trt_tensor = ctx->MsName2Tensor(in_ms_tensor.Name()); - - if (!GetSupportInputBool() && in_ms_tensor.DataType() == DataType::kNumberTypeBool) { - ITensorHelper in_trt_tensor_cast = ctx->MsName2Tensor(in_ms_tensor.Name() + "_to_int32"); - if (in_trt_tensor_cast.trt_tensor_ == nullptr) { - auto cast_trt_tensor = - TRTTensorCast(ctx, in_trt_tensor.trt_tensor_, nvinfer1::DataType::kINT32, in_ms_tensor.Name() + "_cast_int32"); - in_trt_tensor_cast = ITensorHelper{cast_trt_tensor, in_ms_tensor.format(), true}; - ctx->RegisterTensor(in_trt_tensor_cast, in_ms_tensor.Name() + "_to_int32"); - } - return in_trt_tensor_cast; - } - return in_trt_tensor; -} - -ITensorHelper TensorRTOp::output(TensorRTContext *ctx, size_t i) { return ctx->MsName2Tensor(out_tensors_[i].Name()); } - -const std::string &TensorRTOp::type() const { return this->type_; } - -schema::QuantType TensorRTOp::GetQuantType() const { return this->quant_type_; } - -void TensorRTOp::set_in_ops(const std::vector &in_ops) { this->in_ops_ = in_ops; } - -void TensorRTOp::set_out_ops(const std::vector &out_ops) { this->out_ops_ = out_ops; } - -const std::vector &TensorRTOp::in_ops() const { return this->in_ops_; } - -const std::vector &TensorRTOp::out_ops() const { return this->out_ops_; } - -void TensorRTOp::SetRuntime(TensorRTRuntime *runtime) { - this->runtime_ = runtime; - device_id_ = runtime_->GetDeviceID(); -} - -bool TensorRTOp::HasConst() const { - return std::any_of(in_tensors_.begin(), in_tensors_.end(), - [](const TensorInfo &tensor) { return tensor.Data() != nullptr && tensor.IsConst(); }); -} - -int TensorRTOp::ReadyInputsNumber(TensorRTContext *ctx) const { - return std::count_if(in_tensors_.begin(), in_tensors_.end(), - [&](const TensorInfo &tensor) { return ctx->HasTensor(tensor.Name()); }); -} - -bool TensorRTOp::IsShapeKnown() { return true; } - -bool TensorRTOp::IsDynamicInput(TensorRTContext *ctx, size_t k) { - nvinfer1::Dims dims = input(ctx, k).trt_tensor_->getDimensions(); - return std::any_of(dims.d, dims.d + dims.nbDims, [](int d) { return d == -1; }); -} - -int TensorRTOp::Prepare(void **network_tensor_bindings, nvinfer1::ICudaEngine *engine) { - if (op_binding_tensor_.size() != 0) { - MS_LOG(ERROR) << "need special op Prepare for " << op_name_; - return RET_ERROR; - } - return RET_OK; -} - -DynamicShapeParams TensorRTOp::GetDynamicShapeParams() const { return this->dynamic_shape_params_; } - -int TensorRTOp::SetInt8DynamicRange(TensorRTContext *ctx) { - // setting param layer_ forcely - if (this->layer_ == nullptr) { - MS_LOG(WARNING) << op_name_ << " layer is nullptr."; - return RET_OK; - } - if (in_tensors_.empty() || out_tensors_.empty()) { - MS_LOG(ERROR) << "input or output tensor empty."; - return RET_ERROR; - } - return RET_OK; -} - -int TensorRTOp::SetTransposeDynamicRange() { - if (this->transpose_layer_ == nullptr) { - MS_LOG(INFO) << op_name_ << " transpose_layer is nullptr."; - return RET_OK; - } - return RET_OK; -} - -bool TensorRTOp::GetSupportInputBool() { return this->support_input_bool_; } - -void TensorRTOp::SetSupportInputBool(bool support_input_bool) { this->support_input_bool_ = support_input_bool; } - -void TensorRTOp::PrintTrtInputs(TensorRTContext *ctx) { - MS_LOG(DEBUG) << "Op " << op_name_ << " type: " << type_; - for (size_t i = 0; i < in_tensors_.size(); i++) { - if (in_tensors_[i].IsConst()) { - MS_LOG(DEBUG) << "-input " << i << " " << in_tensors_[i].Shape() << " " << in_tensors_[i].DataType(); - } else { - auto tensor = input(ctx, i); - if (tensor.trt_tensor_) { - MS_LOG(DEBUG) << "-input " << i << " " << CudaDimsAsString(tensor.trt_tensor_->getDimensions()) << " " - << in_tensors_[i].DataType(); - } - } - } -} - -void TensorRTOp::PrintTrtOutputs(TensorRTContext *ctx) { - MS_LOG(DEBUG) << "Op " << op_name_ << " type: " << type_; - for (size_t i = 0; i < out_tensors_.size(); i++) { - auto tensor = output(ctx, i); - if (tensor.trt_tensor_) { - MS_LOG(DEBUG) << "-output " << i << " " << CudaDimsAsString(tensor.trt_tensor_->getDimensions()) << " " - << out_tensors_[i].DataType(); - } - } -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_op.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_op.h deleted file mode 100644 index bad5a401a6ee32d66886edf62f4935b7e9235c9a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_op.h +++ /dev/null @@ -1,198 +0,0 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_OP_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_OP_H_ - -#include -#include -#include -#include -#include -#include "include/api/kernel.h" -#include "src/common/log_adapter.h" -#include "include/errorcode.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_context.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "src/extendrt/delegate/tensorrt/op_registration_factory.h" -#include "src/extendrt/delegate/tensorrt/tensor_info.h" -#include "src/common/log_util.h" -#include "ops/base_operator.h" -#include "mindspore/ops/op_def/op_name.h" -#include "common/kernel.h" -#include "include/api/types.h" -#include "mindapi/base/types.h" - -namespace mindspore::lite { -constexpr int INPUT_SIZE2 = 2; -constexpr int INPUT_SIZE3 = 3; -constexpr int INPUT_SIZE4 = 4; -constexpr int INPUT_SIZE5 = 5; -constexpr int INPUT_SIZE6 = 6; -constexpr int INPUT_SIZE7 = 7; -constexpr int INPUT_SIZE8 = 8; -constexpr int INPUT_SIZE9 = 9; -constexpr int INPUT_SIZE10 = 10; - -struct BindingHelper { - std::string name_; - const void *data_{nullptr}; - nvinfer1::DataType data_type_; - size_t size_; - bool is_input_binding_{false}; -}; - -struct DynamicShapeParams { - bool support_dynamic_{true}; - bool support_hw_dynamic_{true}; -}; - -class TensorRTRuntime; - -using BaseOperatorPtr = std::shared_ptr; - -class TensorRTOp { - public: - TensorRTOp(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name); - - virtual ~TensorRTOp() = default; - - virtual int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) = 0; - - // The weight input has been processed internally by the operator. The framework does not - // need to process the weight input. - virtual bool IsWeightInputHanledInner() const { return false; } - - virtual int AddInnerOp(TensorRTContext *ctx) = 0; - - virtual int SetInt8DynamicRange(TensorRTContext *ctx); - - virtual int Prepare(void **network_tensor_bindings, nvinfer1::ICudaEngine *engine); - - const BaseOperatorPtr &GetBaseOperator(); - - virtual bool HasConst() const; - - int ReadyInputsNumber(TensorRTContext *ctx) const; - - std::string GetOpName(); - - std::vector &inputs(); - - ITensorHelper input(TensorRTContext *ctx, size_t i); - - ITensorHelper output(TensorRTContext *ctx, size_t i); - - std::vector &outputs(); - - const std::string &type() const; - - schema::QuantType GetQuantType() const; - - void set_in_ops(const std::vector &in_ops); - - void set_out_ops(const std::vector &out_ops); - - const std::vector &in_ops() const; - - const std::vector &out_ops() const; - - void SetRuntime(TensorRTRuntime *runtime); - cublasHandle_t GetCublasHandle() { return runtime_ ? runtime_->GetCublasHandle() : nullptr; } - cublasLtHandle_t GetCublasLtHandle() { return runtime_ ? runtime_->GetCublasLtHandle() : nullptr; } - - DynamicShapeParams GetDynamicShapeParams() const; - - nvinfer1::ILayer *layer() { return layer_; } - - bool GetSupportInputBool(); - bool IsDynamicInput(TensorRTContext *ctx, size_t k); - - void SetSupportInputBool(bool support_input_bool); - template - std::shared_ptr AsOps() { - return std::make_shared(base_operator_->GetPrim()); - } - - template - static std::shared_ptr AsOps(const BaseOperatorPtr &base_operator) { - return std::make_shared(base_operator->GetPrim()); - } - void PrintTrtInputs(TensorRTContext *ctx); - void PrintTrtOutputs(TensorRTContext *ctx); - - private: - int SetTransposeDynamicRange(); - - protected: - bool IsShapeKnown(); - - nvinfer1::ILayer *layer_ = nullptr; - - nvinfer1::IShuffleLayer *transpose_layer_ = nullptr; - - BaseOperatorPtr base_operator_ = nullptr; - std::vector in_tensors_; - std::vector out_tensors_; - - std::vector in_ops_; - - std::vector out_ops_; - - std::string op_name_; - - std::string type_; - - schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE; - - std::vector op_binding_tensor_; - - TensorRTRuntime *runtime_{nullptr}; - - DynamicShapeParams dynamic_shape_params_; - - uint32_t device_id_{0}; - - bool support_input_bool_{true}; -}; - -template -TensorRTOp *GetTensorRTOp(const BaseOperatorPtr &base_operator, const std::vector &inputs, - const std::vector &outputs, const std::string &name) { - auto *op = new (std::nothrow) T(base_operator, inputs, outputs, name); - if (op == nullptr) { - MS_LOG(WARNING) << "TensorRT is nullptr."; - return nullptr; - } - - auto ret = op->IsSupport(base_operator, inputs, outputs); - if (ret != RET_OK) { - MS_LOG(WARNING) << "TensorRT op is not supported: " << name; - delete op; - return nullptr; - } - return op; -} -typedef TensorRTOp *(*TensorRTGetOp)(const BaseOperatorPtr &base_operator, const std::vector &inputs, - const std::vector &outputs, const std::string &name); - -#define REGISTER_TENSORRT_CREATOR(KEY, TENSORRT_OP) \ - REGISTER_CLASS_CREATOR(std::string, KEY, TensorRTGetOp, GetTensorRTOp); - -using TensorRTRegistrationFactory = AutoRegistrationFactory; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_OP_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_plugin.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_plugin.cc deleted file mode 100644 index d80b59e55cc37b3d9415acc6812abf7aa7fd7596..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_plugin.cc +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" -#include - -namespace mindspore::lite { -void SerializeValue(void **buffer, const void *value, size_t cpy_size) { - memcpy(*buffer, value, cpy_size); - *buffer = static_cast(*buffer) + cpy_size; -} - -void DeserializeValue(void const **buffer, size_t *buffer_size, void *value, size_t cpy_size) { - if (cpy_size > *buffer_size) { - MS_LOG(ERROR) << "invalid desirialize size, buffer size: " << *buffer_size << ", value size: " << cpy_size; - return; - } - memcpy(value, *buffer, cpy_size); - *buffer = static_cast(*buffer) + cpy_size; - *buffer_size -= cpy_size; -} - -nvinfer1::DimsExprs TensorRTPlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) noexcept { - return inputs[0]; -} - -bool TensorRTPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept { - return true; -} - -void TensorRTPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {} - -size_t TensorRTPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept { - return 0; -} - -nvinfer1::DataType TensorRTPlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const noexcept { - return inputTypes[0]; -} - -const char *TensorRTPlugin::getPluginType() const noexcept { return plugin_name_.c_str(); } - -const char *TensorRTPlugin::getPluginVersion() const noexcept { return plugin_version_.c_str(); } - -int TensorRTPlugin::getNbOutputs() const noexcept { return 1; } - -int TensorRTPlugin::initialize() noexcept { return 0; } - -void TensorRTPlugin::terminate() noexcept {} - -size_t TensorRTPlugin::getSerializationSize() const noexcept { return 0; } - -void TensorRTPlugin::serialize(void *buffer) const noexcept {} - -void TensorRTPlugin::destroy() noexcept { - // This gets called when the network containing plugin is destroyed - delete this; -} - -void TensorRTPlugin::setPluginNamespace(const char *libNamespace) noexcept { name_space_ = libNamespace; } - -const char *TensorRTPlugin::getPluginNamespace() const noexcept { return name_space_.c_str(); } -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h deleted file mode 100644 index f2d5f708735d414e727c113da0d83667414cc621..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_ -#include -#include -#include "src/common/log_adapter.h" -#include "include/errorcode.h" -#include "NvInferRuntimeCommon.h" -#include - -namespace mindspore::lite { -void SerializeValue(void **buffer, const void *value, size_t cpy_size); -void DeserializeValue(void const **buffer, size_t *buffer_size, void *value, size_t cpy_size); -class TensorRTPlugin : public nvinfer1::IPluginV2DynamicExt { - public: - TensorRTPlugin(const std::string &layer_name, const std::string &plugin_name, uint32_t device_id = 0) - : layer_name_(layer_name), plugin_name_(plugin_name), device_id_(device_id) {} - - // It doesn't make sense to make GeluPluginDynamic without arguments, so we delete - // default constructor. - TensorRTPlugin() = delete; - - // IPluginV2DynamicExt Methods - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) noexcept override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override; - - // IPluginV2 Methods - const char *getPluginType() const noexcept override; - const char *getPluginVersion() const noexcept override; - int getNbOutputs() const noexcept override; - int initialize() noexcept override; - void terminate() noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - void destroy() noexcept override; - void setPluginNamespace(const char *pluginNamespace) noexcept override; - const char *getPluginNamespace() const noexcept override; - - protected: - std::string layer_name_; - std::string name_space_; - std::string plugin_version_{"1"}; - std::string plugin_name_; - uint32_t device_id_{0}; -}; - -template -class TensorRTPluginCreater : public nvinfer1::IPluginCreator { - public: - explicit TensorRTPluginCreater(const std::string &plugin_name) : plugin_name_(plugin_name) { - // Fill PluginFieldCollection with PluginField arguments metadata - field_collection_.nbFields = fields_.size(); - field_collection_.fields = fields_.data(); - } - - const char *getPluginName() const noexcept override { return plugin_name_.c_str(); } - - const char *getPluginVersion() const noexcept override { return plugin_version_.c_str(); } - - const nvinfer1::PluginFieldCollection *getFieldNames() noexcept override { return &field_collection_; } - - void setPluginNamespace(const char *pluginNamespace) noexcept override { name_space_ = std::string(pluginNamespace); } - - const char *getPluginNamespace() const noexcept override { return name_space_.c_str(); } - - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) noexcept override { - return new (std::nothrow) T(name, fc); - } - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *data, size_t len) noexcept override { - return new (std::nothrow) T(name, data, len); - } - - protected: - static nvinfer1::PluginFieldCollection field_collection_; - static std::vector fields_; - std::string name_space_; - std::string plugin_version_{"1"}; - std::string plugin_name_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorscatteradd_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorscatteradd_tensorrt.cc deleted file mode 100644 index 339a04bb605d37e34f46dcf84d921fee1051a2a1..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorscatteradd_tensorrt.cc +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/tensorscatteradd_tensorrt.h" -#include -#include -#include -#include "infer/tensor_scatter_add.h" -#include "kernel/gpu/cuda_impl/cuda_ops/tensor_scatter_arithmetic.cuh" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" - -namespace mindspore::lite { -int TensorScatterAddTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size() << " : " << op_name_; - return RET_ERROR; - } - return RET_OK; -} - -int TensorScatterAddTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (in_tensors_[0].IsConst()) { - ITensorHelper scatter_input; - scatter_input.trt_tensor_ = lite::ConvertConstantTensor(ctx, in_tensors_[0], op_name_); - scatter_input.format_ = Format::NCHW; - ctx->RegisterTensor(scatter_input, in_tensors_[0].Name()); - } - - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_, input(ctx, 1).trt_tensor_, - input(ctx, INPUT_SIZE2).trt_tensor_}; - auto plugin = std::make_shared(input(ctx, 0).trt_tensor_->getName(), device_id_); - nvinfer1::IPluginV2Layer *scatter_layer = ctx->network()->addPluginV2(inputTensors, 3, *plugin); - if (scatter_layer == nullptr) { - MS_LOG(ERROR) << "addScatter failed for TensorRT."; - return RET_ERROR; - } - - nvinfer1::ITensor *out_tensor = scatter_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{out_tensor, Format::NCHW, true}, out_tensors_[0].Name()); - this->layer_ = scatter_layer; - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(TensorScatterAddPluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int TensorScatterAddPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - return RunCudaTensorScatterAdd(inputDesc, inputs, outputs, stream); -} - -int TensorScatterAddPlugin::RunCudaTensorScatterAdd(const nvinfer1::PluginTensorDesc *inputDesc, - const void *const *inputs, void *const *outputs, - cudaStream_t stream) { - nvinfer1::Dims input_dims = inputDesc[0].dims; - size_t input_num = std::accumulate(input_dims.d, input_dims.d + input_dims.nbDims, 1, std::multiplies()); - nvinfer1::Dims update_dims = inputDesc[INPUT_SIZE2].dims; - size_t update_num = std::accumulate(update_dims.d, update_dims.d + update_dims.nbDims, 1, std::multiplies()); - - size_t indice_dim_0 = inputDesc[1].dims.d[0]; - size_t indice_dim_1 = inputDesc[1].dims.d[1]; - int block_size = 1; - for (int i = indice_dim_1; i != input_dims.nbDims; ++i) { - block_size *= input_dims.d[i]; - } - std::vector indice_stride(indice_dim_1, 0); - indice_stride[indice_stride.size() - 1] = block_size; - for (int i = indice_dim_1 - 1; i > 0; --i) { - indice_stride[i - 1] = indice_stride[i] * input_dims.d[i]; - } - - TensorScatterInfo info; - for (size_t i = 0; i < indice_dim_1; ++i) { - info.indices_stride[i] = static_cast(indice_stride[i]); - } - for (size_t i = 0; i < static_cast(input_dims.nbDims); ++i) { - info.work_shape[i] = static_cast(input_dims.d[i]); - } - cudaMemcpy(outputs[0], inputs[0], input_num * sizeof(float), cudaMemcpyDeviceToDevice); - TensorScatterArithmetic(TensorScatterArithmeticFunctionType::TENSOR_SCATTER_FUNC_ADD, - static_cast(inputs[0]), static_cast(inputs[1]), - static_cast(inputs[INPUT_SIZE2]), static_cast(outputs[0]), block_size, - update_num, input_num, indice_dim_0, indice_dim_1, info, device_id_, stream); - - return RET_OK; -} - -nvinfer1::IPluginV2DynamicExt *TensorScatterAddPlugin::clone() const noexcept { - auto *plugin = new TensorScatterAddPlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -bool TensorScatterAddPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, - int nbInputs, int nbOutputs) noexcept { - if (tensorsDesc[pos].format != nvinfer1::TensorFormat::kLINEAR) { - return false; - } - return true; -} - -size_t TensorScatterAddPlugin::getSerializationSize() const noexcept { return 0; } - -void TensorScatterAddPlugin::serialize(void *buffer) const noexcept {} - -REGISTER_TENSORRT_CREATOR(ops::kNameTensorScatterAdd, TensorScatterAddTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorscatteradd_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorscatteradd_tensorrt.h deleted file mode 100644 index 2deef29ac248c0f6156ba434f359fcbc7df0e181..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tensorscatteradd_tensorrt.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_TENSOR_SCATTER_ADD_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_TENSOR_SCATTER_ADD_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class TensorScatterAddTensorRT : public TensorRTOp { - public: - TensorScatterAddTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~TensorScatterAddTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; -}; - -constexpr auto TENSORSCATTERADD_PLUGIN_NAME{"TensorScatterAddPlugin"}; -class TensorScatterAddPlugin : public TensorRTPlugin { - public: - TensorScatterAddPlugin(const std::string &name, int device_id) - : TensorRTPlugin(name, std::string(TENSORSCATTERADD_PLUGIN_NAME), device_id) {} - - TensorScatterAddPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(TENSORSCATTERADD_PLUGIN_NAME)) {} - - TensorScatterAddPlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(TENSORSCATTERADD_PLUGIN_NAME)) {} - - TensorScatterAddPlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override { - return inputTypes[0]; - } - - private: - int RunCudaTensorScatterAdd(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream); -}; -class TensorScatterAddPluginCreater : public TensorRTPluginCreater { - public: - TensorScatterAddPluginCreater() : TensorRTPluginCreater(std::string(TENSORSCATTERADD_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_OP_TENSOR_SCATTER_ADD_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tile_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/tile_tensorrt.cc deleted file mode 100644 index c0372fc80715ffb135fa9f63d1b672370febb866..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tile_tensorrt.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/tile_tensorrt.h" -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "infer/cxx_api/tile_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" - -namespace mindspore::lite { -int TileTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "invalid input tensor size: " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "invalid output tensor size: " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -int TileTensorRT::AddInnerOp(TensorRTContext *ctx) { - auto repeats_tensor = in_tensors_[1]; - ITensorHelper tile_input = input(ctx, 0); - nvinfer1::ITensor *output; - auto input_shape = ctx->network()->addShape(*input(ctx, 0).trt_tensor_)->getOutput(0); - if (repeats_tensor.IsConst()) { - if (repeats_tensor.ElementNum() != input(ctx, 0).trt_tensor_->getDimensions().nbDims) { - MS_LOG(ERROR) << op_name_ << " has input dims: " << input(ctx, 0).trt_tensor_->getDimensions().nbDims - << ", and invalid repeats cnt: " << repeats_tensor.ElementNum(); - return RET_ERROR; - } - auto ret = ParseData2Vector(in_tensors_[1], &repeats_); - if (ret != RET_OK || repeats_.size() == 0) { - MS_LOG(ERROR) << op_name_ << " has invalid repeats tensor"; - return ret; - } - std::vector repeats(repeats_.size()); - for (size_t i = 0; i != repeats_.size(); ++i) { - repeats[i] = static_cast(repeats_[i]); - } - auto repeat_tensor = ctx->ConvertTo1DTensor(repeats); - auto output_shape = - ctx->network()->addElementWise(*input_shape, *repeat_tensor, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0); - output = Broadcast(ctx, tile_input.trt_tensor_, output_shape); - } else { - auto output_shape = - ctx->network() - ->addElementWise(*input_shape, *input(ctx, 1).trt_tensor_, nvinfer1::ElementWiseOperation::kPROD) - ->getOutput(0); - output = Broadcast(ctx, tile_input.trt_tensor_, output_shape); - } - auto layer = ctx->network()->addIdentity(*output); - layer_ = layer; - auto tile_out = layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{tile_out, tile_input.format_, true}, out_tensors_[0].Name()); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameTileFusion, TileTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tile_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/tile_tensorrt.h deleted file mode 100644 index 860057702ba014d482de53192f563ed7a291fb52..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/tile_tensorrt.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TILE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TILE_TENSORRT_H_ -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" -#include "kernel/gpu/cuda_impl/cuda_ops/tile_impl.cuh" - -namespace mindspore::lite { -class TileTensorRT : public TensorRTOp { - public: - TileTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~TileTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - bool IsWeightInputHanledInner() const override { return true; } - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - int RunAsConcat(TensorRTContext *ctx, const ITensorHelper &tile_input); - std::vector repeats_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TILE_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/topk_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/topk_tensorrt.cc deleted file mode 100644 index be6e0c2e03d49852763e4e6aece4c290b7e29ef2..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/topk_tensorrt.cc +++ /dev/null @@ -1,211 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "src/extendrt/delegate/tensorrt/op/topk_tensorrt.h" -#include "mindspore/ops/op_def/array_ops.h" -#include "infer/cxx_api/arg_max_fusion.h" -#include "infer/cxx_api/arg_min_fusion.h" -#include "infer/cxx_api/topk_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" - -namespace mindspore::lite { -namespace { -nvinfer1::ITensor *TopkReshape(TensorRTContext *ctx, nvinfer1::ITensor *input, int axis) { - auto squeeze = ctx->network()->addShuffle(*input); - if (squeeze == nullptr) { - return nullptr; - } - auto old_shape = ConvertMSShape(input->getDimensions()); - old_shape.erase(old_shape.begin() + axis); - squeeze->setReshapeDimensions(ConvertCudaDims(old_shape)); - return squeeze->getOutput(0); -} -} // namespace - -int TopKTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1 && in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1 && out_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - if (type_ != ops::kNameTopKFusion) { - // need reshape - dynamic_shape_params_.support_hw_dynamic_ = false; - } - if ((type_ == ops::kNameTopKFusion || type_ == ops::kNameTopK) && in_tensors.size() != INPUT_SIZE2) { - MS_LOG(ERROR) << "TopkFusion or Topk need 2 input tensors for " << op_name_; - return RET_ERROR; - } - return RET_OK; -} - -int TopKTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx->network() == nullptr || ReadyInputsNumber(ctx) != 1) { - MS_LOG(ERROR) << "network or input tensor is invalid"; - return RET_ERROR; - } - int ret = ParseParams(ctx); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ParseParams failed for " << op_name_; - return ret; - } - - ITensorHelper topk_input; - ret = PreprocessInputs(ctx, &topk_input); - if (ret != RET_OK || topk_input.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "preprocess input failed for " << op_name_; - return ret; - } - axis_ = 1 << axis_value_; - MS_LOG(DEBUG) << "addTopK input " << GetTensorFormat(topk_input); - MS_LOG(DEBUG) << op_name_ << " has k: " << top_k_ << ", axis: " << axis_value_; - bool need_expand = (topk_input.trt_tensor_->getDimensions().nbDims == 1); - if (need_expand == true) { - topk_input.trt_tensor_ = ExpandDim(ctx, topk_input.trt_tensor_, 0); - axis_ = INPUT_SIZE2; - } - - nvinfer1::ITopKLayer *topk_layer; - if (topk_input.trt_tensor_->getType() == nvinfer1::DataType::kINT32) { - MS_LOG(INFO) << "trt op topk not support INT32 as input, cast to float."; - auto cast_layer = ctx->network()->addIdentity(*topk_input.trt_tensor_); - CHECK_NULL_RETURN(cast_layer); - cast_layer->setOutputType(0, nvinfer1::DataType::kFLOAT); - auto cast_output = cast_layer->getOutput(0); - CHECK_NULL_RETURN(cast_output); - cast_layer->setName((op_name_ + "_cast").c_str()); - topk_layer = ctx->network()->addTopK(*cast_output, topk_op_, top_k_, axis_); - } else { - topk_layer = ctx->network()->addTopK(*topk_input.trt_tensor_, topk_op_, top_k_, axis_); - } - - CHECK_NULL_RETURN(topk_layer); - this->layer_ = topk_layer; - topk_layer->setName(op_name_.c_str()); - nvinfer1::ITensor *value_out_tensor = topk_layer->getOutput(0); - nvinfer1::ITensor *index_out_tensor = topk_layer->getOutput(1); - if (need_expand == true) { - auto shape = ConvertMSShape(value_out_tensor->getDimensions()); - shape.erase(shape.begin()); - value_out_tensor = Reshape(ctx, value_out_tensor, shape); - index_out_tensor = Reshape(ctx, index_out_tensor, shape); - } - // output 0 is data value, output 1 is index - - if (top_k_ == 1 && type_ != ops::kNameTopKFusion && keep_dims_ == false) { - value_out_tensor = TopkReshape(ctx, value_out_tensor, axis_value_); - if (value_out_tensor == nullptr) { - MS_LOG(ERROR) << "add output squeeze failed!"; - return RET_ERROR; - } - index_out_tensor = TopkReshape(ctx, index_out_tensor, axis_value_); - if (index_out_tensor == nullptr) { - MS_LOG(ERROR) << "add output squeeze failed!"; - return RET_ERROR; - } - } - if (out_tensors_.size() == INPUT_SIZE2) { - auto out_tensor = (out_tensors_[1].DataType() == DataType::kNumberTypeInt32) ? index_out_tensor : value_out_tensor; - auto output_helper = ITensorHelper{out_tensor, topk_input.format_, true}; - ctx->RegisterTensor(output_helper, out_tensors_[1].Name()); - } - auto out_tensor = (out_tensors_[0].DataType() == DataType::kNumberTypeInt32) ? index_out_tensor : value_out_tensor; - auto output_helper = ITensorHelper{out_tensor, topk_input.format_, true}; - ctx->RegisterTensor(output_helper, out_tensors_[0].Name()); - return RET_OK; -} - -int TopKTensorRT::ParseParams(TensorRTContext *ctx) { - int input_nbDims = input(ctx, 0).trt_tensor_->getDimensions().nbDims; - if (type_ == ops::kNameArgMinFusion || type_ == ops::kNameArgMaxFusion) { - std::unordered_map type2op = { - {ops::kNameArgMaxFusion, nvinfer1::TopKOperation::kMAX}, {ops::kNameArgMinFusion, nvinfer1::TopKOperation::kMIN}}; - topk_op_ = type2op[type_]; - auto prim = AsOps(); - CHECK_NULL_RETURN(prim); - axis_value_ = prim->get_axis(); - axis_value_ = axis_value_ >= 0 ? axis_value_ : input_nbDims + axis_value_; - if (prim->HasAttr(ops::kKeepDims)) { - keep_dims_ = prim->get_keep_dims(); - } - top_k_ = prim->HasAttr(ops::kTopK) ? prim->get_top_k() : 1; - } - if (type_ == ops::kNameTopKFusion) { - auto topk_prim = AsOps(); - CHECK_NULL_RETURN(topk_prim); - if (topk_prim->HasAttr(ops::kLargest)) { - topk_op_ = topk_prim->get_largest() == 1 ? nvinfer1::TopKOperation::kMAX : nvinfer1::TopKOperation::kMIN; - } else { - MS_LOG(INFO) << "No attribute Largest for TopKFusion, use Default: MAX"; - topk_op_ = nvinfer1::TopKOperation::kMAX; - } - - if (topk_prim->HasAttr(ops::kAxis)) { - axis_value_ = topk_prim->get_axis(); - } else { - MS_LOG(INFO) << "No attribute Axis for TopKFusion, use Default: input dims - 1"; - axis_value_ = input_nbDims - 1; - } - axis_value_ = axis_value_ >= 0 ? axis_value_ : input_nbDims + axis_value_; - std::vector tmp(1); - int ret_k = ParseData2Vector(in_tensors_[1], &tmp); - if (ret_k != RET_OK) { - return ret_k; - } - top_k_ = tmp[0]; - } - if (type_ == ops::kNameTopK) { - auto topk_prim = AsOps(); - CHECK_NULL_RETURN(topk_prim); - topk_op_ = nvinfer1::TopKOperation::kMAX; - - axis_value_ = input_nbDims - 1; - std::vector tmp(1); - int ret_k = ParseData2Vector(in_tensors_[1], &tmp); - if (ret_k != RET_OK) { - return ret_k; - } - top_k_ = tmp[0]; - } - // Currently reduceAxes must specify exactly one dimension, and it must be one of the last four dimensions. - if (axis_value_ != input_nbDims - 1) { - MS_LOG(ERROR) << op_name_ << " has unsupported axis : " << axis_value_; - return RET_ERROR; - } - return RET_OK; -} - -int TopKTensorRT::PreprocessInputs(TensorRTContext *ctx, ITensorHelper *topk_input) { - auto input_dim = input(ctx, 0).trt_tensor_->getDimensions(); - int ret = RET_OK; - if (input_dim.nbDims == DIMENSION_4D) { - ret = PreprocessInputs2SameDim(ctx, input(ctx, 0), topk_input); - } else { - *topk_input = input(ctx, 0); - } - return ret; -} -REGISTER_TENSORRT_CREATOR(ops::kNameArgMaxFusion, TopKTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameArgMinFusion, TopKTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameTopKFusion, TopKTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameTopK, TopKTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/topk_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/topk_tensorrt.h deleted file mode 100644 index 875bb89788baaf943c28ba53a20aaf01284e0d24..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/topk_tensorrt.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TOPK_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TOPK_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -namespace mindspore::lite { -class TopKTensorRT : public TensorRTOp { - public: - TopKTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~TopKTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &inputs, - const std::vector &outputs) override; - - private: - int ParseParams(TensorRTContext *ctx); - - int PreprocessInputs(TensorRTContext *ctx, ITensorHelper *topk_input); - - nvinfer1::TopKOperation topk_op_{nvinfer1::TopKOperation::kMAX}; - uint32_t axis_{0}; - int axis_value_{0}; - int32_t top_k_{0}; - bool keep_dims_{false}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TOPK_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/unary_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/unary_tensorrt.cc deleted file mode 100644 index 235f12171e428321312584ec5a5447312c72d4db..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/unary_tensorrt.cc +++ /dev/null @@ -1,129 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/op/unary_tensorrt.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_n.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -UnaryTensorRT::UnaryTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) { - unary_ops_ = { - {ops::kNameSqrt, nvinfer1::UnaryOperation::kSQRT}, - {ops::kNameAbs, nvinfer1::UnaryOperation::kABS}, - {ops::kNameNeg, nvinfer1::UnaryOperation::kNEG}, - {ops::kNameLog, nvinfer1::UnaryOperation::kLOG}, - {ops::kNameSin, nvinfer1::UnaryOperation::kSIN}, - {ops::kNameCos, nvinfer1::UnaryOperation::kCOS}, - {ops::kNameCeil, nvinfer1::UnaryOperation::kCEIL}, - {ops::kNameFloor, nvinfer1::UnaryOperation::kFLOOR}, - {ops::kNameExpFusion, nvinfer1::UnaryOperation::kEXP}, -#if TRT_VERSION_GE(7, 2) - {ops::kNameLogicalNot, nvinfer1::UnaryOperation::kNOT}, -#endif - }; -} - -int UnaryTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (!IsShapeKnown()) { - MS_LOG(ERROR) << "Unsupported input tensor unknown shape: " << op_name_; - return RET_ERROR; - } - if (in_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - } - auto it = unary_ops_.find(type_); - if (it != unary_ops_.end()) { - unary_op_ = it->second; - } else { - MS_LOG(ERROR) << "unsupported unary ops type: " << type_; - return RET_ERROR; - } - return RET_OK; -} - -int UnaryTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "network or input tensor is invalid"; - return RET_ERROR; - } - if (type_ == ops::kNameNeg && input(ctx, 0).trt_tensor_->getType() == nvinfer1::DataType::kINT32) { - auto trt_in_tensor = input(ctx, 0).trt_tensor_; - size_t dims_size = mindspore::IntToSize(trt_in_tensor->getDimensions().nbDims); - static const float neg1_const = -1; - auto prod_input1 = - ConvertScalarToITensor(ctx, dims_size, &neg1_const, DataType::kNumberTypeInt32, op_name_ + "_neg1"); - CHECK_NULL_RETURN(prod_input1); - auto prod_layer = - ctx->network()->addElementWise(*trt_in_tensor, *prod_input1, nvinfer1::ElementWiseOperation::kPROD); - CHECK_NULL_RETURN(prod_layer); - auto out_tensor = prod_layer->getOutput(0); - CHECK_NULL_RETURN(out_tensor); - ctx->RegisterTensor(ITensorHelper{out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - prod_layer->setName(op_name_.c_str()); - this->layer_ = prod_layer; - ctx->RegisterLayer(prod_layer, op_name_); - return RET_OK; - } - nvinfer1::IUnaryLayer *cal_layer = ctx->network()->addUnary(*input(ctx, 0).trt_tensor_, unary_op_); - if (cal_layer == nullptr) { - MS_LOG(ERROR) << "addUnary failed for: " << op_name_; - return RET_ERROR; - } - cal_layer->setName(op_name_.c_str()); - this->layer_ = cal_layer; - if (type_ == ops::kNameExpFusion) { - auto exp_op = AsOps(); - CHECK_NULL_RETURN(exp_op); - if (exp_op->HasAttr(ops::kScale) && exp_op->HasAttr(ops::kShift) && exp_op->HasAttr(ops::kBase)) { - float scale = exp_op->get_scale(); - float shift = exp_op->get_shift(); - float base = exp_op->get_base(); - if (scale != 1.0f || shift != 0.0f || base != -1.0f) { - MS_LOG(ERROR) << op_name_ << " has fusion to calculate."; - return RET_ERROR; - } - } - } - nvinfer1::ITensor *op_out_tensor = cal_layer->getOutput(0); - ctx->RegisterTensor(ITensorHelper{op_out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - return RET_OK; -} -REGISTER_TENSORRT_CREATOR(ops::kNameSqrt, UnaryTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameAbs, UnaryTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameNeg, UnaryTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameLog, UnaryTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameSin, UnaryTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameCos, UnaryTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameCeil, UnaryTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameFloor, UnaryTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameExpFusion, UnaryTensorRT) -#if TRT_VERSION_GE(7, 2) -REGISTER_TENSORRT_CREATOR(ops::kNameLogicalNot, UnaryTensorRT) -#endif -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/unary_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/unary_tensorrt.h deleted file mode 100644 index aa3d607a9a96450e2b0bd5f901e364ecbaaa100a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/unary_tensorrt.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_UNARY_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_UNARY_TENSORRT_H_ -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" - -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "infer/cxx_api/exp_fusion.h" - -namespace mindspore::lite { -class UnaryTensorRT : public TensorRTOp { - public: - UnaryTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, std::string name); - - ~UnaryTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - std::map unary_ops_; - nvinfer1::UnaryOperation unary_op_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_UNARY_TENSORRT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.cc deleted file mode 100644 index 3287e5ab3ba45252ba5389c07bdfbf7e814f767c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.cc +++ /dev/null @@ -1,194 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "NvInferRuntimeCommon.h" -#include "src/extendrt/delegate/tensorrt/op/where_tensorrt.h" -#include "kernel/gpu/cuda_impl/cuda_ops/where_impl.cuh" -#include "infer/where.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_w.h" - -namespace mindspore::lite { -constexpr int INPUT_X_INDEX = 1; -constexpr int INPUT_Y_INDEX = 2; - -int WhereTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) { - if (in_tensors.size() != 1 && in_tensors.size() != INPUT_SIZE3) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); - return RET_ERROR; - } - if (out_tensors.size() != 1) { - MS_LOG(ERROR) << "Unsupported output tensor size, size is " << out_tensors.size(); - return RET_ERROR; - } - return RET_OK; -} - -nvinfer1::ITensor *WhereTensorRT::GetBroadcastTensor(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor) { - auto input_cond_dims = input(ctx, 0).trt_tensor_->getDimensions(); - nvinfer1::Dims in_tensor_dims = input_tensor->getDimensions(); - while (in_tensor_dims.nbDims < input_cond_dims.nbDims) { - input_tensor = ExpandDim(ctx, input_tensor, 0); - if (input_tensor->getDimensions().nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name_; - } - nvinfer1::IShuffleLayer *shuffle_layer = ctx->network()->addShuffle(*input_tensor); - shuffle_layer->setReshapeDimensions(input_tensor->getDimensions()); - input_tensor = shuffle_layer->getOutput(0); - in_tensor_dims = input_tensor->getDimensions(); - } - return input_tensor; -} - -int WhereTensorRT::AddInnerOp(TensorRTContext *ctx) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "network or input tensor is invalid"; - return RET_ERROR; - } - auto input_x = input(ctx, INPUT_X_INDEX).trt_tensor_; - auto input_y = input(ctx, INPUT_Y_INDEX).trt_tensor_; - if (in_tensors_[INPUT_X_INDEX].DataType() != in_tensors_[INPUT_Y_INDEX].DataType()) { - auto target_type_index = - INPUT_X_INDEX + (in_tensors_[INPUT_X_INDEX].DataType() < in_tensors_[INPUT_Y_INDEX].DataType()); - if (INPUT_X_INDEX != target_type_index) { - input_x = TRTTensorCast(ctx, input_x, ConvertDataType(in_tensors_[INPUT_X_INDEX].DataType()), op_name_ + "_cast"); - } - if (INPUT_Y_INDEX != target_type_index) { - input_y = TRTTensorCast(ctx, input_y, ConvertDataType(in_tensors_[INPUT_Y_INDEX].DataType()), op_name_ + "_cast"); - } - } - nvinfer1::ITensor *inputTensors[] = {input(ctx, 0).trt_tensor_, input_x, input_y}; - ITensorHelper cond_helper = input(ctx, 0); - if (cond_helper.trt_tensor_->getType() != nvinfer1::DataType::kINT32) { - inputTensors[0] = TRTTensorCast(ctx, input(ctx, 0).trt_tensor_, nvinfer1::DataType::kINT32, op_name_ + "_cast_in"); - } - auto input_x_dims = input(ctx, INPUT_X_INDEX).trt_tensor_->getDimensions(); - auto input_y_dims = input(ctx, INPUT_Y_INDEX).trt_tensor_->getDimensions(); - // broadcast to same shape - if (input_x_dims.nbDims != input_y_dims.nbDims) { - if (input_x_dims.nbDims > input_y_dims.nbDims) { - auto input_shape_tensor = ctx->network()->addShape(*input(ctx, INPUT_X_INDEX).trt_tensor_)->getOutput(0); - auto inputy = GetBroadcastTensor(ctx, input_y); - auto size_tensor = ctx->network()->addShape(*inputy)->getOutput(0); - size_tensor = ctx->network() - ->addElementWise(*input_shape_tensor, *size_tensor, nvinfer1::ElementWiseOperation::kMAX) - ->getOutput(0); - inputTensors[INPUT_Y_INDEX] = Broadcast(ctx, inputy, size_tensor); - } else { - auto input_shape_tensor = ctx->network()->addShape(*input(ctx, INPUT_Y_INDEX).trt_tensor_)->getOutput(0); - auto inputx = GetBroadcastTensor(ctx, input_x); - auto size_tensor = ctx->network()->addShape(*inputx)->getOutput(0); - size_tensor = ctx->network() - ->addElementWise(*input_shape_tensor, *size_tensor, nvinfer1::ElementWiseOperation::kMAX) - ->getOutput(0); - inputTensors[INPUT_X_INDEX] = Broadcast(ctx, inputx, size_tensor); - } - } - - auto plugin = std::make_shared(op_name_); - if (plugin == nullptr) { - MS_LOG(ERROR) << "create WherePlugin failed for " << op_name_; - return RET_ERROR; - } - nvinfer1::IPluginV2Layer *where_layer = ctx->network()->addPluginV2(inputTensors, 3, *plugin); - this->layer_ = where_layer; - nvinfer1::ITensor *op_out_tensor = where_layer->getOutput(0); - if (op_out_tensor == nullptr) { - MS_LOG(ERROR) << "where out tensor is nullptr."; - return RET_ERROR; - } - ctx->RegisterTensor(ITensorHelper{op_out_tensor, input(ctx, 0).format_, input(ctx, 0).same_format_}, - out_tensors_[0].Name()); - - return RET_OK; -} - -REGISTER_TENSORRT_PLUGIN(WherePluginCreater); -template class TensorRTPluginCreater; -template -nvinfer1::PluginFieldCollection TensorRTPluginCreater::field_collection_{}; -template -std::vector TensorRTPluginCreater::fields_; - -int WherePlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, - cudaStream_t stream) noexcept { - return RunCudaWhere(inputDesc, inputs, outputs, stream); -} - -int WherePlugin::RunCudaWhere(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, - void *const *outputs, cudaStream_t stream) { - if (inputDesc[0].type == nvinfer1::DataType::kINT32 && inputDesc[INPUT_Y_INDEX].type == nvinfer1::DataType::kINT32) { - Where(static_cast(inputs[0]), static_cast(inputs[INPUT_X_INDEX]), - static_cast(inputs[INPUT_Y_INDEX]), static_cast(outputs[0]), - GetDimsVolume(inputDesc[0].dims), device_id_, stream); - } else if (inputDesc[0].type == nvinfer1::DataType::kINT32 && - inputDesc[INPUT_Y_INDEX].type == nvinfer1::DataType::kFLOAT) { - Where(static_cast(inputs[0]), static_cast(inputs[INPUT_X_INDEX]), - static_cast(inputs[INPUT_Y_INDEX]), static_cast(outputs[0]), - GetDimsVolume(inputDesc[0].dims), device_id_, stream); - } else if (inputDesc[0].type == nvinfer1::DataType::kINT32 && - inputDesc[INPUT_Y_INDEX].type == nvinfer1::DataType::kHALF) { - Where(static_cast(inputs[0]), static_cast(inputs[INPUT_X_INDEX]), - static_cast(inputs[INPUT_Y_INDEX]), static_cast(outputs[0]), - GetDimsVolume(inputDesc[0].dims), device_id_, stream); - } else { - MS_LOG(ERROR) << "invalid where type"; - return RET_ERROR; - } - return RET_OK; -} // namespace mindspore::lite - -nvinfer1::IPluginV2DynamicExt *WherePlugin::clone() const noexcept { - auto *plugin = new WherePlugin(*this); - plugin->setPluginNamespace(name_space_.c_str()); - return plugin; -} - -nvinfer1::DataType WherePlugin::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept { - return inputTypes[1]; -} - -bool WherePlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept { - if (tensorsDesc[pos].format != nvinfer1::TensorFormat::kLINEAR) { - return false; - } - if (pos == 0) { - return tensorsDesc[pos].type == nvinfer1::DataType::kINT32; - } else if (pos == 1) { - return tensorsDesc[pos].type == nvinfer1::DataType::kFLOAT || tensorsDesc[pos].type == nvinfer1::DataType::kINT32; - } else { - return tensorsDesc[pos].type == tensorsDesc[pos - 1].type; - } - MS_LOG(ERROR) << tensorsDesc[pos].type << " " << tensorsDesc[pos].format; - return false; -} - -size_t WherePlugin::getSerializationSize() const noexcept { return sizeof(schema::PrimitiveType); } - -void WherePlugin::serialize(void *buffer) const noexcept {} - -REGISTER_TENSORRT_CREATOR(ops::kNameWhere, WhereTensorRT) -REGISTER_TENSORRT_CREATOR(ops::kNameSelect, WhereTensorRT) -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.h deleted file mode 100644 index edfe863844766b84404c7bc03dc5e8771eaef21a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op/where_tensorrt.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_WHERE_TENSORRT_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_WHERE_TENSORRT_H_ - -#include -#include -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h" - -namespace mindspore::lite { -class WhereTensorRT : public TensorRTOp { - public: - WhereTensorRT(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors, const std::string &name) - : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} - - ~WhereTensorRT() override = default; - - int AddInnerOp(TensorRTContext *ctx) override; - - int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, - const std::vector &out_tensors) override; - - private: - nvinfer1::ITensor *GetBroadcastTensor(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor); -}; - -constexpr auto WHERE_PLUGIN_NAME{"WherePlugin"}; -class WherePlugin : public TensorRTPlugin { - public: - explicit WherePlugin(const std::string name) : TensorRTPlugin(name, std::string(WHERE_PLUGIN_NAME)) {} - - WherePlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - : TensorRTPlugin(std::string(name), std::string(WHERE_PLUGIN_NAME)) {} - - WherePlugin(const char *name, const void *serialData, size_t serialLength) - : TensorRTPlugin(std::string(name), std::string(WHERE_PLUGIN_NAME)) {} - - WherePlugin() = delete; - - nvinfer1::IPluginV2DynamicExt *clone() const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override; - size_t getSerializationSize() const noexcept override; - void serialize(void *buffer) const noexcept override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const - noexcept override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, - int nbOutputs) noexcept override; - - private: - int RunCudaWhere(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, void *const *outputs, - cudaStream_t stream); -}; -class WherePluginCreater : public TensorRTPluginCreater { - public: - WherePluginCreater() : TensorRTPluginCreater(std::string(WHERE_PLUGIN_NAME)) {} -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_OP_WHERE_PLUGIN_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/op_registration_factory.h b/mindspore-lite/src/extendrt/delegate/tensorrt/op_registration_factory.h deleted file mode 100644 index 705ae8cdba873709ff007131e5198a1b4555478a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/op_registration_factory.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_AUTO_REGISTRATION_FACTORY_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_AUTO_REGISTRATION_FACTORY_H_ - -#include - -namespace mindspore::lite { -template -class AutoRegistrationFactory { - public: - struct AutoRegister { - AutoRegister(KeyType k, CreatorType creator) { - AutoRegistrationFactory::Get().Insert(k, creator); - } - }; - static AutoRegistrationFactory &Get(); - bool HasKey(KeyType k) const { return key2creator_.find(k) != key2creator_.end(); } - CreatorType GetCreator(KeyType k) { return key2creator_[k]; } - - private: - bool Insert(KeyType k, CreatorType creator) { - if (HasKey(k)) { - return false; - } - return key2creator_.emplace(k, creator).second; - } - std::unordered_map key2creator_; -}; - -#define AUTO_REGISTRATION_FACTORY_JOIN(a, b) a##b - -#define AUTO_REGISTRATION_FACTORY_UNIQUE_NAME_JOIN(a, b) AUTO_REGISTRATION_FACTORY_JOIN(a, b) - -#define AUTO_REGISTRATION_FACTORY_UNIQUE_NAME AUTO_REGISTRATION_FACTORY_UNIQUE_NAME_JOIN(g_, __COUNTER__) - -#define REGISTER_CLASS_CREATOR(KeyType, k, CreatorType, creator) \ - static AutoRegistrationFactory::AutoRegister AUTO_REGISTRATION_FACTORY_UNIQUE_NAME(k, creator); -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_AUTO_REGISTRATION_FACTORY_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.cc deleted file mode 100644 index eebb15529bb23c1772f57b168475bc9ba74682be..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.cc +++ /dev/null @@ -1,280 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.h" -#include "mindspore/ops/op_def/lite_ops.h" -#include "mindspore/ops/op_def/array_ops.h" -#include "mindspore/ops/op_def/framework_ops.h" -#include "nnacl/op_base.h" -#include "include/common/utils/anfalgo.h" -#include "mindspore/ccsrc/include/common/utils/utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" - -namespace mindspore { -tensor::TensorPtr TensorRtOptimizer::GetParameterValue(const CNodePtr &node, size_t parameter_index) { - if (node == nullptr) { - return nullptr; - } - auto input = common::AnfAlgo::GetInputNode(node, parameter_index); - if (!input->isa()) { - return nullptr; - } - auto parameter = input->cast(); - if (parameter == nullptr || !parameter->has_default()) { - return nullptr; - } - auto param_val = parameter->default_param(); - if (!param_val->isa()) { - return nullptr; - } - return param_val->cast(); -} - -std::vector TensorRtOptimizer::GetParameterIntValue(const CNodePtr &node, size_t parameter_index) { - auto tensor = GetParameterValue(node, parameter_index); - if (tensor == nullptr) { - MS_LOG(ERROR) << "tensor is nullptr!"; - return {}; - } - auto elem_num = tensor->ElementsNum(); - if (elem_num < 1) { - return {}; - } - auto data_c = tensor->data_c(); - if (data_c == nullptr) { - return {}; - } - std::vector ints; - auto type_id = tensor->Dtype()->type_id(); - if (type_id == TypeId::kNumberTypeInt32) { - auto int_val = reinterpret_cast(data_c); - for (int64_t i = 0; i < elem_num; i++) { - ints.push_back(int_val[i]); - } - } else if (type_id == TypeId::kNumberTypeInt64) { - auto int_val = reinterpret_cast(data_c); - for (int64_t i = 0; i < elem_num; i++) { - ints.push_back(LongToInt(int_val[i])); - } - } else { - return {}; - } - return ints; -} - -std::vector TensorRtOptimizer::GetParameterFloatValue(const CNodePtr &node, size_t parameter_index) { - auto tensor = GetParameterValue(node, parameter_index); - if (tensor == nullptr) { - MS_LOG(ERROR) << "tensor is nullptr!"; - return {}; - } - auto elem_num = tensor->ElementsNum(); - if (elem_num < 1) { - return {}; - } - auto data_c = tensor->data_c(); - if (data_c == nullptr) { - return {}; - } - std::vector floats; - auto type_id = tensor->Dtype()->type_id(); - if (type_id == TypeId::kNumberTypeInt32) { - auto int_val = reinterpret_cast(data_c); - for (int64_t i = 0; i < elem_num; i++) { - floats.push_back(IntToFloat(int_val[i])); - } - } else if (type_id == TypeId::kNumberTypeInt64) { - auto int_val = reinterpret_cast(data_c); - for (int64_t i = 0; i < elem_num; i++) { - floats.push_back(LongToFloat(int_val[i])); - } - } else if (type_id == TypeId::kNumberTypeFloat32) { - auto float_val = reinterpret_cast(data_c); - for (int64_t i = 0; i < elem_num; i++) { - floats.push_back(float_val[i]); - } - } else { - return {}; - } - return floats; -} - -bool TensorRtOptimizer::GetMatmulFactor(const AnfNodePtr &pack_input, float *matmul_factor, int32_t *sclice_index, - AnfNodePtr *shape_input) { - constexpr size_t expect_matmul_input_size = 2; - if (!common::AnfAlgo::CheckPrimitiveType(pack_input, prim::kPrimMulFusion)) { - return false; - } - auto matmul = pack_input->cast(); - if (common::AnfAlgo::GetInputNum(matmul) != expect_matmul_input_size) { - return false; - } - auto matmul_factors = GetParameterFloatValue(matmul, kIndex1); - if (matmul_factors.size() != 1) { - return false; - } - *matmul_factor = matmul_factors[0]; - auto matmul_input0 = common::AnfAlgo::GetInputNode(matmul, kIndex0); - if (!common::AnfAlgo::CheckPrimitiveType(matmul_input0, prim::kPrimStridedSlice)) { - return false; - } - auto slice_node = matmul_input0->cast(); - constexpr size_t slice_input_size = 4; - if (common::AnfAlgo::GetInputNum(slice_node) != slice_input_size) { - return false; - } - auto begin_vec = GetParameterIntValue(slice_node, kIndex1); - auto end_vec = GetParameterIntValue(slice_node, kIndex2); - auto stride_vec = GetParameterIntValue(slice_node, kIndex3); - if (begin_vec.size() != 1 || end_vec.size() != 1 || stride_vec.size() != 1) { - return false; - } - if (begin_vec[0] + 1 != end_vec[0]) { - return false; - } - *sclice_index = begin_vec[0]; - auto slice_input = common::AnfAlgo::GetInputNode(slice_node, kIndex0); - if (!common::AnfAlgo::CheckPrimitiveType(slice_input, prim::kPrimShape) && - !common::AnfAlgo::CheckPrimitiveType(slice_input, prim::kPrimTensorShape) && - !common::AnfAlgo::CheckPrimitiveType(slice_input, prim::kPrimDynamicShape)) { - return false; - } - auto shape_node = slice_input->cast(); - constexpr size_t reshape_input_size = 1; - if (common::AnfAlgo::GetInputNum(shape_node) < reshape_input_size) { - return false; - } - auto shape_input0 = common::AnfAlgo::GetInputNode(shape_node, kIndex0); - while (common::AnfAlgo::CheckPrimitiveType(shape_input0, prim::kPrimTranspose)) { - shape_input0 = common::AnfAlgo::GetInputNode(shape_input0->cast(), kIndex0); - } - *shape_input = shape_input0; - return true; -} - -bool TensorRtOptimizer::OptResizeScales(const FuncGraphPtr &func_graph, const CNodePtr &resize_node) { - auto resize_input1 = common::AnfAlgo::GetInputNode(resize_node, kIndex1); - if (resize_input1 == nullptr) { - return false; - } - if (!common::AnfAlgo::CheckPrimitiveType(resize_input1, prim::kPrimStack)) { - return false; - } - auto pack_node = resize_input1->cast(); - constexpr size_t expect_pack_input_size = 2; - if (common::AnfAlgo::GetInputNum(pack_node) != expect_pack_input_size) { - return false; - } - auto pack_input0 = common::AnfAlgo::GetInputNode(pack_node, kIndex0); - auto pack_input1 = common::AnfAlgo::GetInputNode(pack_node, kIndex1); - - float matmul_factor0 = 0.0; - int32_t slice_dim_input0 = 0; - AnfNodePtr shape0_input = nullptr; - if (!GetMatmulFactor(pack_input0, &matmul_factor0, &slice_dim_input0, &shape0_input)) { - return false; - } - float matmul_factor1 = 0.0; - int32_t slice_dim_input1 = 0; - AnfNodePtr shape1_input = nullptr; - if (!GetMatmulFactor(pack_input1, &matmul_factor1, &slice_dim_input1, &shape1_input)) { - return false; - } - auto resize_input0 = common::AnfAlgo::GetInputNode(resize_node, kIndex0); - while (common::AnfAlgo::CheckPrimitiveType(resize_input0, prim::kPrimTranspose)) { - resize_input0 = common::AnfAlgo::GetInputNode(resize_input0->cast(), kIndex0); - } - if (resize_input0 != shape0_input || resize_input0 != shape1_input) { - return false; - } - if (matmul_factor0 <= 0.0f || matmul_factor1 <= 0.0f) { - return false; - } - std::vector scales; - scales.push_back(1); - scales.push_back(1); - if ((slice_dim_input0 == kNCHW_H && slice_dim_input1 == kNCHW_W) || - (slice_dim_input0 == kNHWC_H && slice_dim_input1 == kNHWC_W)) { // 1,2 or 2,3 - scales.push_back(matmul_factor0); - scales.push_back(matmul_factor1); - } else if ((slice_dim_input1 == kNCHW_H && slice_dim_input0 == kNCHW_W) || - (slice_dim_input1 == kNHWC_H && slice_dim_input0 == kNHWC_W)) { - scales.push_back(matmul_factor1); - scales.push_back(matmul_factor0); - } else { - return false; - } - common::AnfAlgo::SetNodeAttr(kAttrScales, MakeValue(scales), resize_node); - return true; -} - -bool TensorRtOptimizer::OptResizeHeightWidth(const FuncGraphPtr &func_graph, const CNodePtr &resize_node) { - auto resize_input1 = common::AnfAlgo::GetInputNode(resize_node, kIndex1); - if (resize_input1 == nullptr) { - return false; - } - if (!common::AnfAlgo::CheckPrimitiveType(resize_input1, prim::kPrimGather)) { - return false; - } - auto gather_node = resize_input1->cast(); - constexpr size_t expect_gather_input_size = 2; - if (common::AnfAlgo::GetInputNum(gather_node) < expect_gather_input_size) { - return false; - } - auto gather_const = GetParameterIntValue(gather_node, kIndex1); - if (gather_const.size() != kDim2 || gather_const[0] != kNCHW_H || gather_const[1] != kNCHW_W) { - return false; - } - auto gather_input0 = common::AnfAlgo::GetInputNode(gather_node, kIndex0); - if (!common::AnfAlgo::CheckPrimitiveType(gather_input0, prim::kPrimConcat)) { - return false; - } - // input 0 is primitive, real input 1 index is kIndex1 + 1 - resize_node->set_input(kIndex1 + 1, gather_input0); - return true; -} - -void TensorRtOptimizer::RunOptimizer(const FuncGraphPtr &func_graph) { - MS_ASSERT(func_graph != nullptr); - auto node_list = TopoSort(func_graph->get_return()); - constexpr size_t resize_input_size = 2; - for (auto &node : node_list) { - if (node == nullptr) { - continue; - } - if (!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimResize)) { - continue; - } - auto resize_cnode = node->cast(); - if (common::AnfAlgo::GetInputNum(resize_cnode) != resize_input_size) { - MS_LOG(WARNING) << "Input size " << common::AnfAlgo::GetInputNum(resize_cnode) << " of resize node " - << resize_cnode->fullname_with_scope() << " != 2"; - continue; - } - if (OptResizeScales(func_graph, resize_cnode)) { - continue; - } - if (OptResizeHeightWidth(func_graph, resize_cnode)) { - continue; - } - } -} -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.h b/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.h deleted file mode 100644 index 86aedba1bcd2895113c1dbf095af5c65e45d47c5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_OPTIMIZER_H -#define MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_OPTIMIZER_H -#include - -#include "include/api/kernel.h" -#include "include/errorcode.h" -#include "src/common/log_adapter.h" -#include "include/api/context.h" -#include "extendrt/delegate/factory.h" -#include "extendrt/session/lite_graph_executor.h" - -namespace mindspore { -class TensorRtOptimizer { - public: - void RunOptimizer(const FuncGraphPtr &func_graph); - - private: - bool OptResizeScales(const FuncGraphPtr &func_graph, const CNodePtr &resize_node); - bool OptResizeHeightWidth(const FuncGraphPtr &func_graph, const CNodePtr &resize_node); - tensor::TensorPtr GetParameterValue(const CNodePtr &node, size_t parameter_index); - std::vector GetParameterIntValue(const CNodePtr &node, size_t parameter_index); - std::vector GetParameterFloatValue(const CNodePtr &node, size_t parameter_index); - bool GetMatmulFactor(const AnfNodePtr &pack_input, float *matmul_factor, int32_t *sclice_index, - AnfNodePtr *shape_input); -}; -} // namespace mindspore -#endif // MINDSPORE_LITE_SRC_EXTENDRT_TENSORRT_OPTIMIZER_H diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensor_info.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/tensor_info.cc deleted file mode 100644 index 6a82200f33f24beb77d1cb3ac386282aa44aa79f..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensor_info.cc +++ /dev/null @@ -1,162 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/tensor_info.h" -#include -#include -#include -#include -#include "include/api/kernel.h" -#include "src/common/utils.h" -#include "src/extendrt/utils/tensor_default_impl.h" - -namespace mindspore::lite { -class TensorInfoImpl { - public: - TensorInfoImpl() {} - TensorInfoImpl(const std::string &name, mindspore::DataType type, const std::vector &shape, - mindspore::Format format, const void *data, size_t data_len, - const mindspore::tensor::TensorPtr &tensor_val) - : tensor_impl_(name, type, shape), tensor_val_(tensor_val) { - tensor_impl_.SetFormat(format); - auto is_const = (data != nullptr); - tensor_impl_.SetIsConst(is_const); - SetData(data, data_len); - } - - size_t item_size() const { return DataTypeSize(static_cast(tensor_impl_.DataType())); } - void SetData(const void *data, size_t data_len) { - if (data != nullptr && data_len != 0) { - if (tensor_impl_.DataSize() != data_len) { - MS_LOG(INFO) << "Tensor expect data size " << tensor_impl_.DataSize() << " != data len " << data_len - << ", shape: " << tensor_impl_.Shape() << ", dtype: " << tensor_impl_.DataType(); - } - tensor_impl_.SetData(const_cast(data), false); - } - } - TensorDefaultImpl tensor_impl_; - mindspore::tensor::TensorPtr tensor_val_ = nullptr; -}; - -TensorInfo::TensorInfo(const std::string &name, mindspore::DataType type, const std::vector &shape, - mindspore::Format format, const void *data, size_t data_len, - const mindspore::tensor::TensorPtr &tensor_val) { - impl_ = std::make_shared(name, type, shape, format, data, data_len, tensor_val); -} - -std::string TensorInfo::Name() const { - if (impl_ == nullptr) { - return ""; - } - return impl_->tensor_impl_.Name(); -} - -mindspore::DataType TensorInfo::DataType() const { - if (impl_ == nullptr) { - return mindspore::DataType::kTypeUnknown; - } - return impl_->tensor_impl_.DataType(); -} - -mindspore::Format TensorInfo::format() const { - if (impl_ == nullptr) { - return DEFAULT_FORMAT; - } - return impl_->tensor_impl_.Format(); -} - -const std::vector &TensorInfo::Shape() const { - static const std::vector empty_shape; - if (impl_ == nullptr) { - return empty_shape; - } - return impl_->tensor_impl_.Shape(); -} - -const void *TensorInfo::Data() const { - if (impl_ == nullptr) { - return nullptr; - } - return impl_->tensor_impl_.Data().get(); -} - -void *TensorInfo::MutableData() { - if (impl_ == nullptr) { - return nullptr; - } - return const_cast(impl_->tensor_impl_.MutableData()); -} - -size_t TensorInfo::DataSize() const { - if (impl_ == nullptr) { - return 0; - } - return ElementNum() * item_size(); -} - -bool TensorInfo::IsConst() const { - if (impl_ == nullptr) { - return 0; - } - return impl_->tensor_impl_.IsConst(); -} - -size_t TensorInfo::item_size() const { - if (impl_ == nullptr) { - return 0; - } - return impl_->item_size(); -} - -void TensorInfo::SetShape(const std::vector &shape) { - if (impl_ == nullptr) { - return; - } - impl_->tensor_impl_.SetShape(shape); -} - -void TensorInfo::SetDataType(const mindspore::DataType data_type) { - if (impl_ == nullptr) { - return; - } - impl_->tensor_impl_.SetDataType(data_type); -} - -void TensorInfo::SetData(const void *data, size_t data_len) { - if (impl_ == nullptr) { - return; - } - impl_->SetData(data, data_len); -} - -int64_t TensorInfo::ElementNum() const { - if (impl_ == nullptr) { - return 0; - } - return impl_->tensor_impl_.ElementNum(); -} - -TensorInfo &TensorInfo::operator=(const TensorInfo &other) { - impl_ = other.impl_; - return *this; -} - -bool TensorInfo::operator==(const TensorInfo &other) const { return impl_ == other.impl_; } - -bool TensorInfo::operator!=(const TensorInfo &other) const { return impl_ != other.impl_; } - -bool TensorInfo::operator<(const TensorInfo &other) const { return impl_ < other.impl_; } -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensor_info.h b/mindspore-lite/src/extendrt/delegate/tensorrt/tensor_info.h deleted file mode 100644 index 41b51a2c41c2607138a5eca4caee40a1af443d97..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensor_info.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSOR_INFO_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSOR_INFO_H_ - -#include -#include -#include -#include -#include "include/api/kernel.h" -#include "ir/tensor.h" - -namespace mindspore::lite { -class TensorInfoImpl; -class TensorInfo { - public: - TensorInfo() = default; - TensorInfo(const std::string &name, mindspore::DataType type, const std::vector &shape, - mindspore::Format format, const void *data, size_t data_len, - const mindspore::tensor::TensorPtr &tensor_val); - ~TensorInfo() = default; - - std::string Name() const; - mindspore::DataType DataType() const; - mindspore::Format format() const; - const std::vector &Shape() const; - int64_t ElementNum() const; - const void *Data() const; - void *MutableData(); - size_t DataSize() const; - - bool IsConst() const; - - void SetShape(const std::vector &shape); - void SetDataType(const mindspore::DataType data_type); - void SetData(const void *data, size_t data_len); - - size_t item_size() const; - - TensorInfo &operator=(const TensorInfo &other); - bool operator==(const TensorInfo &other) const; - bool operator!=(const TensorInfo &other) const; - bool operator<(const TensorInfo &other) const; - - private: - std::shared_ptr impl_ = nullptr; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSOR_INFO_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_allocator.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_allocator.cc deleted file mode 100755 index fb1cff07d7c6ca808aa97ecaa39ab74dd3522438..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_allocator.cc +++ /dev/null @@ -1,235 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/tensorrt_allocator.h" -#include -#include -#include "src/common/log_adapter.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/cast.cuh" - -namespace mindspore::lite { -void *TensorRTAllocator::MallocDeviceMem(const TensorInfo &host_tensor, size_t size) { - return MallocDeviceMem(host_tensor.Name(), size, ConvertDataType(host_tensor.DataType())); -} - -void *TensorRTAllocator::MallocDeviceMem(const std::string &name, size_t size, nvinfer1::DataType data_type) { - if (cuda_tensor_map_.find(name) != cuda_tensor_map_.end() && size <= cuda_tensor_map_[name].size) { - MS_LOG(DEBUG) << "tensor :" << name << " has already in cuda Allocator pool."; - return cuda_tensor_map_[name].data; - } - void *device_ptr = nullptr; - auto cuda_ret = cudaMalloc(&device_ptr, size); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "Cuda Malloc failed for size:" << size; - return nullptr; - } - MS_LOG(INFO) << "cudaMalloc size: " << size << " for " << name; - if (cuda_tensor_map_[name].data != nullptr) { - cuda_ret = cudaFree(cuda_tensor_map_[name].data); - if (cuda_ret != cudaSuccess && cuda_ret != cudaErrorCudartUnloading) { - MS_LOG(ERROR) << "free old cuda device_ptr failed for " << cudaGetErrorName(cuda_ret); - cuda_ret = cudaFree(device_ptr); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "free new cuda device_ptr failed for " << cudaGetErrorName(cuda_ret); - return nullptr; - } - return nullptr; - } - } - cuda_tensor_map_[name].data = device_ptr; - cuda_tensor_map_[name].is_valid_mem = false; - cuda_tensor_map_[name].size = size; - return device_ptr; -} - -void TensorRTAllocator::MarkMemValid(const std::string &name, bool isValid) { - cuda_tensor_map_[name].is_valid_mem = isValid; - return; -} - -bool TensorRTAllocator::GetMemIsValid(const std::string &name) { - if (cuda_tensor_map_.find(name) == cuda_tensor_map_.end()) { - MS_LOG(WARNING) << "tensor :" << name << " not in cuda Allocator pool."; - return false; - } - return cuda_tensor_map_[name].is_valid_mem; -} - -void *TensorRTAllocator::GetDevicePtr(const std::string &tensor_name) { - if (tensor_name.empty()) { - return nullptr; - } - if (cuda_tensor_map_.find(tensor_name) == cuda_tensor_map_.end()) { - return nullptr; - } - return this->cuda_tensor_map_.find(tensor_name)->second.data; -} - -int TensorRTAllocator::SyncMemHostToDevice(const tensor::Tensor &host_tensor, const std::string &device_tensor_name, - bool sync, size_t size) { - size = (size == 0) ? host_tensor.Size() : size; - return SyncMemInHostAndDevice(const_cast(host_tensor.data_c()), device_tensor_name, size, true, sync); -} - -int TensorRTAllocator::SyncMemDeviceToHost(tensor::Tensor *host_tensor, const std::string &device_tensor_name, - bool sync) { - if (host_tensor == NULL) { - MS_LOG(ERROR) << "host tensor is null."; - return RET_ERROR; - } -#if TRT_VERSION_GE(7, 2) - if (host_tensor->data_type() == TypeId::kNumberTypeBool) { - CudaTensorParam ¤t_cuda_tensor = cuda_tensor_map_.find(device_tensor_name)->second; - auto device_ptr = current_cuda_tensor.data; - if (device_ptr == nullptr) { - MS_LOG(ERROR) << "device_ptr is null for " << device_tensor_name; - return RET_ERROR; - } - int *host_ptr = reinterpret_cast(malloc(host_tensor->DataSize() * sizeof(int))); - cudaError_t cuda_ret; - if (sync) { - cuda_ret = cudaMemcpy(host_ptr, device_ptr, host_tensor->DataSize() * sizeof(int), cudaMemcpyDeviceToHost); - } else { - cuda_ret = - cudaMemcpyAsync(host_ptr, device_ptr, host_tensor->DataSize() * sizeof(int), cudaMemcpyDeviceToHost, stream_); - } - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret); - return RET_ERROR; - } - bool *host_tensor_ptr = static_cast(host_tensor->data_c()); - for (size_t i = 0; i != host_tensor->Size(); ++i) { - host_tensor_ptr[i] = (host_ptr[i] != 0); - } - free(host_ptr); - return RET_OK; - } -#endif - return SyncMemDeviceToHost(host_tensor->data_c(), host_tensor->Size(), device_tensor_name, sync); -} - -int TensorRTAllocator::SyncMemDeviceToHost(void *dst_data, size_t data_size, const std::string &device_tensor_name, - bool sync) { - if (dst_data == nullptr) { - MS_LOG(ERROR) << " dst host data cannot be nullptr."; - return RET_ERROR; - } - auto it = cuda_tensor_map_.find(device_tensor_name); - if (it == cuda_tensor_map_.end()) { - MS_LOG(ERROR) << " cannot find device address " << device_tensor_name; - return RET_ERROR; - } - CudaTensorParam ¤t_cuda_tensor = it->second; - // is memcpy from device to host, the host mem is valid, change tag for mem pool. - current_cuda_tensor.is_valid_mem = true; - auto device_ptr = current_cuda_tensor.data; - if (device_ptr == nullptr) { - MS_LOG(ERROR) << "device_ptr is null for " << device_tensor_name; - return RET_ERROR; - } - cudaError_t cuda_ret; - if (sync) - cuda_ret = cudaMemcpy(dst_data, device_ptr, data_size, cudaMemcpyDeviceToHost); - else - cuda_ret = cudaMemcpyAsync(dst_data, device_ptr, data_size, cudaMemcpyDeviceToHost, stream_); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret); - return RET_ERROR; - } - MS_LOG(INFO) << "cuda memcpy success for " << device_tensor_name; - return RET_OK; -} - -int TensorRTAllocator::SyncMemInHostAndDevice(tensor::Tensor *host_tensor, const std::string &device_tensor_name, - bool is_host2device, bool sync) { - if (host_tensor == NULL) { - MS_LOG(ERROR) << "host tensor is null."; - return RET_ERROR; - } -#if TRT_VERSION_GE(7, 2) - if (host_tensor->data_type() == TypeId::kNumberTypeBool && !is_host2device) { - CudaTensorParam ¤t_cuda_tensor = cuda_tensor_map_.find(device_tensor_name)->second; - auto device_ptr = current_cuda_tensor.data; - if (device_ptr == nullptr) { - MS_LOG(ERROR) << "device_ptr is null for " << device_tensor_name; - return RET_ERROR; - } - int *host_ptr = reinterpret_cast(malloc(host_tensor->DataSize())); - auto cuda_ret = cudaMemcpy(host_ptr, device_ptr, host_tensor->DataSize(), cudaMemcpyDeviceToHost); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret); - return RET_ERROR; - } - bool *host_tensor_ptr = static_cast(host_tensor->data_c()); - for (size_t i = 0; i != host_tensor->Size(); ++i) { - host_tensor_ptr[i] = (host_ptr[i] != 0); - } - free(host_ptr); - return RET_OK; - } -#endif - return SyncMemInHostAndDevice(host_tensor->data_c(), device_tensor_name, host_tensor->Size(), is_host2device, sync); -} - -int TensorRTAllocator::SyncMemInHostAndDevice(void *host_data, const std::string &device_tensor_name, size_t data_size, - bool is_host2device, bool sync) { - if (host_data == nullptr || cuda_tensor_map_.find(device_tensor_name) == cuda_tensor_map_.end()) { - MS_LOG(ERROR) << " host or device ptr is null."; - return RET_ERROR; - } - CudaTensorParam ¤t_cuda_tensor = cuda_tensor_map_.find(device_tensor_name)->second; - // is memcpy from device to host, the host mem is valid, change tag for mem pool. - current_cuda_tensor.is_valid_mem = is_host2device ? current_cuda_tensor.is_valid_mem : true; - if (is_host2device && current_cuda_tensor.is_valid_mem) { - MS_LOG(DEBUG) << "no need memcpy for: " << device_tensor_name; - return RET_OK; - } - auto device_ptr = current_cuda_tensor.data; - if (device_ptr == nullptr) { - MS_LOG(ERROR) << "device_ptr is null for " << device_tensor_name; - return RET_ERROR; - } - - void *src_ptr = is_host2device ? host_data : device_ptr; - void *dst_ptr = is_host2device ? device_ptr : host_data; - cudaMemcpyKind kind = is_host2device ? cudaMemcpyHostToDevice : cudaMemcpyDeviceToHost; - cudaError_t cuda_ret; - if (sync) - cuda_ret = cudaMemcpy(dst_ptr, src_ptr, data_size, kind); - else - cuda_ret = cudaMemcpyAsync(dst_ptr, src_ptr, data_size, kind, stream_); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret); - return RET_ERROR; - } - MS_LOG(INFO) << "cuda memcpy success for " << device_tensor_name; - return RET_OK; -} - -int TensorRTAllocator::ClearDeviceMem() { - for (auto &iter : cuda_tensor_map_) { - auto cuda_ret = cudaFree(iter.second.data); - if (cuda_ret != cudaSuccess && cuda_ret != cudaErrorCudartUnloading) { - MS_LOG(WARNING) << "free cuda failed for " << cudaGetErrorName(cuda_ret); - } - iter.second.data = nullptr; - iter.second.is_valid_mem = false; - } - return RET_OK; -} -std::map TensorRTAllocator::GetAllDevicePtr() { return this->cuda_tensor_map_; } -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_allocator.h b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_allocator.h deleted file mode 100644 index deb43c078f7a632d5939ed2949e67501fd4bc555..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_allocator.h +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_ALLOCATOR_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_ALLOCATOR_H_ -#include "src/extendrt/delegate/tensorrt/tensorrt_allocator.h" -#include -#include -#include -#include "include/api/types.h" -#include "ir/tensor.h" -#include "src/extendrt/delegate/tensorrt/tensor_info.h" - -namespace mindspore::lite { -struct CudaTensorParam { - void *data = nullptr; - bool is_valid_mem = false; - size_t size = 0; -}; -class TensorRTAllocator { - public: - TensorRTAllocator() = default; - - ~TensorRTAllocator() = default; - - void *MallocDeviceMem(const TensorInfo &host_tensor, size_t size); - - void *MallocDeviceMem(const std::string &name, size_t size, nvinfer1::DataType data_type); - - void *GetDevicePtr(const std::string &tensor_name); - - void SetCudaStream(cudaStream_t stream) { stream_ = stream; } - - std::map GetAllDevicePtr(); - - int SyncMemInHostAndDevice(tensor::Tensor *host_tensor, const std::string &device_tensor_name, bool is_host2device, - bool sync = true); - - int SyncMemInHostAndDevice(void *host_data, const std::string &device_tensor_name, size_t data_size, - bool is_host2device, bool sync = true); - int SyncMemHostToDevice(const tensor::Tensor &host_tensor, const std::string &device_tensor_name, bool sync = true, - size_t size = 0); - int SyncMemDeviceToHost(tensor::Tensor *host_tensor, const std::string &device_tensor_name, bool sync = true); - int SyncMemDeviceToHost(void *dst_data, size_t data_size, const std::string &device_tensor_name, bool sync = true); - - int ClearDeviceMem(); - - void MarkMemValid(const std::string &name, bool isValid); - - bool GetMemIsValid(const std::string &name); - - private: - std::map cuda_tensor_map_; - cudaStream_t stream_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_ALLOCATOR_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_context.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_context.cc deleted file mode 100644 index 018de7e4f9648586672d9c75e999739b256ff72d..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_context.cc +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/tensorrt_context.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" - -namespace mindspore::lite { -TensorRTContext::~TensorRTContext() { - if (network_ != nullptr) { - network_->destroy(); - network_ = nullptr; - } - for (auto ptr : owner_memorys_) { - free(ptr); - } -} - -bool TensorRTContext::Init() { - network_ = runtime_->GetBuilder()->createNetworkV2( - 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)); - if (network_ == nullptr) { - MS_LOG(ERROR) << "New network init failed."; - return false; - } - return true; -} - -void TensorRTContext::SetRuntime(TensorRTRuntime *runtime) { runtime_ = runtime; } - -nvinfer1::INetworkDefinition *TensorRTContext::network() { return network_; } - -void TensorRTContext::RegisterLayer(nvinfer1::ILayer *layer, const std::string &basename) { - if (layer == nullptr) { - MS_LOG(ERROR) << "Register null layer!"; - return; - } - MS_LOG(DEBUG) << "ms_layer " << basename << " register"; - layer->setName((basename + "_" + std::to_string(counter_++)).c_str()); -} - -void TensorRTContext::RegisterTensor(ITensorHelper tensor, const std::string &basename) { - MS_LOG(DEBUG) << GetTensorFormat(tensor); - std::string trt_name = basename + "_" + std::to_string(counter_++); - tensor.trt_tensor_->setName(trt_name.c_str()); - MS_LOG(DEBUG) << "ms_tensor " << basename << " register to " << trt_name; - ms_name2trt_tensor_[basename] = tensor; -} - -void TensorRTContext::RegisterTensorWithSameName(ITensorHelper tensor, const std::string &basename) { - std::string trt_name = basename; - tensor.trt_tensor_->setName(trt_name.c_str()); - MS_LOG(DEBUG) << "ms_tensor " << basename << " register to " << trt_name; - ms_name2trt_tensor_[basename] = tensor; -} - -bool TensorRTContext::HasTensor(const std::string &name) const { - return ms_name2trt_tensor_.find(name) != ms_name2trt_tensor_.end(); -} - -ITensorHelper TensorRTContext::MsName2Tensor(const std::string &ms_name) { - if (ms_name2trt_tensor_.find(ms_name) != ms_name2trt_tensor_.end()) { - return ms_name2trt_tensor_[ms_name]; - } - MS_LOG(WARNING) << "Get Tensorrt tensor by ms_tensor: " << ms_name << " fail!"; - return {}; -} - -template -nvinfer1::ITensor *TensorRTContext::ConvertTo0DTensor(T value) { - void *ptr = malloc(sizeof(T)); - memcpy(ptr, reinterpret_cast(&value), sizeof(T)); - owner_memorys_.push_back(ptr); - - nvinfer1::Weights weights{GetNvinferDataType(), ptr, 1}; - nvinfer1::Dims dims{}; - nvinfer1::IConstantLayer *constant_tensor = network()->addConstant(dims, weights); - if (constant_tensor == nullptr) { - MS_LOG(ERROR) << "create constant_tensor failed."; - return nullptr; - } - return constant_tensor->getOutput(0); -} - -template -nvinfer1::ITensor *TensorRTContext::ConvertTo1DTensor(T value) { - return ConvertTo1DTensor(std::vector{value}); -} - -template -nvinfer1::ITensor *TensorRTContext::ConvertTo1DTensor(const std::vector &values) { - void *ptr = malloc(values.size() * sizeof(T)); - const T *begin = &values[0]; - memcpy(ptr, reinterpret_cast(begin), values.size() * sizeof(T)); - owner_memorys_.push_back(ptr); - - nvinfer1::Weights weights{GetNvinferDataType(), ptr, static_cast(values.size())}; - nvinfer1::Dims dims{1, {static_cast(values.size())}}; - nvinfer1::IConstantLayer *constant_tensor = network()->addConstant(dims, weights); - if (constant_tensor == nullptr) { - MS_LOG(ERROR) << "create constant_tensor failed."; - return nullptr; - } - return constant_tensor->getOutput(0); -} - -template nvinfer1::ITensor *TensorRTContext::ConvertTo0DTensor(int value); -template nvinfer1::ITensor *TensorRTContext::ConvertTo0DTensor(float value); -template nvinfer1::ITensor *TensorRTContext::ConvertTo1DTensor(int value); -template nvinfer1::ITensor *TensorRTContext::ConvertTo1DTensor(float value); -template nvinfer1::ITensor *TensorRTContext::ConvertTo1DTensor(const std::vector &values); -template nvinfer1::ITensor *TensorRTContext::ConvertTo1DTensor(const std::vector &values); -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_context.h b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_context.h deleted file mode 100644 index fa7f6ca368b1cd261540ebf03893987b9c7affe0..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_context.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_TENSORRT_CONTEXT_H_ -#define MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_TENSORRT_CONTEXT_H_ - -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h" - -namespace mindspore::lite { -struct ITensorHelper { - nvinfer1::ITensor *trt_tensor_{nullptr}; - mindspore::Format format_{Format::NCHW}; - bool same_format_{true}; - bool is_tensor{true}; -}; -class TensorRTContext { - public: - TensorRTContext() = default; - ~TensorRTContext(); - bool Init(); - void SetRuntime(TensorRTRuntime *runtime); - nvinfer1::INetworkDefinition *network(); - void RegisterLayer(nvinfer1::ILayer *layer, const std::string &basename); - void RegisterTensor(ITensorHelper tensor, const std::string &basename); - void RegisterTensorWithSameName(ITensorHelper tensor, const std::string &basename); - bool HasTensor(const std::string &name) const; - ITensorHelper MsName2Tensor(const std::string &ms_name); - - template - nvinfer1::ITensor *ConvertTo0DTensor(T value); - template - nvinfer1::ITensor *ConvertTo1DTensor(T value); - template - nvinfer1::ITensor *ConvertTo1DTensor(const std::vector &values); - - private: - int counter_{0}; - nvinfer1::INetworkDefinition *network_{nullptr}; - std::unordered_map ms_name2trt_tensor_; - TensorRTRuntime *runtime_{nullptr}; - std::vector owner_memorys_; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_DELEGATE_TENSORRT_TENSORRT_CONTEXT_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.cc deleted file mode 100644 index 83c8414a63c7018df991968f940fe4d59c65de3e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.cc +++ /dev/null @@ -1,671 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/tensorrt_graph_executor.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindspore/ops/op_def/framework_ops.h" -#include "src/extendrt/delegate/delegate_utils.h" -#include "common/common_utils.h" -#include "ccsrc/include/backend/optimizer/helper.h" -#include "ccsrc/include/common/utils/convert_utils.h" -#include "common/config_infos.h" -#include "tools/optimizer/common/gllo_utils.h" -#include "src/extendrt/utils/func_graph_utils.h" -#include "src/extendrt/delegate/tensorrt/optimizer/tensorrt_optimizer.h" -#include "infer/custom.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore::lite { -namespace { -const char tensorrt_provider[] = "tensorrt"; - -struct NodeWithOutputIndex { - NodeWithOutputIndex() = default; - NodeWithOutputIndex(session::KernelWithIndex kernel_index, TensorInfo tensor_info) - : kernel_index(kernel_index), tensor_info(tensor_info) {} - - session::KernelWithIndex kernel_index; - TensorInfo tensor_info; -}; - -ValuePtr GetNodeValuePtr(AnfNodePtr input_node) { - if (input_node == nullptr) { - return nullptr; - } - if (IsPrimitiveCNode(input_node, prim::kPrimDepend)) { - input_node = AnfUtils::VisitKernel(input_node, 0).first; - } - ValuePtr value = nullptr; - if (input_node->isa() && !HasAbstractMonad(input_node)) { - auto value_node = input_node->cast(); - if (value_node) { - value = value_node->value(); - } - } else if (input_node->isa()) { - auto parameter = input_node->cast(); - if (parameter->has_default()) { - value = parameter->default_param(); - } - } - return value; -} - -tensor::TensorPtr GetConstNodeValue(AnfNodePtr input_node) { - ValuePtr value = GetNodeValuePtr(input_node); - if (value == nullptr) { - return nullptr; - } - if (value->isa()) { - auto tensor = value->cast(); - if (tensor == nullptr || tensor->data().const_data() == nullptr) { - return nullptr; - } - return tensor; - } - if (value->isa()) { - return ScalarToTensor(value->cast()); - } - if (value->isa()) { - return opt::CreateTupleTensor(value->cast()); - } - if (value->isa()) { - auto type_ptr = value->cast(); - if (type_ptr == nullptr) { - return nullptr; - } - return std::make_shared(static_cast(type_ptr->type_id()), type_ptr->type()); - } - MS_LOG(WARNING) << "Unexpected value type " << value->type_name() << " for " << input_node->fullname_with_scope(); - return nullptr; -} - -TensorInfo KernelTensorAsTensorInfo(const session::KernelWithIndex &tensor_id) { - auto prev_node = tensor_id.first; - auto tensor_val = GetConstNodeValue(prev_node); - - constexpr auto tensorrt_format = mindspore::Format::NCHW; - auto name = FuncGraphUtils::GetTensorName(tensor_id); - auto shape = FuncGraphUtils::GetTensorShape(tensor_id); - auto datatype = FuncGraphUtils::GetTensorDataType(tensor_id); - auto format = tensorrt_format; - const void *data = nullptr; - size_t data_len = 0; - if (tensor_val) { - data = tensor_val->data_c(); - data_len = tensor_val->Size(); - shape = tensor_val->shape_c(); - } - TensorInfo tensor_info(name, datatype, shape, format, data, data_len, tensor_val); - return tensor_info; -} - -Status GetAbstractArgsFromCNode(const CNodePtr &cnode, std::vector *tensor_info_list_ptr, - BaseOperatorPtr *base_operator, std::vector *input_tensors, - std::vector *output_tensors) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(tensor_info_list_ptr); - auto &tensor_info_list = *tensor_info_list_ptr; - auto prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(prim); - auto kernel_name = prim->name(); - ops::PrimitiveCPtr primc_ptr = nullptr; - static auto primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap(); - if (primc_fns.find(kernel_name) != primc_fns.end()) { - primc_ptr = primc_fns[kernel_name](); - (void)primc_ptr->SetAttrs(prim->attrs()); - } - if (primc_ptr == nullptr) { - MS_LOG(ERROR) << "OpPrimCRegister can not find " << kernel_name; - return mindspore::kLiteError; - } - - *base_operator = nullptr; - static auto operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap(); - if (operator_fns.find(kernel_name) != operator_fns.end()) { - *base_operator = operator_fns[kernel_name](primc_ptr); - } - MS_EXCEPTION_IF_NULL(*base_operator); - // Makeup input tensors. - input_tensors->clear(); - auto input_nodes = FuncGraphUtils::GetNodeInputs(cnode); - for (auto &tensor_id : input_nodes) { - auto it = std::find_if(tensor_info_list.begin(), tensor_info_list.end(), - [&tensor_id](const NodeWithOutputIndex &index) { return index.kernel_index == tensor_id; }); - if (it != tensor_info_list.end()) { - input_tensors->push_back(it->tensor_info); - } else { - auto tensor_info = KernelTensorAsTensorInfo(tensor_id); - input_tensors->push_back(tensor_info); - tensor_info_list.push_back(NodeWithOutputIndex(tensor_id, tensor_info)); - } - } - // Makeup output tensors. - output_tensors->clear(); - auto output_num = AnfUtils::GetOutputTensorNum(cnode); - for (size_t output_idx = 0; output_idx < output_num; ++output_idx) { - session::KernelWithIndex tensor_id = {cnode, output_idx}; - auto it = std::find_if(tensor_info_list.begin(), tensor_info_list.end(), - [&tensor_id](const NodeWithOutputIndex &index) { return index.kernel_index == tensor_id; }); - if (it != tensor_info_list.end()) { - output_tensors->push_back(it->tensor_info); - } else { - auto tensor_info = KernelTensorAsTensorInfo(tensor_id); - output_tensors->push_back(tensor_info); - tensor_info_list.push_back(NodeWithOutputIndex(tensor_id, tensor_info)); - } - } - return kSuccess; -} - -Status GetModelInputsInfo(const FuncGraphPtr &func_graph, std::vector *tensor_info_list_ptr, - std::vector *input_tensors) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(tensor_info_list_ptr); - MS_EXCEPTION_IF_NULL(input_tensors); - auto &tensor_info_list = *tensor_info_list_ptr; - std::vector inputs; - FuncGraphUtils::GetFuncGraphInputs(func_graph, &inputs); - // find parameters of graph inputs - for (auto &tensor_id : inputs) { - auto it = std::find_if(tensor_info_list.begin(), tensor_info_list.end(), - [&tensor_id](const NodeWithOutputIndex &index) { return index.kernel_index == tensor_id; }); - if (it != tensor_info_list.end()) { - input_tensors->push_back(it->tensor_info); - } else { - auto tensor_info = KernelTensorAsTensorInfo(tensor_id); - input_tensors->push_back(tensor_info); - tensor_info_list.push_back(NodeWithOutputIndex(tensor_id, tensor_info)); - } - } - return kSuccess; -} - -Status GetModelOutputsInfo(const FuncGraphPtr &func_graph, std::vector *tensor_info_list_ptr, - std::vector *output_tensors) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(tensor_info_list_ptr); - auto &tensor_info_list = *tensor_info_list_ptr; - std::vector outputs; - FuncGraphUtils::GetFuncGraphOutputs(func_graph, &outputs); - for (auto &tensor_id : outputs) { - auto it = std::find_if(tensor_info_list.begin(), tensor_info_list.end(), - [&tensor_id](const NodeWithOutputIndex &index) { return index.kernel_index == tensor_id; }); - if (it != tensor_info_list.end()) { - output_tensors->push_back(it->tensor_info); - } else { - auto tensor_info = KernelTensorAsTensorInfo(tensor_id); - output_tensors->push_back(tensor_info); - tensor_info_list.push_back(NodeWithOutputIndex(tensor_id, tensor_info)); - } - } - return kSuccess; -} -} // namespace -TensorRTExecutor::TensorRTExecutor(const std::shared_ptr &context, const ConfigInfos &config_infos) - : context_(context), config_infos_(config_infos) {} - -TensorRTExecutor::~TensorRTExecutor() { - // delete tensorrt_graph_ before delete runtime - tensorrt_graph_.reset(); - if (runtime_ != nullptr) { - delete runtime_; - } - if (stream_ != nullptr) { - cudaStreamDestroy(stream_); - } - if (cublas_handle_ != nullptr) { - cublasDestroy(cublas_handle_); - cublas_handle_ = nullptr; - } - if (cublaslt_handle_ != nullptr) { - cublasLtDestroy(cublaslt_handle_); - cublaslt_handle_ = nullptr; - } -} -bool IsHardwareSupport() { - int driver_version = 0; - int ret = cudaDriverGetVersion(&driver_version); - if (ret != cudaSuccess || driver_version == 0) { - MS_LOG(WARNING) << "No nvidia GPU driver."; - return false; - } - return true; -} - -bool TensorRTExecutor::Init() { - if (!IsHardwareSupport()) { - return false; - } - if (context_ == nullptr) { - MS_LOG(ERROR) << "Context cannot be nullptr"; - return false; - } - - std::vector> device_list = context_->MutableDeviceInfo(); - auto iter = std::find_if(device_list.begin(), device_list.end(), [](std::shared_ptr device) { - return device->GetDeviceType() == DeviceType::kGPU; - }); - if (iter == device_list.end()) { - MS_LOG(ERROR) << "no gpu device info found for TensorRT."; - return false; - } - auto gpu_info = (*iter)->Cast(); - if (gpu_info == nullptr) { - MS_LOG(ERROR) << "no gpu device info found for TensorRT."; - return false; - } - device_info_ = gpu_info; - int ret = lite::SetCudaDevice(device_info_); - if (ret != RET_OK) { - return false; - } - if (runtime_ == nullptr) { - runtime_ = new (std::nothrow) TensorRTRuntime(); - if (runtime_ == nullptr) { - MS_LOG(ERROR) << "create TensorRTRuntime failed."; - return false; - } - } - if (runtime_->Init() != RET_OK) { - MS_LOG(ERROR) << "TensorRTRuntime init failed."; - return false; - } - runtime_->SetDeviceID(device_info_->GetDeviceID()); - - auto cuda_ret = cudaStreamCreate(&stream_); - if (cuda_ret != cudaSuccess) { - MS_LOG(ERROR) << "Cuda create stream failed"; - return false; - } - - auto cublas_ret = cublasCreate(&cublas_handle_); - if (cublas_ret != CUBLAS_STATUS_SUCCESS) { - MS_LOG(ERROR) << "Cuda create cublas handle failed"; - return false; - } - - auto cublaslt_ret = cublasLtCreate(&cublaslt_handle_); - if (cublaslt_ret != CUBLAS_STATUS_SUCCESS) { - MS_LOG(ERROR) << "Cuda create cublaslt handle failed"; - return false; - } - - ret = ParseOptimizationProfile(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Parse input ranges failed."; - return false; - } - ret = ParseTransformerProfile(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Parse transformer failed."; - return false; - } - return true; -} - -int TensorRTExecutor::ParseOptimizationProfile() { - auto gpu_context_it = config_infos_.find(kGPUContextSection); - if (gpu_context_it == config_infos_.end()) { - MS_LOG(INFO) << "do not have input ranges config."; - return RET_OK; - } - auto &gpu_context = gpu_context_it->second; - ProfileConfigs profile_configs; - if (!ProfileParser::Parse(gpu_context, true, &profile_configs)) { - MS_LOG(WARNING) << "Failed to parse profile info from '" << kGPUContextSection << "'"; - return RET_FAILED; - } - trt_profile_configs_ = profile_configs; - auto precision_mode = ProfileParser::GetOption(gpu_context, lite::kPrecisionModeKey, ""); - if (precision_mode.empty()) { - device_info_->SetPrecisionMode("enforce_fp32"); - } else { - device_info_->SetPrecisionMode(precision_mode); - } - serialize_path_ = ProfileParser::GetOption(gpu_context, lite::kMSCacheSerializePathKey); - return RET_OK; -} - -int TensorRTExecutor::ParseTransformerProfile() { - auto transformer_context_it = config_infos_.find(kTransformerSection); - if (transformer_context_it == config_infos_.end()) { - MS_LOG(INFO) << "do not have input ranges config."; - return RET_OK; - } - - auto &transformer_context = transformer_context_it->second; - int encoder_input = -1; - int decoder_input = -1; - std::string optimize_transformer = ""; - try { - encoder_input = std::stoi(ProfileParser::GetOption(transformer_context, lite::kEncoderInputKey, "-1").c_str()); - } catch (...) { - MS_LOG(ERROR) << "The value of encoder_input must be int."; - } - runtime_->SetTransformerEncoderInputIdx(encoder_input); - try { - decoder_input = std::stoi(ProfileParser::GetOption(transformer_context, lite::kDecoderInputKey, "-1").c_str()); - } catch (...) { - MS_LOG(ERROR) << "The value of decoder_input must be int."; - } - runtime_->SetTransformerDecoderInputIdx(decoder_input); - auto is_ffn_f16 = ProfileParser::GetOption(transformer_context, lite::kFfnFp16Key, "true"); - if (is_ffn_f16 == "true") { - runtime_->SetTransformerFfnFp16(true); - } else if (is_ffn_f16 == "false") { - runtime_->SetTransformerFfnFp16(false); - } else { - MS_LOG(ERROR) << "The value of ffn_f16 must be true or false."; - return RET_ERROR; - } - optimize_transformer = ProfileParser::GetOption(transformer_context, lite::kOptimizeTransformer, ""); - runtime_->SetTransformerOptimize(optimize_transformer); - return RET_OK; -} - -int TensorRTExecutor::ParseDumpOptions(const std::map &gpu_context) { - auto dump_ops_str = ProfileParser::GetOption(gpu_context, lite::kDumpOpsKey, ""); - if (!dump_ops_str.empty()) { - dump_ops_ = lite::StrSplit(dump_ops_str, ";"); - dump_dir_ = ProfileParser::GetOption(gpu_context, lite::kDumpDirKey, ""); - if (dump_dir_.empty()) { - dump_dir_ = "."; - } - } - return RET_OK; -} - -Status TensorRTExecutor::BuildSubGraph(const FuncGraphPtr &func_graph) { - std::vector tensorrt_ops; - int tensorrt_subgraph_index = 0; - - auto nodes = func_graph->TopoSort(func_graph->get_return()); - if (nodes.empty()) { - MS_LOG(ERROR) << "There are no nodes in the graph"; - return mindspore::kLiteNullptr; - } - std::vector tensor_info_list; - auto status = GetModelInputsInfo(func_graph, &tensor_info_list, &inputs_); - if (status != kSuccess) { - return status; - } - for (const auto &node : nodes) { - auto cnode = node->cast(); - if (!cnode || !AnfUtils::IsRealKernel(cnode)) { - continue; - } - auto node_name = node->fullname_with_scope(); - BaseOperatorPtr base_operator = nullptr; - std::vector input_tensors; - std::vector output_tensors; - status = GetAbstractArgsFromCNode(cnode, &tensor_info_list, &base_operator, &input_tensors, &output_tensors); - if (status != kSuccess || base_operator == nullptr) { - MS_LOG(ERROR) << "Failed to get operator of node " << node_name; - return mindspore::kLiteError; - } - auto tensorrt_op = FindTensorRTOp(cnode, base_operator, input_tensors, output_tensors); - if (tensorrt_op == nullptr) { - MS_LOG(ERROR) << "FindTensorRTOp failed " << node_name; - return mindspore::kLiteError; - } - tensorrt_op->SetRuntime(this->runtime_); - tensorrt_ops.push_back(tensorrt_op); - if (!dump_ops_.empty() && std::find(dump_ops_.begin(), dump_ops_.end(), node_name) != dump_ops_.end()) { - std::copy(output_tensors.begin(), output_tensors.end(), std::back_inserter(dump_outputs_)); - } - } - status = GetModelOutputsInfo(func_graph, &tensor_info_list, &outputs_); - if (status != kSuccess) { - return status; - } - for (auto &out_tensor_info : outputs_) { - if (out_tensor_info.DataType() == DataType::kNumberTypeFloat16) { - MS_LOG(INFO) << "output " << out_tensor_info.Name() << " is Float16, set to Float32"; - out_tensor_info.SetDataType(DataType::kNumberTypeFloat32); - } - } - std::vector trt_outputs = outputs_; - std::copy(dump_outputs_.begin(), dump_outputs_.end(), std::back_inserter(trt_outputs)); - tensorrt_graph_ = CreateTensorRTGraph(tensorrt_ops, func_graph, tensorrt_subgraph_index, inputs_, trt_outputs); - if (!tensorrt_graph_) { - MS_LOG(ERROR) << "Create tensorrt graph failed"; - return mindspore::kLiteError; - } - return mindspore::kSuccess; -} - -TensorRTOp *TensorRTExecutor::FindTensorRTOp(const CNodePtr &cnode, const BaseOperatorPtr &base_operator, - const std::vector &input_tensors, - const std::vector &output_tensors) { - auto name = cnode->fullname_with_scope(); - auto node_type = base_operator->name(); - auto &plugin_factory = TensorRTRegistrationFactory::Get(); - if (node_type == ops::kNameCustom) { - if (common::AnfAlgo::HasNodeAttr("unique_name", cnode)) { - node_type = common::AnfAlgo::GetNodeAttr(cnode, "unique_name"); - } - } - if (plugin_factory.HasKey(node_type)) { - TensorRTOp *tensorrt_op = plugin_factory.GetCreator(node_type)(base_operator, input_tensors, output_tensors, name); - if (tensorrt_op == nullptr) { - return nullptr; - } - if (!support_resize_) { - return tensorrt_op; - } - support_resize_ = tensorrt_op->GetDynamicShapeParams().support_dynamic_ ? support_resize_ : false; - if (!tensorrt_op->GetDynamicShapeParams().support_dynamic_) { - MS_LOG(WARNING) << "TensorRT subgraph don't support dynamic shape resize, because of op " << name; - support_hw_resize_ = false; - return tensorrt_op; - } - if (!support_hw_resize_) { - return tensorrt_op; - } - support_hw_resize_ = tensorrt_op->GetDynamicShapeParams().support_hw_dynamic_ ? support_hw_resize_ : false; - if (!tensorrt_op->GetDynamicShapeParams().support_hw_dynamic_) { - MS_LOG(WARNING) << "TensorRT subgraph don't support dynamic hw dims resize, because of op " << name; - } - return tensorrt_op; - } else { - MS_LOG(WARNING) << "Unsupported op type for TensorRT. kernel name:" << name << " type:" << node_type; - return nullptr; - } -} - -std::shared_ptr TensorRTExecutor::CreateTensorRTGraph(const std::vector &ops, - const FuncGraphPtr &graph, int index, - const std::vector &inputs, - const std::vector &outputs) { - if (!trt_profile_configs_.input_infos.empty()) { - std::vector input_names; - std::transform(inputs.begin(), inputs.end(), std::back_inserter(input_names), - [](auto &item) { return item.Name(); }); - if (!ProfileParser::ReorderByInputNames(input_names, &trt_profile_configs_)) { - MS_LOG(ERROR) << "Reorder profile by input names failed, input names: " << input_names; - return nullptr; - } - } - - auto tensorrt_graph = std::make_shared(ops, inputs, outputs, context_.get(), device_info_, runtime_, - support_resize_, support_hw_resize_, trt_profile_configs_); - if (tensorrt_graph == nullptr) { - MS_LOG(ERROR) << "new tensorrt_graph failed."; - return nullptr; - } - if (serialize_path_.size() > 0) { - tensorrt_graph->SetSerializePath(serialize_path_ + "_trt" + std::to_string(GetRankID()) + ".bin_" + - std::to_string(index)); - } - // 1. For every op, find pre and next ops - FindPreNextOps(ops); - - // 2. Init TensorRT SubGraph. - auto ret = tensorrt_graph->Init(stream_, cublas_handle_, cublaslt_handle_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "TensorRTGraph init failed."; - return nullptr; - } - - // 3. Build TensorRT Model. - ret = tensorrt_graph->BuildTensorRTGraph(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "TensorRTGraph build failed."; - return nullptr; - } - ret = tensorrt_graph->Prepare(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "TensorRTGraph prepare failed."; - return nullptr; - } - return tensorrt_graph; -} - -bool TensorRTExecutor::CompileGraph(const FuncGraphPtr &graph, const std::map &compile_options, - uint32_t *graph_id) { - if (graph == nullptr || graph_id == nullptr) { - MS_LOG(ERROR) << "Input param graph or graph id is nullptr"; - return false; - } - *graph_id = 0; - TensorRtOptimizer optimizer; - optimizer.RunOptimizer(graph); - - int ret = lite::SetCudaDevice(device_info_); - if (ret != RET_OK) { - return false; - } - Status build_ret = BuildSubGraph(graph); - if (build_ret != kSuccess) { - MS_LOG(INFO) << "BuildSubGraph failed"; - return false; - } - return true; -} - -bool TensorRTExecutor::RunGraph(uint32_t, const std::vector &inputs, - std::vector *outputs, const std::map &compile_options) { - if (inputs.size() != inputs_.size()) { - MS_LOG(ERROR) << "Graph inputs size " << inputs_.size() << " != execute outputs size " << inputs.size(); - return false; - } - if (tensorrt_graph_ == nullptr) { - MS_LOG(ERROR) << "TensorRT subgraph is nullptr."; - return false; - } - if (dump_outputs_.empty()) { - if (!outputs->empty() && outputs_.size() != outputs->size()) { - MS_LOG(ERROR) << "Graph outputs size " << inputs_.size() << " != expected outputs size " << outputs->size(); - return false; - } - return tensorrt_graph_->Execute(inputs, outputs) == RET_OK; - } - - if (!outputs->empty()) { - MS_LOG(ERROR) << "Cannot has graph outputs when dump op"; - return false; - } - std::vector trt_outputs; - if (tensorrt_graph_->Execute(inputs, &trt_outputs) != RET_OK) { - return false; - } - if (trt_outputs.size() != outputs_.size() + dump_outputs_.size()) { - MS_LOG(ERROR) << "TensorRT Graph outputs size " << trt_outputs.size() << " != graph outputs size " - << outputs_.size() << " + dump output size " << dump_outputs_.size(); - return false; - } - for (size_t i = 0; i < outputs_.size(); i++) { - outputs->push_back(trt_outputs[i]); - } - if (!has_dumped_) { - has_dumped_ = true; - auto dump_tensor = [this](const std::string &file_name, const tensor::Tensor &tensor) { - std::string new_file = file_name; - for (size_t i = 0; i < new_file.size(); i++) { - if (new_file[i] == '/' || new_file[i] == '\\') { - new_file[i] = '_'; - } - } - std::ofstream fp(dump_dir_ + "/" + new_file, std::ofstream::binary); - if (!fp.is_open()) { - MS_LOG(WARNING) << "Failed to open file " << dump_dir_ + "/" + file_name; - return; - } - fp.write(reinterpret_cast(tensor.data_c()), tensor.Size()); - }; - for (size_t i = 0; i < inputs.size(); i++) { - dump_tensor("input_" + std::to_string(i) + ".bin", inputs[i]); - } - for (size_t i = 0; i < outputs->size(); i++) { - dump_tensor("output_" + std::to_string(i) + ".bin", (*outputs)[i]); - } - for (size_t i = outputs_.size(); i < trt_outputs.size(); i++) { - auto tensor_info = dump_outputs_[i - outputs_.size()]; - dump_tensor(tensor_info.Name() + ".bin", trt_outputs[i]); - } - } - return true; -} - -bool TensorRTExecutor::Resize(uint32_t, const std::vector &inputs, - const std::vector> &new_shapes) { - if (tensorrt_graph_ == nullptr) { - MS_LOG(ERROR) << "TensorRT subgraph is nullptr."; - return false; - } - return tensorrt_graph_->Resize(inputs, new_shapes) == RET_OK; -} - -std::vector TensorRTExecutor::GetInputInfos(uint32_t) { - std::vector tensors; - for (auto &tensor_info : inputs_) { - auto type_id = static_cast(tensor_info.DataType()); - auto shape = tensor_info.Shape(); - tensors.push_back(tensor::Tensor(type_id, shape)); - } - return tensors; -} - -std::vector TensorRTExecutor::GetOutputInfos(uint32_t) { - std::vector tensors; - for (auto &tensor_info : outputs_) { - auto type_id = static_cast(tensor_info.DataType()); - auto shape = tensor_info.Shape(); - tensors.push_back(tensor::Tensor(type_id, shape)); - } - return tensors; -} - -static std::shared_ptr TensorRTGraphExecutorCreator(const std::shared_ptr &ctx, - const ConfigInfos &config_infos) { - auto executor = std::make_shared(ctx, config_infos); - if (!executor->Init()) { - return nullptr; - } - return executor; -} - -REG_DELEGATE(kGPU, tensorrt_provider, TensorRTGraphExecutorCreator); -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.h b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.h deleted file mode 100644 index 648c858c848b96b5ddd896eee8d8db845ee06a81..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_graph_executor.h +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_GRAPH_EXECUTOR_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_GRAPH_EXECUTOR_H_ -#include -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_subgraph.h" -#include "src/extendrt/delegate/parameter_cache/embedding_cache_manager.h" -#include "include/api/kernel.h" -#include "include/errorcode.h" -#include "src/common/log_adapter.h" -#include "include/api/context.h" -#include "base/base.h" -#include "extendrt/delegate/factory.h" -#include "extendrt/session/lite_graph_executor.h" -#include "include/backend/kernel_graph.h" - -namespace mindspore::lite { -struct TrtGraphContext { - std::vector tensorrt_ops; - std::vector inputs; - std::vector outputs; - std::shared_ptr sub_graph = nullptr; -}; - -class TensorRTExecutor : public LiteGraphExecutor { - public: - TensorRTExecutor(const std::shared_ptr &context, const ConfigInfos &config_infos); - - ~TensorRTExecutor() override; - - bool Init(); - - bool CompileGraph(const FuncGraphPtr &graph, const std::map &compile_options, - uint32_t *graph_id) override; - bool RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector *outputs, - const std::map &compile_options) override; - - bool Resize(uint32_t graph_id, const std::vector &inputs, - const std::vector> &new_shapes) override; - std::vector GetInputInfos(uint32_t graph_id) override; - std::vector GetOutputInfos(uint32_t graph_id) override; - - private: - int ParseOptimizationProfile(); - - int ParseTransformerProfile(); - - Status BuildSubGraph(const FuncGraphPtr &graph); - - TensorRTOp *FindTensorRTOp(const CNodePtr &cnode, const BaseOperatorPtr &base_operator, - const std::vector &input_tensors, - const std::vector &output_tensors); - - std::shared_ptr CreateTensorRTGraph(const std::vector &ops, const FuncGraphPtr &graph, - int index, const std::vector &inputs, - const std::vector &outputs); - int ParseDumpOptions(const std::map &gpu_context); - - std::shared_ptr context_{nullptr}; - ConfigInfos config_infos_; - std::shared_ptr device_info_{nullptr}; - TensorRTRuntime *runtime_{nullptr}; - bool support_hw_resize_{true}; - bool support_resize_{true}; - const std::string cache_model_path_; - size_t vocab_size_{0}; - size_t device_cache_size_{0}; - std::string serialize_path_; - cudaStream_t stream_{nullptr}; - cublasHandle_t cublas_handle_{nullptr}; - cublasLtHandle_t cublaslt_handle_{nullptr}; - - std::vector kernel_list_; - - ProfileConfigs trt_profile_configs_; - - std::shared_ptr tensorrt_graph_ = nullptr; - std::vector inputs_; - std::vector outputs_; - std::vector dump_outputs_; - std::vector dump_ops_; - std::string dump_dir_; - bool has_dumped_ = false; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_GRAPH_EXECUTOR_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_runtime.h b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_runtime.h deleted file mode 100644 index 3b09b32fdaaa5fb6b7ed9bd001fe92bc22766eb5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_runtime.h +++ /dev/null @@ -1,128 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_RUNTIME_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_RUNTIME_H_ -#include -#include -#include "include/errorcode.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_allocator.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h" -#include "src/common/log_adapter.h" -#define MAX_BATCH_SIZE 64 - -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; - -namespace mindspore::lite { -class TensorRTLogger : public nvinfer1::ILogger { - void log(Severity severity, const char *msg) noexcept override { - if (severity == Severity::kINTERNAL_ERROR || severity == Severity::kERROR) { - MS_LOG(WARNING) << msg; - } else if (severity == Severity::kWARNING) { - MS_LOG(WARNING) << msg; - } else if (severity == Severity::kINFO) { - MS_LOG(INFO) << msg; - } else { - MS_LOG(DEBUG) << msg; - } - } -}; - -enum RuntimePrecisionMode : int { RuntimePrecisionMode_FP32, RuntimePrecisionMode_FP16 }; - -class TensorRTRuntime { - public: - TensorRTRuntime() = default; - - ~TensorRTRuntime(); - - int Init(); - - nvinfer1::IBuilder *GetBuilder() { return this->builder_; } - - int GetBatchSize() { return batch_size_; } - - void SetBatchSize(int batch_size) { batch_size_ = batch_size; } - - void SetCudaStream(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle) { - allocator_->SetCudaStream(stream); - cublas_handle_ = cublas_handle; - cublaslt_handle_ = cublaslt_handle; - } - - RuntimePrecisionMode GetRuntimePrecisionMode() { return runtime_percision_mode_; } - - int GetTransformerEncoderInputIdx() { return transformer_encoder_input_idx_; } - - int GetTransformerDecoderInputIdx() { return transformer_decoder_input_idx_; } - - bool GetTransformerFfnFp16() { return transformer_ffn_fp16_; } - - std::string GetTransformerOptimize() const { return optimize_transformer_; } - - int GetVslEncoderPluginId() { return vsl_encoder_plugin_id_; } - - int GetVslDecoderPluginId() { return vsl_decoder_plugin_id_; } - - void SetRuntimePrecisionMode(RuntimePrecisionMode runtime_percision_mode) { - runtime_percision_mode_ = runtime_percision_mode; - } - - void SetTransformerEncoderInputIdx(int transformer_encoder_input_idx) { - transformer_encoder_input_idx_ = transformer_encoder_input_idx; - } - - void SetTransformerDecoderInputIdx(int transformer_decoder_input_idx) { - transformer_decoder_input_idx_ = transformer_decoder_input_idx; - } - void SetTransformerFfnFp16(bool is_ffn_fp16) { transformer_ffn_fp16_ = is_ffn_fp16; } - void SetTransformerOptimize(const std::string &optimize_transformer) { optimize_transformer_ = optimize_transformer; } - - bool IsTransformerOptimizeSigma() { - std::string pangu_sigma("pangu_sigma"); - return (optimize_transformer_ == pangu_sigma) ? true : false; - } - void SetVslEncoderPluginId(int plugin_id) { vsl_encoder_plugin_id_ = plugin_id; } - - void SetVslDecoderPluginId(int plugin_id) { vsl_decoder_plugin_id_ = plugin_id; } - - TensorRTAllocator *GetAllocator() { return this->allocator_; } - - void SetDeviceID(uint32_t device_id) { device_id_ = device_id; } - - uint32_t GetDeviceID() { return device_id_; } - cublasHandle_t GetCublasHandle() { return cublas_handle_; } - cublasLtHandle_t GetCublasLtHandle() { return cublaslt_handle_; } - - private: - bool is_init_{false}; - nvinfer1::IBuilder *builder_{nullptr}; - TensorRTLogger logger_; - TensorRTAllocator *allocator_{nullptr}; - int batch_size_{0}; - uint32_t device_id_{0}; - RuntimePrecisionMode runtime_percision_mode_{RuntimePrecisionMode::RuntimePrecisionMode_FP32}; - int transformer_encoder_input_idx_{-1}; - int transformer_decoder_input_idx_{-1}; - bool transformer_ffn_fp16_{true}; - std::string optimize_transformer_{""}; - int vsl_encoder_plugin_id_{-1}; - int vsl_decoder_plugin_id_{-1}; - cublasHandle_t cublas_handle_{nullptr}; - cublasLtHandle_t cublaslt_handle_{nullptr}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_RUNTIME_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_serializer.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_serializer.cc deleted file mode 100644 index 597440d7ae030fcfbaf000bbac96bc28f8e2b771..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_serializer.cc +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/tensorrt_serializer.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h" -#include "src/common/file_utils.h" - -namespace mindspore::lite { -nvinfer1::ICudaEngine *TensorRTSerializer::GetSerializedEngine() { - if (serialize_file_path_.empty()) { - return nullptr; - } - char *trt_model_stream{nullptr}; - size_t size{0}; - trt_model_stream = ReadFile(serialize_file_path_.c_str(), &size); - if (trt_model_stream == nullptr || size == 0) { - MS_LOG(WARNING) << "read engine file failed : " << serialize_file_path_; - return nullptr; - } - nvinfer1::IRuntime *runtime = nvinfer1::createInferRuntime(logger_); - if (runtime == nullptr) { - delete[] trt_model_stream; - MS_LOG(ERROR) << "createInferRuntime failed."; - return nullptr; - } - nvinfer1::ICudaEngine *engine = runtime->deserializeCudaEngine(trt_model_stream, size, nullptr); - delete[] trt_model_stream; - runtime->destroy(); - return engine; -} -void TensorRTSerializer::SaveSerializedEngine(nvinfer1::ICudaEngine *engine) { - if (serialize_file_path_.size() == 0) { - return; - } - nvinfer1::IHostMemory *ptr = engine->serialize(); - if (ptr == nullptr) { - MS_LOG(ERROR) << "serialize engine failed"; - return; - } - - int ret = WriteToBin(serialize_file_path_, ptr->data(), ptr->size()); - if (ret != RET_OK) { - MS_LOG(ERROR) << "save engine failed " << serialize_file_path_; - } else { - MS_LOG(INFO) << "save engine to " << serialize_file_path_; - } - ptr->destroy(); - return; -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc deleted file mode 100644 index 5b725be233bd08eafd04bafc1d57a65fe0cb402c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc +++ /dev/null @@ -1,962 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/tensorrt_subgraph.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/delegate_utils.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "src/common/utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "infer/cxx_api/topk_fusion.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" - -namespace mindspore::lite { -TensorRTSubGraph::TensorRTSubGraph(std::vector ops, const std::vector &inputs, - const std::vector &outputs, const mindspore::Context *ctx, - std::shared_ptr device_info, TensorRTRuntime *runtime, - bool support_resize, bool support_hw_resize, - const ProfileConfigs &trt_profile_config) - : inputs_(inputs), - outputs_(outputs), - all_ops_(std::move(ops)), - device_info_(device_info), - runtime_(runtime), - trt_profile_config_(trt_profile_config) { - trt_specific_weight_handled_inner_ = { - ops::kNameTranspose, ops::kNameReshape, ops::kNameExpandDims, ops::kNameTopKFusion, ops::kNameBroadcastTo, - }; - if (!support_resize) { - input_batchsize_index_ = -1; - input_hw_index_ = -1; - } - if (!support_hw_resize) { - input_hw_index_ = -1; - } -} - -TensorRTSubGraph::~TensorRTSubGraph() { - if (ctx_ != nullptr) { - delete ctx_; - } - if (config_ != nullptr) { - config_->destroy(); - config_ = nullptr; - } - if (trt_context_ != nullptr) { -#ifdef PROFILER_ - auto profile = dynamic_cast(trt_context_->getProfiler()); - if (profile != nullptr) { - std::cout << *profile << std::endl; - delete profile; - } -#endif - trt_context_->destroy(); - trt_context_ = nullptr; - } - if (engine_ != nullptr) { - engine_->destroy(); - engine_ = nullptr; - } - if (tensor_bindings_ != nullptr) { - delete[] tensor_bindings_; - tensor_bindings_ = nullptr; - } - for (auto op : all_ops_) { - delete op; - } -} - -bool TensorRTSubGraph::IsValidProfileDims() const { - if (trt_profile_config_.profiles.empty()) { - MS_LOG(INFO) << "Number of profiles is 0."; - return false; - } - for (auto &profile : trt_profile_config_.profiles) { - if (profile.inputs.size() != trt_profile_config_.input_infos.size()) { - MS_LOG(WARNING) << "Profile input size " << profile.inputs.size() << " != input shape size " - << trt_profile_config_.input_infos.size(); - return false; - } - for (size_t i = 0; i < profile.inputs.size(); i++) { - const auto &profile_input = profile.inputs[i]; - const auto &input_info = trt_profile_config_.input_infos[i]; - if (profile_input.min_dims.size() != input_info.input_shape.size()) { - MS_LOG(WARNING) << "Profile input " << input_info.name << " min dims number " << profile_input.min_dims.size() - << " != input shape dim number " << input_info.input_shape.size(); - return false; - } - if (profile_input.max_dims.size() != input_info.input_shape.size()) { - MS_LOG(WARNING) << "Profile input " << input_info.name << " max dims number " << profile_input.max_dims.size() - << " != input shape dim number " << input_info.input_shape.size(); - return false; - } - if (profile_input.opt_dims.size() != input_info.input_shape.size()) { - MS_LOG(WARNING) << "Profile input " << input_info.name << " opt dims number " << profile_input.opt_dims.size() - << " != input shape dim number " << input_info.input_shape.size(); - return false; - } - } - } - return true; -} - -int TensorRTSubGraph::Init(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle) { - auto ret = GetGraphInOutOps(inputs_, outputs_, &in_ops_, &out_ops_, all_ops_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Get TensorRT subgraph input and output ops failed."; - return RET_ERROR; - } - ctx_ = new (std::nothrow) TensorRTContext(); - if (ctx_ == nullptr) { - MS_LOG(ERROR) << "New TensorRTContext failed."; - return RET_ERROR; - } - ctx_->SetRuntime(runtime_); - if (!ctx_->Init()) { - MS_LOG(ERROR) << "New TensorRTContext failed."; - return RET_ERROR; - } - if (SetDeviceConfig(stream, cublas_handle, cublaslt_handle) != RET_OK) { - MS_LOG(WARNING) << "set tensorrt config failed."; - } - serializer_ = std::make_shared(serialize_file_path_); - if (serializer_ == nullptr) { - MS_LOG(ERROR) << "create Serializer failed."; - return RET_ERROR; - } - using_input_ranges_ = IsValidProfileDims(); - if (using_input_ranges_) { - for (size_t i = 0; i != trt_profile_config_.profiles.size(); ++i) { - profiles_.push_back(runtime_->GetBuilder()->createOptimizationProfile()); - } - } else { - profiles_.push_back(runtime_->GetBuilder()->createOptimizationProfile()); - } - for (size_t i = 0; i != profiles_.size(); ++i) { - if (profiles_[i] == nullptr) { - MS_LOG(ERROR) << "create optimization profile failed."; - return RET_ERROR; - } - } - engine_ = serializer_->GetSerializedEngine(); - if (engine_ != nullptr) { - MS_LOG(INFO) << "using serialized engine " << serialize_file_path_; - return RET_OK; - } - for (size_t i = 0; i < inputs_.size(); i++) { - if (inputs_[i].Shape().size() != DIMENSION_4D) { - input_hw_index_ = -1; - } - } - return RET_OK; -} - -int TensorRTSubGraph::BuildEngine() { - // print all network ops - for (auto &profile : profiles_) { - if (this->config_->addOptimizationProfile(profile) == -1) { - MS_LOG(ERROR) << "addOptimizationProfile failed."; - return RET_ERROR; - } - } - MS_LOG(INFO) << "build engine for tensorrt network: " << ctx_->network()->getName(); - for (int i = 0; i < ctx_->network()->getNbLayers(); i++) { - MS_LOG(DEBUG) << "tensorrt op: " << ctx_->network()->getLayer(i)->getName(); - } - MS_LOG(DEBUG) << "end of tensorrt network: " << ctx_->network()->getName(); - - this->engine_ = runtime_->GetBuilder()->buildEngineWithConfig(*ctx_->network(), *this->config_); - if (this->engine_ == nullptr) { - MS_LOG(ERROR) << "Create engine failed in TensorRT network"; - return RET_ERROR; - } - if (serialize_file_path_.size() > 0) { - serializer_->SaveSerializedEngine(engine_); - } - return RET_OK; -} - -int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream, cublasHandle_t cublas_handle, - cublasLtHandle_t cublaslt_handle) { - if (config_ == nullptr) { - this->config_ = runtime_->GetBuilder()->createBuilderConfig(); - if (this->config_ == nullptr) { - MS_LOG(ERROR) << "create builder config failed."; - return RET_ERROR; - } - } - // set fp16 - if (device_info_->GetEnableFP16() && runtime_->GetBuilder()->platformHasFastFp16()) { - MS_LOG(INFO) << "set fp16 flag successfully for tensorrt."; - config_->setFlag(nvinfer1::BuilderFlag::kFP16); - runtime_->SetRuntimePrecisionMode(RuntimePrecisionMode_FP16); - } - - // set int8 - if (IsInt8Mode() && runtime_->GetBuilder()->platformHasFastInt8()) { - MS_LOG(INFO) << "set int8 flag successfully for tensorrt."; - config_->setFlag(nvinfer1::BuilderFlag::kINT8); - // Mark calibrator as null - config_->setInt8Calibrator(nullptr); - input_hw_index_ = -1; - } else { - MS_LOG(INFO) << "inputs no quant params or platform not support int8."; - } - runtime_->SetCudaStream(stream, cublas_handle, cublaslt_handle); - config_->setProfileStream(stream); - stream_ = stream; - - MS_LOG(INFO) << GetRankID() << " tensorrt subgraph stream: " << stream_; - - // config setMaxWorkspaceSize to 2100 MB for max limit - constexpr size_t kWorkspaceSize = static_cast(2100) * (1 << 20); - config_->setMaxWorkspaceSize(kWorkspaceSize); - return RET_OK; -} - -bool TensorRTSubGraph::IsInt8Mode() { - for (auto cur_op : all_ops_) { - if (cur_op->GetQuantType() == schema::QuantType_QUANT_ALL) { - return true; - } - } - return false; -} - -nvinfer1::ITensor *TensorRTSubGraph::SetTensorRTNetworkInput(const TensorInfo &in_tensor, int index) { - if (index < 0) { - return nullptr; - } - for (int i = 0; i < ctx_->network()->getNbInputs(); i++) { - if (in_tensor.Name().compare(ctx_->network()->getInput(i)->getName()) == 0) { - MS_LOG(INFO) << "input tensor is already added in network: " << in_tensor.Name(); - return ctx_->network()->getInput(i); - } - } - - auto cuda_dtype = ConvertDataType(in_tensor.DataType()); - if (static_cast(cuda_dtype) == -1) { - MS_LOG(ERROR) << "Unsupported input data type " << static_cast(in_tensor.DataType()); - return nullptr; - } - nvinfer1::Dims input_dims; - if (using_input_ranges_) { - input_dims = SetInputDimsProfile(in_tensor, index); - } else { - input_dims = ParseInputDimsProfile(in_tensor, index); - } - MS_LOG(INFO) << "add network input: " << in_tensor.Name(); - return ctx_->network()->addInput(in_tensor.Name().c_str(), cuda_dtype, input_dims); -} - -nvinfer1::Dims TensorRTSubGraph::SetInputDimsProfile(const TensorInfo &in_tensor, int index) { - auto input_info = trt_profile_config_.input_infos[index]; - auto input_dims = ConvertCudaDims(input_info.input_shape); - DebugDims("input dims", input_dims); - for (size_t i = 0; i < trt_profile_config_.profiles.size(); i++) { - auto &profile = trt_profile_config_.profiles[i]; - auto min_dims = ConvertCudaDims(profile.inputs[index].min_dims); - if (!profiles_[i]->setDimensions(input_info.name.c_str(), nvinfer1::OptProfileSelector::kMIN, min_dims)) { - MS_LOG(ERROR) << "setDimensions of kMIN failed for " << input_info.name; - return input_dims; - } - auto opt_dims = ConvertCudaDims(profile.inputs[index].opt_dims); - if (!profiles_[i]->setDimensions(input_info.name.c_str(), nvinfer1::OptProfileSelector::kOPT, opt_dims)) { - MS_LOG(ERROR) << "setDimensions of kOPT failed for " << input_info.name; - return input_dims; - } - - auto max_dims = ConvertCudaDims(profile.inputs[index].max_dims); - if (!profiles_[i]->setDimensions(input_info.name.c_str(), nvinfer1::OptProfileSelector::kMAX, max_dims)) { - MS_LOG(ERROR) << "setDimensions of kMAX failed for " << input_info.name; - return input_dims; - } - DebugDims("min dims", min_dims); - DebugDims("opt dims", opt_dims); - DebugDims("max dims", max_dims); - } - return input_dims; -} - -nvinfer1::Dims TensorRTSubGraph::ParseInputDimsProfile(const TensorInfo &in_tensor, int index) { - nvinfer1::Dims input_dims = ConvertCudaDims(in_tensor.Shape()); - nvinfer1::Dims input_dims_min = ConvertCudaDims(in_tensor.Shape()); - if (!profiles_.front()->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kMIN, input_dims_min)) { - MS_LOG(ERROR) << "setDimensions of kMIN failed for " << in_tensor.Name(); - return input_dims; - } - nvinfer1::Dims input_dims_opt = ConvertCudaDims(in_tensor.Shape()); - if (!profiles_.front()->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kOPT, input_dims_opt)) { - MS_LOG(ERROR) << "setDimensions of kOPT failed for " << in_tensor.Name(); - return input_dims; - } - nvinfer1::Dims input_dims_max = ConvertCudaDims(in_tensor.Shape()); - // input_dims_max should be the same with input network dims - if (!profiles_.front()->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kMAX, input_dims_max)) { - MS_LOG(ERROR) << "setDimensions of kMAX failed for " << in_tensor.Name(); - return input_dims; - } - if (trt_profile_config_.profiles.empty()) { - ProfileItem profile_item; - profile_item.inputs.resize(inputs_.size()); - trt_profile_config_.profiles.push_back(profile_item); - } - auto &profile_item = trt_profile_config_.profiles.back(); - profile_item.inputs[index].min_dims = ConvertMSShape(input_dims_min); - profile_item.inputs[index].opt_dims = ConvertMSShape(input_dims_opt); - profile_item.inputs[index].max_dims = ConvertMSShape(input_dims_max); - - DebugDims("input min dims", input_dims_min); - DebugDims("input opt dims", input_dims_opt); - DebugDims("input max dims", input_dims_max); - return input_dims; -} - -int TensorRTSubGraph::ParseInputsProfile() { - MS_LOG(INFO) << "using serialied engine."; - for (size_t i = 0; i < inputs_.size(); i++) { - auto dim = ParseInputDimsProfile(inputs_[i], i); - if (dim.nbDims <= 0) { - MS_LOG(ERROR) << "input dims is invalid."; - return RET_ERROR; - } - } - return RET_OK; -} - -int TensorRTSubGraph::GetInputIndexByName(const std::string &name) { - for (size_t i = 0; i != inputs().size(); ++i) { - if (inputs()[i].Name() == name) { - return i; - } - } - return -1; -} - -int TensorRTSubGraph::BuildTensorRTGraph() { - MS_ASSERT(!all_ops_.empty()); - int ret; - if (engine_ != nullptr) { - return ParseInputsProfile(); - } - // build engine online - for (auto cur_op : all_ops_) { - cur_op->SetRuntime(runtime_); - for (size_t i = 0; i != cur_op->inputs().size(); ++i) { - // Data From CPU - auto in_tensor = cur_op->inputs()[i]; - if (IsSubGraphInputTensor(this->inputs(), in_tensor)) { - nvinfer1::ITensor *trt_tensor = SetTensorRTNetworkInput(in_tensor, GetInputIndexByName(in_tensor.Name())); - if (trt_tensor == nullptr) { - MS_LOG(ERROR) << "SetTensorRTNetworkInput failed for " << in_tensor.Name(); - return RET_ERROR; - } - // avoid bool input tensor - cur_op->SetSupportInputBool(false); - - ctx_->RegisterTensorWithSameName(ITensorHelper{trt_tensor, in_tensor.format(), true}, in_tensor.Name()); - continue; - } - - ITensorHelper trt_tensor = FindTensorRTInputs(cur_op, in_tensor); - if (trt_tensor.trt_tensor_ == nullptr) { - // weight tensor - auto weight_handled_inner = - cur_op->IsWeightInputHanledInner() || - trt_specific_weight_handled_inner_.find(cur_op->type()) != trt_specific_weight_handled_inner_.end(); - if (!weight_handled_inner) { - if (!in_tensor.IsConst()) { - MS_LOG(ERROR) << "Weight Tensor data is not const."; - return RET_ERROR; - } - trt_tensor.trt_tensor_ = lite::ConvertConstantTensor(ctx_, in_tensor, cur_op->GetOpName()); - trt_tensor.format_ = Format::NCHW; - MS_LOG(INFO) << "auto convert constant tensor for: " << in_tensor.Name(); - ctx_->RegisterTensor(trt_tensor, in_tensor.Name()); - } - } else { - ctx_->RegisterTensor(trt_tensor, in_tensor.Name()); - } - } - MS_LOG(DEBUG) << "Parsing TensorRT op for " << cur_op->GetOpName(); - - ret = cur_op->AddInnerOp(ctx_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Add op failed in TensorRT network: " << cur_op->GetOpName(); - return RET_ERROR; - } - ret = cur_op->SetInt8DynamicRange(ctx_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Set Int8 dynamic range failed in TensorRT network: " << cur_op->GetOpName(); - return RET_ERROR; - } - } - ret = MarkOutputs(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "MarkOutputs failed in TensorRT network"; - return ret; - } - - std::string network_name = "network_" + std::string(ctx_->network()->getInput(0)->getName()) + "_" + - std::string(ctx_->network()->getOutput(0)->getName()); - ctx_->network()->setName(network_name.c_str()); - this->name_ = network_name; - ret = BuildEngine(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Create engine failed in TensorRT network"; - return ret; - } - return RET_OK; -} - -int TensorRTSubGraph::MarkOutputs() { - // Mark NetWork Output Tensor. - for (const auto &out_tensor : outputs_) { - std::string output_name = out_tensor.Name(); - auto input_it = std::find_if(inputs_.begin(), inputs_.end(), - [=](const TensorInfo &input) { return input.Name() == output_name; }); - if (input_it != inputs_.end()) { - nvinfer1::ITensor *trt_tensor = SetTensorRTNetworkInput(*input_it, GetInputIndexByName(input_it->Name())); - if (trt_tensor == nullptr) { - MS_LOG(ERROR) << "trt_tensor is nullptr!"; - return RET_ERROR; - } - ctx_->network()->markOutput(*trt_tensor); - continue; - } - if (out_tensor.IsConst()) { - MS_LOG(INFO) << "markOutput for: " << out_tensor.Name(); - auto output_helper = ctx_->MsName2Tensor(out_tensor.Name()); - if (output_helper.trt_tensor_ == nullptr) { - output_helper.trt_tensor_ = lite::ConvertConstantTensor(ctx_, out_tensor, out_tensor.Name()); - output_helper.format_ = Format::NCHW; - MS_LOG(INFO) << "auto convert constant tensor for: " << out_tensor.Name(); - ctx_->RegisterTensor(output_helper, out_tensor.Name()); - } - nvinfer1::ITensor *out_trt_tensor = output_helper.trt_tensor_; - out_trt_tensor->setName(("__" + out_tensor.Name()).c_str()); - out_trt_tensor = ctx_->network()->addIdentity(*out_trt_tensor)->getOutput(0); - out_trt_tensor->setName(out_tensor.Name().c_str()); - ctx_->network()->markOutput(*out_trt_tensor); - for (int n = 0; n < out_trt_tensor->getDimensions().nbDims; n++) { - if (out_trt_tensor->getDimensions().d[n] == -1) { - output_batchsize_index_ = n; - break; - } - } - } - for (auto out_op : this->out_ops_) { - for (size_t index = 0; index < out_op->outputs().size(); index++) { - if (out_op->outputs()[index] == out_tensor) { - MS_LOG(INFO) << "markOutput for: " << out_tensor.Name(); - auto output_helper = out_op->output(ctx_, index); - nvinfer1::ITensor *out_trt_tensor = output_helper.trt_tensor_; - out_trt_tensor->setName(("__" + out_tensor.Name()).c_str()); - auto out_layer = ctx_->network()->addIdentity(*out_trt_tensor); - if (out_tensor.DataType() == DataType::kNumberTypeFloat16) { - MS_LOG(WARNING) << "Cast output tensor " << out_tensor.Name() << " to fp16"; - out_layer->setOutputType(0, nvinfer1::DataType::kHALF); - } - out_trt_tensor = out_layer->getOutput(0); - out_trt_tensor->setName(out_tensor.Name().c_str()); - ctx_->network()->markOutput(*out_trt_tensor); - for (int n = 0; n < out_trt_tensor->getDimensions().nbDims; n++) { - if (out_trt_tensor->getDimensions().d[n] == -1) { - output_batchsize_index_ = n; - break; - } - } - } - } - } - } - return RET_OK; -} - -int TensorRTSubGraph::Prepare() { - int ret = lite::SetCudaDevice(device_info_); - if (ret != RET_OK) { - return ret; - } - if (this->engine_ == nullptr) { - MS_LOG(ERROR) << "engine_ is null in this builder_"; - return RET_ERROR; - } - this->trt_context_ = this->engine_->createExecutionContext(); - if (this->trt_context_ == nullptr) { - MS_LOG(ERROR) << "TensorRTSubGraph create context failed."; - return RET_ERROR; - } - -#ifdef PROFILER_ - auto profiler = new SimpleProfiler("myprofiler"); - if (profiler == nullptr) { - MS_LOG(WARNING) << "Cannot create profiler"; - } - this->trt_context_->setProfiler(profiler); -#endif - - int binding_num = this->engine_->getNbBindings(); - if (binding_num <= 0) { - MS_LOG(ERROR) << "TensorRTSubGraph binding num < 0."; - return RET_ERROR; - } - tensor_bindings_ = new (std::nothrow) void *[binding_num]; - if (tensor_bindings_ == nullptr) { - MS_LOG(ERROR) << "malloc tensor binding array failed."; - return RET_ERROR; - } - profile_index_ = MaxVolumnProfileIndex(); - if (this->trt_context_->setOptimizationProfile(profile_index_)) { - MS_LOG(INFO) << "setOptimizationProfile: " << profile_index_; - } - const auto &profile = trt_profile_config_.profiles[profile_index_]; - for (size_t i = 0; i != inputs_.size(); ++i) { - auto &tensor = inputs_[i]; - auto max_profile_dims = profile.inputs[i].max_dims; - tensor.SetShape(max_profile_dims); - int volumn = std::accumulate(max_profile_dims.begin(), max_profile_dims.end(), 1, std::multiplies()); - auto type_size = lite::DataTypeSize(static_cast(tensor.DataType())); - auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(tensor, volumn * type_size); - if (device_ptr == nullptr) { - MS_LOG(ERROR) << "malloc for inputs tensor device memory failed."; - return RET_ERROR; - } - auto tensor_name = tensor.Name(); - trt_in_tensor_name_.push_back(tensor_name); - int index = GetProfileBindingIndex(tensor_name, profile_index_); - MS_LOG(INFO) << "device index " << index << " for tensor : " << tensor_name << " attr: " << device_ptr; - tensor_bindings_[index] = device_ptr; - nvinfer1::Dims input_dims = ConvertCudaDims(profile.inputs[i].max_dims); - if (!this->trt_context_->setBindingDimensions(index, input_dims)) { - MS_LOG(ERROR) << "invalid input dims of " << tensor.Name(); - return RET_ERROR; - } - } - if (!this->trt_context_->allInputDimensionsSpecified()) { - MS_LOG(ERROR) << "input dims need to be specified."; - return RET_ERROR; - } - for (auto op : all_ops_) { - ret = op->Prepare(tensor_bindings_, engine_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "prepare op failed of " << op->GetOpName(); - return RET_ERROR; - } - } - for (auto &tensor : outputs_) { - int max_index = GetProfileBindingIndex(tensor.Name(), profile_index_); - auto out_dims = trt_context_->getBindingDimensions(max_index); - int elem_num = std::accumulate(out_dims.d, out_dims.d + out_dims.nbDims, 1, std::multiplies()); - DebugDims("out dims", out_dims); - MS_LOG(INFO) << "Set output shape by tensorrt binding output"; - tensor.SetShape(lite::ConvertMSShape(out_dims)); - auto type_size = lite::DataTypeSize(static_cast(tensor.DataType())); - if (tensor.DataType() == DataType::kNumberTypeBool) { - type_size = lite::DataTypeSize(static_cast(DataType::kNumberTypeInt32)); - } - auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(tensor, elem_num * type_size); - if (device_ptr == nullptr) { - MS_LOG(ERROR) << "malloc for outputs tensor device memory failed."; - return RET_ERROR; - } - for (size_t j = 0; j != profiles_.size(); ++j) { - int index = GetProfileBindingIndex(tensor.Name(), j); - tensor_bindings_[index] = device_ptr; - } - trt_out_tensor_name_.push_back(tensor.Name()); - } - return RET_OK; -} - -int TensorRTSubGraph::SelectProfile(const std::vector &new_shapes) const { - std::vector profile_index; - for (size_t i = 0; i < profiles_.size(); ++i) { - const auto &profile = trt_profile_config_.profiles[i]; - bool condition = true; - for (size_t j = 0; j < trt_in_tensor_name_.size(); ++j) { - auto new_shape = new_shapes[j]; - auto profile_input = profile.inputs[j]; - if (new_shape.size() != profile_input.max_dims.size()) { - condition = false; - } else { - for (size_t od = 0; od < new_shape.size(); od++) { - if (new_shape[od] < profile_input.min_dims[od] || new_shape[od] > profile_input.max_dims[od]) { - condition = false; - break; - } - } - } - } - if (condition) { - profile_index.push_back(i); - } - } - return profile_index.empty() ? -1 : profile_index.front(); -} - -size_t TensorRTSubGraph::MaxVolumnProfileIndex() const { - int max_volumn = std::numeric_limits::min(); - size_t max_volumn_index = 0; - for (size_t i = 0; i < trt_profile_config_.profiles.size(); ++i) { - const auto &profile = trt_profile_config_.profiles[i]; - // depend on the first input tensor - int64_t volumn = std::accumulate(profile.inputs[0].max_dims.begin(), profile.inputs[0].max_dims.end(), 1, - std::multiplies()); - if (volumn > max_volumn) { - max_volumn_index = i; - max_volumn = volumn; - } - } - return max_volumn_index; -} - -int TensorRTSubGraph::GetProfileBindingIndex(const std::string &name, size_t profile_index) { - std::string binding_name = name; - if (profile_index != 0) { - binding_name += " [profile " + std::to_string(profile_index) + "]"; - } - return this->engine_->getBindingIndex(binding_name.c_str()); -} - -int TensorRTSubGraph::OnNewInputShapes(const std::vector &new_shapes) { - if (inputs_.size() != new_shapes.size()) { - MS_LOG(ERROR) << "Graph inputs size " << inputs_.size() << " != resize input size " << new_shapes.size(); - return RET_ERROR; - } - auto select_profile_index = SelectProfile(new_shapes); - if (select_profile_index < 0) { - MS_LOG(ERROR) << "Not support input shape " << new_shapes; - return RET_ERROR; - } - profile_index_ = static_cast(select_profile_index); - if (this->trt_context_->setOptimizationProfile(profile_index_)) { - MS_LOG(INFO) << "setOptimizationProfile: " << profile_index_; - } - int batch_size = -1; - for (size_t i = 0; i < trt_in_tensor_name_.size(); i++) { - if (inputs_[i].Shape() == new_shapes[i]) { - continue; - } - if (input_batchsize_index_ == -1) { - MS_LOG(ERROR) << "current network don't support resize."; - return RET_ERROR; - } - inputs_[i].SetShape(new_shapes[i]); - if (ctx_->network() != nullptr) { - for (int j = 0; j < ctx_->network()->getNbInputs(); j++) { - if (trt_in_tensor_name_[i].compare(ctx_->network()->getInput(j)->getName()) != 0) { - continue; - } - nvinfer1::Dims construct_dims = ctx_->network()->getInput(j)->getDimensions(); - bool ret = ValidInputResizeDims(construct_dims, inputs_[i].Shape()); - if (!ret) { - MS_LOG(ERROR) << "input resize shape is invalid."; - return RET_ERROR; - } - } - } - - MS_LOG(INFO) << "resize at input_batch_index " << input_batchsize_index_ << ", update batch size to " - << inputs_[i].Shape()[input_batchsize_index_]; - int new_batch_size = inputs_[i].Shape()[input_batchsize_index_]; - if (batch_size != -1 && batch_size != new_batch_size) { - MS_LOG(ERROR) << "Batch size " << batch_size << " of input 0 != batch size " << new_batch_size << " of input " - << i; - return RET_ERROR; - } - batch_size = new_batch_size; - - int index = GetProfileBindingIndex(trt_in_tensor_name_[i], profile_index_); - // Set actual input size - nvinfer1::Dims input_dims = ConvertCudaDims(inputs_[i].Shape()); - for (int od = 0; od < input_dims.nbDims; od++) { - MS_LOG(DEBUG) << "in tensor " << trt_in_tensor_name_[i] << " dims at " << od << " is " << input_dims.d[od]; - } - - if (!this->trt_context_->setBindingDimensions(index, input_dims)) { - MS_LOG(ERROR) << "invalid input dims of " << inputs_[i].Name() << ", profile index: " << profile_index_ - << ", dst dims: " << CudaDimsAsString(input_dims); - return RET_ERROR; - } - } - if (!this->trt_context_->allInputDimensionsSpecified()) { - MS_LOG(ERROR) << "input dims need to be specified."; - return RET_ERROR; - } - if (batch_size != -1) { - for (size_t i = 0; i < trt_out_tensor_name_.size(); i++) { - auto index = GetProfileBindingIndex(trt_out_tensor_name_[i], profile_index_); - auto out_dims = trt_context_->getBindingDimensions(index); - DebugDims("out dims", out_dims); - auto new_shape = lite::ConvertMSShape(out_dims); - MS_LOG(INFO) << "Set output shape of " << trt_out_tensor_name_[i] << " to " << new_shape - << " by tensorrt binding output"; - outputs_[i].SetShape(new_shape); - } - } - return RET_OK; -} - -int TensorRTSubGraph::VSLPreExectute(const std::vector &inputs, int i, bool sync, - const std::string &tensor_name) { - const bool is_expert_ids = (inputs.size() == Num6) ? Num1 : 0; - const int input_ids_idx = 0; - const int expert_ids_idx = (is_expert_ids) ? Num1 : -1; - const int attn_mask_idx = Num1 + is_expert_ids; - const int pos_ids_idx = Num2 + is_expert_ids; - const int current_idx_idx = Num3 + is_expert_ids; - if (i == input_ids_idx || i == expert_ids_idx || i == pos_ids_idx) { - int *in_ptr = static_cast(inputs[i].data_ptr()->data()); - int batch = inputs[trt_in_tensor_name_.size() - Num1].ElementsNum(); - int seq = inputs[0].ElementsNum() / batch; - int export_num = (i != expert_ids_idx) ? Num1 : inputs[i].ElementsNum() / (batch * seq); - bool incremental_mode = - (static_cast(inputs[pos_ids_idx].data().const_data())[0] != 0) ? true : false; - size_t h_token = 0; - for (int k = 0; k < batch; k++) { - int actual_seq_len = - (incremental_mode) - ? Num1 - : (static_cast(inputs[trt_in_tensor_name_.size() - Num1].data().const_data())[k] + Num1); - int batch_valid = static_cast(inputs[trt_in_tensor_name_.size() - Num1].data().const_data())[k]; - h_token += (batch_valid == -1) ? 0 : actual_seq_len; - } - for (int j = 0; j < export_num; j++) { - size_t h_token_idx = 0; - for (int k = 0; k < batch; k++) { - int actual_seq_len = - (incremental_mode) - ? Num1 - : (static_cast(inputs[trt_in_tensor_name_.size() - Num1].data().const_data())[k] + Num1); - for (int l = 0; l < actual_seq_len; l++) { - in_ptr[j * h_token + h_token_idx + l] = - static_cast(inputs[i].data_ptr()->data())[j * batch * seq + k * seq + l]; - } - h_token_idx += actual_seq_len; - } - } - return runtime_->GetAllocator()->SyncMemHostToDevice(inputs[i], tensor_name, sync, - h_token * export_num * sizeof(int)); - } else if (i != attn_mask_idx && i != current_idx_idx) { - return runtime_->GetAllocator()->SyncMemHostToDevice(inputs[i], tensor_name, sync); - } - return RET_OK; -} - -int TensorRTSubGraph::PreExecute(const std::vector &inputs, const std::vector &outputs, - bool sync) { - if (inputs_.size() != inputs.size()) { - MS_LOG(ERROR) << "Graph inputs size " << inputs_.size() << " != execute inputs size " << inputs.size(); - return RET_ERROR; - } - if (!outputs.empty() && outputs.size() != outputs_.size()) { - MS_LOG(ERROR) << "Graph outputs size " << outputs_.size() << " != execute outputs size " << outputs.size(); - return RET_ERROR; - } - std::vector new_shapes; - std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_shapes), [](auto &t) { return t.shape_c(); }); - auto ret = OnNewInputShapes(new_shapes); - if (ret != RET_OK) { - return ret; - } - for (size_t i = 0; i < trt_in_tensor_name_.size(); i++) { - auto trt_tensor_name = trt_in_tensor_name_[i]; - void *device_ptr = nullptr; - auto input_device_address = inputs[i].device_address(); - if (input_device_address != nullptr && input_device_address->GetMutablePtr() != nullptr) { - device_ptr = input_device_address->GetMutablePtr(); - } else { - device_ptr = runtime_->GetAllocator()->MallocDeviceMem(trt_tensor_name, inputs_[i].DataSize(), - ConvertDataType(inputs_[i].DataType())); - if (device_ptr == nullptr) { - MS_LOG(ERROR) << "realloc for input tensor device memory failed."; - return RET_ERROR; - } - if (runtime_->IsTransformerOptimizeSigma()) { - ret = VSLPreExectute(inputs, i, sync, trt_tensor_name); - } else { - ret = runtime_->GetAllocator()->SyncMemHostToDevice(inputs[i], trt_tensor_name, sync); - } - if (ret != RET_OK) { - MS_LOG(ERROR) << "sync mem from host to device failed for " << trt_tensor_name; - return RET_ERROR; - } - runtime_->GetAllocator()->MarkMemValid(trt_tensor_name, true); - } - int index = GetProfileBindingIndex(trt_tensor_name, profile_index_); - MS_LOG(INFO) << "device index " << index << " for tensor : " << trt_tensor_name << " attr: " << device_ptr; - tensor_bindings_[index] = device_ptr; - } - for (size_t i = 0; i < trt_out_tensor_name_.size(); i++) { - const auto &trt_out_tensor_name = trt_out_tensor_name_[i]; - int index = GetProfileBindingIndex(trt_out_tensor_name, profile_index_); - void *device_ptr = nullptr; - if (outputs.size() > i) { - auto &output = outputs[i]; - if (output.device_address() && output.device_address()->GetMutablePtr()) { - if (output.Size() < outputs_[i].DataSize()) { - MS_LOG(ERROR) << "Specified output device data size " << output.Size() - << " cannot less than execute output data size " << outputs_[i].DataSize() - << ", output shape: " << outputs_[i].Shape(); - return RET_ERROR; - } - device_ptr = output.device_address()->GetMutablePtr(); - } - } - if (!device_ptr) { - device_ptr = runtime_->GetAllocator()->MallocDeviceMem(trt_out_tensor_name, outputs_[i].DataSize(), - ConvertDataType(outputs_[i].DataType())); - if (device_ptr == nullptr) { - MS_LOG(ERROR) << "realloc for outputs tensor device memory failed."; - return RET_ERROR; - } - } - tensor_bindings_[index] = device_ptr; - } - return RET_OK; -} // namespace mindspore::lite - -int TensorRTSubGraph::PostExecute(std::vector *outputs, bool sync) { - if (!outputs->empty() && outputs->size() != outputs_.size()) { - MS_LOG(ERROR) << "Graph outputs size " << outputs_.size() << " != execute outputs size " << outputs->size(); - return RET_ERROR; - } - auto has_outputs = !outputs->empty(); - for (size_t i = 0; i < trt_out_tensor_name_.size(); i++) { - const auto &trt_out_tensor_name = trt_out_tensor_name_[i]; - auto index = GetProfileBindingIndex(trt_out_tensor_name, profile_index_); - // actual output tensor dims - auto out_dims = this->trt_context_->getBindingDimensions(index); - std::vector new_shape = lite::ConvertMSShape(out_dims); - for (int od = 0; od < out_dims.nbDims; od++) { - MS_LOG(DEBUG) << "out tensor " << trt_out_tensor_name << " dims at " << od << " is " << new_shape[od]; - } - runtime_->GetAllocator()->MarkMemValid(trt_out_tensor_name, true); - if (has_outputs) { - auto &tensor = outputs->at(i); - auto dst_device = tensor.device_address(); - if (dst_device == nullptr || dst_device->GetMutablePtr() == nullptr) { - if (tensor.Size() < outputs_[i].DataSize()) { - MS_LOG(ERROR) << "Specified output host data size " << tensor.Size() - << " cannot less than execute output data size " << outputs_[i].DataSize() - << ", output shape: " << new_shape; - return RET_ERROR; - } - auto host_address = tensor.data_c(); - if (host_address == nullptr) { - MS_LOG(ERROR) << "Specified output device or host address cannot be nullptr"; - return RET_ERROR; - } - int sync_ret = runtime_->GetAllocator()->SyncMemDeviceToHost(host_address, outputs_[i].DataSize(), - trt_out_tensor_name, sync); - if (sync_ret != RET_OK) { - MS_LOG(ERROR) << "sync mem from device to host failed for " << trt_out_tensor_name; - return sync_ret; - } - } - } else { - tensor::Tensor output_tensor(static_cast(outputs_[i].DataType()), new_shape); - int sync_ret = runtime_->GetAllocator()->SyncMemDeviceToHost(&output_tensor, trt_out_tensor_name, sync); - if (sync_ret != RET_OK) { - MS_LOG(ERROR) << "sync mem from device to host failed for " << trt_out_tensor_name; - return sync_ret; - } - outputs->push_back(output_tensor); - } - runtime_->GetAllocator()->MarkMemValid(trt_out_tensor_name, false); - } - // make mem invalid, prepare for next execute - for (size_t i = 0; i < inputs_.size(); i++) { - runtime_->GetAllocator()->MarkMemValid(trt_in_tensor_name_[i], false); - } - return RET_OK; -} - -bool TensorRTSubGraph::ValidInputResizeDims(const nvinfer1::Dims &construct_dims, - const std::vector &resize_input_shape) { - if (static_cast(construct_dims.nbDims) != resize_input_shape.size()) { - MS_LOG(ERROR) << "invalid resize input."; - return false; - } - return true; -} - -int TensorRTSubGraph::Execute(const std::vector &inputs, std::vector *outputs) { -#ifdef ASYNC_INFER - bool sync = false; -#else - bool sync = true; -#endif - int ret = lite::SetCudaDevice(device_info_); - if (ret != RET_OK) { - return ret; - } - ret = PreExecute(inputs, *outputs, sync); - if (ret != RET_OK) { - return ret; - } - if (sync) { - if (!this->trt_context_->executeV2(tensor_bindings_)) { - MS_LOG(ERROR) << "TensorRT execute failed."; - return RET_ERROR; - } - } else { - if (!this->trt_context_->enqueueV2(tensor_bindings_, stream_, nullptr)) { - MS_LOG(ERROR) << "TensorRT execute failed."; - return RET_ERROR; - } - } - ret = PostExecute(outputs, sync); - if (ret != RET_OK) { - return ret; - } - if (!sync) { - cudaStreamSynchronize(stream_); - } - return RET_OK; -} - -int TensorRTSubGraph::Resize(const std::vector &, const std::vector &new_shapes) { - return OnNewInputShapes(new_shapes); -} - -ITensorHelper TensorRTSubGraph::FindTensorRTInputs(TensorRTOp *cur_op, const TensorInfo &in_tensor) { - for (auto input_op : cur_op->in_ops()) { - for (size_t i = 0; i < input_op->outputs().size(); i++) { - auto out_tensor = input_op->outputs().at(i); - if (in_tensor.Name().compare(out_tensor.Name()) == 0) { - return input_op->output(ctx_, i); - } - } - } - return ITensorHelper{}; -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.h b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.h deleted file mode 100644 index e26ccaca54971d5143e38256335894a59367be54..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.h +++ /dev/null @@ -1,154 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_SUBGRAPH_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_SUBGRAPH_H_ -#include -#include -#include -#include -#include -#include -#include "include/api/kernel.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include "src/extendrt/delegate/tensorrt/tensorrt_serializer.h" -#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" -#include "src/extendrt/delegate/parameter_cache/embedding_cache_manager.h" -#include "include/api/context.h" -#include "common/config_infos.h" - -namespace mindspore::lite { -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -struct CacheTensorInfo { - std::vector network_input_tensor_; - bool front_op_can_cache_; -}; - -class TensorRTSubGraph { - public: - TensorRTSubGraph(std::vector ops, const std::vector &inputs, - const std::vector &outputs, const mindspore::Context *ctx, - std::shared_ptr device_info, TensorRTRuntime *runtime, bool support_resize, - bool support_hw_resize, const ProfileConfigs &trt_profile_config); - ~TensorRTSubGraph(); - - int Prepare(); - - int Execute(const std::vector &inputs, std::vector *outputs); - - int Resize(const std::vector &inputs, const std::vector &new_shapes); - - int BuildTensorRTGraph(); - - int Init(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle); - - void SetSerializePath(const std::string &path) { serialize_file_path_ = std::move(path); } - - int VSLPreExectute(const std::vector &inputs, int i, bool sync, const std::string &tensor_name); - - std::vector &inputs() { return inputs_; } - - std::vector &outputs() { return outputs_; } - - private: - int GetInputIndexByName(const std::string &name); - int BuildEngine(); - - int SetDeviceConfig(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle); - - bool IsInt8Mode(); - - bool SupportFP16(); - - nvinfer1::ITensor *SetTensorRTNetworkInput(const TensorInfo &in_tensor, int index); - - ITensorHelper FindTensorRTInputs(TensorRTOp *cur_op, const TensorInfo &in_tensor); - - int MarkOutputs(); - - bool IsCached(TensorRTOp *cur_op, const TensorInfo &in_tensor); - - void FindCacheTensorInfo(TensorRTOp *cur_op, TensorInfo device_cache_tensor); - - bool CanOpCache(TensorRTOp *cur_op); - - int HandleCacheTensor(TensorRTOp *cur_op, const TensorInfo &in_tensor); - - nvinfer1::Dims ParseInputDimsProfile(const TensorInfo &in_tensor, int index); - nvinfer1::Dims SetInputDimsProfile(const TensorInfo &in_tensor, int index); - int ParseInputsProfile(); - - int PreExecute(const std::vector &inputs, const std::vector &outputs, - bool sync = true); - int PostExecute(std::vector *outputs, bool sync = true); - - int OnNewInputShapes(const std::vector &inputs); - - size_t MaxVolumnProfileIndex() const; - int SelectProfile(const std::vector &new_shapes) const; - int GetProfileBindingIndex(const std::string &name, size_t profile_index); - bool ValidInputResizeDims(const nvinfer1::Dims &construct_dims, const std::vector &resize_input_shape); - bool IsValidProfileDims() const; - - std::string name_; - std::vector inputs_; - std::vector outputs_; - - std::vector all_ops_{}; - // subgraph input nodes. - std::vector in_ops_{}; - // subgraph output nodes. - std::vector out_ops_{}; - - void **tensor_bindings_{nullptr}; - - std::shared_ptr device_info_{nullptr}; - - TensorRTRuntime *runtime_{nullptr}; // all subgraph in one delegate share a runtime_ - - std::set trt_specific_weight_handled_inner_; - - // save in/out tensor name for subgraph isolate. - std::vector trt_in_tensor_name_; - std::vector trt_out_tensor_name_; - - nvinfer1::INetworkDefinition *network_{nullptr}; - nvinfer1::IBuilderConfig *config_{nullptr}; - nvinfer1::ICudaEngine *engine_{nullptr}; - nvinfer1::IExecutionContext *trt_context_{nullptr}; - - TensorRTContext *ctx_; - - // -1 means don't support resize - int input_batchsize_index_{0}; - int output_batchsize_index_{0}; - int input_hw_index_{0}; - - std::map> model_input_to_cache_tensors_; - - std::shared_ptr serializer_{nullptr}; - - std::string serialize_file_path_; - cudaStream_t stream_{nullptr}; - - std::vector profiles_{}; - bool using_input_ranges_{false}; - ProfileConfigs trt_profile_config_; - size_t profile_index_{0}; -}; -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_SUBGRAPH_H_ diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.cc b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.cc deleted file mode 100644 index 72a5393597ea2e55cba1b2b304c5db641c848f26..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.cc +++ /dev/null @@ -1,996 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" -#include -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/op/cast_plugin.h" -#include "src/extendrt/delegate/tensorrt/distribution/distribution_collective.h" - -namespace mindspore::lite { -nvinfer1::Dims ConvertCudaDims(int data, size_t size) { - nvinfer1::Dims dims{}; - dims.nbDims = -1; - if (size > static_cast(dims.MAX_DIMS)) { - MS_LOG(ERROR) << "invalid shape size: " << size; - return dims; - } - dims.nbDims = size; - for (size_t i = 0; i < size; i++) { - dims.d[i] = data; - } - return dims; -} - -template -nvinfer1::Dims ConvertCudaDimsWithType(const void *data, int64_t size) { - nvinfer1::Dims dims{}; - dims.nbDims = -1; - if (size > static_cast(dims.MAX_DIMS)) { - MS_LOG(ERROR) << "invalid shape size: " << size; - return dims; - } - dims.nbDims = size; - - auto *dims_data = static_cast(data); - for (int i = 0; i < size; i++) { - dims.d[i] = static_cast(*(dims_data + i)); - } - return dims; -} - -nvinfer1::Dims ConvertCudaDims(const std::vector &data) { - auto dims = ConvertCudaDimsWithType(data.data(), data.size()); - return dims; -} - -nvinfer1::Dims ConvertCudaDims(const TensorInfo &ms_tensor) { - auto data = ms_tensor.Data(); - auto size = ms_tensor.ElementNum(); - auto ms_dtype = ms_tensor.DataType(); - - nvinfer1::Dims dims{}; - if (ms_dtype == DataType::kNumberTypeInt32) { - dims = ConvertCudaDimsWithType(data, size); - } else if (ms_dtype == DataType::kNumberTypeInt64) { - dims = ConvertCudaDimsWithType(data, size); - } else { - MS_LOG(ERROR) << "invalid DataType: " << ms_dtype; - } - return dims; -} - -std::string CudaDimsAsString(const nvinfer1::Dims &dims) { - std::stringstream str_stream; - str_stream << "[" << dims.nbDims << ":"; - if (dims.nbDims > 0) { - for (int i = 0; i < dims.nbDims; i++) { - str_stream << dims.d[i]; - if (i + 1 != dims.nbDims) { - str_stream << ","; - } - } - } - str_stream << "]"; - return str_stream.str(); -} - -std::vector ConvertTensorAsIntVector(const TensorInfo &ms_tensor) { - if (!ms_tensor.IsConst()) { - MS_LOG(ERROR) << "Expect tensor to be const tensor, but got var tensor"; - return {}; - } - auto data = ms_tensor.Data(); - if (data == nullptr) { - MS_LOG(ERROR) << "Const data cannot be nullptr"; - return {}; - } - std::vector vals; - auto ms_dtype = ms_tensor.DataType(); - auto size = ms_tensor.ElementNum(); - if (ms_dtype == DataType::kNumberTypeInt32 || static_cast(ms_dtype) == TypeId::kMetaTypeTypeType) { - auto int_data = reinterpret_cast(data); - for (int64_t i = 0; i < size; i++) { - vals.push_back(int_data[i]); - } - } else if (ms_dtype == DataType::kNumberTypeInt64) { - auto int_data = reinterpret_cast(data); - for (int64_t i = 0; i < size; i++) { - vals.push_back((int32_t)int_data[i]); - } - } else { - MS_LOG(ERROR) << "invalid DataType: " << ms_dtype; - } - return vals; -} - -bool SameDims(nvinfer1::Dims dims, const std::vector &shape) { - if (dims.nbDims != static_cast(shape.size())) { - return false; - } - // dynamic dim, only channel dim know - for (int i = 0; i < dims.nbDims; i++) { - if (dims.d[i] == -1) { - continue; - } - if (dims.d[i] != shape[i]) { - return false; - } - } - return true; -} - -std::vector ConvertMSShape(const nvinfer1::Dims dims) { - std::vector shape; - for (int i = 0; i < dims.nbDims; i++) { - shape.push_back(dims.d[i]); - } - return shape; -} - -std::vector NHWC2NCHW(std::vector nhwc_shape) { - std::vector nchw_shape; - if (nhwc_shape.size() != DIMENSION_4D) { - return nhwc_shape; - } - nchw_shape.push_back(nhwc_shape[kNHWC_N]); - nchw_shape.push_back(nhwc_shape[kNHWC_C]); - nchw_shape.push_back(nhwc_shape[kNHWC_H]); - nchw_shape.push_back(nhwc_shape[kNHWC_W]); - return nchw_shape; -} - -nvinfer1::IShuffleLayer *SetTranspose(TensorRTContext *ctx, const nvinfer1::ITensor &input, - nvinfer1::Permutation permutation) { - nvinfer1::IShuffleLayer *layer = ctx->network()->addShuffle(const_cast(input)); - if (layer == nullptr) { - MS_LOG(ERROR) << "failed to create ShuffleLayer when create transpose op."; - return nullptr; - } - layer->setFirstTranspose(permutation); - return layer; -} - -nvinfer1::DataType ConvertDataType(DataType type_id) { - std::map data_type_map = { -#if TRT_VERSION_GE(7, 2) - {DataType::kNumberTypeBool, nvinfer1::DataType::kBOOL}, -#endif - {DataType::kNumberTypeInt8, nvinfer1::DataType::kINT8}, - {DataType::kNumberTypeInt32, nvinfer1::DataType::kINT32}, - {DataType::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT}, - {DataType::kNumberTypeFloat16, nvinfer1::DataType::kHALF}, - {DataType::kNumberTypeInt64, nvinfer1::DataType::kINT32}, - }; - auto iter = data_type_map.find(type_id); - nvinfer1::DataType data_type; - if (iter != data_type_map.end()) { - data_type = iter->second; - } else { - data_type = nvinfer1::DataType::kFLOAT; - MS_LOG(INFO) << "invalid data_type for TensorRT, need check: " << static_cast(type_id); - } - return data_type; -} - -cudaDataType ConvertDataType(nvinfer1::DataType type_id) { - std::map data_type_map = { - {nvinfer1::DataType::kINT8, CUDA_R_8I}, - {nvinfer1::DataType::kINT32, CUDA_R_32I}, - {nvinfer1::DataType::kFLOAT, CUDA_R_32F}, - {nvinfer1::DataType::kHALF, CUDA_R_16F}, - }; - auto iter = data_type_map.find(type_id); - cudaDataType data_type; - if (iter != data_type_map.end()) { - data_type = iter->second; - } else { - data_type = CUDA_R_32F; - MS_LOG(WARNING) << "invalid data_type for TensorRT, need check: " << static_cast(type_id); - } - return data_type; -} - -nvinfer1::IShuffleLayer *NHWC2NCHW(TensorRTContext *ctx, const nvinfer1::ITensor &input) { - // NHWC 0123 NCHW 0312 - nvinfer1::Permutation perm{{0, 3, 1, 2}}; - return SetTranspose(ctx, input, perm); -} - -nvinfer1::IShuffleLayer *NCHW2NHWC(TensorRTContext *ctx, const nvinfer1::ITensor &input) { - // NCHW 0123 NHWC 0231 - nvinfer1::Permutation perm{{0, 2, 3, 1}}; - return SetTranspose(ctx, input, perm); -} - -nvinfer1::ITensor *ConvertConstantTensor(TensorRTContext *ctx, const TensorInfo &ms_tensor, - const std::string &op_name) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "context or network is null for ConvertConstantTensor"; - return nullptr; - } - nvinfer1::Dims dims = ConvertCudaDims(ms_tensor.Shape()); - if (dims.nbDims == -1) { - MS_LOG(INFO) << ms_tensor.Name() << " ConvertCudaDims failed, convert as scalar."; - dims.nbDims = 1; - dims.d[0] = 1; - } - nvinfer1::DataType data_type = ConvertDataType(ms_tensor.DataType()); - if (!ms_tensor.IsConst()) { - MS_LOG(ERROR) << "ConvertConstantTensor from a MSTensor with nullptr data: " << ms_tensor.Name(); - return nullptr; - } - nvinfer1::Weights weights{data_type, ms_tensor.Data(), ms_tensor.ElementNum()}; - if (data_type == nvinfer1::DataType::kBOOL) { - weights.type = nvinfer1::DataType::kINT32; - void *data_int32 = malloc(ms_tensor.ElementNum() * sizeof(int32_t)); - if (data_int32 == nullptr) { - MS_LOG(ERROR) << "Malloc buffer failed."; - return nullptr; - } - auto src = static_cast(ms_tensor.Data()); - auto dst = static_cast(data_int32); - for (int i = 0; i < ms_tensor.ElementNum(); i++) { - dst[i] = (int32_t)src[i]; - } - weights.values = data_int32; - } - nvinfer1::IConstantLayer *constant_tensor = ctx->network()->addConstant(dims, weights); - if (constant_tensor == nullptr) { - MS_LOG(ERROR) << "create constant_tensor failed."; - return nullptr; - } - ctx->RegisterLayer(constant_tensor, ms_tensor.Name() + "_" + op_name); - auto tensor_ptr = constant_tensor->getOutput(0); - return tensor_ptr; -} - -nvinfer1::ITensor *ConvertScalarToITensor(TensorRTContext *ctx, size_t shape_size, const void *value, - const DataType data_type, const std::string &op_name) { - nvinfer1::Dims dims = ConvertCudaDims(1, shape_size); - if (dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name; - return nullptr; - } - nvinfer1::Weights weights{ConvertDataType(data_type), value, 1}; - nvinfer1::IConstantLayer *constant_tensor = ctx->network()->addConstant(dims, weights); - if (constant_tensor == nullptr) { - MS_LOG(ERROR) << "create constant_tensor failed."; - return nullptr; - } - ctx->RegisterLayer(constant_tensor, op_name + "_constant"); - return constant_tensor->getOutput(0); -} - -nvinfer1::ITensor *ConvertScalarToITensor(TensorRTContext *ctx, size_t shape_size, const TensorInfo &ms_tensor, - const DataType data_type, const std::string &op_name) { - const void *value = ms_tensor.Data(); - auto tensor_ptr = ConvertScalarToITensor(ctx, shape_size, value, data_type, op_name); - return tensor_ptr; -} - -std::experimental::optional TryConvertActivationType(ActivationType activation_type) { - std::map action_map = { - {ActivationType::RELU, ActivationParams{nvinfer1::ActivationType::kRELU, false, 0, false, 0}}, - {ActivationType::SIGMOID, ActivationParams{nvinfer1::ActivationType::kSIGMOID, false, 0, false, 0}}, - {ActivationType::TANH, ActivationParams{nvinfer1::ActivationType::kTANH, false, 0, false, 0}}, - {ActivationType::LEAKY_RELU, ActivationParams{nvinfer1::ActivationType::kLEAKY_RELU, true, 0, false, 0}}, - {ActivationType::ELU, ActivationParams{nvinfer1::ActivationType::kELU, true, 0, false, 0}}, - {ActivationType::SELU, ActivationParams{nvinfer1::ActivationType::kSELU, true, 0, true, 0}}, - {ActivationType::SOFTSIGN, ActivationParams{nvinfer1::ActivationType::kSOFTSIGN, false, 0, false, 0}}, - {ActivationType::SOFTPLUS, ActivationParams{nvinfer1::ActivationType::kSOFTPLUS, false, 0, false, 0}}, - {ActivationType::THRESHOLDRELU, ActivationParams{nvinfer1::ActivationType::kTHRESHOLDED_RELU, true, 0, false, 0}}, - {ActivationType::RELU6, ActivationParams{nvinfer1::ActivationType::kCLIP, true, 0, true, 6}}, - {ActivationType::RELU1, ActivationParams{nvinfer1::ActivationType::kCLIP, true, 0, true, 1}}, - {ActivationType::HARD_TANH, ActivationParams{nvinfer1::ActivationType::kCLIP, true, -1, true, 1}}, - {ActivationType::HSIGMOID, ActivationParams{nvinfer1::ActivationType::kHARD_SIGMOID, true, 1.f / 6, true, 0.5f}}, - // using plugin - {ActivationType::GELU, ActivationParams{nvinfer1::ActivationType::kTHRESHOLDED_RELU, false, 0, false, 0}}, - {ActivationType::SWISH, ActivationParams{nvinfer1::ActivationType::kSIGMOID, false, 0, false, 0}}}; - return action_map.find(activation_type) != action_map.end() - ? std::experimental::optional(action_map[activation_type]) - : std::experimental::nullopt; -} - -bool IsComfortableAlign(std::vector *in_shape_ptr, const std::vector &out_shape, int index) { - if (in_shape_ptr->size() > out_shape.size()) { - return false; - } - int out_index = index; - int in_index = in_shape_ptr->size() - 1; - while (in_index >= 0 && (in_shape_ptr->at(in_index) == out_shape[out_index] || in_shape_ptr->at(in_index) == 1)) { - in_index--; - out_index--; - } - return in_index < 0; -} - -void BackComfortableAlign(std::vector *in_shape_ptr, const std::vector &out_shape) { - if (in_shape_ptr->size() >= out_shape.size()) { - return; - } - int out_index = out_shape.size() - 1; - bool is_comfortable = false; - while (out_index >= static_cast(in_shape_ptr->size()) - 1) { - if (IsComfortableAlign(in_shape_ptr, out_shape, out_index)) { - is_comfortable = true; - break; - } - out_index--; - } - if (is_comfortable == false) { - MS_LOG(INFO) << "failed to align constant tensor"; - return; - } - while (static_cast(in_shape_ptr->size()) - 1 < out_index) { - in_shape_ptr->insert(in_shape_ptr->begin(), 1); - } - while (in_shape_ptr->size() < out_shape.size()) { - in_shape_ptr->insert(in_shape_ptr->end(), 1); - } - DebugDims("constant : ", ConvertCudaDims(*in_shape_ptr)); - return; -} - -void AlignShapeRank(std::vector *in_shape_ptr, const std::vector &out_shape) { - const size_t last_dim = in_shape_ptr->size() - 1; - const int in_rank = in_shape_ptr->size(); - int index = out_shape.size() - 1; - for (; index >= 0; index--) { - if (out_shape[index] == in_shape_ptr->at(last_dim)) { - break; - } - } - const int align_rank = index + 1; - if (index <= 0 || align_rank == in_rank) return; - for (int i = 0; i < index + 1 - in_rank; i++) { - in_shape_ptr->insert(in_shape_ptr->begin(), 1); - } -} - -nvinfer1::ITensor *ConvertTensorWithExpandDims(TensorRTContext *ctx, const TensorInfo &ms_tensor, - const std::vector &expect_shape, const std::string &op_name) { - if (ctx == nullptr || ctx->network() == nullptr) { - MS_LOG(ERROR) << "network is null for ConvertTensorWithExpandDims"; - return nullptr; - } - if (!ms_tensor.IsConst()) { - MS_LOG(ERROR) << "ConvertTensorWithExpandDims from a MSTensor with nullptr data"; - return nullptr; - } - auto origin_shape = ms_tensor.Shape(); - std::vector convert_shape(expect_shape); - BackComfortableAlign(&origin_shape, convert_shape); - if (ms_tensor.ElementNum() != - std::accumulate(origin_shape.begin(), origin_shape.end(), 1, std::multiplies())) { - MS_LOG(ERROR) << "ExpandDims failed for " << op_name; - return nullptr; - } - nvinfer1::Dims dims = ConvertCudaDims(origin_shape); - if (dims.nbDims == -1) { - MS_LOG(ERROR) << "ConvertCudaDims failed for " << op_name; - return nullptr; - } - nvinfer1::DataType data_type = ConvertDataType(ms_tensor.DataType()); - nvinfer1::Weights weights{data_type, ms_tensor.Data(), ms_tensor.ElementNum()}; - nvinfer1::IConstantLayer *constant_tensor = ctx->network()->addConstant(dims, weights); - if (constant_tensor == nullptr) { - MS_LOG(ERROR) << "create constant_tensor failed."; - return nullptr; - } - ctx->RegisterLayer(constant_tensor, ms_tensor.Name() + "_" + op_name); - auto tensor_ptr = constant_tensor->getOutput(0); - return tensor_ptr; -} - -nvinfer1::ITensor *ConvertConstantTensor1D(TensorRTContext *ctx, int *weights_vec, nvinfer1::DataType data_type) { - constexpr int nchw_dims_count = 4; - nvinfer1::Weights weights{data_type, weights_vec, nchw_dims_count}; - nvinfer1::Dims dims; - dims.nbDims = 1; - dims.d[0] = nchw_dims_count; - nvinfer1::IConstantLayer *constant_tensor = ctx->network()->addConstant(dims, weights); - if (constant_tensor == nullptr) { - MS_LOG(ERROR) << "create constant_tensor failed."; - return nullptr; - } - return constant_tensor->getOutput(0); -} - -nvinfer1::ITensor *ConvertConstantTensorWithDims(TensorRTContext *ctx, const TensorInfo &ms_tensor, - const std::vector &expect_shape, const std::string &op_name) { - nvinfer1::ITensor *constant_input{nullptr}; - std::string tensor_name = op_name + "_" + ms_tensor.Name(); - if (ms_tensor.Shape().size() == 0 || ms_tensor.ElementNum() == 1) { - constant_input = - lite::ConvertScalarToITensor(ctx, expect_shape.size(), ms_tensor, ms_tensor.DataType(), tensor_name); - if (constant_input == nullptr) { - MS_LOG(ERROR) << "create Itensor from scalar tensor failed: " << tensor_name; - return nullptr; - } - } else if (ms_tensor.Shape().size() == expect_shape.size()) { - constant_input = lite::ConvertConstantTensor(ctx, ms_tensor, tensor_name); - if (constant_input == nullptr) { - MS_LOG(ERROR) << "create Itensor from constant tensor failed: " << tensor_name; - return nullptr; - } - } else if (ms_tensor.ElementNum() >= 1) { - constant_input = ConvertTensorWithExpandDims(ctx, ms_tensor, expect_shape, tensor_name); - if (constant_input == nullptr) { - MS_LOG(ERROR) << "create Itensor from ConvertTensorWithExpandDims failed: " << tensor_name; - return nullptr; - } - } else { - MS_LOG(ERROR) << "const tensor value needs check: " << tensor_name; - } - return constant_input; -} - -nvinfer1::Weights TransposeWeight2D(const TensorInfo &ms_tensor, void **pack_weight) { - // usage notice: malloc addr saved to pack_weight, save pack_weight ptr and free it when deconstruct - nvinfer1::Weights weights{}; - weights.count = ms_tensor.ElementNum(); - auto weight_shape = ms_tensor.Shape(); - if (weight_shape.size() != DIMENSION_2D) { - MS_LOG(ERROR) << ms_tensor.Name() << " dims is " << weight_shape.size(); - return weights; - } - if (!ms_tensor.IsConst()) { - MS_LOG(ERROR) << ms_tensor.Name() << " has null data"; - return weights; - } - void *pack_weight_tmp = malloc(ms_tensor.DataSize()); - if (pack_weight_tmp == nullptr) { - MS_LOG(ERROR) << "Malloc buffer failed."; - return weights; - } - *pack_weight = pack_weight_tmp; - weights.values = pack_weight_tmp; - - int row = weight_shape[0]; - int col = weight_shape[1]; - - switch (ms_tensor.DataType()) { - case DataType::kNumberTypeFloat16: { - weights.type = nvinfer1::DataType::kHALF; - auto src = static_cast(ms_tensor.Data()); - auto dst = static_cast(pack_weight_tmp); - for (int r = 0; r < row; ++r) { - for (int c = 0; c < col; ++c) { - dst[c * row + r] = src[r * col + c]; - } - } - break; - } - case DataType::kNumberTypeFloat32: { - weights.type = nvinfer1::DataType::kFLOAT; - auto dst = static_cast(pack_weight_tmp); - auto src = static_cast(ms_tensor.Data()); - for (int r = 0; r < row; ++r) { - for (int c = 0; c < col; ++c) { - dst[c * row + r] = src[r * col + c]; - } - } - break; - } - default: { - MS_LOG(ERROR) << ms_tensor.Name() << " has unsupported tensor datatype for transpose data : " - << static_cast(ms_tensor.DataType()); - } - } - return weights; -} - -nvinfer1::Weights ConvertWeight(const TensorInfo &ms_tensor) { - nvinfer1::Weights weights{}; - weights.type = ConvertDataType(ms_tensor.DataType()); - weights.values = ms_tensor.Data(); - weights.count = ms_tensor.ElementNum(); - if (weights.values == nullptr) { - MS_LOG(ERROR) << "ConvertWeight from a MSTensor with nullptr data"; - } - return weights; -} - -nvinfer1::ITensor *TRTTensorCast(TensorRTContext *ctx, nvinfer1::ITensor *trt_tensor, nvinfer1::DataType data_type, - const std::string &name) { -#if TRT_VERSION_GE(7, 2) - data_type = data_type == nvinfer1::DataType::kBOOL ? nvinfer1::DataType::kINT32 : data_type; - if (data_type == nvinfer1::DataType::kINT32 && trt_tensor->getType() == nvinfer1::DataType::kFLOAT) { - auto plugin = std::make_shared(name, data_type); - nvinfer1::ITensor *inputTensors[] = {trt_tensor}; - nvinfer1::IPluginV2Layer *cast_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); - if (cast_layer == nullptr) { - MS_LOG(ERROR) << "cast_layer is nullptr!"; - return nullptr; - } - cast_layer->setName(name.c_str()); - nvinfer1::ITensor *cast_out = cast_layer->getOutput(0); - cast_out->setName((name + "_output").c_str()); - return cast_out; - } - auto cast_layer = ctx->network()->addIdentity(*trt_tensor); -#else - auto plugin = std::make_shared(name, data_type); - nvinfer1::ITensor *inputTensors[] = {trt_tensor}; - nvinfer1::IPluginV2Layer *cast_layer = ctx->network()->addPluginV2(inputTensors, 1, *plugin); -#endif - if (cast_layer == nullptr) { - MS_LOG(ERROR) << "create cast layer failed for: " << name; - return nullptr; - } -#if TRT_VERSION_GE(7, 2) - cast_layer->setOutputType(0, data_type); -#endif - cast_layer->setName(name.c_str()); - nvinfer1::ITensor *cast_out = cast_layer->getOutput(0); - cast_out->setName((name + "_output").c_str()); - return cast_out; -} - -int SetCudaDevice(std::shared_ptr device_info_) { - return SetCudaDevice(static_cast(device_info_->GetDeviceID())); -} - -int SetCudaDevice(int device_id) { - int device = 0; - auto ret = cudaGetDevice(&device); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaGetDevice failed, device is untrustable. error code: " << ret; - return RET_ERROR; - } - int set_device_id = device_id; - int deviceCnt = 0; - - ret = cudaGetDeviceCount(&deviceCnt); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaGetDeviceCount failed."; - return RET_ERROR; - } - - if (set_device_id > deviceCnt - 1) { - MS_LOG(ERROR) << "invalid input device id as " << set_device_id << " for current device count " << deviceCnt; - return RET_ERROR; - } - if (device != set_device_id) { - ret = cudaSetDevice(set_device_id); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaSetDevice failed, error code: " << ret; - return RET_ERROR; - } - } - if (cudaGetDevice(&device) != cudaSuccess) { - MS_LOG(ERROR) << "cudaGetDevice failed, device is untrustable."; - return RET_ERROR; - } - MS_LOG(DEBUG) << "cuda is running on device: " << device; - return RET_OK; -} - -Format GetOutputFormat(Format input_format, nvinfer1::Permutation perm) { - if (input_format == Format::NHWC) { - if (perm.order[kNHWC_N] == kNHWC_N && perm.order[kNHWC_H] == kNHWC_C && perm.order[kNHWC_W] == kNHWC_W && - perm.order[kNHWC_C] == kNHWC_H) { - return Format::NCHW; - } - } else if (input_format == Format::NCHW) { - if (perm.order[kNCHW_N] == kNCHW_N && perm.order[kNCHW_C] == kNCHW_H && perm.order[kNCHW_H] == kNCHW_W && - perm.order[kNCHW_W] == kNCHW_C) { - return Format::NHWC; - } - } - MS_LOG(WARNING) << "transpose out format needs to check for " << input_format; - return input_format; -} -int ConvertAxisFromNHWC2NCHW(int nhwc_axis) { - return nhwc_axis; - // N0H1W2C3->N0C1H2W3 - if (nhwc_axis > kNHWC_C) { - return nhwc_axis; - } - switch (nhwc_axis) { - case kNHWC_N: - return kNCHW_N; - case kNHWC_H: - return kNCHW_H; - case kNHWC_W: - return kNCHW_W; - case kNHWC_C: - return kNCHW_C; - default: - MS_LOG(ERROR) << "invalid input axis for nhwc: " << nhwc_axis; - } - return nhwc_axis; -} - -void PackNHWCToNCHWFp16(const void *src, void *dst, size_t batches, size_t plane, size_t channel, size_t task_id, - size_t thread_count) { - size_t hw8 = plane / C8NUM; - size_t task_start = 0; - size_t task_end = plane; - if (thread_count > 0) { - size_t offset_hw = UP_DIV(hw8, thread_count) * C8NUM; - task_start = offset_hw * task_id; - size_t count = plane - task_start; - if (count == 0) { - return; - } - task_end = (task_id + 1) == thread_count ? plane : MSMIN(plane, task_start + offset_hw); - hw8 = task_start + ((task_end - task_start) >= offset_hw ? offset_hw : 0); - } else { - hw8 *= C8NUM; - } - size_t c8 = channel / C8NUM * C8NUM; - size_t batch = plane * channel; - for (size_t n = 0; n < batches; n++) { - const uint16_t *src_batch = static_cast(src) + n * batch; - uint16_t *dst_batch = static_cast(dst) + n * batch; - size_t hw = task_start; - for (; hw < hw8; hw += C8NUM) { - size_t c = 0; - for (; c < c8; c += C8NUM) { - const uint16_t *src_ptr = src_batch + hw * channel + c; - uint16_t *dst_ptr = dst_batch + c * plane + hw; - for (size_t tr = 0; tr < C8NUM; tr++) { - for (size_t tc = 0; tc < C8NUM; tc++) { - dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; - } - } - } - for (; c < channel; c++) { - const uint16_t *src_ptr = src_batch + hw * channel + c; - uint16_t *dst_ptr = dst_batch + c * plane + hw; - for (size_t i = 0; i < C8NUM; i++) { - dst_ptr[i] = src_ptr[i * channel]; - } - } - } - for (; hw < task_end; hw++) { - const uint16_t *src_ptr = src_batch + hw * channel; - uint16_t *dst_ptr = dst_batch + hw; - for (size_t i = 0; i < channel; i++) { - dst_ptr[i * plane] = src_ptr[i]; - } - } - } -} -std::string GetTensorFormat(nvinfer1::ITensor *trt_tensor, mindspore::Format format, bool is_same) { - nvinfer1::Dims dims = trt_tensor->getDimensions(); - std::string is_same_string = is_same ? " is same with ms tensor " : " is different from ms tensor "; - std::string out_string = "tensor " + std::string(trt_tensor->getName()) + ": format (NHWC:1, NCHW:0) is " + - std::to_string(static_cast(format)) + is_same_string + ", dims is "; - std::string dim_string = "["; - for (int i = 0; i < dims.nbDims; i++) { - dim_string += std::to_string(dims.d[i]); - if (i != dims.nbDims - 1) { - dim_string += ", "; - } - } - dim_string += "]"; - out_string += dim_string; - return out_string; -} - -std::string GetTensorFormat(ITensorHelper tensor_helper) { - return GetTensorFormat(tensor_helper.trt_tensor_, tensor_helper.format_, tensor_helper.same_format_); -} - -std::string GetTensorFormat(nvinfer1::ITensor *trt_tensor) { return GetTensorFormat(trt_tensor, Format::NHWC, true); } - -std::experimental::optional TryConvertTRTReduceMode(ReduceMode mode) { - std::map reduce_ops_ = { - {ReduceMode::Reduce_Mean, nvinfer1::ReduceOperation::kAVG}, - {ReduceMode::Reduce_Max, nvinfer1::ReduceOperation::kMAX}, - {ReduceMode::Reduce_Min, nvinfer1::ReduceOperation::kMIN}, - {ReduceMode::Reduce_Prod, nvinfer1::ReduceOperation::kPROD}, - {ReduceMode::Reduce_L2, nvinfer1::ReduceOperation::kSUM}, - {ReduceMode::Reduce_Sum, nvinfer1::ReduceOperation::kSUM}, - }; - return reduce_ops_.find(mode) != reduce_ops_.end() - ? std::experimental::optional(reduce_ops_[mode]) - : std::experimental::nullopt; -} -int PreprocessInputs2SameDim(TensorRTContext *ctx, ITensorHelper input_tensor_helper, - ITensorHelper *out_tensor_helper) { - if (input_tensor_helper.trt_tensor_ == nullptr) { - MS_LOG(ERROR) << "input trt tensor is nullptr"; - return RET_ERROR; - } - out_tensor_helper->trt_tensor_ = input_tensor_helper.trt_tensor_; - out_tensor_helper->format_ = input_tensor_helper.format_; - out_tensor_helper->same_format_ = true; - if (input_tensor_helper.trt_tensor_->getDimensions().nbDims == DIMENSION_4D && !input_tensor_helper.same_format_) { - if (input_tensor_helper.format_ == Format::NCHW) { - // transpose: NCHW->NHWC - nvinfer1::IShuffleLayer *transpose_layer_in = NCHW2NHWC(ctx, *input_tensor_helper.trt_tensor_); - if (transpose_layer_in == nullptr) { - MS_LOG(ERROR) << "op action convert failed"; - return RET_ERROR; - } - transpose_layer_in->setName( - (std::string(input_tensor_helper.trt_tensor_->getName()) + "_input_transpose2NHWC").c_str()); - out_tensor_helper->trt_tensor_ = transpose_layer_in->getOutput(0); - out_tensor_helper->format_ = Format::NHWC; - } else { - // transpose: NHWC->NCHW - nvinfer1::IShuffleLayer *transpose_layer_in = NHWC2NCHW(ctx, *input_tensor_helper.trt_tensor_); - if (transpose_layer_in == nullptr) { - MS_LOG(ERROR) << "op action convert failed"; - return RET_ERROR; - } - transpose_layer_in->setName( - (std::string(input_tensor_helper.trt_tensor_->getName()) + "_input_transpose2NCHW").c_str()); - out_tensor_helper->trt_tensor_ = transpose_layer_in->getOutput(0); - out_tensor_helper->format_ = Format::NCHW; - } - } - return RET_OK; -} - -int GetDimsVolume(const nvinfer1::Dims &dims) { - if (dims.nbDims <= 0) { - return 0; - } - return std::accumulate(dims.d, dims.d + dims.nbDims, 1, std::multiplies()); -} - -int GetDimsVolume(const std::vector &shape) { - if (shape.size() == 0) { - return 0; - } - return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); -} - -std::experimental::optional SqueezeDims(const nvinfer1::Dims &in_dims, int pos) { - if (in_dims.nbDims <= 1) { - MS_LOG(ERROR) << "invalid shape size: " << in_dims.nbDims << "for squeeze."; - return {}; - } - nvinfer1::Dims out_dims; - int i = 0; - for (int j = 0; j <= in_dims.nbDims; ++j) { - if (j != pos) { - out_dims.d[i++] = in_dims.d[j]; - } - } - out_dims.nbDims = in_dims.nbDims - 1; - return std::experimental::optional(out_dims); -} - -std::experimental::optional UnsqueezeDims(const nvinfer1::Dims &in_dims, int pos, int val) { - if (in_dims.nbDims >= in_dims.MAX_DIMS) { - MS_LOG(ERROR) << "invalid shape size: " << in_dims.nbDims << "for unsqueeze."; - return {}; - } - nvinfer1::Dims out_dims; - int i = 0; - for (int j = 0; j <= in_dims.nbDims; ++j) { - if (j == pos) { - out_dims.d[j] = val; - } else { - out_dims.d[j] = in_dims.d[i++]; - } - } - out_dims.nbDims = in_dims.nbDims + 1; - return std::experimental::optional(out_dims); -} - -int ParseData2Vector(const TensorInfo &ms_tensor, std::vector *dst) { - if (!ms_tensor.IsConst()) { - MS_LOG(ERROR) << "ignore tensor: " << ms_tensor.Name(); - return RET_ERROR; - } - dst->clear(); - dst->resize(ms_tensor.ElementNum()); - switch (ms_tensor.DataType()) { - case DataType::kNumberTypeInt64: { - Data2Vector(dst, ms_tensor.Data()); - break; - } - case DataType::kNumberTypeInt32: { - Data2Vector(dst, ms_tensor.Data()); - break; - } - default: { - MS_LOG(ERROR) << ms_tensor.Name() << " has more datatype to parse"; - return RET_ERROR; - } - } - return RET_OK; -} - -nvinfer1::ITensor *ExpandDim(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor, int axis) { - // input has to prepocess to nchw - auto input_dims = input_tensor->getDimensions(); - nvinfer1::IShuffleLayer *shuffle_layer = ctx->network()->addShuffle(*input_tensor); - // if expand dim not at last dim and shape is dynamic, change to expanddim at last dim and transpose - bool special_expand = false; - for (int i = 0; i < input_dims.nbDims; i++) { - special_expand = special_expand || input_dims.d[i] == -1; - } - special_expand = special_expand && (axis != -1 && axis != input_dims.nbDims); - - if (special_expand) { - std::vector new_shape; - for (int i = 0; i < input_dims.nbDims; i++) { - new_shape.push_back(input_dims.d[i] == -1 ? 0 : input_dims.d[i]); - } - new_shape.push_back(1); - nvinfer1::Dims new_dims = ConvertCudaDims(new_shape); - if (new_dims.nbDims == -1) { - return nullptr; - } - - shuffle_layer->setReshapeDimensions(new_dims); - // transpose - nvinfer1::Permutation perm{}; - for (int i = 0; i < new_dims.nbDims; i++) { - if (i < axis) { - perm.order[i] = i; - } else if (i == axis) { - perm.order[i] = new_dims.nbDims - 1; - } else { - perm.order[i] = i - 1; - } - } - nvinfer1::IShuffleLayer *trans_layer = ctx->network()->addShuffle(*shuffle_layer->getOutput(0)); - if (trans_layer == nullptr) { - MS_LOG(ERROR) << "add transpose layer failed for special expand dims op "; - return nullptr; - } - trans_layer->setFirstTranspose(perm); - return trans_layer->getOutput(0); - } else { - std::vector new_shape; - for (int i = 0; i < input_dims.nbDims; i++) { - if (axis == i) { - new_shape.push_back(1); - } - new_shape.push_back(input_dims.d[i] == -1 ? 0 : input_dims.d[i]); - } - if (axis == -1 || axis == input_dims.nbDims) { - new_shape.push_back(1); - } - nvinfer1::Dims new_dims = ConvertCudaDims(new_shape); - if (new_dims.nbDims == -1) { - return nullptr; - } - shuffle_layer->setReshapeDimensions(new_dims); - return shuffle_layer->getOutput(0); - } -} - -nvinfer1::ITensor *Broadcast(TensorRTContext *ctx, nvinfer1::ITensor *input, nvinfer1::ITensor *shape) { - int rank = shape->getDimensions().d[0]; - - nvinfer1::Dims starts{rank}; - std::fill(starts.d, starts.d + rank, 0); - nvinfer1::Dims strides{rank}; - std::fill(strides.d, strides.d + rank, 1); - - auto slice_layer = ctx->network()->addSlice(*input, starts, {}, strides); - slice_layer->setMode(nvinfer1::SliceMode::kWRAP); - const int INPUT2 = 2; - slice_layer->setInput(INPUT2, *shape); - - auto shuffler_output = slice_layer->getOutput(0); - if (shuffler_output == nullptr) { - MS_LOG(ERROR) << "add slice layer failed"; - } - return shuffler_output; -} - -nvinfer1::ITensor *Reshape(TensorRTContext *ctx, nvinfer1::ITensor *input, const std::vector &shape) { - return Reshape(ctx, input, ConvertCudaDims(shape)); -} - -nvinfer1::ITensor *Reshape(TensorRTContext *ctx, nvinfer1::ITensor *input, const nvinfer1::Dims &shape) { - auto reshape_layer = ctx->network()->addShuffle(*input); - if (reshape_layer == nullptr) { - MS_LOG(ERROR) << "add reshape_layer failed"; - return nullptr; - } - reshape_layer->setReshapeDimensions(shape); - return reshape_layer->getOutput(0); -} - -void DebugDims(const std::string &key, const nvinfer1::Dims &dims) { - MS_LOG(DEBUG) << key << ":" << dims.nbDims; - for (int i = 0; i != dims.nbDims; ++i) { - MS_LOG(DEBUG) << dims.d[i]; - } -} - -template <> -nvinfer1::DataType GetNvinferDataType() { - return nvinfer1::DataType::kFLOAT; -} - -template <> -nvinfer1::DataType GetNvinferDataType() { - return nvinfer1::DataType::kINT32; -} - -template nvinfer1::DataType GetNvinferDataType(); -template nvinfer1::DataType GetNvinferDataType(); - -#ifdef PROFILER_ -void SimpleProfiler::reportLayerTime(const char *layerName, float ms) noexcept { - mProfile_[layerName].count++; - mProfile_[layerName].time += ms; - if (std::find(mLayerNames_.begin(), mLayerNames_.end(), layerName) == mLayerNames_.end()) { - mLayerNames_.push_back(layerName); - } -} - -SimpleProfiler::SimpleProfiler(const char *name, const std::vector &srcProfilers) : mName_(name) { - for (const auto &srcProfiler : srcProfilers) { - for (const auto &rec : srcProfiler.mProfile_) { - auto it = mProfile_.find(rec.first); - if (it == mProfile_.end()) { - mProfile_.insert(rec); - } else { - it->second.time += rec.second.time; - it->second.count += rec.second.count; - } - } - } -} - -std::ostream &operator<<(std::ostream &out, const SimpleProfiler &value) { - out << "========== " << value.mName_ << " profile ==========" << std::endl; - float totalTime = 0; - std::string layerNameStr = "TensorRT layer name"; - int maxLayerNameLength = std::max(static_cast(layerNameStr.size()), 70); - for (const auto &elem : value.mProfile_) { - totalTime += elem.second.time; - maxLayerNameLength = std::max(maxLayerNameLength, static_cast(elem.first.size())); - } - - auto old_settings = out.flags(); - auto old_precision = out.precision(); - // Output header - { - out << std::setw(maxLayerNameLength) << layerNameStr << " "; - out << std::setw(C12NUM) << "Runtime, " - << "%" - << " "; - out << std::setw(C12NUM) << "Invocations" - << " "; - out << std::setw(C12NUM) << "Runtime, ms" << std::endl; - } - for (size_t i = 0; i < value.mLayerNames_.size(); i++) { - const std::string layerName = value.mLayerNames_[i]; - auto elem = value.mProfile_.at(layerName); - out << std::setw(maxLayerNameLength) << layerName << " "; - out << std::setw(C12NUM) << std::fixed << std::setprecision(1) << (elem.time * 100.0F / totalTime) << "%" - << " "; - out << std::setw(C12NUM) << elem.count << " "; - out << std::setw(C12NUM) << std::fixed << std::setprecision(C2NUM) << elem.time << std::endl; - } - out.flags(old_settings); - out.precision(old_precision); - out << "========== " << value.mName_ << " total runtime = " << totalTime << " ms ==========" << std::endl; - - return out; -} -#endif // PROFILER_ -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h b/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h deleted file mode 100644 index b5787b4e90567f504f18ad6795a34a6d1c1521fd..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h +++ /dev/null @@ -1,219 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_UTILS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_UTILS_H_ -#include -#include -#include -#include -#include -#include -#include -#include "src/extendrt/delegate/tensorrt/tensorrt_context.h" -#include "src/extendrt/delegate/tensorrt/tensor_info.h" -#include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h" -#include "ir/dtype/type_id.h" -#include "schema/ops_generated.h" -#include "nnacl/pack.h" -#include "include/api/context.h" -#include "mindapi/base/types.h" - -#define kNCHW_N 0 -#define kNCHW_C 1 -#define kNCHW_H 2 -#define kNCHW_W 3 -#define kNHWC_N 0 -#define kNHWC_H 1 -#define kNHWC_W 2 -#define kNHWC_C 3 - -namespace mindspore::lite { -#define TRT_VERSION_GE(major, minor) \ - (NV_TENSORRT_MAJOR > major) || ((NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR >= minor)) -#define TRT_VERSION_LS(major, minor) \ - (NV_TENSORRT_MAJOR < major) || ((NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR < minor)) -struct ActivationParams { - nvinfer1::ActivationType activation_type; - bool has_alpha; - float alpha; - bool has_beta; - float beta; -}; - -typedef union float32_bits { - unsigned int u; - float f; -} float32_bits; - -#ifdef PROFILER_ -struct SimpleProfiler : public nvinfer1::IProfiler { - struct Record { - float time{0}; - int count{0}; - }; - - virtual void reportLayerTime(const char *layerName, float ms) noexcept; - - explicit SimpleProfiler(const char *name, - const std::vector &srcProfilers = std::vector()); - - friend std::ostream &operator<<(std::ostream &out, const SimpleProfiler &value); - - private: - std::string mName_; - std::vector mLayerNames_; - std::map mProfile_; -}; -#endif - -// Convert Tensor data to Cuda dims. -nvinfer1::Dims ConvertCudaDims(const std::vector &data); - -nvinfer1::Dims ConvertCudaDims(int data, size_t size); - -nvinfer1::Dims ConvertCudaDims(const TensorInfo &ms_tensor); - -std::string CudaDimsAsString(const nvinfer1::Dims &dims); - -std::vector ConvertTensorAsIntVector(const TensorInfo &ms_tensor); - -bool SameDims(nvinfer1::Dims dims, const std::vector &shape); - -std::vector ConvertMSShape(const nvinfer1::Dims dims); - -std::vector NHWC2NCHW(std::vector nhwc_shape); - -nvinfer1::DataType ConvertDataType(DataType type_id); - -cudaDataType ConvertDataType(nvinfer1::DataType type_id); - -nvinfer1::IShuffleLayer *NHWC2NCHW(TensorRTContext *ctx, const nvinfer1::ITensor &input); - -nvinfer1::IShuffleLayer *NCHW2NHWC(TensorRTContext *ctx, const nvinfer1::ITensor &input); - -std::experimental::optional TryConvertActivationType(ActivationType activation_type); - -nvinfer1::ITensor *ConvertConstantTensor(TensorRTContext *ctx, const TensorInfo &ms_tensor, const std::string &op_name); - -nvinfer1::ITensor *ConvertTensorWithExpandDims(TensorRTContext *ctx, const TensorInfo &ms_tensor, - const std::vector &expect_shape, const std::string &op_name); - -nvinfer1::ITensor *ConvertScalarToITensor(TensorRTContext *ctx, size_t shape_size, const void *value, - const DataType data_type, const std::string &op_name); - -nvinfer1::ITensor *ConvertScalarToITensor(TensorRTContext *ctx, size_t shape_size, const TensorInfo &ms_tensor, - const DataType data_type, const std::string &op_name); - -nvinfer1::ITensor *ConvertConstantTensorWithDims(TensorRTContext *ctx, const TensorInfo &ms_tensor, - const std::vector &expect_shape, const std::string &op_name); - -nvinfer1::Weights TransposeWeight2D(const TensorInfo &ms_tensor, void **pack_weight); - -nvinfer1::Weights ConvertWeight(const TensorInfo &ms_tensor); - -nvinfer1::ITensor *TRTTensorCast(TensorRTContext *ctx, nvinfer1::ITensor *tensor, nvinfer1::DataType data_type, - const std::string &name); - -int SetCudaDevice(std::shared_ptr device_info_); - -int SetCudaDevice(int device_id); - -Format GetOutputFormat(Format input_format, nvinfer1::Permutation perm); - -int ConvertAxisFromNHWC2NCHW(int nhwc_axis); - -void PackNHWCToNCHWFp16(const void *src, void *dst, size_t batch, size_t plane, size_t channel, size_t task_id, - size_t thread_count); - -std::string GetTensorFormat(nvinfer1::ITensor *trt_tensor, mindspore::Format format, bool is_same); - -std::string GetTensorFormat(ITensorHelper tensor_helper); - -std::string GetTensorFormat(nvinfer1::ITensor *trt_tensors); - -std::experimental::optional TryConvertTRTReduceMode(ReduceMode mode); - -int PreprocessInputs2SameDim(TensorRTContext *ctx, ITensorHelper input_tensor_helper, ITensorHelper *out_tensor_helper); - -int GetDimsVolume(const nvinfer1::Dims &dims); - -int GetDimsVolume(const std::vector &shape); - -std::experimental::optional SqueezeDims(const nvinfer1::Dims &in_dims, int pos); - -std::experimental::optional UnsqueezeDims(const nvinfer1::Dims &in_dims, int pos, int val); - -nvinfer1::ITensor *Reshape(TensorRTContext *ctx, nvinfer1::ITensor *input, const std::vector &shape); - -nvinfer1::ITensor *Reshape(TensorRTContext *ctx, nvinfer1::ITensor *input, const nvinfer1::Dims &shape); - -nvinfer1::ITensor *ConvertConstantTensor1D(TensorRTContext *ctx, int *weights_vec, nvinfer1::DataType data_type); - -int ParseData2Vector(const TensorInfo &ms_tensor, std::vector *dst); - -void DebugDims(const std::string &key, const nvinfer1::Dims &dims); - -nvinfer1::ITensor *ExpandDim(TensorRTContext *ctx, nvinfer1::ITensor *input_tensor, int axis); - -nvinfer1::ITensor *Broadcast(TensorRTContext *ctx, nvinfer1::ITensor *input, nvinfer1::ITensor *shape); - -template -nvinfer1::DataType GetNvinferDataType(); - -template -bool SameDims(const std::vector &shape1, const std::vector &shape2) { - if (shape1.size() != shape2.size()) { - return false; - } - for (size_t i = 0; i < shape1.size(); i++) { - if (std::abs(shape1[i] - shape2[i]) > 1e-6) { - return false; - } - } - return true; -} - -template -nvinfer1::Dims ConvertCudaDims(const std::vector &shape) { - nvinfer1::Dims dims{}; - dims.nbDims = -1; - if (!shape.empty() && shape.size() <= static_cast(dims.MAX_DIMS)) { - dims.nbDims = shape.size(); - for (int i = 0; i < dims.nbDims; i++) { - dims.d[i] = static_cast(shape[i]); - } - } else { - MS_LOG(INFO) << "ms shape is invalid!shape size: " << shape.size(); - } - return dims; -} - -inline size_t IntToSize(int u) { - if (u < 0) { - MS_LOG(WARNING) << "The int value(" << u << ") is less than 0."; - return SIZE_MAX; - } - return static_cast(u); -} -template -void Data2Vector(std::vector *dst, const void *src) { - auto src_ptr = static_cast(src); - for (size_t i = 0; i < dst->size(); i++) { - dst->at(i) = static_cast(src_ptr[i]); - } -} -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_UTILS_H_ diff --git a/mindspore-lite/src/extendrt/delegate/type.h b/mindspore-lite/src/extendrt/delegate/type.h deleted file mode 100644 index 109fd1194720c7b88bcecf72cf4ec1324a9a02a7..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/delegate/type.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TYPE_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TYPE_H_ - -#include -#include -#include "include/api/delegate_api.h" -#include "ir/func_graph.h" -#include "src/extendrt/kernel/base_kernel.h" -#include "extendrt/kernel/kernel_lib.h" - -namespace mindspore { -class ExtendDelegate : public IDelegate { - public: - ExtendDelegate() = default; - ~ExtendDelegate() override = default; - - void ReplaceNodes(const std::shared_ptr &graph) override { - // not implemented - } - - bool IsDelegateNode(const std::shared_ptr &node) override { - // not implemented - return false; - } - - std::shared_ptr CreateKernel(const std::shared_ptr &node) override { - // not implemented - return nullptr; - } - - virtual std::shared_ptr CreateKernel(const kernel::KernelSpec &spec, - const std::vector &inputs, - const std::vector &outputs, - const InferContext *ctx) const { - // not implemented - return nullptr; - } -}; -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TYPE_H_ diff --git a/mindspore-lite/src/extendrt/delegate_graph_executor.cc b/mindspore-lite/src/extendrt/delegate_graph_executor.cc index 89d27e1fe8cf8ed293fa4d9c25ff22eb555e42f0..81a4d2e51cd94f84989bb33454d6a3b6769db4c4 100644 --- a/mindspore-lite/src/extendrt/delegate_graph_executor.cc +++ b/mindspore-lite/src/extendrt/delegate_graph_executor.cc @@ -40,7 +40,7 @@ bool GraphSinkDelegate::IsDelegateNode(const std::shared_ptr &node) { return false; } -std::shared_ptr GraphExecutorDelegate::CreateKernel(const std::shared_ptr &node) { +std::shared_ptr GraphExecutorDelegate::CreateKernel(const std::shared_ptr &node) { if (!IsDelegateNode(node)) { return nullptr; } diff --git a/mindspore-lite/src/extendrt/delegate_graph_executor.h b/mindspore-lite/src/extendrt/delegate_graph_executor.h index d2778a96a693fd0722a1e81480f09a62c890f003..858ec041bbed9c0ff0d90203e215d49985019938 100644 --- a/mindspore-lite/src/extendrt/delegate_graph_executor.h +++ b/mindspore-lite/src/extendrt/delegate_graph_executor.h @@ -23,17 +23,19 @@ #include "runtime/hardware/device_context.h" #include "tools/common/func_graph_subgraph.h" #include "common/kernel.h" -#include "extendrt/session/lite_graph_executor.h" + +#include "src/extendrt/session/lite_graph_executor.h" +#include "src/extendrt/subgraph_kernel.h" + namespace mindspore { // Graph sink delegate, the whole FuncGraph as a node to execute. -class GraphSinkDelegate : public IDelegate { +class GraphSinkDelegate { public: - GraphSinkDelegate(const std::vector &inputs, const std::vector &outputs) - : IDelegate(inputs, outputs) {} + GraphSinkDelegate(const std::vector &inputs, const std::vector &outputs) {} virtual ~GraphSinkDelegate() = default; - void ReplaceNodes(const std::shared_ptr &graph) override; + void ReplaceNodes(const std::shared_ptr &graph); - bool IsDelegateNode(const std::shared_ptr &node) override; + bool IsDelegateNode(const std::shared_ptr &node); protected: FuncGraphPtr sink_graph_; @@ -47,7 +49,7 @@ class GraphExecutorDelegate : public GraphSinkDelegate { std::shared_ptr executor) : GraphSinkDelegate(inputs, outputs), executor_(executor) {} virtual ~GraphExecutorDelegate() = default; - std::shared_ptr CreateKernel(const std::shared_ptr &node) override; + std::shared_ptr CreateKernel(const std::shared_ptr &node); private: const std::shared_ptr executor_; diff --git a/mindspore-lite/src/extendrt/executor.cc b/mindspore-lite/src/extendrt/executor.cc deleted file mode 100644 index c923695d1508b86a1e1b70e3f45df78635c7b9b2..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/executor.cc +++ /dev/null @@ -1,432 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "backend/common/session/executor.h" -#include "backend/common/session/executor_manager.h" -#include -#include -#include -#include "runtime/device/kernel_runtime_manager.h" -#include "include/common/utils/comm_manager.h" -#include "include/common/utils/scoped_long_running.h" - -namespace mindspore { -namespace session { -namespace { -void GetNeedNotifyTensors(const VectorRef *outputs, std::set *result) { - MS_EXCEPTION_IF_NULL(outputs); - MS_EXCEPTION_IF_NULL(result); - for (auto &item : *outputs) { - if (utils::isa(item)) { - auto vector_ref = utils::cast(item); - GetNeedNotifyTensors(&vector_ref, result); - } else if (utils::isa(item)) { - auto tensor = utils::cast(item); - result->emplace(tensor); - } - } -} - -bool TensorInVector(const VectorRef *outputs) { - MS_EXCEPTION_IF_NULL(outputs); - for (auto &item : *outputs) { - if (utils::isa(item)) { - auto vector_ref = utils::cast(item); - if (TensorInVector(&vector_ref)) { - return true; - } - } else if (utils::isa(item)) { - return true; - } - } - return false; -} - -bool IsTaskReady(const std::shared_ptr &task) { - MS_EXCEPTION_IF_NULL(task); - for (auto &input : task->input_need_wait_tensors_) { - MS_EXCEPTION_IF_NULL(input); - if (input->NeedWait()) { - return false; - } - } - auto session = task->session_; - MS_EXCEPTION_IF_NULL(session); - auto graph = session->GetGraph(task->graph_id_); - if (graph != nullptr) { - return graph->IsPreGraphFinished(); - } - return true; -} - -void WaitLockedInputs(const std::shared_ptr &task) { - bool need_lock = false; - for (auto &tensor : task->input_tensors_) { - if (tensor->NeedWait()) { - if (tensor->IsGraphOutput()) { - task->input_need_wait_tensors_.emplace_back(tensor); - } else { - need_lock = true; - } - } - } - if (need_lock) { - mindspore::ScopedLongRunning long_running; - for (auto &input_tensor : task->input_tensors_) { - if (input_tensor->NeedWait() && !input_tensor->IsGraphOutput()) { - MsException::Instance().CheckException(); - input_tensor->Wait(); - } - } - MsException::Instance().CheckException(); - } - // need lock input parameters for optimizer - for (auto &need_lock_tensor : task->input_need_lock_tensors_) { - need_lock_tensor->SetNeedWait(true); - } -} -} // namespace - -void CompileNodesTask::Run() { - MS_EXCEPTION_IF_NULL(session_); - MS_EXCEPTION_IF_NULL(segment_); - graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_); -} - -void CompileGraphTask::Run() { - MS_EXCEPTION_IF_NULL(session_); - graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_)); -} - -void BuildGraphTask::Run() { - MS_EXCEPTION_IF_NULL(session_); - session_->BuildGraphImpl(graph_id_); -} - -void RunGraphTask::Run() { - MS_EXCEPTION_IF_NULL(session_); - MS_LOG(INFO) << "Start run graph " << graph_id_; - auto graph = session_->GetGraph(graph_id_); - if (graph == nullptr) { - MS_LOG(ERROR) << "Invalid graph id " << graph_id_; - return; - } - graph->ResetGraphRunningStatus(); - if (AnfUtils::UseMemScheduler()) { - graph->SetOutputNodeToTensor(node_to_tensor_); - } - try { - session_->LoadInputs(graph_id_, input_tensors_); - session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); - std::map new_to_old_device_address; - session_->UpdateOutputTensors(&outputs_, tensor_to_node_, &new_to_old_device_address); - } catch (const std::exception &e) { - session_->ReportErrorMessage(); - ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); - MsException::Instance().SetException(); - } - MS_LOG(INFO) << "End run graph " << graph_id_; - graph->OnRunGraphFinished(); - std::set need_notify_tensors(input_need_lock_tensors_.begin(), input_need_lock_tensors_.end()); - GetNeedNotifyTensors(&outputs_, &need_notify_tensors); - for (auto &tensor : need_notify_tensors) { - if (tensor != nullptr) { - tensor->SetNeedWait(false); - } - } - ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished); -} - -void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); } - -void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); } - -Executor::Executor(const std::string &device_name, uint32_t device_id) { - device_name_ = device_name; - device_id_ = device_id; - worker_ = std::make_shared(&Executor::WorkerLoop, this); -} - -Executor::~Executor() { - try { - WorkerJoin(); - } catch (const std::exception &e) { - MS_LOG(ERROR) << "Executor call destructor failed: " << e.what(); - } catch (...) { - MS_LOG(ERROR) << "Executor call destructor failed."; - } -} - -void Executor::WorkerJoin() { - // Avoid worker thread join itself which will cause deadlock - if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) { - { - std::lock_guard lock(task_mutex_); - auto task = std::make_shared(); - ready_tasks_.push(task); - task_cond_var_.notify_all(); - } - worker_->join(); - } -} - -void Executor::WorkerLoop() { - while (true) { - std::shared_ptr task; - { - std::unique_lock lock(task_mutex_); - task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); }); - task = ready_tasks_.front(); - ready_tasks_.pop(); - } - MS_EXCEPTION_IF_NULL(task); - enum TaskType task_type = task->type_; - bool task_sync_flag = task->sync_run_; - if (task_type == kExit) { - OnWorkerExit(); - return; - } - try { - if (task->session_ != nullptr) { - task->session_->SetThreadContext(); - } - task->Run(); - if (task->session_ != nullptr) { - task->session_->ReportWarningMessage(); - } - } catch (const std::exception &e) { - if (task->session_ != nullptr) { - task->session_->ReportErrorMessage(); - } - ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); - MsException::Instance().SetException(); - } - { - std::lock_guard lock(done_task_mutex_); - done_tasks_.emplace_back(std::move(task)); - } - if (task_type != kRunGraph || task_sync_flag) { - std::lock_guard lock(task_mutex_); - sync_run_task_finished_ = true; - sync_cond_var_.notify_all(); - } - } -} - -std::vector> Executor::GetReadyTasksFromPendingList() { - std::vector> ready_tasks; - std::lock_guard lock(pending_task_mutex_); - for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) { - auto task = *iter; - if (IsTaskReady(task)) { - (void)ready_tasks.emplace_back(task); - iter = pending_tasks_.erase(iter); - } else { - ++iter; - } - } - return ready_tasks; -} - -void Executor::OnEvent(const ExecutorEvent &event) { - if (event == ExecutorEvent::kRunGraphFinished) { - OnRunGraphFinished(); - } else if (event == ExecutorEvent::kClear) { - OnClear(); - } else if (event == ExecutorEvent::kException) { - OnException(); - } -} - -void Executor::OnClear() { - { - mindspore::ScopedLongRunning long_running; - WorkerJoin(); - } - ClearDoneTasks(); -} - -void Executor::OnException() { - std::vector> done_tasks; - { - std::lock_guard lock(task_mutex_); - while (!ready_tasks_.empty()) { - (void)done_tasks.emplace_back(ready_tasks_.front()); - ready_tasks_.pop(); - } - } - { - std::lock_guard lock(pending_task_mutex_); - (void)std::copy(pending_tasks_.begin(), pending_tasks_.end(), std::back_inserter(done_tasks)); - pending_tasks_.clear(); - } - { - std::lock_guard lock(done_task_mutex_); - (void)done_tasks_.insert(done_tasks_.end(), done_tasks.begin(), done_tasks.end()); - } -} - -void Executor::OnRunGraphFinished() { - auto ready_tasks = GetReadyTasksFromPendingList(); - std::lock_guard lock(task_mutex_); - for (auto &task : ready_tasks) { - ready_tasks_.push(task); - } - if (!ready_tasks.empty()) { - task_cond_var_.notify_all(); - } - reenter_cond_var_.notify_all(); -} - -void Executor::ClearDoneTasks() { - std::lock_guard lock(done_task_mutex_); - done_tasks_.clear(); -} - -void Executor::RunTask(const std::shared_ptr &task, bool sync, bool long_run) { - if (sync) { - ClearDoneTasks(); - } - { - std::lock_guard lock(task_mutex_); - sync_run_task_finished_ = false; - ready_tasks_.push(task); - } - task_cond_var_.notify_all(); - if (sync && !sync_run_task_finished_) { - std::unique_lock lock(task_mutex_); - if (sync && long_run) { - mindspore::ScopedLongRunning long_running; - sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; }); - } else { - sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; }); - } - } - ClearDoneTasks(); - MsException::Instance().CheckException(); -} - -GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, - const AnfNodePtrList &outputs) { - auto task = std::make_shared(); - task->session_ = session; - task->segment_ = segment; - task->output_nodes_ = outputs; - RunTask(task, true); - return task->graph_id_; -} - -GraphId Executor::CompileGraph(const SessionPtr &session, NotNull func_graph) { - auto task = std::make_shared(); - task->session_ = session; - task->func_graph_ = func_graph.get(); - RunTask(task, true); - return task->graph_id_; -} - -void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) { - auto task = std::make_shared(); - task->session_ = session; - task->graph_id_ = graphId; - RunTask(task, true); -} - -void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id, - const std::vector &inputs, VectorRef *outputs) { - MS_EXCEPTION_IF_NULL(session); - MS_EXCEPTION_IF_NULL(outputs); - auto task = std::make_shared(); - task->session_ = session; - task->graph_id_ = graph_id; - task->input_tensors_ = inputs; - session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_, &task->node_to_tensor_); - task->outputs_ = *outputs; - task->sync_run_ = true; - RunTask(task, true, true); -} - -void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, - const std::vector &inputs, VectorRef *outputs) { - MS_EXCEPTION_IF_NULL(session); - MS_EXCEPTION_IF_NULL(outputs); - auto task = std::make_shared(); - task->session_ = session; - task->graph_id_ = graph_id; - task->input_tensors_ = inputs; - task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs); - auto graph = session->GetGraph(task->graph_id_); - if (graph != nullptr && !graph->IsPostGraphFinished()) { - mindspore::ScopedLongRunning long_running; - std::unique_lock lock(reenter_mutex_); - reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); }); - MsException::Instance().CheckException(); - } - session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_, &task->node_to_tensor_); - // maintain a copy of output vector - task->outputs_ = *outputs; - - // Run graph synchronously when the graph require gil. - if (graph != nullptr && graph->is_need_gil()) { - std::unique_lock lock(reenter_mutex_); - reenter_cond_var_.wait(lock, [&graph] { return graph->IsPreGraphFinished(); }); - MsException::Instance().CheckException(); - task->sync_run_ = true; - RunTask(task, true, true); - return; - } - - // sync run graph without output tensor(int dataset graph) - if ((!TensorInVector(outputs) && graph != nullptr && !graph->HasPostGraph())) { - task->sync_run_ = true; - RunTask(task, true, true); - return; - } - WaitLockedInputs(task); - for (auto &tensor_node : task->tensor_to_node_) { - tensor_node.first->SetNeedWait(true); - } - { - std::lock_guard lock(pending_task_mutex_); - if (!IsTaskReady(task)) { - ClearDoneTasks(); - pending_tasks_.push_back(task); - return; - } - } - RunTask(task, false); -} - -bool Executor::CreateCommGroup(const std::string &group_name, const std::vector &ranks) { - auto task = std::make_shared(); - task->group_name_ = group_name; - task->ranks_ = ranks; - RunTask(task, true); - return task->result_; -} - -bool Executor::DestroyCommGroup(const std::string &group_name) { - auto task = std::make_shared(); - task->group_name_ = group_name; - RunTask(task, true); - return task->result_; -} - -void Executor::OnWorkerExit() { - if (device_name_ == kAscendDevice) { - device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_); - } -} -} // namespace session -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/graph_compiler.cc b/mindspore-lite/src/extendrt/graph_compiler.cc deleted file mode 100644 index 2b79631d1cb1f2b4946e949ca98bd7b104ae1d7f..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler.cc +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "extendrt/graph_compiler.h" - -namespace mindspore { -namespace infer { -ExcutionPlan GraphCompiler::Compile(FuncGraphPtr func_graph) { return {}; } -GraphId GraphCompiler::CompileSegment(const GraphSegmentPtr segment) { return -1; } -CompileResult GraphCompiler::LinkSegment() { return CompileResult(); } -ExcutionPlan GraphCompiler::Schedule(const CompileResult &compile_result) { - return scheduler_.Schedule(compile_result); -} -} // namespace infer -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/graph_compiler/anfnode_tensor_adapter.cc b/mindspore-lite/src/extendrt/graph_compiler/anfnode_tensor_adapter.cc deleted file mode 100644 index 514c9324b0f4c7e51830b258056da8976fd1de1a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/anfnode_tensor_adapter.cc +++ /dev/null @@ -1,576 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/graph_compiler/anfnode_tensor_adapter.h" -#include -#include "src/extendrt/graph_compiler/compile_result_builder.h" -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" -#include "mindspore/ops/op_def/sequence_ops.h" -#include "utils/ms_utils_secure.h" -#include "mindspore/ccsrc/include/common/utils/utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" - -using ShapePtr = mindspore::abstract::ShapePtr; -using AbstractBasePtr = mindspore::abstract::AbstractBasePtr; -using AbstractTensorPtr = mindspore::abstract::AbstractTensorPtr; -using AbstractSequencePtr = mindspore::abstract::AbstractSequencePtr; - -namespace mindspore { -namespace lite { -namespace { -AbstractBasePtr GetRealAbstract(const CNodePtr &cnode) { - // MakeTuple infer is skipped in converter, so cnode->abstract is nullptr. The function can be deleted if MakeTuple is - // inferred in the converter. - if (!IsPrimitive(cnode->input(kIndex0), prim::kPrimMakeTuple)) { - return cnode->abstract(); - } - std::vector abstracts; - for (size_t i = 1; i < cnode->size(); ++i) { - const auto &input = cnode->inputs()[i]; - MSLITE_CHECK_PTR_RETURN(input, nullptr); - const auto &abstract = input->abstract(); - MSLITE_CHECK_PTR_RETURN(input, abstract); - abstracts.emplace_back(abstract); - } - return std::make_shared(abstracts); -} -} // namespace - -InferTensor *TensorAdapter::Convert2Tensor(const ParameterPtr ¶m_node, Format format) { - auto adapter = TensorAdapter::Create(param_node, format); - if (adapter == nullptr) { - MS_LOG(ERROR) << "Create tensor-adapter from parameter failed, parameter : " << param_node; - return nullptr; - } - return adapter->ToTensor(); -} - -InferTensor *TensorAdapter::Convert2Tensor(const ValueNodePtr &value_node, Format format) { - auto adapter = TensorAdapter::Create(value_node, format); - if (adapter == nullptr) { - MS_LOG(ERROR) << "Create tensor-adapter from value-node failed, value-node : " << value_node; - return nullptr; - } - return adapter->ToTensor(); -} - -InferTensor *TensorAdapter::Convert2Tensor(const AbstractTensorPtr &abstract, Format format) { - auto adapter = TensorAdapter::Create(abstract, format); - if (adapter == nullptr) { - MS_LOG(ERROR) << "Create tensor-adapter from abstracttensor failed, abstract : " << abstract; - return nullptr; - } - return adapter->ToTensor(); -} - -InferTensor *TensorAdapter::Convert2Tensor(const AbstractBasePtr &abstract, Format format) { - auto adapter = TensorAdapter::Create(abstract, format); - if (adapter == nullptr) { - MS_LOG(ERROR) << "Create tensor-adapter from abstractbase failed, abstract : " << abstract; - return nullptr; - } - return adapter->ToTensor(); -} - -InferTensor *TensorAdapter::ToTensor() { - std::vector int32_shape; - if (std::any_of(shape_.begin(), shape_.end(), - [](const ShapeValueDType &dim) { return dim == abstract::Shape::kShapeRankAny; })) { - int32_shape.emplace_back(-1); - } else { - int32_shape.resize(shape_.size()); - for (size_t i = 0; i < shape_.size(); i++) { - int32_shape[i] = static_cast(shape_[i]); - } - } - auto *tensor = InferTensor::CreateTensor(name_, data_type_, int32_shape, data_, data_len_); - if (tensor == nullptr) { - return nullptr; - } - // move data to tensor - tensor->set_own_data(own_data_); - own_data_ = false; - tensor->set_format(format_); - return tensor; -} - -std::vector> TensorAdapter::CreateTensorsFromAbstract(const AbstractBasePtr &abstract, - Format format) { - if (abstract == nullptr) { - MS_LOG(ERROR) << "Input `abstract` is nullptr."; - return {}; - } - std::vector> results; - // multi output abstract - if (utils::isa(abstract)) { - auto elements = utils::cast(abstract)->elements(); - for (auto &element : elements) { - auto tensor = TensorAdapter::Convert2Tensor(element, format); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Create tensor from abstract failed, abstract : " << element; - return {}; - } - results.emplace_back(std::unique_ptr(tensor)); - } - return results; - } - // single output abstract - if (utils::isa(abstract)) { - auto tensor = TensorAdapter::Convert2Tensor(abstract, format); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Create tensor from abstract failed, abstract : " << abstract; - return {}; - } - results.emplace_back(std::unique_ptr(tensor)); - return results; - } - MS_LOG(ERROR) << "Unsupported abstract: " << abstract; - return {}; -} - -std::vector TensorAdapter::Convert2Tensor(const CNodePtr &cnode, Format format) { - if (cnode == nullptr) { - MS_LOG(ERROR) << "Input cnode is nullptr."; - return {}; - } - - auto abstract = GetRealAbstract(cnode); - if (abstract == nullptr) { - MS_LOG(ERROR) << "CNode abstract is nullptr."; - return {}; - } - auto tmp = TensorAdapter::CreateTensorsFromAbstract(abstract); - if (tmp.empty()) { - MS_LOG(ERROR) << "Create tensors from output abstract of cnode failed, cnode : " << cnode->fullname_with_scope(); - return {}; - } - std::vector results; - results.reserve(tmp.size()); - std::transform(tmp.begin(), tmp.end(), std::back_inserter(results), - [](std::unique_ptr &tensor) { return tensor.release(); }); - return results; -} - -TensorAdapterPtr TensorAdapter::Create(const ParameterPtr ¶m_node, Format format) { - if (param_node == nullptr) { - MS_LOG(ERROR) << "Input parameter is nullptr."; - return nullptr; - } - ShapeVector shape_vector; - TypeId data_type = kTypeUnknown; - auto status = GetDTAndShapeFromParameter(param_node, &data_type, &shape_vector); - if (status != kSuccess) { - MS_LOG(ERROR) << "Get data type and shape from param node failed."; - return nullptr; - } - if (data_type == kObjectTypeString) { - MS_LOG(ERROR) << "Not support kObjectTypeString type DefaultParam."; - return nullptr; - } - auto abstract = param_node->abstract(); - if (abstract == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter is nullptr."; - return nullptr; - } - auto adapter = std::make_shared(abstract->name()); - adapter->data_type_ = data_type; - adapter->shape_ = shape_vector; - adapter->format_ = format; - adapter->is_const_ = param_node->has_default(); - if (!adapter->is_const_) { - return adapter; - } - auto tensor_info = std::dynamic_pointer_cast(param_node->default_param()); - if (tensor_info == nullptr) { - MS_LOG(ERROR) << "Cast default-param to tensor failed."; - return nullptr; - } - adapter->compress_type_ = tensor_info->compression_type(); - adapter->data_ = tensor_info->data_c(); - adapter->data_len_ = tensor_info->Size(); - adapter->own_data_ = false; - return adapter; -} - -TensorAdapterPtr TensorAdapter::CreateFromTensorValueNode(const ValueNodePtr &value_node) { - auto value_abstract = value_node->abstract(); - if (value_abstract == nullptr) { - MS_LOG(ERROR) << "Abstract of value is nullptr"; - return nullptr; - } - auto adapter = TensorAdapter::Create(value_abstract); - if (adapter == nullptr) { - MS_LOG(ERROR) << "Create tensor adapter from abstract of valuenode failed, valuenode: " - << value_node->fullname_with_scope(); - return nullptr; - } - adapter->is_const_ = true; - - auto value = value_node->value(); - if (value == nullptr) { - MS_LOG(ERROR) << "Value of value-node is nullptr, " << value_node->fullname_with_scope(); - return nullptr; - } - auto data = value->cast(); - if (data == nullptr) { - MS_LOG(ERROR) << "Value of tensor-type value-node is not a Tensor, " << value_node->fullname_with_scope(); - return nullptr; - } - adapter->data_ = data->data_c(); - adapter->data_len_ = data->Size(); - adapter->own_data_ = false; - return adapter; -} - -TensorAdapterPtr TensorAdapter::CreateFromInt32ImmValue(const ValueNodePtr &value_node) { - MS_ASSERT(value_node != nullptr); - auto adapter = std::make_shared(value_node->fullname_with_scope()); - adapter->is_const_ = true; - adapter->data_type_ = kNumberTypeInt32; - adapter->shape_ = {1}; - auto value = value_node->value(); - if (value == nullptr) { - MS_LOG(ERROR) << "Value of value-node is nullptr, " << value_node->fullname_with_scope(); - return nullptr; - } - auto data = GetValue(value); - adapter->data_ = malloc(sizeof(int32_t)); - if (adapter->data_ == nullptr) { - MS_LOG(ERROR) << "malloc const tensor data failed."; - return nullptr; - } - (reinterpret_cast(adapter->data_))[0] = data; - adapter->data_len_ = sizeof(int32_t); - adapter->own_data_ = true; - return adapter; -} - -TensorAdapterPtr TensorAdapter::CreateFromInt64ImmValue(const ValueNodePtr &value_node) { - MS_ASSERT(value_node != nullptr); - auto adapter = std::make_shared(value_node->fullname_with_scope()); - adapter->is_const_ = true; - adapter->data_type_ = kNumberTypeInt64; - adapter->shape_ = {1}; - auto value = value_node->value(); - if (value == nullptr) { - MS_LOG(ERROR) << "Value of value-node is nullptr, " << value_node->fullname_with_scope(); - return nullptr; - } - auto data = GetValue(value); - adapter->data_ = malloc(sizeof(int64_t)); - if (adapter->data_ == nullptr) { - MS_LOG(ERROR) << "malloc const tensor data failed."; - return nullptr; - } - (reinterpret_cast(adapter->data_))[0] = data; - adapter->data_len_ = sizeof(int64_t); - adapter->own_data_ = true; - return adapter; -} - -TensorAdapterPtr TensorAdapter::CreateFromBoolImmValue(const ValueNodePtr &value_node) { - MS_ASSERT(value_node != nullptr); - auto adapter = std::make_shared(value_node->fullname_with_scope()); - adapter->is_const_ = true; - adapter->data_type_ = kNumberTypeBool; - adapter->shape_ = {1}; - auto value = value_node->value(); - if (value == nullptr) { - MS_LOG(ERROR) << "Value of value-node is nullptr, " << value_node->fullname_with_scope(); - return nullptr; - } - auto data = value->cast(); - if (data == nullptr) { - MS_LOG(ERROR) << "BoolImm Value of cast to BoolImmPtr failed, " << value_node->fullname_with_scope(); - return nullptr; - } - auto data_value = data->value(); - adapter->data_ = malloc(sizeof(bool)); - if (adapter->data_ == nullptr) { - MS_LOG(ERROR) << "malloc const tensor data failed."; - return nullptr; - } - (reinterpret_cast(adapter->data_))[0] = data_value; - adapter->data_len_ = sizeof(bool); - adapter->own_data_ = true; - return adapter; -} - -TensorAdapterPtr TensorAdapter::CreateFromNumberTypeValue(const ValueNodePtr &value_node) { - MS_ASSERT(value_node != nullptr); - auto adapter = std::make_shared(value_node->fullname_with_scope()); - adapter->is_const_ = true; - adapter->data_type_ = kNumberTypeInt32; - adapter->shape_ = {1}; - auto data = utils::cast(value_node->value()); - if (data == nullptr) { - MS_LOG(ERROR) << "Value of Number type value-node is not a NumberPtr, " << value_node->fullname_with_scope(); - return nullptr; - } - TypeId number_type = data->number_type(); - static const std::unordered_map TypeToTypeMap = { - {kNumberTypeInt, kNumberTypeInt32}, {kNumberTypeUInt, kNumberTypeUInt32}, {kNumberTypeFloat, kNumberTypeFloat32}}; - if (TypeToTypeMap.find(number_type) != TypeToTypeMap.end()) { - number_type = TypeToTypeMap.at(number_type); - } - auto number_data = static_cast(number_type); - adapter->data_ = malloc(sizeof(int32_t)); - if (adapter->data_ == nullptr) { - MS_LOG(ERROR) << "malloc const tensor data failed."; - return nullptr; - } - (reinterpret_cast(adapter->data_))[0] = number_data; - adapter->data_len_ = sizeof(int32_t); - adapter->own_data_ = true; - return adapter; -} - -TensorAdapterPtr TensorAdapter::CreateFromIntSequenceValue(const ValueNodePtr &value_node) { - MS_ASSERT(value_node != nullptr); - auto value_seq = utils::cast(value_node->value()); - if (value_seq == nullptr) { - MS_LOG(ERROR) << "Value of Sequence type value-node is not a ValueSequencePtr, " - << value_node->fullname_with_scope(); - return nullptr; - } - auto adapter = std::make_shared(value_node->fullname_with_scope()); - adapter->is_const_ = true; - if (!value_seq->value().empty()) { - if (value_seq->value().front()->type()->number_type() == kNumberTypeInt32 || - value_seq->value().front()->type()->number_type() == kNumberTypeInt) { - adapter->data_type_ = kNumberTypeInt32; - auto data = GetValue>(value_seq); - auto data_len = data.size() * sizeof(int32_t); - adapter->shape_ = {static_cast(data.size())}; - adapter->data_len_ = data_len; - if (data_len > 0) { - adapter->data_ = malloc(data_len); - if (adapter->data_ == nullptr) { - MS_LOG(ERROR) << "malloc const tensor data failed."; - return nullptr; - } - auto ret = memcpy_s(adapter->data_, data_len, data.data(), data_len); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy const tensor data failed: " << ret; - free(adapter->data_); - return nullptr; - } - adapter->own_data_ = true; - } else { - adapter->data_ = nullptr; - adapter->own_data_ = false; - } - } else if (value_seq->value().front()->type()->number_type() == kNumberTypeInt64) { - adapter->data_type_ = kNumberTypeInt64; - auto data = GetValue>(value_seq); - auto data_len = data.size() * sizeof(int64_t); - adapter->shape_ = {static_cast(data.size())}; - adapter->data_len_ = data_len; - if (data_len > 0) { - adapter->data_ = malloc(data_len); - if (adapter->data_ == nullptr) { - MS_LOG(ERROR) << "malloc const tensor data failed."; - return nullptr; - } - auto ret = memcpy_s(adapter->data_, data_len, data.data(), data_len); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy const tensor data failed: " << ret; - free(adapter->data_); - return nullptr; - } - adapter->own_data_ = true; - } else { - adapter->data_ = nullptr; - adapter->own_data_ = false; - } - } else { - MS_LOG(ERROR) << "only support integer value ValueSequence."; - return nullptr; - } - } - return adapter; -} - -TensorAdapterPtr TensorAdapter::Create(const ValueNodePtr &value_node, Format format) { - MS_ASSERT(value_node != nullptr); - auto value = value_node->value(); - TensorAdapterPtr adapter; - if (value->isa()) { - adapter = CreateFromTensorValueNode(value_node); - } else if (value->isa()) { - adapter = CreateFromInt32ImmValue(value_node); - } else if (value->isa()) { - adapter = CreateFromInt64ImmValue(value_node); - } else if (value->isa()) { - adapter = CreateFromBoolImmValue(value_node); - } else if (value->isa()) { - adapter = CreateFromIntSequenceValue(value_node); - } else if (value->isa()) { - adapter = CreateFromNumberTypeValue(value_node); - } else { - MS_LOG(ERROR) << "Not support value type: " << value->type(); - return nullptr; - } - if (adapter == nullptr) { - return nullptr; - } - adapter->format_ = format; - return adapter; -} - -TensorAdapterPtr TensorAdapter::Create(const AbstractBasePtr &abs, Format format) { - auto abs_tensor = utils::cast(abs); - if (abs_tensor == nullptr) { - MS_LOG(ERROR) << "Input abstract is not a AbstractTensor."; - return nullptr; - } - return TensorAdapter::Create(abs_tensor, format); -} - -TensorAdapterPtr TensorAdapter::Create(const AbstractTensorPtr &abs_tensor, Format format) { - if (abs_tensor == nullptr) { - MS_LOG(ERROR) << "Input abstract is not a AbstractTensor."; - return nullptr; - } - ShapeVector shape_vector; - TypeId data_type = kTypeUnknown; - auto ret = GetDTAndShapeFromAbTensor(abs_tensor, &data_type, &shape_vector); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Get data type and shape from value node failed."; - return nullptr; - } - auto adapter = std::make_shared(abs_tensor->name()); - adapter->data_type_ = data_type; - adapter->shape_ = shape_vector; - adapter->format_ = format; - return adapter; -} - -StatusCode TensorAdapter::GetDTAndShapeFromAbTensor(const AbstractTensorPtr &abstract, TypeId *data_type, - ShapeVector *shape_vector) { - if (MS_UNLIKELY(abstract == nullptr || data_type == nullptr || shape_vector == nullptr)) { - MS_LOG(ERROR) << "input argument is nullptr"; - return kLiteInputParamInvalid; - } - if (abstract->element() == nullptr) { - MS_LOG(ERROR) << "`element` of abstract is nullptr"; - return kLiteError; - } - auto type_ptr = abstract->element()->GetTypeTrack(); - if (type_ptr == nullptr) { - MS_LOG(ERROR) << "Type of abstract is nullptr"; - return kLiteError; - } - *data_type = type_ptr->type_id(); - if (!utils::isa(abstract->BuildShape())) { - MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr"; - return kLiteError; - } - *shape_vector = utils::cast(abstract->BuildShape())->shape(); - return kSuccess; -} - -StatusCode TensorAdapter::SetDTAndShapeFromAbTensor(const TypeId &data_type, const ShapeVector &shape, - const AbstractTensorPtr &abstract) { - if (MS_UNLIKELY(abstract == nullptr)) { - MS_LOG(ERROR) << "input `abstract` is nullptr"; - return kLiteInputParamInvalid; - } - if (!utils::isa(abstract->BuildShape())) { - MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr"; - return kLiteError; - } - auto build_shape = utils::cast(abstract->BuildShape()); - build_shape->set_shape(shape); - abstract->set_shape(build_shape); - - if (abstract->element() == nullptr) { - MS_LOG(ERROR) << "`element` of abstract is nullptr"; - return kLiteError; - } - abstract->element()->set_type(TypeIdToType(data_type)); - return kSuccess; -} - -StatusCode TensorAdapter::SetDTAndShapeFromAbTensor(const TypeId &data_type, const std::vector &shape, - const mindspore::abstract::AbstractTensorPtr &abstract) { - ShapeVector shape_vec; - shape_vec.resize(shape.size()); - (void)std::transform(shape.begin(), shape.end(), shape_vec.begin(), - [](const int &dim) { return static_cast(dim); }); - return TensorAdapter::SetDTAndShapeFromAbTensor(data_type, shape_vec, abstract); -} - -StatusCode TensorAdapter::GetDTAndShapeFromParameter(const ParameterPtr ¶m_node, TypeId *data_type, - ShapeVector *shape_vector) { - MS_ASSERT(param_node != nullptr && data_type != nullptr && shape_vector != nullptr); - auto abstract_base = param_node->abstract(); - if (abstract_base == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); - return kLiteError; - } - auto abstract_tensor = utils::cast(abstract_base); - if (abstract_tensor == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name(); - return kLiteError; - } - return GetDTAndShapeFromAbTensor(abstract_tensor, data_type, shape_vector); -} - -bool TensorAdapter::SetDTAndShapeFromAbTensorToLiteTensor(const AbstractBasePtr &abstract, InferTensor *tensor) { - if (!utils::isa(abstract)) { - MS_LOG(ERROR) << "The abstract should be tensor, but got abstract : " << abstract; - return false; - } - ShapeVector shape_vector; - TypeId data_type = kTypeUnknown; - auto ret = TensorAdapter::GetDTAndShapeFromAbTensor(utils::cast(abstract), - &data_type, &shape_vector); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Get dtype and shape from abstract failed, abstract : " << abstract; - return false; - } - std::vector int32_shape; - std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(int32_shape), - [](const auto &shape) { return static_cast(shape); }); - tensor->set_data_type(data_type); - tensor->set_shape(int32_shape); - tensor->set_format(NCHW); - return true; -} - -bool TensorAdapter::SetDTAndShapeFromLiteTensorToAbTensor(const InferTensor &tensor, const AbstractBasePtr &abstract) { - if (MS_UNLIKELY(abstract == nullptr)) { - MS_LOG(ERROR) << "Input `abstract` is nullptr"; - return false; - } - if (!utils::isa(abstract)) { - MS_LOG(ERROR) << "The abstract should be tensor, but got abstract : " << abstract; - return false; - } - - auto ret = TensorAdapter::SetDTAndShapeFromAbTensor(tensor.data_type(), tensor.shape(), - utils::cast(abstract)); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Set dtype and shape to abstract failed, abstract : " << abstract; - return false; - } - return true; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/graph_compiler/anfnode_tensor_adapter.h b/mindspore-lite/src/extendrt/graph_compiler/anfnode_tensor_adapter.h deleted file mode 100644 index d51ac615836560c7d6e1209f2aa373ce90d7c823..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/anfnode_tensor_adapter.h +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_ANFNODE_TENSOR_ADAPTER_H_ -#define MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_ANFNODE_TENSOR_ADAPTER_H_ -#include -#include -#include -#include -#include -#include "src/infer/tensor.h" -#include "abstract/abstract_value.h" -#include "ir/anf.h" -#include "include/api/status.h" - -namespace mindspore { -namespace lite { -class TensorAdapter; -using TensorAdapterPtr = std::shared_ptr; -class TensorAdapter { - public: - explicit TensorAdapter(std::string name) : name_(std::move(name)) {} - virtual ~TensorAdapter() { - if (own_data_) { - free(data_); - } - } - - InferTensor *ToTensor(); - - static TensorAdapterPtr Create(const ParameterPtr ¶m_node, Format format = DEFAULT_FORMAT); - static TensorAdapterPtr Create(const ValueNodePtr &value_node, Format format = DEFAULT_FORMAT); - static TensorAdapterPtr Create(const mindspore::abstract::AbstractTensorPtr &abstract, - Format format = DEFAULT_FORMAT); - static TensorAdapterPtr Create(const mindspore::abstract::AbstractBasePtr &abstract, Format format = DEFAULT_FORMAT); - - static std::vector> CreateTensorsFromAbstract(const AbstractBasePtr &abstract, - Format format = Format::DEFAULT_FORMAT); - static std::vector Convert2Tensor(const CNodePtr &cnode, Format format = DEFAULT_FORMAT); - static InferTensor *Convert2Tensor(const ParameterPtr ¶m_node, Format format = DEFAULT_FORMAT); - static InferTensor *Convert2Tensor(const ValueNodePtr &value_node, Format format = DEFAULT_FORMAT); - static InferTensor *Convert2Tensor(const mindspore::abstract::AbstractTensorPtr &abstract, - Format format = DEFAULT_FORMAT); - static InferTensor *Convert2Tensor(const mindspore::abstract::AbstractBasePtr &abstract, - Format format = DEFAULT_FORMAT); - - static StatusCode GetDTAndShapeFromAbTensor(const mindspore::abstract::AbstractTensorPtr &abstract, TypeId *data_type, - ShapeVector *shape_vector); - static StatusCode SetDTAndShapeFromAbTensor(const TypeId &data_type, const ShapeVector &shape, - const mindspore::abstract::AbstractTensorPtr &abstract); - static StatusCode SetDTAndShapeFromAbTensor(const TypeId &data_type, const std::vector &shape, - const mindspore::abstract::AbstractTensorPtr &abstract); - - static bool SetDTAndShapeFromAbTensorToLiteTensor(const AbstractBasePtr &abstract, InferTensor *tensor); - static bool SetDTAndShapeFromLiteTensorToAbTensor(const InferTensor &tensor, const AbstractBasePtr &abstract); - - private: - static StatusCode GetDTAndShapeFromParameter(const ParameterPtr ¶m_node, TypeId *data_type, ShapeVector *shape); - - static TensorAdapterPtr CreateFromTensorValueNode(const ValueNodePtr &value_node); - - static TensorAdapterPtr CreateFromInt32ImmValue(const ValueNodePtr &value_node); - - static TensorAdapterPtr CreateFromInt64ImmValue(const ValueNodePtr &value_node); - - static TensorAdapterPtr CreateFromBoolImmValue(const ValueNodePtr &value_node); - - static TensorAdapterPtr CreateFromNumberTypeValue(const ValueNodePtr &value_node); - - static TensorAdapterPtr CreateFromIntSequenceValue(const ValueNodePtr &value_node); - - public: - Format format_{DEFAULT_FORMAT}; - TensorCompressionType compress_type_ = TensorCompressionType::kNoCompression; - TypeId data_type_{kTypeUnknown}; - bool is_const_{false}; - ShapeVector shape_{}; - void *data_{nullptr}; - size_t data_len_{0}; - bool own_data_{true}; - std::string name_; -}; -} // namespace lite -} // namespace mindspore - -#endif diff --git a/mindspore-lite/src/extendrt/graph_compiler/compile_result.cc b/mindspore-lite/src/extendrt/graph_compiler/compile_result.cc deleted file mode 100644 index 445a79c2e8cbb799bf014fdb1544f3180f9dec09..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/compile_result.cc +++ /dev/null @@ -1,359 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/graph_compiler/compile_result.h" -#include -#include -#include -#include -#include -#include -#include "ops/base_operator.h" -#include "utils/hash_map.h" -#include "include/api/status.h" -#include "ir/primitive.h" -#include "mindspore/ops/op_def/op_name.h" -#include "ops/primitive_c.h" -#include "src/common/file_utils.h" - -namespace mindspore { -namespace lite { -namespace { -constexpr char tab[] = " "; - -inline std::string GenIndent(int indent) { - std::ostringstream oss; - for (int i = 0; i < indent; i++) { - oss << tab; - } - return oss.str(); -} - -inline std::string DumpIntShape(const std::vector &shape) { - std::ostringstream oss; - oss << "["; - for (size_t i = 0; i < shape.size(); i++) { - oss << shape[i]; - if (i < shape.size() - 1) { - oss << ", "; - } - } - oss << "]"; - return oss.str(); -} - -inline std::string DumpTensor(const InferTensor *tensor, int indent = 0) { - std::ostringstream oss; - oss << GenIndent(indent) << "Tensor tensor_name() << ", dtype:" << tensor->data_type() - << ", format:" << tensor->format() << ", shape:" << DumpIntShape(tensor->shape()) << ">"; - return oss.str(); -} -} // namespace - -kernel::KernelAttr CompileNode::GetKernelAttr() const { - kernel::KernelAttr attr; - for (auto &input : inputs_) { - (void)attr.AddInputAttr(input->data_type(), FormatEnumToString(input->format())); - } - for (auto &output : outputs_) { - (void)attr.AddOutputAttr(output->data_type(), FormatEnumToString(output->format())); - } - return attr; -} - -CompileNodePtr CompileNode::Create(CNodePtr cnode) { - if (cnode == nullptr) { - return nullptr; - } - auto primitive = GetValueNode>(cnode->input(0)); - if (primitive == nullptr) { - MS_LOG(ERROR) << "Node has no primitive, first input of cnode(" << cnode->fullname_with_scope() - << ") is : " << cnode->input(0); - return nullptr; - } - auto node = std::make_shared(cnode->fullname_with_scope(), kernel::PrimitiveType(primitive->name())); - ops::PrimitiveCPtr primc{nullptr}; - if (utils::isa(primitive)) { - primc = utils::cast(primitive); - } else { - static auto ops_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap(); - auto primc_creator_iter = ops_primc_fns.find(node->type_.TypeName()); - if (primc_creator_iter == ops_primc_fns.end()) { - MS_LOG(ERROR) << "Can not find primitive_c create function for: " << node->type_; - return nullptr; - } - primc = primc_creator_iter->second(); - if (primc == nullptr) { - MS_LOG(ERROR) << "Create primitive_c failed, type: " << node->type_; - return nullptr; - } - (void)primc->SetAttrs(primitive->attrs()); - } - static auto baseops_fns = ops::OperatorRegister::GetInstance().GetOperatorMap(); - auto baseops_creator_iter = baseops_fns.find(node->type_.TypeName()); - if (baseops_creator_iter == baseops_fns.end()) { - MS_LOG(ERROR) << "Can not find base-operator create function for: " << node->type_; - return nullptr; - } - auto baseops_creator = baseops_creator_iter->second; - node->base_operator_ = baseops_creator(primc); - if (node->base_operator_ == nullptr) { - MS_LOG(ERROR) << "Create base-operator failed, type: " << node->type_; - return nullptr; - } - node->cnode_ = std::move(cnode); - return node; -} - -void CompileNode::AppendInputTensor(InferTensor *tensor) { this->inputs_.emplace_back(tensor); } - -void CompileNode::AppendOutputTensor(InferTensor *tensor) { this->outputs_.emplace_back(tensor); } - -std::string CompileNode::Dump(int indent) const { - constexpr int kNumberTwo = 2; - std::ostringstream oss; - oss << GenIndent(indent) << "CompileNode {" << std::endl; - oss << GenIndent(indent + 1) << "inputs: [" << std::endl; - for (auto &input : inputs_) { - oss << DumpTensor(input, indent + kNumberTwo) << std::endl; - } - oss << GenIndent(indent + 1) << "]" << std::endl; - oss << GenIndent(indent + 1) << "outputs: [" << std::endl; - for (auto &output : outputs_) { - oss << DumpTensor(output, indent + kNumberTwo) << std::endl; - } - oss << GenIndent(indent + 1) << "]" << std::endl; - oss << GenIndent(indent) << "}"; - return oss.str(); -} - -void CompileNode::ReplaceInputTensor(InferTensor *dst, const InferTensor *src) { - std::replace_if( - inputs_.begin(), inputs_.end(), [&src](InferTensor *ele) { return ele == src; }, dst); -} - -CompileNodePtr CompileResult::GetNode(const std::string &name) { - auto iter = node_map_.find(name); - if (iter == node_map_.end()) { - return nullptr; - } else { - return iter->second; - } -} - -CompileNodePtr CompileResult::GetArgNode(const std::string &name) { - auto iter = arg_node_map_.find(name); - if (iter == arg_node_map_.end()) { - return nullptr; - } else { - return iter->second; - } -} - -std::vector &CompileResult::GetMutableNodes() { - if (assembled_) { - MS_LOG(EXCEPTION) << "CompileResult not mutable after build."; - } - return nodes_; -} -std::vector &CompileResult::GetMutableInputs() { - if (assembled_) { - MS_LOG(EXCEPTION) << "CompileResult not mutable after build."; - } - return inputs_; -} - -std::vector &CompileResult::GetMutableOutputs() { - if (assembled_) { - MS_LOG(EXCEPTION) << "CompileResult not mutable after build."; - } - return outputs_; -} - -StatusCode CompileResult::AppendNode(CompileNodePtr node) { - if (assembled_) { - MS_LOG(EXCEPTION) << "CompileResult not mutable after build."; - } - if (node == nullptr) { - MS_LOG(ERROR) << "Input node is nullptr"; - return kLiteInputParamInvalid; - } - const std::string &node_name = node->GetName(); - auto iter = node_map_.find(node_name); - if (iter != node_map_.end()) { - MS_LOG(ERROR) << "Duplicated node name : " << node_name; - return kLiteError; - } - node_map_[node_name] = node; - nodes_.emplace_back(node); - return kSuccess; -} - -StatusCode CompileResult::AppendArgNode(CompileNodePtr node) { - if (assembled_) { - MS_LOG(EXCEPTION) << "CompileResult not mutable after build."; - } - if (node == nullptr) { - MS_LOG(ERROR) << "Input node is nullptr"; - return kLiteInputParamInvalid; - } - const std::string &node_name = node->GetName(); - auto iter = arg_node_map_.find(node_name); - if (iter != arg_node_map_.end()) { - MS_LOG(ERROR) << "Duplicated node name : " << node_name; - return kLiteError; - } - arg_node_map_[node_name] = node; - arg_nodes_.emplace_back(node); - return kSuccess; -} - -StatusCode CompileResult::AppendTensor(InferTensor *tensor) { - if (assembled_) { - MS_LOG(EXCEPTION) << "CompileResult not mutable after build."; - } - if (tensor == nullptr) { - MS_LOG(ERROR) << "Input tensor is nullptr"; - return kLiteInputParamInvalid; - } - tensors_.emplace_back(tensor); - return kSuccess; -} - -StatusCode CompileResult::AppendInputTensor(InferTensor *tensor, bool is_borrow) { - if (assembled_) { - MS_LOG(EXCEPTION) << "CompileResult not mutable after build."; - } - if (tensor == nullptr) { - MS_LOG(ERROR) << "Input tensor is nullptr"; - return kLiteInputParamInvalid; - } - inputs_.emplace_back(tensor); - if (!is_borrow) { - return AppendTensor(tensor); - } - return kSuccess; -} - -StatusCode CompileResult::AppendOutputTensor(InferTensor *tensor, bool is_borrow) { - if (assembled_) { - MS_LOG(EXCEPTION) << "CompileResult not mutable after build."; - } - if (tensor == nullptr) { - MS_LOG(ERROR) << "Input tensor is nullptr"; - return kLiteInputParamInvalid; - } - if (tensor->tensor_name().empty()) { - tensor->set_tensor_name("graph_out_" + std::to_string(this->outputs_.size())); - } - if (!is_borrow) { - return AppendTensor(tensor); - } - outputs_.emplace_back(tensor); - return kSuccess; -} - -StatusCode CompileResult::AppendNodeInputTensor(const CompileNodePtr &compile_node, InferTensor *tensor, - bool is_borrow) { - if (compile_node == nullptr) { - MS_LOG(ERROR) << "Input compile_node is nullptr"; - return kLiteInputParamInvalid; - } - return AppendNodeInputTensor(compile_node->GetName(), tensor, is_borrow); -} - -StatusCode CompileResult::AppendNodeInputTensor(const std::string &node_name, InferTensor *input_tensor, - bool is_borrow) { - if (assembled_) { - MS_LOG(EXCEPTION) << "CompileResult not mutable after build."; - } - if (input_tensor == nullptr) { - MS_LOG(ERROR) << "`input_tensor` is nullptr"; - return kLiteInputParamInvalid; - } - - auto iter = node_map_.find(node_name); - if (iter == node_map_.end()) { - MS_LOG(ERROR) << "CompileNode not belong to this graph, node: " << node_name; - return kLiteError; - } - iter->second->AppendInputTensor(input_tensor); - if (!is_borrow) { - return AppendTensor(input_tensor); - } - return kSuccess; -} - -StatusCode CompileResult::AppendNodeOutputTensor(const CompileNodePtr &compile_node, InferTensor *tensor, - bool is_borrow) { - if (compile_node == nullptr) { - MS_LOG(ERROR) << "Input compile_node is nullptr"; - return kLiteInputParamInvalid; - } - return AppendNodeOutputTensor(compile_node->GetName(), tensor, is_borrow); -} - -StatusCode CompileResult::AppendNodeOutputTensor(const std::string &node_name, InferTensor *output_tensor, - bool is_borrow) { - if (assembled_) { - MS_LOG(EXCEPTION) << "CompileResult not mutable after build."; - } - if (output_tensor == nullptr) { - MS_LOG(ERROR) << "`output_tensor` is nullptr"; - return kLiteInputParamInvalid; - } - - auto iter = node_map_.find(node_name); - if (iter == node_map_.end()) { - MS_LOG(ERROR) << "CompileNode not belong to this graph, node: " << node_name; - return kLiteError; - } - iter->second->AppendOutputTensor(output_tensor); - if (!is_borrow) { - return AppendTensor(output_tensor); - } - return kSuccess; -} - -std::string CompileResult::Dump(int indent) const { - constexpr int kNumTwo = 2; - std::ostringstream oss; - oss << GenIndent(indent) << "CompileResult {" << std::endl; - oss << GenIndent(indent + 1) << "nodes: [" << std::endl; - for (auto &node : nodes_) { - oss << node->Dump(indent + kNumTwo) << std::endl; - } - oss << GenIndent(indent + 1) << "]" << std::endl; - oss << GenIndent(indent + 1) << "inputs: [" << std::endl; - for (auto &input : inputs_) { - oss << DumpTensor(input, indent + kNumTwo) << std::endl; - } - oss << GenIndent(indent + 1) << "]" << std::endl; - oss << GenIndent(indent + 1) << "outputs: [" << std::endl; - for (auto &output : outputs_) { - oss << DumpTensor(output, indent + kNumTwo) << std::endl; - } - oss << GenIndent(indent + 1) << "]" << std::endl; - oss << GenIndent(indent + 1) << "tensors: [" << std::endl; - for (auto &tensor : tensors_) { - oss << DumpTensor(tensor, indent + kNumTwo) << std::endl; - } - oss << GenIndent(indent + 1) << "]" << std::endl; - oss << GenIndent(indent) << "}" << std::endl; - return oss.str(); -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/graph_compiler/compile_result.h b/mindspore-lite/src/extendrt/graph_compiler/compile_result.h deleted file mode 100644 index d7a701b780e9895baffc336caca0fc286e31f1b8..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/compile_result.h +++ /dev/null @@ -1,126 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_COMPILE_RESULT_H_ -#define MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_COMPILE_RESULT_H_ -#include -#include -#include -#include -#include -#include -#include "ir/anf.h" -#include "src/infer/tensor.h" -#include "include/model.h" -#include "ops/base_operator.h" -#include "utils/hash_map.h" -#include "include/api/status.h" -#include "common/common_utils.h" -#include "src/infer/primitive_type.h" - -namespace mindspore { -namespace lite { -class CompileNode { - public: - explicit CompileNode(std::string name, const kernel::PrimitiveType &type) : name_(std::move(name)), type_(type) {} - static std::shared_ptr Create(CNodePtr cnode); - - virtual ~CompileNode() = default; - - std::string GetName() const { return name_; } - kernel::PrimitiveType GetType() const { return type_; } - std::shared_ptr GetBaseOperator() const { return base_operator_; } - CNodePtr GetCNode() const { return cnode_; } - const std::vector &GetInputs() const { return inputs_; } - InferTensor *GetInput(size_t i) const { return inputs_.at(i); } - size_t InputSize() const { return inputs_.size(); } - const std::vector &GetOutputs() const { return outputs_; } - InferTensor *GetOutput(size_t i) const { return outputs_.at(i); } - size_t OutputSize() const { return outputs_.size(); } - - void SetName(const std::string &name) { name_ = name; } - void AppendInputTensor(InferTensor *tensor); - void AppendOutputTensor(InferTensor *tensor); - void ReplaceInputTensor(InferTensor *dst, const InferTensor *src); - kernel::KernelAttr GetKernelAttr() const; - std::string Dump(int indent = 0) const; - - private: - std::string name_{}; - kernel::PrimitiveType type_{}; - std::shared_ptr base_operator_{nullptr}; - CNodePtr cnode_{nullptr}; - std::vector inputs_{}; - std::vector outputs_{}; -}; -using CompileNodePtr = std::shared_ptr; - -class CompileResult { - public: - CompileResult() = default; - virtual ~CompileResult() = default; - - CompileNodePtr GetNode(const std::string &name); - CompileNodePtr GetArgNode(const std::string &name); - const std::vector &GetNodes() const { return nodes_; } - size_t NodeSize() const { return nodes_.size(); } - const std::vector &GetTensors() const { return tensors_; } - size_t TensorSize() const { return tensors_.size(); } - const std::vector &GetInputs() const { return inputs_; } - InferTensor *GetInput(size_t i) const { return inputs_.at(i); } - size_t InputSize() const { return inputs_.size(); } - const std::vector &GetOutputs() const { return outputs_; } - InferTensor *GetOutput(size_t i) const { return outputs_.at(i); } - size_t OutputSize() const { return outputs_.size(); } - const std::vector &GetParamNodes() const { return param_nodes_; } - const std::vector &GetReturnNodes() const { return return_nodes_; } - - std::vector &GetMutableNodes(); - std::vector &GetMutableInputs(); - std::vector &GetMutableOutputs(); - StatusCode AppendNode(CompileNodePtr node); - StatusCode AppendArgNode(CompileNodePtr node); - StatusCode AppendTensor(InferTensor *tensor); - StatusCode AppendInputTensor(InferTensor *tensor, bool is_borrow = false); - StatusCode AppendOutputTensor(InferTensor *tensor, bool is_borrow = false); - - StatusCode AppendNodeInputTensor(const CompileNodePtr &compile_node, InferTensor *tensor, bool is_borrow = false); - StatusCode AppendNodeInputTensor(const std::string &node_name, InferTensor *tensor, bool is_borrow = false); - StatusCode AppendNodeOutputTensor(const CompileNodePtr &compile_node, InferTensor *tensor, bool is_borrow = false); - StatusCode AppendNodeOutputTensor(const std::string &node_name, InferTensor *tensor, bool is_borrow = false); - - void Assemble() { this->assembled_ = true; } - - std::string Dump(int indent = 0) const; - - private: - bool assembled_ = false; - std::vector nodes_{}; - std::vector tensors_{}; - std::vector inputs_{}; - std::vector outputs_{}; - HashMap node_map_{}; - HashMap tensor_map_{}; - std::vector param_nodes_{}; - std::vector return_nodes_{}; - std::vector arg_nodes_{}; - HashMap arg_node_map_{}; -}; -using CompileResultPtr = std::shared_ptr; -} // namespace lite -} // namespace mindspore - -#endif diff --git a/mindspore-lite/src/extendrt/graph_compiler/compile_result_builder.cc b/mindspore-lite/src/extendrt/graph_compiler/compile_result_builder.cc deleted file mode 100644 index 35bc93903564f4ea7a3445fc252057266878e97a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/compile_result_builder.cc +++ /dev/null @@ -1,479 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/graph_compiler/compile_result_builder.h" -#include -#include "mindspore/ops/op_def/structure_ops.h" -#include "mindspore/ops/op_def/sequence_ops.h" -#include "mindspore/ops/op_def/framework_ops.h" -#include "src/extendrt/graph_compiler/anfnode_tensor_adapter.h" -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" -#include "mindspore/ops/op_def/op_name.h" -#include "ops/primitive_c.h" -#include "src/extendrt/utils/func_graph_utils.h" - -using AbstractBasePtr = mindspore::abstract::AbstractBasePtr; -using AbstractTensorPtr = mindspore::abstract::AbstractTensorPtr; -using AbstractSequencePtr = mindspore::abstract::AbstractSequencePtr; - -namespace mindspore { -namespace lite { -StatusCode CompileResultBuilder::BuildInputs(const AnfNodePtrList &inputs) { - MS_ASSERT(graph_ != nullptr); - if (graph_->InputSize() > 0) { - MS_LOG(ERROR) << "Please don't call BuildInputs twice."; - return kLiteError; - } - for (auto &input : inputs) { - auto results = TensorAdapter::CreateTensorsFromAbstract(input->abstract(), compile_option_->graph_input_format); - if (results.empty()) { - MS_LOG(ERROR) << "Create tensors from abstract of segments input failed, input : " - << input->fullname_with_scope(); - return kLiteError; - } - auto arg_node = std::make_shared(input->fullname_with_scope(), kernel::PrimitiveType()); - auto ret = graph_->AppendArgNode(arg_node); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Append input lite-node to graph failed, input : " << input->fullname_with_scope(); - return ret; - } - for (auto &result : results) { - auto tensor = result.release(); - arg_node->AppendOutputTensor(tensor); - ret = graph_->AppendInputTensor(tensor); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Append output tensor to argument node failed, node: " << input->fullname_with_scope(); - delete (tensor); - return ret; - } - } - } - return kSuccess; -} - -StatusCode CompileResultBuilder::BuildNodes(const std::vector &nodes) { - MS_ASSERT(graph_ != nullptr); - if (graph_->NodeSize() > 0) { - MS_LOG(ERROR) << "Please don't call BuildNodes twice."; - return kLiteError; - } - - for (auto &node : nodes) { - if (!utils::isa(node)) { - continue; - } - auto ret = CreateAndAppendNode(utils::cast(node)); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Create compile node from cnode failed : " << node; - return ret; - } - } - return kSuccess; -} - -StatusCode CompileResultBuilder::BuildNodes(const GraphSegmentPtr &graph_segment) { - return BuildNodes(graph_segment->nodes_); -} - -StatusCode CompileResultBuilder::BuildOutputs(const AnfNodePtrList &outputs) { - MS_ASSERT(graph_ != nullptr); - if (graph_->OutputSize() > 0) { - MS_LOG(ERROR) << "Please don't call BuildOutputs twice."; - return kLiteError; - } - for (auto &output : outputs) { - auto out_cnode = utils::cast(output); - if (out_cnode == nullptr) { - MS_LOG(ERROR) << "Outputs should be a CNode vector, but got " << output->Type() << " type element."; - return kLiteError; - } - auto compile_node = graph_->GetNode(out_cnode->fullname_with_scope()); - if (compile_node == nullptr) { - continue; - } - for (auto &tensor : compile_node->GetOutputs()) { - auto ret = graph_->AppendOutputTensor(tensor, true); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Append output tensor to graph failed, output: " << out_cnode->fullname_with_scope(); - return ret; - } - } - } - return kSuccess; -} - -// Replace `dst_tensor` with `src_tensor`. -void CompileResultBuilder::ReplaceTensor(InferTensor *dst_tensor, const InferTensor *src_tensor) { - // used as inputs of other node - auto &nodes = graph_->GetMutableNodes(); - for (auto &compile_node : nodes) { - if (compile_node == nullptr) { - continue; - } - compile_node->ReplaceInputTensor(dst_tensor, src_tensor); - } - // used as outputs of graph - auto &outputs = graph_->GetMutableOutputs(); - std::replace_if( - outputs.begin(), outputs.end(), [&src_tensor](InferTensor *ele) { return ele == src_tensor; }, dst_tensor); -} - -StatusCode CompileResultBuilder::RemoveMakeSeqNode() { - auto &nodes = graph_->GetMutableNodes(); - for (auto iter = nodes.begin(); iter != nodes.end();) { - auto &node = *iter; - if (node->GetType() != kMakeTupleOpName && node->GetType() != kMakeListOpName) { - iter++; - continue; - } - MS_LOG(INFO) << "Handling make sequence node: " << node->GetName(); - auto tensor_number = node->InputSize(); - if (tensor_number != node->OutputSize()) { - MS_LOG(ERROR) << "MakeSequence node should has same number of inputs and outputs, but got " << tensor_number - << " inputs and " << node->OutputSize() << " outputs."; - return kLiteError; - } - for (size_t i = 0; i < tensor_number; i++) { - ReplaceTensor(node->GetInput(i), node->GetOutput(i)); - } - iter = nodes.erase(iter); - } - return kSuccess; -} - -StatusCode CompileResultBuilder::RemoveDependNode() { - auto &nodes = graph_->GetMutableNodes(); - for (auto iter = nodes.begin(); iter != nodes.end();) { - auto &node = *iter; - if (node->GetType() != kDependOpName) { - iter++; - continue; - } - MS_LOG(INFO) << "Handling Depend node: " << node->GetName(); - constexpr int kSize2 = 2; - if (node->InputSize() != kSize2) { - MS_LOG(ERROR) << "Depend node should has 2 inputs, but got " << node->InputSize(); - return kLiteError; - } - if (node->OutputSize() != 1) { - MS_LOG(ERROR) << "Depend node should has 1 outputs, but got " << node->OutputSize(); - return kLiteError; - } - ReplaceTensor(node->GetInput(0), node->GetOutput(0)); - iter = nodes.erase(iter); - } - return kSuccess; -} - -StatusCode CompileResultBuilder::RemoveSeqGetItemNode() { - auto &nodes = graph_->GetMutableNodes(); - for (auto iter = nodes.begin(); iter != nodes.end();) { - auto &node = *iter; - if (node->GetType() != kTupleGetItemOpName && node->GetType() != kListGetItemOpName && - node->GetType() != "array_getitem" && node->GetType() != kSliceGetItemOpName) { - iter++; - continue; - } - MS_LOG(DEBUG) << "Handling GetItem node: " << node->GetName(); - if (node->OutputSize() != 1) { - MS_LOG(ERROR) << "GetItem node should has 1 outputs, but got " << node->OutputSize(); - return kLiteError; - } - auto index_tensor = node->GetInput(node->GetInputs().size() - 1); - if (index_tensor->data() == nullptr) { - MS_LOG(ERROR) << "`index_tensor` of GetItem should be a const tensor, but has no data."; - return kLiteError; - } - if (index_tensor->data_type() == kNumberTypeInt32) { - auto idx = reinterpret_cast(index_tensor->data())[0]; - ReplaceTensor(node->GetInput(idx), node->GetOutput(0)); - } else if (index_tensor->data_type() == kNumberTypeInt64) { - auto idx = reinterpret_cast(index_tensor->data())[0]; - ReplaceTensor(node->GetInput(idx), node->GetOutput(0)); - } else { - MS_LOG(ERROR) << "`index_tensor` of GetItem should be a const tensor with int data type, but got " - << index_tensor->data_type(); - return kLiteError; - } - iter = nodes.erase(iter); - } - return kSuccess; -} - -StatusCode CompileResultBuilder::OptimizeGraph() { - MS_ASSERT(graph_ != nullptr); - auto ret = RemoveDependNode(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Handle Depend node failed"; - return ret; - } - ret = RemoveMakeSeqNode(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Handle Make Sequence node failed"; - return ret; - } - ret = RemoveSeqGetItemNode(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Handle Sequence-Getitem node failed"; - return ret; - } - return kSuccess; -} - -CompileResultPtr CompileResultBuilder::Build(const GraphSegmentPtr &graph_segment, const AnfNodePtrList &inputs, - const AnfNodePtrList &outputs) { - graph_ = std::make_shared(); - if (BuildInputs(inputs) != kSuccess) { - MS_LOG(ERROR) << "Build graph inputs failed"; - return nullptr; - } - if (BuildNodes(graph_segment) != kSuccess) { - MS_LOG(ERROR) << "Build graph nodes failed"; - return nullptr; - } - if (BuildOutputs(outputs) != kSuccess) { - MS_LOG(ERROR) << "Build graph outputs failed"; - return nullptr; - } - if (OptimizeGraph() != kSuccess) { - MS_LOG(ERROR) << "Optimize graph failed"; - return nullptr; - } - graph_->Assemble(); - return graph_; -} - -StatusCode CompileResultBuilder::AppendInputCNodeToInputs(const CNodePtr &cnode, const CompileNodePtr &compile_node) { - if (cnode == nullptr) { - MS_LOG(ERROR) << "Input cnode is nullptr."; - return kLiteInputParamInvalid; - } - if (compile_node == nullptr) { - MS_LOG(ERROR) << "Input compile_node is nullptr."; - return kLiteInputParamInvalid; - } - auto input_node = graph_->GetNode(cnode->fullname_with_scope()); - if (input_node == nullptr) { - input_node = graph_->GetArgNode(cnode->fullname_with_scope()); - } - if (input_node == nullptr) { - MS_LOG(ERROR) << "Can not find input lite-node in graph, node: " << cnode->fullname_with_scope(); - return kLiteError; - } - for (auto &input_node_output : input_node->GetOutputs()) { - auto ret = graph_->AppendNodeInputTensor(compile_node, input_node_output, true); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Append input tensor for node failed, node: " << compile_node->GetName(); - return ret; - } - } - return kSuccess; -} - -StatusCode CompileResultBuilder::AppendInputParameterToInputs(const ParameterPtr ¶m_node, - const CompileNodePtr &compile_node) { - if (param_node == nullptr) { - MS_LOG(ERROR) << "Input param_node is nullptr."; - return kLiteInputParamInvalid; - } - if (compile_node == nullptr) { - MS_LOG(ERROR) << "Input compile_node is nullptr."; - return kLiteInputParamInvalid; - } - auto arg_node = graph_->GetArgNode(param_node->fullname_with_scope()); - if (arg_node != nullptr) { - for (auto &output : arg_node->GetOutputs()) { - auto ret = graph_->AppendNodeInputTensor(compile_node, output, true); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Append input tensor for node failed, node: " << compile_node->GetName(); - return ret; - } - } - return kSuccess; - } - auto tensor_from_param = TensorAdapter::Convert2Tensor(param_node); - if (tensor_from_param == nullptr) { - MS_LOG(ERROR) << "Create tensor from Parameter failed."; - return kLiteError; - } - auto format_value = compile_node->GetBaseOperator()->GetAttr(mindspore::ops::kFormat); - if (format_value != nullptr) { - tensor_from_param->set_format(static_cast(GetValue(format_value))); - } else { - tensor_from_param->set_format(compile_option_->graph_format); - } - auto ret = graph_->AppendNodeInputTensor(compile_node, tensor_from_param); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Append input tensor for node failed, node: " << compile_node->GetName(); - delete tensor_from_param; - return ret; - } - return kSuccess; -} - -StatusCode CompileResultBuilder::AppendInputValueNodeToInputs(const ValueNodePtr &value_node, - const CompileNodePtr &compile_node) { - if (value_node == nullptr) { - MS_LOG(ERROR) << "Input value_node is nullptr."; - return kLiteInputParamInvalid; - } - if (compile_node == nullptr) { - MS_LOG(ERROR) << "Input compile_node is nullptr."; - return kLiteInputParamInvalid; - } - if (value_node->value() != nullptr && value_node->value()->isa()) { - MS_LOG(WARNING) << "Skip Monad value node: " << value_node->fullname_with_scope(); - return kSuccess; - } - auto tensor_from_value = TensorAdapter::Convert2Tensor(value_node); - if (tensor_from_value == nullptr) { - MS_LOG(ERROR) << "Create tensor from ValueNode failed."; - return kLiteError; - } - auto format_value = compile_node->GetBaseOperator()->GetAttr(mindspore::ops::kFormat); - if (format_value != nullptr) { - tensor_from_value->set_format(static_cast(GetValue(format_value))); - } else { - tensor_from_value->set_format(compile_option_->graph_format); - } - auto ret = graph_->AppendNodeInputTensor(compile_node, tensor_from_value); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Append input tensor for node failed, node: " << compile_node->GetName(); - delete tensor_from_value; - return ret; - } - return kSuccess; -} - -StatusCode CompileResultBuilder::CreateAndAppendNode(const CNodePtr &cnode) { - auto compile_node = CompileNode::Create(cnode); - if (compile_node == nullptr) { - MS_LOG(ERROR) << "Create compile node failed, cnode: " << cnode->fullname_with_scope(); - return kLiteError; - } - auto ret = graph_->AppendNode(compile_node); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Append compile_node to graph failed, node: " << compile_node->GetName(); - return ret; - } - // inputs - for (size_t i = 1; i < cnode->size(); i++) { - auto &input = cnode->input(i); - if (utils::isa(input)) { - ret = this->AppendInputCNodeToInputs(utils::cast(input), compile_node); - } else if (utils::isa(input)) { - ret = this->AppendInputParameterToInputs(utils::cast(input), compile_node); - } else if (utils::isa(input)) { - ret = this->AppendInputValueNodeToInputs(utils::cast(input), compile_node); - } else { - MS_LOG(ERROR) << "Unsupported input node of cnode: " << input - << ", current cnode: " << cnode->fullname_with_scope(); - ret = kLiteNotSupport; - } - if (ret != kSuccess) { - MS_LOG(ERROR) << "Create input tensor for cnode failed, cnode: " << cnode->fullname_with_scope(); - return ret; - } - } - // outputs - ret = BuildNodeOutputTensor(cnode, compile_node); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Create output tensors of cnode failed, cnode: " << cnode; - return ret; - } - return kSuccess; -} - -StatusCode CompileResultBuilder::BuildNodeOutputTensor(const CNodePtr &cnode, const CompileNodePtr &compile_node) { - if (compile_node == nullptr) { - MS_LOG(ERROR) << "Input compile_node is nullptr."; - return kLiteInputParamInvalid; - } - if (compile_node->OutputSize() > 0) { - MS_LOG(ERROR) << "Build node output twice, node : " << compile_node->GetName(); - return kLiteError; - } - auto results = TensorAdapter::Convert2Tensor(cnode); - if (results.empty()) { - MS_LOG(ERROR) << "Create tensors from cnode failed, cnode : " << cnode->fullname_with_scope(); - return kLiteError; - } - size_t index = 0; - auto ret = kSuccess; - for (; index < results.size(); index++) { - auto tensor = results[index]; - ret = graph_->AppendNodeOutputTensor(compile_node, tensor); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Append output tensor to node failed, node: " << compile_node->GetName(); - break; - } - } - // release results if failed - for (; index < results.size(); index++) { - delete results[index]; - } - return ret; -} - -StatusCode CompileResultBuilder::BuildNodes(const FuncGraphPtr &func_graph) { - MS_ASSERT(func_graph != nullptr); - auto nodes = func_graph->TopoSort(func_graph->get_return()); - if (nodes.empty()) { - MS_LOG(ERROR) << "There are no nodes in the graph"; - return kLiteError; - } - - return BuildNodes(nodes); -} - -CompileResultPtr CompileResultBuilder::Build(const FuncGraphPtr &func_graph) { - graph_ = std::make_shared(); - - if (BuildInputs(func_graph->get_inputs()) != kSuccess) { - MS_LOG(ERROR) << "Build graph inputs failed"; - return nullptr; - } - if (BuildNodes(func_graph) != kSuccess) { - MS_LOG(ERROR) << "Build graph nodes failed"; - return nullptr; - } - - std::vector outputs_with_index; - FuncGraphUtils::GetFuncGraphOutputs(func_graph, &outputs_with_index); - AnfNodePtrList outputs; - outputs.resize(outputs_with_index.size()); - for (auto &output : outputs_with_index) { - if (output.second >= outputs.size()) { - MS_LOG(ERROR) << "Build graph nodes failed"; - return nullptr; - } - outputs[output.second] = output.first; - } - if (BuildOutputs(outputs) != kSuccess) { - MS_LOG(ERROR) << "Build graph outputs failed"; - return nullptr; - } - if (OptimizeGraph() != kSuccess) { - MS_LOG(ERROR) << "Optimize graph failed"; - return nullptr; - } - graph_->Assemble(); - return graph_; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/graph_compiler/compile_result_builder.h b/mindspore-lite/src/extendrt/graph_compiler/compile_result_builder.h deleted file mode 100644 index 85ab33f149ccf422478642b65f13ad7b18909cc3..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/compile_result_builder.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_COMPILE_RESULT_BUILDER_H_ -#define MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_COMPILE_RESULT_BUILDER_H_ -#include -#include -#include -#include -#include -#include -#include "src/extendrt/graph_compiler/compile_result.h" -#include "src/extendrt/graph_compiler/compile_option.h" -#include "src/infer/tensor.h" -#include "abstract/abstract_value.h" -#include "ir/anf.h" -#include "include/api/status.h" - -namespace mindspore { -namespace lite { -class CompileResultBuilder { - public: - explicit CompileResultBuilder(CompileOptionPtr option) : compile_option_(std::move(option)) { - MS_EXCEPTION_IF_NULL(compile_option_); - } - ~CompileResultBuilder() = default; - CompileResultPtr Build(const GraphSegmentPtr &graph_segment, const AnfNodePtrList &inputs, - const AnfNodePtrList &outputs); - CompileResultPtr Build(const FuncGraphPtr &func_graph); - - private: - // build - StatusCode BuildInputs(const AnfNodePtrList &inputs); - StatusCode BuildNodes(const GraphSegmentPtr &graph_segment); - StatusCode BuildNodes(const std::vector &nodes); - StatusCode BuildNodes(const FuncGraphPtr &func_graph); - StatusCode BuildOutputs(const AnfNodePtrList &outputs); - StatusCode OptimizeGraph(); - // methods about node - StatusCode CreateAndAppendNode(const CNodePtr &cnode); - StatusCode AppendInputCNodeToInputs(const CNodePtr &cnode, const CompileNodePtr &compile_node); - StatusCode AppendInputParameterToInputs(const ParameterPtr ¶m_node, const CompileNodePtr &compile_node); - StatusCode AppendInputValueNodeToInputs(const ValueNodePtr &value_node, const CompileNodePtr &compile_node); - // methods about tensor - StatusCode BuildNodeOutputTensor(const CNodePtr &cnode, const CompileNodePtr &compile_node); - // methods about optimize - StatusCode RemoveSeqGetItemNode(); - StatusCode RemoveMakeSeqNode(); - StatusCode RemoveDependNode(); - // Replace `dst_tensor` with `src_tensor`. - void ReplaceTensor(InferTensor *dst_tensor, const InferTensor *src_tensor); - - private: - CompileResultPtr graph_ = nullptr; - CompileOptionPtr compile_option_{nullptr}; - std::set input_names_{}; -}; -} // namespace lite -} // namespace mindspore - -#endif diff --git a/mindspore-lite/src/extendrt/graph_compiler/default_graph_compiler.cc b/mindspore-lite/src/extendrt/graph_compiler/default_graph_compiler.cc deleted file mode 100644 index 508d563841e204f717a366c8e3469313b56ea583..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/default_graph_compiler.cc +++ /dev/null @@ -1,440 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "abstract/abstract_value.h" -#include "backend/graph_compiler/graph_partition.h" -#include "base/base_ref.h" -#include "extendrt/execution_plan.h" -#include "extendrt/graph_compiler/anfnode_tensor_adapter.h" -#include "extendrt/graph_compiler/default_graph_compiler.h" -#include "extendrt/graph_compiler/factory.h" -#include "extendrt/mock/lite_runtime/converters.h" -#include "extendrt/utils/func_graph_utils.h" -#include "ir/manager.h" -#include "mindspore/ops/op_def/framework_ops.h" -#include "mindspore/ops/op_def/nn_ops.h" -#include "mindspore/ops/op_def/sequence_ops.h" -#include "mindspore/ops/op_def/op_name.h" -#include "src/extendrt/graph_compiler/compile_result_builder.h" -#include "tools/optimizer/common/gllo_utils.h" -#include "src/common/common.h" -#include "extendrt/delegate/factory.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -static const std::vector ms_infer_cut_list = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, - prim::kPrimBpropCut, prim::kPrimSwitchLayer}; -static constexpr auto ms_infer_backend_name = "mindspore_lite_backend"; - -void DefaultGraphCompiler::InitCompileOption(const FuncGraphPtr &graph) { - if (option_ != nullptr) { - MS_LOG(INFO) << "CompileOption is already inited."; - return; - } - option_ = std::make_shared(); - auto format_value = graph->get_attr(mindspore::ops::kFormat); - if (format_value != nullptr) { - option_->graph_format = Format(GetValue(format_value)); - } - auto input_format_value = graph->get_attr(kInputFormat); - if (input_format_value != nullptr) { - option_->graph_input_format = Format(GetValue(input_format_value)); - } - - if (inner_context_->IsDeviceTypeEnabled(lite::DT_ASCEND)) { - option_->backend = kernel::kBackendAscend; - } -} - -void DefaultGraphCompiler::ReplaceNodes(const std::shared_ptr &graph) { - const ConfigInfos config_infos; - auto &device_contexts = context_->MutableDeviceInfo(); - if (device_contexts.empty()) { - MS_LOG(ERROR) << "No context found"; - } - auto device_type = device_contexts.at(0)->GetDeviceType(); - auto provider = device_contexts.at(0)->GetProvider(); - auto delegate = - DelegateRegistry::GetInstance().GetDelegate(device_type, provider, context_, config_infos); - if (delegate != nullptr) { - delegate->ReplaceNodes(graph); - } -} - -std::shared_ptr DefaultGraphCompiler::Compile(FuncGraphPtr graph) { - inner_context_ = ContextUtils::Convert(context_.get()); - if (inner_context_ == nullptr || inner_context_->Init() != RET_OK) { - MS_LOG(ERROR) << "Init inner context failed"; - return nullptr; - } - - InitCompileOption(graph); - - ReplaceNodes(graph); - - MS_LOG(DEBUG) << "Partition graph begin"; - auto graph_segments = Partition(graph); - if (graph_segments.empty()) { - MS_LOG(ERROR) << "Partition graph failed"; - return nullptr; - } - MS_LOG(DEBUG) << "Partition graph end"; - - MS_LOG(DEBUG) << "Schedule graph to execute plan begin"; - auto execution_plan = NonCFGCompile(graph_segments, graph); - if (execution_plan == nullptr) { - MS_LOG(ERROR) << "Schedule graph failed"; - return nullptr; - } - MS_LOG(DEBUG) << "Schedule graph to execute plan end"; - - return execution_plan; -} - -std::vector DefaultGraphCompiler::Partition(const FuncGraphPtr &graph) { - auto partition = std::make_shared(ms_infer_cut_list, ms_infer_backend_name); - if (partition == nullptr) { - MS_LOG(ERROR) << "Create graph partition failed, maybe not enough memory"; - return {}; - } - - // multi_target set false - bool is_multi_target; - return partition->Partition(graph, &is_multi_target); -} - -CompileResultPtr DefaultGraphCompiler::Compile(const GraphSegmentPtr &segment, const std::vector &inputs, - const std::vector &outputs) { - auto builder = std::make_shared(option_); - return builder->Build(segment, inputs, outputs); -} - -std::vector DefaultGraphCompiler::Schedule(const CompileResultPtr &compile_result) { - if (MS_UNLIKELY(scheduler_ == nullptr)) { - scheduler_ = std::make_shared(this->inner_context_, context_, option_); - } - return {scheduler_->Schedule(compile_result)}; -} - -std::vector DefaultGraphCompiler::SkipMakeTuple(const AnfNodePtr &origin_node) { - if (!origin_node->isa()) { - // not cnode, return origin node - return {origin_node}; - } - auto cnode = origin_node->cast(); - MS_ASSERT(cnode != nullptr); - - if (!IsPrimitive(cnode->input(0), prim::kPrimMakeTuple)) { - return {origin_node}; - } - - std::vector results; - for (size_t i = 1; i < cnode->size(); i++) { - auto real_nodes = SkipMakeTuple(cnode->input(i)); - results.insert(results.end(), real_nodes.begin(), real_nodes.end()); - } - return results; -} - -Status DefaultGraphCompiler::UpdateSubGraphInoutMap(const kernel::KernelExec &subgraph, const AnfNodePtrList &inputs, - const AnfNodePtrList &outputs) { - // add subgraph_input_map: subgraph.input-tensors --> anfnode - auto count = inputs.size(); - // not support tuple as an input of cnode now except return and tuplegetitem, but return is skipped before and - // tuplegetitem is not a cut point. - if (MS_UNLIKELY(count != subgraph.in_tensors().size())) { - MS_LOG(ERROR) << "Subgraph has " << subgraph.in_tensors().size() << " inputs while segment has " << count - << " inputs."; - return kLiteError; - } - for (size_t i = 0; i < count; i++) { - subgraph_input_map_[subgraph.in_tensors()[i]] = inputs[i]; - } - // add subgraph_output_map: anfnode --> subgraph.output-tensors - count = outputs.size(); - // not support tuple as an input of cnode now except return and tuplegetitem, but return is skipped before and - // tuplegetitem is not a cut point. - if (MS_UNLIKELY(count != subgraph.out_tensors().size())) { - MS_LOG(ERROR) << "Subgraph has " << subgraph.out_tensors().size() << " outputs while segment has " << count - << " outputs."; - return kLiteError; - } - for (size_t i = 0; i < count; i++) { - subgraph_output_map_[outputs[i]] = subgraph.out_tensors()[i]; - } - return kSuccess; -} - -std::tuple DefaultGraphCompiler::GetSegmentInout(const GraphSegment &graph_segment) { - FuncGraphPtr fg = nullptr; - AnfNodePtrList inputs; - AnfNodePtrList outputs; - std::tie(fg, inputs, outputs) = FuncGraphUtils::TransformSegmentToAnfGraph(graph_segment.nodes_); - // TransformSegmentToAnfGraph puts all input and weight into 'inputs'. In inference, we erase weight. - for (auto iter = inputs.begin(); iter != inputs.end();) { - if (utils::isa(*iter) && (utils::cast(*iter))->has_default()) { - iter = inputs.erase(iter); - } else { - iter++; - } - } - // maketuple and tuplegetitem make nosense in inference, skip nodes with these types for outputs - AnfNodePtrList real_outputs; - real_outputs.reserve(outputs.size()); - for (auto &output : outputs) { - std::vector seg_outputs = SkipMakeTuple(output); - real_outputs.insert(real_outputs.end(), seg_outputs.begin(), seg_outputs.end()); - } - return std::make_tuple(inputs, real_outputs); -} - -Status DefaultGraphCompiler::CreateExecPlanKernels(const std::vector &graph_segments, - std::vector *segments_outputs) { - MS_ASSERT(execution_plan_ != nullptr); - for (const auto &graph_segment : graph_segments) { - if (graph_segment == nullptr) { - MS_LOG(ERROR) << "Graph segment is nullptr"; - return kLiteNullptr; - } - if (graph_segment->nodes_.size() == 1) { - auto &node = graph_segment->nodes_[0]; - if (opt::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - } - AnfNodePtrList inputs; - AnfNodePtrList outputs; - std::tie(inputs, outputs) = GetSegmentInout(*graph_segment); - // maketuple tuplegetitem is deleted inside of Compile - auto compile_result = this->Compile(graph_segment, inputs, outputs); - if (compile_result == nullptr) { - MS_LOG(ERROR) << "Convert to CompileResult failed"; - return kLiteError; - } - auto kernels = this->Schedule(compile_result); - if (kernels.size() != 1) { - MS_LOG(ERROR) << "Only support one subgraph from one graph segment now, got " << kernels.size(); - return kLiteError; - } - auto kernel = kernels[0]; - if (kernel == nullptr) { - MS_LOG(ERROR) << "Schedule failed, return nullptr."; - return kLiteError; - } - auto ret = UpdateSubGraphInoutMap(*kernel, inputs, outputs); - if (ret != kSuccess) { - MS_LOG(ERROR) << "UpdateSubGraphInoutMap failed: " << ret; - return ret; - } - segments_outputs->emplace_back(outputs); - execution_plan_->AddKernel(kernel); - } - return kSuccess; -} - -Status DefaultGraphCompiler::CreateExecPlanInputs(const FuncGraphPtr &func_graph) { - MS_ASSERT(graph_input_tensors_.empty()); - auto graph_inputs = func_graph->get_inputs(); - if (graph_inputs.empty()) { - MS_LOG(ERROR) << "Get graph inputs node failed"; - return kLiteError; - } - for (const auto &input : func_graph->get_inputs()) { - if (!utils::isa(input)) { - MS_LOG(ERROR) << "Not supported graph input: " << input; - return kLiteError; - } - auto parameter = utils::cast(input); - auto tensor = TensorAdapter::Convert2Tensor(parameter, option_->graph_input_format); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Create graph input tensor failed, input : " << input->fullname_with_scope(); - return kLiteError; - } - tensor->set_category(GRAPH_INPUT); - graph_input_tensors_.push_back(tensor); - subgraph_output_map_[input] = tensor; - } - execution_plan_->SetInputs(graph_input_tensors_); - return kSuccess; -} - -Status DefaultGraphCompiler::CreateExecPlanOutputs(const FuncGraphPtr &func_graph, - const std::vector &segments_outputs) { - MS_ASSERT(execution_plan_ != nullptr); - anf_tensor_map_.clear(); - auto graph_output = func_graph->output(); - if (graph_output == nullptr) { - MS_LOG(ERROR) << "Get graph output node failed"; - return kLiteError; - } - std::vector graph_outputs = SkipMakeTuple(graph_output); - - auto graph_output_tensors = DefaultGraphCompiler::CreateTensors(graph_outputs); - if (graph_output_tensors.size() != graph_outputs.size()) { - MS_LOG(ERROR) << "Create graph output tensor failed"; - return kLiteError; - } - for (size_t i = 0; i < graph_outputs.size(); i++) { - auto output_node = graph_outputs[i]; - auto output_tensor = graph_output_tensors[i]; - auto it = anf_tensor_map_.find(output_node); - if (it != anf_tensor_map_.end()) { - MS_LOG(ERROR) << "Can not find corresponding tensor for graph output node: " - << output_node->fullname_with_scope(); - return kLiteError; - } - anf_tensor_map_[output_node] = output_tensor; - } - execution_plan_->SetOutputs(graph_output_tensors); - - auto *output_isolate_map = new std::unordered_map(); - for (size_t i = 0; i < execution_plan_->GetKernels().size(); i++) { - auto kernel = execution_plan_->GetKernels()[i]; - if (MS_UNLIKELY(kernel->out_tensors().size() != segments_outputs[i].size())) { - MS_LOG(ERROR) << "Subgraph has " << kernel->in_tensors().size() << " outputs while segment has " - << segments_outputs[i].size() << " outputs."; - delete output_isolate_map; - return kLiteError; - } - for (size_t j = 0; j < kernel->out_tensors().size(); j++) { - auto output_tensor = kernel->out_tensors()[j]; - auto &output_node = segments_outputs[i][j]; - auto it = anf_tensor_map_.find(output_node); - if (it != anf_tensor_map_.end()) { - auto outter_tensor = it->second; - (*output_isolate_map)[output_tensor] = outter_tensor; - } - } - } - execution_plan_->SetOutputsMap(output_isolate_map); - return kSuccess; -} - -Status DefaultGraphCompiler::IsolateSubGraphs() { - auto *subgraph_isolate_map = new std::unordered_map(); - for (auto &kernel : execution_plan_->GetKernels()) { - auto &in_tensors = kernel->in_tensors(); - for (size_t i = 0; i < in_tensors.size(); i++) { - auto &input = in_tensors[i]; - auto anf_node = subgraph_input_map_.find(input); - if (anf_node == subgraph_input_map_.end()) { - MS_LOG(ERROR) << "Can not find corresponding anf_node for " << i << "th input of subgraph " << kernel->name(); - delete subgraph_isolate_map; - return kLiteError; - } - auto output = subgraph_output_map_.find(anf_node->second); - if (output == subgraph_output_map_.end()) { - MS_LOG(ERROR) << "Can not find corresponding output tensor for anf_node: " - << anf_node->second->fullname_with_scope(); - delete subgraph_isolate_map; - return kLiteError; - } - (*subgraph_isolate_map)[input] = output->second; - } - } - this->execution_plan_->SetInputsMap(subgraph_isolate_map); - return kSuccess; -} - -std::shared_ptr DefaultGraphCompiler::NonCFGCompile( - const std::vector &graph_segments, const FuncGraphPtr &func_graph) { - execution_plan_ = std::make_shared(); - execution_plan_->SetContext(inner_context_); - - // set func graph manager - auto func_manager = func_graph->manager(); - if (func_manager == nullptr) { - func_manager = Manage(func_graph, true); - func_graph->set_manager(func_manager); - } - std::vector segments_outputs; - auto ret = CreateExecPlanKernels(graph_segments, &segments_outputs); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Create graph subgraphs failed"; - return nullptr; - } - ret = CreateExecPlanInputs(func_graph); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Create graph input tensors failed"; - return nullptr; - } - ret = CreateExecPlanOutputs(func_graph, segments_outputs); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Create graph output tensors failed"; - return nullptr; - } - ret = IsolateSubGraphs(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Isolate subgraphs failed"; - return nullptr; - } - return execution_plan_; -} - -std::vector DefaultGraphCompiler::CreateTensors(const std::vector &nodes) { - std::vector tensors; - for (const auto &node : nodes) { - if (node->isa()) { - auto cnode = node->cast(); - auto tmp = TensorAdapter::Convert2Tensor(cnode); - if (tmp.empty()) { - MS_LOG(ERROR) << "Create tensors from cnode failed, node : " << node->fullname_with_scope(); - for (auto tensor : tensors) { - delete tensor; - } - return {}; - } - (void)tensors.insert(tensors.cend(), tmp.begin(), tmp.end()); - continue; - } - if (node->isa()) { - auto param_node = node->cast(); - auto tensor = TensorAdapter::Convert2Tensor(param_node); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Create tensors from parameter failed, node : " << node->fullname_with_scope(); - return {}; - } - tensors.emplace_back(tensor); - continue; - } - if (node->isa()) { - auto value_node = node->cast(); - auto tensor = TensorAdapter::Convert2Tensor(value_node); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Create tensors from value node failed, node : " << node->fullname_with_scope(); - return {}; - } - tensors.emplace_back(tensor); - continue; - } - } - return tensors; -} - -static std::shared_ptr DefaultGraphCompilerCreator( - const std::shared_ptr &ctx) { - auto graph_compiler = std::make_shared(ctx); - return graph_compiler; -} -REG_GRAPH_COMPILER(kDefaultCompiler, DefaultGraphCompilerCreator); -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/graph_compiler/default_graph_compiler.h b/mindspore-lite/src/extendrt/graph_compiler/default_graph_compiler.h deleted file mode 100644 index 1c4aa28257d31649c67905875d5bca672171334f..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/default_graph_compiler.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_DEFAULT_GRAPH_COMPILER_H_ -#define MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_DEFAULT_GRAPH_COMPILER_H_ - -#include -#include -#include -#include "infer/graph_compiler.h" -#include "infer/context.h" -#include "src/extendrt/graph_compiler/compile_result.h" -#include "src/extendrt/graph_compiler/single_graph_scheduler.h" -#include "src/extendrt/graph_compiler/compile_option.h" - -namespace mindspore::lite { -class DefaultGraphCompiler : public infer::abstract::GraphCompiler { - public: - explicit DefaultGraphCompiler(const std::shared_ptr &context) : context_(context) { - inner_context_ = nullptr; - } - ~DefaultGraphCompiler() override = default; - - std::shared_ptr Compile(FuncGraphPtr graph) override; - - protected: - void InitCompileOption(const FuncGraphPtr &graph); - std::shared_ptr NonCFGCompile(const std::vector &graph_segments, - const FuncGraphPtr &func_graph); - - virtual std::vector Partition(const FuncGraphPtr &graph); - - CompileResultPtr Compile(const GraphSegmentPtr &segment, const std::vector &inputs, - const std::vector &outputs); - - std::vector Schedule(const CompileResultPtr &compile_result); - - private: - Status CreateExecPlanKernels(const std::vector &graph_segments, - std::vector *segments_outputs); - Status UpdateSubGraphInoutMap(const kernel::KernelExec &subgraph, const AnfNodePtrList &inputs, - const AnfNodePtrList &outputs); - std::tuple GetSegmentInout(const GraphSegment &graph_segment); - Status CreateExecPlanInputs(const FuncGraphPtr &func_graph); - Status CreateExecPlanOutputs(const FuncGraphPtr &func_graph, const std::vector &segments_outputs); - Status IsolateSubGraphs(); - static std::vector CreateTensors(const std::vector &nodes); - std::vector SkipMakeTuple(const AnfNodePtr &origin_node); - void ReplaceNodes(const std::shared_ptr &graph); - - private: - std::shared_ptr execution_plan_{nullptr}; - std::vector graph_input_tensors_; - mindspore::HashMap anf_tensor_map_; - mindspore::HashMap subgraph_input_map_; - mindspore::HashMap subgraph_output_map_; - SingleGraphSchedulerPtr scheduler_{nullptr}; - const std::shared_ptr &context_; - InferContextPtr inner_context_{nullptr}; - CompileOptionPtr option_{nullptr}; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_DEFAULT_GRAPH_COMPILER_H_ diff --git a/mindspore-lite/src/extendrt/graph_compiler/factory.h b/mindspore-lite/src/extendrt/graph_compiler/factory.h deleted file mode 100644 index 46ee79ff4f55ff94303c41a80d7b493569948995..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/factory.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_FACTORY_H_ -#define MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_FACTORY_H_ - -#include -#include - -#include "extendrt/graph_compiler/type.h" -#include "include/api/context.h" -#include "infer/graph_compiler.h" - -namespace mindspore { -using GraphCompilerRegFunc = - std::function(const std::shared_ptr &)>; - -class GraphCompilerRegistry { - public: - GraphCompilerRegistry() = default; - virtual ~GraphCompilerRegistry() = default; - - static GraphCompilerRegistry &GetInstance(); - - void RegCompiler(const mindspore::GraphCompilerType &graph_compiler_type, const GraphCompilerRegFunc &creator); - - std::shared_ptr GetCompiler(const mindspore::GraphCompilerType &type, - const std::shared_ptr &context); - - private: - mindspore::HashMap graph_compiler_map_; -}; - -class GraphCompilerRegistrar { - public: - GraphCompilerRegistrar(const mindspore::GraphCompilerType &graph_compiler_type, const GraphCompilerRegFunc &creator) { - GraphCompilerRegistry::GetInstance().RegCompiler(graph_compiler_type, creator); - } - ~GraphCompilerRegistrar() = default; -}; - -#define REG_GRAPH_COMPILER(type, creator) static GraphCompilerRegistrar g_##type##GraphCompiler(type, creator); -} // namespace mindspore - -#endif // MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_FACTORY_H_ diff --git a/mindspore-lite/src/extendrt/graph_compiler/infershape_helper.cc b/mindspore-lite/src/extendrt/graph_compiler/infershape_helper.cc deleted file mode 100644 index f9de7739147c32fc2210a6702a709baaf6bf1a07..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/infershape_helper.cc +++ /dev/null @@ -1,386 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/graph_compiler/infershape_helper.h" -#include -#include -#include -#include -#include -#include "src/common/ops/operator_populate/operator_populate_register.h" -#include "src/extendrt/graph_compiler/anfnode_tensor_adapter.h" -#include "src/litert/pass/format_pass/format_pass.h" -#include "tools/optimizer/graph/node_infershape.h" -#include "abstract/dshape.h" - -#include "mindspore/ops/op_def/auto_generate/gen_lite_ops.h" -#include "mindspore/ops/op_def/nn_op_name.h" -#include "infer/adam.h" -#include "infer/apply_momentum.h" -#include "infer/batch_to_space.h" -#include "infer/depth_to_space.h" -#include "infer/fused_batch_norm.h" -#include "infer/cxx_api/avg_pool_fusion.h" -#include "infer/cxx_api/conv2d_backprop_input_fusion.h" -#include "infer/cxx_api/conv2d_backprop_filter_fusion.h" -#include "infer/cxx_api/conv2d_fusion.h" -#include "infer/cxx_api/conv2d_transpose_fusion.h" -#include "infer/cxx_api/max_pool_fusion.h" -#include "infer/cxx_api/prelu_fusion.h" -#include "infer/grad/max_pool_grad.h" -#include "infer/grad/resize_grad.h" -#include "infer/instance_norm.h" -#include "infer/lrn.h" -#include "ops_utils/op_utils.h" -#include "infer/resize.h" -#include "infer/roi_pooling.h" -#include "infer/sgd.h" -#include "infer/space_to_batch.h" -#include "infer/space_to_batch_nd.h" -#include "infer/space_to_depth.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_b.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" - -namespace mindspore::lite { -namespace { -static const std::set FormatAwareOp = {ops::kNameAdam, - ops::kNameApplyMomentum, - ops::kNameAvgPoolFusion, - ops::kNameAvgPoolGrad, - kBatchNormOpName, - kBatchNormGradOpName, - ops::kNameBatchToSpace, - ops::kNameBiasAdd, - ops::kNameBiasAddGrad, - ops::kNameConv2DBackpropInputFusion, - ops::kNameConv2DBackpropFilterFusion, - ops::kNameConv2DFusion, - ops::kNameConv2dTransposeFusion, - ops::kNameDepthToSpace, - ops::kNameFusedBatchNorm, - ops::kNameGridSampler2D, - ops::kNameInstanceNorm, - ops::kNameLRN, - ops::kNameMaxPoolFusion, - ops::kNameMaxPoolGrad, - ops::kNamePReLUFusion, - ops::kNameResize, - ops::kNameResizeGrad, - ops::kNameROIPooling, - ops::kNameSGD, - ops::kNameSpaceToBatch, - ops::kNameSpaceToBatchND, - ops::kNameSpaceToDepth}; - -constexpr int kNCHW2NHWC = 0; -constexpr int kNHWC2NCHW = 1; -void TransposeShape(InferTensor *tensor, int transpose_type) { - if (MS_UNLIKELY(tensor == nullptr)) { - return; - } - auto shape = tensor->shape(); - constexpr int kNCHWDimSize = 4; - if (shape.size() != kNCHWDimSize) { - return; - } - std::vector new_shape(kNCHWDimSize); - if (transpose_type == kNCHW2NHWC) { - new_shape[kNHWC_N] = shape[kNCHW_N]; - new_shape[kNHWC_H] = shape[kNCHW_H]; - new_shape[kNHWC_W] = shape[kNCHW_W]; - new_shape[kNHWC_C] = shape[kNCHW_C]; - tensor->set_shape(new_shape); - tensor->set_format(NHWC); - return; - } else if (transpose_type == kNHWC2NCHW) { - new_shape[kNCHW_N] = shape[kNHWC_N]; - new_shape[kNCHW_C] = shape[kNHWC_C]; - new_shape[kNCHW_H] = shape[kNHWC_H]; - new_shape[kNCHW_W] = shape[kNHWC_W]; - tensor->set_shape(new_shape); - tensor->set_format(NCHW); - return; - } -} - -void TransposeShape(std::vector *tensors, int transpose_type) { - if (MS_UNLIKELY(tensors == nullptr)) { - return; - } - for (auto *tensor : *tensors) { - TransposeShape(tensor, transpose_type); - } -} - -int SyncInferRetToLiteTensor(const CompileNode &node, const int &infer_ret) { - if (infer_ret == RET_INFER_INVALID) { - for (auto *output : node.GetOutputs()) { - output->set_shape({-1}); - } - } - auto cnode = node.GetCNode(); - MS_ASSERT(cnode != nullptr); - auto abstract = cnode->abstract(); - if (utils::isa(abstract)) { - auto elements = utils::cast(abstract)->elements(); - if (elements.size() != node.OutputSize()) { - MS_LOG(INFO) << "The cnode output size: " << elements.size() - << " is not equal to lite tensors size: " << node.OutputSize(); - return RET_ERROR; - } - for (size_t i = 0; i < elements.size(); i++) { - if (!TensorAdapter::SetDTAndShapeFromAbTensorToLiteTensor(elements[i], node.GetOutput(i))) { - MS_LOG(INFO) << "Sync infershape result from Abstract to InferTensor failed, node : " << node.GetName(); - return RET_ERROR; - } - } - return RET_OK; - } - if (utils::isa(abstract)) { - if (!TensorAdapter::SetDTAndShapeFromAbTensorToLiteTensor(abstract, node.GetOutput(0))) { - MS_LOG(INFO) << "Sync infershape result from Abstract to InferTensor failed, node : " << node.GetName(); - return RET_ERROR; - } - return RET_OK; - } - MS_LOG(INFO) << "Unsupported abstract type: " << abstract; - return RET_ERROR; -} - -int SyncInferRetToCNodeNative(const CompileNode &node) { - auto cnode = node.GetCNode(); - MS_ASSERT(cnode != nullptr); - const auto &outputs = node.GetOutputs(); - if (outputs.empty()) { - return RET_OK; - } - auto abstract = cnode->abstract(); - if (utils::isa(abstract)) { - auto abs_tuple = utils::cast(abstract); - MS_ASSERT(abs_tuple != nullptr); - if (abs_tuple->elements().size() != outputs.size()) { - MS_LOG(INFO) << "Node(" << node.GetName() << ") has " << outputs.size() - << " output tensor(s), but its AbstractTuple has " << abs_tuple->elements().size() << " element(s)."; - return RET_ERROR; - } - for (size_t i = 0; i < outputs.size(); i++) { - if (!TensorAdapter::SetDTAndShapeFromLiteTensorToAbTensor(*outputs[i], abs_tuple->elements()[i])) { - MS_LOG(INFO) << "Sync infershape result from InferTensor to Abstract failed, " << node.GetName(); - return RET_ERROR; - } - } - cnode->set_abstract(abs_tuple); - return RET_OK; - } - if (utils::isa(abstract)) { - if (outputs.size() != 1) { - MS_LOG(INFO) << "Node(" << node.GetName() << ")'s abstract is an AbstractTensor but has " << outputs.size() - << " output tensor(s)."; - return RET_ERROR; - } - auto abs_tensor = utils::cast(abstract); - MS_ASSERT(abs_tensor != nullptr); - if (!TensorAdapter::SetDTAndShapeFromLiteTensorToAbTensor(*outputs[0], abs_tensor)) { - MS_LOG(INFO) << "Sync infershape result from InferTensor to Abstract failed, " << node.GetName(); - return RET_ERROR; - } - cnode->set_abstract(abs_tensor); - return RET_OK; - } - MS_LOG(INFO) << "Unsupported abstract type: " << abstract; - return RET_ERROR; -} - -int SyncInferRetToCNode(const CompileNode &node, const int &infer_ret) { - const auto &outputs = node.GetOutputs(); - if (infer_ret == RET_INFER_INVALID) { - for (auto *output : outputs) { - output->set_shape({abstract::Shape::kShapeRankAny}); - } - } - auto ret = SyncInferRetToCNodeNative(node); - if (infer_ret == RET_INFER_INVALID) { - for (auto *output : outputs) { - output->set_shape({-1}); - } - } - return ret; -} - -int InferShapeByNNACL(const CompileNodePtr &node, OpParameter *op_parameter, Format format, InferContext *context) { - if (format != NHWC && format != NCHW) { - MS_LOG(INFO) << "NNACL infershape only support NCHW or NHWC format, got " << FormatEnumToString(format); - return RET_ERROR; - } - auto inputs = node->GetInputs(); - auto outputs = node->GetOutputs(); - int infer_ret = RET_OK; - for (auto *input : inputs) { - auto shape = input->shape(); - if (std::any_of(shape.begin(), shape.end(), [](const int dim) { return dim < 0; })) { - infer_ret = RET_INFER_INVALID; - break; - } - } - if (infer_ret != RET_INFER_INVALID) { - if (format == NCHW) { - TransposeShape(&inputs, kNCHW2NHWC); - TransposeShape(&outputs, kNCHW2NHWC); - } - infer_ret = KernelInferShape(node->GetInputs(), node->GetOutputs(), op_parameter, context->allocator); - if (format == NCHW) { - TransposeShape(&inputs, kNHWC2NCHW); - TransposeShape(&outputs, kNHWC2NCHW); - } - } - if (infer_ret != RET_OK && infer_ret != RET_INFER_INVALID) { - return infer_ret; - } - auto ret = SyncInferRetToCNode(*node, infer_ret); - if (ret != RET_OK) { - MS_LOG(INFO) << "Sync infershape result from InferTensor to Abstract failed: " << node->GetName(); - return ret; - } - return infer_ret; -} - -int InferShapeByOps(const CompileNodePtr &node, Format format) { - auto node_infer_shape = std::make_shared(); - if (node_infer_shape == nullptr) { - MS_LOG(INFO) << "create NodeInferShape manager failed."; - return false; - } - auto cnode = node->GetCNode(); - auto infer_ret = NodeFallBackInferShape(cnode, format); - if (infer_ret != RET_OK && infer_ret != RET_INFER_INVALID) { - return infer_ret; - } - - auto ret = SyncInferRetToLiteTensor(*node, infer_ret); - if (ret != RET_OK) { - MS_LOG(INFO) << "Sync infershape result from Abstract to InferTensor failed: " << node->GetName(); - return ret; - } - return infer_ret; -} - -inline void DumpInferResult(const CompileNode &node, int infer_ret) { -#ifdef Debug - std::ostringstream oss; - oss << "GraphFallBackInferShape(" << node.GetName() << ") InferShape ret: " << infer_ret << ", shape:"; - bool first_output = true; - for (auto &output : node.GetOutputs()) { - if (first_output) { - first_output = false; - } else { - oss << ", "; - } - oss << ShapeVectorToStr(output->shape()); - } - MS_LOG(INFO) << oss.str(); -#endif -} -} // namespace - -int GraphFallBackInferShape(const FuncGraphPtr &graph, Format format, InferContext *context) { return RET_ERROR; } - -int NodeFallBackInferShape(const CNodePtr &cnode, Format format) { - if (cnode == nullptr) { - MS_LOG(INFO) << "cnode is nullptr"; - return RET_ERROR; - } - auto node_infer_shape = std::make_shared(); - if (node_infer_shape == nullptr) { - MS_LOG(INFO) << "create NodeInferShape manager failed."; - return false; - } - auto anf_prim = GetValueNode>(cnode->input(0)); - if (anf_prim == nullptr) { - MS_LOG(INFO) << "primitive is nullptr"; - return RET_ERROR; - } - (void)anf_prim->AddAttr(ops::kFormat, MakeValue(static_cast(format))); - // return {-1} when infer-invalid currently. But we should support {-2} and {-1, -1, -1} in NNACL in future. - auto infer_ret = node_infer_shape->InferShapeByOps(cnode, true); - if (infer_ret != RET_OK && infer_ret != RET_INFER_INVALID) { - return infer_ret; - } - return infer_ret; -} - -namespace { -int OpsOrNNACLInferShape(const CompileNodePtr &node, OpParameter *op_parameter, InferContext *context, - Format infer_format = Format::DEFAULT_FORMAT) { - if (op_parameter != nullptr) { - infer_format = (infer_format == Format::DEFAULT_FORMAT) ? NHWC : infer_format; - auto infer_ret = InferShapeByNNACL(node, op_parameter, infer_format, context); - free(op_parameter); - if (infer_ret != RET_OK && infer_ret != RET_INFER_INVALID) { - MS_LOG(INFO) << "Infer kernel failed for op: " << node->GetName(); - } - return infer_ret; - } else { - infer_format = (infer_format == Format::DEFAULT_FORMAT) ? NCHW : infer_format; - auto infer_ret = InferShapeByOps(node, infer_format); - if (infer_ret != RET_OK && infer_ret != RET_INFER_INVALID) { - MS_LOG(INFO) << "Infer kernel failed for op: " << node->GetName(); - } - return infer_ret; - } -} -} // namespace - -int NodeFallBackInferShape(const CompileNodePtr &node, Format format, InferContext *context) { - MSLITE_CHECK_PTR_RETURN(node, RET_PARAM_INVALID); - if (node->GetType().TypeName() == "Custom") { - return RET_INFER_INVALID; - } - auto base_operator = node->GetBaseOperator(); - MSLITE_CHECK_PTR_RETURN(base_operator, RET_NULL_PTR); - auto op_parameter = OperatorPopulateRegistry::GetInstance()->CreatePopulateByOp(base_operator); - auto iter = FormatAwareOp.find(node->GetType().TypeName()); - int infer_ret; - // Format-not-aware op should infer in format indicated by format attr of mindir. - if (iter != FormatAwareOp.end()) { - infer_ret = OpsOrNNACLInferShape(node, op_parameter, context, format); - } else { // non-format-aware op not care about format, could infershape by NNACL or OPS - infer_ret = OpsOrNNACLInferShape(node, op_parameter, context); - } - DumpInferResult(*node, infer_ret); - return infer_ret; -} - -int GraphFallBackInferShape(const CompileResultPtr &node_list, Format format, InferContext *context) { - for (const auto &node : node_list->GetNodes()) { - auto infer_ret = NodeFallBackInferShape(node, format, context); - if (infer_ret != RET_OK && infer_ret != RET_INFER_INVALID) { - MS_LOG(INFO) << "Infer kernel failed for op: " << node->GetName(); - return infer_ret; - } - } - return RET_OK; -} -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/graph_compiler/infershape_helper.h b/mindspore-lite/src/extendrt/graph_compiler/infershape_helper.h deleted file mode 100644 index 2f7c78b2daa0fecabb815efa7fef47680c2ef7e5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/infershape_helper.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_INFERSHAPE_HELPER_H_ -#define MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_INFERSHAPE_HELPER_H_ - -#include "src/extendrt/graph_compiler/compile_result.h" -#include "src/infer/context.h" - -namespace mindspore { -namespace lite { -int GraphFallBackInferShape(const CompileResultPtr &node_list, Format format, InferContext *context); -int GraphFallBackInferShape(const FuncGraphPtr &graph, Format format, InferContext *context); -int NodeFallBackInferShape(const CompileNodePtr &node, Format format, InferContext *context); -int NodeFallBackInferShape(const CNodePtr &node, Format format = Format::NCHW); -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_INFERSHAPE_HELPER_H_ diff --git a/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.cc b/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.cc deleted file mode 100644 index 67302c5d75ceb8ad3ac992ba953898a813c4061d..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.cc +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/graph_compiler/single_graph_scheduler.h" -#include "src/common/log_util.h" -#include "src/common/tensor_util.h" -#include "src/extendrt/graph_compiler/infershape_helper.h" -#include "src/extendrt/kernel/kernel_selector/kernel_selector.h" -#include "src/litert/kernel_registry.h" -#include "src/litert/pass/format_pass/format_pass.h" -#include "src/litert/pass/format_pass/pass_utils.h" -#include "tools/optimizer/graph/node_infershape.h" -#include "src/common/draw/drawer.h" -#include "src/extendrt/kernel/nnacl/nnacl_base_kernel.h" -#include "src/extendrt/kernel/extendrt_kernel_exec.h" -#include "nnacl/format_transpose_parameter.h" -#include "extendrt/delegate/ascend_native/delegate.h" -#include "extendrt/delegate/factory.h" - -namespace mindspore { -namespace lite { -InferKernel *SingleGraphScheduler::Schedule(const CompileResultPtr &node_list) { - DrawDot(node_list.get(), "start_schedule"); - // infer shape - MS_ASSERT(compile_option_ != nullptr); - // try infer shape, if failed, will infer shape by kernel - (void)GraphFallBackInferShape(node_list, compile_option_->graph_format, context_.get()); - DrawDot(node_list.get(), "fallback_infershape"); - - execution_flow_ = std::make_shared(); - MSLITE_CHECK_PTR_RETURN(execution_flow_, nullptr); - execution_flow_->SetInputs(node_list->GetInputs()); - execution_flow_->SetOutputs(node_list->GetOutputs()); - execution_flow_->SetTensors(node_list->GetTensors()); - execution_flow_->SetContext(context_); - auto schedule_ret = SelectKernel(node_list); - if (schedule_ret != RET_OK) { - MS_LOG(ERROR) << "Scheduler CompileResult to kernels failed."; - return nullptr; - } - - // append kernel with transpose - auto kernel = execution_flow_->ConstructFusionKernel(); - if (kernel == nullptr) { - MS_LOG(ERROR) << "Construct subgraph kernel failed."; - return nullptr; - } - kernel->set_context(context_.get()); - DrawDot(kernel, "select_kernel"); - - auto ret = OptimizeTranspose(kernel); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Optimize format of executionplan failed."; - return nullptr; - } - - auto infer_ret = kernel->InferShape(); - if (infer_ret != RET_OK && infer_ret != RET_INFER_INVALID) { - MS_LOG(ERROR) << "InferShape SubGraph kernel failed."; - return nullptr; - } - DrawDot(reinterpret_cast(kernel), "kernel_infershape"); - return kernel; -} - -void SingleGraphScheduler::CreateDelegateKernel(const std::shared_ptr &node, - mindspore::ExtendDelegate *delegate, - std::vector *kernels) { - auto delegate_kernel = delegate->CreateKernel({node->GetType(), node->GetKernelAttr(), compile_option_->graph_format, - compile_option_->backend, node->GetBaseOperator(), node->GetCNode()}, - node->GetInputs(), node->GetOutputs(), context_.get()); - if (delegate_kernel != nullptr) { - auto kernel_exec = new kernel::ExtendRTKernelExec(delegate_kernel); - kernel_exec->set_name(node->GetName()); - auto desc = kernel_exec->desc(); - desc.format = Format::NCHW; // now all delegate should be nchw, but this will cause bad performance. - kernel_exec->set_desc(desc); - kernel_exec->set_context(this->context_.get()); // not safety - kernels->push_back(kernel_exec); - } -} - -int SingleGraphScheduler::SelectKernel(const CompileResultPtr &node_list) { - kernel_selector_ = kernel::CreateKernelSelector(compile_option_); - const ConfigInfos config_infos; - auto &device_contexts = ctx_->MutableDeviceInfo(); - if (device_contexts.empty()) { - MS_LOG(ERROR) << "no context found"; - return RET_ERROR; - } - auto device_type = device_contexts.at(0)->GetDeviceType(); - auto provider = device_contexts.at(0)->GetProvider(); - auto delegate = - DelegateRegistry::GetInstance().GetDelegate(device_type, provider, ctx_, config_infos); - std::vector kernels; - for (const auto &node : node_list->GetNodes()) { - MSLITE_CHECK_PTR_RETURN(node, RET_NULL_PTR); - if ((delegate != nullptr) && (delegate->IsDelegateNode(node->GetCNode()))) { - CreateDelegateKernel(node, delegate, &kernels); - continue; - } - auto kernel_exec = - kernel_selector_->CreateKernel({node->GetType(), node->GetKernelAttr(), compile_option_->graph_format, - compile_option_->backend, node->GetBaseOperator(), node->GetCNode()}, - node->GetInputs(), node->GetOutputs(), context_.get()); - if (kernel_exec == nullptr) { - MS_LOG(ERROR) << "Create kernel exec for node: " << node->GetName() << " failed."; - return RET_NOT_SUPPORT; - } - kernel_exec->set_name(node->GetName()); - kernels.emplace_back(kernel_exec); - } - execution_flow_->SetKernels(kernels); - return RET_OK; -} - -bool SingleGraphScheduler::HandleWeightForKernels() { - if (compile_option_->datatype != kNumberTypeFloat32 && compile_option_->datatype != kNumberTypeFloat16) { - return true; - } - auto kernels = execution_flow_->GetKernels(); - for (const auto &kernel : kernels) { - for (const auto &input : kernel->in_tensors()) { - // only cast const tensor - if (!input->IsConst()) { - continue; - } - // only support fp32->fp16 or fp16->fp32 - if (input->data_type() != kNumberTypeFloat32 && input->data_type() != kNumberTypeFloat16) { - continue; - } - auto ret = CastConstTensorData(input, compile_option_->datatype, context_->device_and_pkg_support_fp16_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Cast data for tensor: " << input->tensor_name() << " failed."; - return false; - } - } - } - return true; -} - -namespace { -kernel::KernelExec *CreateFormatTransFunc(InferTensor *input, InferTensor *output, - const pass::TransInfoPair &trans_info, const std::string &name, - const InferContext *ctx, const kernel::KernelKey &desc) { - auto param = reinterpret_cast(malloc(sizeof(FormatTransposeParameter))); - if (param == nullptr) { - MS_LOG(ERROR) << "Malloc FormatTransposeParameter failed."; - return nullptr; - } - (void)memset(param, 0, sizeof(FormatTransposeParameter)); - param->op_parameter_.type_ = static_cast(schema::PrimitiveType_FormatTranspose); - param->src_format_ = static_cast((trans_info.src_format_)); - param->dst_format_ = static_cast((trans_info.dst_format_)); - kernel::KernelKey format_transpose_key = desc; - format_transpose_key.type = schema::PrimitiveType_FormatTranspose; - format_transpose_key.format = NHWC; - format_transpose_key.data_type = input->data_type(); - - auto lite_kernel = KernelRegistry::GetInstance()->GetLiteKernel({input}, {output}, ctx, &format_transpose_key, - reinterpret_cast(param)); - if (lite_kernel == nullptr) { - MS_LOG(ERROR) << "Create format-transpose lite-kernel failed."; - free(param); - return nullptr; - } - auto base_kernel = new (std::nothrow) kernel::NNACLBaseKernel(std::shared_ptr(lite_kernel)); - if (base_kernel == nullptr) { - MS_LOG(ERROR) << "Create format-transpose base-kernel failed."; - return nullptr; - } - auto *kernel_exec = new (std::nothrow) kernel::ExtendRTKernelExec(std::shared_ptr(base_kernel)); - if (kernel_exec == nullptr) { - MS_LOG(ERROR) << "Create format-transpose kernel-exec failed."; - return nullptr; - } - kernel_exec->set_desc(format_transpose_key); - kernel_exec->set_context(ctx); - kernel_exec->set_name(name); - return kernel_exec; -} -} // namespace - -Status SingleGraphScheduler::OptimizeTranspose(kernel::SubGraphKernel *kernel) { - std::vector subgraph_list = {kernel}; - auto ret = - pass::DoFormatPass(&subgraph_list, &kernel->tensors(), compile_option_->graph_format, CreateFormatTransFunc); - if (ret != RET_OK) { - MS_LOG(INFO) << "Run Optimize transpose pass failed."; - return kLiteError; - } - return kSuccess; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.h b/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.h deleted file mode 100644 index 286c6c3dcca7b4d434141976aee25e9adefff05d..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_compiler/single_graph_scheduler.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_SINGLE_GRAPH_SCHEDULER_H_ -#define MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_SINGLE_GRAPH_SCHEDULER_H_ -#include -#include -#include -#include -#include -#include "src/extendrt/graph_compiler/compile_result.h" -#include "src/extendrt/execution_flow.h" -#include "src/litert/inner_context.h" -#include "src/extendrt/graph_compiler/compile_option.h" -#include "src/infer/kernel.h" -#include "src/extendrt/kernel/kernel_selector/kernel_selector.h" -#include "src/infer/context.h" -#include "src/extendrt/delegate/type.h" -#include "src/executor/sub_graph_kernel.h" - -namespace mindspore { -namespace lite { -class SingleGraphScheduler { - public: - explicit SingleGraphScheduler(InferContextPtr context, const std::shared_ptr &ctx, - std::shared_ptr option) - : context_(std::move(context)), ctx_(ctx), compile_option_(std::move(option)) {} - virtual ~SingleGraphScheduler() = default; - InferKernel *Schedule(const CompileResultPtr &node_list); - - private: - int SelectKernel(const CompileResultPtr &node_list); - bool HandleWeightForKernels(); - Status OptimizeTranspose(kernel::SubGraphKernel *kernels); - - private: - InferContextPtr context_{nullptr}; - std::shared_ptr ctx_{nullptr}; - std::shared_ptr compile_option_{nullptr}; - infer::ExecutionFlowPtr execution_flow_{nullptr}; - std::shared_ptr kernel_selector_{nullptr}; - void CreateDelegateKernel(const std::shared_ptr &node, mindspore::ExtendDelegate *delegate, - std::vector *kernels); - - std::map op_parameters_; -}; -using SingleGraphSchedulerPtr = std::shared_ptr; -} // namespace lite -} // namespace mindspore -#endif diff --git a/mindspore-lite/src/extendrt/graph_executor/default_executor.cc b/mindspore-lite/src/extendrt/graph_executor/default_executor.cc deleted file mode 100644 index 3c9b0fb59e46f3e65cd533262619a4ad9d117026..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_executor/default_executor.cc +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/graph_executor/default_executor.h" - -#include -#include "extendrt/graph_executor/factory.h" -#include "extendrt/execution_plan.h" -#include "litert/mindrt_executor.h" - -namespace mindspore { -DefaultExecutor::DefaultExecutor() { - name_ = "DefaultExecutor"; - execution_plan_ = nullptr; -} - -DefaultExecutor::DefaultExecutor(const std::string &name, - std::shared_ptr execution_plan) { - name_ = name; - execution_plan_ = execution_plan; -} - -bool DefaultExecutor::Init() { - auto infer_execution_plan = std::dynamic_pointer_cast(execution_plan_); - if (infer_execution_plan == nullptr) { - MS_LOG(ERROR) << "Not Supported execution plan is passed"; - return false; - } - if (!infer_execution_plan->PrepareKernels()) { - MS_LOG(ERROR) << "Prepare kernels failed"; - return false; - } - - inited_ = true; - return true; -} - -Status DefaultExecutor::Prepare() { - if (!Init()) { - MS_LOG(ERROR) << "Init executor failed"; - return kLiteError; - } - - MS_ASSERT(inited_ == true); - MS_ASSERT(execution_plan_ != nullptr); - - // only single sub graph will choose this executor - MS_ASSERT(execution_plan_->ToKernelList().size() == 1); - - auto sub_graph_kernel = execution_plan_->ToKernelList().at(0); - if (sub_graph_kernel == nullptr) { - MS_LOG(ERROR) << "Sub graph kernel is nullptr"; - return kLiteNullptr; - } - - if (sub_graph_kernel->Prepare() != RET_OK) { - MS_LOG(ERROR) << "Sub graph kernel prepare failed"; - return kLiteError; - } - - return kSuccess; -} - -Status DefaultExecutor::Execute() { - if (!inited_) { - MS_LOG(ERROR) << "Executor is not inited correctly"; - return kLiteError; - } - MS_ASSERT(execution_plan_ != nullptr); - - // only single sub graph will choose this executor - MS_ASSERT(execution_plan_->ToKernelList().size() == 1); - - auto sub_graph_kernel = execution_plan_->ToKernelList().at(0); - if (sub_graph_kernel == nullptr) { - MS_LOG(ERROR) << "Sub graph kernel is nullptr"; - return kLiteNullptr; - } - - // copy data to sub_graph_inputs - auto sub_graph_inputs = sub_graph_kernel->in_tensors(); - auto user_inputs = execution_plan_->GetInputs(); - for (size_t i = 0; i < user_inputs.size(); ++i) { - auto sub_graph_input = sub_graph_inputs.at(i); - auto user_input = user_inputs.at(i); - sub_graph_input->set_data(user_input->data()); - sub_graph_input->set_category(lite::GRAPH_INPUT); - } - - // copy data to sub_graph_outputs - auto sub_graph_outputs = sub_graph_kernel->out_tensors(); - auto user_outputs = execution_plan_->GetOutputs(); - for (size_t i = 0; i < user_outputs.size(); ++i) { - auto sub_graph_output = sub_graph_outputs.at(i); - auto user_output = user_outputs.at(i); - sub_graph_output->set_data(user_output->data()); - sub_graph_output->set_category(lite::GRAPH_OUTPUT); - } - - if (sub_graph_kernel->Execute() != RET_OK) { - MS_LOG(ERROR) << "Sub graph kernel execute failed"; - return kLiteError; - } - - for (auto sub_graph_input : sub_graph_inputs) { - sub_graph_input->set_data(nullptr); - } - - for (auto sub_graph_output : sub_graph_outputs) { - sub_graph_output->set_data(nullptr); - } - - return kSuccess; -} - -int DefaultExecutor::Resize(const std::vector &inputs, - const std::vector> &dims) { - if (!inited_) { - MS_LOG(ERROR) << "Executor is not inited correctly"; - return kLiteError; - } - MS_ASSERT(execution_plan_ != nullptr); - - // only single sub graph will choose this executor - MS_ASSERT(execution_plan_->ToKernelList().size() == 1); - - auto sub_graph_kernel = execution_plan_->ToKernelList().at(0); - if (sub_graph_kernel == nullptr) { - MS_LOG(ERROR) << "Sub graph kernel is nullptr"; - return kLiteNullptr; - } - - auto ret = sub_graph_kernel->ReSize(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Sub graph kernel resize failed"; - return kLiteError; - } - - return RET_OK; -} - -static std::shared_ptr DefaultGraphExecutorCreator( - const std::string &name, std::shared_ptr execution_plan) { - auto graph_executor = std::make_shared(name, execution_plan); - return graph_executor; -} -REG_GRAPH_EXECUTOR(kDefaultExecutor, DefaultGraphExecutorCreator); -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/graph_executor/factory.h b/mindspore-lite/src/extendrt/graph_executor/factory.h deleted file mode 100644 index 4c1af3aadce473587168f09cf8cbdf3ae54ba8ba..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_executor/factory.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_EXECUTOR_FACTORY_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_EXECUTOR_FACTORY_H_ - -#include -#include -#include - -#include "extendrt/graph_executor/type.h" -#include "infer/executor.h" -#include "infer/execution_plan.h" - -namespace mindspore { -using GraphExecutorRegFunc = std::function( - const std::string &name, std::shared_ptr execution_plan)>; - -class GraphExecutorRegistry { - public: - GraphExecutorRegistry() = default; - virtual ~GraphExecutorRegistry() = default; - - static GraphExecutorRegistry &GetInstance(); - - void RegExecutor(const GraphExecutorType &type, const GraphExecutorRegFunc &creator); - - std::shared_ptr GetExecutor( - const mindspore::GraphExecutorType &type, const std::string &name, - std::shared_ptr execution_plan); - - private: - mindspore::HashMap graph_executor_map_; -}; - -class GraphExecutorRegistrar { - public: - GraphExecutorRegistrar(const mindspore::GraphExecutorType &type, const GraphExecutorRegFunc &creator) { - GraphExecutorRegistry::GetInstance().RegExecutor(type, creator); - } - ~GraphExecutorRegistrar() = default; -}; - -#define REG_GRAPH_EXECUTOR(type, creator) static GraphExecutorRegistrar g_##type##GraphExecutor(type, creator); -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_EXECUTOR_FACTORY_H_ diff --git a/mindspore-lite/src/extendrt/graph_executor/flow_executor.cc b/mindspore-lite/src/extendrt/graph_executor/flow_executor.cc deleted file mode 100644 index d787df7e19a518d5bc57a796fab9c1251ac795b1..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_executor/flow_executor.cc +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/graph_executor/flow_executor.h" -#include "extendrt/execution_plan.h" -#include "litert/mindrt_executor.h" - -namespace mindspore::infer { -FlowExecutor::FlowExecutor() { FlowExecutor("FlowExecutor"); } - -FlowExecutor::FlowExecutor(const std::string &name, std::shared_ptr execution_plan) { - name_ = name; - execution_plan_ = execution_plan; - auto infer_execution_plan = std::dynamic_pointer_cast(execution_plan_); - if (infer_execution_plan == nullptr) { - MS_LOG(ERROR) << "FlowExecutor::FlowExecutor Not Supported execution plan is passed"; - } else { - executor_ = std::make_shared(infer_execution_plan->GetInputMap(), - infer_execution_plan->GetOutputMap); - } -} - -Status FlowExecutor::Prepare(std::shared_ptr execution_flow) { - if (executor_ == nullptr) { - MS_LOG(ERROR) << "FlowExecutor::Prepare executor is nullptr"; - return kLiteError; - } - - if (execution_flow == nullptr) { - MS_LOG(ERROR) << "FlowExecutor::Prepare execution flow is nullptr"; - return kLiteError; - } - - return executor_->Prepare(execution_flow->GetKernels(), execution_flow->GetInputs(), execution_flow->GetOutputs(), - execution_flow->GetContext); -} - -Status FlowExecutor::Execute() { return kSuccess; } - -int FlowExecutor::Resize(const std::vector &inputs, const std::vector> &dims) { - return kSuccess; -} -} // namespace mindspore::infer diff --git a/mindspore-lite/src/extendrt/graph_executor/flow_executor.h b/mindspore-lite/src/extendrt/graph_executor/flow_executor.h deleted file mode 100644 index 1e69956ab0d9da794884ca1ac426adcaa165c4fc..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_executor/flow_executor.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_FLOW_EXECUTOR_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_FLOW_EXECUTOR_H_ - -#include -#include -#include - -#include "infer/executor.h" -#include "infer/execution_plan.h" -#include "litert/executor.h" - -namespace mindspore::infer { -class FlowExecutor : public mindspore::infer::abstract::Executor { - public: - FlowExecutor(); - explicit FlowExecutor(const std::string &name, std::shared_ptr execution_plan); - virtual ~FlowExecutor() = default; - - const std::string &Name() override { return name_; } - - Status Prepare() override; - - Status Execute() override; - - int Resize(const std::vector &inputs, const std::vector> &dims) override; - - private: - std::string name_; - std::shared_ptr execution_flow_; - std::shared_ptr executor_; - std::shared_ptr execution_plan_; -}; -} // namespace mindspore::infer - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_FLOW_EXECUTOR_H_ diff --git a/mindspore-lite/src/extendrt/graph_executor/mindrt_graph_executor.cc b/mindspore-lite/src/extendrt/graph_executor/mindrt_graph_executor.cc deleted file mode 100644 index f22286588e0437c3929679a2871802fd413bd7e6..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_executor/mindrt_graph_executor.cc +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "extendrt/graph_executor/mindrt_graph_executor.h" - -#include - -#include "extendrt/graph_executor/factory.h" -#include "litert/mindrt_executor.h" -#include "extendrt/execution_plan.h" - -namespace mindspore { -MindRTGraphExecutor::MindRTGraphExecutor() { - name_ = "MindRTGraphExecutor"; - execution_plan_ = nullptr; -} - -MindRTGraphExecutor::MindRTGraphExecutor(const std::string &name, - std::shared_ptr execution_plan) { - name_ = name; - execution_plan_ = execution_plan; -} - -bool MindRTGraphExecutor::Init() { - auto infer_execution_plan = std::dynamic_pointer_cast(execution_plan_); - if (infer_execution_plan == nullptr) { - MS_LOG(ERROR) << "Not Supported execution plan is passed"; - return false; - } - if (!infer_execution_plan->PrepareKernels()) { - MS_LOG(ERROR) << "Prepare kernels failed"; - return false; - } - - mindrt_executor_ = std::make_shared(infer_execution_plan->GetOutputsMap(), - infer_execution_plan->GetInputsMap()); - if (mindrt_executor_ == nullptr) { - MS_LOG(ERROR) << "Create mindrt executor failed"; - return false; - } - - inited_ = true; - - return true; -} - -Status MindRTGraphExecutor::Prepare() { - if (!Init()) { - MS_LOG(ERROR) << "Init executor failed"; - return kLiteError; - } - - MS_ASSERT(inited_ == true); - MS_ASSERT(mindrt_executor_ != nullptr); - MS_ASSERT(execution_plan_ != nullptr); - - auto ret = mindrt_executor_->Prepare(execution_plan_->ToKernelList(), execution_plan_->GetInputs(), - execution_plan_->GetOutputs(), execution_plan_->GetContext().get()); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Prepare execution plan failed with code " << ret; - return kLiteError; - } - return kSuccess; -} - -Status MindRTGraphExecutor::Execute() { - if (!inited_) { - MS_LOG(ERROR) << "Executor is not inited correctly"; - return kLiteError; - } - MS_ASSERT(mindrt_executor_ != nullptr); - MS_ASSERT(execution_plan_ != nullptr); - - auto ret = - mindrt_executor_->Run(execution_plan_->GetInputs(), execution_plan_->GetOutputs(), execution_plan_->ToKernelList(), - execution_plan_->GetKernelBeforeCallBack(), execution_plan_->GetKernelAfterCallBack()); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Run execution plan failed with code " << ret; - return kLiteError; - } - return kSuccess; -} - -int MindRTGraphExecutor::Resize(const std::vector &inputs, - const std::vector> &dims) { - if (!inited_) { - MS_LOG(ERROR) << "Executor is not inited correctly"; - return kLiteError; - } - MS_ASSERT(mindrt_executor_ != nullptr); - - std::vector> dims32; - std::transform(dims.begin(), dims.end(), std::back_inserter(dims32), [](std::vector shape) { - std::vector shape32; - std::transform(shape.begin(), shape.end(), std::back_inserter(shape32), - [](int64_t dim) { return static_cast(dim); }); - return shape32; - }); - return mindrt_executor_->Resize(inputs, dims32); -} - -static std::shared_ptr MindRTGraphExecutorCreator( - const std::string &name, std::shared_ptr execution_plan) { - auto graph_executor = std::make_shared(name, execution_plan); - return graph_executor; -} -REG_GRAPH_EXECUTOR(kMindRTExecutor, MindRTGraphExecutorCreator); -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/graph_executor/mindrt_graph_executor.h b/mindspore-lite/src/extendrt/graph_executor/mindrt_graph_executor.h deleted file mode 100644 index 4ed3a0d6da6086b86cd721233408ced205286219..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_executor/mindrt_graph_executor.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_DEFAULT_GRAPH_RUNTIME_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_DEFAULT_GRAPH_RUNTIME_H_ - -#include -#include -#include - -#include "infer/executor.h" -#include "infer/execution_plan.h" -#include "litert/executor.h" - -namespace mindspore { -/** - * MindRTGraphExecutor: Executor using MindRT to make the actor async execute for parallelize speedup - */ -class MindRTGraphExecutor : public mindspore::infer::abstract::Executor { - public: - MindRTGraphExecutor(); - explicit MindRTGraphExecutor(const std::string &name, std::shared_ptr execution_plan); - virtual ~MindRTGraphExecutor() = default; - - const std::string &Name() override { return name_; } - - Status Prepare() override; - - Status Execute() override; - - int Resize(const std::vector &inputs, - const std::vector> &dims) override; - - private: - bool Init(); - - private: - std::string name_; - std::shared_ptr mindrt_executor_; - std::shared_ptr execution_plan_; - bool inited_ = false; -}; -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_DEFAULT_GRAPH_RUNTIME_H_ diff --git a/mindspore-lite/src/extendrt/graph_executor/plan_executor.cc b/mindspore-lite/src/extendrt/graph_executor/plan_executor.cc deleted file mode 100644 index ec8c4f5107f105d8b3e0c4154c2b3125061564ec..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_executor/plan_executor.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/graph_executor/plan_executor.h" -#include "extendrt/execution_plan.h" -#include "litert/mindrt_executor.h" - -namespace mindspore::infer { -PlanExecutor::PlanExecutor() { PlanExecutor("PlanExecutor"); } - -PlanExecutor::PlanExecutor(const std::string &name, std::shared_ptr execution_plan) { - name_ = name; - execution_plan_ = execution_plan; - auto infer_execution_plan = std::dynamic_pointer_cast(execution_plan_); - if (infer_execution_plan == nullptr) { - MS_LOG(ERROR) << "FlowExecutor::FlowExecutor Not Supported execution plan is passed"; - } else { - executor_ = std::make_shared(infer_execution_plan->GetInputMap(), - infer_execution_plan->GetOutputMap()); - } -} - -Status PlanExecutor::Prepare() { - if (executor_ == nullptr) { - MS_LOG(ERROR) << "FlowExecutor::Prepare executor is nullptr"; - return kLiteError; - } - - if (execution_plan_ == nullptr) { - MS_LOG(ERROR) << "FlowExecutor::Prepare execution plan is nullptr"; - return kLiteError; - } - return executor_->Prepare(execution_plan_->ToKernelList(), execution_plan_->GetInputs(), - execution_plan_->GetOutputs(), execution_plan_->GetContext()); -} - -Status PlanExecutor::Execute() { - if (executor_ == nullptr) { - MS_LOG(ERROR) << "FlowExecutor::Execute executor is nullptr"; - return kLiteError; - } - if (execution_plan_ == nullptr) { - MS_LOG(ERROR) << "FlowExecutor::Execute execution plan is nullptr"; - return kLiteError; - } - return executor_->Run(execution_plan_->GetInputs(), execution_plan_->GetOutputs(), execution_plan_->ToKernelList(), - execution_plan_->GetKernelBeforeCallBack(), execution_plan_->GetKernelAfterCallBack()); -} - -int PlanExecutor::Resize(const std::vector &inputs, const std::vector> &dims) { - if (executor_ == nullptr) { - MS_LOG(ERROR) << "FlowExecutor::Resize executor is nullptr"; - return kLiteError; - } - return executor_->Resize(inputs, dims); -} -} // namespace mindspore::infer diff --git a/mindspore-lite/src/extendrt/graph_executor/plan_executor.h b/mindspore-lite/src/extendrt/graph_executor/plan_executor.h deleted file mode 100644 index 1bf70b45bfc594b6087a90eb7faa406daeb678a1..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_executor/plan_executor.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_PLAN_EXECUTOR_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_PLAN_EXECUTOR_H_ - -#include -#include -#include - -#include "infer/executor.h" -#include "infer/execution_plan.h" -#include "litert/executor.h" - -namespace mindspore::infer { -class PlanExecutor : public mindspore::infer::abstract::Executor { - public: - PlanExecutor(); - explicit PlanExecutor(const std::string &name); - virtual ~PlanExecutor() = default; - - const std::string &Name() override { return name_; } - - Status Prepare() override; - - Status Execute() override; - - int Resize(const std::vector &inputs, const std::vector> &dims) override; - - private: - std::string name_; - std::shared_ptr executor_; - std::shared_ptr execution_plan_; -}; -} // namespace mindspore::infer - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_PLAN_EXECUTOR_H_ diff --git a/mindspore-lite/src/extendrt/graph_executor/type.h b/mindspore-lite/src/extendrt/graph_executor/type.h deleted file mode 100644 index 66cfbe9e27420be5f0a632824eefd20ee68ee478..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_executor/type.h +++ /dev/null @@ -1,22 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_EXECUTOR_TYPE_H_ -#define MINDSPORE_LITE_EXTENDRT_GRAPH_EXECUTOR_TYPE_H_ - -namespace mindspore { -enum GraphExecutorType { kDefaultExecutor = 0, kMindRTExecutor, kNoneExcutor }; -} // namespace mindspore -#endif // MINDSPORE_LITE_EXTENDRT_GRAPH_EXECUTOR_TYPE_H_ diff --git a/mindspore-lite/src/extendrt/graph_partitioner/condition_partitioner.cc b/mindspore-lite/src/extendrt/graph_partitioner/condition_partitioner.cc deleted file mode 100644 index e29773b25c48fe49d633bc9a997a4f17ae1383c8..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_partitioner/condition_partitioner.cc +++ /dev/null @@ -1,228 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include - -#include "extendrt/graph_partitioner/condition_partitioner.h" - -#include "mindspore/ops/op_def/structure_ops.h" -#include "mindspore/ops/op_def/sequence_ops.h" -#include "mindspore/ops/op_def/framework_ops.h" -#include "utils/ms_context.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_h.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" - -namespace mindspore { -std::vector ConditionPartitioner::Partition(const FuncGraphPtr &graph, bool *multi_target) { - MS_EXCEPTION_IF_NULL(graph); - auto nodes = TopoSort(graph->get_return()); - - MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); - bool contain_multi_target = ContainMultiTarget(nodes); - if (multi_target != nullptr) { - *multi_target = contain_multi_target; - } - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - std::string default_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); - - std::vector segments; - std::vector segment_nodes; - std::map node_to_segment; - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - auto separete_type = EvalSeparatorCondition(segment_nodes, node); - switch (separete_type) { - case kIsolatedNodeSeparate: { - NodesToSegments(segment_nodes, &segments, &node_to_segment); - segment_nodes.clear(); - segment_nodes.emplace_back(node); - auto segment = std::make_shared(segment_nodes, true); - segments.push_back(segment); - segment_nodes.clear(); - break; - } - case kDirectSeparate: { - NodesToSegments(segment_nodes, &segments, &node_to_segment); - segment_nodes.clear(); - break; - } - case kNoSeparate: { - segment_nodes.emplace_back(node); - break; - } - default: { - MS_LOG(ERROR) << "Separate graph failed"; - return std::vector{}; - } - } - } - if (contain_multi_target) { - AddSegmentDependency(graph, node_to_segment); - RemoveUselessDependency(&segments); - } - return segments; -} - -SeparateType ConditionPartitioner::EvalSeparatorCondition(const std::vector &prev_segment, - const AnfNodePtr &node) { - for (auto separator : separators_) { - MS_EXCEPTION_IF_NULL(separator); - auto separate = separator->GraphSeparateCheck(prev_segment, node); - if (separate == kNoSeparate) { - continue; - } - return separate; - } - return kNoSeparate; -} - -void ConditionPartitioner::NodesToSegments(const std::vector &segment_nodes, - std::vector *segments, - std::map *node_to_segment) { - if (segment_nodes.empty()) { - return; - } - - AddSegment(segment_nodes, segments, node_to_segment); - return; -} - -void ConditionPartitioner::AddSegment(const std::vector &nodes, std::vector *segments, - std::map *node_to_segment) { - MS_EXCEPTION_IF_NULL(segments); - MS_EXCEPTION_IF_NULL(node_to_segment); - auto segment = std::make_shared(nodes, false); - segments->emplace_back(segment); - for (auto &node : nodes) { - (*node_to_segment)[node] = segment; - } -} - -void ConditionPartitioner::AddSegmentDependency(const FuncGraphPtr &graph, - const std::map &node_to_segment) { - MS_EXCEPTION_IF_NULL(graph); - std::stack to_visit; - std::map nodes_ref; - CalcNodeRefCount(graph, &nodes_ref); - to_visit.push(graph->get_return()); - while (!to_visit.empty()) { - auto &node = to_visit.top(); - MS_EXCEPTION_IF_NULL(node); - to_visit.pop(); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto node_inputs = cnode->inputs(); - GraphSegmentPtr node_segment{nullptr}; - auto node_iter = node_to_segment.find(node); - if (node_iter != node_to_segment.end()) { - node_segment = node_iter->second; - } - for (auto &input : node_inputs) { - if (node_segment != nullptr && !node_segment->is_cut_ && input != nullptr && input->isa()) { - GraphSegmentPtr input_segment{nullptr}; - auto input_iter = node_to_segment.find(input); - if (input_iter != node_to_segment.end()) { - input_segment = input_iter->second; - } - if (input_segment != nullptr && input_segment != node_segment && !input_segment->is_cut_) { - node_segment->AddPreSegment(input_segment); - } - } - auto ref_iter = nodes_ref.find(input); - if (ref_iter != nodes_ref.end()) { - ref_iter->second--; - if (ref_iter->second != 0) { - continue; - } - } - to_visit.push(input); - } - } -} - -void ConditionPartitioner::RemoveUselessDependency(const std::vector *segments) { - MS_EXCEPTION_IF_NULL(segments); - for (auto &segment : *segments) { - MS_EXCEPTION_IF_NULL(segment); - if (segment->is_cut_) { - continue; - } - bool total_virtual_node = true; - for (auto &node : segment->nodes_) { - if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) || - IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary) || - IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) || - IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) || - IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { - continue; - } - total_virtual_node = false; - break; - } - if (total_virtual_node) { - segment->pre_segments_.clear(); - } - } -} - -void ConditionPartitioner::CalcNodeRefCount(const FuncGraphPtr &graph, std::map *nodes_ref) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(nodes_ref); - std::queue queue; - queue.push(graph->get_return()); - std::set visited; - while (!queue.empty()) { - auto node = queue.front(); - queue.pop(); - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - for (auto &input : cnode->inputs()) { - auto iter = nodes_ref->find(input); - if (iter != nodes_ref->end()) { - iter->second++; - } else { - (void)nodes_ref->emplace(input, 1UL); - } - if (visited.find(input) != visited.end()) { - continue; - } - visited.insert(input); - queue.push(input); - } - } -} -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/graph_partitioner/condition_partitioner.h b/mindspore-lite/src/extendrt/graph_partitioner/condition_partitioner.h deleted file mode 100644 index a35ab4d378b86696458993f4eb18e36552adb1bf..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_partitioner/condition_partitioner.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_PARTITIONER_CONDITION_PARTITIONER_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_PARTITIONER_CONDITION_PARTITIONER_H_ - -#include -#include -#include - -#include "extendrt/graph_partitioner/type.h" -#include "extendrt/graph_partitioner/graph_separator/type.h" - -namespace mindspore { -class ConditionPartitioner : public GraphPartitioner { - public: - ConditionPartitioner() = default; - ~ConditionPartitioner() = default; - - std::vector Partition(const FuncGraphPtr &graph, bool *multi_target = nullptr) override; - - void SetSeparators(std::vector> separator) { separators_ = separator; } - - std::vector> GetSeparators() { return separators_; } - - private: - SepareteType EvalSeparatorCondition(const std::vector &prev_segment, const AnfNodePtr &node); - void NodesToSegments(const std::vector &segment_nodes, std::vector *segments, - std::map *node_to_segment); - void AddSegment(const std::vector &nodes, std::vector *segments, - std::map *node_to_segment); - void AddSegmentDependency(const FuncGraphPtr &graph, const std::map &node_to_segment); - void RemoveUselessDependency(const std::vector *segments); - void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *nodes_ref); - - private: - std::vector> separators_; -}; -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_PARTITIONER_CONDITION_PARTITIONER_H_ diff --git a/mindspore-lite/src/extendrt/graph_partitioner/factory.h b/mindspore-lite/src/extendrt/graph_partitioner/factory.h deleted file mode 100644 index 8947de1f3de8e80a70cbbd3d6ee95bed189440ae..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_partitioner/factory.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_PARTITIONER_FACTORY_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_PARTITIONER_FACTORY_H_ - -#include -#include - -#include "extendrt/graph_partitioner/type.h" -#include "infer/graph_partitioner.h" - -namespace mindspore { -using GraphPartitionerRegFunc = std::function()>; - -class GraphPartitionerRegistry { - public: - GraphPartitionerRegistry() = default; - virtual ~GraphPartitionerRegistry() = default; - - static GraphPartitionerRegistry &GetInstance(); - - void RegPartitioner(const GraphPartitionerType &type, const GraphPartitionerRegFunc &creator); - - std::shared_ptr GetPartitioner(const mindspore::GraphPartitionerType &type); - - private: - mindspore::HashMap graph_partitioner_map_; -}; - -class GraphPartitionerRegistrar { - public: - GraphPartitionerRegistrar(const mindspore::GraphPartitionerType &type, const GraphPartitionerRegFunc &creator) { - GraphPartitionerRegistry::GetInstance().RegPartitioner(type, creator); - } - ~GraphPartitionerRegistrar() = default; -}; - -#define REG_GRAPH_PARTITIONER(type, creator) static GraphPartitionerRegistrar g_##type##Partitioner(type, creator); -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_PARTITIONER_FACTORY_H_ diff --git a/mindspore-lite/src/extendrt/graph_partitioner/type.h b/mindspore-lite/src/extendrt/graph_partitioner/type.h deleted file mode 100644 index 17daa1792d249c1548ce7c17380de99d26c5e57c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_partitioner/type.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_PARTITIONER_TYPE_H_ -#define MINDSPORE_LITE_EXTENDRT_GRAPH_PARTITIONER_TYPE_H_ - -#include -#include - -#include "ir/func_graph.h" - -namespace mindspore { -enum GraphPartitionerType { kConditionPartitioner = 0, kCustomPartitioner, kNoneRuntime }; - -class GraphPartitioner : public std::enable_shared_from_this { - public: - virtual ~GraphPartitioner() = default; - - /// \brief Partition FuncGraph into several GraphSegment - /// - /// \param[in] graph FuncGraph need to partition. - /// \param[out] multi_target if the graph need run on multi target - /// - /// \return list of GraphSegment for SubGraphs - virtual std::vector Partition(const FuncGraphPtr &graph, bool *multi_target = nullptr) = 0; -}; -} // namespace mindspore -#endif // MINDSPORE_LITE_EXTENDRT_GRAPH_PARTITIONER_TYPE_H_ diff --git a/mindspore-lite/src/extendrt/graph_runtime/default_graph_runtime.cc b/mindspore-lite/src/extendrt/graph_runtime/default_graph_runtime.cc deleted file mode 100644 index 4054e69876d223599d78b1d48d7ba99fa0b2140e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_runtime/default_graph_runtime.cc +++ /dev/null @@ -1,240 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "extendrt/graph_runtime/default_graph_runtime.h" - -#include "extendrt/graph_runtime/factory.h" -#include "extendrt/graph_executor/factory.h" -#include "extendrt/utils/tensor_utils.h" -#include "extendrt/execution_plan.h" -#include "executor/sub_graph_kernel.h" - -namespace mindspore { -using ExecutionPlan = mindspore::infer::abstract::ExecutionPlan; - -Status DefaultGraphRuntime::Prepare(std::shared_ptr execution_plan) { - MS_LOG(DEBUG) << "Graph runtime prepare begin"; - - if (execution_plan == nullptr) { - MS_LOG(ERROR) << "Execution plan is nullptr."; - return kLiteNullptr; - } - execution_plan_ = execution_plan; - - auto executor = SelectExecutor(); - if (executor == nullptr) { - MS_LOG(ERROR) << "Select executor is nullptr."; - return kLiteNullptr; - } - - MS_LOG(DEBUG) << "Prepare executor begin"; - auto status = executor->Prepare(); - if (status != kSuccess) { - MS_LOG(ERROR) << "Prepare executor failed executor: " << executor->Name(); - return kLiteError; - } - MS_LOG(DEBUG) << "Prepare executor end"; - - MS_LOG(DEBUG) << "Graph runtime prepare end"; - return kSuccess; -} - -Status DefaultGraphRuntime::Execute() { - MS_LOG(INFO) << "Graph runtime execute begin"; - - if (execution_plan_ == nullptr) { - MS_LOG(ERROR) << "Execution plan is nullptr"; - return kLiteNullptr; - } - - auto executor = SelectExecutor(); - if (executor == nullptr) { - MS_LOG(ERROR) << "Select executor is nullptr"; - return kLiteNullptr; - } - - MS_LOG(DEBUG) << "Execute executor begin"; - auto status = executor->Execute(); - if (status != kSuccess) { - MS_LOG(ERROR) << "Execute executor failed executor: " << executor->Name(); - return kLiteError; - } - MS_LOG(DEBUG) << "Execute executor end"; - - MS_LOG(DEBUG) << "Graph runtime execute end"; - return kSuccess; -} - -Status DefaultGraphRuntime::Execute(const std::vector &inputs, - const std::vector &outputs, - infer::abstract::KernelCallBack before, infer::abstract::KernelCallBack after) { - MS_LOG(DEBUG) << "Graph runtime execute begin"; - - if (execution_plan_ == nullptr) { - MS_LOG(ERROR) << "Execution Plan is nullptr"; - return kLiteNullptr; - } - - auto executor = SelectExecutor(); - if (executor == nullptr) { - MS_LOG(ERROR) << "Select Executor is nullptr."; - return kLiteNullptr; - } - - MS_LOG(DEBUG) << "Execute executor begin"; - execution_plan_->SetInputs(inputs); - execution_plan_->SetKernelBeforeCallBack(before); - execution_plan_->SetKernelAfterCallBack(after); - auto status = executor->Execute(); - if (status != kSuccess) { - MS_LOG(ERROR) << "Execute executor failed executor: " << executor->Name(); - return kLiteError; - } - MS_LOG(DEBUG) << "Execute executor end"; - - MS_LOG(INFO) << "Graph runtime execute end"; - return kSuccess; -} - -Status DefaultGraphRuntime::Resize(const std::vector &inputs, - const std::vector> &dims) { - MS_LOG(DEBUG) << "Graph runtime resize begin"; - - if (execution_plan_ == nullptr) { - MS_LOG(ERROR) << "Execution plan is nullptr"; - return kLiteNullptr; - } - - auto executor = SelectExecutor(); - if (executor == nullptr) { - MS_LOG(ERROR) << "Select executor is nullptr"; - return kLiteNullptr; - } - - auto graph_inputs = execution_plan_->GetInputs(); - auto original_dims = AbstractTensorUtils::GetTensorListShapes(graph_inputs); - - AbstractTensorUtils::SetTensorListShapse(graph_inputs, dims); - - if (!ResizeKernels()) { - AbstractTensorUtils::SetTensorListShapse(graph_inputs, original_dims); - if (!ResizeKernels()) { - MS_LOG(ERROR) << "Restore kernel size failed."; - } - return kLiteError; - } - - MS_LOG(DEBUG) << "Resize executor begin"; - auto status = executor->Resize(inputs, dims); - if (status != kSuccess) { - MS_LOG(ERROR) << "Resize executor failed executor: " << executor->Name(); - return kLiteError; - } - MS_LOG(DEBUG) << "Resize executor end"; - - MS_LOG(DEBUG) << "Graph runtime resize end"; - return kSuccess; -} - -bool DefaultGraphRuntime::ResizeKernels() { - auto infer_execution_plan = std::dynamic_pointer_cast(execution_plan_); - if (infer_execution_plan == nullptr) { - MS_LOG(ERROR) << "Not Supported execution plan is passed"; - return false; - } - auto kernels = infer_execution_plan->ToKernelList(); - auto isolate_input_map = infer_execution_plan->GetInputsMap(); - auto isolate_output_map = infer_execution_plan->GetOutputsMap(); - for (auto kernel : kernels) { - if (kernel == nullptr) { - MS_LOG(ERROR) << "Input kernel is nullptr!"; - return false; - } - int ret; - if (kernel->desc().arch == kernel::kDelegate) { - ret = kernel->ReSize(); - } else { - // resize subgraph inputs - auto sub_graph_kernel = reinterpret_cast(kernel); - for (auto input : sub_graph_kernel->in_tensors()) { - if (isolate_input_map->find(input) != isolate_input_map->end()) { - input->set_shape(isolate_input_map->at(input)->shape()); - input->set_data_type(isolate_input_map->at(input)->data_type()); - input->set_format(isolate_input_map->at(input)->format()); - } - } - ret = sub_graph_kernel->ReSize(); - for (auto output : sub_graph_kernel->out_tensors()) { - if (isolate_input_map->find(output) != isolate_input_map->end()) { - isolate_output_map->at(output)->set_shape(output->shape()); - isolate_output_map->at(output)->set_data_type(output->data_type()); - isolate_output_map->at(output)->set_format(output->format()); - } - } - DrawDot(sub_graph_kernel, "resize"); - } - if (ret == lite::RET_INFER_INVALID) { - MS_LOG(WARNING) << "InferShape is interrupted"; - continue; - } - if (ret != RET_OK) { - MS_LOG(ERROR) << "ReSize node " << kernel->name() << " failed"; - return false; - } - } - return true; -} - -std::vector DefaultGraphRuntime::GetInputs() { - if (execution_plan_ == nullptr) { - MS_LOG(ERROR) << "Execution plan is nullptr."; - return std::vector{}; - } - return execution_plan_->GetInputs(); -} - -std::vector DefaultGraphRuntime::GetOutputs() { - if (execution_plan_ == nullptr) { - MS_LOG(ERROR) << "Execution plan is nullptr."; - return std::vector{}; - } - return execution_plan_->GetOutputs(); -} - -std::shared_ptr DefaultGraphRuntime::SelectExecutor() { - if (executor_ == nullptr) { - if (execution_plan_->GetKernels().size() == 1) { - auto sub_graph = dynamic_cast(execution_plan_->GetKernels()[0]); - if (sub_graph != nullptr) { - if (sub_graph->nodes().size() == 1) { - executor_ = - GraphExecutorRegistry::GetInstance().GetExecutor(kDefaultExecutor, "default-executor", execution_plan_); - return executor_; - } - } - executor_ = GraphExecutorRegistry::GetInstance().GetExecutor(kMindRTExecutor, "mindrt-executor", execution_plan_); - } else { - executor_ = GraphExecutorRegistry::GetInstance().GetExecutor(kMindRTExecutor, "mindrt-executor", execution_plan_); - } - } - return executor_; -} - -static std::shared_ptr DefaultGraphRuntimeCreator() { - auto graph_runtime = std::make_shared(); - return graph_runtime; -} -REG_GRAPH_RUNTIME(kDefaultRuntime, DefaultGraphRuntimeCreator); -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/graph_runtime/default_graph_runtime.h b/mindspore-lite/src/extendrt/graph_runtime/default_graph_runtime.h deleted file mode 100644 index 43bd003a1c27cd0e2199c8ea32cd4c401ba8ec15..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_runtime/default_graph_runtime.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_DEFAULT_GRAPH_RUNTIME_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_DEFAULT_GRAPH_RUNTIME_H_ - -#include -#include - -#include "infer/graph_runtime.h" -#include "src/common/draw/drawer.h" - -namespace mindspore { -class DefaultGraphRuntime : public mindspore::infer::abstract::GraphRuntime { - public: - DefaultGraphRuntime() { InitDotDrawer(); } - virtual ~DefaultGraphRuntime() = default; - - Status Prepare(std::shared_ptr execution_plan) override; - - Status Execute() override; - - Status Execute(const std::vector &inputs, - const std::vector &outputs, - infer::abstract::KernelCallBack before = nullptr, - infer::abstract::KernelCallBack after = nullptr) override; - - Status Resize(const std::vector &inputs, - const std::vector> &dims) override; - - std::vector GetInputs() override; - - std::vector GetOutputs() override; - - private: - std::shared_ptr SelectExecutor(); - bool ResizeKernels(); - - private: - std::shared_ptr execution_plan_ = nullptr; - std::shared_ptr executor_ = nullptr; -}; -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_DEFAULT_GRAPH_RUNTIME_H_ diff --git a/mindspore-lite/src/extendrt/graph_runtime/factory.h b/mindspore-lite/src/extendrt/graph_runtime/factory.h deleted file mode 100644 index 1d13ae58b53751f414bd91e572c3360eaa837323..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_runtime/factory.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_FACTORY_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_FACTORY_H_ - -#include -#include - -#include "extendrt/graph_runtime/type.h" -#include "infer/graph_runtime.h" - -namespace mindspore { -using GraphRuntimeRegFunc = std::function()>; - -class GraphRuntimeRegistry { - public: - GraphRuntimeRegistry() = default; - virtual ~GraphRuntimeRegistry() = default; - - static GraphRuntimeRegistry &GetInstance(); - - void RegRuntime(const GraphRuntimeType &type, const GraphRuntimeRegFunc &creator); - - std::shared_ptr GetRuntime(const mindspore::GraphRuntimeType &type); - - private: - mindspore::HashMap graph_runtime_map_; -}; - -class GraphRuntimeRegistrar { - public: - GraphRuntimeRegistrar(const mindspore::GraphRuntimeType &type, const GraphRuntimeRegFunc &creator) { - GraphRuntimeRegistry::GetInstance().RegRuntime(type, creator); - } - ~GraphRuntimeRegistrar() = default; -}; - -#define REG_GRAPH_RUNTIME(type, creator) static GraphRuntimeRegistrar g_##type##GraphRuntime(type, creator); -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_GRAPH_RUNTIME_FACTORY_H_ diff --git a/mindspore-lite/src/extendrt/graph_runtime/type.h b/mindspore-lite/src/extendrt/graph_runtime/type.h deleted file mode 100644 index 2da8d1a4541652dec3f49901c58ae6655111eb83..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/graph_runtime/type.h +++ /dev/null @@ -1,22 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_RUNTIME_TYPE_H_ -#define MINDSPORE_LITE_EXTENDRT_GRAPH_RUNTIME_TYPE_H_ - -namespace mindspore { -enum GraphRuntimeType { kDefaultRuntime = 0, kNoneRuntime }; -} // namespace mindspore -#endif // MINDSPORE_LITE_EXTENDRT_GRAPH_RUNTIME_TYPE_H_ diff --git a/mindspore-lite/src/extendrt/infer_device_address.cc b/mindspore-lite/src/extendrt/infer_device_address.cc deleted file mode 100644 index db269626f770550d28ebd84c840578cbe0e239af..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/infer_device_address.cc +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "src/extendrt/infer_device_address.h" - -namespace mindspore { -void InferDeviceAddress::ClearDeviceMemory() { - if (GetDevicePtr() == nullptr) { - return; - } - free(GetDevicePtr()); - SetDevicePtr(nullptr); -} - -bool InferDeviceAddress::SyncDeviceToHost(const ShapeVector &, size_t size, TypeId type, void *host_ptr, - bool sync_on_demand) const { - // The input or output may be empty. - if ((size == 0) || (GetSize() == 0)) { - MS_LOG(INFO) << "No need sync, host size: " << size << ", device size: " << GetSize(); - return true; - } - if (GetDevicePtr() == nullptr) { - MS_LOG(ERROR) << "The pointer device ptr is null!"; - return false; - } - if (host_ptr == GetDevicePtr()) { - MS_LOG(DEBUG) << "host_ptr is equal to device ptr, request ignored."; - return true; - } - - if (type == type_id()) { - if (size > GetSize()) { - MS_LOG(WARNING) << "Please check whether need sync data, host size: " << size << ", device size: " << GetSize(); - return true; - } - errno_t ret_code = memcpy_s(host_ptr, size, GetDevicePtr(), size); - // Return ERANGE when the copy size is larger than SECUREC_MEM_MAX_LEN. - if (ret_code != EOK) { - MS_LOG(ERROR) << "Failed to copy tensor!"; - return false; - } else { - return true; - } - } - return true; -} - -bool InferDeviceAddress::SyncHostToDevice(const ShapeVector &, size_t size, TypeId type, const void *host_ptr, - const std::string &) const { - // The input or output may be empty. - if ((size == 0) || (GetSize() == 0)) { - MS_LOG(INFO) << "No need sync, host size: " << size << ", device size: " << GetSize(); - return true; - } - if (GetDevicePtr() == nullptr) { - MS_LOG(ERROR) << "The pointer device ptr() is null!"; - return false; - } - if (host_ptr == GetDevicePtr()) { - MS_LOG(DEBUG) << "host_ptr is equal to device ptr request ignored."; - return true; - } - - if (type == type_id()) { - if (size > GetSize()) { - MS_LOG(WARNING) << "Please check whether need sync data, host size: " << size << ", device size: " << GetSize(); - return true; - } - - // If the value of host is a scalar type, then the host addr is a temporary address, which will be released after - // the sync ends. Therefore, if the value is less than 16, it needs to be copied. -#ifndef __APPLE__ - const size_t kCopySize = 16; - if (size <= kCopySize) { - return ((memcpy_s(GetDevicePtr(), size, host_ptr, size) != EOK) ? false : true); - } -#endif - - SetDevicePtr(const_cast(host_ptr)); - set_original_ref_count(SIZE_MAX); - set_ref_count(SIZE_MAX); - } - return true; -} -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/infer_device_address.h b/mindspore-lite/src/extendrt/infer_device_address.h deleted file mode 100644 index 36c8b0831109849b9aee0b2ad9caa3c648a0824b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/infer_device_address.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_INFER_DEVICE_ADDRESS_H_ -#define MINDSPORE_LITE_EXTENDRT_INFER_DEVICE_ADDRESS_H_ - -#include - -#include "common/device_address.h" - -using DeviceAddress = mindspore::device::DeviceAddress; - -namespace mindspore { -class InferDeviceAddress : public DeviceAddress { - public: - InferDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) - : DeviceAddress(ptr, size, format, type_id) {} - ~InferDeviceAddress() override { ClearDeviceMemory(); } - - bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr, - bool sync_on_demand = false) const override; - bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr, - const std::string &format) const override; - void ClearDeviceMemory() override; -}; -} // namespace mindspore - -#endif // MINDSPORE_LITE_EXTENDRT_INFER_DEVICE_ADDRESS_H_ diff --git a/mindspore-lite/src/extendrt/infer_session.cc b/mindspore-lite/src/extendrt/infer_session.cc index 3648c0121f86310a6139fd7b434433a38527243d..8128956b4b68e60b84213b9896600596233d73b1 100644 --- a/mindspore-lite/src/extendrt/infer_session.cc +++ b/mindspore-lite/src/extendrt/infer_session.cc @@ -18,11 +18,9 @@ #include "common/ms_factory.h" #include "extendrt/delegate/factory.h" #include "extendrt/session/factory.h" -#include "extendrt/delegate/plugin/tensorrt_executor_plugin.h" #include "extendrt/delegate/plugin/litert_executor_plugin.h" #include "extendrt/delegate/plugin/ascend_ge_executor_plugin.h" -#include "extendrt/delegate/plugin/ascend_native_executor_plugin.h" -#include "extendrt/kernel/ascend/plugin/ascend_kernel_plugin.h" +#include "extendrt/delegate/plugin/ascend_acl_executor_plugin.h" #include "nnacl/op_base.h" namespace mindspore { @@ -36,18 +34,9 @@ void AscendPluginRegistration(const std::shared_ptr &ascend_de return; } } - if (use_experimental_rts) { - constexpr auto default_acl_provider = "acl"; - constexpr auto default_ascend_native_provider = "ascend_native"; - if (provider == default_ascend_native_provider) { - if (!lite::AscendNativeExecutorPlugin::GetInstance().Register()) { - MS_LOG(WARNING) << "Failed to register Ascend Native plugin"; - return; - } - } - - if (provider == default_acl_provider) { - if (!kernel::AscendKernelPlugin::Register()) { + if (!use_experimental_rts) { + if (provider == "litert") { + if (!lite::AscendAclExecutorPlugin::GetInstance().Register()) { MS_LOG(WARNING) << "Failed to register Ascend ACL plugin"; return; } @@ -61,7 +50,6 @@ std::shared_ptr InferSession::CreateSession(const std::shared_ptr< bool use_experimental_rts = env != nullptr && strcmp(env, "on") == 0; HandleContext(context, use_experimental_rts); auto session_type = SelectSession(context, use_experimental_rts); - MS_LOG(DEBUG) << "Session type " << static_cast(session_type); return SessionRegistry::GetInstance().GetSession(session_type, context, config_info); } @@ -69,32 +57,18 @@ void InferSession::HandleContext(const std::shared_ptr &context, bool u if (!context) { return; } - constexpr auto default_gpu_provider = "tensorrt"; constexpr auto default_cpu_provider = "litert"; auto device_infos = context->MutableDeviceInfo(); for (auto &device_info : device_infos) { if (!device_info) { - continue; - } - if (device_info->GetDeviceType() == kGPU) { - auto gpu_device = device_info->Cast(); - if (!gpu_device) { - continue; - } - auto provider = gpu_device->GetProvider(); - if (provider.empty() || provider == default_gpu_provider) { - if (!lite::TensorRTExecutorPlugin::GetInstance().Register()) { - MS_LOG(WARNING) << "Failed to register TensorRT plugin"; - return; - } - gpu_device->SetProvider(default_gpu_provider); - } + MS_LOG(WARNING) << "device info is nullptr."; continue; } if (device_info->GetDeviceType() == kAscend) { auto ascend_device = device_info->Cast(); if (!ascend_device) { + MS_LOG(WARNING) << "not ascend device."; continue; } AscendPluginRegistration(ascend_device, use_experimental_rts); @@ -103,6 +77,7 @@ void InferSession::HandleContext(const std::shared_ptr &context, bool u if (device_info->GetDeviceType() == kCPU) { auto cpu_device = device_info->Cast(); if (!cpu_device) { + MS_LOG(WARNING) << "cpu_device"; continue; } auto provider = cpu_device->GetProvider(); @@ -132,7 +107,7 @@ SessionType InferSession::SelectSession(const std::shared_ptr &context, if (device_context->GetProvider() == "ge") { return kDelegateSession; } - return kSingleOpSession; + return kDelegateSession; } return kDelegateSession; } @@ -147,5 +122,4 @@ Status InferSession::Finalize() { MS_LOG(INFO) << "Finalize is only implemented in single_op_session now."; return kLiteError; } - } // namespace mindspore diff --git a/mindspore-lite/src/extendrt/infer_session.h b/mindspore-lite/src/extendrt/infer_session.h index b7fdac1b82bed5791eb0ee4788c2c5f981b30008..2854a664789a48f370ecfdee0aa314ddcc1b64e0 100644 --- a/mindspore-lite/src/extendrt/infer_session.h +++ b/mindspore-lite/src/extendrt/infer_session.h @@ -64,6 +64,7 @@ class InferSession : public std::enable_shared_from_this { /// /// \param[in] model_data Define the buffer read from a model file. /// \param[in] data_size Define bytes number of model buffer. + /// /// \return Status. virtual Status CompileGraph(const void *model_data, size_t data_size, uint32_t *graph_id) { (void)model_data; @@ -78,8 +79,8 @@ class InferSession : public std::enable_shared_from_this { /// \param[out] outputs Which is a pointer to a vector. The model outputs are filled in the container in sequence. /// /// \return Status. - virtual Status RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs) = 0; + virtual Status RunGraph(uint32_t graph_id, const std::vector &inputs, + std::vector *outputs) = 0; /// \brief Run Model Graph to inference. /// @@ -89,8 +90,8 @@ class InferSession : public std::enable_shared_from_this { /// \param[in] after CallBack after predict. /// /// \return Status. - virtual Status RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs, const MSKernelCallBack &before, + virtual Status RunGraph(uint32_t graph_id, const std::vector &inputs, + std::vector *outputs, const MSKernelCallBack &before, const MSKernelCallBack &after) = 0; /// \brief Resize model inputs shape and memory from specified dims. @@ -99,7 +100,7 @@ class InferSession : public std::enable_shared_from_this { /// \param[in] dims Define dst resize shapes. /// /// \return Status. - virtual Status Resize(uint32_t graph_id, const std::vector &inputs, + virtual Status Resize(uint32_t graph_id, const std::vector &inputs, const std::vector> &dims) { (void)graph_id; (void)inputs; @@ -137,7 +138,7 @@ class InferSession : public std::enable_shared_from_this { /// \return The input tensor with the given name, if the name is not found, an invalid tensor is returned. virtual MutableTensorImplPtr GetInputByTensorName(uint32_t graph_id, const std::string &name) = 0; - virtual Status UpdateWeights(const std::vector>> &weights) { + virtual Status UpdateWeights(const std::vector>> &weights) { return kLiteError; } diff --git a/mindspore-lite/src/extendrt/kernel/acl/acl_kernel_lib.cc b/mindspore-lite/src/extendrt/kernel/acl/acl_kernel_lib.cc deleted file mode 100644 index 88ca77f76fc74c554164c7b2766bfa7d269c3490..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/acl/acl_kernel_lib.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/kernel/acl/acl_kernel_lib.h" -#include "src/extendrt/kernel/ascend/api/ascend_kernel_api.h" -#include "src/extendrt/kernel/acl/acl_lite_kernel.h" -#include "src/extendrt/kernel/kernel_spec_infos.h" -#include "common/ms_factory.h" -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "src/infer/graph_compiler.h" -#include "src/common/common.h" - -namespace mindspore::kernel { -std::shared_ptr AclKernelLib::CreateKernelMod(const PrimitiveType &op_type, const KernelAttr &attr, - const Format &format, const std::string &backend) { - if (backend != kBackendAscend) { - MS_LOG(INFO) << "AclKernelLib only support Ascend backend, but got: " << backend << "."; - return nullptr; - } - if (!MatchFormat(format, Format::NCHW)) { - MS_LOG(INFO) << "AclKernelLib only support NCHW layout, but got " << FormatEnumToString(format); - return nullptr; - } - - auto kernel_name = lite::kNameCustomAscend; - std::shared_ptr kernel_mod = kernel::Factory::Instance().Create(kernel_name); - - if (kernel_mod == nullptr) { - MS_LOG(INFO) << "Create kernel mod failed. kernel: " << op_type.TypeName(); - return nullptr; - } - // acl custom inputs and outputs number is not fixed, so do not checkout kernel attr here - return kernel_mod; -} - -bool AclKernelLib::Support(const PrimitiveType &op_type, const KernelAttr &attr, const std::string &backend, - const Format &format) const { - return AclKernelLib::CreateKernelMod(op_type, attr, format, backend) != nullptr; -} - -BaseKernel *AclKernelLib::CreateKernel(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) const { - auto kernel_mod = AclKernelLib::CreateKernelMod(spec.op_type, spec.attr, spec.format, spec.backend); - if (kernel_mod == nullptr) { - MS_LOG(ERROR) << "Create kernel mod failed. kernel: " << spec.op_type.TypeName(); - return nullptr; - } - return new (std::nothrow) AclLiteKernel(kernel_mod, spec.primitive, inputs, outputs, ctx); -} - -REG_KERNEL_LIB(kAclKernelLibName, AclKernelLib); -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/kernel/acl/acl_kernel_lib.h b/mindspore-lite/src/extendrt/kernel/acl/acl_kernel_lib.h deleted file mode 100644 index e5c176e64aedf560be3190a319b8954a6be66525..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/acl/acl_kernel_lib.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ACL_DEFAULT_LIB_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ACL_DEFAULT_LIB_H_ - -#include -#include -#include -#include "src/extendrt/kernel/kernel_lib.h" -#include "src/extendrt/kernel/kernel_spec_infos.h" - -using mindspore::infer::abstract::Tensor; - -namespace mindspore::kernel { -class AclKernelLib : public KernelLib { - public: - AclKernelLib() : KernelLib(kAclKernelLibName, "Ascend") {} - - bool Support(const PrimitiveType &op_type, const KernelAttr &attr, const std::string &backend, - const Format &format = DEFAULT_FORMAT) const override; - BaseKernel *CreateKernel(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) const override; - - private: - static std::shared_ptr CreateKernelMod(const PrimitiveType &op_type, - const KernelAttr &attr, const Format &format, - const std::string &backend); -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ACL_DEFAULT_LIB_H_ diff --git a/mindspore-lite/src/extendrt/kernel/acl/acl_lite_kernel.cc b/mindspore-lite/src/extendrt/kernel/acl/acl_lite_kernel.cc deleted file mode 100644 index 7ba479f494665750e54418fbadf608d177892173..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/acl/acl_lite_kernel.cc +++ /dev/null @@ -1,153 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "src/extendrt/kernel/acl/acl_lite_kernel.h" -#include "src/extendrt/utils/tensor_utils.h" - -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; - -namespace mindspore::kernel { -AclLiteKernel::AclLiteKernel(std::shared_ptr kernel_mod, BaseOperatorPtr base_operator, - std::vector in_tensors, std::vector out_tensors, - const lite::InnerContext *ctx) - : BaseKernel({base_operator, nullptr}, std::move(in_tensors), std::move(out_tensors), ctx), - kernel_mod_(std::move(kernel_mod)), - base_operator_(std::move(base_operator)) { - inputs_ = CloudTensorUtils::LiteTensorToKernelTensorPtrVec(in_tensors_); - outputs_ = CloudTensorUtils::LiteTensorToKernelTensorPtrVec(out_tensors_); -} - -int AclLiteKernel::Prepare() { - bool ret = kernel_mod_->Init(base_operator_->GetPrim(), inputs_, outputs_); - return ret ? ReSize() : RET_ERROR; -} - -int AclLiteKernel::ReSize() { - // acl custom kernel last input is om data, do not pass to resize - std::vector kernel_inputs; - kernel_inputs.assign(inputs_.begin(), inputs_.end() - 1); - - return kernel_mod_->Resize(kernel_inputs, outputs_); -} - -int AclLiteKernel::InferShape() { - // new shape is already updated in in_tensors_, infer shape base on in_tensors_ - - // current acl do not support change static shape - - // if infer of in_tensors_ is not changed, do nothing - bool shape_changed = false; - if (inputs_.size() != in_tensors_.size()) { - MS_LOG(ERROR) << "New shape size " << in_tensors_.size() << " is not the same with old shape size " - << inputs_.size(); - return lite::RET_ERROR; - } - // in_tensors_ last is om data, delete it - for (size_t i = 0; i < inputs_.size() - 1; i++) { - auto new_input = in_tensors_.at(i); - auto old_input = inputs_.at(i); - - auto new_shape = new_input->shape(); - auto is_dynamic = std::any_of(new_shape.begin(), new_shape.end(), [i](auto dim) { return dim < 0; }); - if (is_dynamic) { - MS_LOG(ERROR) << "New shape of input " << i << " cannot be dynamic, new shape: " << new_shape; - return lite::RET_NOT_SUPPORT; - } - if (old_input->GetShapeVector() != new_input->shape64()) { - shape_changed = true; - } - } - if (!shape_changed) { - for (size_t i = 0; i < outputs_.size(); i++) { - auto new_output = out_tensors_.at(i); - auto old_output = outputs_.at(i); - new_output->set_shape64(old_output->GetShapeVector()); - new_output->set_data_type(old_output->dtype_id()); - } - return lite::RET_OK; - } - - return lite::RET_NOT_SUPPORT; -} - -int AclLiteKernel::Run() { - std::vector workspace; - if (in_tensors_.size() != inputs_.size()) { - MS_LOG(ERROR) << "Given inputs size " << in_tensors_.size() << " != graph inputs size " << inputs_.size(); - return kLiteError; - } - for (size_t i = 0; i < in_tensors_.size() - 1; i++) { - auto &input = in_tensors_[i]; - auto &kernel_input = inputs_[i]; - if (input->Size() != kernel_input->size()) { - MS_LOG(ERROR) << "Byte size of input " << i << " != the size expected, given size " << input->Size() - << ", expected size " << kernel_input->size() - << ", input shape: " << kernel_input->GetShapeVector(); - return kLiteError; - } - auto input_device_address = input->device_data(); - if (input_device_address != nullptr) { - auto device_ptr = input_device_address; - kernel_input->SetData(std::make_shared(device_ptr, input->Size())); - kernel_input->SetHostData(nullptr); - } else { - kernel_input->SetHostData(std::make_shared(input->data(), input->Size())); - kernel_input->SetData(nullptr); - } - } - // solve out_tensors empty case - if (out_tensors_.empty()) { - std::transform(outputs_.begin(), outputs_.end(), std::back_inserter(out_tensors_), [](auto &item) { - auto shape64 = item->GetShapeVector(); - std::vector shape; - std::transform(shape64.begin(), shape64.end(), std::back_inserter(shape), - [](auto &value) { return static_cast(value); }); - return new lite::Tensor(item->dtype_id(), shape); - }); - } - if (out_tensors_.size() != outputs_.size()) { - MS_LOG(ERROR) << "Given outputs size " << out_tensors_.size() << " != graph inputs size " << outputs_.size(); - return kLiteError; - } - for (size_t i = 0; i < out_tensors_.size(); i++) { - auto output = out_tensors_[i]; - auto kernel_output = outputs_[i]; - if (output->Size() != kernel_output->size()) { - MS_LOG(ERROR) << "Byte size of output " << i << " != the size expected, given size " << output->Size() - << ", expected size " << kernel_output->size() - << ", output shape: " << kernel_output->GetShapeVector(); - return kLiteError; - } - auto output_device_address = output->device_data(); - if (output_device_address != nullptr) { - auto device_ptr = output_device_address; - kernel_output->SetData(std::make_shared(device_ptr, output->Size())); - kernel_output->SetHostData(nullptr); - } else { - kernel_output->SetHostData(std::make_shared(output->data(), output->Size())); - kernel_output->SetData(nullptr); - } - } - - auto ret = kernel_mod_->Launch(workspace, workspace, workspace, nullptr); - - return ret ? RET_OK : RET_ERROR; -} -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/kernel/acl/acl_lite_kernel.h b/mindspore-lite/src/extendrt/kernel/acl/acl_lite_kernel.h deleted file mode 100644 index bd73258df6fd9d212072054b4c7b08f066328616..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/acl/acl_lite_kernel.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ACL_ACL_LITE_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ACL_ACL_LITE_KERNEL_H_ - -#include -#include -#include -#include -#include "src/extendrt/kernel/base_kernel.h" -#include "common/kernel.h" -#include "ops/base_operator.h" - -namespace mindspore::kernel { -class AclLiteKernel : public BaseKernel { - public: - explicit AclLiteKernel(std::shared_ptr kernel_mod, BaseOperatorPtr base_operator, - std::vector in_tensors, std::vector out_tensors, - const lite::InnerContext *ctx); - ~AclLiteKernel() override = default; - - int Prepare() override; - int ReSize() override; - int Run() override; - int InferShape(); - - public: - std::string KernelType() const { return base_operator_->name(); } - - private: - KernelModPtr kernel_mod_; - BaseOperatorPtr base_operator_; - std::vector inputs_; - std::vector outputs_; -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ACL_ACL_LITE_KERNEL_H_ diff --git a/mindspore-lite/src/extendrt/kernel/acl/custom_acl.cc b/mindspore-lite/src/extendrt/kernel/acl/custom_acl.cc deleted file mode 100644 index f6fb0e8c5f0aa72796b76aece268dcc71f1ead7c..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/acl/custom_acl.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include "infer/custom.h" -#include "abstract/abstract_value.h" -#include "abstract/dshape.h" -#include "abstract/ops/op_infer.h" -#include "abstract/ops/primitive_infer_map.h" -#include "base/base.h" -#include "ir/anf.h" -#include "ir/dtype/number.h" -#include "ir/dtype/type.h" -#include "ir/primitive.h" -#include "mindapi/base/shape_vector.h" -#include "mindapi/base/shared_ptr.h" -#include "mindapi/base/type_id.h" -#include "ops/primitive_c.h" -#include "utils/check_convert_utils.h" -#include "src/common/log_adapter.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore { -namespace ops { -namespace { -constexpr int64_t dynamic_rank_shape_value = -2; -} -class CustomAclInfer : public abstract::OpInferBase { - public: - BaseShapePtr InferShape(const PrimitivePtr &primitive, - const std::vector &input_args) const override { - // Acl Custom ops cannot infer at build time, fill with -2 for dynamic - ShapeVector ret_shape; - ret_shape.push_back(dynamic_rank_shape_value); - return std::make_shared(ret_shape); - } - - TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override { - // return first input datatype - auto op_name = primitive->name(); - auto x = CheckAndConvertUtils::CheckArgs(op_name, input_args, 0); - return x->element()->GetTypeTrack(); - } -}; - -GVAR_DEF(PrimitivePtr, kPrimCustomAcl, std::make_shared(kNameCustom)); -REGISTER_PRIMITIVE_OP_INFER_IMPL(Custom, kPrimCustomAcl, CustomAclInfer, false); -} // namespace ops -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/kernel/ascend/plugin/ascend_kernel_plugin.cc b/mindspore-lite/src/extendrt/kernel/ascend/plugin/ascend_kernel_plugin.cc deleted file mode 100644 index d6321eea6384979f334c20b4db0c4a7f3c557393..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/ascend/plugin/ascend_kernel_plugin.cc +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/kernel/ascend/plugin/ascend_kernel_plugin.h" -#include -#include -#include "src/common/log_adapter.h" -#include "include/errorcode.h" -#include "common/ms_factory.h" -#if !defined(_WIN32) -#include -#include "extendrt/cxx_api/dlutils.h" -#endif - -namespace mindspore::kernel { -std::mutex AscendKernelPlugin::mutex_; - -AscendKernelPlugin::AscendKernelPlugin() = default; - -AscendKernelPlugin::~AscendKernelPlugin() { - MS_LOG(DEBUG) << "~AscendKernelPlugin() begin."; - Unregister(); - MS_LOG(DEBUG) << "~AscendKernelPlugin() end."; -} - -Status AscendKernelPlugin::TryRegister() { - std::lock_guard lock(mutex_); - static AscendKernelPlugin instance; - return instance.TryRegisterInner(); -} - -bool AscendKernelPlugin::Register() { - std::lock_guard lock(mutex_); - static AscendKernelPlugin instance; - auto status = instance.TryRegisterInner(); - if (status.IsError()) { - MS_LOG(ERROR) << status.ToString(); - return false; - } - MS_LOG(INFO) << "Register ascend kernel plugin success."; - return true; -} - -Status AscendKernelPlugin::TryRegisterInner() { -#if !defined(_WIN32) - if (is_registered_) { - return kSuccess; - } - Dl_info dl_info; - dladdr(reinterpret_cast(this), &dl_info); - std::string cur_so_path = dl_info.dli_fname; - auto converter_pos = cur_so_path.find("libmindspore_converter.so"); - if (converter_pos != std::string::npos) { - MS_LOG(INFO) << "libmindspore_converter.so does not need to register"; - return kSuccess; - } - auto pos = cur_so_path.find("libmindspore-lite.so"); - if (pos == std::string::npos) { - MS_LOG(DEBUG) << "Could not find libmindspore-lite so, cur so path: " << cur_so_path; - auto c_lite_pos = cur_so_path.find("_c_lite"); - if (c_lite_pos == std::string::npos) { - return {kLiteError, "Could not find _c_lite so, cur so path: " + cur_so_path}; - } - pos = c_lite_pos; - } - std::string parent_dir = cur_so_path.substr(0, pos); - std::string ascend_kernel_plugin_path; - auto ret = FindSoPath(parent_dir, "libascend_kernel_plugin.so", &ascend_kernel_plugin_path); - if (ret != kSuccess) { - return {kLiteError, "Get real path of libascend_kernel_plugin.so failed."}; - } - MS_LOG(INFO) << "Find ascend kernel plugin so success, path = " << ascend_kernel_plugin_path; - void *function = nullptr; - ret = DLSoOpen(ascend_kernel_plugin_path, "CreateCustomAscendKernel", &handle_, &function); - if (ret != kSuccess) { - return {kLiteError, "DLSoOpen failed, so path: " + ascend_kernel_plugin_path}; - } - auto create_kernel_func = reinterpret_cast *(*)(void)>(function); - if (create_kernel_func == nullptr) { - return {kLiteError, "Cast CreateCustomAscendKernel failed."}; - } - create_kernel_map_ = create_kernel_func(); - if (create_kernel_map_ == nullptr) { - return {kLiteError, "Create custom ascend kernel failed."}; - } - // register - for (auto &kernel : *create_kernel_map_) { - if (!kernel::Factory::Instance().IsRegistered(kernel.first)) { - kernel::Factory::Instance().Register(kernel.first, std::move(kernel.second)); - register_kernels_.push_back(kernel.first); - } - } - is_registered_ = true; - return kSuccess; -#endif -} - -void AscendKernelPlugin::Unregister() { -#if !defined(_WIN32) - if (handle_ == nullptr) { - MS_LOG(INFO) << "Handle is nullptr."; - return; - } - for (auto &kernel : register_kernels_) { - kernel::Factory::Instance().UnRegister(kernel); - } - auto destroy_map_func = - reinterpret_cast *)>(dlsym(handle_, "DestroyCustomAscendKernel")); - if (destroy_map_func == nullptr) { - MS_LOG(ERROR) << "Undefined symbol DestroyCustomAscendKernel in ['libascend_kernel_plugin.so']."; - return; - } - destroy_map_func(create_kernel_map_); - (void)dlclose(handle_); - handle_ = nullptr; - is_registered_ = false; -#endif -} -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/kernel/ascend/src/custom_ascend_kernel.cc b/mindspore-lite/src/extendrt/kernel/ascend/src/custom_ascend_kernel.cc deleted file mode 100644 index bd1f0bc7028542cbfd7213fd9db2f64b1a0124e1..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/ascend/src/custom_ascend_kernel.cc +++ /dev/null @@ -1,339 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/kernel/ascend/src/custom_ascend_kernel.h" -#include -#include -#include "include/registry/register_kernel.h" -#include "include/api/types.h" -#include "include/api/data_type.h" -#include "extendrt/kernel/ascend/model/model_infer.h" -#include "infer/custom.h" -#include "common/ms_factory.h" -#include "src/common/log_util.h" -#include "src/common/common.h" -#include "common/log_adapter.h" -#include "plugin/res_manager/ascend/symbol_interface/acl_rt_symbol.h" -#include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" - -bool SaveOM(const void *model, size_t length, const std::string &file_path) { return true; } - -namespace mindspore::kernel { -namespace acl { -namespace { -std::mutex g_mem_mutex; -} -CustomAscendKernelMod::CustomAscendKernelMod() - : load_model_(false), acl_options_(nullptr), model_infer_(nullptr), input_data_idx_(0) {} - -CustomAscendKernelMod::~CustomAscendKernelMod() { - std::unique_lock lock(g_mem_mutex); - if (load_model_ || is_multi_model_sharing_mem_prepare_) { - MS_LOG(INFO) << "Delete model from AclEnvGuard!"; - AclEnvGuard::DeleteModel(model_infer_); - if (!model_infer_->Finalize()) { - MS_LOG(ERROR) << "Model finalize failed."; - } - } -} - -bool CustomAscendKernelMod::Finalize() { return AclEnvGuard::Finalize(); } - -void CustomAscendKernelMod::RecordInputDataIndex(const std::vector &inputs) { - for (size_t idx = 0; idx < inputs.size(); ++idx) { - if (inputs[idx] == nullptr) { - MS_LOG(ERROR) << "Input " << idx << " is invalid."; - return; - } - if (inputs[idx]->GetData() == nullptr) { - input_data_idx_ = idx; - break; - } - } -} - -AclModelOptionsPtr CustomAscendKernelMod::GenAclOptions() { - auto acl_options_ptr = std::make_shared(); - if (acl_options_ptr == nullptr) { - MS_LOG(ERROR) << "Acl options make shared failed."; - return nullptr; - } - auto profiling_path_val = primitive_->GetAttr(lite::kProfilingPathKey); - if (profiling_path_val != nullptr) { - auto val = GetValue(profiling_path_val); - acl_options_ptr->profiling_path = val; - } - auto dump_path_val = primitive_->GetAttr(lite::kDumpPathKey); - if (dump_path_val != nullptr) { - auto val = GetValue(dump_path_val); - acl_options_ptr->dump_path = val; - } - auto inner_calc_workspace_size = primitive_->GetAttr(lite::kInnerCalcWorkspaceSize); - if (inner_calc_workspace_size != nullptr) { - auto val = GetValue(inner_calc_workspace_size); - acl_options_ptr->multi_model_sharing_mem_prepare = val; - is_multi_model_sharing_mem_prepare_ = true; - } - auto inner_sharing_workspace = primitive_->GetAttr(lite::kInnerSharingWorkspace); - if (inner_sharing_workspace != nullptr) { - auto val = GetValue(inner_sharing_workspace); - acl_options_ptr->multi_model_sharing_mem = val; - } - auto inner_model_path = primitive_->GetAttr(lite::kInnerModelPath); - if (inner_model_path != nullptr) { - auto val = GetValue(inner_model_path); - acl_options_ptr->model_path = val; - } - auto workspace_key = primitive_->GetAttr(lite::kInnerWorkspace); - if (workspace_key != nullptr) { - auto val = GetValue(workspace_key); - acl_options_ptr->share_workspace = val; - } - auto weightspace_key = primitive_->GetAttr(lite::kInnerWeightspace); - if (weightspace_key != nullptr) { - auto val = GetValue(weightspace_key); - acl_options_ptr->share_weightspace = val; - } - auto weightspace_workspace_key = primitive_->GetAttr(lite::kInnerWeightspaceWorkspace); - if (weightspace_workspace_key != nullptr) { - auto val = GetValue(weightspace_workspace_key); - acl_options_ptr->share_weightspace_workspace = val; - } - auto bundle_model = primitive_->GetAttr(lite::kBundleModel); - if (bundle_model != nullptr) { - auto val = GetValue(bundle_model); - acl_options_ptr->is_bundle_model = val; - } - // set device id - uint32_t device_count; - if (CALL_ASCEND_API(aclrtGetDeviceCount, &device_count) != ACL_SUCCESS) { - MS_LOG(WARNING) << "Get device count failed, set default device id 0."; - return acl_options_ptr; - } - if (device_id_ >= device_count) { - MS_LOG(WARNING) << "Current device id " << device_id_ << " is larger than max count " << device_count - << ",please check the device info of context and set the default device id 0."; - return acl_options_ptr; - } - acl_options_ptr->device_id = static_cast(device_id_); - MS_LOG(INFO) << "Set device id " << device_id_; - return acl_options_ptr; -} - -bool CustomAscendKernelMod::Init(const std::vector &inputs, - const std::vector &outputs) { - std::unique_lock lock(g_mem_mutex); - inputs_ = inputs; - outputs_ = outputs; - if (load_model_) { - MS_LOG(INFO) << "Om has been loaded in custom kernel."; - return true; - } - // last input is as specific usage - inputs_.resize(inputs.size() - 1); - - if (inputs.empty() || outputs.empty()) { - MS_LOG(ERROR) << "Custom kernel has empty inputs or outputs, which is invalid."; - return false; - } - acl_options_ = GenAclOptions(); - if (acl_options_ == nullptr) { - MS_LOG(ERROR) << "Generate acl options failed."; - return false; - } - auto &om_input = inputs.back(); - if (om_input == nullptr || om_input->GetData() == nullptr) { - MS_LOG(ERROR) << "Om data input is invalid, inputs size " << inputs.size(); - return false; - } - auto om_data = om_input->GetData(); - model_infer_ = std::make_shared(acl_options_); - if (model_infer_ == nullptr) { - MS_LOG(ERROR) << "Create ModelInfer failed."; - return false; - } - if (!model_infer_->Init()) { - MS_LOG(ERROR) << "Model infer init failed."; - return false; - } - if (!model_infer_->Load(om_data->addr, om_data->size)) { - MS_LOG(ERROR) << "Load om data failed."; - return false; - } - - SaveOM(om_data->addr, om_data->size, "./"); - - if (is_multi_model_sharing_mem_prepare_) { - MS_LOG(INFO) << "is multi model sharing mem prepare."; - return true; - } - UpdateInputKernelTensorInfo(); - UpdateOutputKernelTensorInfo(); - MS_LOG(INFO) << "Load om data success."; - load_model_ = true; - MS_LOG(INFO) << "Add model to AclEnvGuard!"; - AclEnvGuard::AddModel(model_infer_); - return true; -} - -int CustomAscendKernelMod::Resize(const std::vector &inputs, - const std::vector &outputs) { - if (!load_model_) { - MS_LOG(ERROR) << "Load model failed when resize."; - return lite::RET_ERROR; - } - - if (KernelMod::Resize(inputs, outputs) != KRET_OK) { - MS_LOG(WARNING) << "Invalid inputs or output shapes."; - } - - if (inputs.size() < 1) { - MS_LOG(ERROR) << "inputs size is less than one."; - return lite::RET_ERROR; - } - if (!OnNewInputShapes(inputs)) { - MS_LOG(ERROR) << "Failed to resize inputs"; - return lite::RET_ERROR; - } - return lite::RET_OK; -} - -template -static bool CheckInputNums(const std::vector &update_info, const std::vector &inputs, size_t input_weight = 0) { - if (update_info.empty()) { - MS_LOG(ERROR) << "check update info size empty"; - return false; - } - if (update_info.size() + input_weight != inputs.size()) { - MS_LOG(ERROR) << "update info size and inputs size check failed. update info size: " << update_info.size() - << ". inputs' size: " << inputs.size() << ". input weight: " << input_weight - << "update_info:" << update_info; - return false; - } - return true; -} - -template -static bool CheckOutputNums(const std::vector &update_info, const std::vector &outputs) { - if (update_info.empty()) { - MS_LOG(ERROR) << "check update info size empty"; - return false; - } - if (update_info.size() != outputs.size()) { - MS_LOG(ERROR) << "update info size and outputs size check failed. update info size: " << update_info.size() - << ". outputs' size: " << outputs.size(); - return false; - } - return true; -} - -bool CustomAscendKernelMod::OnNewInputShapes(const std::vector &new_inputs) { - auto input_shapes = model_infer_->GetInputShape(); - if (input_shapes.size() != new_inputs.size()) { - MS_LOG(ERROR) << "Invalid new input size " << new_inputs.size() << ", expect input size " << input_shapes.size(); - return false; - } - bool input_shape_changed = false; - for (size_t i = 0; i < new_inputs.size(); i++) { - auto new_shape = new_inputs[i]->GetShapeVector(); - if (input_shapes[i] != new_shape) { - input_shape_changed = true; - } - } - if (!input_shape_changed) { - return true; - } - std::vector new_shapes; - std::transform(new_inputs.begin(), new_inputs.end(), std::back_inserter(new_shapes), - [](auto &t) { return t->GetShapeVector(); }); - if (!model_infer_->Resize(new_shapes)) { - MS_LOG(ERROR) << "Failed to Resize"; - return false; - } - UpdateInputKernelTensorInfo(); - UpdateOutputKernelTensorInfo(); - return true; -} - -bool CustomAscendKernelMod::Launch(const std::vector &, const std::vector &, - const std::vector &, void *stream_ptr) { - if (!load_model_) { - MS_LOG(ERROR) << "Init custom ascend kernel has been not ready."; - return false; - } - UpdateOutputKernelTensorInfo(); - if (!model_infer_->Inference(inputs_, outputs_)) { - MS_LOG(ERROR) << "Custom kernel execute failed."; - return false; - } - return true; -} - -bool CustomAscendKernelMod::UpdateWeights(const std::vector &kernel_weights, - const std::vector &, const std::vector &, - void *stream_ptr) { - if (!load_model_) { - MS_LOG(ERROR) << "Init custom ascend kernel has been not ready!"; - return false; - } - if (!model_infer_->UpdateWeights(kernel_weights)) { - MS_LOG(ERROR) << "Update weights failed!"; - return false; - } - return true; -} - -void CustomAscendKernelMod::UpdateOutputKernelTensorInfo() { - if (model_infer_ == nullptr) { - MS_LOG(ERROR) << "update input shape fail because model_infer_ is nullptr"; - return; - } - const std::vector shapes = model_infer_->GetOutputShape(); - const std::vector types = model_infer_->GetOutputDataType(); - const std::vector formats = model_infer_->GetOutputFormat(); - if (!CheckOutputNums(shapes, outputs_) || !CheckOutputNums(types, outputs_) || !CheckOutputNums(formats, outputs_)) { - return; - } - for (size_t i = 0; i < outputs_.size(); ++i) { - auto &output = outputs_[i]; - output->SetType(std::make_shared(TypeIdToType(types[i]))); - output->SetShape(std::make_shared(shapes[i])); - output->set_format(formats[i]); - } - return; -} -// In DVPP, model input shape and data type get modified -void CustomAscendKernelMod::UpdateInputKernelTensorInfo() { - if (model_infer_ == nullptr) { - MS_LOG(ERROR) << "update input shape fail because model_infer_ is nullptr"; - return; - } - const std::vector shapes = model_infer_->GetInputShape(); - const std::vector types = model_infer_->GetInputDataType(); - const std::vector formats = model_infer_->GetInputFormat(); - if (!CheckInputNums(shapes, inputs_) || !CheckInputNums(types, inputs_) || !CheckInputNums(formats, inputs_)) { - return; - } - - for (size_t i = 0; i < inputs_.size(); ++i) { - auto &input = inputs_[i]; - input->SetType(std::make_shared(TypeIdToType(types[i]))); - input->SetShape(std::make_shared(shapes[i])); - input->set_format(formats[i]); - } -} -} // namespace acl -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/kernel/ascend/src/custom_ascend_kernel.h b/mindspore-lite/src/extendrt/kernel/ascend/src/custom_ascend_kernel.h deleted file mode 100644 index 91e157a8caf42f93f7b11db9c08da53157dfbfd2..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/ascend/src/custom_ascend_kernel.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_CUSTOM_ASCEND_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_CUSTOM_ASCEND_KERNEL_H_ - -#include -#include -#include -#include -#include "extendrt/kernel/ascend/options/acl_model_options.h" -#include "extendrt/kernel/ascend/model/model_infer.h" -#include "include/api/types.h" -#include "include/api/context.h" -#include "common/kernel.h" -#include "common/common_utils.h" -#include "include/errorcode.h" - -bool SaveOM(const void *model, size_t length, const std::string &file_path); - -namespace mindspore::kernel { -namespace acl { -class CustomAscendKernelMod : public kernel::KernelMod { - public: - CustomAscendKernelMod(); - ~CustomAscendKernelMod() override; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - bool Finalize() override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - bool UpdateWeights(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr); - - std::vector GetOpSupport() override { return {}; } - - private: - void RecordInputDataIndex(const std::vector &inputs); - AclModelOptionsPtr GenAclOptions(); - void UpdateInputKernelTensorInfo(); - void UpdateOutputKernelTensorInfo(); - bool OnNewInputShapes(const std::vector &new_shapes); - - bool load_model_; - AclModelOptionsPtr acl_options_; - ModelInferPtr model_infer_; - size_t input_data_idx_; - bool is_multi_model_sharing_mem_prepare_ = false; - std::vector inputs_; - std::vector outputs_; -}; -} // namespace acl -} // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_CUSTOM_ASCEND_KERNEL_H_ diff --git a/mindspore-lite/src/extendrt/kernel/ascend_native/ascend_native_composite_kernel.cc b/mindspore-lite/src/extendrt/kernel/ascend_native/ascend_native_composite_kernel.cc deleted file mode 100644 index b04645d37d1b407f96ab61e511965cfd3522e514..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/ascend_native/ascend_native_composite_kernel.cc +++ /dev/null @@ -1,518 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/kernel/ascend_native/ascend_native_composite_kernel.h" -#include -#include -#include "extendrt/delegate/ascend_native/ascend_native_kernel_registry.h" -#include "extendrt/utils/func_graph_utils.h" -#include "extendrt/delegate/ascend_native/ops/ascend_native_composite.h" -#include "ops/primitive_c.h" -#include "src/train/opt_allocator.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" - -#define DPRN() std::cout -namespace mindspore::kernel { -using mindspore::ops::AscendNativeComposite; - -int AscendNativeCompositeKernel::InferShape() { - for (auto &kernel : kernels_) { - auto ret = kernel->InferShape(); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "kernel InferShape failed for " << kernel->get_name(); - return lite::RET_ERROR; - } - } - return lite::RET_OK; -} - -static inline BaseOperatorPtr CreateOperatorByCNode(const CNodePtr &cnode) { - auto prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(prim); - auto kernel_name = prim->name(); - // Create PrimtiveC from map and create BaseOperator. - mindspore::ops::PrimitiveCPtr primc_ptr = nullptr; - static auto primc_fns = mindspore::ops::OpPrimCRegister::GetInstance().GetPrimCMap(); - if (primc_fns.find(kernel_name) != primc_fns.end()) { - primc_ptr = primc_fns[kernel_name](); - (void)primc_ptr->SetAttrs(prim->attrs()); - } - MS_EXCEPTION_IF_NULL(primc_ptr); - - static auto operator_fns = mindspore::ops::OperatorRegister::GetInstance().GetOperatorMap(); - if (operator_fns.find(kernel_name) == operator_fns.end()) { - MS_LOG(EXCEPTION) << "Cannot create BaseOperator for " << kernel_name; - } - auto base_operator = operator_fns[kernel_name](primc_ptr); - MS_EXCEPTION_IF_NULL(base_operator); - return base_operator; -} - -std::shared_ptr AscendNativeCompositeKernel::CreateKernel(const AnfNodePtr &node) { - if (!node->isa()) { - MS_LOG(ERROR) << "AscendNativeCompositeKernel::CreateKernel not a cnode"; - return nullptr; - } - auto cnode = node->cast(); - if (cnode == nullptr) { - MS_LOG(ERROR) << "AscendNativeCompositeKernel::CreateKernel cnode is nullptr"; - return nullptr; - } - auto prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(prim); - // step II - Prepare kernel attributes - std::vector input_tensors; - CreateInputKernelTensors(cnode, &input_tensors); - std::vector output_tensors; - CreateOutputKernelTensors(cnode, &output_tensors); - kernel::InferPrimitive primitive; - primitive.base_operator = CreateOperatorByCNode(cnode); - primitive.cnode = cnode; - auto kernel_name = cnode->fullname_with_scope(); - auto node_type = primitive.base_operator->name(); - // step III - Create Ascend native Kernel - auto &plugin_factory = kernel::AscendNativeRegistrationFactory::Get(); - // TODO(nizzan) :: remove stub patch - if (!plugin_factory.HasKey(node_type)) node_type = "AscendNativeStub"; - if (plugin_factory.HasKey(node_type)) { - kernel::AscendNativeBaseKernel *ascend_native_op = - plugin_factory.GetCreator(node_type)(input_tensors, output_tensors, primitive, context_, stream_, node_type); - if (ascend_native_op == nullptr) { - return nullptr; - } - auto ker = std::shared_ptr(ascend_native_op); - if (ker == nullptr) { - MS_LOG(ERROR) << "Kernel is nullptr"; - return nullptr; - } - - if (!ker->IsWeightInputHanledInner()) { - auto in_tensors = ker->in_tensors(); - for (auto &t : in_tensors) { - MS_EXCEPTION_IF_NULL(t); - if (t->IsConst() && (t->data() == nullptr)) { - MS_LOG(ERROR) << "no data to tensor " << t->tensor_name(); - return nullptr; - } - if (t->IsConst() && t->device_data() == nullptr) { - bool t_is_float = (t->data_type() == kNumberTypeFloat || t->data_type() == kNumberTypeFloat32); - if (t_is_float) { - void *device_ptr = nullptr; - ascend_native::CopyHostFp32ToDeviceFp16(t->data(), &device_ptr, t->ElementsNum(), - const_cast(stream_)); - t->set_device_data(device_ptr); - } else { - t->set_device_data(ascend_native::MallocCopy(t->data(), t->Size(), const_cast(stream_))); - } - } - } - } - // TODO(nizzan) :: remove if - if (node_type == "AscendNativeStub") { - ker->set_name(primitive.base_operator->name()); - } - return ker; - } else { - MS_LOG(WARNING) << "Unsupported op type for ascend native. kernel name:" << kernel_name << " type:" << node_type; - return nullptr; - } -} - -int AscendNativeCompositeKernel::GetIdxFromString(std::string str) { - auto idx_str = str.rfind("_"); - std::string sub = str.substr(idx_str + 1); - return std::stoi(sub); -} - -static inline kernel::InferTensor *anfTensorToTensorInfo(const common::KernelWithIndex &tensor_id) { - auto prev_node = tensor_id.first; - auto tensor_val = FuncGraphUtils::GetConstNodeValue(prev_node); - - constexpr auto tensorrt_format = mindspore::Format::NCHW; - auto name = FuncGraphUtils::GetTensorName(tensor_id); - auto shape = FuncGraphUtils::GetTensorShape(tensor_id); - auto data_type = FuncGraphUtils::GetTensorDataType(tensor_id); - auto format = tensorrt_format; - const void *data = nullptr; - size_t data_len = 0; - if (tensor_val) { - data = tensor_val->data_c(); - data_len = tensor_val->Size(); - shape = tensor_val->shape_c(); - } - std::vector t_shape; - t_shape.resize(shape.size()); - std::transform(shape.begin(), shape.end(), t_shape.begin(), [](int64_t x) { return static_cast(x); }); - - auto t = kernel::InferTensor::CreateTensor(name, static_cast(data_type), t_shape, data, data_len); - if (t == nullptr) { - MS_LOG(EXCEPTION) << "Cannot CreateTensor for " << name; - } - t->set_format(format); - return t; -} - -void AscendNativeCompositeKernel::CreateInputKernelTensors(const CNodePtr &cnode, - std::vector *input_tensors) { - input_tensors->clear(); - auto graph_inputs = func_graph_->get_inputs(); - auto input_nodes = FuncGraphUtils::GetNodeInputs(cnode); - auto cnode_inputs = this->primitive_.cnode->inputs(); - for (auto &tensor_id : input_nodes) { - bool found_tensor = false; - for (size_t j = 0; j < graph_inputs.size(); j++) { - if (tensor_id.first == graph_inputs[j]) { - int idx = GetIdxFromString(tensor_id.first->fullname_with_scope()); - input_tensors->push_back(in_tensors_[idx]); - allocated_tensors_.insert(in_tensors_[idx]); - auto it = std::find_if(kernel_list_.begin(), kernel_list_.end(), - [&tensor_id](const KernelWithIndexAndTensor &k) { return k.kernel_index == tensor_id; }); - if (it == kernel_list_.end()) { - kernel_list_.push_back(KernelWithIndexAndTensor(tensor_id, in_tensors_[idx])); - } - found_tensor = true; - break; - } - } - if (!found_tensor) { - for (size_t j = 1; j < cnode_inputs.size(); j++) { - if (tensor_id.first == cnode_inputs[j]) { - input_tensors->push_back(in_tensors_[j - 1]); - allocated_tensors_.insert(in_tensors_[j - 1]); - auto it = - std::find_if(kernel_list_.begin(), kernel_list_.end(), - [&tensor_id](const KernelWithIndexAndTensor &k) { return k.kernel_index == tensor_id; }); - if (it == kernel_list_.end()) { - kernel_list_.push_back(KernelWithIndexAndTensor(tensor_id, in_tensors_[j - 1])); - } - found_tensor = true; - break; - } - } - if (!found_tensor) { - auto it = std::find_if(kernel_list_.begin(), kernel_list_.end(), - [&tensor_id](const KernelWithIndexAndTensor &k) { return k.kernel_index == tensor_id; }); - // tensor already created - use the same tensor - if (it != kernel_list_.end()) { - input_tensors->push_back(it->tensor_info); - } else { - auto tensor_info = anfTensorToTensorInfo(tensor_id); - if (tensor_info == nullptr) { - MS_LOG(ERROR) << "failed to get tensor info"; - return; - } - input_tensors->push_back(tensor_info); - kernel_list_.push_back(KernelWithIndexAndTensor(tensor_id, tensor_info)); - } - } - } - } -} - -void AscendNativeCompositeKernel::CreateOutputKernelTensors(const CNodePtr &cnode, - std::vector *output_tensors) { - output_tensors->clear(); - auto output_num = mindspore::AnfUtils::GetOutputTensorNum(cnode); - bool output_found = false; - for (size_t output_idx = 0; output_idx < output_num; ++output_idx) { - mindspore::common::KernelWithIndex tensor_id = {cnode, output_idx}; - auto it = std::find_if(kernel_list_.begin(), kernel_list_.end(), - [&tensor_id](const KernelWithIndexAndTensor &k) { return k.kernel_index == tensor_id; }); - if (it != kernel_list_.end()) { - output_tensors->push_back(it->tensor_info); - output_found = true; - } else { - auto graph_output = func_graph_->output(); - if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) { - auto outc = graph_output->cast(); - for (size_t i = 1; i < outc->size(); i++) { - if (IsPrimitiveCNode(outc->input(i), prim::kPrimTupleGetItem)) { - auto get_item = outc->input(i)->cast(); - auto tuple_idx = common::AnfAlgo::GetTupleGetItemOutIndex(get_item); - if ((get_item->input(SECOND_INPUT) == cnode) && (tuple_idx == output_idx)) { - out_tensors_[i - 1]->set_device_data( - ascend_native::MallocDevice(out_tensors_[i - 1]->Size(), const_cast(stream_))); - out_tensors_[i - 1]->ResetRefCount(); - allocated_tensors_.insert(out_tensors_[i - 1]); - output_tensors->push_back(out_tensors_[i - 1]); - output_found = true; - } - } else if (outc->input(i) == cnode) { - out_tensors_[i - 1]->set_device_data( - ascend_native::MallocDevice(out_tensors_[i - 1]->Size(), const_cast(stream_))); - out_tensors_[i - 1]->ResetRefCount(); - allocated_tensors_.insert(out_tensors_[i - 1]); - output_tensors->push_back(out_tensors_[i - 1]); - output_found = true; - } - } - } else { - if (graph_output == cnode) { - out_tensors_[0]->set_device_data( - ascend_native::MallocDevice(out_tensors_[0]->Size(), const_cast(stream_))); - out_tensors_[0]->ResetRefCount(); - allocated_tensors_.insert(out_tensors_[0]); - output_tensors->push_back(out_tensors_[0]); - output_found = true; - } - } - } - if (!output_found) { - auto tensor_info = anfTensorToTensorInfo(tensor_id); - output_tensors->push_back(tensor_info); - kernel_list_.push_back(KernelWithIndexAndTensor(tensor_id, tensor_info)); - } - } -} - -int AscendNativeCompositeKernel::AllocTensors() { - OptAllocator allocator; - std::unordered_map ref_count; - offset_map_.clear(); - for (auto &kernel : kernels_) { - // malloc output tensors - for (auto &tensor : kernel->out_tensors()) { - if ((allocated_tensors_.find(tensor) == allocated_tensors_.end())) { - if (offset_map_.find(tensor) == offset_map_.end()) { - size_t tensor_size = tensor->Size(); - size_t offset = allocator.Malloc(tensor_size); - offset_map_[tensor] = offset; - ref_count[tensor] = tensor->init_ref_count(); - } - } - } - // free according to reference counter - for (auto &tensor : kernel->in_tensors()) { - if ((tensor->category() == lite::Category::VAR) && - ((allocated_tensors_.find(tensor) == allocated_tensors_.end()))) { - int count = ref_count[tensor] - 1; - ref_count[tensor] = count; - if (count == 0) { - allocator.Free(offset_map_[tensor]); - } - } - } - } - // Set Tensor data - device_mem_size_ = allocator.total_size(); - return ReAllocTensors(); -} - -int AscendNativeCompositeKernel::ReAllocTensors() { - if (device_memory_base_addr_ != nullptr) { - return lite::RET_OK; - } - if (device_mem_size_ > 0) { - device_memory_base_addr_ = ascend_native::MallocDevice(device_mem_size_, const_cast(stream_)); - if (device_memory_base_addr_ == nullptr) { - MS_LOG(EXCEPTION) << "Allocation of " << device_mem_size_ << "B on device failed"; - return kMDOutOfMemory; - } - for (auto &it : offset_map_) { - auto &tensor = it.first; - tensor->set_device_data( - reinterpret_cast(reinterpret_cast(device_memory_base_addr_) + it.second)); - } - } - return lite::RET_OK; -} - -void AscendNativeCompositeKernel::FreeDevice() { - ascend_native::FreeDevice(device_memory_base_addr_, const_cast(stream_)); - device_memory_base_addr_ = nullptr; - for (auto &it : offset_map_) { - auto &tensor = it.first; - tensor->set_device_data(nullptr); - } -} - -void AscendNativeCompositeKernel::InitializeTensorRefrenceCnt() { - for (auto &kernel : kernels_) { - for (auto tensor : kernel->in_tensors()) { - if (tensor->category() == lite::VAR || tensor->category() == lite::GRAPH_INPUT) { - auto ref_count = tensor->init_ref_count(); - tensor->set_init_ref_count(ref_count + 1); - } - } - } -} - -int AscendNativeCompositeKernel::AllocateGraphTensors() { - if (device_memory_base_addr_ == nullptr) { - InitializeTensorRefrenceCnt(); - } else { - FreeDevice(); - } - return AllocTensors(); -} - -int AscendNativeCompositeKernel::AllocateGraphWorkspace(size_t ws_size) { - if (get_workspace() != nullptr) return lite::RET_OK; - void *ws_ptr = nullptr; - if (ws_size > 0) { - if (ws_size > max_ws_size_) { - MS_LOG(ERROR) << "kernel ws is too big " << ws_size; - return kLiteError; - } - // alloc ws on device space - ws_ptr = ascend_native::MallocDevice(ws_size, const_cast(stream_)); - if (ws_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Allocation of " << ws_size << "B on device failed"; - return kMDOutOfMemory; - } - set_workspace(ws_ptr); - set_workspace_size(ws_size); - for (auto &kernel : kernels_) { - kernel->set_workspace(ws_ptr); - } - } - return lite::RET_OK; -} - -int AscendNativeCompositeKernel::Prepare() { - auto nodes = TopoSort(func_graph_->get_return()); - for (auto &node : nodes) { - if (!node->isa() || !AnfUtils::IsRealKernel(node)) { - continue; - } - auto kernel = CreateKernel(node); - if (kernel == nullptr) { - MS_LOG(ERROR) << "composite create kernel failed."; - return lite::RET_ERROR; - } - kernels_.emplace_back(kernel); - } - if (kernels_.empty()) { - MS_LOG(ERROR) << "composite does not support empty subgraph now."; - return lite::RET_ERROR; - } - // call kernel prepare - size_t ws_size = 0; - for (auto &kernel : kernels_) { - auto ret = kernel->Prepare(); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "composite kernel prepare failed with " << ret; - return lite::RET_ERROR; - } - size_t k_ws_size = kernel->get_workspace_size(); - if (k_ws_size > ws_size) ws_size = k_ws_size; - } - if (AllocateGraphWorkspace(ws_size) != lite::RET_OK) { - MS_LOG(ERROR) << "kernel workspace allocation failed "; - return lite::RET_ERROR; - } - if (AllocateGraphTensors() != lite::RET_OK) { - MS_LOG(ERROR) << "kernel graph allocation failed "; - return lite::RET_ERROR; - } - return lite::RET_OK; -} - -int AscendNativeCompositeKernel::Run() { - MS_LOG(INFO) << "AscendNativeCompositeKernel::Execute Begin"; - // call kernel run interface one by one - for (auto &kernel : kernels_) { - auto ret = kernel->PreProcess(); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "kernel preprocess failed with " << ret << " for " << kernel->get_name(); - return lite::RET_ERROR; - } - ret = kernel->Run(); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "kernel run failed with " << ret << " for " << kernel->get_name(); - return lite::RET_ERROR; - } - // synchronize all tasks are finished - ascend_native::SyncDevice(const_cast(stream_)); - ret = kernel->PostProcess(); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "kernel postprocess failed with " << ret << " for " << kernel->get_name(); - return lite::RET_ERROR; - } - } - MS_LOG(INFO) << "AscendNativeCompositeKernel::Execute End"; - return lite::RET_OK; -} - -int AscendNativeCompositeKernel::PostProcess() { - // Free device data - FreeDevice(); - ascend_native::FreeDevice(get_workspace(), const_cast(stream_)); - set_workspace(nullptr); - // Decrement inputs ref count - for (size_t i = 0; i < in_tensors_.size(); i++) { - auto ref = in_tensors_[i]->ref_count(); - in_tensors_[i]->set_ref_count(--ref); - if ((ref <= 0) && (in_tensors_[i]->category() == lite::VAR)) { - ascend_native::FreeDevice(in_tensors_[i]->device_data(), const_cast(stream_)); - in_tensors_[i]->set_device_data(nullptr); - } - } - return lite::RET_OK; -} - -int AscendNativeCompositeKernel::PreProcess() { - for (auto &tensor : out_tensors()) { - if (tensor->device_data() == nullptr) { - auto data_ptr = ascend_native::MallocDevice(tensor->Size(), const_cast(stream_)); - if (data_ptr == nullptr) { - MS_LOG(ERROR) << "Cannot allocate device memory size:" << tensor->Size(); - return lite::RET_NULL_PTR; - } - tensor->set_device_data(data_ptr); - } - tensor->ResetRefCount(); - } - auto ws_size = get_workspace_size(); - ReAllocTensors(); - if (AllocateGraphWorkspace(ws_size) != lite::RET_OK) { - MS_LOG(ERROR) << "kernel workspace allocation failed "; - return kLiteError; - } - if (InferShape() != lite::RET_OK) { - MS_LOG(ERROR) << "InferShape AscendNativeCompositeKernel failed "; - return kLiteError; - } - return lite::RET_OK; -} - -int AscendNativeCompositeKernel::ReSize() { - size_t ws_size = 0; - for (auto &kernel : kernels_) { - size_t k_ws_size = kernel->get_workspace_size(); - if (k_ws_size > ws_size) ws_size = k_ws_size; - auto ret = kernel->ReSize(); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "kernel" << kernel->get_name() << " ReSize failed "; - return lite::RET_ERROR; - } - } - auto ret = AllocateGraphWorkspace(ws_size); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "kernel workspace allocation failed "; - return lite::RET_ERROR; - } - ret = AllocateGraphTensors(); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "kernel graph allocation failed "; - return lite::RET_ERROR; - } - return lite::RET_OK; -} -REGISTER_ASCEND_NATIVE_CREATOR(ops::kNameAscendNativeComposite, AscendNativeCompositeKernel) -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/kernel/ascend_native/ascend_native_composite_kernel.h b/mindspore-lite/src/extendrt/kernel/ascend_native/ascend_native_composite_kernel.h deleted file mode 100644 index 1ab64c95c99a649d6ff51a054b790fa6d6552fc3..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/ascend_native/ascend_native_composite_kernel.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_COMPOSITE_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_COMPOSITE_KERNEL_H_ - -#include "extendrt/delegate/ascend_native/ascend_native_base_kernel.h" -#include -#include -#include -#include -#include -#include "infer/context.h" - -namespace mindspore::kernel { -class AscendNativeCompositeKernel : public AscendNativeBaseKernel { - public: - // AscendNativeCompositeKernel = delete; - - AscendNativeCompositeKernel(const std::vector &inputs, const std::vector &outputs, - InferPrimitive prim, const InferContext *ctx, const void *stream, std::string name) - : AscendNativeBaseKernel(inputs, outputs, prim, ctx, stream, name) {} - - int Prepare() override; - int Run() override; - int PostProcess() override; - int PreProcess() override; - int InferShape() override; - int ReSize() override; - - bool IsWeightInputHanledInner() const override { return true; } - - void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } - - private: - std::shared_ptr CreateKernel(const AnfNodePtr &node); - void CreateInputKernelTensors(const CNodePtr &cnode, std::vector *input_tensors); - void CreateOutputKernelTensors(const CNodePtr &cnode, std::vector *output_tensors); - int GetIdxFromString(std::string str); - int AllocateGraphTensors(); - int AllocTensors(); - void InitializeTensorRefrenceCnt(); - void FreeDevice(); - int ReAllocTensors(); - int AllocateGraphWorkspace(size_t size); - size_t get_workspace_size() const override { return ws_size_; } - void set_workspace_size(size_t size) { ws_size_ = size; } - FuncGraphPtr func_graph_; - std::vector kernel_list_; - std::vector> kernels_; - void *device_memory_base_addr_ = nullptr; - size_t device_mem_size_ = 0; - std::set allocated_tensors_; - std::unordered_map offset_map_; - size_t ws_size_{0}; - - static constexpr size_t max_ws_size_ = static_cast(2100) * (1 << 20); -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_COMPOSITE_KERNEL_H_ diff --git a/mindspore-lite/src/extendrt/kernel/ascend_native/ascend_native_copy_kernel.cc b/mindspore-lite/src/extendrt/kernel/ascend_native/ascend_native_copy_kernel.cc deleted file mode 100644 index c77840e7f750d7b129b58b93544be60e93894b3b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/ascend_native/ascend_native_copy_kernel.cc +++ /dev/null @@ -1,234 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/kernel/ascend_native/ascend_native_copy_kernel.h" -#include -#include "extendrt/delegate/ascend_native/ascend_native_kernel_registry.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/utils.h" -#include "extendrt/delegate/ops/copy.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore::kernel { -int AscendNativeCopyKernel::InferShape() { - out_tensors_[0]->set_shape(in_tensors_[0]->shape()); - auto const &data_type = in_tensors_[0]->data_type(); - bool is_float = - (data_type == kNumberTypeFloat32) || (data_type == kNumberTypeFloat16) || (data_type == kNumberTypeFloat); - if (copy_type_ == ops::Copy::CopyFormatType::HOST_DEVICE) { - if (is_float) { - out_tensors_[0]->set_data_type(kNumberTypeFloat16); - } else { - out_tensors_[0]->set_data_type(data_type); - } - } else if (copy_type_ == ops::Copy::CopyFormatType::DEVICE_HOST) { - if (is_float) { - out_tensors_[0]->set_data_type(kNumberTypeFloat32); - } else { - out_tensors_[0]->set_data_type(data_type); - } - } - return lite::RET_OK; -} - -int AscendNativeCopyKernel::Prepare() { - auto prim = GetValueNode(primitive_.cnode->input(0)); - copy_type_ = static_cast(GetValue(prim->GetAttr(mindspore::ops::kCopyFormat))); - auto ret = InferShape(); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "Ascend native copy kernel inferShape failed."; - return lite::RET_ERROR; - } - return lite::RET_OK; -} - -int AscendNativeCopyKernel::PreProcess() { - out_tensors_[0]->ResetRefCount(); - switch (copy_type_) { - case ops::Copy::CopyFormatType::HOST_DEVICE: { - if (out_tensors_[0]->device_data() == nullptr) { - auto device_data = ascend_native::MallocDevice(out_tensors_[0]->Size(), const_cast(stream_)); - if (device_data == nullptr) { - MS_LOG(ERROR) << "fail to allocate " << out_tensors_[0]->Size() << "Bytes for device"; - return lite::RET_NULL_PTR; - } - out_tensors_[0]->set_device_data(device_data); - } - break; - } - case ops::Copy::CopyFormatType::DEVICE_HOST: { - if (out_tensors_[0]->data() == nullptr) { - out_tensors_[0]->MallocData(); - if (out_tensors_[0]->data() == nullptr) { - MS_LOG(ERROR) << "fail to allocate " << out_tensors_[0]->Size() << "Bytes for host"; - return lite::RET_ERROR; - } - } - break; - } - case ops::Copy::CopyFormatType::NONE: { - MS_LOG(WARNING) << "Ascend native copy kernel type is none. Kernel is redundant."; - break; - } - default: { - MS_LOG(ERROR) << "Ascend native copy kernel execute - copy type not supported."; - return lite::RET_ERROR; - } - } - return lite::RET_OK; -} - -void PrintTensor(InferTensor *tensor, int max_len, void *stream) { - int elem = std::min(static_cast(tensor->ElementsNum()), max_len); - std::cout << "device tensor " << tensor->tensor_name() << " data:"; - switch (tensor->data_type()) { - case kNumberTypeFloat16: - ascend_native::PrintFp16(tensor->device_data(), elem, stream); - break; - case kNumberTypeFloat32: - ascend_native::PrintFp32(tensor->device_data(), elem, stream); - break; - case kNumberTypeInt32: - ascend_native::PrintInt32(tensor->device_data(), elem, stream); - break; - default: - std::cout << "not supported type " << tensor->data_type() << std::endl; - } - std::cout << std::endl; -} - -void PrintTensor(InferTensor *tensor, int max_len) { - int elem = std::min(static_cast(tensor->ElementsNum()), max_len); - std::cout << "host tensor " << tensor->tensor_name() << " data:"; - switch (tensor->data_type()) { - case kNumberTypeFloat32: { - auto ptr = static_cast(tensor->data()); - for (int i = 0; i < elem; i++) { - std::cout << *(ptr + i) << " "; - } - break; - } - case kNumberTypeInt32: { - auto ptr = static_cast(tensor->data()); - for (int i = 0; i < elem; i++) { - std::cout << *(ptr + i) << " "; - } - break; - } - default: - std::cout << "not supported type " << tensor->data_type() << std::endl; - } - std::cout << std::endl; -} - -int AscendNativeCopyKernel::Run() { - MS_LOG(INFO) << "AscendNativeCopyKernel::Execute"; - auto elem = out_tensors_[0]->ElementsNum(); - // Execute copy - switch (copy_type_) { - case ops::Copy::CopyFormatType::HOST_DEVICE: { - if (in_tensors_[0]->data() == nullptr) { - MS_LOG(ERROR) << "no host data to tensor " << in_tensors_[0]->tensor_name(); - return lite::RET_ERROR; - } - void *dst = out_tensors_[0]->device_data(); - if (dst == nullptr) { - MS_LOG(ERROR) << "no output tensor allocated"; - return lite::RET_ERROR; - } - bool t_is_float = - (in_tensors_[0]->data_type() == kNumberTypeFloat || in_tensors_[0]->data_type() == kNumberTypeFloat32); - if (t_is_float) { - ascend_native::CopyHostFp32ToDeviceFp16(in_tensors_[0]->data(), &dst, elem, const_cast(stream_)); - } else { - int elem_size = mindspore::lite::DataTypeSize(in_tensors_[0]->data_type()); - switch (elem_size) { - case Num4: - ascend_native::CopyHostFp32ToDeviceFp32(in_tensors_[0]->data(), &dst, elem, const_cast(stream_)); - break; - case Num2: - ascend_native::CopyHostFp16ToDeviceFp16(in_tensors_[0]->data(), &dst, elem, const_cast(stream_)); - break; - case Num1: - ascend_native::CopyHostFp16ToDeviceFp16(in_tensors_[0]->data(), &dst, elem / 2, - const_cast(stream_)); - break; - default: - MS_LOG(ERROR) << "no supported size " << elem_size; - return lite::RET_ERROR; - } - } - break; - } - case ops::Copy::CopyFormatType::DEVICE_HOST: { - if (in_tensors_[0]->device_data() == nullptr) { - MS_LOG(ERROR) << "no device data to tensor " << in_tensors_[0]->tensor_name(); - return lite::RET_ERROR; - } - out_tensors_[0]->set_data_type(kNumberTypeFloat32); - if (elem * sizeof(float) > out_tensors_[0]->Size()) { - MS_LOG(ERROR) << "wrong output size"; - return lite::RET_ERROR; - } - ascend_native::CopyDeviceFp16ToHostFp32(in_tensors_[0]->device_data(), out_tensors_[0]->data(), elem, - const_cast(stream_)); - break; - } - case ops::Copy::CopyFormatType::NONE: { - MS_LOG(WARNING) << "Ascend native copy kernel type is none. Kernel is redundant."; - break; - } - default: { - MS_LOG(ERROR) << "Ascend native copy kernel execute - copy type not supported. " << copy_type_; - return lite::RET_ERROR; - } - } - return lite::RET_OK; -} - -int AscendNativeCopyKernel::PostProcess() { - switch (copy_type_) { - case ops::Copy::CopyFormatType::HOST_DEVICE: { - in_tensors_[0]->DecRefCount(); - break; - } - case ops::Copy::CopyFormatType::DEVICE_HOST: { - auto ref = in_tensors_[0]->ref_count() - 1; - in_tensors_[0]->set_ref_count(ref); - if (ref < 0) { - MS_LOG(ERROR) << "less than zero reference count"; - return lite::RET_ERROR; - } - if (ref == 0) { - ascend_native::FreeDevice(in_tensors_[0]->device_data(), const_cast(stream_)); - in_tensors_[0]->set_device_data(nullptr); - } - break; - } - case ops::Copy::CopyFormatType::NONE: { - MS_LOG(WARNING) << "Ascend native copy kernel type is none. Kernel is redundant."; - break; - } - default: { - MS_LOG(ERROR) << "Ascend native copy kernel execute - copy type not supported."; - return lite::RET_ERROR; - } - } - return lite::RET_OK; -} - -int AscendNativeCopyKernel::ReSize() { return lite::RET_OK; } -REGISTER_ASCEND_NATIVE_CREATOR(ops::kNameCopy, AscendNativeCopyKernel) -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/kernel/ascend_native/ascend_native_copy_kernel.h b/mindspore-lite/src/extendrt/kernel/ascend_native/ascend_native_copy_kernel.h deleted file mode 100644 index 9487d204c783b97e0891c848acd0831ae3bab2b7..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/ascend_native/ascend_native_copy_kernel.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_COPY_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_COPY_KERNEL_H_ - -#include -#include -#include -#include "extendrt/delegate/ascend_native/ascend_native_base_kernel.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/utils.h" -#include "extendrt/delegate/ops/copy.h" - -namespace mindspore::kernel { -class AscendNativeCopyKernel : public AscendNativeBaseKernel { - public: - AscendNativeCopyKernel(const std::vector &inputs, const std::vector &outputs, - InferPrimitive prim, const InferContext *ctx, const void *stream, std::string name) - : AscendNativeBaseKernel(inputs, outputs, prim, ctx, stream, name) {} - - int InferShape() override; - - int Prepare() override; - - int Run() override; - - int PreProcess() override; - - int PostProcess() override; - - int ReSize() override; - - private: - ops::Copy::CopyFormatType copy_type_; -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_NATIVE_COPY_KERNEL_H_ diff --git a/mindspore-lite/src/extendrt/kernel/base_kernel.h b/mindspore-lite/src/extendrt/kernel/base_kernel.h deleted file mode 100644 index 7de9f3904636a9ebdc858863edb1cb8a3cb102bc..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/base_kernel.h +++ /dev/null @@ -1,160 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_EXTENDRT_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_EXTENDRT_KERNEL_H_ - -#include -#include -#include -#include "include/api/kernel_api.h" -#include "ops/base_operator.h" -#include "ir/anf.h" -#include "include/api/status.h" -#include "src/infer/tensor.h" -#include "src/infer/context.h" -#include "src/infer/primitive_type.h" -#include "src/extendrt/graph_compiler/infershape_helper.h" - -namespace mindspore::kernel { -using BaseOperatorPtr = std::shared_ptr; -using InferContext = mindspore::infer::abstract::Context; -using InferTensor = mindspore::infer::abstract::Tensor; -struct InferPrimitive { - BaseOperatorPtr base_operator{nullptr}; - CNodePtr cnode{nullptr}; -}; - -class BaseKernel : public IKernel { - public: - BaseKernel(InferPrimitive primitive, const InferContext *ctx) : primitive_(std::move(primitive)), context_(ctx) { - auto base_operator = primitive_.base_operator; - if (base_operator != nullptr) { - type_ = PrimitiveType(primitive_.base_operator->name()); - } - } - - BaseKernel(InferPrimitive primitive, std::vector in_tensors, std::vector out_tensors, - const InferContext *ctx) - : primitive_(std::move(primitive)), - in_tensors_(std::move(in_tensors)), - out_tensors_(std::move(out_tensors)), - context_(ctx) { - auto base_operator = primitive_.base_operator; - if (base_operator != nullptr) { - type_ = PrimitiveType(primitive_.base_operator->name()); - } - } - - int Prepare() override { return kLiteError; } - - int InferShape() override { return lite::NodeFallBackInferShape(primitive_.cnode, NCHW); } - - int Execute() override { - auto ret = PreProcess(); - if (lite::RET_OK != ret) { - MS_LOG(ERROR) << "run kernel PreProcess failed, name: " << this->name(); - return ret; - } - - ret = Run(); - if (lite::RET_OK != ret) { - MS_LOG(ERROR) << "run kernel failed, name: " << this->name(); - return ret; - } - - ret = PostProcess(); - if (lite::RET_OK != ret) { - MS_LOG(ERROR) << "run kernel PostProcess failed, name: " << this->name(); - return ret; - } - return lite::RET_OK; - } - - virtual int Run() { return lite::RET_OK; } - - virtual bool InferShapeDone() const { - auto checker = context_ != nullptr ? context_->get_infer_checker() : lite::InferCheckerOutput; - return checker != nullptr && checker(in_tensors_, out_tensors_); - } - - virtual int PreProcess() { - if (!InferShapeDone()) { - auto ret = InferShape(); - if (ret != 0) { - MS_LOG(ERROR) << "InferShape fail!"; - return ret; - } - ret = ReSize(); - if (ret != 0) { - MS_LOG(ERROR) << "ReSize fail!ret: " << ret; - return ret; - } - } - - for (auto *output : this->out_tensors()) { - MS_ASSERT(output != nullptr); - auto ret = output->MallocData(); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "MallocData failed"; - return ret; - } - output->ResetRefCount(); - } - return lite::RET_OK; - } - // called after Run - virtual int PostProcess() { - for (auto &in_tensor : this->in_tensors()) { - MS_ASSERT(in_tensor != nullptr); - if (in_tensor->category() == lite::VAR) { - in_tensor->DecRefCount(); - } - } - return lite::RET_OK; - } - - int ReSize() override { return lite::RET_ERROR; } - - const std::vector &inputs() override; - - const std::vector &outputs() override; - - virtual void set_in_tensors(const std::vector &in_tensors) { this->in_tensors_ = in_tensors; } - - virtual void set_in_tensor(InferTensor *in_tensor, size_t index); - - virtual void set_out_tensors(const std::vector &out_tensors) { this->out_tensors_ = out_tensors; } - - virtual void set_out_tensor(InferTensor *out_tensor, size_t index); - - virtual const std::vector &in_tensors() const { return in_tensors_; } - - virtual const std::vector &out_tensors() const { return out_tensors_; } - - PrimitiveType type() { return type_; } - - virtual OpParameter *op_parameter() const { return nullptr; } - - protected: - InferPrimitive primitive_; - PrimitiveType type_; - std::vector in_tensors_; - std::vector out_tensors_; - const InferContext *context_ = nullptr; -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_EXTENDRT_KERNEL_H_ diff --git a/mindspore-lite/src/extendrt/kernel/cpu/less_test_kernel_mod.cc b/mindspore-lite/src/extendrt/kernel/cpu/less_test_kernel_mod.cc deleted file mode 100644 index f6343ce360e0b458d8b3d4bdf2ccf6302718003b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/cpu/less_test_kernel_mod.cc +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "extendrt/kernel/cpu/less_test_kernel_mod.h" -#include "mindspore/ops/op_def/comparison_ops.h" -#include "common/ms_factory.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" - -namespace mindspore::kernel { -const size_t test_input_size = 2; -const int test_input_shape = 7; - -bool LessTestKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) { - MS_LOG(INFO) << "LessTestKernelMod::Launch"; - // test shape 7 value - if (inputs.size() < test_input_size || outputs.size() < 1) { - MS_LOG(ERROR) << "input or output size is wrong!"; - return false; - } - auto x = static_cast(inputs[0]->device_ptr()); - auto y = static_cast(inputs[1]->device_ptr()); - auto z = static_cast(outputs[0]->device_ptr()); - - for (int i = 0; i < test_input_shape; i++) { - if (x[i] < y[i]) { - z[i] = true; - } else { - z[i] = false; - } - } - - for (int i = 0; i < test_input_shape; i++) { - MS_LOG(INFO) << "LessTestKernelMod::Launch z " << z[i]; - } - - return true; -} - -bool LessTestKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - MS_LOG(INFO) << "LessTestKernelMod::Init"; - return true; -} - -MS_KERNEL_FACTORY_REG_BY_CREATOR(CpuKernelMod, Less, - []() { return std::make_shared(prim::kPrimLess->name()); }); -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.cc b/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.cc deleted file mode 100644 index dfef81ad04573380ec99b2a84fe17d6d26a49b33..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.cc +++ /dev/null @@ -1,440 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/kernel/cpu/transpose_kernel_mod.h" -#include -#include -#include "mindspore/ops/op_def/array_ops.h" -#include "common/ms_factory.h" -#include "include/api/status.h" -#include "nnacl/errorcode.h" -#include "src/common/log_util.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" - -namespace mindspore::kernel { -namespace { -constexpr size_t kTransposeInputsNum = 2; -constexpr size_t kTransposeOutputsNum = 1; -constexpr size_t kIndex0 = 0; -constexpr size_t kIndex1 = 1; -constexpr size_t kIndex2 = 2; -constexpr size_t kIndex3 = 3; -constexpr size_t kIndex4 = 4; -constexpr size_t kIndex5 = 5; -constexpr size_t kIndex6 = 6; -constexpr size_t kIndex7 = 7; -// kMaxTransposeSerialSize = 64 * 3 * 512 * 512 -constexpr size_t kMaxTransposeSerialSize = 50331648; -} // namespace - -bool TransposeKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTransposeInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTransposeOutputsNum, kernel_name_); - launch_func_(this, inputs, outputs); - return true; -} - -int TransposeKernelMod::Resize(const std::vector &inputs, const std::vector &outputs) { - return kSuccess; -} - -bool TransposeKernelMod::Init(const std::vector &inputs, const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTransposeInputsNum, kernel_name_); - CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTransposeOutputsNum, kernel_name_); - MS_CHECK_TRUE_RET(inputs[kIndex0] != nullptr && outputs[kIndex0] != nullptr, false); - input_shape_ = inputs[kIndex0]->GetShapeVector(); - output_shape_ = outputs[kIndex0]->GetShapeVector(); - MS_CHECK_TRUE_RET(inputs[kIndex1] != nullptr, false); - auto address_ptr = inputs[kIndex1]->GetData(); - if (address_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Address ptr is nullptr."; - } - int *addr = static_cast(address_ptr->addr); - if (addr == nullptr) { - MS_LOG(EXCEPTION) << "Cast addr failed."; - } - std::vector perm; - for (size_t i = 0; i < (address_ptr->size) / sizeof(int); ++i) { - perm.emplace_back(static_cast(addr[i])); - } - for (auto p : perm) { - p = (p >= 0) ? p : (perm.size() + p); - if (p < 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the perm value must be in [-" << perm.size() << ", " - << (perm.size() - 1) << "], but got " << perm; - } - axes_.emplace_back(p); - } - dtype_ = inputs[kIndex0]->dtype_id(); - if (axes_.size() > MAX_TRANSPOSE_DIM_SIZE) { - MS_LOG(EXCEPTION) << "Transpose support max dimension is " << MAX_TRANSPOSE_DIM_SIZE << "D, but got " - << axes_.size() << "D."; - } - for (size_t i = 0; i < axes_.size(); ++i) { - transpose_param_.perm_[i] = SizeToInt(axes_[i]); - } - size_t num_axes = input_shape_.size(); - transpose_param_.perm_size_ = axes_.size(); - transpose_param_.num_axes_ = SizeToInt(num_axes); - transpose_param_.strides_[num_axes - 1] = 1; - transpose_param_.out_strides_[num_axes - 1] = 1; - for (size_t i = num_axes - 1; i >= 1; i--) { - transpose_param_.strides_[i - 1] = input_shape_[i] * transpose_param_.strides_[i]; - transpose_param_.out_strides_[i - 1] = output_shape_[i] * transpose_param_.out_strides_[i]; - } - launch_map_[kNumberTypeBool] = &TransposeKernelMod::LaunchKernel; - launch_map_[kNumberTypeInt8] = &TransposeKernelMod::LaunchKernel; - launch_map_[kNumberTypeInt16] = &TransposeKernelMod::LaunchKernel; - launch_map_[kNumberTypeInt32] = &TransposeKernelMod::LaunchKernel; - launch_map_[kNumberTypeInt64] = &TransposeKernelMod::LaunchKernel; - launch_map_[kNumberTypeUInt8] = &TransposeKernelMod::LaunchKernel; - launch_map_[kNumberTypeUInt16] = &TransposeKernelMod::LaunchKernel; - launch_map_[kNumberTypeUInt32] = &TransposeKernelMod::LaunchKernel; - launch_map_[kNumberTypeUInt64] = &TransposeKernelMod::LaunchKernel; - launch_map_[kNumberTypeFloat16] = &TransposeKernelMod::LaunchKernel; - launch_map_[kNumberTypeFloat32] = &TransposeKernelMod::LaunchKernel; - launch_map_[kNumberTypeFloat64] = &TransposeKernelMod::LaunchKernel; - auto iter = launch_map_.find(dtype_); - if (iter != launch_map_.end()) { - launch_func_ = iter->second; - } else { - MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_; - } - free(address_ptr->addr); - inputs[kIndex1]->GetData()->addr = nullptr; - inputs[kIndex1]->GetData()->size = 0; - return true; -} - -template -void TransposeKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - if (inputs.size() < 1 || outputs.size() < 1 || inputs[0]->device_ptr() == nullptr || - outputs[0]->device_ptr() == nullptr) { - MS_LOG(ERROR) << "inputs or outputs shape size is wrong."; - return; - } - const auto *input_addr = reinterpret_cast(inputs[0]->device_ptr()); - auto *output_addr = reinterpret_cast(outputs[0]->device_ptr()); - transpose_param_.data_num_ = SizeToInt(inputs[0]->size() / sizeof(T)); - int output_shape[SizeToInt(output_shape_.size())]; - for (size_t i = 0; i < output_shape_.size(); ++i) { - output_shape[i] = output_shape_[i]; - } - bool res{static_cast(NNACL_OK)}; - res = DoTranspose(input_addr, output_addr, output_shape, &transpose_param_); - if (res != static_cast(NNACL_OK)) { - MS_LOG(EXCEPTION) << "Transpose run failed."; - } -} - -template -int TransposeKernelMod::DoTranspose(const T *in_data, T *out_data, const int *output_shape, - const TransposeParameter *transpose_param) { - NNACL_CHECK_NULL_RETURN_ERR(in_data); - NNACL_CHECK_NULL_RETURN_ERR(out_data); - NNACL_CHECK_NULL_RETURN_ERR(output_shape); - NNACL_CHECK_NULL_RETURN_ERR(transpose_param); - const int *perm = transpose_param->perm_; - const int *strides = transpose_param->strides_; - const int *out_strides = transpose_param->out_strides_; - int data_size = transpose_param->data_num_ * sizeof(T); - int num_axes = transpose_param->num_axes_; - bool needTranspose = false; - for (size_t i = 1; i < (unsigned int)num_axes; ++i) { - if (perm[i] - perm[i - 1] != 1) { - needTranspose = true; - break; - } - } - if (!needTranspose) { - (void)memcpy(out_data, in_data, data_size); - return NNACL_OK; - } - for (size_t i = 0; i < (unsigned int)num_axes; ++i) { - if (perm[i] < 0) { - return NNACL_PARAM_INVALID; - } - } - if (num_axes == kIndex2) { - TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape); - } else if (num_axes == kIndex3) { - TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape); - } else if (num_axes == kIndex4) { - TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape); - } else if (num_axes == kIndex5) { - TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape); - } else if (num_axes == kIndex6) { - TransposeDim6(in_data, out_data, strides, out_strides, perm, output_shape); - } else if (num_axes == kIndex7) { - TransposeDim7(in_data, out_data, strides, out_strides, perm, output_shape); - } else { - return NNACL_ERR; - } - return NNACL_OK; -} - -template -void TransposeKernelMod::TransposeDim2(const T *in_data, T *out_data, const int *strides, const int *out_strides, - const int *perm, const int *output_shape) { - const int stride0 = strides[perm[kIndex0]]; - const int stride1 = strides[perm[kIndex1]]; - const int output0 = output_shape[kIndex0]; - const int output1 = output_shape[kIndex1]; - for (size_t i = 0; i < (unsigned int)output0; ++i) { - size_t out_stride0_i = i * output1; - size_t stride0_i = i * 1 * stride0; - for (size_t j = 0; j < (unsigned int)output1; ++j) { - out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; - } - } -} - -template -void TransposeKernelMod::TransposeDim3(const T *in_data, T *out_data, const int *strides, const int *out_strides, - const int *perm, const int *output_shape) { - const int stride0 = strides[perm[kIndex0]]; - const int stride1 = strides[perm[kIndex1]]; - const int stride2 = strides[perm[kIndex2]]; - const int out_stride0 = out_strides[kIndex0]; - const int out_stride1 = out_strides[kIndex1]; - const int output0 = output_shape[kIndex0]; - const int output1 = output_shape[kIndex1]; - const int output2 = output_shape[kIndex2]; - for (size_t i = 0; i < (unsigned int)output0; ++i) { - size_t out_stride0_i = i * out_stride0; - size_t stride0_i = i * stride0; - for (size_t j = 0; j < (unsigned int)output1; ++j) { - size_t out_stride1_j = j * out_stride1; - size_t stride1_j = j * stride1; - for (size_t k = 0; k < (unsigned int)output2; ++k) { - out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; - } - } - } -} - -template -void TransposeKernelMod::TransposeDim4(const T *in_data, T *out_data, const int *strides, const int *out_strides, - const int *perm, const int *output_shape) { - const int stride0 = strides[perm[kIndex0]]; - const int stride1 = strides[perm[kIndex1]]; - const int stride2 = strides[perm[kIndex2]]; - const int stride3 = strides[perm[kIndex3]]; - const int out_stride0 = out_strides[kIndex0]; - const int out_stride1 = out_strides[kIndex1]; - const int out_stride2 = out_strides[kIndex2]; - const int output0 = output_shape[kIndex0]; - const int output1 = output_shape[kIndex1]; - const int output2 = output_shape[kIndex2]; - const int output3 = output_shape[kIndex3]; - for (size_t i = 0; i < (unsigned int)output0; ++i) { - size_t out_stride0_i = i * out_stride0; - size_t stride0_i = i * stride0; - for (size_t j = 0; j < (unsigned int)output1; ++j) { - size_t out_stride1_j = j * out_stride1; - size_t stride1_j = j * stride1; - for (size_t k = 0; k < (unsigned int)output2; ++k) { - size_t out_stride2_k = k * out_stride2; - size_t stride2_k = k * stride2; - for (size_t m = 0; m < (unsigned int)output3; ++m) { - out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = - in_data[stride0_i + stride1_j + stride2_k + m * stride3]; - } - } - } - } -} - -template -void TransposeKernelMod::TransposeDim5(const T *in_data, T *out_data, const int *strides, const int *out_strides, - const int *perm, const int *output_shape) { - const int stride0 = strides[perm[kIndex0]]; - const int stride1 = strides[perm[kIndex1]]; - const int stride2 = strides[perm[kIndex2]]; - const int stride3 = strides[perm[kIndex3]]; - const int stride4 = strides[perm[kIndex4]]; - const int out_stride0 = out_strides[kIndex0]; - const int out_stride1 = out_strides[kIndex1]; - const int out_stride2 = out_strides[kIndex2]; - const int out_stride3 = out_strides[kIndex3]; - const int output0 = output_shape[kIndex0]; - const int output1 = output_shape[kIndex1]; - const int output2 = output_shape[kIndex2]; - const int output3 = output_shape[kIndex3]; - const int output4 = output_shape[kIndex4]; - for (size_t i = 0; i < (unsigned int)output0; ++i) { - size_t out_stride0_i = i * out_stride0; - size_t stride0_i = i * stride0; - for (size_t j = 0; j < (unsigned int)output1; ++j) { - size_t out_stride1_j = j * out_stride1; - size_t stride1_j = j * stride1; - for (size_t k = 0; k < (unsigned int)output2; ++k) { - size_t out_stride2_k = k * out_stride2; - size_t stride2_k = k * stride2; - for (size_t m = 0; m < (unsigned int)output3; ++m) { - size_t out_stride3_m = m * out_stride3; - size_t stride3_m = m * stride3; - for (size_t n = 0; n < (unsigned int)output4; ++n) { - out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = - in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; - } - } - } - } - } -} - -template -void TransposeKernelMod::TransposeDim6(const T *in_data, T *out_data, const int *strides, const int *out_strides, - const int *perm, const int *output_shape) { - const int stride0 = strides[perm[kIndex0]]; - const int stride1 = strides[perm[kIndex1]]; - const int stride2 = strides[perm[kIndex2]]; - const int stride3 = strides[perm[kIndex3]]; - const int stride4 = strides[perm[kIndex4]]; - const int stride5 = strides[perm[kIndex5]]; - const int out_stride0 = out_strides[kIndex0]; - const int out_stride1 = out_strides[kIndex1]; - const int out_stride2 = out_strides[kIndex2]; - const int out_stride3 = out_strides[kIndex3]; - const int out_stride4 = out_strides[kIndex4]; - const int output0 = output_shape[kIndex0]; - const int output1 = output_shape[kIndex1]; - const int output2 = output_shape[kIndex2]; - const int output3 = output_shape[kIndex3]; - const int output4 = output_shape[kIndex4]; - const int output5 = output_shape[kIndex5]; - for (size_t i = 0; i < (unsigned int)output0; ++i) { - size_t out_stride0_i = i * out_stride0; - size_t stride0_i = i * stride0; - for (size_t j = 0; j < (unsigned int)output1; ++j) { - size_t out_stride1_j = j * out_stride1; - size_t stride1_j = j * stride1; - for (size_t k = 0; k < (unsigned int)output2; ++k) { - size_t out_stride2_k = k * out_stride2; - size_t stride2_k = k * stride2; - for (size_t m = 0; m < (unsigned int)output3; ++m) { - size_t out_stride3_m = m * out_stride3; - size_t stride3_m = m * stride3; - for (size_t n = 0; n < (unsigned int)output4; ++n) { - size_t out_stride4_n = n * out_stride4; - size_t stride4_n = n * stride4; - for (size_t g = 0; g < (unsigned int)output5; ++g) { - out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + g] = - in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + g * stride5]; - } - } - } - } - } - } -} - -template -void TransposeKernelMod::TransposeDim7(const T *in_data, T *out_data, const int *strides, const int *out_strides, - const int *perm, const int *output_shape) { - const int stride0 = strides[perm[kIndex0]]; - const int stride1 = strides[perm[kIndex1]]; - const int stride2 = strides[perm[kIndex2]]; - const int stride3 = strides[perm[kIndex3]]; - const int stride4 = strides[perm[kIndex4]]; - const int stride5 = strides[perm[kIndex5]]; - const int stride6 = strides[perm[kIndex6]]; - const int out_stride0 = out_strides[kIndex0]; - const int out_stride1 = out_strides[kIndex1]; - const int out_stride2 = out_strides[kIndex2]; - const int out_stride3 = out_strides[kIndex3]; - const int out_stride4 = out_strides[kIndex4]; - const int out_stride5 = out_strides[kIndex5]; - const int output0 = output_shape[kIndex0]; - const int output1 = output_shape[kIndex1]; - const int output2 = output_shape[kIndex2]; - const int output3 = output_shape[kIndex3]; - const int output4 = output_shape[kIndex4]; - const int output5 = output_shape[kIndex5]; - const int output6 = output_shape[kIndex6]; - for (size_t i = 0; i < (unsigned int)output0; ++i) { - size_t out_stride0_i = i * out_stride0; - size_t stride0_i = i * stride0; - for (size_t j = 0; j < (unsigned int)output1; ++j) { - size_t out_stride1_j = j * out_stride1; - size_t stride1_j = j * stride1; - for (size_t k = 0; k < (unsigned int)output2; ++k) { - size_t out_stride2_k = k * out_stride2; - size_t stride2_k = k * stride2; - for (size_t m = 0; m < (unsigned int)output3; ++m) { - size_t out_stride3_m = m * out_stride3; - size_t stride3_m = m * stride3; - for (size_t n = 0; n < (unsigned int)output4; ++n) { - size_t out_stride4_n = n * out_stride4; - size_t stride4_n = n * stride4; - for (size_t g = 0; g < (unsigned int)output5; ++g) { - size_t out_stride5_g = g * out_stride5; - size_t stride5_g = g * stride5; - for (size_t s = 0; s < (unsigned int)output6; ++s) { - out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + out_stride4_n + out_stride5_g + - s] = - in_data[stride0_i + stride1_j + stride2_k + stride3_m + stride4_n + stride5_g + s * stride6]; - } - } - } - } - } - } - } -} - -template -void TransposeKernelMod::TransposeDims(const T *in_data, T *out_data, const int *output_shape, - const TransposeParameter *transpose_param, int task_id, int thread_num) { - NNACL_CHECK_NULL_RETURN_VOID(in_data); - NNACL_CHECK_NULL_RETURN_VOID(out_data); - NNACL_CHECK_NULL_RETURN_VOID(output_shape); - NNACL_CHECK_NULL_RETURN_VOID(transpose_param); - NNACL_CHECK_ZERO_RETURN(thread_num); - const int *perm = transpose_param->perm_; - const int *strides = transpose_param->strides_; - const int *out_strides = transpose_param->out_strides_; - int num_axes = transpose_param->num_axes_; - size_t data_size = (*out_strides) * output_shape[0]; - size_t offset_size = UP_DIV(data_size, thread_num); - size_t task_offset = offset_size * task_id; - int count = data_size - task_offset; - if (count <= 0) { - return; - } - count = MSMIN(offset_size, (unsigned int)count); - for (int idx = task_offset; (unsigned int)idx < task_offset + count; ++idx) { - int pos = idx; - int output_idx = 0; - int input_idx = 0; - for (int i = 0; i < num_axes; ++i) { - NNACL_CHECK_ZERO_RETURN(*(out_strides + i)); - int position = pos / *(out_strides + i); - int out_stride = i < num_axes - 1 ? out_strides[i] : 1; - output_idx += (position * out_stride); - input_idx += (position * strides[perm[i]]); - pos -= position * (*(out_strides + i)); - } - out_data[output_idx] = in_data[input_idx]; - } -} - -MS_KERNEL_FACTORY_REG_BY_CREATOR(KernelMod, Transpose, - []() { return std::make_shared(prim::kPrimTranspose->name()); }); -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.h b/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.h deleted file mode 100644 index 94dd8fa76dbfa2f114aa1d4697f5cc7d0e2f7c7b..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/cpu/transpose_kernel_mod.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CPU_TRANSPOSE_KERNEL_MOD_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CPU_TRANSPOSE_KERNEL_MOD_H_ - -#include -#include -#include -#include -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "nnacl/transpose_parameter.h" -#include "common/common_utils.h" - -namespace mindspore::kernel { -class TransposeKernelMod : public NativeCpuKernelMod { - public: - TransposeKernelMod() = default; - ~TransposeKernelMod() override = default; - - explicit TransposeKernelMod(const std::string name) { kernel_name_ = name; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - bool Init(const std::vector &inputs, const std::vector &outputs) override; - - int Resize(const std::vector &inputs, const std::vector &outputs) override; - std::vector GetOpSupport() override { return {}; } - - private: - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); - template - int DoTranspose(const T *in_data, T *out_data, const int *output_shape, const TransposeParameter *transpose_param); - template - void TransposeDim2(const T *in_data, T *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape); - template - void TransposeDim3(const T *in_data, T *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape); - template - void TransposeDim4(const T *in_data, T *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape); - template - void TransposeDim5(const T *in_data, T *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape); - template - void TransposeDim6(const T *in_data, T *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape); - template - void TransposeDim7(const T *in_data, T *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape); - template - void TransposeDims(const T *in_data, T *out_data, const int *output_shape, const TransposeParameter *transpose_param, - int task_id, int thread_num); - - TransposeParameter transpose_param_; - std::vector input_shape_; - std::vector output_shape_; - std::vector axes_; - TypeId dtype_{kTypeUnknown}; - using TypeKernel = - std::function &, const std::vector &)>; - std::unordered_map launch_map_; - TypeKernel launch_func_; -}; -} // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CPU_TRANSPOSE_KERNEL_MOD_H_ diff --git a/mindspore-lite/src/extendrt/kernel/cuda/CMakeLists.txt b/mindspore-lite/src/extendrt/kernel/cuda/CMakeLists.txt deleted file mode 100644 index cc0019bdc09422f5f31c838b0288860bfb34d1ce..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/cuda/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -file(GLOB CUDA_LITE_KERNEL_SRC LIST_DIRECTORIES false - ${CMAKE_CURRENT_SOURCE_DIR}/*.cc - ${OPS_DIR}/kernel/gpu/cuda_impl/cuda_class/*.cc -) - -add_library(cuda_lite_kernel_mid OBJECT ${CUDA_LITE_KERNEL_SRC}) diff --git a/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.cc b/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.cc deleted file mode 100644 index ee4f07aca0255911029c85e416ef2c25435e5176..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/cuda/batchtospace.cc +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/kernel/cuda/batchtospace.h" -#include -#include "nnacl/batch_to_space_parameter.h" - -namespace mindspore::kernel { -int BatchtoSpaceCudaKernel::Prepare() { - CudaKernel::Prepare(); - if (batch_to_space_helper_ == nullptr) { - batch_to_space_helper_ = std::make_shared>(type_name_); - helper_ = batch_to_space_helper_; - } - cukernel::BatchToSpaceAttr attr; - auto param = reinterpret_cast(op_parameter_); - attr.block_size = param->block_shape_[0]; - attr.crops.push_back({param->crops_[0], param->crops_[1]}); - attr.crops.push_back({param->crops_[2], param->crops_[3]}); - int ret = batch_to_space_helper_->CheckKernelParam(&attr); - CHECK_NOT_EQUAL_RETURN(ret, RET_OK); - ret = ReSize(); - CHECK_NOT_EQUAL_RETURN(ret, RET_OK); - return RET_OK; -} - -int BatchtoSpaceCudaKernel::Run() { - int ret = batch_to_space_helper_->Process(input_device_ptrs_, output_device_ptrs_, work_device_ptrs_, stream_); - CHECK_NOT_EQUAL_RETURN(ret, RET_OK); - return RET_OK; -} - -// REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BatchToSpace, CudaKernelCreator) -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/kernel/cuda/cuda_kernel.cc b/mindspore-lite/src/extendrt/kernel/cuda/cuda_kernel.cc deleted file mode 100644 index f123219bb2f7a8fe5193c250afc2e1ddf64e91c4..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/cuda/cuda_kernel.cc +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/kernel/cuda/cuda_kernel.h" - -namespace mindspore::kernel { -int CudaKernel::PreProcess() { - if (output_device_size_.size() == 0) { - for (size_t i = 0; i < out_tensors_.size(); i++) { - // allocator cudaMalloc mem_size: out_tensors_[i]->set_allocator(/*CudaAllocator*/) - output_device_size_.push_back(helper_->GetOutputSizeList()[i]); - output_device_ptrs_.push_back(out_tensors_[i]->MutableData()); - } - } else { - for (size_t i = 0; i < out_tensors_.size(); i++) { - if (helper_->GetOutputSizeList()[i] > output_device_size_[i]) { - out_tensors_[i]->FreeData(); - output_device_size_[i] = helper_->GetOutputSizeList()[i]; - output_device_ptrs_[i] = out_tensors_[i]->MutableData(); - } - } - } - for (size_t i = 0; i < in_tensors_.size(); i++) { - input_device_ptrs_[i] = in_tensors_[i]->data(); - } - return RET_OK; -} - -int CudaKernel::ReSize() { - // menory calculate - std::vector> input_shapes; - std::vector> output_shapes; - for (auto in : in_tensors_) { - std::vector one_shape(in->shape().size()); - for (size_t i = 0; i < in->shape().size(); i++) { - one_shape[i] = static_cast(in->shape()[i]); - } - input_shapes.push_back(one_shape); - } - for (auto out : out_tensors_) { - std::vector one_shape(out->shape().size()); - for (size_t i = 0; i < out->shape().size(); i++) { - one_shape[i] = static_cast(out->shape()[i]); - } - output_shapes.push_back(one_shape); - } - helper_->ResetResource(); - auto ret = helper_->CalMemSize(input_shapes, output_shapes); - CHECK_NOT_EQUAL_RETURN(ret, RET_OK); - return RET_OK; -} -int CudaKernel::PostProcess() { - for (size_t i = 0; i < in_tensors_.size(); i++) { - in_tensors_[i]->DecRefCount(); - } - return RET_OK; -} -CudaKernel::~CudaKernel() { helper_ = nullptr; } -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/kernel/cuda/cuda_kernel.h b/mindspore-lite/src/extendrt/kernel/cuda/cuda_kernel.h deleted file mode 100644 index 635b01e2c8191fe3fe45fb61c35f252e47565710..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/cuda/cuda_kernel.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CUDA_CUDA_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_CUDA_CUDA_KERNEL_H_ - -#include -#include -#include -#include -#include "src/litert/inner_kernel.h" -#include "src/litert/lite_kernel.h" -#include "cuda_impl/cuda_class/helper_base.h" - -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -namespace mindspore::kernel { -class CudaKernel : public InnerKernel { - public: - CudaKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::InnerContext *ctx) - : InnerKernel(parameter, inputs, outputs, ctx) {} - ~CudaKernel() override; - int Prepare() override { - type_name_ = std::string(EnumNamePrimitiveType(type())); - return RET_OK; - } - int PreProcess() override; - int PostProcess() override; - int ReSize() override; - int Run() override { return RET_ERROR; } - - protected: - std::vector output_device_size_; - std::vector input_device_ptrs_; - std::vector output_device_ptrs_; - std::vector work_device_ptrs_; - cudaStream_t stream_; - std::shared_ptr helper_{nullptr}; - std::string type_name_; -}; -template -kernel::InnerKernel *CudaKernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, - const lite::InnerContext *ctx, const kernel::KernelKey &desc) { - return new (std::nothrow) T(opParameter, inputs, outputs, ctx); -} -} // namespace mindspore::kernel -#endif diff --git a/mindspore-lite/src/extendrt/kernel/default/cnode_infer_manager.cc b/mindspore-lite/src/extendrt/kernel/default/cnode_infer_manager.cc deleted file mode 100644 index f900aea414203596530327392ec3499cfdebc091..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/default/cnode_infer_manager.cc +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "src/extendrt/kernel/default/cnode_infer_manager.h" -#include -#include "abstract/abstract_value.h" -#include "include/backend/anf_runtime_algorithm.h" -#include "src/extendrt/graph_compiler/anfnode_tensor_adapter.h" - -namespace mindspore { -namespace kernel { -bool SetDTAndShapeFromAbTensorToLiteTensor(const AbstractBasePtr &abstract, lite::Tensor *tensor) { - if (!utils::isa(abstract)) { - MS_LOG(ERROR) << "The abstract should be tensor, but got abstract : " << abstract; - return false; - } - ShapeVector shape_vector; - TypeId data_type = kTypeUnknown; - auto ret = lite::TensorAdapter::GetDTAndShapeFromAbTensor( - utils::cast(abstract), &data_type, &shape_vector); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Get dtype and shape from abstract failed, abstract : " << abstract; - return false; - } - std::vector int32_shape; - std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(int32_shape), - [](const auto &shape) { return static_cast(shape); }); - tensor->set_data_type(data_type); - tensor->set_shape(int32_shape); - tensor->set_format(NHWC); - return true; -} - -int CNodeInferShape(const CNodePtr &cnode, const std::vector &outputs) { - session::AnfRuntimeAlgorithm::InferShape(cnode); - // sync cnode abstract info to Lite Tensor - auto abstract = cnode->abstract(); - if (utils::isa(abstract)) { - auto elements = utils::cast(abstract)->elements(); - if (elements.size() != outputs.size()) { - MS_LOG(ERROR) << "The cnode output size: " << elements.size() - << " is not equal to lite tensors size: " << outputs.size(); - return lite::RET_ERROR; - } - for (size_t i = 0; i < elements.size(); i++) { - if (!SetDTAndShapeFromAbTensorToLiteTensor(elements[i], outputs[i])) { - MS_LOG(ERROR) << "Set tensor info from abstract failed, abstract : " << elements[i]; - return lite::RET_ERROR; - } - } - return lite::RET_OK; - } - if (utils::isa(abstract)) { - if (!SetDTAndShapeFromAbTensorToLiteTensor(abstract, outputs[0])) { - MS_LOG(ERROR) << "Set tensor info from abstract failed, abstract : " << abstract; - return lite::RET_ERROR; - } - return lite::RET_OK; - } - MS_LOG(ERROR) << "Unsupported abstract: " << abstract; - return lite::RET_ERROR; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/kernel/default/default_kernel_lib.cc b/mindspore-lite/src/extendrt/kernel/default/default_kernel_lib.cc deleted file mode 100644 index cb44c0f8d45a998c08e9b05fa828b621c63e1e87..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/default/default_kernel_lib.cc +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/kernel/default/default_kernel_lib.h" -#include "src/extendrt/kernel/default/kernel_mod_kernel.h" -#include "common/ms_factory.h" -#include "plugin/device/cpu/kernel/cpu_kernel.h" -#include "src/infer/graph_compiler.h" - -namespace mindspore::kernel { -std::shared_ptr DefaultKernelLib::CreateKernelMod(const PrimitiveType &op_type, const KernelAttr &attr, - const Format &format, const std::string &backend) { - if (backend != kBackendCPU) { - MS_LOG(INFO) << "DefaultKernelLib only support CPU backend, but got: " << backend << "."; - return nullptr; - } - if (!MatchFormat(format, Format::NCHW)) { - MS_LOG(INFO) << "DefaultKernelLib only support NCHW layout, but got " << FormatEnumToString(format); - return nullptr; - } - auto kernel_mod = Factory::Instance().Create(op_type.TypeName()); - if (kernel_mod == nullptr) { - MS_LOG(INFO) << "Create kernel mod failed. kernel: " << op_type; - return nullptr; - } - auto match_ret = MatchKernelAttr(attr, kernel_mod->GetOpSupport()); - if (!match_ret.first) { - MS_LOG(INFO) << "For '" << op_type << "' does not support this kernel type: " << attr; - return nullptr; - } - return kernel_mod; -} - -bool DefaultKernelLib::Support(const PrimitiveType &op_type, const KernelAttr &attr, const std::string &backend, - const Format &format) const { - return DefaultKernelLib::CreateKernelMod(op_type, attr, format, backend) != nullptr; -} - -BaseKernel *DefaultKernelLib::CreateKernel(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) const { - auto kernel_mod = DefaultKernelLib::CreateKernelMod(spec.op_type, spec.attr, spec.format, spec.backend); - if (kernel_mod == nullptr) { - MS_LOG(ERROR) << "Create kernel mod failed. kernel: " << spec.op_type; - return nullptr; - } - return new (std::nothrow) KernelModKernel(kernel_mod, spec.primitive, spec.cnode, inputs, outputs, ctx); -} - -InferKernel *DefaultKernelLib::CreateKernelExec(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, - const InferContext *ctx) const { - auto *kernel_exec = KernelLib::CreateKernelExec(spec, inputs, outputs, ctx); - if (kernel_exec == nullptr) { - return nullptr; - } - auto desc = kernel_exec->desc(); - desc.format = Format::NCHW; - kernel_exec->set_desc(desc); - return kernel_exec; -} - -REG_KERNEL_LIB(kDefaultKernelLibName, DefaultKernelLib); -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/kernel/default/default_kernel_lib.h b/mindspore-lite/src/extendrt/kernel/default/default_kernel_lib.h deleted file mode 100644 index 34e510e7e8da644dc69bbce47755ed164d27ab1d..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/default/default_kernel_lib.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_DEFAULT_DEFAULT_LIB_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_DEFAULT_DEFAULT_LIB_H_ - -#include -#include -#include -#include "src/extendrt/kernel/kernel_lib.h" -#include "src/extendrt/kernel/kernel_spec_infos.h" - -namespace mindspore::kernel { -class DefaultKernelLib : public KernelLib { - public: - DefaultKernelLib() : KernelLib(kDefaultKernelLibName, kBackendCPU) {} - - bool Support(const PrimitiveType &op_type, const KernelAttr &attr, const std::string &backend, - const Format &format = DEFAULT_FORMAT) const override; - BaseKernel *CreateKernel(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) const override; - InferKernel *CreateKernelExec(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) const override; - - private: - static std::shared_ptr CreateKernelMod(const PrimitiveType &op_type, - const KernelAttr &attr, const Format &format, - const std::string &backend); -}; -} // namespace mindspore::kernel -#endif diff --git a/mindspore-lite/src/extendrt/kernel/default/kernel_mod_kernel.cc b/mindspore-lite/src/extendrt/kernel/default/kernel_mod_kernel.cc deleted file mode 100644 index 157fa1a0c5de42567ee93f75a0f8dbca6ab750fe..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/default/kernel_mod_kernel.cc +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/kernel/default/kernel_mod_kernel.h" -#include "src/extendrt/utils/tensor_utils.h" -#include "src/extendrt/kernel/default/cnode_infer_manager.h" - -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; - -namespace mindspore::kernel { -int KernelModKernel::Prepare() { - if (!InferShapeDone()) { - return RET_OK; - } - auto inputs = CloudTensorUtils::LiteTensorToKernelTensorPtrVec(in_tensors_); - auto outputs = CloudTensorUtils::LiteTensorToKernelTensorPtrVec(out_tensors_); - - bool ret = kernel_mod_->Init(inputs, outputs); - return ret ? ReSize() : RET_ERROR; -} - -int KernelModKernel::ReSize() { - auto inputs = CloudTensorUtils::LiteTensorToKernelTensorPtrVec(in_tensors_); - auto outputs = CloudTensorUtils::LiteTensorToKernelTensorPtrVec(out_tensors_); - return kernel_mod_->Resize(inputs, outputs); -} - -int KernelModKernel::Run() { - auto inputs = CloudTensorUtils::LiteTensorToKernelTensorPtrVec(in_tensors_); - auto outputs = CloudTensorUtils::LiteTensorToKernelTensorPtrVec(out_tensors_); - - std::vector workspace; - auto workspace_size = kernel_mod_->GetWorkspaceSizeList(); - for (size_t &i : workspace_size) { - auto buffer = context_->allocator->Malloc(i); - auto tensor = new (std::nothrow) kernel::KernelTensor(); - if (tensor == nullptr || buffer == nullptr) { - return RET_ERROR; - } - tensor->set_device_ptr(buffer); - workspace.push_back(tensor); - } - - auto ret = kernel_mod_->Launch(inputs, workspace, outputs, nullptr); - - for (const auto &tensor : workspace) { - context_->allocator->Free(tensor->device_ptr()); - tensor->set_device_ptr(nullptr); - } - return ret ? RET_OK : RET_ERROR; -} - -int KernelModKernel::InferShape() { return CNodeInferShape(cnode_, this->out_tensors_); } -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/kernel/default/kernel_mod_kernel.h b/mindspore-lite/src/extendrt/kernel/default/kernel_mod_kernel.h deleted file mode 100644 index 2e8a73f1178cc37e42d4c16d274483f900844df4..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/default/kernel_mod_kernel.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_DEFAULT_KERNEL_MOD_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_DEFAULT_KERNEL_MOD_KERNEL_H_ - -#include -#include -#include -#include -#include "src/extendrt/kernel/base_kernel.h" -#include "common/kernel.h" -#include "ops/base_operator.h" - -namespace mindspore::kernel { -class KernelModKernel : public BaseKernel { - public: - KernelModKernel(std::shared_ptr kernel_mod, BaseOperatorPtr base_operator, - CNodePtr cnode, const std::vector &in_tensors, - const std::vector &out_tensors, const InferContext *ctx) - : BaseKernel({base_operator, cnode}, in_tensors, out_tensors, ctx), - kernel_mod_(std::move(kernel_mod)), - base_operator_(std::move(base_operator)), - cnode_(std::move(cnode)) {} - ~KernelModKernel() override = default; - - int Prepare() override; - int InferShape() override; - int ReSize() override; - int Run() override; - - private: - KernelModPtr kernel_mod_; - BaseOperatorPtr base_operator_; - CNodePtr cnode_; -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_DEFAULT_KERNEL_MOD_KERNEL_H_ diff --git a/mindspore-lite/src/extendrt/kernel/extendrt_kernel_exec.h b/mindspore-lite/src/extendrt/kernel/extendrt_kernel_exec.h deleted file mode 100644 index c7237e01f7ce4da0d7a688de66b00bb5f7b363e5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/extendrt_kernel_exec.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_EXTENDRT_KERNEL_EXEC_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_EXTENDRT_KERNEL_EXEC_H_ -#include -#include -#include -#include "src/executor/kernel_exec.h" -#include "src/extendrt/kernel/base_kernel.h" - -namespace mindspore::kernel { -class ExtendRTKernelExec : public KernelExec { - public: - ExtendRTKernelExec() : KernelExec() {} - - explicit ExtendRTKernelExec(std::shared_ptr kernel) : KernelExec(std::move(kernel)) {} - - ~ExtendRTKernelExec() override = default; - - bool IsBuiltin() override { - if (desc_.provider != kBuiltin) { - MS_LOG(EXCEPTION) << "Custom kernel not supported in ExtendRT now."; - } - return false; - } - - OpParameter *op_parameter() const override { - MS_ASSERT(kernel_ != nullptr); - return std::static_pointer_cast(kernel_)->op_parameter(); - } - - PrimitiveType type() const override { return reinterpret_cast(kernel_.get())->type(); } - - void set_in_tensors(const std::vector &in_tensors) override { - std::static_pointer_cast(kernel_)->set_in_tensors(in_tensors); - } - - void set_in_tensor(lite::Tensor *in_tensor, size_t index) override { - std::static_pointer_cast(kernel_)->set_in_tensor(in_tensor, index); - } - - void set_out_tensors(const std::vector &out_tensors) override { - std::static_pointer_cast(kernel_)->set_out_tensors(out_tensors); - } - - void set_out_tensor(lite::Tensor *out_tensor, size_t index) override { - std::static_pointer_cast(kernel_)->set_out_tensor(out_tensor, index); - } - - const std::vector &in_tensors() const override { - return std::static_pointer_cast(kernel_)->in_tensors(); - } - - const std::vector &out_tensors() const override { - return std::static_pointer_cast(kernel_)->out_tensors(); - } -}; -} // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_EXTENDRT_KERNEL_EXEC_H_ diff --git a/mindspore-lite/src/extendrt/kernel/kernel_lib.h b/mindspore-lite/src/extendrt/kernel/kernel_lib.h deleted file mode 100644 index b692e4ab138e5c8f70f63dc1c8e6e6ad895af0e6..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/kernel_lib.h +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_LIB_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_LIB_H_ - -#include -#include -#include -#include -#include -#include -#include "src/common/log_adapter.h" -#include "mindapi/base/format.h" -#include "src/infer/primitive_type.h" -#include "src/infer/kernel.h" -#include "src/infer/tensor.h" -#include "src/infer/context.h" -#include "src/extendrt/kernel/base_kernel.h" -#include "ops/base_operator.h" -#include "common/common_utils.h" -#include "src/extendrt/kernel/extendrt_kernel_exec.h" -#include "src/extendrt/kernel/kernel_spec_infos.h" - -namespace mindspore::kernel { -struct KernelSpec { - PrimitiveType op_type; - KernelAttr attr; - Format format; - std::string backend; - BaseOperatorPtr primitive; - CNodePtr cnode; -}; - -class KernelLib { - public: - KernelLib(std::string name, std::string backend) : name_(std::move(name)), backend_(std::move(backend)) {} - virtual ~KernelLib() = default; - virtual bool Support(const PrimitiveType &op_type, const KernelAttr &attr, const std::string &backend, - const Format &format = DEFAULT_FORMAT) const = 0; - virtual BaseKernel *CreateKernel(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) const = 0; - - virtual InferKernel *CreateKernelExec(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) const { - auto *base_kernel = this->CreateKernel(spec, inputs, outputs, ctx); - if (base_kernel == nullptr) { - MS_LOG(ERROR) << "Create base kernel failed. kernel: " << spec.op_type; - return nullptr; - } - auto *kernel_exec = new (std::nothrow) ExtendRTKernelExec(std::shared_ptr(base_kernel)); - if (kernel_exec == nullptr) { - MS_LOG(ERROR) << "Create kernel exec failed. kernel: " << spec.op_type; - return nullptr; - } - auto desc = kernel_exec->desc(); - if (backend_ == kernel::kBackendAscend) { - desc.arch = kernel::KERNEL_ARCH::kACL; - } else if (backend_ == kernel::kBackendGPU) { - desc.arch = kernel::KERNEL_ARCH::kGPU; - } else { - desc.arch = kernel::KERNEL_ARCH::kCPU; - } - desc.format = spec.format; - desc.kernel_arch = backend_; - kernel_exec->set_desc(desc); - kernel_exec->set_context(ctx); - return kernel_exec; - } - - std::string Name() const { return name_; } - std::string Backend() const { return backend_; } - - protected: - static bool MatchFormat(const Format &format1, const Format &format2) { - if (format1 == Format::DEFAULT_FORMAT || format2 == Format::DEFAULT_FORMAT) { - return true; - } - return format1 == format2; - } - - protected: - std::string name_; // provider - std::string backend_; -}; - -class KernelLibRegister { - public: - static KernelLibRegister &Instance() { - static KernelLibRegister instance; - return instance; - } - - virtual ~KernelLibRegister() { - for (auto &iter : kernel_libs_) { - delete iter.second; - } - kernel_libs_.clear(); - } - - bool RegKernelLib(const std::string &provider, const KernelLib *lib) { - auto iter = kernel_libs_.find(provider); - if (MS_LIKELY(iter != kernel_libs_.end())) { - MS_LOG(ERROR) << "KernelLib " << provider << " is already exist."; - return false; - } - kernel_libs_[provider] = lib; - return true; - } - - KernelLib *GetKernelLib(const std::string &provider) { - auto iter = kernel_libs_.find(provider); - if (MS_LIKELY(iter == kernel_libs_.end())) { - MS_LOG(ERROR) << "KernelLib " << provider << " is not exist."; - return nullptr; - } - return const_cast(iter->second); - } - - const std::unordered_map &GetAllLibs() { return kernel_libs_; } - - private: - KernelLibRegister() = default; - - private: - // map from provider/name of kernel-lib to kernel-lib - std::unordered_map kernel_libs_; -}; - -class KernelLibRegistry { - public: - KernelLibRegistry(const std::string &provider, const KernelLib *lib) { - if (MS_UNLIKELY(lib == nullptr)) { - MS_LOG(WARNING) << "KernelLib " << provider << " is nullptr, ignored."; - return; - } - (void)KernelLibRegister::Instance().RegKernelLib(provider, lib); - } -}; - -#define REG_KERNEL_LIB(name, class) static KernelLibRegistry g_##class##Registry(name, new (std::nothrow) class()) -} // namespace mindspore::kernel -#endif diff --git a/mindspore-lite/src/extendrt/kernel/kernel_selector/format_first_kernel_selector.h b/mindspore-lite/src/extendrt/kernel/kernel_selector/format_first_kernel_selector.h deleted file mode 100644 index 2e055a141d90ae36163a67f5637257662ba99947..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/kernel_selector/format_first_kernel_selector.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_SELECTOR_FORMAT_FIRST_KERNEL_SELECTOR_H -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_SELECTOR_FORMAT_FIRST_KERNEL_SELECTOR_H - -#include -#include -#include -#include -#include "src/extendrt/kernel/kernel_selector/kernel_selector.h" -#include "src/extendrt/kernel/nnacl/nnacl_lib.h" -#include "src/extendrt/kernel/default/default_kernel_lib.h" -#include "src/extendrt/graph_compiler/compile_option.h" - -namespace mindspore::kernel { -class FormatFirstKernelSelector : public KernelSelector { - public: - explicit FormatFirstKernelSelector(const std::shared_ptr &compile_option) - : KernelSelector(compile_option) {} - ~FormatFirstKernelSelector() override = default; - - InferKernel *CreateKernel(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) override { - static std::map kPriorityMap{ - {"custom", 0}, - {kNNACLLibName, 1}, - {"bolt", 1}, - {kDefaultKernelLibName, 100}, - }; - auto match_ks = spec; - auto candidates = Candidates(match_ks.op_type, match_ks.attr, match_ks.backend, match_ks.format); - if (candidates.empty()) { - match_ks.format = DEFAULT_FORMAT; - candidates = Candidates(match_ks.op_type, match_ks.attr, match_ks.backend, match_ks.format); - } - if (candidates.empty()) { - MS_LOG(ERROR) << "Can not find suitable kernellib, op_type: " << spec.op_type << ", kernel attr: " << spec.attr; - return nullptr; - } - int min_priority = INT32_MAX; - const kernel::KernelLib *selected{nullptr}; - constexpr int default_priority = 1000; - for (auto &candidate : candidates) { - auto iter = kPriorityMap.find(candidate->Name()); - int priority = (iter == kPriorityMap.end()) ? default_priority : iter->second; - if (priority < min_priority) { - min_priority = priority; - selected = candidate; - } - } - MS_ASSERT(selected != nullptr); - auto kernel = selected->CreateKernelExec(match_ks, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "Create kernel from " << selected->Name() << " failed, op_type: " << match_ks.op_type - << ", kernel attr: " << match_ks.attr; - return nullptr; - } - MS_LOG(INFO) << "Create " << selected->Name() << " kernel for " << spec.cnode->fullname_with_scope(); - return kernel; - } -}; -} // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_SELECTOR_FORMAT_FIRST_KERNEL_SELECTOR_H diff --git a/mindspore-lite/src/extendrt/kernel/kernel_selector/kernel_selector.h b/mindspore-lite/src/extendrt/kernel/kernel_selector/kernel_selector.h deleted file mode 100644 index 552c9c26c63f90acada9a65917aa36ecda8a1849..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/kernel_selector/kernel_selector.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_SELECTOR_H -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_SELECTOR_H - -#include -#include -#include -#include "src/infer/primitive_type.h" -#include "common/common_utils.h" -#include "src/infer/graph_compiler.h" -#include "src/extendrt/kernel/kernel_lib.h" -#include "src/infer/kernel.h" -#include "src/extendrt/graph_compiler/compile_option.h" - -namespace mindspore::kernel { -class KernelSelector { - public: - explicit KernelSelector(const std::shared_ptr &compile_option) - : compile_option_(compile_option) {} - virtual ~KernelSelector() = default; - virtual InferKernel *CreateKernel(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) = 0; - - protected: - // `format = DEFAULT_FORMAT` means not care about Format while select kernel. - std::vector Candidates(const PrimitiveType &op_type, const KernelAttr &require, - const std::string &backend, Format format = DEFAULT_FORMAT); - - protected: - const std::shared_ptr compile_option_{nullptr}; -}; - -std::shared_ptr CreateKernelSelector(const std::shared_ptr &compile_option); -} // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_SELECTOR_H diff --git a/mindspore-lite/src/extendrt/kernel/kernel_selector/nnacl_first_kernel_selector.h b/mindspore-lite/src/extendrt/kernel/kernel_selector/nnacl_first_kernel_selector.h deleted file mode 100644 index 606a751aa8fac78be4c82f3abb81ff1a5eeccb77..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/kernel_selector/nnacl_first_kernel_selector.h +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_SELECTOR_NNACL_FIRST_KERNEL_SELECTOR_H -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_SELECTOR_NNACL_FIRST_KERNEL_SELECTOR_H - -#include -#include -#include -#include -#include "src/extendrt/kernel/kernel_selector/kernel_selector.h" -#include "src/extendrt/kernel/nnacl/nnacl_lib.h" -#include "src/extendrt/kernel/default/default_kernel_lib.h" -#include "src/extendrt/kernel/kernel_spec_infos.h" -#include "src/extendrt/graph_compiler/compile_option.h" - -namespace mindspore::kernel { -class NNACLFirstKernelSelector : public KernelSelector { - public: - explicit NNACLFirstKernelSelector(const std::shared_ptr &compile_option) - : KernelSelector(compile_option) {} - ~NNACLFirstKernelSelector() override = default; - - InferKernel *CreateKernel(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) override { - auto nnacl_lib = KernelLibRegister::Instance().GetKernelLib(kNNACLLibName); - if (nnacl_lib == nullptr) { - MS_LOG(ERROR) << "Can not find NNACL kernellib."; - return nullptr; - } - auto match_ks = spec; - match_ks.format = DEFAULT_FORMAT; - if (nnacl_lib->Support(match_ks.op_type, match_ks.attr, match_ks.backend)) { - auto kernel = nnacl_lib->CreateKernelExec(match_ks, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "Create kernel from " << nnacl_lib->Name() << " failed, op_type: " << match_ks.op_type - << ", kernel attr: " << match_ks.attr; - return nullptr; - } - MS_LOG(INFO) << "Create NNACL kernel for " << match_ks.cnode->fullname_with_scope(); - return kernel; - } - - auto acl_lib = KernelLibRegister::Instance().GetKernelLib(kAclKernelLibName); - if (acl_lib != nullptr) { - if (acl_lib->Support(match_ks.op_type, match_ks.attr, match_ks.backend)) { - auto kernel = acl_lib->CreateKernelExec(match_ks, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "Create kernel from " << acl_lib->Name() << " failed, op_type: " << match_ks.op_type - << ", kernel attr: " << match_ks.attr; - return nullptr; - } - MS_LOG(INFO) << "Create KernelMod kernel for " << match_ks.cnode->fullname_with_scope(); - return kernel; - } - } - - auto kernelmod_lib = KernelLibRegister::Instance().GetKernelLib(kDefaultKernelLibName); - if (kernelmod_lib == nullptr) { - MS_LOG(ERROR) << "Can not find kernelmod kernellib."; - return nullptr; - } - if (kernelmod_lib->Support(match_ks.op_type, match_ks.attr, match_ks.backend)) { - auto kernel = kernelmod_lib->CreateKernelExec(match_ks, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "Create kernel from " << kernelmod_lib->Name() << " failed, op_type: " << match_ks.op_type - << ", kernel attr: " << match_ks.attr; - return nullptr; - } - MS_LOG(INFO) << "Create KernelMod kernel for " << match_ks.cnode->fullname_with_scope(); - return kernel; - } - return nullptr; - } -}; -} // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_KERNEL_SELECTOR_NNACL_FIRST_KERNEL_SELECTOR_H diff --git a/mindspore-lite/src/extendrt/kernel/nnacl/nnacl_base_kernel.h b/mindspore-lite/src/extendrt/kernel/nnacl/nnacl_base_kernel.h deleted file mode 100644 index d9e2444f4334226e6464614132d03a51713bd72e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/nnacl/nnacl_base_kernel.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_NNACL_NNACL_BASE_KERNEL_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_NNACL_NNACL_BASE_KERNEL_H_ - -#include -#include -#include -#include -#include "src/extendrt/kernel/base_kernel.h" -#include "src/litert/lite_kernel.h" -#include "ops/base_operator.h" - -namespace mindspore::kernel { -class NNACLBaseKernel : public BaseKernel { - public: - explicit NNACLBaseKernel(std::shared_ptr lite_kernel) - : BaseKernel({}, nullptr), lite_kernel_(std::move(lite_kernel)) { - this->type_ = schema::EnumNamePrimitiveType(lite_kernel_->type()); - } - ~NNACLBaseKernel() override = default; - - int Prepare() override { return lite_kernel_->Prepare(); } - int InferShape() override { return lite_kernel_->InferShape(); } - int ReSize() override { return lite_kernel_->ReSize(); } - int Run() override { return lite::RET_ERROR; } - int Execute() override { return lite_kernel_->Execute(); } - const std::vector &inputs() override { return lite_kernel_->inputs(); } - const std::vector &outputs() override { return lite_kernel_->outputs(); } - void set_in_tensors(const std::vector &in_tensors) override { - lite_kernel_->set_in_tensors(in_tensors); - } - void set_in_tensor(InferTensor *in_tensor, size_t index) override { lite_kernel_->set_in_tensor(in_tensor, index); } - void set_out_tensors(const std::vector &out_tensors) override { - lite_kernel_->set_out_tensors(out_tensors); - } - void set_out_tensor(InferTensor *out_tensor, size_t index) override { - lite_kernel_->set_out_tensor(out_tensor, index); - } - const std::vector &in_tensors() const override { return lite_kernel_->in_tensors(); } - const std::vector &out_tensors() const override { return lite_kernel_->out_tensors(); } - OpParameter *op_parameter() const override { return lite_kernel_->op_parameter(); } - - private: - std::shared_ptr lite_kernel_; -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_NNACL_NNACL_BASE_KERNEL_H_ diff --git a/mindspore-lite/src/extendrt/kernel/nnacl/nnacl_lib.cc b/mindspore-lite/src/extendrt/kernel/nnacl/nnacl_lib.cc deleted file mode 100644 index c6c413fa59d40c287381380cb968f1472fc8c1a3..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/kernel/nnacl/nnacl_lib.cc +++ /dev/null @@ -1,111 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/kernel/nnacl/nnacl_lib.h" -#include -#include -#include -#include "src/litert/kernel_registry.h" -#include "src/common/ops/operator_populate/operator_populate_register.h" -#include "src/infer/graph_compiler.h" -#include "src/extendrt/kernel/nnacl/nnacl_base_kernel.h" - -namespace mindspore::kernel { -namespace { -TypeId GetFirstFp32Fp16OrInt8Type(const KernelAttr &attr) { - if (attr.GetInputSize() == 0) { - MS_LOG(WARNING) << "in tensor is empty."; - return kTypeUnknown; - } - for (size_t i = 0; i < attr.GetInputSize(); i++) { - auto dtype = attr.GetInputAttr(i).dtype; - if (dtype == kObjectTypeTensorType) { - MS_LOG(WARNING) << "Not support TensorType Tensor now!"; - return kTypeUnknown; - } - std::unordered_set type_set = {kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeInt8, kNumberTypeInt32, - kNumberTypeBool, kNumberTypeUInt8, kObjectTypeString}; - if (type_set.find(dtype) != type_set.end()) { - return dtype; - } - } - - return attr.GetInputAttr(0).dtype; -} -} // namespace - -bool NNACLLib::Support(const PrimitiveType &op_type, const KernelAttr &attr, const std::string &backend, - const Format &format) const { - if (backend != kBackendCPU) { - MS_LOG(INFO) << "NNACL only support CPU backend, but got: " << backend << "."; - return false; - } - if (!MatchFormat(format, NHWC)) { - MS_LOG(INFO) << "NNACL not support NHWC layout."; - return false; - } - TypeId data_type = GetFirstFp32Fp16OrInt8Type(attr); - if (data_type == kTypeUnknown) { - MS_LOG(INFO) << "Get main datatype of kernel failed."; - return false; - } - // call SupportKernelC in nnacl/kernel.h directly in the further - kernel::KernelKey key{kCPU, data_type, NHWC, op_type.SchemaType()}; - return lite::KernelRegistry::GetInstance()->SupportKernel(key); -} - -BaseKernel *NNACLLib::CreateKernel(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) const { - if (!MatchFormat(spec.format, NHWC)) { - MS_LOG(INFO) << "NNACL only support NHWC layout, but got " << FormatEnumToString(spec.format); - return nullptr; - } - TypeId data_type = GetFirstFp32Fp16OrInt8Type(spec.attr); - if (data_type == kTypeUnknown) { - MS_LOG(INFO) << "Get main datatype of kernel failed while creating nnacl kernel."; - return nullptr; - } - auto op_parameter = lite::OperatorPopulateRegistry::GetInstance()->CreatePopulateByOp(spec.primitive); - if (op_parameter == nullptr) { - MS_LOG(INFO) << "Populate op-parameter for kernel failed, kernel-type: " << spec.op_type; - return nullptr; - } - op_parameter->thread_num_ = ctx->thread_num_; - // create nnacl kernel base directly in the further - kernel::KernelKey key{kCPU, data_type, NHWC, spec.op_type.SchemaType()}; - auto lite_kernel = lite::KernelRegistry::GetInstance()->GetLiteKernel(inputs, outputs, ctx, &key, op_parameter); - if (lite_kernel == nullptr) { - MS_LOG(INFO) << "Create lite kernel failed: " << op_parameter->name_; - free(op_parameter); - return nullptr; - } - return new NNACLBaseKernel(std::shared_ptr(lite_kernel)); -} - -InferKernel *NNACLLib::CreateKernelExec(const KernelSpec &spec, const std::vector &inputs, - const std::vector &outputs, const InferContext *ctx) const { - auto *kernel_exec = KernelLib::CreateKernelExec(spec, inputs, outputs, ctx); - if (kernel_exec == nullptr) { - return nullptr; - } - auto desc = kernel_exec->desc(); - desc.format = Format::NHWC; - kernel_exec->set_desc(desc); - return kernel_exec; -} - -REG_KERNEL_LIB(kNNACLLibName, NNACLLib); -} // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/memory_offload/infer_strategy_builder.cc b/mindspore-lite/src/extendrt/memory_offload/infer_strategy_builder.cc deleted file mode 100644 index 57fc077b8cb4e3e73f0dd696ed7c3076dc28f2b5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/memory_offload/infer_strategy_builder.cc +++ /dev/null @@ -1,221 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "src/extendrt/memory_offload/infer_strategy_builder.h" - -namespace mindspore { -namespace lite { -namespace { -const size_t kSwapVirtualNodeNum = 2; -template -void CheckVectorIndex(const std::vector &input, size_t index) { - if (input.size() <= index) { - MS_LOG_EXCEPTION << "Invalid vector index " << index << ", vector size is " << input.size(); - } -} -} // namespace - -void MemoryOffloadInferStrategyBuilder::ResetState(const lite::CompileResultPtr &compile_result, - const std::shared_ptr &context) { - MS_EXCEPTION_IF_NULL(compile_result); - MS_EXCEPTION_IF_NULL(context); - - context_ = context; - total_mem_level0_ = context->hbm_mem_size_; - total_mem_level1_ = context->cpu_mem_size_; - - kernel_num_ = compile_result->NodeSize(); - - mem_used_level0_.clear(); - mem_used_level0_.resize(kernel_num_, 0); - mem_used_level1_.clear(); - mem_used_level1_.resize(kernel_num_, 0); - - span_level1_.clear(); - span_level2_.clear(); - auto tmp_queue = std::priority_queue, std::vector>, SpanCmp>(); - span_queue_.swap(tmp_queue); - - kernel_actions_.clear(); - kernel_actions_.resize(kernel_num_ + kSwapVirtualNodeNum); -} - -void MemoryOffloadInferStrategyBuilder::AnalyzeMemoryInfo(const lite::CompileResultPtr &compile_result) { - auto &exec_order = compile_result->GetNodes(); - least_mem_ = SIZE_MAX; - size_t kernel_mem = 0; - for (size_t i = 0; i < exec_order.size(); ++i) { - auto compile_node = exec_order[i]; - for (auto in_tensor : compile_node->GetInputs()) { - tensor_usedby_kernel_ids_[in_tensor].insert(i); - kernel_mem += in_tensor->Size(); - } - for (auto out_tensor : compile_node->GetOutputs()) { - tensor_to_kernel_id_[out_tensor] = i; - kernel_mem += out_tensor->Size(); - } - node_to_mem_size_[compile_node] = kernel_mem; - if (kernel_mem > least_mem_) { - least_mem_ = kernel_mem; - } - kernel_mem = 0; - } - - auto &all_tensors = compile_result->GetTensors(); - for (size_t i = 0; i < all_tensors.size(); ++i) { - tensor_to_index_[all_tensors[i]] = i; - } -} - -void MemoryOffloadInferStrategyBuilder::RecordSpan(const Tensor *tensor, size_t last_index, size_t current_index, - bool output_span) { - auto dist = current_index - last_index; - if (dist <= 1) { - return; - } - - MS_EXCEPTION_IF_NULL(tensor); - MS_EXCEPTION_IF_NULL(context_); - - auto span = std::make_shared(); - MS_EXCEPTION_IF_NULL(span); - span->tensor_id_ = tensor_to_index_[tensor]; - span->tensor_size_ = tensor->Size(); - span->last_index_ = last_index; - span->current_index_ = current_index; - span->weight_ = (dist - 1) * span->tensor_size_; - span->output_span_ = output_span; - - bool offload_param = context_->offload_param_to_cpu_ || context_->offload_param_to_disk_; - if (offload_param && tensor->category() == lite::PARAMETER) { - (void)offload_param_spans_.emplace_back(span); - } else { - span_queue_.emplace(span); - } -} - -void MemoryOffloadInferStrategyBuilder::BuildSpans() { - auto iter = tensor_usedby_kernel_ids_.begin(); - while (iter != tensor_usedby_kernel_ids_.end()) { - auto &used_by_kernels = iter->second; - if (used_by_kernels.empty()) { - continue; - } - auto tensor = iter->first; - size_t first_index = tensor_to_kernel_id_[tensor]; - size_t last_index = first_index; - for (auto current_index = used_by_kernels.begin(); current_index != used_by_kernels.end(); current_index++) { - if (first_index == *current_index) { - continue; - } - - RecordSpan(tensor, last_index, *current_index); - last_index = *current_index; - } - - // if tensor is const or parameter, then tensor data will try to store persistently - if (tensor->category() == lite::PARAMETER || tensor->category() == lite::GRAPH_INPUT || tensor->data() != nullptr) { - RecordSpan(tensor, last_index, first_index + kernel_num_); - } else if (tensor->category() == lite::GRAPH_OUTPUT) { - RecordSpan(tensor, last_index, kernel_num_, true); - } - } -} - -void MemoryOffloadInferStrategyBuilder::ClassifySpanLevel() { - while (!span_queue_.empty()) { - auto span = span_queue_.top(); - bool enough = device::SwapStrategyBuilder::EnoughSpaceForSpan(span, &mem_used_level0_, total_mem_level0_); - if (!enough) { - enough = device::SwapStrategyBuilder::EnoughSpaceForSpan(span, &mem_used_level1_, total_mem_level1_); - if (enough) { - (void)span_level1_.emplace_back(span); - } else { - (void)span_level2_.emplace_back(span); - } - } - span_queue_.pop(); - } -} - -void MemoryOffloadInferStrategyBuilder::AddTensorAction(device::SwapActionType action_type, size_t tensor_id, - size_t kernel_id) { - auto action = std::make_shared(); - action->action_ = action_type; - action->tensor_id_ = tensor_id; - - if (kernel_id > 0 && - (action_type == device::SwapActionType::kHBM2DDR || action_type == device::SwapActionType::kHBM2DISK)) { - auto tensor = compile_result_->GetTensors()[tensor_id]; - if (tensor->category() != lite::PARAMETER) { - action->avoid_copy_ = true; - } - } - - CheckVectorIndex(kernel_actions_, kernel_id); - (void)kernel_actions_[kernel_id].emplace_back(action); -} - -std::shared_ptr MemoryOffloadInferStrategyBuilder::BuildStrategy( - const lite::CompileResultPtr &compile_result) { - MS_EXCEPTION_IF_NULL(compile_result); - auto &exec_order = compile_result->GetNodes(); - - auto strategy = std::make_shared(); - strategy->kernel_num_ = kernel_num_; - strategy->virtual_node_num_ = kSwapVirtualNodeNum; - size_t last_index = 0; - for (size_t i = 0; i < kernel_num_; ++i) { - strategy->nodes_[i + 1] = exec_order[i]->GetCNode(); - (void)strategy->links_.emplace_back(std::make_shared(last_index, i + 1)); - last_index = i + 1; - } - - size_t logic_kernel_num = kernel_actions_.size(); - size_t action_id = logic_kernel_num; - for (size_t i = 0; i < logic_kernel_num; ++i) { - auto &actions = kernel_actions_[i]; - if (actions.empty()) { - continue; - } - auto swap_action = std::make_shared(); - swap_action->actions_ = actions; - strategy->actions_[action_id] = swap_action; - (void)strategy->links_.emplace_back(std::make_shared(i, action_id)); - (void)strategy->links_.emplace_back(std::make_shared(action_id, i + 1)); - ++action_id; - } - - return strategy; -} - -std::shared_ptr MemoryOffloadInferStrategyBuilder::Build( - const lite::CompileResultPtr &compile_result, const std::shared_ptr &context) { - MS_EXCEPTION_IF_NULL(compile_result); - MS_EXCEPTION_IF_NULL(context); - ResetState(compile_result, context); - - AnalyzeMemoryInfo(compile_result); - - BuildSpans(); - - ClassifySpanLevel(); - - device::SwapStrategyBuilder::SpanToTensorAction(); - - return BuildStrategy(compile_result); -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/memory_offload/infer_strategy_builder.h b/mindspore-lite/src/extendrt/memory_offload/infer_strategy_builder.h deleted file mode 100644 index e73d0af136102247598c6e3b34dfdace0af00af5..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/memory_offload/infer_strategy_builder.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_MEMORY_OFFLOAD_STRATEGY_BUILDER_H_ -#define MINDSPORE_LITE_EXTENDRT_MEMORY_OFFLOAD_STRATEGY_BUILDER_H_ -#include -#include -#include -#include -#include -#include "runtime/device/gsm/swap_strategy.h" -#include "runtime/device/gsm/swap_strategy_builder.h" -#include "src/extendrt/graph_compiler/compile_result_builder.h" - -namespace mindspore { -namespace lite { -class MemoryOffloadInferStrategyBuilder : public device::SwapStrategyBuilder { - public: - MemoryOffloadInferStrategyBuilder() = default; - ~MemoryOffloadInferStrategyBuilder() = default; - std::shared_ptr Build(const lite::CompileResultPtr &compile_result, - const std::shared_ptr &context); - - private: - void ResetState(const lite::CompileResultPtr &compile_result, const std::shared_ptr &context); - void RecordSpan(const Tensor *tensor, size_t last_index, size_t current_index, bool output_span = false); - std::shared_ptr BuildStrategy(const lite::CompileResultPtr &compile_result); - void BuildSpans(); - void AnalyzeMemoryInfo(const lite::CompileResultPtr &compile_result); - void ClassifySpanLevel(); - void ClassifyOffloadSpanLevel(const std::vector> &spans, bool offload_to_ddr); - void AddTensorAction(device::SwapActionType action_type, size_t tensor_id, size_t kernel_id); - - std::shared_ptr context_{nullptr}; - size_t prefetch_mem_size_{0}; - std::map node_to_mem_size_; - std::map tensor_to_index_; - std::map tensor_to_kernel_id_; // output tensor, node/kernel id - std::map> tensor_usedby_kernel_ids_; - size_t least_mem_ = SIZE_MAX; - lite::CompileResultPtr compile_result_; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_LITE_EXTENDRT_MEMORY_OFFLOAD_STRATEGY_BUILDER_H_ diff --git a/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_loader.cc b/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_loader.cc index 85a6f2519eb299def1e87ed2e75c23b4a9b42494..03e4a55e7167423bb9c2a4c689bb2daef9ca2374 100644 --- a/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_loader.cc +++ b/mindspore-lite/src/extendrt/mindir_loader/mindir_model/mindir_model_loader.cc @@ -72,10 +72,6 @@ bool MindirModelLoader::ConvertModel(const mind_ir::ModelProto &model_proto) { } else { // no subgraph, add graph to subgraph auto *sub_graph = new (std::nothrow) LiteGraph::SubGraph(); - if (sub_graph ==nullptr) { - MS_LOG(ERROR) << "new subgraph failed."; - return false; - } sub_graph->name_ = model_proto.graph().name(); MS_CHECK_TRUE_MSG( ConvertGraph(model_proto.graph(), sub_graph, true), false, diff --git a/mindspore-lite/src/extendrt/mock/anf_ir_dump.cc b/mindspore-lite/src/extendrt/mock/anf_ir_dump.cc deleted file mode 100644 index feee9db99b31562ef1bea8d1c6171aa49cab8f4f..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/mock/anf_ir_dump.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "include/common/debug/anf_ir_dump.h" - -namespace mindspore { -#define PrintDeprecatedLog \ - static bool already_printed = false; \ - if (already_printed) { \ - return; \ - } \ - already_printed = true; \ - MS_LOG(WARNING) << "The functionality of dumping function graph IR is disabled, " \ - << "please recompile source to enable it. See help of building script."; - -void DumpIR(const std::string &, const FuncGraphPtr &, bool, LocDumpMode, const std::string &) { PrintDeprecatedLog } - -void DumpIR(std::ostringstream &, const FuncGraphPtr &, bool, LocDumpMode) { PrintDeprecatedLog } - -void DumpIRForRDR(const std::string &, const FuncGraphPtr &, bool, LocDumpMode) { PrintDeprecatedLog } - -#undef PrintDeprecatedLog -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/mock/executor.cc b/mindspore-lite/src/extendrt/mock/executor.cc deleted file mode 100644 index 3e853cf8d5779c3dce6a2b1c2786ba2c311f6d3d..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/mock/executor.cc +++ /dev/null @@ -1,432 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "backend/common/session/executor.h" -#include "backend/common/session/executor_manager.h" -#include -#include -#include -#include "runtime/device/kernel_runtime_manager.h" -#include "include/common/utils/comm_manager.h" -#include "include/common/utils/scoped_long_running.h" - -namespace mindspore { -namespace session { -namespace { -void GetNeedNotifyTensors(const VectorRef *outputs, std::set *result) { - MS_EXCEPTION_IF_NULL(outputs); - MS_EXCEPTION_IF_NULL(result); - for (auto &item : *outputs) { - if (utils::isa(item)) { - auto vector_ref = utils::cast(item); - GetNeedNotifyTensors(&vector_ref, result); - } else if (utils::isa(item)) { - auto tensor = utils::cast(item); - result->emplace(tensor); - } - } -} - -bool TensorInVector(const VectorRef *outputs) { - MS_EXCEPTION_IF_NULL(outputs); - for (auto &item : *outputs) { - if (utils::isa(item)) { - auto vector_ref = utils::cast(item); - if (TensorInVector(&vector_ref)) { - return true; - } - } else if (utils::isa(item)) { - return true; - } - } - return false; -} - -bool IsTaskReady(const std::shared_ptr &task) { - MS_EXCEPTION_IF_NULL(task); - for (auto &input : task->input_need_wait_tensors_) { - MS_EXCEPTION_IF_NULL(input); - if (input->NeedWait()) { - return false; - } - } - auto session = task->session_; - MS_EXCEPTION_IF_NULL(session); - auto graph = session->GetGraph(task->graph_id_); - if (graph != nullptr) { - return graph->IsPreGraphFinished(); - } - return true; -} - -void WaitLockedInputs(const std::shared_ptr &task) { - bool need_lock = false; - for (auto &tensor : task->input_tensors_) { - if (tensor->NeedWait()) { - if (tensor->IsGraphOutput()) { - task->input_need_wait_tensors_.emplace_back(tensor); - } else { - need_lock = true; - } - } - } - if (need_lock) { - mindspore::ScopedLongRunning long_running; - for (auto &input_tensor : task->input_tensors_) { - if (input_tensor->NeedWait() && !input_tensor->IsGraphOutput()) { - MsException::Instance().CheckException(); - input_tensor->Wait(); - } - } - MsException::Instance().CheckException(); - } - // need lock input parameters for optimizer - for (auto &need_lock_tensor : task->input_need_lock_tensors_) { - need_lock_tensor->SetNeedWait(true); - } -} -} // namespace - -void CompileNodesTask::Run() { - MS_EXCEPTION_IF_NULL(session_); - MS_EXCEPTION_IF_NULL(segment_); - graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_); -} - -void CompileGraphTask::Run() { - MS_EXCEPTION_IF_NULL(session_); - graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_)); -} - -void BuildGraphTask::Run() { - MS_EXCEPTION_IF_NULL(session_); - session_->BuildGraphImpl(graph_id_); -} - -void RunGraphTask::Run() { - MS_EXCEPTION_IF_NULL(session_); - MS_LOG(INFO) << "Start run graph " << graph_id_; - auto graph = session_->GetGraph(graph_id_); - if (graph == nullptr) { - MS_LOG(ERROR) << "Invalid graph id " << graph_id_; - return; - } - graph->ResetGraphRunningStatus(); - if (AnfUtils::UseMemScheduler()) { - graph->SetOutputNodeToTensor(node_to_tensor_); - } - try { - session_->LoadInputs(graph_id_, input_tensors_); - session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); - std::map new_to_old_device_address; - session_->UpdateOutputTensors(&outputs_, tensor_to_node_, &new_to_old_device_address); - } catch (const std::exception &e) { - session_->ReportErrorMessage(); - ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); - MsException::Instance().SetException(); - } - MS_LOG(INFO) << "End run graph " << graph_id_; - graph->OnRunGraphFinished(); - std::set need_notify_tensors(input_need_lock_tensors_.begin(), input_need_lock_tensors_.end()); - GetNeedNotifyTensors(&outputs_, &need_notify_tensors); - for (auto &tensor : need_notify_tensors) { - if (tensor != nullptr) { - tensor->SetNeedWait(false); - } - } - ExecutorManager::Instance().OnEvent(ExecutorEvent::kRunGraphFinished); -} - -void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); } - -void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); } - -Executor::Executor(const std::string &device_name, uint32_t device_id) { - device_name_ = device_name; - device_id_ = device_id; - worker_ = std::make_shared(&Executor::WorkerLoop, this); -} - -Executor::~Executor() { - try { - WorkerJoin(); - } catch (const std::exception &e) { - MS_LOG(ERROR) << "Executor call destructor failed: " << e.what(); - } catch (...) { - MS_LOG(ERROR) << "Executor call destructor failed."; - } -} - -void Executor::WorkerJoin() { - // Avoid worker thread join itself which will cause deadlock - if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) { - { - std::lock_guard lock(task_mutex_); - auto task = std::make_shared(); - ready_tasks_.push(task); - task_cond_var_.notify_all(); - } - worker_->join(); - } -} - -void Executor::WorkerLoop() { - while (true) { - std::shared_ptr task; - { - std::unique_lock lock(task_mutex_); - task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); }); - task = ready_tasks_.front(); - ready_tasks_.pop(); - } - MS_EXCEPTION_IF_NULL(task); - enum TaskType task_type = task->type_; - bool task_sync_flag = task->sync_run_; - if (task_type == kExit) { - OnWorkerExit(); - return; - } - try { - if (task->session_ != nullptr) { - task->session_->SetThreadContext(); - } - task->Run(); - if (task->session_ != nullptr) { - task->session_->ReportWarningMessage(); - } - } catch (const std::exception &e) { - if (task->session_ != nullptr) { - task->session_->ReportErrorMessage(); - } - ExecutorManager::Instance().OnEvent(ExecutorEvent::kException); - MsException::Instance().SetException(); - } - { - std::lock_guard lock(done_task_mutex_); - done_tasks_.emplace_back(std::move(task)); - } - if (task_type != kRunGraph || task_sync_flag) { - std::lock_guard lock(task_mutex_); - sync_run_task_finished_ = true; - sync_cond_var_.notify_all(); - } - } -} - -std::vector> Executor::GetReadyTasksFromPendingList() { - std::vector> ready_tasks; - std::lock_guard lock(pending_task_mutex_); - for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) { - auto task = *iter; - if (IsTaskReady(task)) { - (void)ready_tasks.emplace_back(task); - iter = pending_tasks_.erase(iter); - } else { - ++iter; - } - } - return ready_tasks; -} - -void Executor::OnEvent(const ExecutorEvent &event) { - if (event == ExecutorEvent::kRunGraphFinished) { - OnRunGraphFinished(); - } else if (event == ExecutorEvent::kClear) { - OnClear(); - } else if (event == ExecutorEvent::kException) { - OnException(); - } -} - -void Executor::OnClear() { - { - mindspore::ScopedLongRunning long_running; - WorkerJoin(); - } - ClearDoneTasks(); -} - -void Executor::OnException() { - std::vector> done_tasks; - { - std::lock_guard lock(task_mutex_); - while (!ready_tasks_.empty()) { - (void)done_tasks.emplace_back(ready_tasks_.front()); - ready_tasks_.pop(); - } - } - { - std::lock_guard lock(pending_task_mutex_); - (void)std::copy(pending_tasks_.begin(), pending_tasks_.end(), std::back_inserter(done_tasks)); - pending_tasks_.clear(); - } - { - std::lock_guard lock(done_task_mutex_); - (void)done_tasks_.insert(done_tasks_.end(), done_tasks.begin(), done_tasks.end()); - } -} - -void Executor::OnRunGraphFinished() { - auto ready_tasks = GetReadyTasksFromPendingList(); - std::lock_guard lock(task_mutex_); - for (auto &task : ready_tasks) { - ready_tasks_.push(task); - } - if (!ready_tasks.empty()) { - task_cond_var_.notify_all(); - } - reenter_cond_var_.notify_all(); -} - -void Executor::ClearDoneTasks() { - std::lock_guard lock(done_task_mutex_); - done_tasks_.clear(); -} - -void Executor::RunTask(const std::shared_ptr &task, bool sync, bool long_run) { - if (sync) { - ClearDoneTasks(); - } - { - std::lock_guard lock(task_mutex_); - sync_run_task_finished_ = false; - ready_tasks_.push(task); - } - task_cond_var_.notify_all(); - if (sync && !sync_run_task_finished_) { - std::unique_lock lock(task_mutex_); - if (sync && long_run) { - mindspore::ScopedLongRunning long_running; - sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; }); - } else { - sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; }); - } - } - ClearDoneTasks(); - MsException::Instance().CheckException(); -} - -GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, - const AnfNodePtrList &outputs) { - auto task = std::make_shared(); - task->session_ = session; - task->segment_ = segment; - task->output_nodes_ = outputs; - RunTask(task, true); - return task->graph_id_; -} - -GraphId Executor::CompileGraph(const SessionPtr &session, NotNull func_graph) { - auto task = std::make_shared(); - task->session_ = session; - task->func_graph_ = func_graph.get(); - RunTask(task, true); - return task->graph_id_; -} - -void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) { - auto task = std::make_shared(); - task->session_ = session; - task->graph_id_ = graphId; - RunTask(task, true); -} - -void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id, - const std::vector &inputs, VectorRef *outputs) { - MS_EXCEPTION_IF_NULL(session); - MS_EXCEPTION_IF_NULL(outputs); - auto task = std::make_shared(); - task->session_ = session; - task->graph_id_ = graph_id; - task->input_tensors_ = inputs; - session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_, &task->node_to_tensor_); - task->outputs_ = *outputs; - task->sync_run_ = true; - RunTask(task, true, true); -} - -void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, - const std::vector &inputs, VectorRef *outputs) { - MS_EXCEPTION_IF_NULL(session); - MS_EXCEPTION_IF_NULL(outputs); - auto task = std::make_shared(); - task->session_ = session; - task->graph_id_ = graph_id; - task->input_tensors_ = inputs; - task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs); - auto graph = session->GetGraph(task->graph_id_); - if (graph != nullptr && !graph->IsPostGraphFinished()) { - mindspore::ScopedLongRunning long_running; - std::unique_lock lock(reenter_mutex_); - reenter_cond_var_.wait(lock, [&graph] { return graph->IsPostGraphFinished(); }); - MsException::Instance().CheckException(); - } - session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_, &task->node_to_tensor_); - // maintain a copy of output vector - task->outputs_ = *outputs; - - // Run graph synchronously when the graph require gil. - if (graph != nullptr && graph->is_need_gil()) { - std::unique_lock lock(reenter_mutex_); - reenter_cond_var_.wait(lock, [&graph] { return graph->IsPreGraphFinished(); }); - MsException::Instance().CheckException(); - task->sync_run_ = true; - RunTask(task, true, true); - return; - } - - // sync run graph without output tensor(int dataset graph) - if ((!TensorInVector(outputs) && !graph->HasPostGraph())) { - task->sync_run_ = true; - RunTask(task, true, true); - return; - } - WaitLockedInputs(task); - for (auto &tensor_node : task->tensor_to_node_) { - tensor_node.first->SetNeedWait(true); - } - { - std::lock_guard lock(pending_task_mutex_); - if (!IsTaskReady(task)) { - ClearDoneTasks(); - pending_tasks_.push_back(task); - return; - } - } - RunTask(task, false); -} - -bool Executor::CreateCommGroup(const std::string &group_name, const std::vector &ranks) { - auto task = std::make_shared(); - task->group_name_ = group_name; - task->ranks_ = ranks; - RunTask(task, true); - return task->result_; -} - -bool Executor::DestroyCommGroup(const std::string &group_name) { - auto task = std::make_shared(); - task->group_name_ = group_name; - RunTask(task, true); - return task->result_; -} - -void Executor::OnWorkerExit() { - if (device_name_ == kAscendDevice) { - device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_); - } -} -} // namespace session -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/mock/ge_mock.cc b/mindspore-lite/src/extendrt/mock/ge_mock.cc deleted file mode 100644 index 121628f22c1dda35ba242f4ef3b9b9d954dbacf3..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/mock/ge_mock.cc +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_MOCK_H -#define GE_MOCK_H -#include -#include -#include -#include -#include -#include "graph/tensor.h" -#include "graph/operator_reg.h" -#include "graph/operator.h" - -#include "ge/ge_api.h" - -namespace ge { - -Session::Session(const std::map &options) {} -Session::~Session() {} - -Status Session::RunGraph(uint32_t id, const std::vector &inputs, std::vector &outputs) { - // for test!!! just copy inputs to outputs: - for (auto it = inputs.begin(); it != inputs.end(); it++) { - outputs.emplace_back(*it); - } - return ge::GRAPH_SUCCESS; -} - -Status Session::AddGraph(uint32_t id, const Graph &graph) { return ge::GRAPH_SUCCESS; } - -Status GEInitialize(const std::map &options) { return ge::GRAPH_SUCCESS; } - -Status GEFinalize() { return ge::GRAPH_SUCCESS; } - -Status Graph::SaveToFile(const string &file_name) const { return ge::GRAPH_SUCCESS; } - -Status Session::RunGraphAsync(uint32_t graph_id, const std::vector &inputs, RunAsyncCallback callback) { - return ge::GRAPH_SUCCESS; -} - -Status Session::RunGraphAsync(uint32_t graph_id, const ContinuousTensorList &inputs, RunAsyncCallback callback) { - return ge::GRAPH_SUCCESS; -} - -Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map &options) { - return ge::GRAPH_SUCCESS; -} - -Status Session::CompileGraph(uint32_t graph_id) { return ge::GRAPH_SUCCESS; } - -CompiledGraphSummaryPtr Session::GetCompiledGraphSummary(uint32_t graph_id) { return nullptr; } - -Status Session::UpdateGraphFeatureMemoryBase(uint32_t graph_id, const void *const memory, size_t size) { - return ge::GRAPH_SUCCESS; -} - -Status Session::SetGraphConstMemoryBase(uint32_t graph_id, const void *const memory, size_t size) { - return ge::GRAPH_SUCCESS; -} - -Status Session::RemoveGraph(uint32_t graph_id) { return ge::GRAPH_SUCCESS; } - -Status Session::RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const std::vector &inputs, - std::vector &outputs) { - return ge::GRAPH_SUCCESS; -} - -void Operator::RequiredAttrWithTypeRegister(const char_t *name, const char_t *type) {} -} // namespace ge - -#endif diff --git a/mindspore-lite/src/extendrt/mock/segment_runner.cc b/mindspore-lite/src/extendrt/mock/segment_runner.cc deleted file mode 100644 index a34187729e9c8c7a65bb8d8d5b2c9308a80562e7..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/mock/segment_runner.cc +++ /dev/null @@ -1,170 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "backend/graph_compiler/segment_runner.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "mindspore/ops/op_def/sequence_ops.h" -#include "mindspore/ops/op_def/framework_ops.h" -#include "frontend/operator/ops.h" -#include "include/common/utils/utils.h" -#include "ir/func_graph_cloner.h" -#include "ir/manager.h" -#include "utils/hash_map.h" -#include "utils/hash_set.h" -#include "src/common/log_adapter.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" - -namespace mindspore { -namespace compile { -namespace { -// Return the list of nodes whose values are required beyond this segment. -// Arguments: -// nodes: list of nodes in the segment -// users: dict mapping each node to its users (globally) -// seen: set of nodes that are part of the segment -AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users, - const mindspore::HashSet &seen) { - AnfNodePtrList output; - if (users.size() == 0) { - return output; - } - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto iter = users.find(node); - if (iter == users.end()) { - continue; - } - auto &node_users = iter->second; - const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users), - [&seen](const std::pair &u) -> bool { - const bool is_outer_user = (seen.find(u.first) == seen.end()); - return is_outer_user; - }); - if (has_outer_user) { - output.emplace_back(node); - } - } - return output; -} - -AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr, - AnfNodePtrToAnfNodePtrMap *eqv_ptr) { - MS_EXCEPTION_IF_NULL(fg); - MS_EXCEPTION_IF_NULL(inputs_ptr); - MS_EXCEPTION_IF_NULL(eqv_ptr); - MS_EXCEPTION_IF_NULL(node); - auto &inputs = *inputs_ptr; - auto &eqv = *eqv_ptr; - if (node->isa() && !IsValueNode(node)) { - eqv[node] = node; - } else if (eqv.find(node) == eqv.end()) { - inputs.push_back(node); - eqv[node] = fg->add_parameter(); - eqv[node]->set_abstract(node->abstract()); - eqv[node]->set_kernel_info(node->kernel_info_ptr()); - } - return eqv[node]; -} -} // namespace - -std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList &lst) { - if (lst.empty()) { - MS_LOG(EXCEPTION) << "Input anf node list is empty"; - } - FuncGraphPtr fg = nullptr; - { - // limit the lifetime of guard. - MS_EXCEPTION_IF_NULL(lst[0]->cast()); - MS_EXCEPTION_IF_NULL(lst[0]->cast()->func_graph()); - TraceGuard guard(MakeTraceInfo(lst[0]->cast()->func_graph()->debug_info())); - fg = std::make_shared(); - } - AnfNodePtrList inputs; - AnfNodePtrToAnfNodePtrMap eqv; - // Merge CNodes into a AnfGraph that represents a linear instruction segment - for (auto n : lst) { - MS_EXCEPTION_IF_NULL(n); - if (!n->isa()) { - MS_LOG(EXCEPTION) << "Inst is not CNode"; - } - auto &inps = n->cast()->inputs(); - if (inps.empty()) { - MS_LOG(EXCEPTION) << "Input is empty"; - } - if (!IsValueNode(inps[0]) && - !(IsValueNode(inps[0]) && - inps[0]->cast()->value()->cast()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) { - MS_LOG(EXCEPTION) << "Input[0] must be a Primitive ValueNode"; - } - auto fn = inps[0]; - std::vector args{fn}; - if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() >= kDependInputSize && - eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { - args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv)); - const size_t value_start_index = 2; - for (size_t i = value_start_index; i < inps.size(); ++i) { - args.emplace_back(NewValueNode(MakeValue(0))); - } - } else { - (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), - [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); - } - TraceGuard tg(MakeTraceInfo(n->debug_info())); - MS_EXCEPTION_IF_NULL(fg); - eqv[n] = fg->NewCNode(args); - eqv[n]->set_abstract(n->abstract()); - eqv[n]->set_kernel_info(n->kernel_info_ptr()); - } - mindspore::HashSet eqv_keys; - for (auto &e : eqv) { - (void)eqv_keys.emplace(e.first); - } - auto mgr = lst[0]->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(mgr); - auto outputs = GetOutput(lst, mgr->node_users(), eqv_keys); - AnfNodePtr fg_output; - if (outputs.size() > 1) { - std::vector output_args; - output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args), - [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; }); - // Set output for AnfGraph - fg_output = fg->NewCNode(output_args); - } else { - if (outputs.empty()) { - MS_LOG(EXCEPTION) << "Output is empty."; - } - fg_output = eqv[outputs[0]]; - } - fg->set_output(fg_output); - return std::make_tuple(fg, inputs, outputs); -} -} // namespace compile -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/mock/transform_mock.cc b/mindspore-lite/src/extendrt/mock/transform_mock.cc deleted file mode 100644 index cac4459af3c2e6934e600c53056a678a08bc1413..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/mock/transform_mock.cc +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2023-2024 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef TRANSFORM_MOCK_H -#define TRANSFORM_MOCK_H -#include -#include "plugin/res_manager/ascend/op_adapter/op_adapter_map.h" -#include "graph/operator.h" -#include "plugin/res_manager/ascend/op_adapter/op_adapter_desc.h" -#include "plugin/res_manager/ascend/op_adapter/op_adapter.h" - -namespace mindspore { -namespace transform { -namespace { -mindspore::HashMap adpt_map_ = { - {kNameCustomOp, std::make_shared(std::make_shared>(""))}}; -} // namespace - -mindspore::HashMap &OpAdapterMap::get() { return adpt_map_; } - -void OpAdapterImpl::updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, - const AnfNodePtr &node) {} - -int OpAdapterImpl::setAttr(const OperatorPtr &op, const std::string &attr_key, const ValuePtr &attr_value) { return 0; } -int OpAdapterImpl::setAttr(const OperatorPtr &op, const PrimitivePtr &prim) { return 0; } -int OpAdapterImpl::setAttr(const OperatorPtr &op, const AnfNodePtr &node) { return 0; } -int OpAdapterImpl::setAttr(const OperatorPtr &op, const uint32_t &input_idx, const ValuePtr &attr_value) { return 0; } -int OpAdapterImpl::setInput(const OperatorPtr &op, int index, const OutHandler &handle) { return 0; } -int OpAdapterImpl::setInput(const OperatorPtr &op, int index, const OperatorPtr &input) { return 0; } -int OpAdapterImpl::setInput(const OperatorPtr &op, int index, - const std::shared_ptr> &handler_vec, bool use_create_byindex_func, - size_t dyn_index) { - return 0; -} -OutHandler OpAdapterImpl::getOutput(const OperatorPtr &op, int index) { - OutHandler handler; - return handler; -} -int OpAdapterImpl::getAttr(const OperatorPtr &op, const std::string &attr_key, ValuePtr *attr_value) { return 0; } -int OpAdapterImpl::getAttr(const OperatorPtr &op, uint32_t input_idx, ValuePtr *attr_value) { return 0; } -std::string OpAdapterImpl::GetCustomOpType(const PrimitivePtr &prim) const { return ""; } -std::map OpAdapterImpl::GetOpAttrList(const OperatorPtr &) const { return {}; } -Status OpAdapterImpl::SetOpSubgraphFunc(const OperatorPtr &op, int index, - const std::shared_ptr> &branches) { - return SUCCESS; -} - -Status OpAdapterImpl::SetOpSubgraphFunc(const OperatorPtr &op, const std::shared_ptr> &subgraphs) { - return SUCCESS; -} -std::vector OpAdapterImpl::getOutputs(const OperatorPtr &op) const { return std::vector(); } -OperatorPtr OpAdapterImpl::GenerateCustomOp(const AnfNodePtr anf) { return nullptr; } -std::map OpAdapterImpl::GetNormalOpAttrList(const OperatorPtr &op, - const AnfNodePtr &node) const { - return {}; -} - -GeDataType TransformUtil::ConvertDataType(const MeDataType &type) { return GeDataType::DT_UNDEFINED; } -GeDataType ConvertDataType(const MeDataType &type) { return GeDataType::DT_UNDEFINED; } -bool IsCustomCNode(const AnfNodePtr &anf) { return false; } -} // namespace transform -} // namespace mindspore - -#endif diff --git a/mindspore-lite/src/extendrt/session/ascend_native_session.cc b/mindspore-lite/src/extendrt/session/ascend_native_session.cc deleted file mode 100644 index de21153491a16e7a2f2982e86ba03203eb5f83f4..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/session/ascend_native_session.cc +++ /dev/null @@ -1,527 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "extendrt/session/ascend_native_session.h" -#include -#include -#include -#include -#include -#include -#include "extendrt/utils/tensor_utils.h" -#include "extendrt/session/factory.h" -#include "extendrt/utils/tensor_default_impl.h" -#include "extendrt/delegate/ascend_native/delegate.h" -#include "src/common/log_adapter.h" -#include "src/litert/cxx_api/converters.h" -#include "ir/graph_utils.h" -#include "tools/optimizer/common/gllo_utils.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/utils.h" -#include "src/train/opt_allocator.h" -#include "plugin/res_manager/ascend/hccl_adapter/hccl_adapter.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" - -namespace mindspore { -Status AscendNativeSession::MoveDataFromHostToDevice(void *sd, bool s_fp16, void *dd, bool d_fp16, size_t elem_num) { - if (s_fp16) { - if (d_fp16) { - ascend_native::CopyHostFp16ToDeviceFp16(sd, &dd, elem_num, ascend_native_stream_); - } else { - ascend_native::CopyHostFp16ToDeviceFp32(sd, &dd, elem_num, ascend_native_stream_); - } - } else { - if (d_fp16) { - ascend_native::CopyHostFp32ToDeviceFp16(sd, &dd, elem_num, ascend_native_stream_); - } else { - ascend_native::CopyHostFp32ToDeviceFp32(sd, &dd, elem_num, ascend_native_stream_); - } - } - return kSuccess; -} - -Status AscendNativeSession::MoveDataFromDeviceToHost(void *sd, bool s_fp16, void *dd, bool d_fp16, size_t elem_num) { - if (sd == nullptr) { - MS_LOG(ERROR) << "source pointer is null"; - return kLiteNullptr; - } - if (dd == nullptr) { - MS_LOG(ERROR) << "destination pointer is null"; - return kLiteNullptr; - } - if (s_fp16) { - if (d_fp16) { - ascend_native::CopyDeviceFp16ToHostFp16(sd, dd, elem_num, ascend_native_stream_); - } else { - ascend_native::CopyDeviceFp16ToHostFp32(sd, dd, elem_num, ascend_native_stream_); - } - } else { - if (d_fp16) { - ascend_native::CopyDeviceFp32ToHostFp16(sd, dd, elem_num, ascend_native_stream_); - } else { - ascend_native::CopyDeviceFp32ToHostFp32(sd, dd, elem_num, ascend_native_stream_); - } - } - return kSuccess; -} - -void *AscendNativeSession::MallocDevice(size_t size) { - MS_CHECK_GT(size, 0, nullptr); - auto device_data = ascend_native::MallocDevice(size, ascend_native_stream_); - if (device_data == nullptr) { - MS_LOG(ERROR) << "fail to allocate " << size << " bytes of device data"; - } - return device_data; -} - -void AscendNativeSession::FreeDevice(void *ptr) { ascend_native::FreeDevice(ptr, ascend_native_stream_); } - -void AscendNativeSession::InitializeTensorRefrenceCnt() { - for (auto &kernel : kernels_) { - for (auto tensor : kernel->in_tensors()) { - if (tensor->category() == lite::VAR || tensor->category() == lite::GRAPH_INPUT) { - auto ref_count = tensor->init_ref_count(); - tensor->set_init_ref_count(ref_count + 1); - } - } - } -} - -Status AscendNativeSession::AllocTensors() { - OptAllocator allocator; - std::unordered_map ref_count; - std::unordered_map offset_map; - auto device_elem_size = context_->IsNpuFloat16Enabled() ? C2NUM : C4NUM; - for (auto &kernel : kernels_) { - // malloc Graph inputs - for (auto &tensor : kernel->in_tensors()) { - // TBD - when is allowed to free input ?? - if (tensor->category() == lite::GRAPH_INPUT) { - size_t elem_num = tensor->ElementsNum(); - if (offset_map.find(tensor) == offset_map.end()) { - size_t offset = allocator.Malloc(device_elem_size * elem_num); - offset_map[tensor] = offset; - ref_count[tensor] = tensor->init_ref_count(); - } - } - } - // malloc output tensors - for (auto &tensor : kernel->out_tensors()) { - size_t elem_num = tensor->ElementsNum(); - size_t offset = allocator.Malloc(device_elem_size * elem_num); - offset_map[tensor] = offset; - ref_count[tensor] = tensor->init_ref_count(); - } - // free according to reference counter - for (auto &tensor : kernel->in_tensors()) { - if (tensor->category() == lite::Category::VAR) { - int count = ref_count[tensor] - 1; - ref_count[tensor] = count; - if (count == 0) { - allocator.Free(offset_map[tensor]); - } - } - } - } - // Set Tensor data - mem_size_ = allocator.total_size(); - if (mem_size_ > 0) { - memory_base_addr_ = malloc(mem_size_); - if (memory_base_addr_ == nullptr) { - MS_LOG(EXCEPTION) << "Allocation of " << mem_size_ << "B on device failed"; - return kMDOutOfMemory; - } - for (auto &kernel : kernels_) { - // allocate graph inputs - for (auto &tensor : kernel->in_tensors()) { - if (tensor->category() == lite::Category::GRAPH_INPUT) { - auto it = offset_map.find(tensor); - if (it != offset_map.end()) { - tensor->set_data(reinterpret_cast(reinterpret_cast(memory_base_addr_) + it->second)); - } - } - } - // allocate activation - for (auto &tensor : kernel->out_tensors()) { - auto it = offset_map.find(tensor); - if (it != offset_map.end()) { - tensor->set_data(reinterpret_cast(reinterpret_cast(memory_base_addr_) + it->second)); - } - } - } - } - return kSuccess; -} - -Status AscendNativeSession::AllocateGraphTensors() { - if (memory_base_addr_ == nullptr) { - InitializeTensorRefrenceCnt(); - } else { - free(memory_base_addr_); - memory_base_addr_ = nullptr; - } - return AllocTensors(); -} - -std::shared_ptr AscendNativeSession::GetDeviceInfo(const std::shared_ptr &context) { - auto device_list = context->MutableDeviceInfo(); - auto ascend_info_iter = std::find_if( - device_list.begin(), device_list.end(), [&](std::shared_ptr &device_info) { - return (device_info && device_info->GetDeviceType() == kAscend && device_info->GetProvider() == "Ascend_native"); - }); - if (ascend_info_iter == device_list.end()) { - MS_LOG(ERROR) << "AscendDeviceInfo is not set. If using distributed inference, make sure device_id " - "and rank_id are set in AscendDeviceInfo"; - return nullptr; - } - auto device_info = *(ascend_info_iter); - return device_info->Cast(); -} - -Status AscendNativeSession::Init(const std::shared_ptr &context, const ConfigInfos &config_info) { - MS_LOG(INFO) << "AscendNativeSession::Init"; - context_ = ContextUtils::Convert(context.get()); - auto ascend_info = GetDeviceInfo(context); - if (ascend_info != nullptr) { - std::string rank_table_file = ""; - uint32_t device_id = ascend_info->GetDeviceID(); - int rank_id = static_cast(ascend_info->GetRankID()); - std::string s_rank_id = std::to_string(rank_id); - bool ret = hccl::HcclAdapter::GetInstance().InitHccl(device_id, s_rank_id, rank_table_file, hccl::HcclMode::kGraph); - if (ret != kSuccess) { - MS_LOG(ERROR) << "HcclAdapter::initHccl failed"; - } - } - return kSuccess; -} - -Status AscendNativeSession::FindGraphInputs(const std::vector &node_list, - const std::vector &graph_inputs, - const std::vector> &kernels) { - if (graph_inputs.empty()) { - MS_LOG(ERROR) << "DefaultGraphCompiler::Schedule get graph inputs node failed"; - return kLiteError; - } - - size_t found_input_node = 0; - this->inputs_.resize(graph_inputs.size()); - int kernel_id = 0; - std::unordered_set input_hash; - for (size_t ni = 0; ni < node_list.size(); ni++) { - auto &node = node_list[ni]; - if (!node->isa() || !AnfUtils::IsRealKernel(node)) { - continue; - } - auto cnode = utils::cast(node); - for (size_t i = 1; i < cnode->size(); i++) { - for (size_t j = 0; j < graph_inputs.size(); j++) { - if (cnode->input(i) == graph_inputs[j]) { - this->inputs_[j] = kernels[kernel_id]->in_tensors().at(i - 1); - this->inputs_[j]->set_category(lite::GRAPH_INPUT); - if (input_hash.find(cnode->input(i)) == input_hash.end()) { - input_hash.insert(cnode->input(i)); - found_input_node++; - break; - } - } - } - } - kernel_id++; - } - if (found_input_node != graph_inputs.size()) { - MS_LOG(ERROR) << "Can not find corresponding anfnode for all funcgraph inputs."; - return kLiteError; - } - return kSuccess; -} - -Status AscendNativeSession::FindGraphOutputs(const std::vector &node_list, const AnfNodePtr &graph_output, - const std::vector> &kernels) { - if (graph_output == nullptr) { - MS_LOG(ERROR) << "get graph output node failed."; - return kLiteError; - } - const PrimitiveSet prims{prim::kPrimTupleGetItem, prim::kPrimListGetItem, prim::kPrimArrayGetItem, - prim::kPrimMakeTuple}; - auto cnode = utils::cast(graph_output); - if (cnode == nullptr) { - MS_LOG(ERROR) << "ascend_native delegate not support empty subgraph now."; - return kLiteError; - } - auto prim_vnode = cnode->input(0); - if (IsOneOfPrimitive(prim_vnode, prims)) { - MS_LOG(ERROR) << "ascend_native delegate not support maketuple and tuple-get-item operator now."; - return kLiteError; - } - int kernel_id = 0; - for (size_t ni = 0; ni < node_list.size(); ni++) { - auto &node = node_list[ni]; - if (!node->isa() || !AnfUtils::IsRealKernel(node)) { - continue; - } - if (node == graph_output) { - for (auto &output : kernels_[kernel_id]->out_tensors()) { // TBD do kernel hash map - this->outputs_.emplace_back(output); - output->set_category(lite::GRAPH_OUTPUT); - } - break; - } - kernel_id++; - } - return kSuccess; -} - -Status AscendNativeSession::CompileGraph(FuncGraphPtr func_graph, const void *data, size_t size, uint32_t *graph_id) { - MS_LOG(INFO) << "AscendNativeSession::CompileGraph"; - if (delegate_ == nullptr) { - MS_LOG(ERROR) << "ascend_native delegate not inited"; - return kLiteNullptr; - } - - delegate_->set_ascend_native_ctx(context_); - ascend_native_stream_ = ascend_native::CreateStream(); - // call delegate replace nodes make the delegate replace the graph nodes - delegate_->ReplaceNodes(func_graph); - auto nodes = TopoSort(func_graph->get_return()); - // for all the nodes in the graph, call the delegate isDelegateNode and CreateKernel interface to create kernels - for (auto &node : nodes) { - if (!node->isa() || !AnfUtils::IsRealKernel(node)) { - continue; - } - - auto kernel = delegate_->CreateKernel(node); - if (kernel == nullptr) { - MS_LOG(ERROR) << "delegate create kernel failed."; - return kLiteError; - } - kernels_.emplace_back(kernel); - } - if (kernels_.empty()) { - MS_LOG(ERROR) << "delegate not support empty subgraph now."; - return kLiteError; - } - - auto findio_ret = FindGraphInputs(nodes, func_graph->get_inputs(), kernels_); - if (findio_ret != kSuccess) { - MS_LOG(ERROR) << "Search graph input tensors failed."; - return findio_ret; - } - findio_ret = FindGraphOutputs(nodes, func_graph->output(), kernels_); - if (findio_ret != kSuccess) { - MS_LOG(ERROR) << "Search graph output tensors failed."; - return findio_ret; - } - if (AllocateGraphTensors() != kSuccess) { - MS_LOG(ERROR) << "kernel graph allocation failed "; - return kLiteError; - } - // call kernel prepare - for (auto &kernel : kernels_) { - auto ret = kernel->Prepare(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "kernel prepare failed with " << ret; - return kLiteError; - } - } - return kSuccess; -} - -void AscendNativeSession::ResetTensorData(const std::vector &old_data, - const std::vector &tensors) { - for (size_t j = 0; j < old_data.size(); j++) { - tensors.at(j)->set_data(old_data.at(j)); - } -} - -Status AscendNativeSession::RefDataFromOuter(const std::vector &outer_tensors) { - const std::vector &inner_tensors = this->inputs_; - if (outer_tensors.size() != inner_tensors.size()) { - MS_LOG(EXCEPTION) << "user input size " << outer_tensors.size() << " is not equal to graph input size " - << inner_tensors.size(); - } - std::vector old_data; - - for (size_t i = 0; i < outer_tensors.size(); i++) { - auto &user_input = outer_tensors.at(i); - auto input = inner_tensors.at(i); - if (user_input.data_type() != input->data_type()) { - ResetTensorData(old_data, inner_tensors); - MS_LOG(ERROR) << "Tensor " << user_input.id() << " has a different data type from input" << input->tensor_name() - << "."; - return kLiteError; - } - if (user_input.data_c() == nullptr) { - ResetTensorData(old_data, inner_tensors); - MS_LOG(ERROR) << "Tensor " << user_input.id() << " has no data."; - return kLiteError; - } - old_data.push_back(input->data()); - if (input->data_type() == kObjectTypeString) { - MS_LOG(ERROR) << "Not support string type tensor now!"; - return kLiteError; - } - if (user_input.data_c() != input->data()) { - if (input->Size() != user_input.Size()) { - ResetTensorData(old_data, inner_tensors); - MS_LOG(ERROR) << "Tensor " << user_input.id() << " has wrong data size."; - return kLiteError; - } - input->set_data(user_input.data_c(), false); - } - } - return kSuccess; -} - -std::vector AscendNativeSession::LiteTensorToTensor() { - const std::vector &inner_tensors = this->outputs_; - std::vector tensors; - for (auto inner_tensor : inner_tensors) { - if (inner_tensor == nullptr) { - MS_LOG(ERROR) << "Input inner_tensors has nullptr."; - return std::vector{}; - } - auto type_id = inner_tensor->data_type(); - auto shape = inner_tensor->shape(); - auto data = inner_tensor->MutableData(); - auto data_size = inner_tensor->Size(); - auto ref_tensor_data = std::make_shared(data, inner_tensor->ElementsNum(), data_size, shape.size()); - std::vector shape64; - std::transform(shape.begin(), shape.end(), std::back_inserter(shape64), - [](int dim) { return static_cast(dim); }); - mindspore::tensor::Tensor tensor(type_id, shape64, ref_tensor_data); - tensors.emplace_back(std::move(tensor)); - } - return tensors; -} - -Status AscendNativeSession::RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs, const MSKernelCallBack &before, - const MSKernelCallBack &after) { - MS_LOG(INFO) << "AscendNativeSession::RunGraph"; - - // get inputs and outputs tensors, set the data ptr for inputs - auto ret = RefDataFromOuter(inputs); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Sync tensor data from use tensor failed: " << ret; - return ret; - } - // call kernel run interface one by one - for (auto &kernel : kernels_) { - auto exec_ret = kernel->Run(); - if (exec_ret != kSuccess) { - MS_LOG(ERROR) << "kernel Run failed with " << exec_ret; - return kLiteError; - } - } - // synchronize all tasks are finished - *outputs = LiteTensorToTensor(); - if (outputs->size() != this->outputs_.size()) { - MS_LOG(ERROR) << "Convert output tensors failed"; - return kLiteNullptr; - } - ascend_native::SyncDevice(ascend_native_stream_); - return kSuccess; -} - -Status AscendNativeSession::RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs) { - return RunGraph(graph_id, inputs, outputs, nullptr, nullptr); -} - -Status AscendNativeSession::Resize(uint32_t graph_id, const std::vector &inputs, - const std::vector> &new_shapes) { - MS_LOG(EXCEPTION) << "AscendNativeSession::Resize not implemented"; -} - -std::vector AscendNativeSession::GetOutputs(uint32_t graph_id) { - std::vector result; - std::transform(this->outputs_.begin(), this->outputs_.end(), std::back_inserter(result), - [](infer::abstract::Tensor *output) { return std::make_shared(output); }); - return result; -} -std::vector AscendNativeSession::GetInputs(uint32_t graph_id) { - std::vector result; - std::transform(this->inputs_.begin(), this->inputs_.end(), std::back_inserter(result), - [](infer::abstract::Tensor *input) { return std::make_shared(input); }); - return result; -} -std::vector AscendNativeSession::GetOutputNames(uint32_t graph_id) { - std::vector output_names; - - auto lite_outputs = this->outputs_; - std::transform(lite_outputs.begin(), lite_outputs.end(), std::back_inserter(output_names), - [](infer::abstract::Tensor *tensor) { return tensor->tensor_name(); }); - return output_names; -} -std::vector AscendNativeSession::GetInputNames(uint32_t graph_id) { - std::vector input_names; - auto lite_inputs = this->inputs_; - std::transform(lite_inputs.begin(), lite_inputs.end(), std::back_inserter(input_names), - [](infer::abstract::Tensor *tensor) { return tensor->tensor_name(); }); - return input_names; -} -MutableTensorImplPtr AscendNativeSession::GetOutputByTensorName(uint32_t graph_id, const std::string &tensorName) { - auto lite_outputs = this->outputs_; - auto it = std::find_if(lite_outputs.begin(), lite_outputs.end(), [tensorName](infer::abstract::Tensor *tensor) { - if (tensor->tensor_name() == tensorName) { - return true; - } - return false; - }); - if (it != lite_outputs.end()) { - return std::make_shared(*it); - } - return nullptr; -} -MutableTensorImplPtr AscendNativeSession::GetInputByTensorName(uint32_t graph_id, const std::string &tensorName) { - auto lite_inputs = this->inputs_; - auto it = std::find_if(lite_inputs.begin(), lite_inputs.end(), [tensorName](infer::abstract::Tensor *tensor) { - if (tensor->tensor_name() == tensorName) { - return true; - } - return false; - }); - if (it != lite_inputs.end()) { - return std::make_shared(*it); - } - return nullptr; -} - -static std::shared_ptr AscendNativeSessionCreator(const std::shared_ptr &ctx, - const ConfigInfos &config_infos) { - auto &device_contexts = ctx->MutableDeviceInfo(); - if (device_contexts.empty()) { - return nullptr; - } - - auto provider = device_contexts.at(0)->GetProvider(); - auto delegate = std::make_shared(); - if (delegate == nullptr) { - return nullptr; - } - auto session = std::make_shared(delegate); - constexpr auto kAscendProviderAscendNative = "ascend_native"; - - if (provider == kAscendProviderAscendNative) { - session->Init(ctx); - } - return session; -} - -REG_SESSION(kAscendNativeSession, AscendNativeSessionCreator); -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/session/ascend_native_session.h b/mindspore-lite/src/extendrt/session/ascend_native_session.h deleted file mode 100644 index 87102c4caeb5c068f06c127005ff495cf20ee91f..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/session/ascend_native_session.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_SESSION_ASCEND_NATIVE_SESSION_H_ -#define MINDSPORE_LITE_EXTENDRT_SESSION_ASCEND_NATIVE_SESSION_H_ - -#include -#include -#include -#include -#include - -#include "extendrt/infer_session.h" -#include "extendrt/delegate/type.h" -#include "extendrt/delegate/ascend_native/delegate.h" -#include "extendrt/delegate/ascend_native/ascend_native_impl/utils.h" -#include "extendrt/delegate/factory.h" -#include "infer/kernel.h" -#include "infer/tensor.h" -#include "infer/context.h" - -namespace mindspore { -class AscendNativeSession : public InferSession { - public: - AscendNativeSession() = default; - explicit AscendNativeSession(std::shared_ptr delegate) - : delegate_(std::move(delegate)) {} - ~AscendNativeSession() override = default; - - Status Init(const std::shared_ptr &context, const ConfigInfos &config_info = {}) override; - Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0, - uint32_t *graph_id = nullptr) override; - Status RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector *outputs, - const MSKernelCallBack &before, const MSKernelCallBack &after) override; - Status RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs) override; - Status Resize(uint32_t graph_id, const std::vector &inputs, - const std::vector> &dims) override; - std::vector GetOutputs(uint32_t graph_id) override; - std::vector GetInputs(uint32_t graph_id) override; - std::vector GetOutputNames(uint32_t graph_id) override; - std::vector GetInputNames(uint32_t graph_id) override; - MutableTensorImplPtr GetOutputByTensorName(uint32_t graph_id, const std::string &tensorName) override; - MutableTensorImplPtr GetInputByTensorName(uint32_t graph_id, const std::string &name) override; - - private: - Status FindGraphInputs(const std::vector &node_list, const std::vector &graph_inputs, - const std::vector> &kernels); - Status FindGraphOutputs(const std::vector &node_list, const AnfNodePtr &graph_output, - const std::vector> &kernels); - Status MoveDataFromHostToDevice(void *sd, bool s_fp16, void *dd, bool d_fp16, size_t elem_num); - Status MoveDataFromDeviceToHost(void *sd, bool s_fp16, void *dd, bool d_fp16, size_t elem_num); - void *MallocDevice(size_t size); - void FreeDevice(void *ptr); - - Status RefDataFromOuter(const std::vector &outer_tensors); - void ResetTensorData(const std::vector &old_data, const std::vector &tensors); - std::vector LiteTensorToTensor(); - void InitializeTensorRefrenceCnt(); - Status AllocTensors(); - Status AllocateGraphTensors(); - std::shared_ptr GetDeviceInfo(const std::shared_ptr &context); - - std::shared_ptr delegate_; - std::vector> kernels_; - std::vector inputs_; - std::vector outputs_; - std::shared_ptr context_; - size_t mem_size_ = 0; - void *memory_base_addr_ = nullptr; - void *ascend_native_stream_ = nullptr; -}; -} // namespace mindspore -#endif // MINDSPORE_LITE_EXTENDRT_SESSION_ASCEND_NATIVE_SESSION_H_ diff --git a/mindspore-lite/src/extendrt/session/default_session.cc b/mindspore-lite/src/extendrt/session/default_session.cc deleted file mode 100644 index faee07e1254b77906a7d6c3f21e5ed04aea13866..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/session/default_session.cc +++ /dev/null @@ -1,392 +0,0 @@ -/** - * Copyright 2019-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "extendrt/session/default_session.h" -#include "mindspore/ops/op_def/sequence_ops.h" -#include "mindspore/ops/op_def/nn_ops.h" -#include "mindspore/ops/op_def/framework_ops.h" -#include "common/ms_factory.h" -#include "extendrt/session/factory.h" -#include "extendrt/graph_compiler/factory.h" -#include "extendrt/graph_runtime/factory.h" -#include "extendrt/utils/tensor_utils.h" -#include "backend/graph_compiler/graph_partition.h" -#include "common/tensor_util.h" -#include "litert/cxx_api/tensor/tensor_impl.h" - -namespace mindspore { -Status DefaultInferSession::Init(const std::shared_ptr &context, const ConfigInfos &config_info) { - MS_LOG(DEBUG) << "Init default session begin"; - - // init compiler and runtime according to context - compiler_ = GraphCompilerRegistry::GetInstance().GetCompiler(kDefaultCompiler, context); - if (compiler_ == nullptr) { - MS_LOG(ERROR) << "Get Compiler is nullptr"; - return kLiteNullptr; - } - - runtime_ = GraphRuntimeRegistry::GetInstance().GetRuntime(kDefaultRuntime); - if (runtime_ == nullptr) { - MS_LOG(ERROR) << "Get Runtime is nullptr"; - return kLiteNullptr; - } - - MS_LOG(DEBUG) << "Init default session end"; - return kSuccess; -} - -Status DefaultInferSession::CompileGraph(FuncGraphPtr graph, const void *data, size_t size, uint32_t *) { - MS_LOG(DEBUG) << "Compile graph begin"; - auto compiler = this->GetGraphCompiler(); - if (compiler == nullptr) { - MS_LOG(ERROR) << "Compiler in session is null"; - return kLiteNullptr; - } - auto execution_plan = compiler->Compile(graph); - if (execution_plan == nullptr) { - MS_LOG(ERROR) << "Compile graph failed, execution plan is null"; - return kLiteNullptr; - } - MS_LOG(DEBUG) << "Compile graph end"; - - MS_LOG(DEBUG) << "Prepare execution plan begin"; - auto runtime = this->GetGraphRuntime(); - if (runtime == nullptr) { - MS_LOG(ERROR) << "Runtime in session is null"; - return kLiteNullptr; - } - auto status = runtime->Prepare(execution_plan); - if (status != kSuccess) { - MS_LOG(ERROR) << "Prepare graph runtime failed"; - return status; - } - MS_LOG(DEBUG) << "Prepare execution plan end"; - - return kSuccess; -} - -Status DefaultInferSession::RunGraph(uint32_t, const std::vector &inputs, - std::vector *outputs, const MSKernelCallBack &before, - const MSKernelCallBack &after) { - MS_LOG(DEBUG) << "Run execution plan begin"; - auto runtime = this->GetGraphRuntime(); - if (runtime_ == nullptr) { - MS_LOG(ERROR) << "Graph Runtime in session is null"; - return kLiteNullptr; - } - // Copy user input data to graph input tensor - auto inner_inputs = runtime->GetInputs(); - auto inner_outputs = runtime->GetOutputs(); - auto status = CopyDataToInnerTensors(inputs, inner_inputs); - if (status != kSuccess) { - MS_LOG(ERROR) << "Copy data pointer to input tensors failed"; - return status; - } - - // Convert api kernel callback to inner callback - infer::abstract::KernelCallBack before_call_back = nullptr; - infer::abstract::KernelCallBack after_call_back = nullptr; - if (before != nullptr) { - before_call_back = [&](const std::vector &before_inputs, - const std::vector &before_outputs, - const MSCallBackParam &call_param) { - std::vector inputs = lite::LiteTensorsToMSTensors(before_inputs); - std::vector outputs = lite::LiteTensorsToMSTensors(before_outputs); - return before(inputs, outputs, call_param); - }; - } - - if (after != nullptr) { - after_call_back = [&](const std::vector &before_inputs, - const std::vector &before_outputs, - const MSCallBackParam &call_param) { - std::vector inputs = lite::LiteTensorsToMSTensors(before_inputs); - std::vector outputs = lite::LiteTensorsToMSTensors(before_outputs); - return after(inputs, outputs, call_param); - }; - } - - status = runtime->Execute(inner_inputs, inner_outputs, before_call_back, after_call_back); - if (status != kSuccess) { - MS_LOG(ERROR) << "Graph runtime execute Failed"; - return status; - } - - // Convert graph output tensor to user output tensor - *outputs = LiteTensorToTensor(inner_outputs); - if (outputs->size() != inner_outputs.size()) { - MS_LOG(ERROR) << "Convert output tensors failed"; - return kLiteError; - } - MS_LOG(DEBUG) << "Execute execution plan End"; - - return kSuccess; -} - -Status DefaultInferSession::RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs) { - return RunGraph(graph_id, inputs, outputs, nullptr, nullptr); -} - -Status DefaultInferSession::Resize(uint32_t, const std::vector &inputs, - const std::vector> &dims) { - MS_LOG(DEBUG) << "Graph Resize begin"; - auto runtime = this->GetGraphRuntime(); - if (runtime_ == nullptr) { - MS_LOG(ERROR) << "Graph runtime in session is null"; - return kLiteNullptr; - } - - auto inner_inputs = runtime->GetInputs(); - auto status = runtime->Resize(inner_inputs, dims); - if (status != kSuccess) { - MS_LOG(ERROR) << "Graph runtime resize failed"; - return status; - } - MS_LOG(DEBUG) << "Graph Resize end"; - - return kSuccess; -} - -std::vector DefaultInferSession::GetOutputs(uint32_t) { - auto runtime = this->GetGraphRuntime(); - if (runtime == nullptr) { - MS_LOG(ERROR) << "Graph runtime in session is null"; - return std::vector{}; - } - auto lite_outputs = runtime->GetOutputs(); - return AbstractTensorsToTensorImpls(lite_outputs); -} - -std::vector DefaultInferSession::GetInputs(uint32_t) { - auto runtime = this->GetGraphRuntime(); - if (runtime == nullptr) { - MS_LOG(ERROR) << "Graph runtime in session is null"; - return std::vector{}; - } - auto lite_inputs = runtime->GetInputs(); - return AbstractTensorsToTensorImpls(lite_inputs); -} - -std::vector DefaultInferSession::GetOutputNames(uint32_t graph_id) { - auto runtime = this->GetGraphRuntime(); - if (runtime == nullptr) { - MS_LOG(ERROR) << "Graph runtime in session is null"; - return std::vector{}; - } - std::vector output_names; - auto lite_outputs = runtime->GetOutputs(); - std::transform(lite_outputs.begin(), lite_outputs.end(), std::back_inserter(output_names), - [](infer::abstract::Tensor *tensor) { return tensor->tensor_name(); }); - return output_names; -} - -std::vector DefaultInferSession::GetInputNames(uint32_t graph_id) { - auto runtime = this->GetGraphRuntime(); - if (runtime == nullptr) { - MS_LOG(ERROR) << "Graph runtime in session is null"; - return std::vector{}; - } - std::vector input_names; - auto lite_inputs = runtime->GetInputs(); - std::transform(lite_inputs.begin(), lite_inputs.end(), std::back_inserter(input_names), - [](infer::abstract::Tensor *tensor) { return tensor->tensor_name(); }); - return input_names; -} - -MutableTensorImplPtr DefaultInferSession::GetOutputByTensorName(uint32_t graph_id, const std::string &tensorName) { - auto runtime = this->GetGraphRuntime(); - if (runtime == nullptr) { - MS_LOG(ERROR) << "Graph runtime in session is null"; - return nullptr; - } - auto lite_outputs = runtime->GetOutputs(); - auto it = std::find_if(lite_outputs.begin(), lite_outputs.end(), [tensorName](infer::abstract::Tensor *tensor) { - if (tensor->tensor_name() == tensorName) { - return true; - } - return false; - }); - if (it != lite_outputs.end()) { - return std::make_shared(*it); - } - return nullptr; -} - -MutableTensorImplPtr DefaultInferSession::GetInputByTensorName(uint32_t graph_id, const std::string &name) { - auto runtime = this->GetGraphRuntime(); - if (runtime == nullptr) { - MS_LOG(ERROR) << "Graph runtime in session is null"; - return nullptr; - } - auto lite_inputs = runtime->GetInputs(); - auto it = std::find_if(lite_inputs.begin(), lite_inputs.end(), [name](infer::abstract::Tensor *tensor) { - if (tensor->tensor_name() == name) { - return true; - } - return false; - }); - if (it != lite_inputs.end()) { - return std::make_shared(*it); - } - return nullptr; -} - -void DefaultInferSession::ResetTensorData(const std::vector &old_data, - const std::vector &tensors) { - for (size_t j = 0; j < old_data.size(); j++) { - tensors.at(j)->set_data(old_data.at(j)); - } -} - -Status DefaultInferSession::CopyDataToInnerTensors(const std::vector &tensors, - std::vector inner_tensors) { - if (tensors.size() != inner_tensors.size()) { - MS_LOG(ERROR) << "user input size " << tensors.size() << " is not equal to graphp input size " - << inner_tensors.size(); - return kLiteError; - } - std::vector old_data; - for (size_t i = 0; i < tensors.size(); i++) { - auto &user_input = tensors.at(i); - auto input = inner_tensors.at(i); - if (user_input.data_type() != input->data_type()) { - ResetTensorData(old_data, inner_tensors); - MS_LOG(ERROR) << "Tensor " << user_input.id() << " has a different data type from input" << input->tensor_name() - << "."; - return kLiteError; - } - if (user_input.data_c() == nullptr) { - ResetTensorData(old_data, inner_tensors); - MS_LOG(ERROR) << "Tensor " << user_input.id() << " has no data."; - return kLiteError; - } - old_data.push_back(input->data()); - if (input->data_type() == kObjectTypeString) { - std::vector shape = - TruncateShape(user_input.shape_c(), input->data_type(), user_input.DataSize(), false); - if (shape.empty() && !(user_input.shape_c().empty())) { - ResetTensorData(old_data, inner_tensors); - MS_LOG(ERROR) << "Input dims of tensor " << user_input.id() << " is invalid."; - return kLiteError; - } - input->set_shape(shape); - input->set_data(user_input.data_c(), false); - } else { - if (user_input.data_c() != input->data()) { - if (input->Size() != user_input.Size()) { - ResetTensorData(old_data, inner_tensors); -#ifndef ENABLE_LITE_ACL - MS_LOG(ERROR) << "Tensor " << user_input.id() << " has wrong data size."; - return kLiteError; -#else - MS_LOG(WARNING) << "Please check tensor " << user_input.id() - << " has been modified data size by DVPP method."; - std::vector truncate_shape = {static_cast(user_input.DataSize())}; - input->set_shape(truncate_shape); -#endif - } - input->set_data(user_input.data_c(), false); - } - } - } - - return kSuccess; -} - -std::vector DefaultInferSession::AbstractTensorsToTensorImpls( - const std::vector &abstract_tensors) { - std::vector tensorImpls; - tensorImpls.reserve(abstract_tensors.size()); - (void)std::transform(abstract_tensors.begin(), abstract_tensors.end(), std::back_inserter(tensorImpls), - [](infer::abstract::Tensor *tensor) { return std::make_shared(tensor); }); - return tensorImpls; -} - -std::vector DefaultInferSession::LiteTensorToTensor( - const std::vector &abstract_tensors) { - std::vector tensors; - - for (auto abstract_tensor : abstract_tensors) { - if (abstract_tensor == nullptr) { - MS_LOG(ERROR) << "get nullptr tensor"; - return std::vector{}; - } - auto type_id = abstract_tensor->data_type(); - auto shape = abstract_tensor->shape(); - auto data = abstract_tensor->MutableData(); - auto data_size = abstract_tensor->Size(); - auto ref_tensor_data = - std::make_shared(data, abstract_tensor->ElementsNum(), data_size, shape.size()); - std::vector shape64; - std::transform(shape.begin(), shape.end(), std::back_inserter(shape64), - [](int dim) { return static_cast(dim); }); - - mindspore::tensor::Tensor tensor(type_id, shape64, ref_tensor_data); - auto device_address = abstract_tensor->device_data(); - if (device_address != nullptr) { - auto lite_device_address = std::make_shared(device_address, abstract_tensor->Size()); - tensor.set_device_address(lite_device_address); - } - tensors.emplace_back(std::move(tensor)); - } - - return tensors; -} - -std::vector DefaultInferSession::TruncateShape(const std::vector &shape, enum TypeId type, - size_t data_len, bool verify_size) { - std::vector empty; - if (shape.empty()) { - return empty; - } - std::vector truncated_shape; - truncated_shape.resize(shape.size()); - size_t element_size = lite::DataTypeSize(type); - for (size_t i = 0; i < shape.size(); i++) { - auto dim = shape[i]; - if (dim < 0 || dim > INT_MAX || (dim != 0 && element_size > INT_MAX / static_cast(dim))) { - MS_LOG(ERROR) << "Invalid shape!dim: " << dim << ", element_size: " << element_size; - return empty; - } else { - element_size *= static_cast(dim); - truncated_shape[i] = static_cast(dim); - } - } - if (verify_size) { - if (element_size != data_len) { - MS_LOG(ERROR) << "Invalid data size!element_size: " << element_size << ", data_len: " << data_len; - return empty; - } - } - return truncated_shape; -} - -static std::shared_ptr DefaultSessionCreator(const std::shared_ptr &ctx, - const ConfigInfos &config_infos) { - auto session = std::make_shared(ctx); - auto ret = session->Init(ctx, config_infos); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Init default session failed."; - return nullptr; - } - return session; -} -REG_SESSION(kDefaultSession, DefaultSessionCreator); -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/session/default_session.h b/mindspore-lite/src/extendrt/session/default_session.h deleted file mode 100644 index bad51d36a7c2dbe03bf640a760e58df4db2f09a7..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/session/default_session.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_SESSION_DEFAULT_SESSION_H_ -#define MINDSPORE_LITE_EXTENDRT_SESSION_DEFAULT_SESSION_H_ - -#include -#include -#include -#include - -#include "extendrt/infer_session.h" - -#include "infer/graph_compiler.h" -#include "infer/graph_runtime.h" - -namespace mindspore { -class DefaultInferSession : public InferSession { - public: - explicit DefaultInferSession(const std::shared_ptr &context) : context_(context) {} - virtual ~DefaultInferSession() = default; - Status Init(const std::shared_ptr &context, const ConfigInfos &config_info = {}) override; - Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0, - uint32_t *graph_id = nullptr) override; - Status RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs) override; - Status RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector *outputs, - const MSKernelCallBack &before, const MSKernelCallBack &after) override; - Status Resize(uint32_t graph_id, const std::vector &inputs, - const std::vector> &dims) override; - std::vector GetOutputs(uint32_t graph_id) override; - std::vector GetInputs(uint32_t graph_id) override; - std::vector GetOutputNames(uint32_t graph_id) override; - std::vector GetInputNames(uint32_t graph_id) override; - MutableTensorImplPtr GetOutputByTensorName(uint32_t graph_id, const std::string &tensorName) override; - MutableTensorImplPtr GetInputByTensorName(uint32_t graph_id, const std::string &name) override; - - protected: - virtual std::shared_ptr GetGraphCompiler() { return compiler_; } - - virtual std::shared_ptr GetGraphRuntime() { return runtime_; } - - void ResetTensorData(const std::vector &old_data, const std::vector &tensors); - - private: - Status CopyDataToInnerTensors(const std::vector &tensors, - std::vector inner_tensors); - std::vector AbstractTensorsToTensorImpls( - const std::vector &abstract_tensors); - std::vector LiteTensorToTensor( - const std::vector &abstract_tensors); - std::vector TruncateShape(const std::vector &shape, enum TypeId type, size_t data_len, - bool verify_size); - - private: - std::shared_ptr compiler_; - - std::shared_ptr runtime_; - - const std::shared_ptr &context_; -}; -} // namespace mindspore -#endif // MINDSPORE_LITE_EXTENDRT_SESSION_DEFAULT_SESSION_H_ diff --git a/mindspore-lite/src/extendrt/session/delegate_session.cc b/mindspore-lite/src/extendrt/session/delegate_session.cc index ca71857410b7c45e2b20a8272d3aa533b8e284f7..862e14fd0e10e83822a52a50efdda3fe067632ae 100644 --- a/mindspore-lite/src/extendrt/session/delegate_session.cc +++ b/mindspore-lite/src/extendrt/session/delegate_session.cc @@ -20,7 +20,6 @@ #include #include #include "extendrt/utils/tensor_utils.h" -#include "src/extendrt/utils/kernel_build_utils.h" #include "extendrt/delegate/factory.h" #include "extendrt/session/factory.h" #include "extendrt/utils/tensor_default_impl.h" @@ -28,10 +27,9 @@ #include "src/extendrt/delegate/plugin/ascend_ge_executor_plugin.h" #include "extendrt/utils/func_graph_utils.h" #include "common/common.h" - +#include "src/extendrt/session/lite_graph_executor.h" namespace mindspore { namespace { -constexpr auto kDataFlowGraphType = "data_flow"; constexpr auto kIsAdapted = "is_adapted"; std::mutex kernel_graph_mutex; @@ -41,7 +39,6 @@ std::mutex g_build_graph_mutex; GraphSinkSession::~GraphSinkSession() = default; Status GraphSinkSession::Init(const std::shared_ptr &context, const ConfigInfos &config_info) { - MS_LOG(INFO) << "GraphSinkSession::Init"; if (graph_executor_ == nullptr) { MS_LOG(ERROR) << "GraphSinkSession::Init failed, graph executor is nullptr."; return kLiteUninitializedObj; @@ -52,7 +49,6 @@ Status GraphSinkSession::Init(const std::shared_ptr &context, const Con } Status GraphSinkSession::CompileGraph(const void *model_data, size_t data_size, uint32_t *graph_id) { - MS_LOG(INFO) << "GraphSinkSession::CompileGraph"; // This lock can be removed when LiteRT supports concurrent multithreading compilation. std::lock_guard lock(g_build_graph_mutex); auto ret = graph_executor_->CompileGraph(model_data, data_size, options_, graph_id); @@ -75,21 +71,19 @@ Status GraphSinkSession::CompileGraph(FuncGraphPtr graph, const void *data, size // This lock can be removed when LiteRT supports concurrent multithreading compilation. std::lock_guard lock(g_build_graph_mutex); // kernel graph will be removed from GraphSinkSession, and this code will be moved to TensorRT plugin - auto func_type = graph->get_attr(kAttrFuncType); - is_data_flow_graph_ = func_type != nullptr && GetValue(func_type) == kDataFlowGraphType; if (context_ && !context_->MutableDeviceInfo().empty()) { auto device_info = context_->MutableDeviceInfo()[0]; bool is_ge_backend = device_info && device_info->GetDeviceType() == DeviceType::kAscend && device_info->GetProvider() == lite::kAscendProviderGe; bool is_adapted = graph->has_attr(kIsAdapted); // The funcgraph will only adapted once while running parallel. - if (is_ge_backend && !is_adapted && !is_data_flow_graph_) { + if (is_ge_backend && !is_adapted) { lite::AscendGeExecutorPlugin::GetInstance().AdaptGraph(graph); graph->set_attr(kIsAdapted, MakeValue(true)); } } DelegateGraphInfo graph_info; // the funcgraph constructed by flowgraph has no inputs and outputs. - auto status = !is_data_flow_graph_ ? InitGraphInputsOutputs(graph, &graph_info) : kSuccess; + auto status = InitGraphInputsOutputs(graph, &graph_info); if (!status.IsOk()) { MS_LOG(ERROR) << "Failed to get inputs and outputs info from graph"; return status; @@ -99,7 +93,7 @@ Status GraphSinkSession::CompileGraph(FuncGraphPtr graph, const void *data, size MS_LOG(ERROR) << "GraphSinkSession::CompileGraph compile graph failed"; return kCoreFailed; } - status = !is_data_flow_graph_ ? UpdateGraphInputsOutputs(*graph_id, &graph_info) : kSuccess; + status = UpdateGraphInputsOutputs(*graph_id, &graph_info); if (!status.IsOk()) { MS_LOG(ERROR) << "Failed to update inputs and outputs info from graph executor"; return status; @@ -120,9 +114,9 @@ Status GraphSinkSession::InitGraphInfo(DelegateGraphInfo *graph_info_ptr, uint32 info.input_names.clear(); for (size_t i = 0; i < new_inputs.size(); i++) { auto &input = new_inputs[i]; - info.input_names.push_back(input.name()); - auto data_type = static_cast(input.data_type()); - auto impl = std::make_shared(info.input_names[i], data_type, input.shape_c()); + info.input_names.push_back(input.Name()); + auto data_type = static_cast(input.DataType()); + auto impl = std::make_shared(info.input_names[i], data_type, input.Shape()); info.inputs.push_back(impl); } @@ -136,9 +130,9 @@ Status GraphSinkSession::InitGraphInfo(DelegateGraphInfo *graph_info_ptr, uint32 info.output_names.clear(); for (size_t i = 0; i < new_outputs.size(); i++) { auto &output = new_outputs[i]; - info.output_names.push_back(output.name()); - auto data_type = static_cast(output.data_type()); - auto impl = std::make_shared(info.output_names[i], data_type, output.shape_c()); + info.output_names.push_back(output.Name()); + auto data_type = static_cast(output.DataType()); + auto impl = std::make_shared(info.output_names[i], data_type, output.Shape()); info.outputs.push_back(impl); } return kSuccess; @@ -146,7 +140,7 @@ Status GraphSinkSession::InitGraphInfo(DelegateGraphInfo *graph_info_ptr, uint32 Status GraphSinkSession::InitGraphInputsOutputs(const FuncGraphPtr &graph, DelegateGraphInfo *graph_info_ptr) { auto &info = *graph_info_ptr; - std::vector graph_inputs, graph_outputs; + std::vector graph_inputs, graph_outputs; { std::unique_lock l(kernel_graph_mutex); FuncGraphReuseManager::GetInstance()->GetInOut(config_infos_, &graph_inputs, &graph_outputs, &info.input_names, @@ -173,15 +167,15 @@ Status GraphSinkSession::InitGraphInputsOutputs(const FuncGraphPtr &graph, Deleg info.inputs.clear(); for (size_t i = 0; i < info.input_names.size(); i++) { auto &input = graph_inputs[i]; - auto data_type = static_cast(input->data_type()); - auto impl = std::make_shared(info.input_names[i], data_type, input->shape_c()); + auto data_type = static_cast(input->DataType()); + auto impl = std::make_shared(info.input_names[i], data_type, input->Shape()); info.inputs.push_back(impl); } info.outputs.clear(); for (size_t i = 0; i < info.output_names.size(); i++) { auto &output = graph_outputs[i]; - auto data_type = static_cast(output->data_type()); - auto impl = std::make_shared(info.output_names[i], data_type, output->shape_c()); + auto data_type = static_cast(output->DataType()); + auto impl = std::make_shared(info.output_names[i], data_type, output->Shape()); info.outputs.push_back(impl); } return kSuccess; @@ -208,8 +202,8 @@ Status GraphSinkSession::UpdateGraphInputsOutputs(uint32_t graph_id, DelegateGra info.inputs.clear(); for (size_t i = 0; i < new_inputs.size(); i++) { auto &input = new_inputs[i]; - auto data_type = static_cast(input.data_type()); - auto impl = std::make_shared(info.input_names[i], data_type, input.shape_c()); + auto data_type = static_cast(input.DataType()); + auto impl = std::make_shared(info.input_names[i], data_type, input.Shape()); info.inputs.push_back(impl); } } @@ -225,18 +219,17 @@ Status GraphSinkSession::UpdateGraphInputsOutputs(uint32_t graph_id, DelegateGra info.outputs.clear(); for (size_t i = 0; i < new_outputs.size(); i++) { auto &output = new_outputs[i]; - auto data_type = static_cast(output.data_type()); - auto impl = std::make_shared(info.output_names[i], data_type, output.shape_c()); + auto data_type = static_cast(output.DataType()); + auto impl = std::make_shared(info.output_names[i], data_type, output.Shape()); info.outputs.push_back(impl); } } return kSuccess; } -Status GraphSinkSession::RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs, const MSKernelCallBack &before, +Status GraphSinkSession::RunGraph(uint32_t graph_id, const std::vector &inputs, + std::vector *outputs, const MSKernelCallBack &before, const MSKernelCallBack &after) { - MS_LOG(INFO) << "GraphSinkSession::RunGraph"; MS_EXCEPTION_IF_NULL(outputs); graph_executor_->SetBefore(before); graph_executor_->SetAfter(after); @@ -251,34 +244,21 @@ Status GraphSinkSession::RunGraph(uint32_t graph_id, const std::vector(input_infos[i]->DataType())) { - MS_LOG(ERROR) << "Input " << i << " data type not match, graph input type " << input_infos[i]->DataType() - << ", given input type " << inputs[i].data_type(); - return kCoreFailed; - } } bool ret = graph_executor_->RunGraph(graph_id, inputs, outputs, options_); if (!ret) { MS_LOG(ERROR) << "GraphSinkSession::RunGraph run graph failed"; return kCoreFailed; } - if (is_data_flow_graph_) { - DelegateGraphInfo graph_info; - if (UpdateGraphInputsOutputs(graph_id, &graph_info) != kSuccess) { - MS_LOG(ERROR) << "Update graph inputs and outputs failed for data flow graph."; - return kCoreFailed; - } - graph_infos_[graph_id] = graph_info; - } return kSuccess; } -Status GraphSinkSession::RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs) { +Status GraphSinkSession::RunGraph(uint32_t graph_id, const std::vector &inputs, + std::vector *outputs) { return RunGraph(graph_id, inputs, outputs, nullptr, nullptr); } -Status GraphSinkSession::Resize(uint32_t graph_id, const std::vector &inputs, +Status GraphSinkSession::Resize(uint32_t graph_id, const std::vector &inputs, const std::vector> &new_shapes) { MS_LOG(INFO) << "GraphSinkSession::Resize"; MS_EXCEPTION_IF_NULL(graph_executor_); @@ -290,15 +270,7 @@ Status GraphSinkSession::Resize(uint32_t graph_id, const std::vectorsecond; auto ret = graph_executor_->Resize(graph_id, inputs, new_shapes); if (!ret) { - return kCoreFailed; - } - auto new_outputs = graph_executor_->GetOutputInfos(graph_id); - if (new_outputs.empty()) { - return kSuccess; - } - if (new_outputs.size() != info.outputs.size()) { - MS_LOG(ERROR) << "Output count " << new_outputs.size() << " get from executor != last output count " - << info.outputs.size(); + MS_LOG(ERROR) << "model resize failed."; return kCoreFailed; } for (size_t i = 0; i < new_shapes.size(); i++) { @@ -306,11 +278,6 @@ Status GraphSinkSession::Resize(uint32_t graph_id, const std::vectorSetShape(input_shape); info.inputs[i]->SetData(nullptr, false); // reset data } - for (size_t i = 0; i < info.outputs.size(); i++) { - auto &output = new_outputs[i]; - info.outputs[i]->SetShape(output.shape_c()); - info.outputs[i]->SetData(nullptr, false); // reset data - } return kSuccess; } std::vector GraphSinkSession::GetOutputs(uint32_t graph_id) { @@ -329,7 +296,8 @@ std::vector GraphSinkSession::GetInputs(uint32_t graph_id) return {}; } auto &info = info_it->second; - return info.inputs; + auto input_info = info.inputs; + return input_info; } std::vector GraphSinkSession::GetOutputNames(uint32_t graph_id) { auto info_it = graph_infos_.find(graph_id); @@ -350,7 +318,7 @@ std::vector GraphSinkSession::GetInputNames(uint32_t graph_id) { return info.input_names; } -Status GraphSinkSession::UpdateWeights(const std::vector>> &weights) { +Status GraphSinkSession::UpdateWeights(const std::vector>> &weights) { MS_LOG(INFO) << "UpdateWeights.."; bool ret = graph_executor_->UpdateWeights(weights); if (!ret) { @@ -401,7 +369,6 @@ static std::shared_ptr DelegateSessionCreator(const std::shared_pt } auto device_type = device_contexts.at(0)->GetDeviceType(); auto provider = device_contexts.at(0)->GetProvider(); - auto delegate = DelegateRegistry>::GetInstance().GetDelegate(device_type, provider, ctx, config_infos); if (delegate == nullptr) { diff --git a/mindspore-lite/src/extendrt/session/delegate_session.h b/mindspore-lite/src/extendrt/session/delegate_session.h index 3e6c20205f842886e3e8908722d8648254db9ae3..a4694059dbab303798fecb49bcb8050860ac630a 100644 --- a/mindspore-lite/src/extendrt/session/delegate_session.h +++ b/mindspore-lite/src/extendrt/session/delegate_session.h @@ -27,8 +27,6 @@ namespace mindspore { /// \brief Delegate Session implementation, use delegate api for inference. -// (zhaizhiqiang): use GraphSinkDelegateSession instead of GraphSinkSession in future. -// class GraphSinkDelegateSession struct DelegateGraphInfo { std::vector inputs; std::vector input_names; @@ -48,11 +46,12 @@ class GraphSinkSession : public InferSession { Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0, uint32_t *graph_id = nullptr) override; Status CompileGraph(const void *model_data, size_t data_size, uint32_t *graph_id) override; - Status RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector *outputs, - const MSKernelCallBack &before, const MSKernelCallBack &after) override; - Status RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs) override; - Status Resize(uint32_t graph_id, const std::vector &inputs, + Status RunGraph(uint32_t graph_id, const std::vector &inputs, + std::vector *outputs, const MSKernelCallBack &before, + const MSKernelCallBack &after) override; + Status RunGraph(uint32_t graph_id, const std::vector &inputs, + std::vector *outputs) override; + Status Resize(uint32_t graph_id, const std::vector &inputs, const std::vector> &dims) override; std::vector GetOutputs(uint32_t graph_id) override; @@ -62,19 +61,21 @@ class GraphSinkSession : public InferSession { MutableTensorImplPtr GetOutputByTensorName(uint32_t graph_id, const std::string &tensorName) override; MutableTensorImplPtr GetInputByTensorName(uint32_t graph_id, const std::string &name) override; void SetConfigInfo(ConfigInfos config_infos) { config_infos_ = config_infos; } - Status UpdateWeights(const std::vector>> &weights) override; + Status UpdateWeights(const std::vector>> &weights) override; + Status Finalize() { + MS_LOG(INFO) << "Finalize is only implemented in single_op_session now."; + graph_executor_->Finalize(); + return kSuccess; + } private: Status InitGraphInfo(DelegateGraphInfo *graph_info_ptr, uint32_t graph_id); Status InitGraphInputsOutputs(const FuncGraphPtr &graph, DelegateGraphInfo *graph_info); Status UpdateGraphInputsOutputs(uint32_t graph_id, DelegateGraphInfo *graph_info); - void UpdateDataFlowGraphInputsOutputs(DelegateGraphInfo *graph_info_ptr, const std::vector &inputs, - const std::vector &outputs); std::shared_ptr graph_executor_; std::map options_; std::map graph_infos_; - bool is_data_flow_graph_ = false; std::shared_ptr context_; ConfigInfos config_infos_; }; diff --git a/mindspore-lite/src/extendrt/session/lite_graph_executor.h b/mindspore-lite/src/extendrt/session/lite_graph_executor.h index 1280d60bc39a4ceaacd75e6c246fa9ba813051b2..54c7bed6a918a38892bf5e6782e69b69152f5618 100644 --- a/mindspore-lite/src/extendrt/session/lite_graph_executor.h +++ b/mindspore-lite/src/extendrt/session/lite_graph_executor.h @@ -23,6 +23,7 @@ #include "include/api/types.h" #include "runtime/hardware/device_context.h" +#include "include/api/status.h" namespace mindspore { /// \brief Adaptive Graph Executor for cloud Graph Executor to solve interface conflicts. @@ -31,28 +32,38 @@ class LiteGraphExecutor { LiteGraphExecutor() = default; virtual ~LiteGraphExecutor() = default; - virtual bool CompileGraph(const FuncGraphPtr &graph, const std::map &compile_options) { + virtual void Initialize() { return; } + virtual void Finalize() { return; } + + virtual bool CompileGraph(const std::shared_ptr &graph, + const std::map &compile_options, uint32_t *graph_id) { return false; } - - virtual bool CompileGraph(const FuncGraphPtr &graph, const std::map &compile_options, - uint32_t *graph_id) { + virtual bool CompileGraph(const void *model_data, size_t data_size, + const std::map &compile_options, uint32_t *graph_id) { return false; } - virtual bool CompileGraph(const void *model_data, size_t data_size, const std::map &compile_options, - uint32_t *graph_id) { - return false; + // form base class + virtual bool RunGraph(const std::shared_ptr &graph, const std::vector &inputs, + std::vector *outputs, const std::map &compile_options) { + MS_LOG(EXCEPTION) << "Unimplemented interface."; + } + + virtual bool CompileGraph(const std::shared_ptr &graph, + const std::map &compile_options) { + return true; } - virtual bool UpdateWeights(const std::vector>> &weights) { + virtual bool UpdateWeights(const std::vector>> &weights) { MS_LOG(ERROR) << "UpdateWeights failed."; return false; } - virtual bool RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs, const std::map &compile_options) { + virtual bool RunGraph(uint32_t graph_id, const std::vector &inputs, + std::vector *outputs, + const std::map &compile_options) { (void)graph_id; (void)inputs; (void)outputs; @@ -60,24 +71,18 @@ class LiteGraphExecutor { return false; } - virtual bool RunGraph(const FuncGraphPtr &graph, const std::vector &inputs, - std::vector *outputs, const std::map &compile_options) { - return false; - } - - - virtual bool Resize(uint32_t graph_id, const std::vector &inputs, + virtual bool Resize(uint32_t graph_id, const std::vector &inputs, const std::vector> &new_shapes) { (void)graph_id; (void)inputs; (void)new_shapes; return true; } - virtual std::vector GetInputInfos(uint32_t graph_id) { + virtual std::vector GetInputInfos(uint32_t graph_id) { (void)graph_id; return {}; } - virtual std::vector GetOutputInfos(uint32_t graph_id) { + virtual std::vector GetOutputInfos(uint32_t graph_id) { (void)graph_id; return {}; } diff --git a/mindspore-lite/src/extendrt/session/lite_infer_session.cc b/mindspore-lite/src/extendrt/session/lite_infer_session.cc deleted file mode 100644 index 091e52fd48fe47f70270e885e2c0272f9ace9c33..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/session/lite_infer_session.cc +++ /dev/null @@ -1,298 +0,0 @@ -/** - * Copyright 2019-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -#include "src/extendrt/session/lite_infer_session.h" - -#include "extendrt/mock/lite_runtime/converters.h" -#include "extendrt/session/factory.h" -#include "extendrt/utils/runtime_utils.h" -#include "extendrt/utils/tensor_utils.h" - -namespace mindspore { -namespace { -std::mutex g_build_graph_mutex; -} - -Status LiteInferSession::Init(const std::shared_ptr &context, const ConfigInfos &config_info) { - MS_LOG(INFO) << "LiteInferSession::Init"; - context_ = context; - lite_session_ = CreateLiteSession(ContextUtils::Convert(context_.get())); - MS_EXCEPTION_IF_NULL(lite_session_); - return kSuccess; -} - -Status LiteInferSession::CompileGraph(FuncGraphPtr graph, const void *data, size_t size, uint32_t *) { - MS_LOG(INFO) << "LiteInferSession::CompileGraph"; - // This lock can be removed when LiteRT supports concurrent multithreading compilation. - std::lock_guard lock(g_build_graph_mutex); - // Lite infer session do not use graph, just use data and size - MS_EXCEPTION_IF_NULL(data); - MS_EXCEPTION_IF_ZERO("size", size); - lite_session_ = CreateLiteSession(ContextUtils::Convert(context_.get())); - MS_EXCEPTION_IF_NULL(lite_session_); - - auto ret = lite_session_->LoadModelAndCompileByBuf(static_cast(data), kMindIR, size); - if (ret != RET_OK) { - MS_LOG(EXCEPTION) << "load model and compile failed"; - } - - return kSuccess; -} - -void LiteInferSession::ResetTensorData(std::vector old_data, const std::vector &tensors) { - for (size_t j = 0; j < old_data.size(); j++) { - tensors.at(j)->set_data(old_data.at(j)); - } -} - -std::vector LiteInferSession::GetLiteSessionOutputs() { - std::vector empty; - if (lite_session_ == nullptr) { - MS_LOG(ERROR) << "Session is null."; - return empty; - } - std::vector res; - auto names = lite_session_->GetOutputTensorNames(); - if (names.empty()) { - MS_LOG(ERROR) << "The output tensor name of this model is null."; - return empty; - } - auto outputs = lite_session_->GetOutputs(); - if (outputs.empty()) { - MS_LOG(ERROR) << "The outputs of model is null."; - return empty; - } - if (names.size() != outputs.size()) { - MS_LOG(ERROR) << "The size of outputs does not match the size of names."; - return empty; - } - res.resize(names.size()); - for (size_t i = 0; i < names.size(); i++) { - auto impl = std::make_shared(outputs[names[i]]); - if (impl == nullptr || impl->lite_tensor() == nullptr) { - MS_LOG(ERROR) << "Create tensor failed."; - return empty; - } - auto tensor = MSTensor(impl); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Create tensor failed."; - return empty; - } - res[i] = tensor; - } - return res; -} - -std::vector LiteInferSession::TruncateShape(const std::vector &shape, TypeId type, size_t data_len, - bool verify_size) { - std::vector empty; - if (shape.empty()) { - return empty; - } - std::vector truncated_shape; - truncated_shape.resize(shape.size()); - size_t element_size = lite::DataTypeSize(type); - for (size_t i = 0; i < shape.size(); i++) { - auto dim = shape[i]; - if (dim < 0 || dim > INT_MAX || (dim != 0 && element_size > INT_MAX / static_cast(dim))) { - MS_LOG(ERROR) << "Invalid shape!dim: " << dim << ", element_size: " << element_size; - return empty; - } else { - element_size *= static_cast(dim); - truncated_shape[i] = static_cast(dim); - } - } - if (verify_size) { - if (element_size != data_len) { - MS_LOG(ERROR) << "Invalid data size!element_size: " << element_size << ", data_len: " << data_len; - return empty; - } - } - return truncated_shape; -} - -Status LiteInferSession::RunGraph(uint32_t, const std::vector &inputs, - std::vector *outputs) { - MS_LOG(INFO) << "SingleOpInferSession::RunGraph with input and outputs"; - MS_EXCEPTION_IF_NULL(outputs); - MS_EXCEPTION_IF_NULL(lite_session_); - - auto input_tensors = lite_session_->GetInputs(); - if (input_tensors.empty()) { - MS_LOG(EXCEPTION) << "Failed to get input tensor."; - } - if (input_tensors.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Wrong input size."; - } - std::vector old_data; - for (size_t i = 0; i < inputs.size(); i++) { - auto input = input_tensors.at(i); - auto user_input = &inputs[i]; - if (user_input->data_type() != input->data_type()) { - ResetTensorData(old_data, input_tensors); - MS_LOG(EXCEPTION) << "Tensor " << user_input->id() << " has a different data type from input" - << input->tensor_name() << "."; - } - if (user_input->data_c() == nullptr) { - ResetTensorData(old_data, input_tensors); - MS_LOG(EXCEPTION) << "Tensor " << user_input->id() << " has no data."; - } - old_data.push_back(input->data()); - if (input->data_type() == kObjectTypeString) { -#ifndef STRING_KERNEL_CLIP - std::vector shape = - TruncateShape(user_input->shape_c(), input->data_type(), user_input->DataSize(), false); - if (shape.empty() && !(user_input->shape_c().empty())) { - ResetTensorData(old_data, input_tensors); - MS_LOG(EXCEPTION) << "Input dims of tensor " << user_input->id() << " is invalid."; - } - input->set_shape(shape); - input->set_data(user_input->data_c()); -#else - MS_LOG(ERROR) << unsupport_string_tensor_log; - return kLiteError; -#endif - } else { - if (user_input->data_c() != input->data()) { - if (input->Size() != user_input->Size()) { - ResetTensorData(old_data, input_tensors); -#ifndef ENABLE_LITE_ACL - MS_LOG(EXCEPTION) << "Tensor " << user_input->id() << " has wrong data size."; -#else - MS_LOG(WARNING) << "Please check tensor " << user_input->id() - << " has been modified data size by DVPP method."; - std::vector truncate_shape = {static_cast(user_input->DataSize())}; - input->set_shape(truncate_shape); -#endif - } - input->set_data(user_input->data_c()); - } - } - } - auto ret = static_cast(lite_session_->RunGraph()); - ResetTensorData(old_data, input_tensors); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Run graph failed."; - return ret; - } - MS_LOG(DEBUG) << "Run graph success."; - auto res = GetLiteSessionOutputs(); - if (res.empty()) { - MS_LOG(DEBUG) << "Empty outputs."; - return kLiteError; - } - outputs->clear(); - *outputs = TensorUtils::MSTensorToTensor(res); - return kSuccess; -} - -std::vector LiteInferSession::GetOutputs(uint32_t) { - auto outputs = lite_session_->GetOutputs(); - std::vector output_tensors; - for (auto &iter : outputs) { - auto output = iter.second; - auto impl = std::make_shared(output); - output_tensors.emplace_back(impl); - } - return output_tensors; -} - -std::vector LiteInferSession::GetInputs(uint32_t) { - auto inputs = lite_session_->GetInputs(); - std::vector input_tensors; - for (auto &input : inputs) { - auto impl = std::make_shared(input); - input_tensors.emplace_back(impl); - } - return input_tensors; -} - -std::vector LiteInferSession::GetOutputNames(uint32_t) { - auto outputs = lite_session_->GetOutputs(); - std::vector output_names; - std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_names), - [](auto iter) { return iter.first; }); - return output_names; -} - -std::vector LiteInferSession::GetInputNames(uint32_t) { - return ConvertToTensorNames(lite_session_->GetInputs()); -} -MutableTensorImplPtr LiteInferSession::GetOutputByTensorName(uint32_t graph_id, const std::string &name) { - auto outputs = lite_session_->GetOutputs(); - for (auto &iter : outputs) { - auto output = iter.second; - if (output->tensor_name() == name) { - return std::make_shared(output); - } - } - return nullptr; -} - -MutableTensorImplPtr LiteInferSession::GetInputByTensorName(uint32_t graph_id, const std::string &name) { - auto inputs = lite_session_->GetInputs(); - for (auto &input : inputs) { - if (input->tensor_name() == name) { - return std::make_shared(input); - } - } - return nullptr; -} - -std::shared_ptr LiteInferSession::CreateLiteSession( - const std::shared_ptr &context) { - auto session = std::make_shared(); - if (session == nullptr) { - MS_LOG(ERROR) << "create session failed"; - return nullptr; - } - - auto ret = session->Init(context); - if (ret != mindspore::lite::RET_OK) { - MS_LOG(ERROR) << "init session failed"; - return nullptr; - } - return session; -} - -std::vector LiteInferSession::ConvertToTensorNames( - const std::vector &lite_tensors) { - std::vector tensor_names; - std::transform(lite_tensors.begin(), lite_tensors.end(), std::back_inserter(tensor_names), - [](mindspore::lite::Tensor *lite_tensor) { - MS_EXCEPTION_IF_NULL(lite_tensor); - return lite_tensor->tensor_name(); - }); - return tensor_names; -} - -static std::shared_ptr LiteInferSessionCreator(const std::shared_ptr &ctx, - const ConfigInfos &config_infos) { - auto session = std::make_shared(); - auto ret = session->Init(ctx, config_infos); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Init session failed."; - return nullptr; - } - return session; -} -REG_SESSION(kLiteInferSession, LiteInferSessionCreator); -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/session/lite_infer_session.h b/mindspore-lite/src/extendrt/session/lite_infer_session.h deleted file mode 100644 index 1bab0fd122fdb95c07ecfa0c77149d24a5a50564..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/session/lite_infer_session.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_SINGLE_OP_SESSION_H_ -#define MINDSPORE_LITE_EXTENDRT_SINGLE_OP_SESSION_H_ - -#include -#include -#include - -#include "src/extendrt/infer_session.h" -#include "src/litert/lite_session.h" - -namespace mindspore { -class LiteInferSession : public InferSession { - public: - LiteInferSession() = default; - explicit LiteInferSession(const std::shared_ptr &context) : context_(context) {} - virtual ~LiteInferSession() = default; - Status Init(const std::shared_ptr &context, const ConfigInfos &config_info = {}) override; - Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0, - uint32_t *graph_id = nullptr) override; - Status RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs) override; - std::vector GetOutputs(uint32_t graph_id) override; - std::vector GetInputs(uint32_t graph_id) override; - std::vector GetOutputNames(uint32_t graph_id) override; - std::vector GetInputNames(uint32_t graph_id) override; - MutableTensorImplPtr GetOutputByTensorName(uint32_t graph_id, const std::string &tensorName) override; - MutableTensorImplPtr GetInputByTensorName(uint32_t graph_id, const std::string &name) override; - - private: - std::shared_ptr CreateLiteSession(const std::shared_ptr &context); - std::vector GetLiteSessionOutputs(); - void ResetTensorData(std::vector old_data, const std::vector &tensors); - std::vector TruncateShape(const std::vector &shape, enum TypeId type, size_t data_len, - bool verify_size); - std::vector ConvertToTensorNames(const std::vector &lite_tensors); - - private: - std::shared_ptr lite_session_; - std::shared_ptr context_; -}; -} // namespace mindspore - -#endif // MINDSPORE_LITE_EXTENDRT_SINGLE_OP_SESSION_H_ diff --git a/mindspore-lite/src/extendrt/session/memory_offload_session.cc b/mindspore-lite/src/extendrt/session/memory_offload_session.cc deleted file mode 100644 index a17463837f3f01ebce3daa07ecdab84ffcb4fb2e..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/session/memory_offload_session.cc +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2023 uawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include "src/extendrt/session/memory_offload_session.h" -#include "common/ms_factory.h" -#include "src/extendrt/session/factory.h" -#include "src/extendrt/memory_offload/infer_strategy_builder.h" -#include "src/extendrt/utils/func_graph_utils.h" -#include "src/common/common.h" - -namespace mindspore::lite { -Status MemoryOffloadInferSession::Init(const std::shared_ptr &context, const ConfigInfos &config_info) { - context_ = context; - return SingleOpInferSession::Init(context, config_info); -} - -kernel::KernelModKernel *MemoryOffloadInferSession::BuildCustomAscendKernelImpl( - const CNodePtr &cnode, const lite::CompileNodePtr &compile_node) { - auto kernel_name = lite::kNameCustomAscend; - std::shared_ptr kernel_mod = kernel::Factory::Instance().Create(kernel_name); - if (kernel_mod == nullptr) { - MS_LOG(ERROR) << "Kernel mod is nullptr, kernel name: " << kernel_name; - return nullptr; - } - - kernel_mod->SetDevicedId(device_id_); - mindspore::kernel::BaseOperatorPtr base_operator; - if (!FuncGraphUtils::GetCNodeOperator(cnode, &base_operator)) { - MS_LOG(ERROR) << "Failed to create operator for cnode " << cnode->fullname_with_scope(); - return nullptr; - } - SetCustomAscendOpAttrs(base_operator); - - auto lite_kernel_mod = - new (std::nothrow) kernel::KernelModKernel(kernel_mod, base_operator, compile_node->GetCNode(), - compile_node->GetInputs(), compile_node->GetOutputs(), nullptr); - if (lite_kernel_mod == nullptr) { - MS_LOG(ERROR) << "new kernel failed " << kernel_name; - return nullptr; - } - - auto ret = lite_kernel_mod->Prepare(); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "kernel prepare failed " << kernel_name; - delete lite_kernel_mod; - return nullptr; - } - - MS_LOG(INFO) << "create Kernel: " << kernel_name << " succ"; - return lite_kernel_mod; -} - -Status MemoryOffloadInferSession::BuildCustomAscendKernel(const CNodePtr &cnode, CompileNodePtr compile_node) { - auto kernel = BuildCustomAscendKernelImpl(cnode, std::move(compile_node)); - if (kernel == nullptr) { - MS_LOG(ERROR) << "Build ascend kernel failed for node: " << cnode->fullname_with_scope(); - return kLiteError; - } - kernels_.push_back(kernel); - - return kSuccess; -} - -Status MemoryOffloadInferSession::CompileGraph(FuncGraphPtr graph, const void *data, size_t size, uint32_t *) { - MS_LOG(INFO) << "MemoryOffloadInferSession::CompileGraph"; - auto compile_option = std::make_shared(); - compile_option->graph_format = NCHW; - lite::CompileResultBuilder compiler(compile_option); - lite::CompileResultPtr compile_result_ = compiler.Build(graph); - if (compile_result_ == nullptr) { - MS_LOG(ERROR) << "Failed to build compile result"; - return kLiteError; - } - for (const auto &node : compile_result_->GetNodes()) { - auto ret = BuildCustomAscendKernel(node->GetCNode(), node); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Failed to Build custom ascend kernel"; - return ret; - } - } - MemoryOffloadInferStrategyBuilder strategy_builder; - auto strategy_ = strategy_builder.Build(compile_result_, swap_context_); - if (strategy_ == nullptr) { - MS_LOG(ERROR) << "Failed to build strategy"; - return kLiteError; - } - - return kSuccess; -} - -static std::shared_ptr MemoryOffloadSessionCreator(const std::shared_ptr &ctx, - const ConfigInfos &config_infos) { - auto session = std::make_shared(); - MS_EXCEPTION_IF_NULL(session); - session->Init(ctx); - session->SetConfigInfo(config_infos); - return session; -} - -REG_SESSION(kMemoryOffloadSession, MemoryOffloadSessionCreator); -} // namespace mindspore::lite diff --git a/mindspore-lite/src/extendrt/session/memory_offload_session.h b/mindspore-lite/src/extendrt/session/memory_offload_session.h deleted file mode 100644 index bd12ad49dc6fa8dd688c1c44bbd41c91b366a112..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/session/memory_offload_session.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_MEMORY_LOAD_SESSION_H_ -#define MINDSPORE_LITE_EXTENDRT_MEMORY_LOAD_SESSION_H_ - -#include -#include -#include -#include -#include "src/extendrt/kernel/default/kernel_mod_kernel.h" -#include "src/extendrt/session/single_op_session.h" -#include "src/extendrt/graph_compiler/compile_result_builder.h" -#include "src/extendrt/memory_offload/infer_strategy_builder.h" -namespace mindspore::lite { -/// \brief memory offload implementation. -class MemoryOffloadInferSession : public SingleOpInferSession { - public: - MemoryOffloadInferSession() = default; - ~MemoryOffloadInferSession() override = default; - - Status Init(const std::shared_ptr &context, const ConfigInfos &config_info = {}) override; - Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0, - uint32_t *graph_id = nullptr) override; - - private: - Status BuildCustomAscendKernel(const CNodePtr &cnode, lite::CompileNodePtr compile_node); - kernel::KernelModKernel *BuildCustomAscendKernelImpl(const CNodePtr &cnode, const lite::CompileNodePtr &compile_node); - - lite::CompileResultPtr compile_result_; - std::vector kernels_; - std::shared_ptr swap_context_; - std::shared_ptr strategy_; - std::shared_ptr context_; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_EXTENDRT_MEMORY_LOAD_SESSION_H_ diff --git a/mindspore-lite/src/extendrt/session/single_op_session.cc b/mindspore-lite/src/extendrt/session/single_op_session.cc deleted file mode 100644 index b89a088bcb71966473755213bc27cb7c9ac5929a..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/session/single_op_session.cc +++ /dev/null @@ -1,684 +0,0 @@ -/** - * Copyright 2019-2023 uawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include - -#include "src/extendrt/session/single_op_session.h" -#include "src/extendrt/infer_device_address.h" - -#include "common/ms_factory.h" -#include "include/common/utils/anfalgo.h" -#include "include/backend/anf_runtime_algorithm.h" -#include "common/common_utils.h" -#include "plugin/device/cpu/kernel/cpu_kernel_mod.h" -#include "src/extendrt/utils/kernel_build_utils.h" -#include "src/extendrt/kernel/ascend/plugin/ascend_kernel_plugin.h" -#include "src/common/common.h" -#include "mindspore/ops/infer/custom.h" -#include "extendrt/session/factory.h" -#include "extendrt/utils/runtime_utils.h" -#include "extendrt/utils/tensor_default_impl.h" -#include "extendrt/utils/func_graph_utils.h" -#include "tools/optimizer/common/gllo_utils.h" -#include "extendrt/utils/tensor_utils.h" - -namespace mindspore { -const size_t tensor_max_size = 0x1000000; - -Status SingleOpInferSession::AscendInit(const std::shared_ptr &context) { - auto device_list = context->MutableDeviceInfo(); - for (const auto &device_info : device_list) { - if (device_info == nullptr) { - MS_LOG(ERROR) << "Device info get from Context cannot be nullptr"; - return kLiteError; - } - if (device_info->GetDeviceType() == DeviceType::kAscend) { - if (!kernel::AscendKernelPlugin::Register()) { - MS_LOG(ERROR) << "Failed to register Ascend plugin"; - return kLiteError; - } - bool is_registered = kernel::AscendAllocatorPlugin::GetInstance().Register(); - if (!is_registered) { - MS_LOG(ERROR) << "AscendAllocatorPlugin failed to register, cannot do acl memory operations"; - return kLiteError; - } - auto ascend_device_info = device_info->Cast(); - if (ascend_device_info == nullptr) { - MS_LOG(ERROR) << "Failed to cast device info to AscendDeviceInfo"; - return kLiteError; - } - device_id_ = ascend_device_info->GetDeviceID(); - return kSuccess; - } - } - MS_LOG(DEBUG) << "There is no ascend device info, no need to register ascend plugin."; - return kSuccess; -} - -Status SingleOpInferSession::Init(const std::shared_ptr &context, const ConfigInfos &config_info) { - MS_LOG(INFO) << "SingleOpInferSession::Init"; - if (context == nullptr) { - MS_LOG(ERROR) << "Input argument context cannot be nullptr"; - return kLiteError; - } - if (AscendInit(context) != kSuccess) { - MS_LOG(ERROR) << "Init ascend failed."; - return kLiteError; - } - config_infos_ = config_info; - return kSuccess; -} - -void SingleOpInferSession::SetCustomAscendOpAttrs(const kernel::BaseOperatorPtr &op) { - if (config_infos_.find(lite::kAscendContextSection) == config_infos_.end() && - config_infos_.find("inner_common") == config_infos_.end()) { - MS_LOG(DEBUG) << "There is no ascend context info in config infos."; - return; - } - // set custom op attrs - auto custom_op = std::dynamic_pointer_cast(op); - CHECK_NULL_RETURN_VOID(custom_op); - auto dst_prim = custom_op->GetPrim(); - CHECK_NULL_RETURN_VOID(dst_prim); - auto share_mem = config_infos_[lite::kInnerCommon]; - if (share_mem.find(lite::kInnerCalcWorkspaceSize) != share_mem.end()) { - auto value = share_mem[lite::kInnerCalcWorkspaceSize]; - is_multi_model_sharing_mem_prepare_ = value == "true" ? true : false; - dst_prim->AddAttr(lite::kInnerCalcWorkspaceSize, MakeValue(is_multi_model_sharing_mem_prepare_)); - MS_LOG(INFO) << "inner_calc_workspace_size: " << is_multi_model_sharing_mem_prepare_; - } - if (share_mem.find(lite::kInnerSharingWorkspace) != share_mem.end()) { - auto value = share_mem[lite::kInnerSharingWorkspace]; - bool is_inner_sharing_workspace = value == "true" ? true : false; - dst_prim->AddAttr(lite::kInnerSharingWorkspace, MakeValue(is_inner_sharing_workspace)); - MS_LOG(INFO) << "is_inner_sharing_workspace: " << is_inner_sharing_workspace; - } - if (share_mem.find(lite::kInnerModelPath) != share_mem.end()) { - auto model_path = share_mem[lite::kInnerModelPath]; - dst_prim->AddAttr(lite::kInnerModelPath, MakeValue(model_path)); - MS_LOG(INFO) << "inner_model_path: " << model_path; - } - if (share_mem.find(lite::kInnerWorkspace) != share_mem.end()) { - auto workspace = share_mem[lite::kInnerWorkspace]; - bool is_workspace = workspace == "true" ? true : false; - dst_prim->AddAttr(lite::kInnerWorkspace, MakeValue(is_workspace)); - } - if (share_mem.find(lite::kInnerWeightspace) != share_mem.end()) { - auto weightspace = share_mem[lite::kInnerWeightspace]; - bool is_weightspace = weightspace == "true" ? true : false; - dst_prim->AddAttr(lite::kInnerWeightspace, MakeValue(is_weightspace)); - } - if (share_mem.find(lite::kInnerWeightspaceWorkspace) != share_mem.end()) { - auto weightspace_workspace = share_mem[lite::kInnerWeightspaceWorkspace]; - bool is_weightspace_workspace = weightspace_workspace == "true" ? true : false; - dst_prim->AddAttr(lite::kInnerWeightspaceWorkspace, MakeValue(is_weightspace_workspace)); - } - if (share_mem.find(lite::kBundleModel) != share_mem.end()) { - auto bundle_model = share_mem[lite::kBundleModel]; - bool is_bundle_model = bundle_model == "true" ? true : false; - dst_prim->AddAttr(lite::kBundleModel, MakeValue(is_bundle_model)); - } - auto ascend_context = config_infos_[lite::kAscendContextSection]; - std::string profiling_path; - if (ascend_context.find(lite::kProfilingPathKey) != ascend_context.end()) { - profiling_path = ascend_context[lite::kProfilingPathKey]; - dst_prim->AddAttr(lite::kProfilingPathKey, MakeValue(profiling_path)); - } - if (ascend_context.find(lite::kDumpPathKey) != ascend_context.end()) { - if (!profiling_path.empty()) { - MS_LOG(ERROR) << "Profiling and dump can't be set at the same time."; - return; - } - auto dump_path = ascend_context[lite::kDumpPathKey]; - dst_prim->AddAttr(lite::kDumpPathKey, MakeValue(dump_path)); - } -} - -void SingleOpInferSession::DestoryKernelTensor(LiteKernelArgs args) { - for (auto input : args.inputs) { - if (input != nullptr) { - delete input; - } - } - args.inputs.clear(); - for (auto output : args.outputs) { - if (output != nullptr) { - delete output; - } - } - args.outputs.clear(); -} - -SingleOpInferSession::~SingleOpInferSession() { DestoryKernelTensor(kernel_args_); } - -std::tuple SingleOpInferSession::BuildCustomAscendKernelImpl( - const CNodePtr &cnode) { - auto kernel_name = lite::kNameCustomAscend; - std::shared_ptr kernel_mod = kernel::Factory::Instance().Create(kernel_name); - if (kernel_mod == nullptr) { - MS_LOG(ERROR) << "Kernel mod is nullptr, kernel name: " << kernel_name; - return std::make_tuple(nullptr, LiteKernelArgs{}); - } - MS_LOG(INFO) << "SingleOpInferSession::Kernels " << kernel_name; - kernel_mod->SetDevicedId(device_id_); - - auto make_kernel_tensor = [](TypeId type_id, const ShapeVector &shape) { - auto kernel_tensor = new (std::nothrow) kernel::KernelTensor(); - if (kernel_tensor == nullptr) { - return kernel_tensor; - } - kernel_tensor->SetType(std::make_shared(TypeIdToType(type_id))); - kernel_tensor->SetShape(std::make_shared(shape)); - return kernel_tensor; - }; - - LiteKernelArgs args; - BaseOperatorPtr op; - if (!FuncGraphUtils::GetCNodeOperator(cnode, &op)) { - MS_LOG(ERROR) << "Failed to create operator for cnode " << cnode->fullname_with_scope(); - return std::make_tuple(nullptr, LiteKernelArgs{}); - } - std::vector tensor_cache; - std::map kernel_tensor_map; - std::vector inputs; - std::vector outputs; - FuncGraphUtils::GetCNodeInputsOutputs(cnode, &inputs, &outputs); - for (size_t i = 0; i < inputs.size(); i++) { - auto &input = inputs[i]; - auto data_type = FuncGraphUtils::GetTensorDataType(input); - auto shape = FuncGraphUtils::GetTensorShape(input); - auto kernel_tensor = make_kernel_tensor(static_cast(data_type), shape); - auto tensor_data = FuncGraphUtils::GetConstNodeValue(input.first); - if (tensor_data) { - tensor_cache.push_back(tensor_data); - kernel_tensor->SetData(std::make_shared(tensor_data->data_c(), tensor_data->Size())); - } - args.inputs.push_back(kernel_tensor); - kernel_tensor_map[input] = kernel_tensor; - } - for (size_t i = 0; i < outputs.size(); i++) { - auto &output = outputs[i]; - kernel::KernelTensor *kernel_tensor = nullptr; - auto it = kernel_tensor_map.find(output); - if (it != kernel_tensor_map.end()) { // use input as output - kernel_tensor = it->second; - } else { - auto data_type = FuncGraphUtils::GetTensorDataType(output); - auto shape = FuncGraphUtils::GetTensorShape(output); - kernel_tensor = make_kernel_tensor(static_cast(data_type), shape); - } - args.outputs.push_back(kernel_tensor); - } - SetCustomAscendOpAttrs(op); - auto ret = kernel_mod->Init(op->GetPrim(), args.inputs, args.outputs); - MS_LOG(INFO) << "SingleOpInferSession::Kernels ret " << ret; - if (!ret) { - DestoryKernelTensor(args); - MS_LOG(ERROR) << "kernel init failed " << kernel_name; - return std::make_tuple(nullptr, LiteKernelArgs{}); - } - if (is_multi_model_sharing_mem_prepare_) { - DestoryKernelTensor(args); - MS_LOG(INFO) << "is multi model sharing mem prepare"; - return std::make_tuple(nullptr, LiteKernelArgs{}); - } - if (args.inputs.size() > 0) { - delete args.inputs.back(); - args.inputs.pop_back(); - } - return std::make_tuple(kernel_mod, args); -} - -Status SingleOpInferSession::BuildCustomAscendKernel(const CNodePtr &cnode) { - kernel::KernelModPtr kernel_mod; - LiteKernelArgs args; - std::tie(kernel_mod, args) = BuildCustomAscendKernelImpl(cnode); - if (is_multi_model_sharing_mem_prepare_) { - MS_LOG(INFO) << "using ascend workspace sharing."; - return kSuccess; - } - if (kernel_mod == nullptr) { - MS_LOG(ERROR) << "Build ascend kernel failed for node: " << cnode->fullname_with_scope(); - return kLiteError; - } - kernel_mod_ = kernel_mod; - kernel_args_ = args; - return kSuccess; -} - -Status SingleOpInferSession::InitInputOutputInfos(const FuncGraphPtr &graph) { - std::vector input_tensors; - std::vector output_tensors; - FuncGraphUtils::GetFuncGraphInputs(graph, &input_tensors); - FuncGraphUtils::GetFuncGraphOutputs(graph, &output_tensors); - if (kernel_args_.inputs.size() != input_tensors.size()) { - MS_LOG(ERROR) << "Graph inputs size " << input_tensors.size() << " != custom inputs size " - << kernel_args_.inputs.size(); - return kCoreFailed; - } - if (kernel_args_.outputs.size() != output_tensors.size()) { - MS_LOG(ERROR) << "Graph outputs size " << output_tensors.size() << " != custom inputs size " - << kernel_args_.outputs.size(); - return kCoreFailed; - } - for (size_t i = 0; i < input_tensors.size(); i++) { - auto &tensor = input_tensors[i]; - auto &kernel_tensor = kernel_args_.inputs[i]; - MS_CHECK_TRUE_RET(kernel_tensor != nullptr, kLiteNullptr); - auto tensor_name = FuncGraphUtils::GetTensorName(tensor); - auto data_type = static_cast(kernel_tensor->dtype_id()); - auto shape = kernel_tensor->GetShapeVector(); - // when input shape is NOT dynamic, the sizes are known and memory can be pre-alloced (thus set is_acl_host to true) - bool is_acl_host = IsDynamicShape(shape) ? false : true; - auto input = std::make_shared(tensor_name, data_type, shape, is_acl_host); - MS_CHECK_TRUE_RET(input != nullptr, kLiteNullptr); - inputs_.push_back(input); - input_names_.push_back(FuncGraphUtils::GetTensorName(tensor)); - (void)malloced_data_size_.insert(std::make_pair(input, input->DataSize())); - } - for (size_t i = 0; i < output_tensors.size(); i++) { - auto &tensor = output_tensors[i]; - auto &kernel_tensor = kernel_args_.outputs[i]; - auto tensor_name = FuncGraphUtils::GetTensorName(tensor); - auto data_type = static_cast(kernel_tensor->dtype_id()); - auto shape = kernel_tensor->GetShapeVector(); - if (dyn_outshape_.size() < output_tensors.size()) { - dyn_outshape_.push_back(false); - } - if (IsDynamicShape(shape)) { - dyn_outshape_[i] = true; - MS_LOG(INFO) << "output " << i << " shape is dynamic: " << shape; - } - outputs_.push_back(std::make_shared(tensor_name, data_type, shape)); - output_names_.push_back(FuncGraphUtils::GetTensorName(tensor)); - } - return kSuccess; -} - -Status SingleOpInferSession::CompileGraph(FuncGraphPtr graph, const void *data, size_t size, uint32_t *) { - MS_LOG(INFO) << "SingleOpInferSession::CompileGraph"; - MS_CHECK_TRUE_RET(graph != nullptr, kLiteNullptr); - // Get whether the current model is a bundle model for LORA. - if (graph->get_attr(lite::kBundleModel) != nullptr) { - config_infos_["inner_common"][lite::kBundleModel] = "true"; - } - auto nodes = graph->TopoSort(graph->get_return()); - if (nodes.empty()) { - MS_LOG(ERROR) << "There are no nodes in the graph"; - return kLiteNullptr; - } - size_t cnode_count = 0; - for (const auto &node : nodes) { - auto cnode = node->cast(); - if (!cnode || !AnfUtils::IsRealKernel(cnode)) { - continue; - } - std::string kernel_name = common::AnfAlgo::GetCNodeName(cnode); - if (kernel_name != lite::kNameCustomAscend) { - MS_LOG(ERROR) << "Only support " << lite::kNameCustomAscend << ", but got " << kernel_name << ", node " - << cnode->fullname_with_scope(); - return kLiteError; - } - cnode_count += 1; - if (cnode_count > 1) { - MS_LOG(ERROR) << "Only support one " << lite::kNameCustomAscend << " node, but got " << kernel_name << ", node " - << cnode->fullname_with_scope(); - return kLiteError; - } - auto ret = BuildCustomAscendKernel(cnode); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Failed to Build custom ascend kernel"; - return ret; - } - } - if (is_multi_model_sharing_mem_prepare_) { - MS_LOG(INFO) << "is multi model sharing mem prepare"; - return kSuccess; - } - auto ret = InitInputOutputInfos(graph); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Failed to init graph input and output infos"; - return ret; - } - return kSuccess; -} - -Status SingleOpInferSession::RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs, const MSKernelCallBack &before, - const MSKernelCallBack &after) { - return RunGraph(graph_id, inputs, outputs); -} - -void SingleOpInferSession::SetBackOutputIfDynamic(std::vector *outputs) { - for (size_t i = 0; i < kernel_args_.outputs.size(); ++i) { - if (dyn_outshape_[i]) { - MS_CHECK_TRUE_RET_VOID(kernel_args_.outputs[i] != nullptr); - ShapeVector shape = kernel_args_.outputs[i]->GetShapeVector(); - (*outputs)[i].set_shape(shape); - kernel::AddressPtr host_addr = kernel_args_.outputs[i]->GetHostData(); - kernel::AddressPtr device_addr = kernel_args_.outputs[i]->GetData(); - if (device_addr != nullptr) { - TypeId out_type = kernel_args_.outputs[i]->dtype_id(); - (*outputs)[i] = tensor::Tensor(out_type, shape, nullptr, device_addr->size); - (*outputs)[i].set_device_address(std::make_shared(device_addr->addr, device_addr->size)); - } else if (host_addr != nullptr) { - TypeId out_type = kernel_args_.outputs[i]->dtype_id(); - auto type_size = abstract::TypeIdSize(out_type); - MS_CHECK_TRUE_RET_VOID(type_size != 0); - auto elem_num = kernel_args_.outputs[i]->size() / type_size; - auto acl_mem_deleter = [](uint8_t *data_buf_ptr) { - kernel::AscendAllocatorPlugin::GetInstance().FreeHost(static_cast(data_buf_ptr)); - }; - auto ref_tensor_data = - std::make_shared(host_addr->addr, elem_num, host_addr->size, shape.size(), acl_mem_deleter); - (*outputs)[i] = tensor::Tensor(out_type, shape, ref_tensor_data); - MS_LOG(DEBUG) << "resetting kernel tensor shape to 0 for the next prediction"; - kernel_args_.outputs[i]->SetShapeVector({0}); - } - } - } -} - -Status SingleOpInferSession::InitInputOutputData(const std::vector &inputs, - std::vector *outputs) { - if (inputs.size() != kernel_args_.inputs.size()) { - MS_LOG(ERROR) << "Given inputs size " << inputs.size() << " != graph inputs size " << kernel_args_.inputs.size(); - return kLiteError; - } - for (size_t i = 0; i < inputs.size(); i++) { - auto &input = inputs[i]; - auto &kernel_input = kernel_args_.inputs[i]; - MS_CHECK_TRUE_RET(kernel_input != nullptr, kLiteError); - if (input.Size() != kernel_input->size()) { - MS_LOG(ERROR) << "Byte size of input " << i << " != the size expected, given size " << input.Size() - << ", expected size " << kernel_input->size() - << ", input shape: " << kernel_input->GetShapeVector(); - return kLiteError; - } - auto input_device_address = input.device_address(); - if (input_device_address != nullptr && input_device_address->GetMutablePtr() != nullptr) { - auto device_ptr = input_device_address->GetMutablePtr(); - kernel_args_.inputs[i]->SetData(std::make_shared(device_ptr, input.Size())); - kernel_args_.inputs[i]->SetHostData(nullptr); - } else { - kernel_args_.inputs[i]->SetHostData(std::make_shared(input.data_c(), input.Size())); - kernel_args_.inputs[i]->SetData(nullptr); - } - kernel_args_.inputs[i]->set_device_id(input.device_info().device_id_); - } - if (outputs->empty()) { - std::transform(kernel_args_.outputs.begin(), kernel_args_.outputs.end(), std::back_inserter(*outputs), - [](auto &item) { return tensor::Tensor(item->dtype_id(), item->GetShapeVector()); }); - } - if (outputs->size() != kernel_args_.outputs.size()) { - MS_LOG(ERROR) << "Given outputs size " << outputs->size() << " != graph inputs size " - << kernel_args_.outputs.size(); - return kLiteError; - } - for (size_t i = 0; i < outputs->size(); i++) { - auto &output = (*outputs)[i]; - auto &kernel_output = kernel_args_.outputs[i]; - if (!dyn_outshape_[i] && output.Size() != kernel_output->size()) { - MS_LOG(ERROR) << "Byte size of output " << i << " != the size expected, given size " << output.Size() - << ", expected size " << kernel_output->size() - << ", output shape: " << kernel_output->GetShapeVector(); - return kLiteError; - } - auto output_device_address = output.device_address(); - if (output_device_address != nullptr && output_device_address->GetMutablePtr() != nullptr) { - auto device_ptr = output_device_address->GetMutablePtr(); - kernel_args_.outputs[i]->SetData(std::make_shared(device_ptr, output.Size())); - kernel_args_.outputs[i]->SetHostData(nullptr); - } else { - if (output.Size() == 0) { - kernel_args_.outputs[i]->SetHostData(std::make_shared(nullptr, output.Size())); - } else { - kernel_args_.outputs[i]->SetHostData(std::make_shared(output.data_c(), output.Size())); - } - kernel_args_.outputs[i]->SetData(nullptr); - } - kernel_args_.outputs[i]->set_device_id(output.device_info().device_id_); - } - return kSuccess; -} - -void DestoryKernelWeights(std::vector *kernel_weights) { - for (auto &w : (*kernel_weights)) { - if (w != nullptr) { - delete w; - w = nullptr; - } - } -} - -Status SingleOpInferSession::InitVariableWeights(const std::vector> &weights, - std::vector *kernel_weights) { - auto make_kernel_tensor = [](TypeId type_id, const ShapeVector &shape) { - auto kernel_tensor = new (std::nothrow) kernel::KernelTensor(); - if (kernel_tensor == nullptr) { - return kernel_tensor; - } - kernel_tensor->SetType(std::make_shared(TypeIdToType(type_id))); - kernel_tensor->SetShape(std::make_shared(shape)); - return kernel_tensor; - }; - for (size_t i = 0; i < weights.size(); i++) { - auto &input = weights[i]; - auto data_type = input->data_type(); - auto shape = input->shape(); - auto kernel_tensor = make_kernel_tensor(static_cast(data_type), shape); - kernel_tensor->SetData(std::make_shared(input->data_c(), input->Size())); - auto input_device_address = input->device_address(); - if (input_device_address != nullptr && input_device_address->GetMutablePtr() != nullptr) { - auto device_ptr = input_device_address->GetMutablePtr(); - kernel_tensor->SetData(std::make_shared(device_ptr, input->Size())); - kernel_tensor->SetHostData(nullptr); - } else { - kernel_tensor->SetHostData(std::make_shared(input->data_c(), input->Size())); - kernel_tensor->SetData(nullptr); - } - kernel_tensor->set_device_id(input->device_info().device_id_); - kernel_weights->push_back(kernel_tensor); - } - return kSuccess; -} - -Status SingleOpInferSession::RunGraph(uint32_t, const std::vector &inputs, - std::vector *outputs) { - if (outputs == nullptr) { - MS_LOG(ERROR) << "outputs cannot be nullptr"; - return kLiteError; - } - if (kernel_mod_ == nullptr) { - MS_LOG(ERROR) << "Model has not been built"; - return kLiteError; - } - MS_LOG(DEBUG) << "SingleOpInferSession::RunGraph with input and outputs"; - std::vector new_shapes; - std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_shapes), [](auto &t) { return t.shape_c(); }); - auto ret = OnNewInputShapes(new_shapes); - if (ret != kSuccess) { - return ret; - } - ret = InitInputOutputData(inputs, outputs); - if (ret != kSuccess) { - return ret; - } - try { - std::vector ignore_datas; - if (!kernel_mod_->Launch(ignore_datas, ignore_datas, ignore_datas, nullptr)) { - MS_LOG(ERROR) << "Failed to launch kernel"; - return kLiteError; - } - } catch (std::exception &e) { - MS_LOG(ERROR) << "Failed to launch kernel, exception: " << e.what(); - return kLiteError; - } - SetBackOutputIfDynamic(outputs); - return kSuccess; -} - -Status SingleOpInferSession::OnNewInputShapes(const std::vector &new_shapes) { - if (kernel_mod_ == nullptr) { - MS_LOG(ERROR) << "Model has not been built"; - return kLiteError; - } - if (inputs_.size() != new_shapes.size()) { - MS_LOG(ERROR) << "Graph inputs size " << inputs_.size() << " != resize input size " << new_shapes.size(); - return kLiteError; - } - auto input_changed = false; - for (size_t i = 0; i < inputs_.size(); i++) { - auto new_shape = new_shapes[i]; - if (std::any_of(new_shape.begin(), new_shape.end(), [](auto dim) { return dim < 0; })) { - MS_LOG(ERROR) << "New shape of input " << i << " cannot be dynamic, new shape: " << new_shape; - return kLiteError; - } - MS_CHECK_TRUE_RET(inputs_[i] != nullptr, kLiteError); - if (inputs_[i]->Shape() != new_shapes[i]) { - input_changed = true; - MS_CHECK_TRUE_RET(kernel_args_.inputs[i] != nullptr, kLiteError); - kernel_args_.inputs[i]->SetShapeVector(new_shapes[i]); - MS_LOG(INFO) << "Set kernel args shape: " << kernel_args_.inputs[i]->GetShapeVector() << ", " - << inputs_[i]->Shape(); - } - } - if (!input_changed) { - return kSuccess; - } - MS_LOG(INFO) << "SingleOpInferSession::Resize"; - - if (kernel_mod_->Resize(kernel_args_.inputs, kernel_args_.outputs) != kSuccess) { - MS_LOG(ERROR) << "Failed to resize custom ascend kernel"; - for (size_t i = 0; i < inputs_.size(); i++) { - kernel_args_.inputs[i]->SetShapeVector(inputs_[i]->Shape()); - } - return kLiteError; - } - // shapes of inputs and outputs should be updated in CustomAscendKernelMod::Resize - for (size_t i = 0; i < inputs_.size(); i++) { - auto input = inputs_[i]; - MS_CHECK_TRUE_RET(input != nullptr, kLiteNullptr); - auto input_tensor = std::dynamic_pointer_cast(inputs_[i]); - MS_CHECK_TRUE_MSG(input_tensor != nullptr, kLiteNullptr, "cast to TensorDefaultImpl failed"); - input_tensor->SetShape(kernel_args_.inputs[i]->GetShapeVector()); - MS_CHECK_TRUE_RET(malloced_data_size_.find(input) != malloced_data_size_.end(), kLiteError); - if (input_tensor->DataSize() > malloced_data_size_.at(input)) { - auto data_size = input_tensor->DataSize(); - void *data_buf = kernel::AscendAllocatorPlugin::GetInstance().MallocHost(data_size); - MS_CHECK_TRUE_MSG(data_buf != nullptr, kLiteNullptr, "malloc on host failed"); - input_tensor->SetAclHostData(data_buf); - malloced_data_size_[input] = data_size; - } - } - for (size_t i = 0; i < outputs_.size(); i++) { - outputs_[i]->SetShape(kernel_args_.outputs[i]->GetShapeVector()); - } - return kSuccess; -} - -Status SingleOpInferSession::Resize(uint32_t, const std::vector &, - const std::vector> &dims) { - return OnNewInputShapes(dims); -} - -std::vector SingleOpInferSession::GetOutputs(uint32_t) { return outputs_; } -std::vector SingleOpInferSession::GetInputs(uint32_t) { return inputs_; } -std::vector SingleOpInferSession::GetOutputNames(uint32_t) { return output_names_; } -std::vector SingleOpInferSession::GetInputNames(uint32_t) { return input_names_; } - -MutableTensorImplPtr SingleOpInferSession::GetOutputByTensorName(uint32_t, const std::string &tensor_name) { - for (size_t idx = 0; idx < output_names_.size(); ++idx) { - if (output_names_[idx] == tensor_name) { - if (idx < outputs_.size()) { - return outputs_[idx]; - } - } - } - MS_LOG(ERROR) << "Can't found tensor name " << tensor_name; - return nullptr; -} - -Status SingleOpInferSession::UpdateWeights(const std::vector>> &weights) { - if (kernel_mod_ == nullptr) { - MS_LOG(ERROR) << "Model has not been built!"; - return kLiteError; - } - if (weights.size() > 1) { - MS_LOG(ERROR) << "Only support single weight, current weight num:" << weights.size() << "!"; - return kLiteError; - } - std::vector kernel_weights; - auto ret = InitVariableWeights(weights[0], &kernel_weights); - if (ret != kSuccess) { - DestoryKernelWeights(&kernel_weights); - MS_LOG(ERROR) << "InitLoraWeights failed! ret:" << ret << "!"; - return ret; - } - std::vector ignore_datas; - if (!kernel_mod_->UpdateWeights(kernel_weights, ignore_datas, ignore_datas, nullptr)) { - DestoryKernelWeights(&kernel_weights); - MS_LOG(ERROR) << "Failed to launch kernel!"; - return kLiteError; - } - DestoryKernelWeights(&kernel_weights); - return kSuccess; -} - -MutableTensorImplPtr SingleOpInferSession::GetInputByTensorName(uint32_t, const std::string &tensor_name) { - for (size_t idx = 0; idx < input_names_.size(); ++idx) { - if (input_names_[idx] == tensor_name) { - if (idx < inputs_.size()) { - return inputs_[idx]; - } - } - } - MS_LOG(ERROR) << "Can't found tensor name " << tensor_name; - return nullptr; -} - -void SingleOpInferSession::AscendFinalize() { - auto kernel_name = lite::kNameCustomAscend; - std::shared_ptr kernel_mod = kernel::Factory::Instance().Create(kernel_name); - if (kernel_mod == nullptr) { - MS_LOG(INFO) << "Create kernel mod failed: " << kernel_name; - return; - } - (void)kernel_mod->Finalize(); -} - -Status SingleOpInferSession::Finalize() { - SingleOpInferSession::AscendFinalize(); - return kSuccess; -} - -static std::shared_ptr SingleOpSessionCreator(const std::shared_ptr &ctx, - const ConfigInfos &config_infos) { - auto session = std::make_shared(); - auto ret = session->Init(ctx, config_infos); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Init session failed."; - return nullptr; - } - return session; -} -REG_SESSION(kSingleOpSession, SingleOpSessionCreator); -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/session/single_op_session.h b/mindspore-lite/src/extendrt/session/single_op_session.h deleted file mode 100644 index 82750acfd813985e3ce62fffd66b140db3b8b3b3..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/session/single_op_session.h +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2019-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_EXTENDRT_SINGLE_OP_SESSION_H_ -#define MINDSPORE_LITE_EXTENDRT_SINGLE_OP_SESSION_H_ - -#include -#include -#include -#include -#include -#include -#include "src/extendrt/infer_session.h" -#include "mindspore/ccsrc/kernel/framework_utils.h" - -namespace mindspore { -struct LiteKernelArgs { - std::vector inputs; - std::vector outputs; - std::map depend_tensor_map; // dynamic shape kernel may need this map - // cppcheck-suppress unusedStructMember - constexpr static char key[] = "KernelArgs"; -}; - -/// \brief Single Op Session implementation, used in Ascend Device Context. -class SingleOpInferSession : public InferSession { - public: - SingleOpInferSession() = default; - ~SingleOpInferSession() override; - Status Init(const std::shared_ptr &context, const ConfigInfos &config_info = {}) override; - Status AscendInit(const std::shared_ptr &context); - static void AscendFinalize(); - Status Finalize() override; - Status CompileGraph(FuncGraphPtr graph, const void *data = nullptr, size_t size = 0, - uint32_t *graph_id = nullptr) override; - Status RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector *outputs, - const MSKernelCallBack &before, const MSKernelCallBack &after) override; - Status RunGraph(uint32_t graph_id, const std::vector &inputs, - std::vector *outputs) override; - Status Resize(uint32_t graph_id, const std::vector &inputs, - const std::vector> &dims) override; - std::vector GetOutputs(uint32_t graph_id) override; - std::vector GetInputs(uint32_t graph_id) override; - std::vector GetOutputNames(uint32_t graph_id) override; - std::vector GetInputNames(uint32_t graph_id) override; - MutableTensorImplPtr GetOutputByTensorName(uint32_t graph_id, const std::string &tensorName) override; - MutableTensorImplPtr GetInputByTensorName(uint32_t graph_id, const std::string &name) override; - void SetConfigInfo(const ConfigInfos &config_infos) { config_infos_ = config_infos; } - void SetCustomAscendOpAttrs(const kernel::BaseOperatorPtr &op); - Status UpdateWeights(const std::vector>> &weights); - - protected: - Status OnNewInputShapes(const std::vector &new_shapes); - Status BuildCustomAscendKernel(const CNodePtr &node); - std::tuple BuildCustomAscendKernelImpl(const CNodePtr &node); - Status InitInputOutputInfos(const FuncGraphPtr &graph); - void SetBackOutputIfDynamic(std::vector *outputs); - Status InitInputOutputData(const std::vector &inputs, std::vector *outputs); - Status InitVariableWeights(const std::vector> &weights, - std::vector *kernel_weights); - void DestoryKernelTensor(LiteKernelArgs args); - - std::vector inputs_; - std::vector input_names_; - std::vector outputs_; - std::vector output_names_; - uint32_t device_id_ = 0; - std::vector dyn_outshape_; - - kernel::KernelModPtr kernel_mod_ = nullptr; - LiteKernelArgs kernel_args_; - ConfigInfos config_infos_; - bool is_multi_model_sharing_mem_prepare_ = false; - - std::unordered_map malloced_data_size_; -}; -} // namespace mindspore - -#endif // MINDSPORE_LITE_EXTENDRT_SINGLE_OP_SESSION_H_ diff --git a/mindspore-lite/src/extendrt/subgraph_kernel.cc b/mindspore-lite/src/extendrt/subgraph_kernel.cc index fbbd4db38f587bf6a0a4d60bc23c08777af237b5..2674a88e462ad1f2f2df0660ff7fe95e197622e1 100644 --- a/mindspore-lite/src/extendrt/subgraph_kernel.cc +++ b/mindspore-lite/src/extendrt/subgraph_kernel.cc @@ -15,20 +15,20 @@ */ #include "src/extendrt/subgraph_kernel.h" namespace mindspore::kernel { -bool SubgraphKernel::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - std::vector in; - std::vector out; +bool SubgraphKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + std::vector in; + std::vector out; std::map compile_options; executor_->RunGraph(subgraph_, in, &out, compile_options); return true; } -bool SubgraphKernel::Init(const std::vector &inputs, const std::vector &outputs) { +bool SubgraphKernel::Init(const std::vector &inputs, const std::vector &outputs) { std::map compile_options; return executor_->CompileGraph(subgraph_, compile_options); } -int SubgraphKernel::Resize(const std::vector &inputs, const std::vector &outputs) { +int SubgraphKernel::Resize(const std::vector &inputs, const std::vector &outputs) { return 0; } } // namespace mindspore::kernel diff --git a/mindspore-lite/src/extendrt/subgraph_kernel.h b/mindspore-lite/src/extendrt/subgraph_kernel.h index 3074b1200eb748f300fd7930732d37a4873db5d1..229f1e0d11663f7eff88876fedaad3244c4fc043 100644 --- a/mindspore-lite/src/extendrt/subgraph_kernel.h +++ b/mindspore-lite/src/extendrt/subgraph_kernel.h @@ -23,21 +23,21 @@ #include "ir/func_graph.h" #include "runtime/hardware/device_context.h" #include "common/common_utils.h" -#include "extendrt/session/lite_graph_executor.h" +#include "src/extendrt/session/lite_graph_executor.h" + namespace mindspore::kernel { -class SubgraphKernel : public KernelMod { +class SubgraphKernel { public: SubgraphKernel(FuncGraphPtr subgraph, std::shared_ptr executor) : subgraph_(subgraph), executor_(executor) {} virtual ~SubgraphKernel() = default; - bool Init(const std::vector & /* inputs */, - const std::vector & /* outputs */) override; + bool Init(const std::vector & /* inputs */, const std::vector & /* outputs */); - int Resize(const std::vector &inputs, const std::vector &outputs) override; + int Resize(const std::vector &inputs, const std::vector &outputs); - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GetOpSupport() override { return {}; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr); + std::vector GetOpSupport() { return {}; } protected: FuncGraphPtr subgraph_; diff --git a/mindspore-lite/src/extendrt/utils/func_graph_utils.cc b/mindspore-lite/src/extendrt/utils/func_graph_utils.cc index 1bc434c891e556362c64aa3c82e7e9da2b7fc0d7..4be2c68faa5b0acfffda0bf8081a89d247bbff09 100644 --- a/mindspore-lite/src/extendrt/utils/func_graph_utils.cc +++ b/mindspore-lite/src/extendrt/utils/func_graph_utils.cc @@ -25,8 +25,6 @@ #include "mindspore/ops/op_def/sequence_ops.h" #include "mindspore/ops/op_def/array_ops.h" #include "mindspore/ops/op_def/framework_ops.h" -#include "include/common/utils/convert_utils.h" -#include "mindspore/ccsrc/include/backend/optimizer/helper.h" #include "mindspore/ops/op_def/op_name.h" #include "tools/optimizer/format/to_nhwc_format.h" @@ -60,6 +58,78 @@ ValuePtr FuncGraphUtils::GetNodeValuePtr(AnfNodePtr input_node) { return value; } +tensor::TensorPtr FuncGraphUtils::CreateEmptyTupleTensor(const ValueTuplePtr &value_tuple) { + std::vector tensor_shape = {0}; + tensor::TensorPtr tensor = std::make_shared(kInt64->type_id(), tensor_shape); + MS_EXCEPTION_IF_NULL(tensor); + tensor::DeviceInfo device_info{kOpFormat_DEFAULT, kInt64}; + tensor->set_device_info(device_info); + tensor->set_user_data(kTensorValueIsEmpty, value_tuple); + return tensor; +} + +template +tensor::TensorPtr FuncGraphUtils::CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, + const TypePtr &type_ptr, size_t data_length) { + MS_EXCEPTION_IF_NULL(value_tuple_ptr); + MS_EXCEPTION_IF_NULL(type_ptr); + std::vector values; + for (const auto &v : value_tuple_ptr->value()) { + MS_EXCEPTION_IF_NULL(v); + if (v->isa()) { + auto scalar = v->cast(); + values.push_back(GetValue(scalar)); + } else { + MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; + return nullptr; + } + } + std::vector tensor_shape = {SizeToLong(values.size())}; + tensor::TensorPtr tensor = std::make_shared(type_ptr->type_id(), tensor_shape); + MS_EXCEPTION_IF_NULL(tensor); + tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr}; + tensor->set_device_info(device_info); + auto data_ptr = tensor->data_c(); + MS_EXCEPTION_IF_NULL(data_ptr); + auto elem_num = values.size() * data_length; + auto ret_code = memcpy_s(data_ptr, static_cast(tensor->data().nbytes()), values.data(), elem_num); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data into tensor, memcpy_s errorno: " << ret_code; + } + return tensor; +} + +tensor::TensorPtr FuncGraphUtils::CreateTupleTensor(const ValueTuplePtr &value_tuple) { + MS_EXCEPTION_IF_NULL(value_tuple); + tensor::TensorPtr tensor = nullptr; + if (value_tuple->value().empty()) { + tensor = CreateEmptyTupleTensor(value_tuple); + return tensor; + } + ValuePtr v = *(value_tuple->value().begin()); + MS_EXCEPTION_IF_NULL(v); + // Currently we only deal with the scalar tuple + if (!v->isa()) { + MS_LOG(DEBUG) << "The value " << v << "of tuple is not a scalar"; + return nullptr; + } + auto scalar = v->cast(); + MS_EXCEPTION_IF_NULL(scalar); + if (scalar->isa()) { + tensor = CreateTensorWithValueTuple(value_tuple, kInt32, sizeof(int32_t)); + } else if (scalar->isa()) { + tensor = CreateTensorWithValueTuple(value_tuple, kInt64, sizeof(int64_t)); + } else if (scalar->isa()) { + tensor = CreateTensorWithValueTuple(value_tuple, kFloat32, sizeof(float)); + } else { + auto type = scalar->type(); + auto type_str = (type == nullptr) ? "nullptr" : type->ToString(); + MS_LOG(ERROR) << "Invalid scalar type: " << type_str; + return nullptr; + } + return tensor; +} + tensor::TensorPtr FuncGraphUtils::GetConstNodeValue(AnfNodePtr input_node) { ValuePtr value = GetNodeValuePtr(input_node); if (value == nullptr) { @@ -76,7 +146,7 @@ tensor::TensorPtr FuncGraphUtils::GetConstNodeValue(AnfNodePtr input_node) { return ScalarToTensor(value->cast()); } if (value->isa()) { - return opt::CreateTupleTensor(value->cast()); + return CreateTupleTensor(value->cast()); } if (value->isa()) { auto type_ptr = value->cast(); @@ -203,7 +273,7 @@ bool FuncGraphUtils::GetFuncGraphInputs(const FuncGraphPtr &func_graph, std::vec MS_LOG(ERROR) << "Input " << input->fullname_with_scope() << " of FuncGraph is not type of Parameter."; return false; } - if (common::AnfAlgo::IsParameterWeight(parameter)) { + if (parameter->has_default()) { continue; } inputs->push_back(std::make_pair(input, 0)); @@ -310,7 +380,8 @@ AbstractBasePtr FuncGraphUtils::GetAbstract(const AnfWithOutIndex &tensor) { return common::AnfAlgo::FetchAbstractByIndex(node->abstract(), idx); } -void FuncGraphUtils::GetFuncGraphInputsInfo(const FuncGraphPtr &func_graph, std::vector *inputs, +void FuncGraphUtils::GetFuncGraphInputsInfo(const FuncGraphPtr &func_graph, + std::vector> *inputs, std::vector *inputs_name) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(inputs); @@ -326,14 +397,16 @@ void FuncGraphUtils::GetFuncGraphInputsInfo(const FuncGraphPtr &func_graph, std: auto name = FuncGraphUtils::GetTensorName(tensor); auto data_type = FuncGraphUtils::GetTensorDataType(tensor); auto shape = FuncGraphUtils::GetTensorShape(tensor); - auto ms_tensor = std::make_shared(static_cast(data_type), shape); - ms_tensor->set_name(name); + auto ms_tensor = std::shared_ptr(MSTensor::CreateTensor(name, data_type, {}, nullptr, 0)); + MS_EXCEPTION_IF_NULL(ms_tensor); + ms_tensor->SetShape(shape); inputs->push_back(ms_tensor); inputs_name->push_back(name); } } -void FuncGraphUtils::GetFuncGraphOutputsInfo(const FuncGraphPtr &func_graph, std::vector *outputs, +void FuncGraphUtils::GetFuncGraphOutputsInfo(const FuncGraphPtr &func_graph, + std::vector> *outputs, std::vector *output_names) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(outputs); @@ -349,8 +422,9 @@ void FuncGraphUtils::GetFuncGraphOutputsInfo(const FuncGraphPtr &func_graph, std auto name = FuncGraphUtils::GetTensorName(tensor); auto data_type = FuncGraphUtils::GetTensorDataType(tensor); auto shape = FuncGraphUtils::GetTensorShape(tensor); - auto ms_tensor = std::make_shared(static_cast(data_type), shape); - ms_tensor->set_name(name); + auto ms_tensor = std::shared_ptr(MSTensor::CreateTensor(name, data_type, {}, nullptr, 0)); + MS_EXCEPTION_IF_NULL(ms_tensor); + ms_tensor->SetShape(shape); outputs->push_back(ms_tensor); output_names->push_back(name); } diff --git a/mindspore-lite/src/extendrt/utils/func_graph_utils.h b/mindspore-lite/src/extendrt/utils/func_graph_utils.h index 489d4c4d1a6c9513dd3b4a1081ef4625c5a5bfd4..8b4ee6ba8d5080bc18271bee0d14da370c4a0ff8 100644 --- a/mindspore-lite/src/extendrt/utils/func_graph_utils.h +++ b/mindspore-lite/src/extendrt/utils/func_graph_utils.h @@ -23,11 +23,12 @@ #include #include #include - +#include #include "ir/anf.h" #include "ir/dtype/type.h" #include "ir/func_graph.h" #include "include/api/data_type.h" +#include "include/api/types.h" #include "include/api/status.h" #include "common/kernel.h" #include "include/common/utils/anfalgo.h" @@ -53,9 +54,9 @@ class FuncGraphUtils { static std::string GetTensorName(const AnfWithOutIndex &tensor); static AbstractBasePtr GetAbstract(const AnfWithOutIndex &tensor); - static void GetFuncGraphInputsInfo(const FuncGraphPtr &graph, std::vector *inputs, + static void GetFuncGraphInputsInfo(const FuncGraphPtr &graph, std::vector> *inputs, std::vector *inputs_name); - static void GetFuncGraphOutputsInfo(const FuncGraphPtr &graph, std::vector *outputs, + static void GetFuncGraphOutputsInfo(const FuncGraphPtr &graph, std::vector> *outputs, std::vector *output_names); static Status UnifyGraphToNHWCFormat(const FuncGraphPtr &graph); @@ -66,6 +67,12 @@ class FuncGraphUtils { static AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr, mindspore::HashMap *eqv_ptr); + static tensor::TensorPtr CreateEmptyTupleTensor(const ValueTuplePtr &value_tuple); + template + static tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, + size_t data_length); + static tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple); + private: static ValuePtr GetNodeValuePtr(AnfNodePtr input_node); }; diff --git a/mindspore-lite/src/extendrt/utils/kernel_build_utils.cc b/mindspore-lite/src/extendrt/utils/kernel_build_utils.cc deleted file mode 100644 index a3de5dab414a51c82de5a4591608baddae71e854..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/utils/kernel_build_utils.cc +++ /dev/null @@ -1,439 +0,0 @@ -/** - * Copyright 2020-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/utils/kernel_build_utils.h" -#include -#include -#include -#include "mindspore/ops/op_def/framework_ops.h" -#include "common/common_utils.h" -#include "common/ms_factory.h" -#include "common/kernel_build_info.h" -#include "common/oplib/opinfo.h" -#include "common/oplib/oplib.h" -#include "utils/trace_base.h" -#include "include/common/utils/convert_utils.h" -#include "include/common/utils/anfalgo.h" -#include "include/backend/kernel_info.h" -#include "include/backend/anf_runtime_algorithm.h" -#include "src/common/common.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" - -namespace mindspore { -namespace infer { -using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; -using mindspore::kernel::KernelBuildInfo; -namespace { -constexpr auto kParamDynamic = "dynamic"; -constexpr auto kInputNum = 3; -constexpr auto kNameTranspose = "Transpose"; -constexpr auto kCustomTypeAscend = "acl_build"; - -bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) { - auto input_node = common::AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa() || input_node->isa()) { - return true; - } - return false; -} - -void GetOutputDtypes(const CNodePtr &kernel_node, std::vector *output_types) { - size_t output_num = AnfUtils::GetOutputTensorNum(kernel_node); - for (size_t output_index = 0; output_index < output_num; ++output_index) { - TypeId dtype = common::AnfAlgo::GetOutputInferDataType(kernel_node, output_index); - output_types->emplace_back(dtype); - } -} - -void GetOutputFormat(const CNodePtr &kernel_node, std::vector *output_formats) { - size_t output_num = AnfUtils::GetOutputTensorNum(kernel_node); - for (size_t output_index = 0; output_index < output_num; ++output_index) { - output_formats->emplace_back(kOpFormat_DEFAULT); - } -} - -void GetInputDtypes(const CNodePtr &kernel_node, std::vector *input_types, - std::vector *input_no_cnode_indexes) { - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t input_index = 0; input_index < input_num; ++input_index) { - TypeId dtype = kTypeUnknown; - if (IsInputNotCNode(kernel_node, input_index)) { - input_no_cnode_indexes->emplace_back(input_index); - dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); - } else { - dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); - } - input_types->emplace_back(dtype); - } -} - -void GetInputFormat(const CNodePtr &kernel_node, std::vector *input_formats) { - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t input_index = 0; input_index < input_num; ++input_index) { - input_formats->emplace_back(kOpFormat_DEFAULT); - } -} - -bool InputDtypeMatch(TypeId InputAttr, TypeId input_type, bool strict) { - if (InputAttr == input_type) { - return true; - } - if (!strict && InputAttr == kNumberTypeInt32 && (input_type == kNumberTypeInt16 || input_type == kNumberTypeInt64)) { - return true; - } - if (!strict && InputAttr == kNumberTypeFloat32 && - (input_type == kNumberTypeFloat16 || input_type == kNumberTypeFloat64)) { - return true; - } - return false; -} - -int GetOutputDtypeMatchedNum(const kernel::KernelAttr &kernel_attr, const std::vector &output_types) { - if (kernel_attr.GetOutputSize() != output_types.size()) { - MS_LOG(DEBUG) << "required output num:" << kernel_attr.GetInputSize() - << ", actual output num:" << output_types.size(); - return 0; - } - int data_type_matched_num = 0; - auto output_num = output_types.size(); - for (size_t i = 0; i < output_num; ++i) { - if (kernel_attr.GetOutputAttr(i).dtype != output_types[i]) { - MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetOutputAttr(i).dtype - << ", actual output dtype:" << output_types[i]; - } else { - data_type_matched_num++; - } - } - return data_type_matched_num; -} - -int GetInputDtypeFormatMatchedNum(const kernel::KernelAttr &kernel_attr, const std::vector &input_types, - const std::vector &input_not_cnode_indexes, bool strict) { - if (kernel_attr.GetInputSize() != input_types.size()) { - MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size(); - return 0; - } - int data_type_matched_num = 0; - auto input_num = input_types.size(); - for (size_t i = 0; i < input_num; ++i) { - if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).dtype, input_types[i], strict)) { - MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).dtype - << ", actual input dtype:" << input_types[i]; - } else { - data_type_matched_num++; - } - } - return data_type_matched_num; -} - -void ExpandKernelAttr(const CNodePtr &kernel_node, kernel::KernelAttr *kernel_attr) { - MS_EXCEPTION_IF_NULL(kernel_attr); - size_t attr_num = kernel_attr->GetInputSize(); - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - if (attr_num == 0) { - MS_LOG(EXCEPTION) << "Input size is empty"; - return; // To pass the CI Check_Cppcheck - } - // Only support one dynamic input like Concat or - // many dynamic input but each input has same number like DynamicStitch - std::string format = kOpFormat_DEFAULT; - std::vector attr_list; - size_t each_attr_input_num = input_num / attr_num; - for (size_t i = 0; i < attr_num; ++i) { - TypeId input_dtype = kernel_attr->GetInputAttr(i).dtype; - for (size_t j = 0; j < each_attr_input_num; ++j) { - (void)attr_list.emplace_back(input_dtype, format); - } - } - kernel_attr->SetInputAttrList(attr_list); - - TypeId output_dtype = kernel_attr->GetOutputAttr(0).dtype; - size_t output_num = AnfUtils::GetOutputTensorNum(kernel_node); - for (size_t i = 1; i < output_num; ++i) { - (void)kernel_attr->AddOutputAttr(output_dtype); - } -} - -void SetKernelBuildInfo(const std::vector &input_formats, const std::vector &input_types, - const std::vector &output_formats, const std::vector &output_types, - AnfNode *kernel_node) { - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - builder->SetInputsFormat(input_formats); - builder->SetInputsDeviceType(input_types); - builder->SetOutputsFormat(output_formats); - builder->SetOutputsDeviceType(output_types); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node); -} - -void UpdateDynamicKernelBuildInfo(const CNodePtr &kernel_node) { - const std::string &op_name = common::AnfAlgo::GetCNodeName(kernel_node); - MS_LOG(INFO) << "Operator name: " << op_name; - // Set kernel build info - std::vector input_types; - std::vector input_not_cnode_indexes; - GetInputDtypes(kernel_node, &input_types, &input_not_cnode_indexes); - std::vector output_types; - GetOutputDtypes(kernel_node, &output_types); - std::vector input_formats; - GetInputFormat(kernel_node, &input_formats); - std::vector output_formats; - GetOutputFormat(kernel_node, &output_formats); - SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get()); -} - -void UpdateCustomKernelBuildInfo(const CNodePtr &kernel_node, bool is_akg_op) { - MS_EXCEPTION_IF_NULL(kernel_node); - auto builder = std::make_shared(); - const std::string &op_name = common::AnfAlgo::GetCNodeName(kernel_node); - if (is_akg_op) { - builder->SetKernelType(KernelType::AKG_KERNEL); - } else { - builder->SetKernelType(KernelType::CPU_KERNEL); - } - builder->SetProcessor(kernel::Processor::CPU); - // Set inputs info - std::vector input_types; - std::vector input_not_cnode_indexes; - GetInputDtypes(kernel_node, &input_types, &input_not_cnode_indexes); - std::vector input_formats; - GetInputFormat(kernel_node, &input_formats); - builder->SetInputsDeviceType(input_types); - builder->SetInputsFormat(input_formats); - // Set inputs info - std::vector output_types; - GetOutputDtypes(kernel_node, &output_types); - std::vector output_formats; - GetOutputFormat(kernel_node, &output_formats); - builder->SetOutputsDeviceType(output_types); - builder->SetOutputsFormat(output_formats); - if (op_name == lite::kNameCustomAscend || is_akg_op) { - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); - } -} - -kernel::KernelAttr FillNoneInKernelAttr(const CNodePtr &kernel_node, const std::vector &input_types, - const std::vector &output_types, - const kernel::KernelAttr &kernel_attr) { - MS_EXCEPTION_IF_NULL(kernel_node); - // Only process Custom op - if (!IsPrimitiveCNode(kernel_node, prim::kPrimCustom)) { - return kernel_attr; - } - auto input_num = input_types.size(); - auto output_num = output_types.size(); - if (kernel_attr.GetInputSize() != input_types.size() || kernel_attr.GetOutputSize() != output_types.size()) { - MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_num; - MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetOutputSize() << ", actual input num:" << output_num; - return kernel_attr; - } - kernel::KernelAttr result; - // Fill inputs info. - for (size_t i = 0; i < input_num; ++i) { - auto type_format = kernel_attr.GetInputAttr(i); - if (type_format.dtype == TypeId::kMetaTypeNone) { - type_format.dtype = input_types[i]; - } - if (type_format.format.empty()) { - type_format.format = kOpFormat_DEFAULT; - } - (void)result.AddInputAttr(type_format.dtype, type_format.format); - } - // Fill outputs info. - for (size_t i = 0; i < output_num; ++i) { - auto type_format = kernel_attr.GetOutputAttr(i); - if (type_format.dtype == TypeId::kMetaTypeNone) { - type_format.dtype = output_types[i]; - } - if (type_format.format.empty()) { - type_format.format = kOpFormat_DEFAULT; - } - (void)result.AddOutputAttr(type_format.dtype, type_format.format); - } - return result; -} -} // namespace - -bool IsDynamicParamKernel(const std::string &op_name) { - const auto &op_info = kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kImplyCPU); - if (op_info == nullptr) { - return false; - } - - const auto &input_io_info = op_info->inputs_ptr(); - if (input_io_info.size() != 1 || input_io_info[0]->param_type() != kParamDynamic) { - return false; - } - - const auto &output_io_info = op_info->outputs_ptr(); - if (output_io_info.size() != 1 || output_io_info[0]->param_type() != kParamDynamic) { - return false; - } - - return true; -} - -bool SelectKernel(const CNodePtr &kernel_node, kernel::KernelAttr *selected_kernel_attr, - const std::vector &kernel_attrs, const std::vector &input_types, - const std::vector &input_not_cnode_indexes, const std::vector &output_types, - std::pair *matched, bool strict) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(selected_kernel_attr); - MS_EXCEPTION_IF_NULL(matched); - MS_LOG(DEBUG) << "Select kernel for op: " << common::AnfAlgo::GetCNodeName(kernel_node); - for (auto kernel_attr : kernel_attrs) { - if (kernel_attr.GetAllSame()) { - ExpandKernelAttr(kernel_node, &kernel_attr); - } - size_t output_num = AnfUtils::GetOutputTensorNum(kernel_node); - if (kernel_attr.GetOutputSize() != output_num) { - MS_LOG(DEBUG) << "Output num is not equal!"; - continue; - } - - auto new_kernel_attr = FillNoneInKernelAttr(kernel_node, input_types, output_types, kernel_attr); - int input_dtype_matched_num = - GetInputDtypeFormatMatchedNum(new_kernel_attr, input_types, input_not_cnode_indexes, strict); - int output_dtype_matched_num = GetOutputDtypeMatchedNum(new_kernel_attr, output_types); - // All formats and data types matched - if (input_dtype_matched_num == SizeToInt(input_types.size())) { - *selected_kernel_attr = new_kernel_attr; - matched->first = true; - if (output_dtype_matched_num == SizeToInt(output_types.size())) { - matched->second = true; - return true; - } - } - } - return false; -} - -kernel::KernelAttr BuildKernelFromInput(const std::vector &inputs, const std::vector &outputs, - const kernel::KernelAttr &origin_attr) { - kernel::KernelAttr attr = origin_attr; - for (auto in_dtype : inputs) { - (void)attr.AddInputAttr(in_dtype); - } - for (auto out_dtype : outputs) { - (void)attr.AddOutputAttr(out_dtype); - } - (void)attr.AddSkipCheckAttr(true); - return attr; -} - -std::pair SetKernelInfoWithMsg(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - const std::string &op_name = common::AnfAlgo::GetCNodeName(kernel_node); - if (IsPrimitiveCNode(kernel_node, prim::kPrimCustom)) { - if (common::AnfAlgo::HasNodeAttr("type", kernel_node) && - common::AnfAlgo::GetNodeAttr(kernel_node, "type") == "GraphKernel") { - UpdateCustomKernelBuildInfo(kernel_node, true); - return {}; - } - auto tp = common::AnfAlgo::GetNodeAttr(kernel_node, kAttrFuncType); - if (IsOneOfCustomAkgType(tp)) { - UpdateCustomKernelBuildInfo(kernel_node, true); - return {}; - } - if (tp == kCustomTypeAscend) { - UpdateCustomKernelBuildInfo(kernel_node, false); - return {}; - } - // If Custom op has not set reg info, then infer info from inputs - if (mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kImplyCPU) == nullptr) { - MS_LOG(WARNING) << "Not find operator information for Custom operator[" << op_name << "]. " - << "Infer operator information from inputs. For more details, " - << "please refer to 'mindspore.ops.Custom' at https://www.mindspore.cn."; - UpdateCustomKernelBuildInfo(kernel_node, false); - return {}; - } - } else if (IsDynamicParamKernel(op_name)) { - // Select for dynamic kernel(both the number and data type are undetermined). - UpdateDynamicKernelBuildInfo(kernel_node); - return {}; - } else if (IsAKGSparseOP(kernel_node)) { - UpdateCustomKernelBuildInfo(kernel_node, true); - return {}; - } - - std::vector input_formats; - std::vector input_types; - std::vector input_not_cnode_indexes; - std::vector selected_output_formats; - std::vector selected_output_types; - MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << op_name; - GetInputDtypes(kernel_node, &input_types, &input_not_cnode_indexes); - GetInputFormat(kernel_node, &input_formats); - GetOutputDtypes(kernel_node, &selected_output_types); - GetOutputFormat(kernel_node, &selected_output_formats); - - SetKernelBuildInfo(input_formats, input_types, selected_output_formats, selected_output_types, kernel_node.get()); - return {}; -} - -void SetKernelInfo(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - auto [msg, etype] = SetKernelInfoWithMsg(kernel_node); - if (msg.empty()) return; - MS_EXCEPTION(etype) << msg; -} - -void CopyInputWeights(const CNodePtr &kernel_node, const std::vector &inputs) { - std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == lite::kNameCustomAscend || kernel_name == kNameTranspose) { - auto node_input_size = kernel_node->size(); - if (node_input_size < kInputNum) { - MS_LOG(ERROR) << "Input num of custom ascend kernel should larger than " << (kInputNum - 1) << ", real num is " - << node_input_size; - return; - } - if (node_input_size != inputs.size() + 1) { - MS_LOG(ERROR) << "Input num of custom ascend kernel [" << node_input_size << "]" - << " is not equal to kernel tensor size[" << (inputs.size() + 1) << "]."; - return; - } - auto weight_input = kernel_node->input(node_input_size - 1); - if (!weight_input->isa()) { - MS_LOG(ERROR) << "Om input is not parameter."; - return; - } - ParameterPtr weight_param = weight_input->cast(); - if (weight_param == nullptr || !weight_param->has_default()) { - MS_LOG(ERROR) << "Om param is invalid, val= " << weight_param; - return; - } - auto tensor = std::static_pointer_cast(weight_param->default_param()); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Tensor is nullptr."; - return; - } - if (tensor->data_c() == nullptr || tensor->Size() == 0) { - MS_LOG(ERROR) << "Tensor data is invalid."; - return; - } - auto new_addr = malloc(tensor->Size()); - if (new_addr == nullptr) { - MS_LOG(ERROR) << "Malloc failed, size= " << tensor->Size(); - return; - } - memcpy(new_addr, tensor->data_c(), tensor->Size()); - kernel::AddressPtr addr_ptr = std::make_shared(new_addr, tensor->Size()); - inputs[inputs.size() - 1]->SetData(addr_ptr); - } -} -} // namespace infer -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/utils/kernel_build_utils.h b/mindspore-lite/src/extendrt/utils/kernel_build_utils.h deleted file mode 100644 index a3ea6b53ff24caf97830be4a106a1acec4665bd9..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/utils/kernel_build_utils.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_UTILS_KERNEL_BUILD_UTILS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_UTILS_KERNEL_BUILD_UTILS_H_ - -#include -#include -#include - -#include "ir/anf.h" -#include "ir/dtype/type.h" -#include "include/common/utils/utils.h" -#include "common/kernel.h" -#include "common/common_utils.h" - -namespace mindspore { -namespace infer { -using kernel::DataType; -void SetKernelInfo(const CNodePtr &apply_kernel_ptr); -void CopyInputWeights(const CNodePtr &kernel_node, const std::vector &inputs); -} // namespace infer -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_UTILS_KERNEL_BUILD_UTILS_H_ diff --git a/mindspore-lite/src/extendrt/utils/runtime_utils.cc b/mindspore-lite/src/extendrt/utils/runtime_utils.cc deleted file mode 100644 index 0f1efd172d51c75855e3e96e018081ad8353febc..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/utils/runtime_utils.cc +++ /dev/null @@ -1,130 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -#include "extendrt/utils/runtime_utils.h" - -#include "src/extendrt/infer_device_address.h" -#include "include/common/utils/anfalgo.h" -#include "include/backend/anf_runtime_algorithm.h" -#include "src/common/common.h" - -namespace mindspore { -namespace { -const size_t tensor_max_size_utils = 0x1000000; -} // namespace - -void *RuntimeUtils::GetAddressPtr(device::DeviceAddressPtr address_ptr) { - MS_EXCEPTION_IF_NULL(address_ptr); - return address_ptr->GetDevicePtr(); -} - -void RuntimeUtils::SetAddressPtr(device::DeviceAddressPtr address_ptr, void *ptr) { - MS_EXCEPTION_IF_NULL(address_ptr); - address_ptr->SetDevicePtr(ptr); -} - -void RuntimeUtils::AllocAddressPtr(device::DeviceAddressPtr address_ptr) { - MS_EXCEPTION_IF_NULL(address_ptr); - if (address_ptr->GetDevicePtr() == nullptr) { - address_ptr->SetDevicePtr(malloc(address_ptr->GetSize())); - } -} - -kernel::AddressPtr RuntimeUtils::GetAddressFromDevice(device::DeviceAddressPtr device_address) { - MS_EXCEPTION_IF_NULL(device_address); - kernel::AddressPtr kernel_address = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_address); - if (device_address->GetDevicePtr() == nullptr) { - device_address->SetDevicePtr(malloc(device_address->GetSize())); - } - MS_EXCEPTION_IF_NULL(device_address->GetDevicePtr()); - kernel_address->addr = device_address->GetDevicePtr(); - kernel_address->size = device_address->GetSize(); - return kernel_address; -} - -device::DeviceAddressPtr RuntimeUtils::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) { - return std::make_shared(device_ptr, device_size, format, type_id); -} - -void RuntimeUtils::UpdateKernelNodeOutputInfo(const AnfNodePtr &kernel_node, - const std::vector &output_addrs) { - std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == lite::kNameCustomAscend) { - size_t output_num = AnfUtils::GetOutputTensorNum(kernel_node); - if (output_addrs.size() != output_num) { - MS_LOG(ERROR) << "Output addr size[" << output_addrs.size() << "] is not equal to node outputs size[" - << output_num << "]"; - return; - } - // update output addr - bool is_update_shape = false; - for (size_t i = 0; i < output_num; ++i) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_node, i); - MS_EXCEPTION_IF_NULL(device_address); - MS_EXCEPTION_IF_NULL(output_addrs[i]); - auto addr_ptr = device_address->GetMutablePtr(); - if (addr_ptr != nullptr && output_addrs[i]->addr != addr_ptr) { - free(addr_ptr); - device_address->set_ptr(output_addrs[i]->addr); - device_address->SetSize(output_addrs[i]->size); - is_update_shape = true; - } - } - if (!is_update_shape) { - MS_LOG(DEBUG) << "There is no need to update output shape."; - return; - } - - auto output_tensors = AnfAlgo::GetOrCreateAllOutputKernelTensors(kernel_node); - auto input_tensors = AnfAlgo::GetOrCreateAllInputKernelTensors(kernel_node); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node); - kernel_mod->UpdateOutputShapeAndSize(input_tensors, output_tensors); - if (output_tensors.empty()) { - MS_LOG(ERROR) << "The output shape size of custom ascend is empty."; - return; - } - auto abstract = kernel_node->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - if (utils::isa(abstract)) { - auto abstract_tuple = abstract->cast(); - MS_EXCEPTION_IF_NULL(abstract_tuple); - if (abstract_tuple->elements().size() != output_tensors.size()) { - MS_LOG(ERROR) << "Abstract size[" << abstract_tuple->elements().size() << "] is not equal to output shape size[" - << output_tensors.size() << "]"; - return; - } - for (size_t i = 0; i < abstract_tuple->elements().size(); ++i) { - auto tmp_abstract = abstract_tuple->elements()[i]; - MS_EXCEPTION_IF_NULL(tmp_abstract); - MS_EXCEPTION_IF_NULL(output_tensors[i]); - tmp_abstract->set_shape(std::make_shared(output_tensors[i]->GetShapeVector())); - } - } else { - MS_EXCEPTION_IF_NULL(output_tensors[0]); - abstract->set_shape(std::make_shared(output_tensors[0]->GetShapeVector())); - } - } -} -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/utils/runtime_utils.h b/mindspore-lite/src/extendrt/utils/runtime_utils.h deleted file mode 100644 index 6441c2aa7d9e3829d2509e9332fcb8395d2e3aca..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/utils/runtime_utils.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_EXTENDRT_UTILS_RUNTIME_UTILS_H_ -#define MINDSPORE_LITE_SRC_EXTENDRT_UTILS_RUNTIME_UTILS_H_ - -#include -#include - -#include "common/device_address.h" -#include "common/kernel.h" -#include "ir/tensor.h" - -namespace mindspore { -class RuntimeUtils { - public: - static void *GetAddressPtr(device::DeviceAddressPtr address_ptr); - static void SetAddressPtr(device::DeviceAddressPtr address_ptr, void *ptr); - static void AllocAddressPtr(device::DeviceAddressPtr address_ptr); - - static kernel::AddressPtr GetAddressFromDevice(device::DeviceAddressPtr address_ptr); - - static device::DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id); - static void UpdateKernelNodeOutputInfo(const AnfNodePtr &kernel_node, - const std::vector &output_addrs); -}; -} // namespace mindspore - -#endif // MINDSPORE_LITE_SRC_EXTENDRT_UTILS_RUNTIME_UTILS_H_ diff --git a/mindspore-lite/src/extendrt/utils/segment_utils.cc b/mindspore-lite/src/extendrt/utils/segment_utils.cc deleted file mode 100644 index bc850a6eff2fd089af8250beb87d7a60e82698ed..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/extendrt/utils/segment_utils.cc +++ /dev/null @@ -1,174 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/extendrt/utils/segment_utils.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "mindspore/ops/op_def/sequence_ops.h" -#include "mindspore/ops/op_def/framework_ops.h" -#include "utils/hash_map.h" -#include "utils/hash_set.h" -#include "src/common/log_adapter.h" -#include "include/common/utils/utils.h" -#include "ir/manager.h" -#include "ir/func_graph_cloner.h" -#include "frontend/operator/ops.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" - -namespace mindspore { -namespace compile { -namespace { -// Return the list of nodes whose values are required beyond this segment. -// Arguments: -// nodes: list of nodes in the segment -// users: dict mapping each node to its users (globally) -// seen: set of nodes that are part of the segment -AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users, - const mindspore::HashSet &seen) { - AnfNodePtrList output; - if (users.size() == 0) { - return output; - } - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto iter = users.find(node); - if (iter == users.end()) { - continue; - } - auto &node_users = iter->second; - const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users), - [&seen](const std::pair &u) -> bool { - const bool is_outer_user = (seen.find(u.first) == seen.end()); - return is_outer_user; - }); - if (has_outer_user) { - output.emplace_back(node); - } - } - return output; -} - -AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr, - mindspore::HashMap *eqv_ptr) { - MS_EXCEPTION_IF_NULL(fg); - MS_EXCEPTION_IF_NULL(inputs_ptr); - MS_EXCEPTION_IF_NULL(eqv_ptr); - MS_EXCEPTION_IF_NULL(node); - auto &inputs = *inputs_ptr; - auto &eqv = *eqv_ptr; - if (node->isa() && !IsValueNode(node)) { - eqv[node] = node; - } else if (eqv.find(node) == eqv.end()) { - inputs.push_back(node); - eqv[node] = fg->add_parameter(); - MS_EXCEPTION_IF_NULL(eqv[node]); - eqv[node]->set_abstract(node->abstract()); - eqv[node]->set_kernel_info(node->kernel_info_ptr()); - } - return eqv[node]; -} -} // namespace - -std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList &lst) { - if (lst.empty()) { - MS_LOG(EXCEPTION) << "Input anf node list is empty"; - } - FuncGraphPtr fg = nullptr; - { - // limit the lifetime of guard. - MS_EXCEPTION_IF_NULL(lst[0]); - MS_EXCEPTION_IF_NULL(lst[0]->cast()); - MS_EXCEPTION_IF_NULL(lst[0]->cast()->func_graph()); - TraceGuard guard(MakeTraceInfo(lst[0]->cast()->func_graph()->debug_info())); - fg = std::make_shared(); - } - AnfNodePtrList inputs; - mindspore::HashMap eqv; - // Merge CNodes into a AnfGraph that represents a linear instruction segment - for (auto n : lst) { - MS_EXCEPTION_IF_NULL(n); - if (!n->isa()) { - MS_LOG(EXCEPTION) << "Inst is not CNode"; - } - auto &inps = n->cast()->inputs(); - if (inps.empty()) { - MS_LOG(EXCEPTION) << "Input is empty"; - } - if (!IsValueNode(inps[0]) && - !(IsValueNode(inps[0]) && - inps[0]->cast()->value()->cast()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL))) { - MS_LOG(EXCEPTION) << "Input[0] must be a Primitive ValueNode"; - } - auto fn = inps[0]; - std::vector args{fn}; - if (IsPrimitive(fn, prim::kPrimDepend) && inps.size() >= kDependInputSize && - eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { - args.emplace_back(RefSubGraphNode(fg, inps[kRealInputIndexInDepend], &inputs, &eqv)); - const size_t value_start_index = 2; - for (size_t i = value_start_index; i < inps.size(); ++i) { - args.emplace_back(NewValueNode(MakeValue(0))); - } - } else { - (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), - [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); - } - TraceGuard tg(MakeTraceInfo(n->debug_info())); - MS_EXCEPTION_IF_NULL(fg); - eqv[n] = fg->NewCNode(args); - MS_EXCEPTION_IF_NULL(eqv[n]); - eqv[n]->set_abstract(n->abstract()); - eqv[n]->set_kernel_info(n->kernel_info_ptr()); - } - mindspore::HashSet eqv_keys; - for (auto &e : eqv) { - (void)eqv_keys.emplace(e.first); - } - MS_EXCEPTION_IF_NULL(lst[0]->func_graph()); - auto mgr = lst[0]->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(mgr); - auto outputs = GetOutput(lst, mgr->node_users(), eqv_keys); - AnfNodePtr fg_output; - if (outputs.size() > 1) { - std::vector output_args; - output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args), - [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; }); - // Set output for AnfGraph - fg_output = fg->NewCNode(output_args); - } else { - if (outputs.empty()) { - MS_LOG(EXCEPTION) << "Output is empty."; - } - fg_output = eqv[outputs[0]]; - } - fg->set_output(fg_output); - return std::make_tuple(fg, inputs, outputs); -} -} // namespace compile -} // namespace mindspore diff --git a/mindspore-lite/src/extendrt/utils/tensor_default_impl.h b/mindspore-lite/src/extendrt/utils/tensor_default_impl.h index 7d43bb259b09f4518d951aa92288616a5b52df88..84ee407289ed176993008d79deb40818dd45298e 100644 --- a/mindspore-lite/src/extendrt/utils/tensor_default_impl.h +++ b/mindspore-lite/src/extendrt/utils/tensor_default_impl.h @@ -29,7 +29,7 @@ #include "common/device_address.h" #include "common/utils.h" #include "common/mutable_tensor_impl.h" -#include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h" +#include "src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h" namespace mindspore { class TensorDefaultImpl : public MutableTensorImpl { @@ -48,7 +48,7 @@ class TensorDefaultImpl : public MutableTensorImpl { for (auto s : shape) { data_buf_size *= static_cast(s); } - void *data_buf_ptr = kernel::AscendAllocatorPlugin::GetInstance().MallocHost(data_buf_size); + void *data_buf_ptr = AscendAllocatorPlugin::GetInstance().MallocHost(data_buf_size); data_ = data_buf_ptr; own_data_ = false; } @@ -74,11 +74,11 @@ class TensorDefaultImpl : public MutableTensorImpl { } if (device_data_ != nullptr && own_data_) { MS_LOG(INFO) << "free device data in tensor default impl."; - kernel::AscendAllocatorPlugin::GetInstance().Free(device_data_, device_id_); + AscendAllocatorPlugin::GetInstance().Free(device_data_, device_id_); device_data_ = nullptr; } if (is_acl_host_ && data_ != nullptr) { - kernel::AscendAllocatorPlugin::GetInstance().FreeHost(const_cast(data_)); + AscendAllocatorPlugin::GetInstance().FreeHost(const_cast(data_)); } } void SetShape(const std::vector &shape) override { shape_ = shape; } @@ -111,7 +111,7 @@ class TensorDefaultImpl : public MutableTensorImpl { void SetDeviceData(void *data) override { if (own_data_ && device_data_ != nullptr) { MS_LOG(INFO) << "tensor has own device data, now release device data and set new device data."; - kernel::AscendAllocatorPlugin::GetInstance().Free(device_data_, device_id_); + AscendAllocatorPlugin::GetInstance().Free(device_data_, device_id_); } device_data_ = data; own_data_ = false; @@ -133,7 +133,7 @@ class TensorDefaultImpl : public MutableTensorImpl { free(const_cast(data_)); } if (is_acl_host_ && data_ != nullptr) { - kernel::AscendAllocatorPlugin::GetInstance().FreeHost(const_cast(data_)); + AscendAllocatorPlugin::GetInstance().FreeHost(const_cast(data_)); is_acl_host_ = false; } data_ = data; @@ -146,7 +146,7 @@ class TensorDefaultImpl : public MutableTensorImpl { free(const_cast(data_)); } if (is_acl_host_ && data_ != nullptr) { - kernel::AscendAllocatorPlugin::GetInstance().FreeHost(const_cast(data_)); + AscendAllocatorPlugin::GetInstance().FreeHost(const_cast(data_)); is_acl_host_ = false; } data_ = data; diff --git a/mindspore-lite/src/extendrt/utils/tensor_utils.cc b/mindspore-lite/src/extendrt/utils/tensor_utils.cc index 3d1ce7d1c58e5f944ee5316af3981cbab6aef715..1a5eaa7c089a124d78c917d9f4c39ca7dd9c1794 100644 --- a/mindspore-lite/src/extendrt/utils/tensor_utils.cc +++ b/mindspore-lite/src/extendrt/utils/tensor_utils.cc @@ -21,152 +21,9 @@ #include #include "extendrt/utils/tensor_utils.h" -#include "common/common_utils.h" #include "mindspore/ccsrc/kernel/framework_utils.h" -#include "common/format_utils.h" namespace mindspore { -TensorRefData::TensorRefData(void *data, size_t bytes_size, size_t data_size, size_t ndim, - const std::function &deleter) - : data_(data), elem_count_(bytes_size), data_size_(data_size), ndim_(ndim), deleter_(deleter) {} - -TensorRefData::~TensorRefData() { - if (deleter_ && data_) { - deleter_(reinterpret_cast(data_)); - } -} - -ssize_t TensorRefData::size() const { return static_cast(elem_count_); } - -ssize_t TensorRefData::itemsize() const { - if (elem_count_ == 0) { - return 0; - } - return static_cast(data_size_ / elem_count_); -} - -ssize_t TensorRefData::nbytes() const { return static_cast(data_size_); } - -ssize_t TensorRefData::ndim() const { return static_cast(ndim_); } - -void *TensorRefData::data() { return data_; } - -const void *TensorRefData::const_data() const { return data_; } - -std::string TensorRefData::ToString(TypeId type, const ShapeVector &shape, bool use_comma) const { - std::stringstream stream; - stream << "RefTensor:["; - for (size_t i = 0; i < shape.size(); i++) { - stream << shape[i]; - if (i + 1 < shape.size()) { - stream << ","; - } - } - stream << "]" << type; - return stream.str(); -} - -mindspore::Format TensorTensorImpl::Format() const { - MS_EXCEPTION_IF_NULL(tensor_); - return kernel::GetFormatFromStrToEnum(tensor_->device_info().format_); -} - -void TensorTensorImpl::SetFormat(mindspore::Format format) { - MS_EXCEPTION_IF_NULL(tensor_); - auto device_info = tensor_->device_info(); - device_info.format_ = kernel::GetFormatFromEnumToStr(format); - tensor_->set_device_info(device_info); -} - -std::vector TensorUtils::MSTensorToTensorPtr(const std::vector &ms_tensors) { - std::vector tensor_ptrs; - - for (auto ms_tensor : ms_tensors) { - auto data_type = ms_tensor.DataType(); - auto type_id = static_cast(data_type); - auto shape = ms_tensor.Shape(); - auto data = ms_tensor.MutableData(); - auto data_size = ms_tensor.DataSize(); - auto ref_tensor_data = std::make_shared(data, ms_tensor.ElementNum(), data_size, shape.size()); - auto tensor_ptr = std::make_shared(type_id, shape, ref_tensor_data); - tensor_ptr->set_name(ms_tensor.Name()); - tensor_ptr->set_data_type(type_id); - tensor_ptrs.push_back(tensor_ptr); - } - return tensor_ptrs; -} - -std::vector TensorUtils::TensorPtrToMSTensor(std::vector tensor_ptrs, - const std::vector &tensor_names) { - std::vector ms_tensors; - for (size_t i = 0; i < tensor_ptrs.size(); i++) { - auto graph_tensor = tensor_ptrs[i]; - std::string graph_tensor_name = tensor_names[i]; - graph_tensor->set_name(graph_tensor_name); - auto tensor_impl = std::make_shared(graph_tensor); - ms_tensors.push_back(MSTensor(tensor_impl)); - } - return ms_tensors; -} - -std::vector TensorUtils::MSTensorToTensor(const std::vector &ms_tensors) { - std::vector tensors; - for (auto ms_tensor : ms_tensors) { - auto data_type = ms_tensor.DataType(); - auto type_id = static_cast(data_type); - auto shape = ms_tensor.Shape(); - auto data = const_cast(ms_tensor.Data().get()); - auto data_size = ms_tensor.DataSize(); - auto ref_tensor_data = std::make_shared(data, ms_tensor.ElementNum(), data_size, shape.size()); - mindspore::tensor::Tensor tensor(type_id, shape, ref_tensor_data); - auto device_address = ms_tensor.GetDeviceData(); - if (device_address != nullptr) { - auto lite_device_address = std::make_shared(device_address, ms_tensor.DataSize()); - tensor.set_device_address(lite_device_address); - // only use device_id now. - auto device_info = tensor::DeviceInfo("DefaultFormat", nullptr, "DefaultFormat", ms_tensor.GetDeviceId()); - tensor.set_device_info(device_info); - } - tensors.emplace_back(std::move(tensor)); - } - return tensors; -} - -std::vector TensorUtils::TensorToMSTensor(std::vector tensors, - const std::vector &tensor_names) { - std::vector ms_tensors; - for (size_t i = 0; i < tensors.size(); i++) { - auto &graph_tensor = tensors[i]; - std::string graph_tensor_name = tensor_names[i]; - graph_tensor.set_name(graph_tensor_name); - auto tensor_impl = std::make_shared(graph_tensor); - ms_tensors.emplace_back(MSTensor(tensor_impl)); - } - return ms_tensors; -} - -std::vector TensorUtils::TensorToTensorPtr( - const std::vector &tensors) { - std::vector tensor_ptrs; - for (auto &tensor : tensors) { - auto type_id = static_cast(tensor.data_type_c()); - auto shape = tensor.shape_c(); - auto data = tensor.data_c(); - auto data_size = tensor.Size(); - auto tensor_ptr = std::make_shared(type_id, shape, data, data_size); - tensor_ptrs.push_back(tensor_ptr); - } - return tensor_ptrs; -} - -std::vector TensorUtils::TensorPtrToTensor( - const std::vector &tensor_ptrs) { - std::vector tensors; - std::transform(tensor_ptrs.begin(), tensor_ptrs.end(), std::back_inserter(tensors), - [](mindspore::tensor::TensorPtr tensor_ptr) { return mindspore::tensor::Tensor(*tensor_ptr); }); - return tensors; -} - kernel::AddressPtr CloudTensorUtils::LiteTensorToAddressPtr(const lite::Tensor *lite_tensor) { kernel::AddressPtr address_ptr = std::make_shared(lite_tensor->data(), lite_tensor->Size()); return address_ptr; @@ -217,36 +74,4 @@ std::vector CloudTensorUtils::LiteTensorToKernelTensorPt return kernel_tensor_list; } - -std::vector> AbstractTensorUtils::GetTensorListShapes( - const std::vector &tensors) { - std::vector> original_dims; - std::transform(tensors.begin(), tensors.end(), std::back_inserter(original_dims), - [](infer::abstract::Tensor *tensor) { - std::vector shape64; - if (tensor != nullptr) { - auto shape32 = tensor->shape(); - std::transform(shape32.begin(), shape32.end(), std::back_inserter(shape64), - [](int dim) { return static_cast(dim); }); - } - return shape64; - }); - return original_dims; -} - -bool AbstractTensorUtils::SetTensorListShapse(const std::vector &tensors, - const std::vector> &shapes) { - for (size_t i = 0; i < tensors.size(); i++) { - auto tensor = tensors.at(i); - if (tensor == nullptr) { - continue; - } - auto shape64 = shapes.at(i); - std::vector shape32; - std::transform(shape64.begin(), shape64.end(), std::back_inserter(shape32), - [](int64_t dim) { return static_cast(dim); }); - tensor->set_shape(shape32); - } - return true; -} } // namespace mindspore diff --git a/mindspore-lite/src/extendrt/utils/tensor_utils.h b/mindspore-lite/src/extendrt/utils/tensor_utils.h index 79ef5c2bb4d04b4ee29cb75caea51cafbf4def2f..b9d547cc4323e701ca588055ab8be80870ed9e88 100644 --- a/mindspore-lite/src/extendrt/utils/tensor_utils.h +++ b/mindspore-lite/src/extendrt/utils/tensor_utils.h @@ -33,200 +33,9 @@ #include "src/tensor.h" #include "infer/tensor.h" #ifdef ENABLE_CLOUD_INFERENCE -#include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h" +#include "src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h" #endif namespace mindspore { -class TensorRefData : public tensor::TensorData { - public: - TensorRefData(void *data, size_t elem_count, size_t data_size, size_t ndim, - const std::function &deleter = nullptr); - ~TensorRefData(); - - ssize_t size() const override; - ssize_t itemsize() const override; - ssize_t nbytes() const override; - ssize_t ndim() const override; - void *data() override; - const void *const_data() const override; - bool is_sub_data() const override { return false; } - bool has_sub_data() const override { return false; } - std::string ToString(TypeId type, const ShapeVector &shape, bool use_comma) const override; - - private: - void *data_ = nullptr; - size_t elem_count_ = 0; - size_t data_size_ = 0; - size_t ndim_ = 0; - std::function deleter_ = nullptr; -}; - -constexpr auto kLiteDeviceName = "LiteDevice"; - -class LiteDeviceAddress : public device::DeviceAddress { - public: - LiteDeviceAddress(void *ptr, size_t size) : device::DeviceAddress(ptr, size) {} - void SetData(void *data) { set_ptr(data); } - - bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr, - bool sync_on_demand = false) const override { - return false; - } - bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr, - const std::string &format) const override { - return false; - } - bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const override { - return SyncHostToDevice(shape, size, type, host_ptr, "DefaultFormat"); - } - void ClearDeviceMemory() override {} -}; - -class TensorTensorImpl : public MutableTensorImpl { - public: - explicit TensorTensorImpl(const tensor::Tensor &tensor) : tensor_(std::make_shared(tensor)) {} - explicit TensorTensorImpl(const std::shared_ptr &tensor) : tensor_(tensor) {} - - void SetData(void *, bool) override { MS_LOG_EXCEPTION << "Cannot set data for TensorTensorImpl"; } - - std::shared_ptr Data() const override { - MS_EXCEPTION_IF_NULL(tensor_); - return std::shared_ptr(tensor_->data_c(), [](const void *) {}); - } - - void SetDeviceId(int device_id) override { - MS_EXCEPTION_IF_NULL(tensor_); - device_id_ = device_id; - } - - void SetDevice(const std::string &device) override { - MS_EXCEPTION_IF_NULL(tensor_); - device_ = device; - } - - int GetDeviceId() const override { - MS_EXCEPTION_IF_NULL(tensor_); - return device_id_; - } - - std::string GetDevice() const override { - MS_EXCEPTION_IF_NULL(tensor_); - return device_; - } - - void *MutableData() override { - MS_EXCEPTION_IF_NULL(tensor_); - return tensor_->data_c(); - } - - void SetDeviceData(void *data) override { - MS_EXCEPTION_IF_NULL(tensor_); - auto old_device_data = GetDeviceData(); - MS_LOG(ERROR) << "set device data in tensor utils."; -#ifdef ENABLE_CLOUD_INFERENCE - if (old_device_data != nullptr && device_own_data_) { - kernel::AscendAllocatorPlugin::GetInstance().Free(old_device_data, GetDeviceId()); - } -#endif - auto data_size = DataSize(); - auto device_address = std::make_shared(data, data_size); - tensor_->set_device_address(device_address); - device_own_data_ = false; - } - void *GetDeviceData() override { - MS_EXCEPTION_IF_NULL(tensor_); - auto device_address = tensor_->device_address(); - if (device_address == nullptr) { - return nullptr; - } - return device_address->GetMutablePtr(); - } - - bool IsDevice() const override { - MS_EXCEPTION_IF_NULL(tensor_); - return tensor_->device_address() != nullptr; - } - - bool IsConst() const override { return false; } - - void SetShape(const std::vector &) override { MS_LOG_EXCEPTION << "Cannot set shape for TensorTensorImpl"; } - void SetDataType(mindspore::DataType) override { MS_LOG_EXCEPTION << "Cannot set data type for TensorTensorImpl"; } - void SetName(const std::string &name) override { - MS_EXCEPTION_IF_NULL(tensor_); - tensor_->set_name(name); - } - - mindspore::Format Format() const override; - - void SetFormat(mindspore::Format format) override; - - const std::string &Name() const override { - MS_EXCEPTION_IF_NULL(tensor_); - return tensor_->name(); - } - enum DataType DataType() const override { - MS_EXCEPTION_IF_NULL(tensor_); - return static_cast(tensor_->data_type()); - } - const std::vector &Shape() const override { - MS_EXCEPTION_IF_NULL(tensor_); - return tensor_->shape(); - } - - void SetAllocator(const std::shared_ptr &allocator) override { - MS_EXCEPTION_IF_NULL(tensor_); - tensor_->set_user_data("allocator", allocator); - } - std::shared_ptr GetAllocator() const override { - MS_EXCEPTION_IF_NULL(tensor_); - return tensor_->user_data("allocator"); - } - - std::vector GetQuantParams() const override { - MS_EXCEPTION_IF_NULL(tensor_); - auto data = tensor_->user_data>("quant_param"); - return data ? *data : std::vector(); - } - - void SetQuantParams(const std::vector &quant_param) override { - MS_EXCEPTION_IF_NULL(tensor_); - tensor_->set_user_data("quant_param", std::make_shared>(quant_param)); - } - - size_t DataSize() const override { - auto elem_num = ElementNum(); - if (elem_num <= 0) { - return 0; - } - return LongToSize(elem_num) * lite::DataTypeSize(static_cast(DataType())); - } - - std::shared_ptr Clone() const override { return std::make_shared(tensor_); } - - private: - std::shared_ptr tensor_ = nullptr; - std::string device_ = ""; - int device_id_ = -1; - bool device_own_data_ = true; -}; - -class TensorUtils { - public: - // MSTensor <-> TensorPtr - static std::vector MSTensorToTensorPtr(const std::vector &ms_tensors); - static std::vector TensorPtrToMSTensor(std::vector tensor_ptrs, - const std::vector &tensor_names); - - static std::vector MSTensorToTensor(const std::vector &ms_tensors); - static std::vector TensorToMSTensor(std::vector tensors, - const std::vector &tensor_names); - - // TensorPtr <-> Tensor - static std::vector TensorToTensorPtr( - const std::vector &tensors); - static std::vector TensorPtrToTensor( - const std::vector &tensor_ptrs); -}; - class CloudTensorUtils { public: /* lite tensor ---> Address */ @@ -239,13 +48,6 @@ class CloudTensorUtils { static std::vector LiteTensorToKernelTensorPtrVec( const std::vector &lite_tensors); }; - -class AbstractTensorUtils { - public: - static std::vector> GetTensorListShapes(const std::vector &tensors); - static bool SetTensorListShapse(const std::vector &tensors, - const std::vector> &shapes); -}; } // namespace mindspore #endif // MINDSPORE_LITE_SRC_EXTENDRT_UTILS_TENSOR_UTILS_H_ diff --git a/mindspore-lite/src/infer/graph_compiler.h b/mindspore-lite/src/infer/graph_compiler.h deleted file mode 100644 index 2e6a1045f6dc14e1eefdd148eb6150580cf01d77..0000000000000000000000000000000000000000 --- a/mindspore-lite/src/infer/graph_compiler.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_INFER_GRAPH_COMPILER_H_ -#define MINDSPORE_LITE_INFER_GRAPH_COMPILER_H_ - -#include -#include -#include "infer/execution_plan.h" -#include "infer/execution_flow.h" - -namespace mindspore::infer::abstract { -class GraphCompiler : public std::enable_shared_from_this { - public: - virtual ~GraphCompiler() = default; - - /// \brief Compile FuncGraph Into ExecutionPlan. - /// - /// \param[in] graph FuncGraph need to compile. - /// - /// \return ExecutionPlan pointer. - virtual ExecutionPlanPtr Compile(FuncGraphPtr graph) = 0; -}; -} // namespace mindspore::infer::abstract - -#endif // MINDSPORE_LITE_INFER_GRAPH_COMPILER_H_ diff --git a/mindspore-lite/src/litert/cxx_api/context.cc b/mindspore-lite/src/litert/cxx_api/context.cc index c59b91b8fba23caacdd5f24a3acc830093517c69..23673c5db17921b7383a588f024379422fcc796f 100644 --- a/mindspore-lite/src/litert/cxx_api/context.cc +++ b/mindspore-lite/src/litert/cxx_api/context.cc @@ -21,7 +21,6 @@ #include "include/lite_types.h" #include "src/litert/inner_allocator.h" #include "src/common/log_adapter.h" -#include "src/extendrt/delegate/tensorrt/distribution/distribution_base.h" namespace mindspore { constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16"; diff --git a/mindspore-lite/src/litert/cxx_api/tensor/tensor_impl.cc b/mindspore-lite/src/litert/cxx_api/tensor/tensor_impl.cc index 2628062366e298368f652f76952ea35642175088..933f2de6691330f3becc288a57c4a3fe70c2af9f 100644 --- a/mindspore-lite/src/litert/cxx_api/tensor/tensor_impl.cc +++ b/mindspore-lite/src/litert/cxx_api/tensor/tensor_impl.cc @@ -26,7 +26,7 @@ #include "src/tensor.h" #include "src/common/string_utils.h" #ifdef ENABLE_CLOUD_INFERENCE -#include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h" +#include "src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h" #endif namespace mindspore { using mindspore::lite::RET_OK; @@ -87,7 +87,7 @@ void LiteTensorImpl::SetDeviceData(void *data) { #ifdef ENABLE_CLOUD_INFERENCE if (GetDeviceData() != nullptr && own_data_) { MS_LOG(INFO) << "free device data in tensor impl."; - kernel::AscendAllocatorPlugin::GetInstance().Free(GetDeviceData(), GetDeviceId()); + AscendAllocatorPlugin::GetInstance().Free(GetDeviceData(), GetDeviceId()); } #endif lite_tensor_->set_device_data(data); diff --git a/mindspore-lite/src/litert/cxx_api/tensor/tensor_impl.h b/mindspore-lite/src/litert/cxx_api/tensor/tensor_impl.h index 165e8cffc97e3097eb17e062fc488dddc8229a2c..fccabd864497e9ae54d1c68b52d2348575990aea 100644 --- a/mindspore-lite/src/litert/cxx_api/tensor/tensor_impl.h +++ b/mindspore-lite/src/litert/cxx_api/tensor/tensor_impl.h @@ -32,7 +32,7 @@ #include "ir/api_tensor_impl.h" #include "common/mutable_tensor_impl.h" #if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE) -#include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h" +#include "src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h" #endif namespace mindspore { @@ -49,7 +49,7 @@ class LiteTensorImpl : public MutableTensorImpl { #if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE) if (GetDeviceData() != nullptr && own_data_) { MS_LOG(INFO) << "free device data in tensor impl."; - kernel::AscendAllocatorPlugin::GetInstance().Free(GetDeviceData(), GetDeviceId()); + AscendAllocatorPlugin::GetInstance().Free(GetDeviceData(), GetDeviceId()); lite_tensor_->set_device_data(nullptr); } #endif @@ -143,7 +143,7 @@ class LiteTensorImpl : public MutableTensorImpl { void *device_data = GetDeviceData(); if (device_data != nullptr && own_data_) { MS_LOG(INFO) << "free device data in tensor impl."; - kernel::AscendAllocatorPlugin::GetInstance().Free(device_data, GetDeviceId()); + AscendAllocatorPlugin::GetInstance().Free(device_data, GetDeviceId()); } #endif lite_tensor_->set_device(device); diff --git a/mindspore-lite/src/litert/cxx_api/types.cc b/mindspore-lite/src/litert/cxx_api/types.cc index 7a827a67086be1e0676691262a43b7557fa0ab0b..ea5bc13eb5b7c11627475c1aebc9cf8b1ea53170 100644 --- a/mindspore-lite/src/litert/cxx_api/types.cc +++ b/mindspore-lite/src/litert/cxx_api/types.cc @@ -28,7 +28,7 @@ #include "utils/file_utils.h" #include "ir/dtype.h" #include "utils/convert_utils_base.h" -#include "extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h" +#include "src/extendrt/delegate/ascend_acl/ascend_allocator_plugin.h" #endif namespace mindspore { @@ -117,28 +117,28 @@ MSTensor *MSTensor::CreateTensor(const std::vector &name, enum DataType ty auto device_type = CharToString(device); if (!device_type.empty() && device_type == "ascend") { #ifdef ENABLE_CLOUD_INFERENCE - kernel::AscendAllocatorPlugin::GetInstance().Register(); + AscendAllocatorPlugin::GetInstance().Register(); // check device id - device_id = device_id == -1 ? kernel::AscendAllocatorPlugin::GetInstance().GetCurrentDeviceId() : device_id; + device_id = device_id == -1 ? AscendAllocatorPlugin::GetInstance().GetCurrentDeviceId() : device_id; // check device data size size_t element_size = CalTensorDataSize(shape, type); MS_CHECK_FALSE_MSG(data_len != 0 && element_size != data_len, nullptr, "data len not equal element size."); // malloc device data - void *device_data = kernel::AscendAllocatorPlugin::GetInstance().Malloc(element_size, device_id); + void *device_data = AscendAllocatorPlugin::GetInstance().Malloc(element_size, device_id); MS_CHECK_TRUE_MSG(device_data != nullptr, nullptr, "malloc device data failed."); // create tensor auto impl = LiteTensorImpl::CreateTensorImpl(CharToString(name), type, shape, nullptr, 0); if (impl == nullptr) { - kernel::AscendAllocatorPlugin::GetInstance().Free(device_data, device_id); + AscendAllocatorPlugin::GetInstance().Free(device_data, device_id); MS_LOG(ERROR) << "Allocate tensor impl failed."; return nullptr; } if (data != nullptr) { // init device data by host data buf - auto status = kernel::AscendAllocatorPlugin::GetInstance().CopyHostDataToDevice(const_cast(data), + auto status = AscendAllocatorPlugin::GetInstance().CopyHostDataToDevice(const_cast(data), device_data, element_size); if (status != kSuccess) { - kernel::AscendAllocatorPlugin::GetInstance().Free(device_data, device_id); + AscendAllocatorPlugin::GetInstance().Free(device_data, device_id); MS_LOG(ERROR) << "copy host data to device data failed."; return nullptr; } @@ -151,7 +151,7 @@ MSTensor *MSTensor::CreateTensor(const std::vector &name, enum DataType ty auto ms_tensor = new (std::nothrow) MSTensor(impl); if (ms_tensor == nullptr) { - kernel::AscendAllocatorPlugin::GetInstance().Free(device_data, device_id); + AscendAllocatorPlugin::GetInstance().Free(device_data, device_id); MS_LOG(ERROR) << "Allocate MSTensor failed."; return nullptr; } @@ -203,7 +203,7 @@ MSTensor *MSTensor::CreateTensor(const std::vector &name, enum DataType ty MSTensor *MSTensor::CreateTensor(const std::vector &name, const MSTensor &tensor, const std::vector &device, int device_id) noexcept { #ifdef ENABLE_CLOUD_INFERENCE - kernel::AscendAllocatorPlugin::GetInstance().Register(); + AscendAllocatorPlugin::GetInstance().Register(); auto dst_device_type = CharToString(device); if (!dst_device_type.empty() && dst_device_type != "ascend") { MS_LOG(ERROR) << "only support create ascend device tensor."; @@ -233,7 +233,7 @@ MSTensor *MSTensor::CreateTensor(const std::vector &name, const MSTensor & auto new_tensor = CreateTensor(tensor.Name(), tensor.DataType(), tensor.Shape(), nullptr, tensor.DataSize(), "ascend", device_id); MS_CHECK_FALSE_MSG(new_tensor == nullptr, nullptr, "create new device tensor failed."); - auto status = kernel::AscendAllocatorPlugin::GetInstance().CopyDeviceDataToDevice( + auto status = AscendAllocatorPlugin::GetInstance().CopyDeviceDataToDevice( static_cast(tensor).GetDeviceData(), new_tensor->GetDeviceData(), new_tensor->DataSize(), tensor.DataSize(), tensor.GetDeviceId(), new_tensor->GetDeviceId()); if (status != kSuccess) { @@ -248,7 +248,7 @@ MSTensor *MSTensor::CreateTensor(const std::vector &name, const MSTensor & MS_LOG(INFO) << "copy device tensor to host tensor."; auto host_form_device = malloc(tensor.DataSize()); MS_CHECK_FALSE_MSG(host_form_device == nullptr, nullptr, "malloc host buf failed."); - auto status = kernel::AscendAllocatorPlugin::GetInstance().CopyDeviceDataToHost( + auto status = AscendAllocatorPlugin::GetInstance().CopyDeviceDataToHost( static_cast(tensor).GetDeviceData(), host_form_device, tensor.DataSize(), tensor.GetDeviceId()); if (status != kSuccess) { free(host_form_device); diff --git a/mindspore-lite/src/litert/inner_allocator.cc b/mindspore-lite/src/litert/inner_allocator.cc index f96d7d65d26314d4f5e4b1bd841d88a255e7eaf7..9bd0c017b6ca06aff7e81e9a335d7b0d6264d65e 100644 --- a/mindspore-lite/src/litert/inner_allocator.cc +++ b/mindspore-lite/src/litert/inner_allocator.cc @@ -50,7 +50,6 @@ bool DefaultAllocator::ReuseMemory(size_t free_size, size_t size) const { } void *DefaultAllocator::Malloc(size_t size) { - Lock(); if (size > max_malloc_size_) { MS_LOG(ERROR) << "MallocData out of max_size, size: " << size; return nullptr; @@ -59,6 +58,7 @@ void *DefaultAllocator::Malloc(size_t size) { MS_LOG(ERROR) << "Memory pool is exhausted"; return nullptr; } + Lock(); auto iter = freeList_.lower_bound(size); if (iter != freeList_.end() && ReuseMemory(iter->second->size, size)) { auto membuf = iter->second; diff --git a/mindspore-lite/src/litert/inner_context.cc b/mindspore-lite/src/litert/inner_context.cc index 91fbbd18e84f0570189d94ae132194a57dc255f9..6ac4ee3ccee21bdc2937975db4339f2a5861cc8c 100644 --- a/mindspore-lite/src/litert/inner_context.cc +++ b/mindspore-lite/src/litert/inner_context.cc @@ -70,7 +70,7 @@ int InnerContext::CreateThreadPool(bool is_control_flow) { actor_thread_num_, inter_op_parallel_num_, thread_num_, bind_mode_, affinity_core_list_, runner_id_); if (thread_pool_ == nullptr) { #ifdef ENABLE_MINDRT -#ifndef MS_COMPILE_IOS + #ifndef MS_COMPILE_IOS if (inter_op_parallel_num_ > 1) { thread_pool_ = ParallelThreadPool::CreateThreadPool(this->inter_op_parallel_num_, this->thread_num_, this->affinity_core_list_, bind_mode_, runner_id_); diff --git a/mindspore-lite/src/litert/kernel/ascend/CMakeLists.txt b/mindspore-lite/src/litert/kernel/ascend/CMakeLists.txt index 5298fe591e8248e1b02e12996926d4557238db05..67ecbaedff60846b470dedace20753cd7f7d3889 100644 --- a/mindspore-lite/src/litert/kernel/ascend/CMakeLists.txt +++ b/mindspore-lite/src/litert/kernel/ascend/CMakeLists.txt @@ -1,22 +1,22 @@ -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN/") -find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN/") +# find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) -aux_source_directory(src ACL_SRC) -aux_source_directory(plugin ACL_SRC) -add_library(ascend_kernel_plugin SHARED ${ACL_SRC}) -add_dependencies(ascend_kernel_plugin fbs_inner_src) -add_dependencies(ascend_kernel_plugin mindspore-lite) -target_link_libraries(ascend_kernel_plugin mindspore-lite _mindspore_ascend_symbol_obj) +# aux_source_directory(src ACL_SRC) +# aux_source_directory(plugin ACL_SRC) +# add_library(ascend_kernel_plugin SHARED ${ACL_SRC}) +# add_dependencies(ascend_kernel_plugin fbs_inner_src) +# add_dependencies(ascend_kernel_plugin mindspore-lite) +# target_link_libraries(ascend_kernel_plugin mindspore-lite _mindspore_ascend_symbol_obj) -if("${MSLITE_REGISTRY_DEVICE}" STREQUAL "SD3403" AND PLATFORM_ARM64) - find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl_retr libacl_retr.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl_cblas libacl_cblas.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl_runtime libruntime.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - target_link_libraries(ascend_kernel_plugin ${ge_graph} ${acl} ${acl_retr} ${acl_cblas} ${acl_runtime}) -else() - target_link_libraries(ascend_kernel_plugin ${ge_graph} ${ge_compiler} - ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime} ${libplatform} - ${libcompress} ${libopskernel} ${libaicore_utils} ${libaicpu_engine_common} ${acl}) -endif() \ No newline at end of file +# if("${MSLITE_REGISTRY_DEVICE}" STREQUAL "SD3403" AND PLATFORM_ARM64) +# find_library(ge_graph libgraph.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +# find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +# find_library(acl_retr libacl_retr.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +# find_library(acl_cblas libacl_cblas.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +# find_library(acl_runtime libruntime.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +# target_link_libraries(ascend_kernel_plugin ${ge_graph} ${acl} ${acl_retr} ${acl_cblas} ${acl_runtime}) +# else() +# target_link_libraries(ascend_kernel_plugin ${ge_graph} ${ge_compiler} +# ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime} ${libplatform} +# ${libcompress} ${libopskernel} ${libaicore_utils} ${libaicpu_engine_common} ${acl}) +# endif() \ No newline at end of file diff --git a/mindspore-lite/src/litert/kernel/ascend/src/model_infer.cc b/mindspore-lite/src/litert/kernel/ascend/src/model_infer.cc index f2cc77a98f9ae6dccb4d05dcb406e152b94b8137..601480f133fc8240242b621d1c13bb5526f3c7f2 100644 --- a/mindspore-lite/src/litert/kernel/ascend/src/model_infer.cc +++ b/mindspore-lite/src/litert/kernel/ascend/src/model_infer.cc @@ -30,16 +30,16 @@ constexpr auto kModelSharingKey = "multi_model_sharing_mem"; ModelInfer::ModelInfer(const Buffer &om_data, const AclModelOptions &options, const std::map &config_info) - : init_flag_(false), - load_flag_(false), - device_type_("AscendCL"), - context_(nullptr), - om_data_(om_data), - options_(options), - model_process_(options), - config_info_(config_info), - acl_env_(nullptr), - device_id_(0) {} + : init_flag_(false), + load_flag_(false), + device_type_("AscendCL"), + context_(nullptr), + om_data_(om_data), + options_(options), + model_process_(options), + config_info_(config_info), + acl_env_(nullptr), + device_id_(0) {} STATUS ModelInfer::Init() { if (init_flag_) { diff --git a/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.cc b/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.cc index 4455901e78e3aa00be870ac8980fcf2a1ade6bfe..1f6da27c009eba0ec5aed51c61ec0eefbb4796a7 100644 --- a/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/base/arithmetic_base.cc @@ -226,7 +226,7 @@ int ArithmeticBaseCPUKernel::UpdateParameter() { int ArithmeticBaseCPUKernel::BroadCastConstTensor() { CalcMultiplesAndStrides(param_); -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE bool prefer_explicit_broadcast = false; #else bool prefer_explicit_broadcast = param_->ndim_ != 1; diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/arithmetic_fp32.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/arithmetic_fp32.cc index 9794faf021520ae196c62213f036e1bd16abadc5..36adb54601df4eed2d9802194e26700f6f71415f 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/arithmetic_fp32.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/arithmetic_fp32.cc @@ -179,7 +179,6 @@ int ArithmeticCPUKernel::DoExecute(const void *input0, const void *input1, void return ret; } -#ifdef SERVER_INFERENCE REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_MulFusion, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_MulFusion, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_AddFusion, LiteKernelCreator) @@ -207,5 +206,4 @@ REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_FloorMod, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Eltwise, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_DivFusion, LiteKernelCreator) -#endif } // namespace mindspore::kernel diff --git a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc index 38b72c3932eb4bb27da70a47a4766fc6db01a83b..ae6735dfffa0a6cddda5cda7dcac23680413b413 100644 --- a/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc +++ b/mindspore-lite/src/litert/kernel/cpu/fp32/matmul_fp32_base.cc @@ -19,7 +19,7 @@ #include "nnacl/fp32/matmul_fp32.h" #include "nnacl/fp32/pack_fp32.h" #include "nnacl/fp32/pack_fp32_opt.h" -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +#if defined(MSLITE_ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) #include "thread/parallel_thread_pool_manager.h" #endif @@ -752,7 +752,7 @@ int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() { } int MatmulFp32BaseCPUKernel::GetThreadCuttingPolicy() { -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +#if defined(MSLITE_ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) constexpr int kNumDeepThreshold = 512; if (params_->deep_ < kNumDeepThreshold) { auto num = ParallelThreadPoolManager::GetInstance()->GetThreadPoolSize( diff --git a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.cc b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.cc index 187c95acbe18118f65ce0d119569b474dded0959..e7036497dca701ce797c5bf1bd083cb4844dcdc0 100644 --- a/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.cc +++ b/mindspore-lite/src/litert/kernel/cpu/nnacl/nnacl_matmul.cc @@ -20,7 +20,7 @@ #include "nnacl/kernel/matmul_base.h" #include "nnacl/cxx_utils.h" #include "src/litert/pack_weight_manager.h" -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +#if defined(MSLITE_ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) #include "thread/parallel_thread_pool_manager.h" #endif @@ -60,7 +60,7 @@ int MatmulKernel::ReSize() { MatmulStruct *matmul = reinterpret_cast(kernel_); matmul->model_thread_nr_ = kernel_->thread_nr_; -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +#if defined(MSLITE_ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) auto num = ParallelThreadPoolManager::GetInstance()->GetThreadPoolSize(ms_context_->thread_pool_); matmul->model_thread_nr_ = (num != -1) ? (num) : (kernel_->thread_nr_); #endif diff --git a/mindspore-lite/src/litert/lite_session.cc b/mindspore-lite/src/litert/lite_session.cc index f736d67193b6c1a900c41810a3aee65e2a2b5c4b..cf47bb98e81372e28fca5ca8e11ea74922417dc5 100644 --- a/mindspore-lite/src/litert/lite_session.cc +++ b/mindspore-lite/src/litert/lite_session.cc @@ -60,10 +60,10 @@ #endif #include "src/litert/runtime_convert.h" #include "extendrt/mindir_loader/model_loader.h" -#ifndef __ANDROID__ -#include "kernel/ascend/plugin/ascend_kernel_plugin.h" -#endif -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +// #ifndef __ANDROID__ +// #include "kernel/ascend/plugin/ascend_kernel_plugin.h" +// #endif +#if defined(ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) #include "thread/parallel_thread_pool_manager.h" #endif #include "src/litert/runtime_packed_node_pass.h" @@ -806,15 +806,6 @@ int LiteSession::PrepareKernels(const Model *model) { } } } - -#if (defined DEBUG) && (defined MSLITE_EXPORT_COMPUTE_IR) - auto subgraph_kernel = static_cast(kernel); - ret = DrawGraph(subgraph_kernel); - if (ret != RET_OK) { - MS_LOG(ERROR) << "graph: " << kernel->name() << " draw failed."; - } -#endif - ret = kernel->Prepare(); if (ret != RET_OK) { MS_LOG(ERROR) << "Prepare kernel " << kernel->name() << " failed: " << ret; @@ -871,7 +862,7 @@ int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &af MS_LOG(ERROR) << "Not support multi-threading"; return RET_ERROR; } -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +#if defined(MSLITE_ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) ParallelThreadPoolManager::GetInstance()->ActivatePool(runner_id_, worker_id_); #endif STATUS ret = CheckTensorsInvalid(inputs_); @@ -897,7 +888,7 @@ int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &af input->set_shape_changed(false); } } -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +#if defined(MSLITE_ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) ParallelThreadPoolManager::GetInstance()->SetFreePool(runner_id_, worker_id_); #endif is_running_.store(false); @@ -930,7 +921,7 @@ int LiteSession::InitSharedThreadPool() { MS_LOG(INFO) << "runner id: " << runner_id_ << " enable_shared_pool: " << enable_shared_pool << " workers_num: " << workers_num << " thread_num_limit: " << thread_num_limit << " remaining_thread_num: " << remaining_thread_num; -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +#if defined(MSLITE_ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) ParallelThreadPoolManager::GetInstance()->Init(enable_shared_pool, runner_id_, workers_num, remaining_thread_num, thread_num_limit); #endif @@ -961,7 +952,7 @@ int LiteSession::InitContext(const std::shared_ptr &context) { context_->thread_pool_->SetMinSpinCount(kDefaulLiteIosSpinCount); #endif -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +#if defined(MSLITE_ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) if (context_->inter_op_parallel_num_ > 1 && !runner_id_.empty() && ParallelThreadPoolManager::GetInstance()->GetEnableSharedThreadPool(runner_id_)) { MS_LOG(INFO) << "Enable subgraph parallelism and enable thread pool sharing"; @@ -972,17 +963,17 @@ int LiteSession::InitContext(const std::shared_ptr &context) { return RET_OK; } -int LiteSession::InitAscend(const std::shared_ptr &context) { -#ifndef __ANDROID__ - if (!context->IsDeviceTypeEnabled(DT_ASCEND)) { - MS_LOG(INFO) << "There is no Ascend device type."; - return RET_OK; - } - return mindspore::AscendKernelPlugin::GetInstance().Register(); -#else - return RET_OK; -#endif -} +// int LiteSession::InitAscend(const std::shared_ptr &context) { +// #ifndef __ANDROID__ +// if (!context->IsDeviceTypeEnabled(DT_ASCEND)) { +// MS_LOG(INFO) << "There is no Ascend device type."; +// return RET_OK; +// } +// return mindspore::AscendKernelPlugin::GetInstance().Register(); +// #else +// return RET_OK; +// #endif +// } int LiteSession::CreateTensorRTDelegate() { #ifdef GPU_TENSORRT @@ -1163,12 +1154,12 @@ int LiteSession::Init(const std::shared_ptr &context) { return ret; } - ret = InitAscend(context); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Open Ascend kernel plugin failed"; - is_running_.store(false); - return ret; - } + // ret = InitAscend(context); + // if (ret != RET_OK) { + // MS_LOG(ERROR) << "Open Ascend kernel plugin failed"; + // is_running_.store(false); + // return ret; + // } ret = InitDelegate(); if (ret != RET_OK) { @@ -1252,7 +1243,7 @@ LiteSession::~LiteSession() { #endif delete ms_context_; ms_context_ = nullptr; -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +#if defined(MSLITE_ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) ParallelThreadPoolManager::GetInstance()->ResetParallelThreadPoolManager(runner_id_); #endif lite::PackWeightManager::GetInstance()->FreePackWeight(runner_id_, model_id_); diff --git a/mindspore-lite/src/litert/pass/format_pass/format_pass.cc b/mindspore-lite/src/litert/pass/format_pass/format_pass.cc index c18ae7f2ebfc1b9f8643cf4b6013a4a50a3f1246..d9dc58e1c4ca7f4406e3a342157770c4fad02eef 100644 --- a/mindspore-lite/src/litert/pass/format_pass/format_pass.cc +++ b/mindspore-lite/src/litert/pass/format_pass/format_pass.cc @@ -21,8 +21,6 @@ #include "src/litert/kernel_registry.h" #include "nnacl/format_transpose_parameter.h" #endif -#include "src/common/draw/drawer.h" - namespace mindspore::lite::pass { #ifdef ENABLE_MULTI_LAYOUT namespace { @@ -80,7 +78,6 @@ int FormatOptimize::RunPass(kernel::SubGraphKernel *graph, std::vector MS_LOG(ERROR) << "Run pass failed"; return status; } - DrawDot(graph, pass->name()); } return RET_OK; } diff --git a/mindspore-lite/src/litert/pass/online_fusion/cast_gather_reduce_fusion_pass.cc b/mindspore-lite/src/litert/pass/online_fusion/cast_gather_reduce_fusion_pass.cc index 2c29023f85192312ad26fefe5d8c1b61fe577284..355c50937f1ba71ed42b4e397f0e99eee72556e0 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/cast_gather_reduce_fusion_pass.cc +++ b/mindspore-lite/src/litert/pass/online_fusion/cast_gather_reduce_fusion_pass.cc @@ -129,7 +129,6 @@ int CastGatherReduceOnlineFusionPass::CreateCastGatherReduceCustomNode(LiteGraph if (online_fusion_prim == nullptr) { MS_LOG(ERROR) << "GetRoot CastGatherReduceFusion primitive failed."; free(prim); - fbb.Clear(); return RET_ERROR; } fbb.Clear(); diff --git a/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.cc b/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.cc index 431534328c835632c5c4ca2fb0c9cd3e13ec13eb..986e319552a8159fb48dc8a8a2e27ea206d4b1a7 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.cc +++ b/mindspore-lite/src/litert/pass/online_fusion/reduce_concat_fusion_pass.cc @@ -137,7 +137,6 @@ int ReduceConcatOnlineFusionPass::CreateReduceConcatCustomNode(LiteGraph::Node * if (online_fusion_prim == nullptr) { MS_LOG(ERROR) << "GetRoot ReduceConcatFusion primitive failed."; free(prim); - fbb.Clear(); return RET_ERROR; } fbb.Clear(); diff --git a/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.cc b/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.cc index 18a835e6c72cd9fda4b9fcfbf17ce6b6ba21cd7e..905b3fc30f949a0a5290331715917b4a019959c0 100644 --- a/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.cc +++ b/mindspore-lite/src/litert/pass/online_fusion/split_reduce_concat_fusion_pass.cc @@ -235,7 +235,6 @@ int SplitReduceConcatOnlineFusionPass::CreateCustomNode(LiteGraph::Node *node, S if (online_fusion_prim == nullptr) { MS_LOG(ERROR) << "GetRoot SplitReduceConcatFusion primitive failed."; free(prim); - fbb.Clear(); return RET_ERROR; } fbb.Clear(); diff --git a/mindspore-lite/src/litert/scheduler.cc b/mindspore-lite/src/litert/scheduler.cc index 008d27d8c75416602e97f5241ad8552a08e957ce..93d2105f7e7fc90e240654ba92c7a9606a97c583 100644 --- a/mindspore-lite/src/litert/scheduler.cc +++ b/mindspore-lite/src/litert/scheduler.cc @@ -55,7 +55,7 @@ #include "include/registry/register_kernel_interface.h" #include "extendrt/mindir_loader/abstract_base_model.h" #include "src/litert/pack_weight_manager.h" -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +#if defined(MSLITE_ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) #include "thread/parallel_thread_pool_manager.h" #endif @@ -1030,7 +1030,7 @@ int Scheduler::FindCpuKernel(const std::vector &in_tensors, const std: } } -#if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) +#if defined(MSLITE_ENABLE_CLOUD_INFERENCE) && defined(ENABLE_MINDRT) // reset op task num, The number of operator segmentation tasks is not necessarily equal to the number of threads int thread_num_limit = ParallelThreadPoolManager::GetInstance()->GetTaskNum(config_info_); if (thread_num_limit != -1 && IsSharedThreadPoolOp(op_type)) { diff --git a/mindspore-lite/src/litert/thread_pool_reuse_manager.cc b/mindspore-lite/src/litert/thread_pool_reuse_manager.cc index e93aadc282fc8edcdba9fdb0415e552ca1ed710f..f29148c45f00d6b2089daee27f288155b7f80a4a 100644 --- a/mindspore-lite/src/litert/thread_pool_reuse_manager.cc +++ b/mindspore-lite/src/litert/thread_pool_reuse_manager.cc @@ -40,25 +40,7 @@ ThreadPoolReuseManager::~ThreadPoolReuseManager() { ThreadPool *ThreadPoolReuseManager::GetThreadPool(size_t actor_num, size_t inter_op_parallel_num, size_t thread_num, BindMode bind_mode, const std::vector &core_list, std::string runner_id) { -#ifdef SERVER_INFERENCE - auto hash_key = ComputeHash(actor_num, inter_op_parallel_num, thread_num, bind_mode, core_list); - std::lock_guard lock(l); - if (thread_pool_container_.find(hash_key) == thread_pool_container_.end()) { - return nullptr; - } - if (thread_pool_container_[hash_key].empty()) { - return nullptr; - } - auto thread_pool = thread_pool_container_[hash_key].back(); - if (inter_op_parallel_num > 1 && !thread_pool->SetRunnerID(runner_id)) { - MS_LOG(WARNING) << "can not reuse thread pool."; - return nullptr; - } - thread_pool_container_[hash_key].pop_back(); - return thread_pool; -#else return nullptr; -#endif } void ThreadPoolReuseManager::RetrieveThreadPool(size_t actor_num, size_t inter_op_parallel_num, size_t thread_num, @@ -67,13 +49,7 @@ void ThreadPoolReuseManager::RetrieveThreadPool(size_t actor_num, size_t inter_o if (thread_pool == nullptr) { return; } -#ifdef SERVER_INFERENCE - auto hash_key = ComputeHash(actor_num, inter_op_parallel_num, thread_num, bind_mode, core_list); - std::lock_guard lock(l); - thread_pool_container_[hash_key].push_back(thread_pool); -#else delete thread_pool; -#endif } std::string ThreadPoolReuseManager::ComputeHash(size_t actor_num, size_t inter_op_parallel_num, size_t thread_num, diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/activation.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/activation.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..bf17fe92af852809d13b4c847ac6f188a68e1a2d --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/activation.cl.inc @@ -0,0 +1,108 @@ +static const char *activation_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"__kernel void LeakyRelu(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape,\n" \ +" const float alpha) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= img_shape.x || Y >= img_shape.y) return;\n" \ +" FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));\n" \ +" FLT4 tmp;\n" \ +" FLT alpha_f = TO_FLT(alpha);\n" \ +" tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * alpha_f;\n" \ +" tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * alpha_f;\n" \ +" tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * alpha_f;\n" \ +" tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * alpha_f;\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), tmp);\n" \ +"}\n" \ +"\n" \ +"__kernel void Relu(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= img_shape.x || Y >= img_shape.y) return;\n" \ +" FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));\n" \ +" in_c4 = max(in_c4, (FLT)(0.f));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), in_c4);\n" \ +"}\n" \ +"\n" \ +"__kernel void Relu6(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= img_shape.x || Y >= img_shape.y) return;\n" \ +" FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));\n" \ +" in_c4 = clamp(in_c4, (FLT)(0.f), (FLT)(6.f));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), in_c4);\n" \ +"}\n" \ +"\n" \ +"__kernel void Sigmoid(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape, const int c4,\n" \ +" const int last_c4) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= img_shape.x || Y >= img_shape.y || c4 == 0) return;\n" \ +" int C4 = X % c4;\n" \ +" FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));\n" \ +" if (C4 < c4 - 1) {\n" \ +" in_c4 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-in_c4));\n" \ +" } else {\n" \ +" in_c4.x = (FLT)(1.f) / ((FLT)(1.f) + exp(-in_c4.x));\n" \ +" if (last_c4 > 1) {\n" \ +" in_c4.y = (FLT)(1.f) / ((FLT)(1.f) + exp(-in_c4.y));\n" \ +" }\n" \ +" if (last_c4 > 2) {\n" \ +" in_c4.z = (FLT)(1.f) / ((FLT)(1.f) + exp(-in_c4.z));\n" \ +" }\n" \ +" if (last_c4 > 3) {\n" \ +" in_c4.w = (FLT)(1.f) / ((FLT)(1.f) + exp(-in_c4.w));\n" \ +" }\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), in_c4);\n" \ +"}\n" \ +"\n" \ +"__kernel void Tanh(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= img_shape.x || Y >= img_shape.y) return;\n" \ +" FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));\n" \ +" FLT4 exp0 = exp(in_c4);\n" \ +" FLT4 exp1 = exp(-in_c4);\n" \ +" in_c4 = (exp0 - exp1) / (exp0 + exp1);\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), in_c4);\n" \ +"}\n" \ +"\n" \ +"__kernel void Swish(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= img_shape.x || Y >= img_shape.y) return;\n" \ +" FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y));\n" \ +" in_c4 = in_c4 * ((FLT4)(1.f) / ((FLT4)(1.f) + exp(-in_c4)));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), in_c4);\n" \ +"}\n" \ +"\n" \ +"__kernel void HSwish(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape) {\n" \ +" int X = get_global_id(0); // w*c\n" \ +" int Y = get_global_id(1); // n*h\n" \ +" if (X >= img_shape.x || Y >= img_shape.y) return;\n" \ +" FLT4 temp = READ_IMAGE(input, smp_zero, (int2)(X, Y));\n" \ +" FLT4 result = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" result.x = temp.x * (temp.x <= -3 ? 0 : (temp.x >= 3 ? 1 : temp.x / 6 + 0.5f));\n" \ +" result.y = temp.y * (temp.y <= -3 ? 0 : (temp.y >= 3 ? 1 : temp.y / 6 + 0.5f));\n" \ +" result.z = temp.z * (temp.z <= -3 ? 0 : (temp.z >= 3 ? 1 : temp.z / 6 + 0.5f));\n" \ +" result.w = temp.w * (temp.w <= -3 ? 0 : (temp.w >= 3 ? 1 : temp.w / 6 + 0.5f));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void HSigmoid(__read_only image2d_t input, __write_only image2d_t output, const int2 img_shape) {\n" \ +" int X = get_global_id(0); // w*c\n" \ +" int Y = get_global_id(1); // n*h\n" \ +" if (X >= img_shape.x || Y >= img_shape.y) return;\n" \ +" FLT4 temp = READ_IMAGE(input, smp_zero, (int2)(X, Y));\n" \ +" FLT4 result = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" result.x = temp.x <= -3 ? 0 : (temp.x >= 3 ? 1 : temp.x / 6 + 0.5f);\n" \ +" result.y = temp.y <= -3 ? 0 : (temp.y >= 3 ? 1 : temp.y / 6 + 0.5f);\n" \ +" result.z = temp.z <= -3 ? 0 : (temp.z >= 3 ? 1 : temp.z / 6 + 0.5f);\n" \ +" result.w = temp.w <= -3 ? 0 : (temp.w >= 3 ? 1 : temp.w / 6 + 0.5f);\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/argminmax.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/argminmax.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..88c53be13783aa199d5f4471a8ffa5db252ef7be --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/argminmax.cl.inc @@ -0,0 +1,81 @@ +static const char *argminmax_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#define swap(a, b, c) \\\n" \ +" c = a; \\\n" \ +" a = b; \\\n" \ +" b = c;\n" \ +"#define swap_atomic(a, b, c) \\\n" \ +" c = atomic_xchg(a, *(b)); \\\n" \ +" c = atomic_xchg(b, c);\n" \ +"#define UP_ROUND(a, b) (((a + b - 1) / b) * b)\n" \ +"#define C4NUM 4\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void argminmax(__global FLT *src_data, __global FLT *dst_data, __global FLT *buf, __global int *ids,\n" \ +" int4 shape, int4 src_size, int4 cus_size, int4 strides, int4 flags) {\n" \ +" int X = get_global_id(0); // lower reduce stride\n" \ +" int Y = get_global_id(1); // upper axis accumulation\n" \ +" if (X >= src_size.x || Y >= src_size.y) {\n" \ +" return;\n" \ +" }\n" \ +" bool keep_dims = cus_size.y;\n" \ +" int width = shape.z * shape.w;\n" \ +" int offset = X + Y * src_size.z;\n" \ +" int align_c4_in = (flags.z != 3) ? (X / shape.w) * (C4NUM - shape.w & 0x00000003) : 0;\n" \ +" int align_c4_out =\n" \ +" (flags.z == 3 && flags.w == 1 && !keep_dims) ? (Y / shape.z) * (C4NUM - shape.z & 0x00000003) : align_c4_in;\n" \ +" int align_in = 0;\n" \ +" int align_out = 0;\n" \ +" if (flags.z == 3) {\n" \ +" align_in = (Y / shape.z) * cus_size.z;\n" \ +" align_out = (Y / ((flags.w > 1 || keep_dims) ? shape.z : shape.z * shape.y)) * cus_size.w;\n" \ +" }\n" \ +" if (flags.z == 0) {\n" \ +" align_in = X / (width)*cus_size.z;\n" \ +" align_out = align_in;\n" \ +" }\n" \ +" if (flags.z == 2 && !keep_dims) {\n" \ +" align_out = (Y / shape.y) * cus_size.w;\n" \ +" }\n" \ +" for (int k = 0; k < src_size.w; ++k) {\n" \ +" int idx0 = (X + k * strides.x) + Y * strides.y + (align_c4_in + align_in);\n" \ +" int idx1 = offset + k * src_size.x;\n" \ +" ids[idx1] = k;\n" \ +" buf[idx1] = src_data[idx0];\n" \ +" }\n" \ +" for (unsigned int i = 2; i <= cus_size.x; i <<= 1) {\n" \ +" for (unsigned int j = i >> 1; j > 0; j >>= 1) {\n" \ +" for (int tid = 0; tid < src_size.w; ++tid) {\n" \ +" unsigned int tid_comp = tid + j;\n" \ +" if (tid_comp < src_size.w) {\n" \ +" int lk = offset + tid * src_size.x;\n" \ +" int rk = offset + tid_comp * src_size.x;\n" \ +" if ((tid & i) == 0) { // ascending\n" \ +" if (buf[lk] > buf[rk]) {\n" \ +" FLT tmpf;\n" \ +" swap(buf[lk], buf[rk], tmpf);\n" \ +" int tmpi;\n" \ +" swap(ids[lk], ids[rk], tmpi);\n" \ +" }\n" \ +" } else { // desending\n" \ +" if (buf[lk] < buf[rk]) {\n" \ +" FLT tmpf;\n" \ +" swap(buf[lk], buf[rk], tmpf);\n" \ +" int tmpi;\n" \ +" swap(ids[lk], ids[rk], tmpi);\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +" for (int k = 0; k < flags.w; ++k) {\n" \ +" int idx0 = (X + k * strides.z) + Y * strides.w + (align_c4_out + align_out);\n" \ +" int idx1 = flags.y ? (offset + (src_size.w - k - 1) * src_size.x) : (offset + k * src_size.x);\n" \ +" if (flags.x) {\n" \ +" dst_data[idx0] = buf[idx1];\n" \ +" } else {\n" \ +" dst_data[idx0] = ids[idx1];\n" \ +" }\n" \ +" }\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/arithmetic.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/arithmetic.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..4c7b4b790cbf744c8d939aa1b7e89dc55ca4c506 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/arithmetic.cl.inc @@ -0,0 +1,633 @@ +static const char *arithmetic_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#define divide_no_check(a, b) (a / b)\n" \ +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"__kernel void ElementAdd(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a + b;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementSub(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a - b;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementMul(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a * b;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementDiv(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = divide_no_check(a, b);\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementLogicalAnd(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = AS_FLT4(AS_UINT4(a) & AS_UINT4(b));\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementLogicalOr(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = AS_FLT4(AS_UINT4(a) | AS_UINT4(b));\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementMaximum(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = max(a, b);\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementMinimum(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = min(a, b);\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementFloorDiv(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = floor(divide_no_check(a, b));\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementFloorMod(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a - floor(divide_no_check(a, b)) * b;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementSquaredDifference(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int2 output_shape, float act_min,\n" \ +" float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = pown((a - b), (int4)2);\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementEqual(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a == b ? (FLT4)1.f : (FLT4).0f;\n" \ +" // error?\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementNotEqual(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a != b ? (FLT4)1.f : (FLT4).0f;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementLess(__read_only image2d_t input_a, __read_only image2d_t input_b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a < b ? (FLT4)1.f : (FLT4).0f;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementLessEqual(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a <= b ? (FLT4)1.f : (FLT4).0f;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementGreater(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a > b ? (FLT4)1.f : (FLT4).0f;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ElementGreaterEqual(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int2 output_shape, float act_min,\n" \ +" float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a >= b ? (FLT4)1.f : (FLT4).0f;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastNHWC4Add(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int4 a_shape, const int4 b_shape,\n" \ +" const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // N * H\n" \ +" if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {\n" \ +" return;\n" \ +" }\n" \ +" int H = Z % output_shape.y;\n" \ +" int N = Z / output_shape.y;\n" \ +" int a_c = X < a_shape.w ? X : 0;\n" \ +" int a_w = Y < a_shape.z ? Y : 0;\n" \ +" int a_h = H < a_shape.y ? H : 0;\n" \ +" int a_n = N < a_shape.x ? N : 0;\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));\n" \ +" int b_c = X < b_shape.w ? X : 0;\n" \ +" int b_w = Y < b_shape.z ? Y : 0;\n" \ +" int b_h = H < b_shape.y ? H : 0;\n" \ +" int b_n = N < b_shape.x ? N : 0;\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));\n" \ +" FLT4 result;\n" \ +" if (broadcastC_flag == 0) {\n" \ +" result = a + b;\n" \ +" } else if (broadcastC_flag == 1) {\n" \ +" result = a.x + b;\n" \ +" } else {\n" \ +" result = a + b.x;\n" \ +" }\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastNHWC4BiasAdd(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int4 a_shape, const int4 b_shape,\n" \ +" const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // N * H\n" \ +" if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {\n" \ +" return;\n" \ +" }\n" \ +" int H = Z % output_shape.y;\n" \ +" int N = Z / output_shape.y;\n" \ +" int a_c = X < a_shape.w ? X : 0;\n" \ +" int a_w = Y < a_shape.z ? Y : 0;\n" \ +" int a_h = H < a_shape.y ? H : 0;\n" \ +" int a_n = N < a_shape.x ? N : 0;\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));\n" \ +" int b_c = X < b_shape.w ? X : 0;\n" \ +" int b_w = Y < b_shape.z ? Y : 0;\n" \ +" int b_h = H < b_shape.y ? H : 0;\n" \ +" int b_n = N < b_shape.x ? N : 0;\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));\n" \ +" FLT4 result;\n" \ +" if (broadcastC_flag == 0) {\n" \ +" result = a + b;\n" \ +" } else if (broadcastC_flag == 1) {\n" \ +" result = a.x + b;\n" \ +" } else {\n" \ +" result = a + b.x;\n" \ +" }\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastNHWC4Sub(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int4 a_shape, const int4 b_shape,\n" \ +" const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // N * H\n" \ +" if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {\n" \ +" return;\n" \ +" }\n" \ +" int H = Z % output_shape.y;\n" \ +" int N = Z / output_shape.y;\n" \ +" int a_c = X < a_shape.w ? X : 0;\n" \ +" int a_w = Y < a_shape.z ? Y : 0;\n" \ +" int a_h = H < a_shape.y ? H : 0;\n" \ +" int a_n = N < a_shape.x ? N : 0;\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));\n" \ +" int b_c = X < b_shape.w ? X : 0;\n" \ +" int b_w = Y < b_shape.z ? Y : 0;\n" \ +" int b_h = H < b_shape.y ? H : 0;\n" \ +" int b_n = N < b_shape.x ? N : 0;\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));\n" \ +" FLT4 result;\n" \ +" if (broadcastC_flag == 0) {\n" \ +" result = a - b;\n" \ +" } else if (broadcastC_flag == 1) {\n" \ +" result = a.x - b;\n" \ +" } else {\n" \ +" result = a - b.x;\n" \ +" }\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastNHWC4Mul(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int4 a_shape, const int4 b_shape,\n" \ +" const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // N * H\n" \ +" if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {\n" \ +" return;\n" \ +" }\n" \ +" int H = Z % output_shape.y;\n" \ +" int N = Z / output_shape.y;\n" \ +" int a_c = X < a_shape.w ? X : 0;\n" \ +" int a_w = Y < a_shape.z ? Y : 0;\n" \ +" int a_h = H < a_shape.y ? H : 0;\n" \ +" int a_n = N < a_shape.x ? N : 0;\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));\n" \ +" int b_c = X < b_shape.w ? X : 0;\n" \ +" int b_w = Y < b_shape.z ? Y : 0;\n" \ +" int b_h = H < b_shape.y ? H : 0;\n" \ +" int b_n = N < b_shape.x ? N : 0;\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));\n" \ +" FLT4 result;\n" \ +" if (broadcastC_flag == 0) {\n" \ +" result = a * b;\n" \ +" } else if (broadcastC_flag == 1) {\n" \ +" result = a.x * b;\n" \ +" } else {\n" \ +" result = a * b.x;\n" \ +" }\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastNHWC4Div(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int4 a_shape, const int4 b_shape,\n" \ +" const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // N * H\n" \ +" if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {\n" \ +" return;\n" \ +" }\n" \ +" int H = Z % output_shape.y;\n" \ +" int N = Z / output_shape.y;\n" \ +" int a_c = X < a_shape.w ? X : 0;\n" \ +" int a_w = Y < a_shape.z ? Y : 0;\n" \ +" int a_h = H < a_shape.y ? H : 0;\n" \ +" int a_n = N < a_shape.x ? N : 0;\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));\n" \ +" int b_c = X < b_shape.w ? X : 0;\n" \ +" int b_w = Y < b_shape.z ? Y : 0;\n" \ +" int b_h = H < b_shape.y ? H : 0;\n" \ +" int b_n = N < b_shape.x ? N : 0;\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));\n" \ +" FLT4 result;\n" \ +" if (broadcastC_flag == 0) {\n" \ +" result = a / b;\n" \ +" } else if (broadcastC_flag == 1) {\n" \ +" result = a.x / b;\n" \ +" } else {\n" \ +" result = a / b.x;\n" \ +" }\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastLogicalAnd(__read_only image2d_t input_a, float b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = AS_FLT4(AS_UINT4(a) & (UINT4)((FLT)b));\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastLogicalOr(__read_only image2d_t input_a, float b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = AS_FLT4(AS_UINT4(a) | (UINT4)((FLT)b));\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastMaximum(__read_only image2d_t input_a, float b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = max(a, (FLT4)b);\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastMinimum(__read_only image2d_t input_a, float b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = min(a, (FLT4)b);\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastFloorDiv(__read_only image2d_t input_a, float b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = floor(divide_no_check(a, (FLT4)b));\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"__kernel void BroadcastNHWC4FloorMod(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int4 a_shape, const int4 b_shape,\n" \ +" const int4 output_shape, const int broadcastC_flag, float act_min, float act_max) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // H\n" \ +" if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(Y * a_shape.w + X, Z));\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, 0));\n" \ +" FLT4 result = a - floor(divide_no_check(a, b)) * b;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastNHWC4SquaredDifference(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int4 a_shape, const int4 b_shape,\n" \ +" const int4 output_shape, const int broadcastC_flag, float act_min,\n" \ +" float act_max) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // w\n" \ +" int Z = get_global_id(2); // H\n" \ +"\n" \ +" if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {\n" \ +" return;\n" \ +" }\n" \ +" int H = Z % output_shape.y;\n" \ +" int N = Z / output_shape.y;\n" \ +" int a_c = X < a_shape.w ? X : 0;\n" \ +" int a_w = Y < a_shape.z ? Y : 0;\n" \ +" int a_h = H < a_shape.y ? H : 0;\n" \ +" int a_n = N < a_shape.x ? N : 0;\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));\n" \ +" int b_c = X < b_shape.w ? X : 0;\n" \ +" int b_w = Y < b_shape.z ? Y : 0;\n" \ +" int b_h = H < b_shape.y ? H : 0;\n" \ +" int b_n = N < b_shape.x ? N : 0;\n" \ +" FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));\n" \ +" FLT4 result;\n" \ +" if (broadcastC_flag == 0) {\n" \ +" result = pown((a - b), (int4)2);\n" \ +" } else if (broadcastC_flag == 1) {\n" \ +" result = pown((a.x - b), (int4)2);\n" \ +" } else {\n" \ +" result = pown((a - b.x), (int4)2);\n" \ +" }\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastEqual(__read_only image2d_t input_a, float b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a == (FLT4)b ? (FLT4)1.f : (FLT4).0f;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastNotEqual(__read_only image2d_t input_a, float b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a != (FLT4)b ? (FLT4)1.f : (FLT4).0f;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastLess(__read_only image2d_t input_a, float b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a < (FLT4)b ? (FLT4)1.f : (FLT4).0f;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastLessEqual(__read_only image2d_t input_a, float b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a <= (FLT4)b ? (FLT4)1.f : (FLT4).0f;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastGreater(__read_only image2d_t input_a, float b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a > (FLT4)b ? (FLT4)1.f : (FLT4).0f;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void BroadcastGreaterEqual(__read_only image2d_t input_a, float b, __write_only image2d_t output,\n" \ +" const int2 output_shape, float act_min, float act_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(X, Y));\n" \ +" FLT4 result = a >= (FLT4)b ? (FLT4)1.f : (FLT4).0f;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), result);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/arithmeticself.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/arithmeticself.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..531f5b32641b4efeb8298281e5de09d63bc94e31 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/arithmeticself.cl.inc @@ -0,0 +1,212 @@ +static const char *arithmeticself_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementAbs_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = result.x >= 0 ? result.x : -result.x;\n" \ +" result.y = result.y >= 0 ? result.y : -result.y;\n" \ +" result.z = result.z >= 0 ? result.z : -result.z;\n" \ +" result.w = result.w >= 0 ? result.w : -result.w;\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementCos_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = cos(result.x);\n" \ +" result.y = cos(result.y);\n" \ +" result.z = cos(result.z);\n" \ +" result.w = cos(result.w);\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementSin_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = sin(result.x);\n" \ +" result.y = sin(result.y);\n" \ +" result.z = sin(result.z);\n" \ +" result.w = sin(result.w);\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementNeg_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = -result.x;\n" \ +" result.y = -result.y;\n" \ +" result.z = -result.z;\n" \ +" result.w = -result.w;\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementExp_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = exp(result.x);\n" \ +" result.y = exp(result.y);\n" \ +" result.z = exp(result.z);\n" \ +" result.w = exp(result.w);\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementLog_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = result.x > 0 ? log(result.x) : HUGE_VALF;\n" \ +" result.y = result.y > 0 ? log(result.y) : HUGE_VALF;\n" \ +" result.z = result.z > 0 ? log(result.z) : HUGE_VALF;\n" \ +" result.w = result.w > 0 ? log(result.w) : HUGE_VALF;\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementSquare_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = result.x * result.x;\n" \ +" result.y = result.y * result.y;\n" \ +" result.z = result.z * result.z;\n" \ +" result.w = result.w * result.w;\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementSqrt_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = result.x > 0 ? sqrt(result.x) : HUGE_VALF;\n" \ +" result.y = result.y > 0 ? sqrt(result.y) : HUGE_VALF;\n" \ +" result.z = result.z > 0 ? sqrt(result.z) : HUGE_VALF;\n" \ +" result.w = result.w > 0 ? sqrt(result.w) : HUGE_VALF;\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementRsqrt_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = result.x > 0 ? 1.0f / sqrt(result.x) : HUGE_VALF;\n" \ +" result.y = result.y > 0 ? 1.0f / sqrt(result.y) : HUGE_VALF;\n" \ +" result.z = result.z > 0 ? 1.0f / sqrt(result.z) : HUGE_VALF;\n" \ +" result.w = result.w > 0 ? 1.0f / sqrt(result.w) : HUGE_VALF;\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementLogicalNot_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = result.x > 0 || result.x < 0 ? false : true;\n" \ +" result.y = result.y > 0 || result.y < 0 ? false : true;\n" \ +" result.z = result.z > 0 || result.z < 0 ? false : true;\n" \ +" result.w = result.w > 0 || result.w < 0 ? false : true;\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementFloor_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = floor(result.x);\n" \ +" result.y = floor(result.y);\n" \ +" result.z = floor(result.z);\n" \ +" result.w = floor(result.w);\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementCeil_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = ceil(result.x);\n" \ +" result.y = ceil(result.y);\n" \ +" result.z = ceil(result.z);\n" \ +" result.w = ceil(result.w);\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void ArithmeticSelf_ElementRound_NHWC4(__read_only image2d_t input0, __write_only image2d_t output,\n" \ +" int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // c/4\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (X)));\n" \ +" result.x = round(result.x);\n" \ +" result.y = round(result.y);\n" \ +" result.z = round(result.z);\n" \ +" result.w = round(result.w);\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/batch_to_space_nd.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/batch_to_space_nd.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..1ccd593079aa13be12521ba7a034466f7711bc4a --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/batch_to_space_nd.cl.inc @@ -0,0 +1,49 @@ +static const char *batch_to_space_nd_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void batch_to_space_nd_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 src_size,\n" \ +" int4 dst_size, int2 block_size, int4 paddings) {\n" \ +" int X = get_global_id(0); // c\n" \ +" int Y = get_global_id(1); // w\n" \ +" int Z = get_global_id(2); // h*n\n" \ +" if (X >= src_size.x || Y >= src_size.y || Y >= src_size.z) {\n" \ +" return;\n" \ +" }\n" \ +" for (int i = 0; i < block_size.x; ++i) {\n" \ +" for (int j = 0; j < block_size.y; ++j) {\n" \ +" int Y_dst = (Y * block_size.y + j);\n" \ +" int Z_dst = Z * block_size.x + i;\n" \ +" if (Y_dst >= dst_size.y || Z_dst >= dst_size.z) {\n" \ +" continue;\n" \ +" }\n" \ +" int Y_org = (Y_dst + paddings.z) / block_size.y;\n" \ +" int Z_org = (Z_dst + paddings.x) / block_size.x;\n" \ +" int Z_com = (i * block_size.y + j) * src_size.z + Z_org;\n" \ +" FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" res_data = READ_IMAGE(src_data, smp_zero, (int2)(Y_org * dst_size.x + X, Z_com));\n" \ +" WRITE_IMAGE(dst_data, (int2)((Y * block_size.y + j) * dst_size.x + X, Z * block_size.x + i), res_data);\n" \ +" }\n" \ +" }\n" \ +"}\n" \ +"__kernel void batch_to_space_nd_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 src_size,\n" \ +" int4 dst_size, int2 block_size, int4 paddings) {\n" \ +" int X = get_global_id(0); // c\n" \ +" int Y = get_global_id(1); // w\n" \ +" int Z = get_global_id(2); // h*n\n" \ +" if (X >= dst_size.x || Y >= dst_size.y || Y >= dst_size.z) {\n" \ +" return;\n" \ +" }\n" \ +" for (int i = 0; i < block_size.x; ++i) {\n" \ +" for (int j = 0; j < block_size.y; ++j) {\n" \ +" int Y_dst = (Y * block_size.y + j);\n" \ +" int Z_dst = Z * block_size.x + i;\n" \ +" int Y_org = (Y_dst + paddings.z) / block_size.y;\n" \ +" int Z_org = (Z_dst + paddings.x) / block_size.x;\n" \ +" int Z_com = (i * block_size.y + j) * src_size.z + Z_org;\n" \ +" FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" res_data = READ_IMAGE(src_data, smp_zero, (int2)(Y_org * dst_size.x + X, Z_com));\n" \ +" WRITE_IMAGE(dst_data, (int2)((Y * block_size.y + j) * dst_size.x + X, Z * block_size.x + i), res_data);\n" \ +" }\n" \ +" }\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/batchnorm.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/batchnorm.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..fcdd7b3810c81339f89fddcf54364bd4dfd4aa0a --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/batchnorm.cl.inc @@ -0,0 +1,43 @@ +static const char *batchnorm_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#define INT4 int4\n" \ +"#define INT2 int2\n" \ +"#define C4NUM 4\n" \ +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n" \ +"__kernel void Batch_normalization_NHWC4(__read_only image2d_t input, __global FLT *scale, __global FLT *offset,\n" \ +" __global FLT *mean, __global FLT *variance, __write_only image2d_t output,\n" \ +" const INT4 input_shape, float epsilon, int unalign_input_w) {\n" \ +" int X = get_global_id(0); // H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // C/4\n" \ +" if (X >= input_shape.y || Y >= input_shape.z || Z >= input_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input, smp_none, (int2)((Y)*input_shape.w + Z, (X)));\n" \ +"\n" \ +" FLT result_mean[4] = {0.0f, 0.0f, 0.0f, 0.0f};\n" \ +" FLT result_var[4] = {0.0f, 0.0f, 0.0f, 0.0f};\n" \ +" FLT result_scale[4] = {1.0f, 1.0f, 1.0f, 1.0f};\n" \ +" FLT result_offset[4] = {0.0f, 0.0f, 0.0f, 0.0f};\n" \ +" if ((Z + 1) * C4NUM <= unalign_input_w) {\n" \ +" for (int i = 0; i < C4NUM; ++i) {\n" \ +" result_mean[i] = mean[Z * C4NUM + i];\n" \ +" result_var[i] = variance[Z * C4NUM + i];\n" \ +" result_scale[i] = scale[Z * C4NUM + i];\n" \ +" result_offset[i] = offset[Z * C4NUM + i];\n" \ +" }\n" \ +" } else {\n" \ +" for (int i = 0; i < unalign_input_w % C4NUM; ++i) {\n" \ +" result_mean[i] = mean[Z * C4NUM + i];\n" \ +" result_var[i] = variance[Z * C4NUM + i];\n" \ +" result_scale[i] = scale[Z * C4NUM + i];\n" \ +" result_offset[i] = offset[Z * C4NUM + i];\n" \ +" }\n" \ +" }\n" \ +" result.x = result_scale[0] * ((result.x - result_mean[0]) / sqrt(result_var[0] + epsilon)) + result_offset[0];\n" \ +" result.y = result_scale[1] * ((result.y - result_mean[1]) / sqrt(result_var[1] + epsilon)) + result_offset[1];\n" \ +" result.z = result_scale[2] * ((result.z - result_mean[2]) / sqrt(result_var[2] + epsilon)) + result_offset[2];\n" \ +" result.w = result_scale[3] * ((result.w - result_mean[3]) / sqrt(result_var[3] + epsilon)) + result_offset[3];\n" \ +" WRITE_IMAGE(output, (int2)((Y)*input_shape.w + Z, (X)), result);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/cast.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/cast.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..d83b174d9ebb093e27225e7fce2bef23a8dacd5f --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/cast.cl.inc @@ -0,0 +1,45 @@ +static const char *cast_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"\n" \ +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"__kernel void Cast_fp32_to_fp16(__read_only image2d_t input, __write_only image2d_t output, int2 XY) {\n" \ +" int x = get_global_id(0);\n" \ +" int y = get_global_id(1);\n" \ +" if (x >= XY.x || y >= XY.y) {\n" \ +" return;\n" \ +" }\n" \ +" half4 result = convert_half4(read_imagef(input, smp_none, (int2)(x, y)));\n" \ +" write_imageh(output, (int2)(x, y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void Cast_fp32_to_fp32(__read_only image2d_t input, __write_only image2d_t output, int2 XY) {\n" \ +" int x = get_global_id(0);\n" \ +" int y = get_global_id(1);\n" \ +" if (x >= XY.x || y >= XY.y) {\n" \ +" return;\n" \ +" }\n" \ +" float4 result = read_imagef(input, smp_none, (int2)(x, y));\n" \ +" write_imagef(output, (int2)(x, y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void Cast_fp16_to_fp16(__read_only image2d_t input, __write_only image2d_t output, int2 XY) {\n" \ +" int x = get_global_id(0);\n" \ +" int y = get_global_id(1);\n" \ +" if (x >= XY.x || y >= XY.y) {\n" \ +" return;\n" \ +" }\n" \ +" half4 result = read_imageh(input, smp_none, (int2)(x, y));\n" \ +" write_imageh(output, (int2)(x, y), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void Cast_fp16_to_fp32(__read_only image2d_t input, __write_only image2d_t output, int2 XY) {\n" \ +" int x = get_global_id(0);\n" \ +" int y = get_global_id(1);\n" \ +" if (x >= XY.x || y >= XY.y) {\n" \ +" return;\n" \ +" }\n" \ +" float4 result = convert_float4(read_imageh(input, smp_none, (int2)(x, y)));\n" \ +" write_imagef(output, (int2)(x, y), result);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/concat.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/concat.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..fa093b6353cd61d59188f8a36342a1dfee17eabf --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/concat.cl.inc @@ -0,0 +1,332 @@ +static const char *concat_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n" \ +"#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))\n" \ +"#define C4NUM 4\n" \ +"\n" \ +"// Align in Axis C for concat\n" \ +"#define CHECK_IDX \\\n" \ +" int X = get_global_id(0); \\\n" \ +" int Y = get_global_id(1); \\\n" \ +" int Z = get_global_id(2); \\\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \\\n" \ +" return; \\\n" \ +" } \\\n" \ +" DTYPE4 result;\n" \ +"\n" \ +"// axis = 1\n" \ +"#define DOConcat2inputaxis1_NHWC4 \\\n" \ +" int IN = X / output_shape.y; \\\n" \ +" int IH = X % output_shape.y; \\\n" \ +" int boundary0 = input_shape0.y; \\\n" \ +" int boundary1 = boundary0 + input_shape1.y; \\\n" \ +" if (IH < boundary0) { \\\n" \ +" int coordinate_x = Y * input_shape0.w + Z; \\\n" \ +" int coordinate_y = IN * input_shape0.y + IH; \\\n" \ +" result = READ_IMAGE(input0, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" } else if (IH < boundary1) { \\\n" \ +" int coordinate_x = Y * input_shape1.w + Z; \\\n" \ +" int coordinate_y = IN * input_shape1.y + IH - boundary0; \\\n" \ +" result = READ_IMAGE(input1, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define DOConcat3inputaxis1_NHWC4 \\\n" \ +" DOConcat2inputaxis1_NHWC4; \\\n" \ +" int boundary2 = boundary1 + input_shape2.y; \\\n" \ +" if (IH >= boundary1 && IH < boundary2) { \\\n" \ +" int coordinate_x = Y * input_shape2.w + Z; \\\n" \ +" int coordinate_y = IN * input_shape2.y + IH - boundary1; \\\n" \ +" result = READ_IMAGE(input2, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define DOConcat4inputaxis1_NHWC4 \\\n" \ +" DOConcat3inputaxis1_NHWC4; \\\n" \ +" int boundary3 = boundary2 + input_shape3.y; \\\n" \ +" if (IH >= boundary2 && IH < boundary3) { \\\n" \ +" int coordinate_x = Y * input_shape3.w + Z; \\\n" \ +" int coordinate_y = IN * input_shape3.y + IH - boundary2; \\\n" \ +" result = READ_IMAGE(input3, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define DOConcat5inputaxis1_NHWC4 \\\n" \ +" DOConcat4inputaxis1_NHWC4; \\\n" \ +" int boundary4 = boundary3 + input_shape4.y; \\\n" \ +" if (IH >= boundary3 && IH < boundary4) { \\\n" \ +" int coordinate_x = Y * input_shape4.w + Z; \\\n" \ +" int coordinate_y = IN * input_shape4.y + IH - boundary3; \\\n" \ +" result = READ_IMAGE(input4, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define DOConcat6inputaxis1_NHWC4 \\\n" \ +" DOConcat5inputaxis1_NHWC4; \\\n" \ +" int boundary5 = boundary4 + input_shape5.y; \\\n" \ +" if (IH >= boundary4 && IH < boundary5) { \\\n" \ +" int coordinate_x = Y * input_shape5.w + Z; \\\n" \ +" int coordinate_y = IN * input_shape5.y + IH - boundary4; \\\n" \ +" result = READ_IMAGE(input5, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"// axis = 2\n" \ +"#define DOConcat2inputaxis2_NHWC4 \\\n" \ +" int boundary0 = input_shape0.z; \\\n" \ +" int boundary1 = boundary0 + input_shape1.z; \\\n" \ +" if (Y < boundary0) { \\\n" \ +" int coordinate_x = Y * input_shape0.w + Z; \\\n" \ +" int coordinate_y = X; \\\n" \ +" result = READ_IMAGE(input0, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" } else if (Y < boundary1) { \\\n" \ +" int coordinate_x = (Y - boundary0) * input_shape1.w + Z; \\\n" \ +" int coordinate_y = X; \\\n" \ +" result = READ_IMAGE(input1, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define DOConcat3inputaxis2_NHWC4 \\\n" \ +" DOConcat2inputaxis2_NHWC4; \\\n" \ +" int boundary2 = boundary1 + input_shape2.z; \\\n" \ +" if (Y >= boundary1 && Y < boundary2) { \\\n" \ +" int coordinate_x = (Y - boundary1) * input_shape2.w + Z; \\\n" \ +" int coordinate_y = X; \\\n" \ +" result = READ_IMAGE(input2, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define DOConcat4inputaxis2_NHWC4 \\\n" \ +" DOConcat3inputaxis2_NHWC4; \\\n" \ +" int boundary3 = boundary2 + input_shape3.z; \\\n" \ +" if (Y >= boundary2 && Y < boundary3) { \\\n" \ +" int coordinate_x = (Y - boundary2) * input_shape3.w + Z; \\\n" \ +" int coordinate_y = X; \\\n" \ +" result = READ_IMAGE(input3, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define DOConcat5inputaxis2_NHWC4 \\\n" \ +" DOConcat4inputaxis2_NHWC4; \\\n" \ +" int boundary4 = boundary3 + input_shape4.z; \\\n" \ +" if (Y >= boundary3 && Y < boundary4) { \\\n" \ +" int coordinate_x = (Y - boundary3) * input_shape4.w + Z; \\\n" \ +" int coordinate_y = X; \\\n" \ +" result = READ_IMAGE(input4, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define DOConcat6inputaxis2_NHWC4 \\\n" \ +" DOConcat5inputaxis2_NHWC4; \\\n" \ +" int boundary5 = boundary4 + input_shape5.z; \\\n" \ +" if (Y >= boundary4 && Y < boundary5) { \\\n" \ +" int coordinate_x = (Y - boundary4) * input_shape5.w + Z; \\\n" \ +" int coordinate_y = X; \\\n" \ +" result = READ_IMAGE(input5, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"// axis = 3\n" \ +"#define DOConcat2inputaxis3_NHWC4 \\\n" \ +" int boundary0 = input_shape0.w; \\\n" \ +" int boundary1 = boundary0 + input_shape1.w; \\\n" \ +" if (Z < boundary0) { \\\n" \ +" int coordinate_x = Y * input_shape0.w + Z; \\\n" \ +" int coordinate_y = X; \\\n" \ +" result = READ_IMAGE(input0, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" } else if (Z < boundary1) { \\\n" \ +" int coordinate_x = Y * input_shape1.w + Z - boundary0; \\\n" \ +" int coordinate_y = X; \\\n" \ +" result = READ_IMAGE(input1, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define DOConcat3inputaxis3_NHWC4 \\\n" \ +" DOConcat2inputaxis3_NHWC4; \\\n" \ +" int boundary2 = boundary1 + input_shape2.w; \\\n" \ +" if (Z >= boundary1 && Z < boundary2) { \\\n" \ +" int coordinate_x = Y * input_shape2.w + Z - boundary1; \\\n" \ +" int coordinate_y = X; \\\n" \ +" result = READ_IMAGE(input2, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define DOConcat4inputaxis3_NHWC4 \\\n" \ +" DOConcat3inputaxis3_NHWC4; \\\n" \ +" int boundary3 = boundary2 + input_shape3.w; \\\n" \ +" if (Z >= boundary2 && Z < boundary3) { \\\n" \ +" int coordinate_x = Y * input_shape3.w + Z - boundary2; \\\n" \ +" int coordinate_y = X; \\\n" \ +" result = READ_IMAGE(input3, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define DOConcat5inputaxis3_NHWC4 \\\n" \ +" DOConcat4inputaxis3_NHWC4; \\\n" \ +" int boundary4 = boundary3 + input_shape4.w; \\\n" \ +" if (Z >= boundary3 && Z < boundary4) { \\\n" \ +" int coordinate_x = Y * input_shape4.w + Z - boundary3; \\\n" \ +" int coordinate_y = X; \\\n" \ +" result = READ_IMAGE(input4, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define DOConcat6inputaxis3_NHWC4 \\\n" \ +" DOConcat5inputaxis3_NHWC4; \\\n" \ +" int boundary5 = boundary4 + input_shape5.w; \\\n" \ +" if (Z >= boundary4 && Z < boundary5) { \\\n" \ +" int coordinate_x = Y * input_shape5.w + Z - boundary4; \\\n" \ +" int coordinate_y = X; \\\n" \ +" result = READ_IMAGE(input5, smp_none, (int2)(coordinate_x, coordinate_y)); \\\n" \ +" }\n" \ +"\n" \ +"#define CONCAT6(Inputnum, Axis, ToFormat) \\\n" \ +" __kernel void Concat##Inputnum##Axis##ToFormat( \\\n" \ +" __read_only image2d_t input0, __read_only image2d_t input1, __read_only image2d_t input2, \\\n" \ +" __read_only image2d_t input3, __read_only image2d_t input4, __read_only image2d_t input5, \\\n" \ +" __write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 input_shape2, int4 input_shape3, \\\n" \ +" int4 input_shape4, int4 input_shape5, int4 output_shape) { \\\n" \ +" CHECK_IDX; \\\n" \ +" DOConcat##Inputnum##Axis##ToFormat; \\\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); \\\n" \ +" }\n" \ +"\n" \ +"#define CONCAT5(Inputnum, Axis, ToFormat) \\\n" \ +" __kernel void Concat##Inputnum##Axis##ToFormat( \\\n" \ +" __read_only image2d_t input0, __read_only image2d_t input1, __read_only image2d_t input2, \\\n" \ +" __read_only image2d_t input3, __read_only image2d_t input4, __write_only image2d_t output, int4 input_shape0, \\\n" \ +" int4 input_shape1, int4 input_shape2, int4 input_shape3, int4 input_shape4, int4 output_shape) { \\\n" \ +" CHECK_IDX; \\\n" \ +" DOConcat##Inputnum##Axis##ToFormat; \\\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); \\\n" \ +" }\n" \ +"\n" \ +"#define CONCAT4(Inputnum, Axis, ToFormat) \\\n" \ +" __kernel void Concat##Inputnum##Axis##ToFormat(__read_only image2d_t input0, __read_only image2d_t input1, \\\n" \ +" __read_only image2d_t input2, __read_only image2d_t input3, \\\n" \ +" __write_only image2d_t output, int4 input_shape0, int4 input_shape1, \\\n" \ +" int4 input_shape2, int4 input_shape3, int4 output_shape) { \\\n" \ +" CHECK_IDX \\\n" \ +" DOConcat##Inputnum##Axis##ToFormat; \\\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); \\\n" \ +" }\n" \ +"\n" \ +"#define CONCAT3(Inputnum, Axis, ToFormat) \\\n" \ +" __kernel void Concat##Inputnum##Axis##ToFormat( \\\n" \ +" __read_only image2d_t input0, __read_only image2d_t input1, __read_only image2d_t input2, \\\n" \ +" __write_only image2d_t output, int4 input_shape0, int4 input_shape1, int4 input_shape2, int4 output_shape) { \\\n" \ +" CHECK_IDX \\\n" \ +" DOConcat##Inputnum##Axis##ToFormat; \\\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); \\\n" \ +" }\n" \ +"\n" \ +"#define CONCAT2(Inputnum, Axis, ToFormat) \\\n" \ +" __kernel void Concat##Inputnum##Axis##ToFormat(__read_only image2d_t input0, __read_only image2d_t input1, \\\n" \ +" __write_only image2d_t output, int4 input_shape0, int4 input_shape1, \\\n" \ +" int4 output_shape) { \\\n" \ +" CHECK_IDX \\\n" \ +" DOConcat##Inputnum##Axis##ToFormat; \\\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (X)), result); \\\n" \ +" }\n" \ +"\n" \ +"// axis = 1\n" \ +"CONCAT6(6input, axis1, _NHWC4)\n" \ +"CONCAT5(5input, axis1, _NHWC4)\n" \ +"CONCAT4(4input, axis1, _NHWC4)\n" \ +"CONCAT3(3input, axis1, _NHWC4)\n" \ +"CONCAT2(2input, axis1, _NHWC4)\n" \ +"\n" \ +"// axis = 2\n" \ +"CONCAT6(6input, axis2, _NHWC4)\n" \ +"CONCAT5(5input, axis2, _NHWC4)\n" \ +"CONCAT4(4input, axis2, _NHWC4)\n" \ +"CONCAT3(3input, axis2, _NHWC4)\n" \ +"CONCAT2(2input, axis2, _NHWC4)\n" \ +"\n" \ +"// axis = 3\n" \ +"CONCAT6(6input, axis3, _NHWC4)\n" \ +"CONCAT5(5input, axis3, _NHWC4)\n" \ +"CONCAT4(4input, axis3, _NHWC4)\n" \ +"CONCAT3(3input, axis3, _NHWC4)\n" \ +"CONCAT2(2input, axis3, _NHWC4)\n" \ +"\n" \ +"// UnAlign in Axis C for concat\n" \ +"#define CHECK_IDX_UNALIGN \\\n" \ +" int X = get_global_id(0); \\\n" \ +" int Y = get_global_id(1); \\\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z) { \\\n" \ +" return; \\\n" \ +" } \\\n" \ +" int IN = X / output_shape.y, IH = X % output_shape.y; \\\n" \ +" int IW = Y; \\\n" \ +" int Align_Shape0 = UP_DIV(input_shape0.w, C4NUM), Align_Shape1 = UP_DIV(input_shape1.w, C4NUM); \\\n" \ +" int Align_OutShape = output_shape.w; \\\n" \ +" int index_output = (IN * output_shape.y + IH) * stride_w + IW * Align_OutShape * C4NUM;\n" \ +"\n" \ +"int doconcat(__read_only image2d_t input, __global DTYPE *output, int Align_Shape, int4 input_shape, int IN, int IH,\n" \ +" int Y, int index_output) {\n" \ +" int Remainder = input_shape.w % C4NUM;\n" \ +" for (int i = 0; i < Align_Shape; ++i) {\n" \ +" DTYPE4 result = READ_IMAGE(input, smp_none, (int2)((Y * Align_Shape + i), (IN * input_shape.y + IH)));\n" \ +" DTYPE result_temp[4] = {result.x, result.y, result.z, result.w};\n" \ +" if ((i + 1) * C4NUM <= input_shape.w) {\n" \ +" for (int j = 0; j < C4NUM; ++j) {\n" \ +" output[index_output++] = result_temp[j];\n" \ +" }\n" \ +" } else {\n" \ +" for (int j = 0; j < Remainder; ++j) {\n" \ +" output[index_output++] = result_temp[j];\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +" return index_output;\n" \ +"}\n" \ +"\n" \ +"__kernel void ConcatInput2UnAlign_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,\n" \ +" __global DTYPE *output, int4 input_shape0, int4 input_shape1, int stride_w,\n" \ +" int4 output_shape) {\n" \ +" CHECK_IDX_UNALIGN;\n" \ +" index_output = doconcat(input0, output, Align_Shape0, input_shape0, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input1, output, Align_Shape1, input_shape1, IN, IH, Y, index_output);\n" \ +"}\n" \ +"\n" \ +"__kernel void ConcatInput3UnAlign_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,\n" \ +" __read_only image2d_t input2, __global DTYPE *output, int4 input_shape0,\n" \ +" int4 input_shape1, int4 input_shape2, int stride_w, int4 output_shape) {\n" \ +" CHECK_IDX_UNALIGN;\n" \ +" int Align_Shape2 = UP_DIV(input_shape2.w, C4NUM);\n" \ +" index_output = doconcat(input0, output, Align_Shape0, input_shape0, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input1, output, Align_Shape1, input_shape1, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input2, output, Align_Shape2, input_shape2, IN, IH, Y, index_output);\n" \ +"}\n" \ +"\n" \ +"__kernel void ConcatInput4UnAlign_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,\n" \ +" __read_only image2d_t input2, __read_only image2d_t input3,\n" \ +" __global DTYPE *output, int4 input_shape0, int4 input_shape1, int4 input_shape2,\n" \ +" int4 input_shape3, int stride_w, int4 output_shape) {\n" \ +" CHECK_IDX_UNALIGN;\n" \ +" int Align_Shape2 = UP_DIV(input_shape2.w, C4NUM), Align_Shape3 = UP_DIV(input_shape3.w, C4NUM);\n" \ +" index_output = doconcat(input0, output, Align_Shape0, input_shape0, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input1, output, Align_Shape1, input_shape1, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input2, output, Align_Shape2, input_shape2, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input3, output, Align_Shape3, input_shape3, IN, IH, Y, index_output);\n" \ +"}\n" \ +"\n" \ +"__kernel void ConcatInput5UnAlign_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,\n" \ +" __read_only image2d_t input2, __read_only image2d_t input3,\n" \ +" __read_only image2d_t input4, __global DTYPE *output, int4 input_shape0,\n" \ +" int4 input_shape1, int4 input_shape2, int4 input_shape3, int4 input_shape4,\n" \ +" int stride_w, int4 output_shape) {\n" \ +" CHECK_IDX_UNALIGN;\n" \ +" int Align_Shape2 = UP_DIV(input_shape2.w, C4NUM), Align_Shape3 = UP_DIV(input_shape3.w, C4NUM);\n" \ +" int Align_Shape4 = UP_DIV(input_shape4.w, C4NUM);\n" \ +" index_output = doconcat(input0, output, Align_Shape0, input_shape0, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input1, output, Align_Shape1, input_shape1, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input2, output, Align_Shape2, input_shape2, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input3, output, Align_Shape3, input_shape3, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input4, output, Align_Shape4, input_shape4, IN, IH, Y, index_output);\n" \ +"}\n" \ +"\n" \ +"__kernel void ConcatInput6UnAlign_NHWC4(__read_only image2d_t input0, __read_only image2d_t input1,\n" \ +" __read_only image2d_t input2, __read_only image2d_t input3,\n" \ +" __read_only image2d_t input4, __read_only image2d_t input5,\n" \ +" __global DTYPE *output, int4 input_shape0, int4 input_shape1, int4 input_shape2,\n" \ +" int4 input_shape3, int4 input_shape4, int4 input_shape5, int stride_w,\n" \ +" int4 output_shape) {\n" \ +" CHECK_IDX_UNALIGN;\n" \ +" int Align_Shape2 = UP_DIV(input_shape2.w, C4NUM), Align_Shape3 = UP_DIV(input_shape3.w, C4NUM);\n" \ +" int Align_Shape4 = UP_DIV(input_shape4.w, C4NUM), Align_Shape5 = UP_DIV(input_shape5.w, C4NUM);\n" \ +" index_output = doconcat(input0, output, Align_Shape0, input_shape0, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input1, output, Align_Shape1, input_shape1, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input2, output, Align_Shape2, input_shape2, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input3, output, Align_Shape3, input_shape3, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input4, output, Align_Shape4, input_shape4, IN, IH, Y, index_output);\n" \ +" index_output = doconcat(input5, output, Align_Shape5, input_shape5, IN, IH, Y, index_output);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/conv2d.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/conv2d.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..a3e70731e35bf44b33707a03c105e39268c81f8e --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/conv2d.cl.inc @@ -0,0 +1,1541 @@ +static const char *conv2d_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"#define CI_TILE 4\n" \ +"#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))\n" \ +"\n" \ +"#define DEFINE_ARGS \\\n" \ +" int N = input_shape.x; \\\n" \ +" int IH = input_shape.y, IW = input_shape.z, CI_SLICES = input_shape.w; \\\n" \ +" int OH = output_shape.y, OW = output_shape.z, CO_SLICES = output_shape.w; \\\n" \ +" int KH = kernel_stride.x, KW = kernel_stride.y; \\\n" \ +" int strideH = kernel_stride.z, strideW = kernel_stride.w; \\\n" \ +" int padTop = pad.x, padBottom = pad.y, padLeft = pad.z, padRight = pad.w; \\\n" \ +" int dilationH = dilation.x, dilationW = dilation.y; \\\n" \ +" \\\n" \ +" int n_oh = get_global_id(0); \\\n" \ +" int ow = get_global_id(1) * BlockW; \\\n" \ +" int co_slice = get_global_id(2) * BlockC; \\\n" \ +" int OH_SLICES = UP_DIV(OH, BlockH); \\\n" \ +" int n = n_oh / OH_SLICES; \\\n" \ +" int oh = (n_oh % OH_SLICES) * BlockH; \\\n" \ +" if (n >= N || oh >= OH || ow >= OW || co_slice >= CO_SLICES) { \\\n" \ +" return; \\\n" \ +" }\n" \ +"\n" \ +"#define DO_TANH(data) \\\n" \ +" exp0 = exp(data); \\\n" \ +" exp1 = exp(-data); \\\n" \ +" data = (exp0 - exp1) / (exp0 + exp1);\n" \ +"\n" \ +"#define DO_LEAKY_RELU(data, alpha) \\\n" \ +" data.x = data.x > 0 ? data.x : data.x * alpha; \\\n" \ +" data.y = data.y > 0 ? data.y : data.y * alpha; \\\n" \ +" data.z = data.z > 0 ? data.z : data.z * alpha; \\\n" \ +" data.w = data.w > 0 ? data.w : data.w * alpha;\n" \ +"\n" \ +"__kernel void Conv2D_H1W1C1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,\n" \ +" __global FLT4 *bias, int4 input_shape, int4 output_shape, int4 kernel_stride, int4 pad,\n" \ +" int2 dilation, int act_type, float alpha) {\n" \ +" const int BlockH = 1;\n" \ +" const int BlockW = 1;\n" \ +" const int BlockC = 1;\n" \ +" DEFINE_ARGS;\n" \ +"\n" \ +" int oh0 = oh + 0;\n" \ +" int n_oh0 = n * OH + oh0;\n" \ +" int ow0 = ow + 0;\n" \ +" int co_slice0 = co_slice + 0;\n" \ +"\n" \ +" FLT4 out_h0_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" __global FLT4 *weight_ptr = weight + co_slice / BlockC * KH * KW * CI_SLICES * BlockC * CI_TILE;\n" \ +"\n" \ +" for (int kh = 0; kh < KH; ++kh) {\n" \ +" int ih0 = kh * dilationH + oh0 * strideH - padTop;\n" \ +" int y_idx0 = (ih0 >= 0 && ih0 < IH) ? n * IH + ih0 : -1;\n" \ +"\n" \ +" for (int kw = 0; kw < KW; ++kw) {\n" \ +" int iw0 = kw * dilationW + ow0 * strideW - padLeft;\n" \ +" int x_idx0 = iw0 * CI_SLICES;\n" \ +"\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in_h0_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx0));\n" \ +" x_idx0++;\n" \ +"\n" \ +" out_h0_w0_c0 += weight_ptr[0] * in_h0_w0.x;\n" \ +" out_h0_w0_c0 += weight_ptr[1] * in_h0_w0.y;\n" \ +" out_h0_w0_c0 += weight_ptr[2] * in_h0_w0.z;\n" \ +" out_h0_w0_c0 += weight_ptr[3] * in_h0_w0.w;\n" \ +"\n" \ +" weight_ptr += 4;\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +"\n" \ +" if (bias != 0) {\n" \ +" out_h0_w0_c0 += bias[co_slice0];\n" \ +" }\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out_h0_w0_c0 = max(out_h0_w0_c0, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out_h0_w0_c0 = clamp(out_h0_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" FLT4 exp0, exp1;\n" \ +" DO_TANH(out_h0_w0_c0);\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) {\n" \ +" DO_LEAKY_RELU(out_h0_w0_c0, alpha);\n" \ +" } else if (act_type == ActivationType_SIGMOID) {\n" \ +" out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));\n" \ +" }\n" \ +"\n" \ +" if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);\n" \ +" } else {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void Conv2D_H2W1C1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,\n" \ +" __global FLT4 *bias, int4 input_shape, int4 output_shape, int4 kernel_stride, int4 pad,\n" \ +" int2 dilation, int act_type, float alpha) {\n" \ +" const int BlockH = 2;\n" \ +" const int BlockW = 1;\n" \ +" const int BlockC = 1;\n" \ +" DEFINE_ARGS;\n" \ +"\n" \ +" int oh0 = oh + 0;\n" \ +" int oh1 = oh + 1;\n" \ +" int n_oh0 = n * OH + oh0;\n" \ +" int n_oh1 = n * OH + oh1;\n" \ +" int ow0 = ow + 0;\n" \ +" int co_slice0 = co_slice + 0;\n" \ +"\n" \ +" FLT4 out_h0_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" __global FLT4 *weight_ptr = weight + co_slice / BlockC * KH * KW * CI_SLICES * BlockC * CI_TILE;\n" \ +"\n" \ +" for (int kh = 0; kh < KH; ++kh) {\n" \ +" int ih0 = kh * dilationH + oh0 * strideH - padTop;\n" \ +" // no need to check oh1, finally write out will check (oh1 < OH)\n" \ +" int ih1 = kh * dilationH + oh1 * strideH - padTop;\n" \ +" // check ih0 and ih1\n" \ +" int y_idx0 = (ih0 >= 0 && ih0 < IH) ? n * IH + ih0 : -1;\n" \ +" int y_idx1 = (ih1 >= 0 && ih1 < IH) ? n * IH + ih1 : -1;\n" \ +"\n" \ +" for (int kw = 0; kw < KW; ++kw) {\n" \ +" int iw0 = kw * dilationW + ow0 * strideW - padLeft;\n" \ +" int x_idx0 = iw0 * CI_SLICES;\n" \ +"\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in_h0_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx0));\n" \ +" FLT4 in_h1_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx1));\n" \ +" x_idx0++;\n" \ +"\n" \ +" out_h0_w0_c0 += weight_ptr[0] * in_h0_w0.x;\n" \ +" out_h1_w0_c0 += weight_ptr[0] * in_h1_w0.x;\n" \ +" out_h0_w0_c0 += weight_ptr[1] * in_h0_w0.y;\n" \ +" out_h1_w0_c0 += weight_ptr[1] * in_h1_w0.y;\n" \ +" out_h0_w0_c0 += weight_ptr[2] * in_h0_w0.z;\n" \ +" out_h1_w0_c0 += weight_ptr[2] * in_h1_w0.z;\n" \ +" out_h0_w0_c0 += weight_ptr[3] * in_h0_w0.w;\n" \ +" out_h1_w0_c0 += weight_ptr[3] * in_h1_w0.w;\n" \ +"\n" \ +" weight_ptr += 4;\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +"\n" \ +" if (bias != 0) {\n" \ +" out_h0_w0_c0 += bias[co_slice0];\n" \ +" out_h1_w0_c0 += bias[co_slice0];\n" \ +" }\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out_h0_w0_c0 = max(out_h0_w0_c0, (FLT4)(0.0f));\n" \ +" out_h1_w0_c0 = max(out_h1_w0_c0, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out_h0_w0_c0 = clamp(out_h0_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c0 = clamp(out_h1_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" FLT4 exp0, exp1;\n" \ +" DO_TANH(out_h0_w0_c0);\n" \ +" DO_TANH(out_h1_w0_c0);\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) {\n" \ +" DO_LEAKY_RELU(out_h0_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c0, alpha);\n" \ +" } else if (act_type == ActivationType_SIGMOID) {\n" \ +" out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));\n" \ +" out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));\n" \ +" }\n" \ +"\n" \ +" if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);\n" \ +" } // end if (oh1 < OH)\n" \ +" } else {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);\n" \ +" } // end (oh1 < OH)\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void Conv2D_H2W1C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,\n" \ +" __global FLT4 *bias, int4 input_shape, int4 output_shape, int4 kernel_stride, int4 pad,\n" \ +" int2 dilation, int act_type, float alpha) {\n" \ +" const int BlockH = 2;\n" \ +" const int BlockW = 1;\n" \ +" const int BlockC = 2;\n" \ +" DEFINE_ARGS;\n" \ +"\n" \ +" int oh0 = oh + 0;\n" \ +" int oh1 = oh + 1;\n" \ +" int n_oh0 = n * OH + oh0;\n" \ +" int n_oh1 = n * OH + oh1;\n" \ +" int ow0 = ow + 0;\n" \ +" int co_slice0 = co_slice + 0;\n" \ +" int co_slice1 = co_slice + 1;\n" \ +"\n" \ +" FLT4 out_h0_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w0_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" __global FLT4 *weight_ptr = weight + co_slice / BlockC * KH * KW * CI_SLICES * BlockC * CI_TILE;\n" \ +"\n" \ +" for (int kh = 0; kh < KH; ++kh) {\n" \ +" int ih0 = kh * dilationH + oh0 * strideH - padTop;\n" \ +" // no need to check oh1, finally write out will check (oh1 < OH)\n" \ +" int ih1 = kh * dilationH + oh1 * strideH - padTop;\n" \ +" // check ih0 and ih1\n" \ +" int y_idx0 = (ih0 >= 0 && ih0 < IH) ? n * IH + ih0 : -1;\n" \ +" int y_idx1 = (ih1 >= 0 && ih1 < IH) ? n * IH + ih1 : -1;\n" \ +"\n" \ +" for (int kw = 0; kw < KW; ++kw) {\n" \ +" int iw0 = kw * dilationW + ow0 * strideW - padLeft;\n" \ +" int x_idx0 = iw0 * CI_SLICES;\n" \ +"\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in_h0_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx0));\n" \ +" FLT4 in_h1_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx1));\n" \ +" x_idx0++;\n" \ +"\n" \ +" out_h0_w0_c0 += weight_ptr[0] * in_h0_w0.x;\n" \ +" out_h1_w0_c0 += weight_ptr[0] * in_h1_w0.x;\n" \ +" out_h0_w0_c0 += weight_ptr[1] * in_h0_w0.y;\n" \ +" out_h1_w0_c0 += weight_ptr[1] * in_h1_w0.y;\n" \ +" out_h0_w0_c0 += weight_ptr[2] * in_h0_w0.z;\n" \ +" out_h1_w0_c0 += weight_ptr[2] * in_h1_w0.z;\n" \ +" out_h0_w0_c0 += weight_ptr[3] * in_h0_w0.w;\n" \ +" out_h1_w0_c0 += weight_ptr[3] * in_h1_w0.w;\n" \ +"\n" \ +" out_h0_w0_c1 += weight_ptr[4] * in_h0_w0.x;\n" \ +" out_h1_w0_c1 += weight_ptr[4] * in_h1_w0.x;\n" \ +" out_h0_w0_c1 += weight_ptr[5] * in_h0_w0.y;\n" \ +" out_h1_w0_c1 += weight_ptr[5] * in_h1_w0.y;\n" \ +" out_h0_w0_c1 += weight_ptr[6] * in_h0_w0.z;\n" \ +" out_h1_w0_c1 += weight_ptr[6] * in_h1_w0.z;\n" \ +" out_h0_w0_c1 += weight_ptr[7] * in_h0_w0.w;\n" \ +" out_h1_w0_c1 += weight_ptr[7] * in_h1_w0.w;\n" \ +"\n" \ +" weight_ptr += 8;\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +"\n" \ +" if (bias != 0) {\n" \ +" out_h0_w0_c0 += bias[co_slice0];\n" \ +" out_h1_w0_c0 += bias[co_slice0];\n" \ +" out_h0_w0_c1 += bias[co_slice1];\n" \ +" out_h1_w0_c1 += bias[co_slice1];\n" \ +" }\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out_h0_w0_c0 = max(out_h0_w0_c0, (FLT4)(0.0f));\n" \ +" out_h1_w0_c0 = max(out_h1_w0_c0, (FLT4)(0.0f));\n" \ +" out_h0_w0_c1 = max(out_h0_w0_c1, (FLT4)(0.0f));\n" \ +" out_h1_w0_c1 = max(out_h1_w0_c1, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out_h0_w0_c0 = clamp(out_h0_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c0 = clamp(out_h1_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w0_c1 = clamp(out_h0_w0_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c1 = clamp(out_h1_w0_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" FLT4 exp0, exp1;\n" \ +" DO_TANH(out_h0_w0_c0);\n" \ +" DO_TANH(out_h1_w0_c0);\n" \ +" DO_TANH(out_h0_w0_c1);\n" \ +" DO_TANH(out_h1_w0_c1);\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) {\n" \ +" DO_LEAKY_RELU(out_h0_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w0_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c1, alpha);\n" \ +" } else if (act_type == ActivationType_SIGMOID) {\n" \ +" out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));\n" \ +" out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));\n" \ +" out_h0_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c1));\n" \ +" out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1));\n" \ +" }\n" \ +"\n" \ +" if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);\n" \ +" } // end if (oh1 < OH)\n" \ +" if (co_slice1 < CO_SLICES) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);\n" \ +" } // end if (oh1 < OH)\n" \ +" } // end if (co_slice1 < CO_SLICES)\n" \ +" } else {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);\n" \ +" } // end (oh1 < OH)\n" \ +" if (co_slice1 < CO_SLICES) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);\n" \ +" } // end if (oh1 < OH)\n" \ +" } // end if (co_slice1 < CO_SLICES)\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void Conv2D_H2W2C1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,\n" \ +" __global FLT4 *bias, int4 input_shape, int4 output_shape, int4 kernel_stride, int4 pad,\n" \ +" int2 dilation, int act_type, float alpha) {\n" \ +" const int BlockH = 2;\n" \ +" const int BlockW = 2;\n" \ +" const int BlockC = 1;\n" \ +" DEFINE_ARGS;\n" \ +"\n" \ +" int oh0 = oh + 0;\n" \ +" int oh1 = oh + 1;\n" \ +" int n_oh0 = n * OH + oh0;\n" \ +" int n_oh1 = n * OH + oh1;\n" \ +" int ow0 = ow + 0;\n" \ +" int ow1 = ow + 1;\n" \ +" int co_slice0 = co_slice + 0;\n" \ +"\n" \ +" FLT4 out_h0_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w1_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w1_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" __global FLT4 *weight_ptr = weight + co_slice / BlockC * KH * KW * CI_SLICES * BlockC * CI_TILE;\n" \ +"\n" \ +" for (int kh = 0; kh < KH; ++kh) {\n" \ +" int ih0 = kh * dilationH + oh0 * strideH - padTop;\n" \ +" // no need to check oh1, finally write out will check (oh1 < OH)\n" \ +" int ih1 = kh * dilationH + oh1 * strideH - padTop;\n" \ +" // check ih0 and ih1\n" \ +" int y_idx0 = (ih0 >= 0 && ih0 < IH) ? n * IH + ih0 : -1;\n" \ +" int y_idx1 = (ih1 >= 0 && ih1 < IH) ? n * IH + ih1 : -1;\n" \ +"\n" \ +" for (int kw = 0; kw < KW; ++kw) {\n" \ +" int iw0 = kw * dilationW + ow0 * strideW - padLeft;\n" \ +" int iw1 = (ow1 < OW) ? kw * dilationW + ow1 * strideW - padLeft : -2;\n" \ +" int x_idx0 = iw0 * CI_SLICES;\n" \ +" int x_idx1 = iw1 * CI_SLICES;\n" \ +"\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in_h0_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx0));\n" \ +" FLT4 in_h0_w1 = READ_IMAGE(input, smp_zero, (int2)(x_idx1, y_idx0));\n" \ +" FLT4 in_h1_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx1));\n" \ +" FLT4 in_h1_w1 = READ_IMAGE(input, smp_zero, (int2)(x_idx1, y_idx1));\n" \ +" x_idx0++;\n" \ +" x_idx1++;\n" \ +"\n" \ +" out_h0_w0_c0 += weight_ptr[0] * in_h0_w0.x;\n" \ +" out_h0_w1_c0 += weight_ptr[0] * in_h0_w1.x;\n" \ +" out_h1_w0_c0 += weight_ptr[0] * in_h1_w0.x;\n" \ +" out_h1_w1_c0 += weight_ptr[0] * in_h1_w1.x;\n" \ +" out_h0_w0_c0 += weight_ptr[1] * in_h0_w0.y;\n" \ +" out_h0_w1_c0 += weight_ptr[1] * in_h0_w1.y;\n" \ +" out_h1_w0_c0 += weight_ptr[1] * in_h1_w0.y;\n" \ +" out_h1_w1_c0 += weight_ptr[1] * in_h1_w1.y;\n" \ +" out_h0_w0_c0 += weight_ptr[2] * in_h0_w0.z;\n" \ +" out_h0_w1_c0 += weight_ptr[2] * in_h0_w1.z;\n" \ +" out_h1_w0_c0 += weight_ptr[2] * in_h1_w0.z;\n" \ +" out_h1_w1_c0 += weight_ptr[2] * in_h1_w1.z;\n" \ +" out_h0_w0_c0 += weight_ptr[3] * in_h0_w0.w;\n" \ +" out_h0_w1_c0 += weight_ptr[3] * in_h0_w1.w;\n" \ +" out_h1_w0_c0 += weight_ptr[3] * in_h1_w0.w;\n" \ +" out_h1_w1_c0 += weight_ptr[3] * in_h1_w1.w;\n" \ +"\n" \ +" weight_ptr += 4;\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +"\n" \ +" if (bias != 0) {\n" \ +" out_h0_w0_c0 += bias[co_slice0];\n" \ +" out_h0_w1_c0 += bias[co_slice0];\n" \ +" out_h1_w0_c0 += bias[co_slice0];\n" \ +" out_h1_w1_c0 += bias[co_slice0];\n" \ +" }\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out_h0_w0_c0 = max(out_h0_w0_c0, (FLT4)(0.0f));\n" \ +" out_h0_w1_c0 = max(out_h0_w1_c0, (FLT4)(0.0f));\n" \ +" out_h1_w0_c0 = max(out_h1_w0_c0, (FLT4)(0.0f));\n" \ +" out_h1_w1_c0 = max(out_h1_w1_c0, (FLT4)(0.0f));\n" \ +"\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out_h0_w0_c0 = clamp(out_h0_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w1_c0 = clamp(out_h0_w1_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c0 = clamp(out_h1_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w1_c0 = clamp(out_h1_w1_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +"\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" FLT4 exp0, exp1;\n" \ +" DO_TANH(out_h0_w0_c0);\n" \ +" DO_TANH(out_h0_w1_c0);\n" \ +" DO_TANH(out_h1_w0_c0);\n" \ +" DO_TANH(out_h1_w1_c0);\n" \ +"\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) {\n" \ +" DO_LEAKY_RELU(out_h0_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w1_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w1_c0, alpha);\n" \ +"\n" \ +" } else if (act_type == ActivationType_SIGMOID) {\n" \ +" out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));\n" \ +" out_h0_w1_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w1_c0));\n" \ +" out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));\n" \ +" out_h1_w1_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c0));\n" \ +" }\n" \ +"\n" \ +" if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);\n" \ +" } // end if (oh1 < OH)\n" \ +" } else {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);\n" \ +" } // end (oh1 < OH)\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void Conv2D_H2W2C2(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,\n" \ +" __global FLT4 *bias, int4 input_shape, int4 output_shape, int4 kernel_stride, int4 pad,\n" \ +" int2 dilation, int act_type, float alpha) {\n" \ +" const int BlockH = 2;\n" \ +" const int BlockW = 2;\n" \ +" const int BlockC = 2;\n" \ +" DEFINE_ARGS;\n" \ +"\n" \ +" int oh0 = oh + 0;\n" \ +" int oh1 = oh + 1;\n" \ +" int n_oh0 = n * OH + oh0;\n" \ +" int n_oh1 = n * OH + oh1;\n" \ +" int ow0 = ow + 0;\n" \ +" int ow1 = ow + 1;\n" \ +" int co_slice0 = co_slice + 0;\n" \ +" int co_slice1 = co_slice + 1;\n" \ +"\n" \ +" FLT4 out_h0_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w1_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w1_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w0_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w1_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w1_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" __global FLT4 *weight_ptr = weight + co_slice / BlockC * KH * KW * CI_SLICES * BlockC * CI_TILE;\n" \ +"\n" \ +" for (int kh = 0; kh < KH; ++kh) {\n" \ +" int ih0 = kh * dilationH + oh0 * strideH - padTop;\n" \ +" // no need to check oh1, finally write out will check (oh1 < OH)\n" \ +" int ih1 = kh * dilationH + oh1 * strideH - padTop;\n" \ +" // check ih0 and ih1\n" \ +" int y_idx0 = (ih0 >= 0 && ih0 < IH) ? n * IH + ih0 : -1;\n" \ +" int y_idx1 = (ih1 >= 0 && ih1 < IH) ? n * IH + ih1 : -1;\n" \ +"\n" \ +" for (int kw = 0; kw < KW; ++kw) {\n" \ +" int iw0 = kw * dilationW + ow0 * strideW - padLeft;\n" \ +" int iw1 = (ow1 < OW) ? kw * dilationW + ow1 * strideW - padLeft : -2;\n" \ +" int x_idx0 = iw0 * CI_SLICES;\n" \ +" int x_idx1 = iw1 * CI_SLICES;\n" \ +"\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in_h0_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx0));\n" \ +" FLT4 in_h0_w1 = READ_IMAGE(input, smp_zero, (int2)(x_idx1, y_idx0));\n" \ +" FLT4 in_h1_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx1));\n" \ +" FLT4 in_h1_w1 = READ_IMAGE(input, smp_zero, (int2)(x_idx1, y_idx1));\n" \ +" x_idx0++;\n" \ +" x_idx1++;\n" \ +"\n" \ +" out_h0_w0_c0 += weight_ptr[0] * in_h0_w0.x;\n" \ +" out_h0_w1_c0 += weight_ptr[0] * in_h0_w1.x;\n" \ +" out_h1_w0_c0 += weight_ptr[0] * in_h1_w0.x;\n" \ +" out_h1_w1_c0 += weight_ptr[0] * in_h1_w1.x;\n" \ +" out_h0_w0_c0 += weight_ptr[1] * in_h0_w0.y;\n" \ +" out_h0_w1_c0 += weight_ptr[1] * in_h0_w1.y;\n" \ +" out_h1_w0_c0 += weight_ptr[1] * in_h1_w0.y;\n" \ +" out_h1_w1_c0 += weight_ptr[1] * in_h1_w1.y;\n" \ +" out_h0_w0_c0 += weight_ptr[2] * in_h0_w0.z;\n" \ +" out_h0_w1_c0 += weight_ptr[2] * in_h0_w1.z;\n" \ +" out_h1_w0_c0 += weight_ptr[2] * in_h1_w0.z;\n" \ +" out_h1_w1_c0 += weight_ptr[2] * in_h1_w1.z;\n" \ +" out_h0_w0_c0 += weight_ptr[3] * in_h0_w0.w;\n" \ +" out_h0_w1_c0 += weight_ptr[3] * in_h0_w1.w;\n" \ +" out_h1_w0_c0 += weight_ptr[3] * in_h1_w0.w;\n" \ +" out_h1_w1_c0 += weight_ptr[3] * in_h1_w1.w;\n" \ +"\n" \ +" out_h0_w0_c1 += weight_ptr[4] * in_h0_w0.x;\n" \ +" out_h0_w1_c1 += weight_ptr[4] * in_h0_w1.x;\n" \ +" out_h1_w0_c1 += weight_ptr[4] * in_h1_w0.x;\n" \ +" out_h1_w1_c1 += weight_ptr[4] * in_h1_w1.x;\n" \ +" out_h0_w0_c1 += weight_ptr[5] * in_h0_w0.y;\n" \ +" out_h0_w1_c1 += weight_ptr[5] * in_h0_w1.y;\n" \ +" out_h1_w0_c1 += weight_ptr[5] * in_h1_w0.y;\n" \ +" out_h1_w1_c1 += weight_ptr[5] * in_h1_w1.y;\n" \ +" out_h0_w0_c1 += weight_ptr[6] * in_h0_w0.z;\n" \ +" out_h0_w1_c1 += weight_ptr[6] * in_h0_w1.z;\n" \ +" out_h1_w0_c1 += weight_ptr[6] * in_h1_w0.z;\n" \ +" out_h1_w1_c1 += weight_ptr[6] * in_h1_w1.z;\n" \ +" out_h0_w0_c1 += weight_ptr[7] * in_h0_w0.w;\n" \ +" out_h0_w1_c1 += weight_ptr[7] * in_h0_w1.w;\n" \ +" out_h1_w0_c1 += weight_ptr[7] * in_h1_w0.w;\n" \ +" out_h1_w1_c1 += weight_ptr[7] * in_h1_w1.w;\n" \ +"\n" \ +" weight_ptr += 8;\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +"\n" \ +" if (bias != 0) {\n" \ +" out_h0_w0_c0 += bias[co_slice0];\n" \ +" out_h0_w1_c0 += bias[co_slice0];\n" \ +" out_h1_w0_c0 += bias[co_slice0];\n" \ +" out_h1_w1_c0 += bias[co_slice0];\n" \ +" out_h0_w0_c1 += bias[co_slice1];\n" \ +" out_h0_w1_c1 += bias[co_slice1];\n" \ +" out_h1_w0_c1 += bias[co_slice1];\n" \ +" out_h1_w1_c1 += bias[co_slice1];\n" \ +" }\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out_h0_w0_c0 = max(out_h0_w0_c0, (FLT4)(0.0f));\n" \ +" out_h0_w1_c0 = max(out_h0_w1_c0, (FLT4)(0.0f));\n" \ +" out_h1_w0_c0 = max(out_h1_w0_c0, (FLT4)(0.0f));\n" \ +" out_h1_w1_c0 = max(out_h1_w1_c0, (FLT4)(0.0f));\n" \ +" out_h0_w0_c1 = max(out_h0_w0_c1, (FLT4)(0.0f));\n" \ +" out_h0_w1_c1 = max(out_h0_w1_c1, (FLT4)(0.0f));\n" \ +" out_h1_w0_c1 = max(out_h1_w0_c1, (FLT4)(0.0f));\n" \ +" out_h1_w1_c1 = max(out_h1_w1_c1, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out_h0_w0_c0 = clamp(out_h0_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w1_c0 = clamp(out_h0_w1_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c0 = clamp(out_h1_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w1_c0 = clamp(out_h1_w1_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w0_c1 = clamp(out_h0_w0_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w1_c1 = clamp(out_h0_w1_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c1 = clamp(out_h1_w0_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w1_c1 = clamp(out_h1_w1_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" FLT4 exp0, exp1;\n" \ +" DO_TANH(out_h0_w0_c0);\n" \ +" DO_TANH(out_h0_w1_c0);\n" \ +" DO_TANH(out_h1_w0_c0);\n" \ +" DO_TANH(out_h1_w1_c0);\n" \ +" DO_TANH(out_h0_w0_c1);\n" \ +" DO_TANH(out_h0_w1_c1);\n" \ +" DO_TANH(out_h1_w0_c1);\n" \ +" DO_TANH(out_h1_w1_c1);\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) {\n" \ +" DO_LEAKY_RELU(out_h0_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w1_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w1_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w0_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w1_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w1_c1, alpha);\n" \ +" } else if (act_type == ActivationType_SIGMOID) {\n" \ +" out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));\n" \ +" out_h0_w1_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w1_c0));\n" \ +" out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));\n" \ +" out_h1_w1_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c0));\n" \ +" out_h0_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c1));\n" \ +" out_h0_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w1_c1));\n" \ +" out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1));\n" \ +" out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1));\n" \ +" }\n" \ +"\n" \ +" if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);\n" \ +" } // end if (oh1 < OH)\n" \ +" if (co_slice1 < CO_SLICES) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1);\n" \ +" } // end if (oh1 < OH)\n" \ +" } // end if (co_slice1 < CO_SLICES)\n" \ +" } else {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);\n" \ +" } // end (oh1 < OH)\n" \ +" if (co_slice1 < CO_SLICES) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);\n" \ +" } // end if (oh1 < OH)\n" \ +" } // end if (co_slice1 < CO_SLICES)\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void Conv2D_H2W2C2_Img(__read_only image2d_t input, __write_only image2d_t output,\n" \ +" __read_only image2d_t weight, __global FLT4 *bias, int4 input_shape, int4 output_shape,\n" \ +" int4 kernel_stride, int4 pad, int2 dilation, int act_type, float alpha) {\n" \ +" const int BlockH = 2;\n" \ +" const int BlockW = 2;\n" \ +" const int BlockC = 2;\n" \ +" DEFINE_ARGS;\n" \ +"\n" \ +" int oh0 = oh + 0;\n" \ +" int oh1 = oh + 1;\n" \ +" int n_oh0 = n * OH + oh0;\n" \ +" int n_oh1 = n * OH + oh1;\n" \ +" int ow0 = ow + 0;\n" \ +" int ow1 = ow + 1;\n" \ +" int co_slice0 = co_slice + 0;\n" \ +" int co_slice1 = co_slice + 1;\n" \ +"\n" \ +" FLT4 out_h0_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w1_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w1_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w0_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w1_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w1_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" int filter_offset = 0;\n" \ +" for (int kh = 0; kh < KH; ++kh) {\n" \ +" int ih0 = kh * dilationH + oh0 * strideH - padTop;\n" \ +" // no need to check oh1, finally write out will check (oh1 < OH)\n" \ +" int ih1 = kh * dilationH + oh1 * strideH - padTop;\n" \ +" // check ih0 and ih1\n" \ +" int y_idx0 = (ih0 >= 0 && ih0 < IH) ? n * IH + ih0 : -1;\n" \ +" int y_idx1 = (ih1 >= 0 && ih1 < IH) ? n * IH + ih1 : -1;\n" \ +"\n" \ +" for (int kw = 0; kw < KW; ++kw) {\n" \ +" int iw0 = kw * dilationW + ow0 * strideW - padLeft;\n" \ +" int iw1 = (ow1 < OW) ? kw * dilationW + ow1 * strideW - padLeft : -2;\n" \ +" int x_idx0 = iw0 * CI_SLICES;\n" \ +" int x_idx1 = iw1 * CI_SLICES;\n" \ +"\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in_h0_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx0));\n" \ +" FLT4 in_h0_w1 = READ_IMAGE(input, smp_zero, (int2)(x_idx1, y_idx0));\n" \ +" FLT4 in_h1_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx1));\n" \ +" FLT4 in_h1_w1 = READ_IMAGE(input, smp_zero, (int2)(x_idx1, y_idx1));\n" \ +" x_idx0++;\n" \ +" x_idx1++;\n" \ +"\n" \ +" FLT4 filter_ci0_co0 = READ_IMAGE(weight, smp_zero, (int2)(co_slice0, filter_offset + 0));\n" \ +" FLT4 filter_ci1_co0 = READ_IMAGE(weight, smp_zero, (int2)(co_slice0, filter_offset + 1));\n" \ +" FLT4 filter_ci2_co0 = READ_IMAGE(weight, smp_zero, (int2)(co_slice0, filter_offset + 2));\n" \ +" FLT4 filter_ci3_co0 = READ_IMAGE(weight, smp_zero, (int2)(co_slice0, filter_offset + 3));\n" \ +" FLT4 filter_ci0_co1 = READ_IMAGE(weight, smp_zero, (int2)(co_slice1, filter_offset + 0));\n" \ +" FLT4 filter_ci1_co1 = READ_IMAGE(weight, smp_zero, (int2)(co_slice1, filter_offset + 1));\n" \ +" FLT4 filter_ci2_co1 = READ_IMAGE(weight, smp_zero, (int2)(co_slice1, filter_offset + 2));\n" \ +" FLT4 filter_ci3_co1 = READ_IMAGE(weight, smp_zero, (int2)(co_slice1, filter_offset + 3));\n" \ +" filter_offset += 4;\n" \ +"\n" \ +" out_h0_w0_c0 += filter_ci0_co0 * in_h0_w0.x;\n" \ +" out_h0_w1_c0 += filter_ci0_co0 * in_h0_w1.x;\n" \ +" out_h1_w0_c0 += filter_ci0_co0 * in_h1_w0.x;\n" \ +" out_h1_w1_c0 += filter_ci0_co0 * in_h1_w1.x;\n" \ +" out_h0_w0_c0 += filter_ci1_co0 * in_h0_w0.y;\n" \ +" out_h0_w1_c0 += filter_ci1_co0 * in_h0_w1.y;\n" \ +" out_h1_w0_c0 += filter_ci1_co0 * in_h1_w0.y;\n" \ +" out_h1_w1_c0 += filter_ci1_co0 * in_h1_w1.y;\n" \ +" out_h0_w0_c0 += filter_ci2_co0 * in_h0_w0.z;\n" \ +" out_h0_w1_c0 += filter_ci2_co0 * in_h0_w1.z;\n" \ +" out_h1_w0_c0 += filter_ci2_co0 * in_h1_w0.z;\n" \ +" out_h1_w1_c0 += filter_ci2_co0 * in_h1_w1.z;\n" \ +" out_h0_w0_c0 += filter_ci3_co0 * in_h0_w0.w;\n" \ +" out_h0_w1_c0 += filter_ci3_co0 * in_h0_w1.w;\n" \ +" out_h1_w0_c0 += filter_ci3_co0 * in_h1_w0.w;\n" \ +" out_h1_w1_c0 += filter_ci3_co0 * in_h1_w1.w;\n" \ +"\n" \ +" out_h0_w0_c1 += filter_ci0_co1 * in_h0_w0.x;\n" \ +" out_h0_w1_c1 += filter_ci0_co1 * in_h0_w1.x;\n" \ +" out_h1_w0_c1 += filter_ci0_co1 * in_h1_w0.x;\n" \ +" out_h1_w1_c1 += filter_ci0_co1 * in_h1_w1.x;\n" \ +" out_h0_w0_c1 += filter_ci1_co1 * in_h0_w0.y;\n" \ +" out_h0_w1_c1 += filter_ci1_co1 * in_h0_w1.y;\n" \ +" out_h1_w0_c1 += filter_ci1_co1 * in_h1_w0.y;\n" \ +" out_h1_w1_c1 += filter_ci1_co1 * in_h1_w1.y;\n" \ +" out_h0_w0_c1 += filter_ci2_co1 * in_h0_w0.z;\n" \ +" out_h0_w1_c1 += filter_ci2_co1 * in_h0_w1.z;\n" \ +" out_h1_w0_c1 += filter_ci2_co1 * in_h1_w0.z;\n" \ +" out_h1_w1_c1 += filter_ci2_co1 * in_h1_w1.z;\n" \ +" out_h0_w0_c1 += filter_ci3_co1 * in_h0_w0.w;\n" \ +" out_h0_w1_c1 += filter_ci3_co1 * in_h0_w1.w;\n" \ +" out_h1_w0_c1 += filter_ci3_co1 * in_h1_w0.w;\n" \ +" out_h1_w1_c1 += filter_ci3_co1 * in_h1_w1.w;\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +"\n" \ +" if (bias != 0) {\n" \ +" out_h0_w0_c0 += bias[co_slice0];\n" \ +" out_h0_w1_c0 += bias[co_slice0];\n" \ +" out_h1_w0_c0 += bias[co_slice0];\n" \ +" out_h1_w1_c0 += bias[co_slice0];\n" \ +" out_h0_w0_c1 += bias[co_slice1];\n" \ +" out_h0_w1_c1 += bias[co_slice1];\n" \ +" out_h1_w0_c1 += bias[co_slice1];\n" \ +" out_h1_w1_c1 += bias[co_slice1];\n" \ +" }\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out_h0_w0_c0 = max(out_h0_w0_c0, (FLT4)(0.0f));\n" \ +" out_h0_w1_c0 = max(out_h0_w1_c0, (FLT4)(0.0f));\n" \ +" out_h1_w0_c0 = max(out_h1_w0_c0, (FLT4)(0.0f));\n" \ +" out_h1_w1_c0 = max(out_h1_w1_c0, (FLT4)(0.0f));\n" \ +" out_h0_w0_c1 = max(out_h0_w0_c1, (FLT4)(0.0f));\n" \ +" out_h0_w1_c1 = max(out_h0_w1_c1, (FLT4)(0.0f));\n" \ +" out_h1_w0_c1 = max(out_h1_w0_c1, (FLT4)(0.0f));\n" \ +" out_h1_w1_c1 = max(out_h1_w1_c1, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out_h0_w0_c0 = clamp(out_h0_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w1_c0 = clamp(out_h0_w1_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c0 = clamp(out_h1_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w1_c0 = clamp(out_h1_w1_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w0_c1 = clamp(out_h0_w0_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w1_c1 = clamp(out_h0_w1_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c1 = clamp(out_h1_w0_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w1_c1 = clamp(out_h1_w1_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" FLT4 exp0, exp1;\n" \ +" DO_TANH(out_h0_w0_c0);\n" \ +" DO_TANH(out_h0_w1_c0);\n" \ +" DO_TANH(out_h1_w0_c0);\n" \ +" DO_TANH(out_h1_w1_c0);\n" \ +" DO_TANH(out_h0_w0_c1);\n" \ +" DO_TANH(out_h0_w1_c1);\n" \ +" DO_TANH(out_h1_w0_c1);\n" \ +" DO_TANH(out_h1_w1_c1);\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) {\n" \ +" DO_LEAKY_RELU(out_h0_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w1_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w1_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w0_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w1_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w1_c1, alpha);\n" \ +" } else if (act_type == ActivationType_SIGMOID) {\n" \ +" out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));\n" \ +" out_h0_w1_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w1_c0));\n" \ +" out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));\n" \ +" out_h1_w1_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c0));\n" \ +" out_h0_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c1));\n" \ +" out_h0_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w1_c1));\n" \ +" out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1));\n" \ +" out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1));\n" \ +" }\n" \ +"\n" \ +" if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);\n" \ +" } // end if (oh1 < OH)\n" \ +" if (co_slice1 < CO_SLICES) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1);\n" \ +" } // end if (oh1 < OH)\n" \ +" } // end if (co_slice1 < CO_SLICES)\n" \ +" } else {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);\n" \ +" } // end (oh1 < OH)\n" \ +" if (co_slice1 < CO_SLICES) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);\n" \ +" } // end if (oh1 < OH)\n" \ +" } // end if (co_slice1 < CO_SLICES)\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void Conv2D_H1W1C1_1x1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,\n" \ +" __global FLT4 *bias, int4 input_shape, int4 output_shape, int4 kernel_stride, int4 pad,\n" \ +" int2 dilation, int act_type, float alpha) {\n" \ +" const int BlockH = 1;\n" \ +" const int BlockW = 1;\n" \ +" const int BlockC = 1;\n" \ +" DEFINE_ARGS;\n" \ +"\n" \ +" int oh0 = oh + 0;\n" \ +" int n_oh0 = n * OH + oh0;\n" \ +" int ow0 = ow + 0;\n" \ +" int co_slice0 = co_slice + 0;\n" \ +"\n" \ +" FLT4 out_h0_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" __global FLT4 *weight_ptr = weight + co_slice / BlockC * KH * KW * CI_SLICES * BlockC * CI_TILE;\n" \ +"\n" \ +" int ih0 = oh0 * strideH - padTop;\n" \ +" int y_idx0 = (ih0 >= 0 && ih0 < IH) ? n * IH + ih0 : -1;\n" \ +"\n" \ +" int iw0 = ow0 * strideW - padLeft;\n" \ +" int x_idx0 = iw0 * CI_SLICES;\n" \ +"\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in_h0_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx0));\n" \ +" x_idx0++;\n" \ +"\n" \ +" out_h0_w0_c0 += weight_ptr[0] * in_h0_w0.x;\n" \ +" out_h0_w0_c0 += weight_ptr[1] * in_h0_w0.y;\n" \ +" out_h0_w0_c0 += weight_ptr[2] * in_h0_w0.z;\n" \ +" out_h0_w0_c0 += weight_ptr[3] * in_h0_w0.w;\n" \ +"\n" \ +" weight_ptr += 4;\n" \ +" }\n" \ +"\n" \ +" if (bias != 0) {\n" \ +" out_h0_w0_c0 += bias[co_slice0];\n" \ +" }\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out_h0_w0_c0 = max(out_h0_w0_c0, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out_h0_w0_c0 = clamp(out_h0_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" FLT4 exp0, exp1;\n" \ +" DO_TANH(out_h0_w0_c0);\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) {\n" \ +" DO_LEAKY_RELU(out_h0_w0_c0, alpha);\n" \ +" } else if (act_type == ActivationType_SIGMOID) {\n" \ +" out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));\n" \ +" }\n" \ +"\n" \ +" if (OW * CO_SLICES <= MAX_IMAGE2D_WIDTH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);\n" \ +" } else {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void Conv2D_H2W1C1_1x1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,\n" \ +" __global FLT4 *bias, int4 input_shape, int4 output_shape, int4 kernel_stride, int4 pad,\n" \ +" int2 dilation, int act_type, float alpha) {\n" \ +" const int BlockH = 2;\n" \ +" const int BlockW = 1;\n" \ +" const int BlockC = 1;\n" \ +" DEFINE_ARGS;\n" \ +"\n" \ +" int oh0 = oh + 0;\n" \ +" int oh1 = oh + 1;\n" \ +" int n_oh0 = n * OH + oh0;\n" \ +" int n_oh1 = n * OH + oh1;\n" \ +" int ow0 = ow + 0;\n" \ +" int co_slice0 = co_slice + 0;\n" \ +"\n" \ +" FLT4 out_h0_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" __global FLT4 *weight_ptr = weight + co_slice / BlockC * KH * KW * CI_SLICES * BlockC * CI_TILE;\n" \ +"\n" \ +" int ih0 = oh0 * strideH - padTop;\n" \ +" // no need to check oh1, finally write out will check (oh1 < OH)\n" \ +" int ih1 = oh1 * strideH - padTop;\n" \ +" // check ih0 and ih1\n" \ +" int y_idx0 = (ih0 >= 0 && ih0 < IH) ? n * IH + ih0 : -1;\n" \ +" int y_idx1 = (ih1 >= 0 && ih1 < IH) ? n * IH + ih1 : -1;\n" \ +"\n" \ +" int iw0 = ow0 * strideW - padLeft;\n" \ +" int x_idx0 = iw0 * CI_SLICES;\n" \ +"\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in_h0_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx0));\n" \ +" FLT4 in_h1_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx1));\n" \ +" x_idx0++;\n" \ +"\n" \ +" out_h0_w0_c0 += weight_ptr[0] * in_h0_w0.x;\n" \ +" out_h1_w0_c0 += weight_ptr[0] * in_h1_w0.x;\n" \ +" out_h0_w0_c0 += weight_ptr[1] * in_h0_w0.y;\n" \ +" out_h1_w0_c0 += weight_ptr[1] * in_h1_w0.y;\n" \ +" out_h0_w0_c0 += weight_ptr[2] * in_h0_w0.z;\n" \ +" out_h1_w0_c0 += weight_ptr[2] * in_h1_w0.z;\n" \ +" out_h0_w0_c0 += weight_ptr[3] * in_h0_w0.w;\n" \ +" out_h1_w0_c0 += weight_ptr[3] * in_h1_w0.w;\n" \ +"\n" \ +" weight_ptr += 4;\n" \ +" }\n" \ +"\n" \ +" if (bias != 0) {\n" \ +" out_h0_w0_c0 += bias[co_slice0];\n" \ +" out_h1_w0_c0 += bias[co_slice0];\n" \ +" }\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out_h0_w0_c0 = max(out_h0_w0_c0, (FLT4)(0.0f));\n" \ +" out_h1_w0_c0 = max(out_h1_w0_c0, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out_h0_w0_c0 = clamp(out_h0_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c0 = clamp(out_h1_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" FLT4 exp0, exp1;\n" \ +" DO_TANH(out_h0_w0_c0);\n" \ +" DO_TANH(out_h1_w0_c0);\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) {\n" \ +" DO_LEAKY_RELU(out_h0_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c0, alpha);\n" \ +" } else if (act_type == ActivationType_SIGMOID) {\n" \ +" out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));\n" \ +" out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));\n" \ +" }\n" \ +"\n" \ +"#ifndef EXCEDD_MAX_IMAGE2D_WIDTH\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);\n" \ +" } // end if (oh1 < OH)\n" \ +"#else\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);\n" \ +" } // end (oh1 < OH)\n" \ +"#endif\n" \ +"}\n" \ +"\n" \ +"__kernel void Conv2D_H2W1C2_1x1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,\n" \ +" __global FLT4 *bias, int4 input_shape, int4 output_shape, int4 kernel_stride, int4 pad,\n" \ +" int2 dilation, int act_type, float alpha) {\n" \ +" const int BlockH = 2;\n" \ +" const int BlockW = 1;\n" \ +" const int BlockC = 2;\n" \ +" DEFINE_ARGS;\n" \ +"\n" \ +" int oh0 = oh + 0;\n" \ +" int oh1 = oh + 1;\n" \ +" int n_oh0 = n * OH + oh0;\n" \ +" int n_oh1 = n * OH + oh1;\n" \ +" int ow0 = ow + 0;\n" \ +" int co_slice0 = co_slice + 0;\n" \ +" int co_slice1 = co_slice + 1;\n" \ +"\n" \ +" FLT4 out_h0_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w0_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" __global FLT4 *weight_ptr = weight + co_slice / BlockC * KH * KW * CI_SLICES * BlockC * CI_TILE;\n" \ +"\n" \ +" int ih0 = oh0 * strideH - padTop;\n" \ +" // no need to check oh1, finally write out will check (oh1 < OH)\n" \ +" int ih1 = oh1 * strideH - padTop;\n" \ +" // check ih0 and ih1\n" \ +" int y_idx0 = (ih0 >= 0 && ih0 < IH) ? n * IH + ih0 : -1;\n" \ +" int y_idx1 = (ih1 >= 0 && ih1 < IH) ? n * IH + ih1 : -1;\n" \ +"\n" \ +" int iw0 = ow0 * strideW - padLeft;\n" \ +" int x_idx0 = iw0 * CI_SLICES;\n" \ +"\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in_h0_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx0));\n" \ +" FLT4 in_h1_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx1));\n" \ +" x_idx0++;\n" \ +"\n" \ +" out_h0_w0_c0 += weight_ptr[0] * in_h0_w0.x;\n" \ +" out_h1_w0_c0 += weight_ptr[0] * in_h1_w0.x;\n" \ +" out_h0_w0_c0 += weight_ptr[1] * in_h0_w0.y;\n" \ +" out_h1_w0_c0 += weight_ptr[1] * in_h1_w0.y;\n" \ +" out_h0_w0_c0 += weight_ptr[2] * in_h0_w0.z;\n" \ +" out_h1_w0_c0 += weight_ptr[2] * in_h1_w0.z;\n" \ +" out_h0_w0_c0 += weight_ptr[3] * in_h0_w0.w;\n" \ +" out_h1_w0_c0 += weight_ptr[3] * in_h1_w0.w;\n" \ +"\n" \ +" out_h0_w0_c1 += weight_ptr[4] * in_h0_w0.x;\n" \ +" out_h1_w0_c1 += weight_ptr[4] * in_h1_w0.x;\n" \ +" out_h0_w0_c1 += weight_ptr[5] * in_h0_w0.y;\n" \ +" out_h1_w0_c1 += weight_ptr[5] * in_h1_w0.y;\n" \ +" out_h0_w0_c1 += weight_ptr[6] * in_h0_w0.z;\n" \ +" out_h1_w0_c1 += weight_ptr[6] * in_h1_w0.z;\n" \ +" out_h0_w0_c1 += weight_ptr[7] * in_h0_w0.w;\n" \ +" out_h1_w0_c1 += weight_ptr[7] * in_h1_w0.w;\n" \ +"\n" \ +" weight_ptr += 8;\n" \ +" }\n" \ +"\n" \ +" if (bias != 0) {\n" \ +" out_h0_w0_c0 += bias[co_slice0];\n" \ +" out_h1_w0_c0 += bias[co_slice0];\n" \ +" out_h0_w0_c1 += bias[co_slice1];\n" \ +" out_h1_w0_c1 += bias[co_slice1];\n" \ +" }\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out_h0_w0_c0 = max(out_h0_w0_c0, (FLT4)(0.0f));\n" \ +" out_h1_w0_c0 = max(out_h1_w0_c0, (FLT4)(0.0f));\n" \ +" out_h0_w0_c1 = max(out_h0_w0_c1, (FLT4)(0.0f));\n" \ +" out_h1_w0_c1 = max(out_h1_w0_c1, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out_h0_w0_c0 = clamp(out_h0_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c0 = clamp(out_h1_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w0_c1 = clamp(out_h0_w0_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c1 = clamp(out_h1_w0_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" FLT4 exp0, exp1;\n" \ +" DO_TANH(out_h0_w0_c0);\n" \ +" DO_TANH(out_h1_w0_c0);\n" \ +" DO_TANH(out_h0_w0_c1);\n" \ +" DO_TANH(out_h1_w0_c1);\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) {\n" \ +" DO_LEAKY_RELU(out_h0_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w0_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c1, alpha);\n" \ +" } else if (act_type == ActivationType_SIGMOID) {\n" \ +" out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));\n" \ +" out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));\n" \ +" out_h0_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c1));\n" \ +" out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1));\n" \ +" }\n" \ +"\n" \ +"#ifndef EXCEDD_MAX_IMAGE2D_WIDTH\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);\n" \ +" } // end if (oh1 < OH)\n" \ +" if (co_slice1 < CO_SLICES) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);\n" \ +" } // end if (oh1 < OH)\n" \ +" } // end if (co_slice1 < CO_SLICES)\n" \ +"#else\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);\n" \ +" } // end (oh1 < OH)\n" \ +" if (co_slice1 < CO_SLICES) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);\n" \ +" } // end if (oh1 < OH)\n" \ +" } // end if (co_slice1 < CO_SLICES)\n" \ +"#endif\n" \ +"}\n" \ +"\n" \ +"__kernel void Conv2D_H2W2C1_1x1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,\n" \ +" __global FLT4 *bias, int4 input_shape, int4 output_shape, int4 kernel_stride, int4 pad,\n" \ +" int2 dilation, int act_type, float alpha) {\n" \ +" const int BlockH = 2;\n" \ +" const int BlockW = 2;\n" \ +" const int BlockC = 1;\n" \ +" DEFINE_ARGS;\n" \ +"\n" \ +" int oh0 = oh + 0;\n" \ +" int oh1 = oh + 1;\n" \ +" int n_oh0 = n * OH + oh0;\n" \ +" int n_oh1 = n * OH + oh1;\n" \ +" int ow0 = ow + 0;\n" \ +" int ow1 = ow + 1;\n" \ +" int co_slice0 = co_slice + 0;\n" \ +"\n" \ +" FLT4 out_h0_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w1_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w1_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" __global FLT4 *weight_ptr = weight + co_slice / BlockC * KH * KW * CI_SLICES * BlockC * CI_TILE;\n" \ +"\n" \ +" int ih0 = oh0 * strideH - padTop;\n" \ +" // no need to check oh1, finally write out will check (oh1 < OH)\n" \ +" int ih1 = oh1 * strideH - padTop;\n" \ +" // check ih0 and ih1\n" \ +" int y_idx0 = (ih0 >= 0 && ih0 < IH) ? n * IH + ih0 : -1;\n" \ +" int y_idx1 = (ih1 >= 0 && ih1 < IH) ? n * IH + ih1 : -1;\n" \ +"\n" \ +" int iw0 = ow0 * strideW - padLeft;\n" \ +" int iw1 = (ow1 < OW) ? ow1 * strideW - padLeft : -2;\n" \ +" int x_idx0 = iw0 * CI_SLICES;\n" \ +" int x_idx1 = iw1 * CI_SLICES;\n" \ +"\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in_h0_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx0));\n" \ +" FLT4 in_h0_w1 = READ_IMAGE(input, smp_zero, (int2)(x_idx1, y_idx0));\n" \ +" FLT4 in_h1_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx1));\n" \ +" FLT4 in_h1_w1 = READ_IMAGE(input, smp_zero, (int2)(x_idx1, y_idx1));\n" \ +" x_idx0++;\n" \ +" x_idx1++;\n" \ +"\n" \ +" out_h0_w0_c0 += weight_ptr[0] * in_h0_w0.x;\n" \ +" out_h0_w1_c0 += weight_ptr[0] * in_h0_w1.x;\n" \ +" out_h1_w0_c0 += weight_ptr[0] * in_h1_w0.x;\n" \ +" out_h1_w1_c0 += weight_ptr[0] * in_h1_w1.x;\n" \ +" out_h0_w0_c0 += weight_ptr[1] * in_h0_w0.y;\n" \ +" out_h0_w1_c0 += weight_ptr[1] * in_h0_w1.y;\n" \ +" out_h1_w0_c0 += weight_ptr[1] * in_h1_w0.y;\n" \ +" out_h1_w1_c0 += weight_ptr[1] * in_h1_w1.y;\n" \ +" out_h0_w0_c0 += weight_ptr[2] * in_h0_w0.z;\n" \ +" out_h0_w1_c0 += weight_ptr[2] * in_h0_w1.z;\n" \ +" out_h1_w0_c0 += weight_ptr[2] * in_h1_w0.z;\n" \ +" out_h1_w1_c0 += weight_ptr[2] * in_h1_w1.z;\n" \ +" out_h0_w0_c0 += weight_ptr[3] * in_h0_w0.w;\n" \ +" out_h0_w1_c0 += weight_ptr[3] * in_h0_w1.w;\n" \ +" out_h1_w0_c0 += weight_ptr[3] * in_h1_w0.w;\n" \ +" out_h1_w1_c0 += weight_ptr[3] * in_h1_w1.w;\n" \ +"\n" \ +" weight_ptr += 4;\n" \ +" }\n" \ +"\n" \ +" if (bias != 0) {\n" \ +" out_h0_w0_c0 += bias[co_slice0];\n" \ +" out_h0_w1_c0 += bias[co_slice0];\n" \ +" out_h1_w0_c0 += bias[co_slice0];\n" \ +" out_h1_w1_c0 += bias[co_slice0];\n" \ +" }\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out_h0_w0_c0 = max(out_h0_w0_c0, (FLT4)(0.0f));\n" \ +" out_h0_w1_c0 = max(out_h0_w1_c0, (FLT4)(0.0f));\n" \ +" out_h1_w0_c0 = max(out_h1_w0_c0, (FLT4)(0.0f));\n" \ +" out_h1_w1_c0 = max(out_h1_w1_c0, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out_h0_w0_c0 = clamp(out_h0_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w1_c0 = clamp(out_h0_w1_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c0 = clamp(out_h1_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w1_c0 = clamp(out_h1_w1_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" FLT4 exp0, exp1;\n" \ +" DO_TANH(out_h0_w0_c0);\n" \ +" DO_TANH(out_h0_w1_c0);\n" \ +" DO_TANH(out_h1_w0_c0);\n" \ +" DO_TANH(out_h1_w1_c0);\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) {\n" \ +" DO_LEAKY_RELU(out_h0_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w1_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w1_c0, alpha);\n" \ +" } else if (act_type == ActivationType_SIGMOID) {\n" \ +" out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));\n" \ +" out_h0_w1_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w1_c0));\n" \ +" out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));\n" \ +" out_h1_w1_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c0));\n" \ +" }\n" \ +"\n" \ +"#ifndef EXCEDD_MAX_IMAGE2D_WIDTH\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);\n" \ +" } // end if (oh1 < OH)\n" \ +"#else\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);\n" \ +" } // end (oh1 < OH)\n" \ +"#endif\n" \ +"}\n" \ +"\n" \ +"__kernel void Conv2D_H2W2C2_1x1(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight,\n" \ +" __global FLT4 *bias, int4 input_shape, int4 output_shape, int4 kernel_stride, int4 pad,\n" \ +" int2 dilation, int act_type, float alpha) {\n" \ +" const int BlockH = 2;\n" \ +" const int BlockW = 2;\n" \ +" const int BlockC = 2;\n" \ +" DEFINE_ARGS;\n" \ +"\n" \ +" int oh0 = oh + 0;\n" \ +" int oh1 = oh + 1;\n" \ +" int n_oh0 = n * OH + oh0;\n" \ +" int n_oh1 = n * OH + oh1;\n" \ +" int ow0 = ow + 0;\n" \ +" int ow1 = ow + 1;\n" \ +" int co_slice0 = co_slice + 0;\n" \ +" int co_slice1 = co_slice + 1;\n" \ +"\n" \ +" FLT4 out_h0_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w1_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w1_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w0_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w1_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w1_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" __global FLT4 *weight_ptr = weight + co_slice / BlockC * KH * KW * CI_SLICES * BlockC * CI_TILE;\n" \ +"\n" \ +" int ih0 = oh0 * strideH - padTop;\n" \ +" // no need to check oh1, finally write out will check (oh1 < OH)\n" \ +" int ih1 = oh1 * strideH - padTop;\n" \ +" // check ih0 and ih1\n" \ +" int y_idx0 = (ih0 >= 0 && ih0 < IH) ? n * IH + ih0 : -1;\n" \ +" int y_idx1 = (ih1 >= 0 && ih1 < IH) ? n * IH + ih1 : -1;\n" \ +"\n" \ +" int iw0 = ow0 * strideW - padLeft;\n" \ +" int iw1 = (ow1 < OW) ? ow1 * strideW - padLeft : -2;\n" \ +" int x_idx0 = iw0 * CI_SLICES;\n" \ +" int x_idx1 = iw1 * CI_SLICES;\n" \ +"\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in_h0_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx0));\n" \ +" FLT4 in_h0_w1 = READ_IMAGE(input, smp_zero, (int2)(x_idx1, y_idx0));\n" \ +" FLT4 in_h1_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx1));\n" \ +" FLT4 in_h1_w1 = READ_IMAGE(input, smp_zero, (int2)(x_idx1, y_idx1));\n" \ +" x_idx0++;\n" \ +" x_idx1++;\n" \ +"\n" \ +" out_h0_w0_c0 += weight_ptr[0] * in_h0_w0.x;\n" \ +" out_h0_w1_c0 += weight_ptr[0] * in_h0_w1.x;\n" \ +" out_h1_w0_c0 += weight_ptr[0] * in_h1_w0.x;\n" \ +" out_h1_w1_c0 += weight_ptr[0] * in_h1_w1.x;\n" \ +" out_h0_w0_c0 += weight_ptr[1] * in_h0_w0.y;\n" \ +" out_h0_w1_c0 += weight_ptr[1] * in_h0_w1.y;\n" \ +" out_h1_w0_c0 += weight_ptr[1] * in_h1_w0.y;\n" \ +" out_h1_w1_c0 += weight_ptr[1] * in_h1_w1.y;\n" \ +" out_h0_w0_c0 += weight_ptr[2] * in_h0_w0.z;\n" \ +" out_h0_w1_c0 += weight_ptr[2] * in_h0_w1.z;\n" \ +" out_h1_w0_c0 += weight_ptr[2] * in_h1_w0.z;\n" \ +" out_h1_w1_c0 += weight_ptr[2] * in_h1_w1.z;\n" \ +" out_h0_w0_c0 += weight_ptr[3] * in_h0_w0.w;\n" \ +" out_h0_w1_c0 += weight_ptr[3] * in_h0_w1.w;\n" \ +" out_h1_w0_c0 += weight_ptr[3] * in_h1_w0.w;\n" \ +" out_h1_w1_c0 += weight_ptr[3] * in_h1_w1.w;\n" \ +"\n" \ +" out_h0_w0_c1 += weight_ptr[4] * in_h0_w0.x;\n" \ +" out_h0_w1_c1 += weight_ptr[4] * in_h0_w1.x;\n" \ +" out_h1_w0_c1 += weight_ptr[4] * in_h1_w0.x;\n" \ +" out_h1_w1_c1 += weight_ptr[4] * in_h1_w1.x;\n" \ +" out_h0_w0_c1 += weight_ptr[5] * in_h0_w0.y;\n" \ +" out_h0_w1_c1 += weight_ptr[5] * in_h0_w1.y;\n" \ +" out_h1_w0_c1 += weight_ptr[5] * in_h1_w0.y;\n" \ +" out_h1_w1_c1 += weight_ptr[5] * in_h1_w1.y;\n" \ +" out_h0_w0_c1 += weight_ptr[6] * in_h0_w0.z;\n" \ +" out_h0_w1_c1 += weight_ptr[6] * in_h0_w1.z;\n" \ +" out_h1_w0_c1 += weight_ptr[6] * in_h1_w0.z;\n" \ +" out_h1_w1_c1 += weight_ptr[6] * in_h1_w1.z;\n" \ +" out_h0_w0_c1 += weight_ptr[7] * in_h0_w0.w;\n" \ +" out_h0_w1_c1 += weight_ptr[7] * in_h0_w1.w;\n" \ +" out_h1_w0_c1 += weight_ptr[7] * in_h1_w0.w;\n" \ +" out_h1_w1_c1 += weight_ptr[7] * in_h1_w1.w;\n" \ +"\n" \ +" weight_ptr += 8;\n" \ +" }\n" \ +"\n" \ +" if (bias != 0) {\n" \ +" out_h0_w0_c0 += bias[co_slice0];\n" \ +" out_h0_w1_c0 += bias[co_slice0];\n" \ +" out_h1_w0_c0 += bias[co_slice0];\n" \ +" out_h1_w1_c0 += bias[co_slice0];\n" \ +" out_h0_w0_c1 += bias[co_slice1];\n" \ +" out_h0_w1_c1 += bias[co_slice1];\n" \ +" out_h1_w0_c1 += bias[co_slice1];\n" \ +" out_h1_w1_c1 += bias[co_slice1];\n" \ +" }\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out_h0_w0_c0 = max(out_h0_w0_c0, (FLT4)(0.0f));\n" \ +" out_h0_w1_c0 = max(out_h0_w1_c0, (FLT4)(0.0f));\n" \ +" out_h1_w0_c0 = max(out_h1_w0_c0, (FLT4)(0.0f));\n" \ +" out_h1_w1_c0 = max(out_h1_w1_c0, (FLT4)(0.0f));\n" \ +" out_h0_w0_c1 = max(out_h0_w0_c1, (FLT4)(0.0f));\n" \ +" out_h0_w1_c1 = max(out_h0_w1_c1, (FLT4)(0.0f));\n" \ +" out_h1_w0_c1 = max(out_h1_w0_c1, (FLT4)(0.0f));\n" \ +" out_h1_w1_c1 = max(out_h1_w1_c1, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out_h0_w0_c0 = clamp(out_h0_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w1_c0 = clamp(out_h0_w1_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c0 = clamp(out_h1_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w1_c0 = clamp(out_h1_w1_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w0_c1 = clamp(out_h0_w0_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w1_c1 = clamp(out_h0_w1_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c1 = clamp(out_h1_w0_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w1_c1 = clamp(out_h1_w1_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" FLT4 exp0, exp1;\n" \ +" DO_TANH(out_h0_w0_c0);\n" \ +" DO_TANH(out_h0_w1_c0);\n" \ +" DO_TANH(out_h1_w0_c0);\n" \ +" DO_TANH(out_h1_w1_c0);\n" \ +" DO_TANH(out_h0_w0_c1);\n" \ +" DO_TANH(out_h0_w1_c1);\n" \ +" DO_TANH(out_h1_w0_c1);\n" \ +" DO_TANH(out_h1_w1_c1);\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) {\n" \ +" DO_LEAKY_RELU(out_h0_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w1_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w1_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w0_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w1_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w1_c1, alpha);\n" \ +" } else if (act_type == ActivationType_SIGMOID) {\n" \ +" out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));\n" \ +" out_h0_w1_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w1_c0));\n" \ +" out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));\n" \ +" out_h1_w1_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c0));\n" \ +" out_h0_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c1));\n" \ +" out_h0_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w1_c1));\n" \ +" out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1));\n" \ +" out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1));\n" \ +" }\n" \ +"\n" \ +"#ifndef EXCEDD_MAX_IMAGE2D_WIDTH\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);\n" \ +" } // end if (oh1 < OH)\n" \ +" if (co_slice1 < CO_SLICES) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1);\n" \ +" } // end if (oh1 < OH)\n" \ +" } // end if (co_slice1 < CO_SLICES)\n" \ +"#else\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);\n" \ +" } // end (oh1 < OH)\n" \ +" if (co_slice1 < CO_SLICES) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);\n" \ +" } // end if (oh1 < OH)\n" \ +" } // end if (co_slice1 < CO_SLICES)\n" \ +"#endif\n" \ +"}\n" \ +"\n" \ +"__kernel void Conv2D_H2W2C2_Img_1x1(__read_only image2d_t input, __write_only image2d_t output,\n" \ +" __read_only image2d_t weight, __global FLT4 *bias, int4 input_shape,\n" \ +" int4 output_shape, int4 kernel_stride, int4 pad, int2 dilation, int act_type,\n" \ +" float alpha) {\n" \ +" const int BlockH = 2;\n" \ +" const int BlockW = 2;\n" \ +" const int BlockC = 2;\n" \ +" DEFINE_ARGS;\n" \ +"\n" \ +" int oh0 = oh + 0;\n" \ +" int oh1 = oh + 1;\n" \ +" int n_oh0 = n * OH + oh0;\n" \ +" int n_oh1 = n * OH + oh1;\n" \ +" int ow0 = ow + 0;\n" \ +" int ow1 = ow + 1;\n" \ +" int co_slice0 = co_slice + 0;\n" \ +" int co_slice1 = co_slice + 1;\n" \ +"\n" \ +" FLT4 out_h0_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w1_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w1_c0 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w0_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h0_w1_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w0_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out_h1_w1_c1 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" int filter_offset = 0;\n" \ +" int ih0 = oh0 * strideH - padTop;\n" \ +" // no need to check oh1, finally write out will check (oh1 < OH)\n" \ +" int ih1 = oh1 * strideH - padTop;\n" \ +" // check ih0 and ih1\n" \ +" int y_idx0 = (ih0 >= 0 && ih0 < IH) ? n * IH + ih0 : -1;\n" \ +" int y_idx1 = (ih1 >= 0 && ih1 < IH) ? n * IH + ih1 : -1;\n" \ +"\n" \ +" int iw0 = ow0 * strideW - padLeft;\n" \ +" int iw1 = (ow1 < OW) ? ow1 * strideW - padLeft : -2;\n" \ +" int x_idx0 = iw0 * CI_SLICES;\n" \ +" int x_idx1 = iw1 * CI_SLICES;\n" \ +"\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in_h0_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx0));\n" \ +" FLT4 in_h0_w1 = READ_IMAGE(input, smp_zero, (int2)(x_idx1, y_idx0));\n" \ +" FLT4 in_h1_w0 = READ_IMAGE(input, smp_zero, (int2)(x_idx0, y_idx1));\n" \ +" FLT4 in_h1_w1 = READ_IMAGE(input, smp_zero, (int2)(x_idx1, y_idx1));\n" \ +" x_idx0++;\n" \ +" x_idx1++;\n" \ +"\n" \ +" FLT4 filter_ci0_co0 = READ_IMAGE(weight, smp_zero, (int2)(co_slice0, filter_offset + 0));\n" \ +" FLT4 filter_ci1_co0 = READ_IMAGE(weight, smp_zero, (int2)(co_slice0, filter_offset + 1));\n" \ +" FLT4 filter_ci2_co0 = READ_IMAGE(weight, smp_zero, (int2)(co_slice0, filter_offset + 2));\n" \ +" FLT4 filter_ci3_co0 = READ_IMAGE(weight, smp_zero, (int2)(co_slice0, filter_offset + 3));\n" \ +" FLT4 filter_ci0_co1 = READ_IMAGE(weight, smp_zero, (int2)(co_slice1, filter_offset + 0));\n" \ +" FLT4 filter_ci1_co1 = READ_IMAGE(weight, smp_zero, (int2)(co_slice1, filter_offset + 1));\n" \ +" FLT4 filter_ci2_co1 = READ_IMAGE(weight, smp_zero, (int2)(co_slice1, filter_offset + 2));\n" \ +" FLT4 filter_ci3_co1 = READ_IMAGE(weight, smp_zero, (int2)(co_slice1, filter_offset + 3));\n" \ +" filter_offset += 4;\n" \ +"\n" \ +" out_h0_w0_c0 += filter_ci0_co0 * in_h0_w0.x;\n" \ +" out_h0_w1_c0 += filter_ci0_co0 * in_h0_w1.x;\n" \ +" out_h1_w0_c0 += filter_ci0_co0 * in_h1_w0.x;\n" \ +" out_h1_w1_c0 += filter_ci0_co0 * in_h1_w1.x;\n" \ +" out_h0_w0_c0 += filter_ci1_co0 * in_h0_w0.y;\n" \ +" out_h0_w1_c0 += filter_ci1_co0 * in_h0_w1.y;\n" \ +" out_h1_w0_c0 += filter_ci1_co0 * in_h1_w0.y;\n" \ +" out_h1_w1_c0 += filter_ci1_co0 * in_h1_w1.y;\n" \ +" out_h0_w0_c0 += filter_ci2_co0 * in_h0_w0.z;\n" \ +" out_h0_w1_c0 += filter_ci2_co0 * in_h0_w1.z;\n" \ +" out_h1_w0_c0 += filter_ci2_co0 * in_h1_w0.z;\n" \ +" out_h1_w1_c0 += filter_ci2_co0 * in_h1_w1.z;\n" \ +" out_h0_w0_c0 += filter_ci3_co0 * in_h0_w0.w;\n" \ +" out_h0_w1_c0 += filter_ci3_co0 * in_h0_w1.w;\n" \ +" out_h1_w0_c0 += filter_ci3_co0 * in_h1_w0.w;\n" \ +" out_h1_w1_c0 += filter_ci3_co0 * in_h1_w1.w;\n" \ +"\n" \ +" out_h0_w0_c1 += filter_ci0_co1 * in_h0_w0.x;\n" \ +" out_h0_w1_c1 += filter_ci0_co1 * in_h0_w1.x;\n" \ +" out_h1_w0_c1 += filter_ci0_co1 * in_h1_w0.x;\n" \ +" out_h1_w1_c1 += filter_ci0_co1 * in_h1_w1.x;\n" \ +" out_h0_w0_c1 += filter_ci1_co1 * in_h0_w0.y;\n" \ +" out_h0_w1_c1 += filter_ci1_co1 * in_h0_w1.y;\n" \ +" out_h1_w0_c1 += filter_ci1_co1 * in_h1_w0.y;\n" \ +" out_h1_w1_c1 += filter_ci1_co1 * in_h1_w1.y;\n" \ +" out_h0_w0_c1 += filter_ci2_co1 * in_h0_w0.z;\n" \ +" out_h0_w1_c1 += filter_ci2_co1 * in_h0_w1.z;\n" \ +" out_h1_w0_c1 += filter_ci2_co1 * in_h1_w0.z;\n" \ +" out_h1_w1_c1 += filter_ci2_co1 * in_h1_w1.z;\n" \ +" out_h0_w0_c1 += filter_ci3_co1 * in_h0_w0.w;\n" \ +" out_h0_w1_c1 += filter_ci3_co1 * in_h0_w1.w;\n" \ +" out_h1_w0_c1 += filter_ci3_co1 * in_h1_w0.w;\n" \ +" out_h1_w1_c1 += filter_ci3_co1 * in_h1_w1.w;\n" \ +" }\n" \ +"\n" \ +" if (bias != 0) {\n" \ +" out_h0_w0_c0 += bias[co_slice0];\n" \ +" out_h0_w1_c0 += bias[co_slice0];\n" \ +" out_h1_w0_c0 += bias[co_slice0];\n" \ +" out_h1_w1_c0 += bias[co_slice0];\n" \ +" out_h0_w0_c1 += bias[co_slice1];\n" \ +" out_h0_w1_c1 += bias[co_slice1];\n" \ +" out_h1_w0_c1 += bias[co_slice1];\n" \ +" out_h1_w1_c1 += bias[co_slice1];\n" \ +" }\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out_h0_w0_c0 = max(out_h0_w0_c0, (FLT4)(0.0f));\n" \ +" out_h0_w1_c0 = max(out_h0_w1_c0, (FLT4)(0.0f));\n" \ +" out_h1_w0_c0 = max(out_h1_w0_c0, (FLT4)(0.0f));\n" \ +" out_h1_w1_c0 = max(out_h1_w1_c0, (FLT4)(0.0f));\n" \ +" out_h0_w0_c1 = max(out_h0_w0_c1, (FLT4)(0.0f));\n" \ +" out_h0_w1_c1 = max(out_h0_w1_c1, (FLT4)(0.0f));\n" \ +" out_h1_w0_c1 = max(out_h1_w0_c1, (FLT4)(0.0f));\n" \ +" out_h1_w1_c1 = max(out_h1_w1_c1, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out_h0_w0_c0 = clamp(out_h0_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w1_c0 = clamp(out_h0_w1_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c0 = clamp(out_h1_w0_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w1_c0 = clamp(out_h1_w1_c0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w0_c1 = clamp(out_h0_w0_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h0_w1_c1 = clamp(out_h0_w1_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w0_c1 = clamp(out_h1_w0_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" out_h1_w1_c1 = clamp(out_h1_w1_c1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" FLT4 exp0, exp1;\n" \ +" DO_TANH(out_h0_w0_c0);\n" \ +" DO_TANH(out_h0_w1_c0);\n" \ +" DO_TANH(out_h1_w0_c0);\n" \ +" DO_TANH(out_h1_w1_c0);\n" \ +" DO_TANH(out_h0_w0_c1);\n" \ +" DO_TANH(out_h0_w1_c1);\n" \ +" DO_TANH(out_h1_w0_c1);\n" \ +" DO_TANH(out_h1_w1_c1);\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) {\n" \ +" DO_LEAKY_RELU(out_h0_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w1_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w1_c0, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w0_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h0_w1_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w0_c1, alpha);\n" \ +" DO_LEAKY_RELU(out_h1_w1_c1, alpha);\n" \ +" } else if (act_type == ActivationType_SIGMOID) {\n" \ +" out_h0_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c0));\n" \ +" out_h0_w1_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w1_c0));\n" \ +" out_h1_w0_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c0));\n" \ +" out_h1_w1_c0 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c0));\n" \ +" out_h0_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w0_c1));\n" \ +" out_h0_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h0_w1_c1));\n" \ +" out_h1_w0_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w0_c1));\n" \ +" out_h1_w1_c1 = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-out_h1_w1_c1));\n" \ +" }\n" \ +"\n" \ +"#ifndef EXCEDD_MAX_IMAGE2D_WIDTH\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh0), out_h0_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh0), out_h0_w1_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice0, n_oh1), out_h1_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice0, n_oh1), out_h1_w1_c0);\n" \ +" } // end if (oh1 < OH)\n" \ +" if (co_slice1 < CO_SLICES) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh0), out_h0_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh0), out_h0_w1_c1);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(ow0 * CO_SLICES + co_slice1, n_oh1), out_h1_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(ow1 * CO_SLICES + co_slice1, n_oh1), out_h1_w1_c1);\n" \ +" } // end if (oh1 < OH)\n" \ +" } // end if (co_slice1 < CO_SLICES)\n" \ +"#else\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow0), out_h0_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh0 * OW + ow1), out_h0_w1_c0);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow0), out_h1_w0_c0);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice0, n_oh1 * OW + ow1), out_h1_w1_c0);\n" \ +" } // end (oh1 < OH)\n" \ +" if (co_slice1 < CO_SLICES) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow0), out_h0_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh0 * OW + ow1), out_h0_w1_c1);\n" \ +" if (oh1 < OH) {\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow0), out_h1_w0_c1);\n" \ +" WRITE_IMAGE(output, (int2)(co_slice1, n_oh1 * OW + ow1), out_h1_w1_c1);\n" \ +" } // end if (oh1 < OH)\n" \ +" } // end if (co_slice1 < CO_SLICES)\n" \ +"#endif\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/conv2d_transpose.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/conv2d_transpose.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..bbd8fa1108bf10985a2bc8a5589393030ade1f7f --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/conv2d_transpose.cl.inc @@ -0,0 +1,110 @@ +static const char *conv2d_transpose_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"__kernel void conv2d_transpose(__read_only image2d_t src_data, __write_only image2d_t dst_data, __global FLT16 *weight,\n" \ +" __read_only image2d_t biases, int2 kernel_size, int2 stride, int2 padding, int4 src_size,\n" \ +" int4 dst_size, int act_type) {\n" \ +" int dst_h = get_global_id(0);\n" \ +" int rem_h = dst_h % stride.x;\n" \ +" int ceil_h = dst_h / stride.x;\n" \ +" dst_h = ceil_h * stride.x * 2 + rem_h;\n" \ +" int dst_w = get_global_id(1);\n" \ +" int rem_w = dst_w % stride.y;\n" \ +" int ceil_w = dst_w / stride.y;\n" \ +" dst_w = ceil_w * stride.y * 2 + rem_w;\n" \ +" int dst_c = get_global_id(2); // n * c4\n" \ +" int n = dst_c / dst_size.z;\n" \ +" dst_c = dst_c % dst_size.z;\n" \ +" if (dst_h >= dst_size.x || dst_w >= dst_size.y || dst_c >= dst_size.z || n >= dst_size.w) return;\n" \ +" int weight_base = dst_c * src_size.z * kernel_size.x * kernel_size.y;\n" \ +" FLT4 r0 = (FLT4)(0.f);\n" \ +" FLT4 r1 = (FLT4)(0.f);\n" \ +" FLT4 r2 = (FLT4)(0.f);\n" \ +" FLT4 r3 = (FLT4)(0.f);\n" \ +" int kh_start = dst_h + padding.x;\n" \ +" int kw_start = dst_w + padding.y;\n" \ +" int kh_end = kh_start - kernel_size.x;\n" \ +" int kw_end = kw_start - kernel_size.y;\n" \ +" int src_h = kh_start / stride.x;\n" \ +" int kh = src_h * stride.x;\n" \ +" int src_w = kw_start / stride.y;\n" \ +" int kw = src_w * stride.y;\n" \ +" for (; kh > kh_end; src_h -= 1, kh -= stride.x) {\n" \ +" int out0_src_h = src_h;\n" \ +" int out1_src_h = src_h + 1;\n" \ +" int kernel_h = kh_start - kh;\n" \ +" int src_w_copy = src_w;\n" \ +" int kw_copy = kw;\n" \ +" for (; kw_copy > kw_end; src_w_copy -= 1, kw_copy -= stride.y) {\n" \ +" int out0_src_w = src_w_copy;\n" \ +" int out1_src_w = src_w_copy + 1;\n" \ +" int kernel_w = kw_start - kw_copy;\n" \ +" int weight_offset = weight_base + (kernel_h * kernel_size.y + kernel_w) * src_size.z;\n" \ +" for (int ci = 0; ci < src_size.z; ++ci) {\n" \ +" FLT4 x0 = (FLT4)0.f;\n" \ +" FLT4 x2 = (FLT4)0.f;\n" \ +" if (out0_src_h < src_size.x && out0_src_h >= 0) {\n" \ +" x0 = READ_IMAGE(src_data, smp_zero, (int2)(out0_src_w * src_size.z + ci, n * src_size.x + out0_src_h));\n" \ +" x2 = READ_IMAGE(src_data, smp_zero, (int2)(out1_src_w * src_size.z + ci, n * src_size.x + out0_src_h));\n" \ +" }\n" \ +" FLT4 x1 = (FLT4)0.f;\n" \ +" FLT4 x3 = (FLT4)0.f;\n" \ +" if (out1_src_h < src_size.x && out1_src_h >= 0) {\n" \ +" x1 = READ_IMAGE(src_data, smp_zero, (int2)(out0_src_w * src_size.z + ci, n * src_size.x + out1_src_h));\n" \ +" x3 = READ_IMAGE(src_data, smp_zero, (int2)(out1_src_w * src_size.z + ci, n * src_size.x + out1_src_h));\n" \ +" }\n" \ +" FLT16 weight_cache = weight[weight_offset++];\n" \ +" r0 += x0.x * weight_cache.s0123;\n" \ +" r0 += x0.y * weight_cache.s4567;\n" \ +" r0 += x0.z * weight_cache.s89ab;\n" \ +" r0 += x0.w * weight_cache.scdef;\n" \ +"\n" \ +" r1 += x1.x * weight_cache.s0123;\n" \ +" r1 += x1.y * weight_cache.s4567;\n" \ +" r1 += x1.z * weight_cache.s89ab;\n" \ +" r1 += x1.w * weight_cache.scdef;\n" \ +"\n" \ +" r2 += x2.x * weight_cache.s0123;\n" \ +" r2 += x2.y * weight_cache.s4567;\n" \ +" r2 += x2.z * weight_cache.s89ab;\n" \ +" r2 += x2.w * weight_cache.scdef;\n" \ +"\n" \ +" r3 += x3.x * weight_cache.s0123;\n" \ +" r3 += x3.y * weight_cache.s4567;\n" \ +" r3 += x3.z * weight_cache.s89ab;\n" \ +" r3 += x3.w * weight_cache.scdef;\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +" FLT4 bias_val = READ_IMAGE(biases, smp_zero, (int2)(dst_c, 0));\n" \ +" r0 += bias_val;\n" \ +" r1 += bias_val;\n" \ +" r2 += bias_val;\n" \ +" r3 += bias_val;\n" \ +"\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" r0 = max(r0, (FLT4)(0.0f));\n" \ +" r1 = max(r1, (FLT4)(0.0f));\n" \ +" r2 = max(r2, (FLT4)(0.0f));\n" \ +" r3 = max(r3, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" r0 = clamp(r0, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" r1 = clamp(r1, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" r2 = clamp(r2, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" r3 = clamp(r3, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" }\n" \ +"\n" \ +" WRITE_IMAGE(dst_data, (int2)(dst_w * dst_size.z + dst_c, n * dst_size.x + dst_h), r0);\n" \ +" if (dst_h + stride.x < dst_size.x && dst_w < dst_size.y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(dst_w * dst_size.z + dst_c, n * dst_size.x + dst_h + stride.x), r1);\n" \ +" }\n" \ +" if (dst_h < dst_size.x && dst_w + stride.y < dst_size.y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((dst_w + stride.y) * dst_size.z + dst_c, n * dst_size.x + dst_h), r2);\n" \ +" }\n" \ +" if (dst_h + stride.x < dst_size.x && dst_w + stride.y < dst_size.y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((dst_w + stride.y) * dst_size.z + dst_c, n * dst_size.x + dst_h + stride.x), r3);\n" \ +" }\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/depthwise_conv2d.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/depthwise_conv2d.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..2c87a5df4efbdab312bcc957c4565b6f97f95e4f --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/depthwise_conv2d.cl.inc @@ -0,0 +1,884 @@ +static const char *depthwise_conv2d_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void DepthwiseConv2d_IMG_NHWC4(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size,\n" \ +" int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1);\n" \ +" int Y = get_global_id(2);\n" \ +" int Z = get_global_id(0);\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int batch = Y / dst_size.y;\n" \ +" int y_offset = (Y % dst_size.y) * stride.y + padding.y;\n" \ +" int fx_c = Z * kernel_size.x * kernel_size.y;\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" bool outside_y = y_c < 0 || y_c >= src_size.y;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" bool outside_x = x_c < 0 || x_c >= src_size.x;\n" \ +" if (!outside_x && !outside_y) {\n" \ +" FLT4 flt_p = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z));\n" \ +" FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r += TO_FLT4(src_p * flt_p);\n" \ +" }\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" FLT4 bias_p = bias[Z];\n" \ +" FLT4 res = TO_FLT4(r) + bias_p;\n" \ +" res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), res);\n" \ +"}\n" \ +"\n" \ +"__kernel void DepthwiseConv2d_IMG_NHWC4_1x1(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size,\n" \ +" int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1);\n" \ +" int Y = get_global_id(2);\n" \ +" int Z = get_global_id(0);\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int batch = Y / dst_size.y;\n" \ +" int y_offset = (Y % dst_size.y) * stride.y + padding.y;\n" \ +" int fx_c = Z;\n" \ +" {\n" \ +" int y_c = y_offset;\n" \ +" bool outside_y = y_c < 0 || y_c >= src_size.y;\n" \ +" {\n" \ +" int x_c = x_offset;\n" \ +" bool outside_x = x_c < 0 || x_c >= src_size.x;\n" \ +" if (!outside_x && !outside_y) {\n" \ +" FLT4 flt_p = READ_IMAGE(filter, smp_zero, (int2)(0, Z));\n" \ +" FLT4 src_p = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r += TO_FLT4(src_p * flt_p);\n" \ +" }\n" \ +" }\n" \ +" }\n" \ +" FLT4 bias_p = bias[Z];\n" \ +" FLT4 res = TO_FLT4(r) + bias_p;\n" \ +" res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), res);\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_FIWI_NHWC4_b222(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size,\n" \ +" int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1) * 2;\n" \ +" int heightOfBlock = dst_size.y + (dst_size.y & 0x1);\n" \ +" int batch = get_global_id(2) / (heightOfBlock >> 1);\n" \ +" int Y = (get_global_id(2) * 2 - heightOfBlock * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0) * 2;\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[8] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" bool last_y =\n" \ +" ((get_global_id(2) - batch * (heightOfBlock >> 1)) == ((heightOfBlock >> 1) - 1)) && ((dst_size.y & 0x1) == 1);\n" \ +" bool last_c = (get_global_id(0) == (dst_size.z >> 1)) && ((dst_size.z & 0x1) == 1);\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" int y_c_a1 = y_c + stride.y;\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.z : y_c;\n" \ +" y_c_a1 = y_c_a1 < 0 || y_c_a1 >= src_size.y ? src_size.y * src_size.z : y_c_a1;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" int x_c_a1 = x_c + stride.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" int x_a1_sign = x_c_a1 < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z));\n" \ +" FLT4 flt_p1 = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z + 1));\n" \ +" {\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c));\n" \ +" FLT4 src_p00_c1 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)((Z + 1) * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" r[1] += TO_FLT4(src_p00_c1 * flt_p1);\n" \ +" }\n" \ +" {\n" \ +" FLT4 src_p01_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_a1_sign + x_c_a1 * src_size.z, y_c));\n" \ +" FLT4 src_p01_c1 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)((Z + 1) * x_a1_sign + x_c_a1 * src_size.z, y_c + batch * src_size.y));\n" \ +" r[2] += TO_FLT4(src_p01_c0 * flt_p0);\n" \ +" r[3] += TO_FLT4(src_p01_c1 * flt_p1);\n" \ +" }\n" \ +" {\n" \ +" FLT4 src_p10_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c_a1));\n" \ +" FLT4 src_p10_c1 = READ_IMAGE(src_data, smp_zero, (int2)(Z + 1 + x_c * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[4] += TO_FLT4(src_p10_c0 * flt_p0);\n" \ +" r[5] += TO_FLT4(src_p10_c1 * flt_p1);\n" \ +" }\n" \ +" {\n" \ +" FLT4 src_p11_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_a1_sign + x_c_a1 * src_size.z, y_c_a1));\n" \ +" FLT4 src_p11_c1 = READ_IMAGE(src_data, smp_zero,\n" \ +" (int2)((Z + 1) * x_a1_sign + x_c_a1 * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[6] += TO_FLT4(src_p11_c0 * flt_p0);\n" \ +" r[7] += TO_FLT4(src_p11_c1 * flt_p1);\n" \ +" }\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z + 1];\n" \ +" r[2] += bias[Z];\n" \ +" r[3] += bias[Z + 1];\n" \ +" r[4] += bias[Z];\n" \ +" r[5] += bias[Z + 1];\n" \ +" r[6] += bias[Z];\n" \ +" r[7] += bias[Z + 1];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[2] = clamp(r[2], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[3] = clamp(r[3], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[4] = clamp(r[4], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[5] = clamp(r[5], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[6] = clamp(r[6], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[7] = clamp(r[7], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z + 1, Y), r[1]);\n" \ +" }\n" \ +" if (!last_x) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z, Y), r[2]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z + 1, Y), r[3]);\n" \ +" }\n" \ +" }\n" \ +" if (!last_y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y + 1), r[4]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z + 1, Y + 1), r[5]);\n" \ +" }\n" \ +" }\n" \ +" if (!last_y && !last_x) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z, Y + 1), r[6]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z + 1, Y + 1), r[7]);\n" \ +" }\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_FIWI_NHWC4_b221(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size,\n" \ +" int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1) * 2;\n" \ +" int heightOfBlock = dst_size.y + (dst_size.y & 0x1);\n" \ +" int batch = get_global_id(2) / (heightOfBlock >> 1);\n" \ +" int Y = (get_global_id(2) * 2 - heightOfBlock * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0);\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[4] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" bool last_y =\n" \ +" ((get_global_id(2) - batch * (heightOfBlock >> 1)) == ((heightOfBlock >> 1) - 1)) && ((dst_size.y & 0x1) == 1);\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" int y_c_a1 = y_c + stride.y;\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.w : y_c;\n" \ +" y_c_a1 = y_c_a1 < 0 || y_c_a1 >= src_size.y ? src_size.y * src_size.w : y_c_a1;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" int x_c_a1 = x_c + stride.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" int x_a1_sign = x_c_a1 < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z));\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" FLT4 src_p01_c0 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(Z * x_a1_sign + x_c_a1 * src_size.z, y_c + batch * src_size.y));\n" \ +" r[1] += TO_FLT4(src_p01_c0 * flt_p0);\n" \ +" FLT4 src_p10_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[2] += TO_FLT4(src_p10_c0 * flt_p0);\n" \ +" FLT4 src_p11_c0 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(Z * x_a1_sign + x_c_a1 * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[3] += TO_FLT4(src_p11_c0 * flt_p0);\n" \ +"\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z];\n" \ +" r[2] += bias[Z];\n" \ +" r[3] += bias[Z];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[2] = clamp(r[2], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[3] = clamp(r[3], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_x) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z, Y), r[1]);\n" \ +" }\n" \ +" if (!last_y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y + 1), r[2]);\n" \ +" }\n" \ +" if (!last_y && !last_x) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z, Y + 1), r[3]);\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_FIWI_NHWC4_b212(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size,\n" \ +" int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1);\n" \ +" int heightOfBlock = dst_size.y + (dst_size.y & 0x1);\n" \ +" int batch = get_global_id(2) / (heightOfBlock >> 1);\n" \ +" int Y = (get_global_id(2) * 2 - heightOfBlock * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0) * 2;\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[4] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" bool last_y =\n" \ +" ((get_global_id(2) - batch * (heightOfBlock >> 1)) == ((heightOfBlock >> 1) - 1)) && ((dst_size.y & 0x1) == 1);\n" \ +" bool last_c = (get_global_id(0) == (dst_size.z >> 1)) && ((dst_size.z & 0x1) == 1);\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" int y_c_a1 = y_c + stride.y;\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.z : y_c;\n" \ +" y_c_a1 = y_c_a1 < 0 || y_c_a1 >= src_size.y ? src_size.y * src_size.z : y_c_a1;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" int x_c_a1 = x_c + stride.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z));\n" \ +" FLT4 flt_p1 = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z + 1));\n" \ +" {\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c));\n" \ +" FLT4 src_p00_c1 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)((Z + 1) * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" r[1] += TO_FLT4(src_p00_c1 * flt_p1);\n" \ +" }\n" \ +" {\n" \ +" FLT4 src_p10_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c_a1));\n" \ +" FLT4 src_p10_c1 = READ_IMAGE(src_data, smp_zero, (int2)(Z + 1 + x_c * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[2] += TO_FLT4(src_p10_c0 * flt_p0);\n" \ +" r[3] += TO_FLT4(src_p10_c1 * flt_p1);\n" \ +" }\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z + 1];\n" \ +" r[2] += bias[Z];\n" \ +" r[3] += bias[Z + 1];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[2] = clamp(r[2], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[3] = clamp(r[3], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z + 1, Y), r[1]);\n" \ +" }\n" \ +" if (!last_y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y + 1), r[2]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z + 1, Y + 1), r[3]);\n" \ +" }\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_FIWI_NHWC4_b211(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size,\n" \ +" int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1);\n" \ +" int heightOfBlock = dst_size.y + (dst_size.y & 0x1);\n" \ +" int batch = get_global_id(2) / (heightOfBlock >> 1);\n" \ +" int Y = (get_global_id(2) * 2 - heightOfBlock * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0);\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[2] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" bool last_y =\n" \ +" ((get_global_id(2) - batch * (heightOfBlock >> 1)) == ((heightOfBlock >> 1) - 1)) && ((dst_size.y & 0x1) == 1);\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" int y_c_a1 = y_c + stride.y;\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.w : y_c;\n" \ +" y_c_a1 = y_c_a1 < 0 || y_c_a1 >= src_size.y ? src_size.y * src_size.w : y_c_a1;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z));\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" FLT4 src_p10_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[1] += TO_FLT4(src_p10_c0 * flt_p0);\n" \ +"\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y + 1), r[1]);\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_FIWI_NHWC4_b121(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size,\n" \ +" int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1) * 2;\n" \ +" int batch = get_global_id(2) / (dst_size.y);\n" \ +" int Y = (get_global_id(2) - dst_size.y * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0);\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[4] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.w : y_c;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" int x_c_a1 = x_c + stride.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" int x_a1_sign = x_c_a1 < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z));\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" FLT4 src_p01_c0 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(Z * x_a1_sign + x_c_a1 * src_size.z, y_c + batch * src_size.y));\n" \ +" r[1] += TO_FLT4(src_p01_c0 * flt_p0);\n" \ +"\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_x) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z, Y), r[1]);\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_FIWI_NHWC4_b112(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __read_only image2d_t filter, __global FLT4 *bias, int2 kernel_size,\n" \ +" int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1);\n" \ +" int batch = get_global_id(2) / (dst_size.y);\n" \ +" int Y = (get_global_id(2) - dst_size.y * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0) * 2;\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[4] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" bool last_c = (get_global_id(0) == (dst_size.z >> 1)) && ((dst_size.z & 0x1) == 1);\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.z : y_c;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" int x_c_a1 = x_c + stride.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z));\n" \ +" FLT4 flt_p1 = READ_IMAGE(filter, smp_zero, (int2)(ky * kernel_size.x + kx, Z + 1));\n" \ +" {\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c));\n" \ +" FLT4 src_p00_c1 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)((Z + 1) * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" r[1] += TO_FLT4(src_p00_c1 * flt_p1);\n" \ +" }\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z + 1];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z + 1, Y), r[1]);\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_IMG_NHWC4_b222(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride,\n" \ +" int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1) * 2;\n" \ +" int heightOfBlock = dst_size.y + (dst_size.y & 0x1);\n" \ +" int batch = get_global_id(2) / (heightOfBlock >> 1);\n" \ +" int Y = (get_global_id(2) * 2 - heightOfBlock * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0) * 2;\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[8] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" bool last_y =\n" \ +" ((get_global_id(2) - batch * (heightOfBlock >> 1)) == ((heightOfBlock >> 1) - 1)) && ((dst_size.y & 0x1) == 1);\n" \ +" bool last_c = (get_global_id(0) == (dst_size.z >> 1)) && ((dst_size.z & 0x1) == 1);\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" int y_c_a1 = y_c + stride.y;\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.z : y_c;\n" \ +" y_c_a1 = y_c_a1 < 0 || y_c_a1 >= src_size.y ? src_size.y * src_size.z : y_c_a1;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" int x_c_a1 = x_c + stride.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" int x_a1_sign = x_c_a1 < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = filter[fx_c];\n" \ +" FLT4 flt_p1 = filter[fx_c + f_len];\n" \ +" {\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c));\n" \ +" FLT4 src_p00_c1 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)((Z + 1) * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" r[1] += TO_FLT4(src_p00_c1 * flt_p1);\n" \ +" }\n" \ +" {\n" \ +" FLT4 src_p01_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_a1_sign + x_c_a1 * src_size.z, y_c));\n" \ +" FLT4 src_p01_c1 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)((Z + 1) * x_a1_sign + x_c_a1 * src_size.z, y_c + batch * src_size.y));\n" \ +" r[2] += TO_FLT4(src_p01_c0 * flt_p0);\n" \ +" r[3] += TO_FLT4(src_p01_c1 * flt_p1);\n" \ +" }\n" \ +" {\n" \ +" FLT4 src_p10_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c_a1));\n" \ +" FLT4 src_p10_c1 = READ_IMAGE(src_data, smp_zero, (int2)(Z + 1 + x_c * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[4] += TO_FLT4(src_p10_c0 * flt_p0);\n" \ +" r[5] += TO_FLT4(src_p10_c1 * flt_p1);\n" \ +" }\n" \ +" {\n" \ +" FLT4 src_p11_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_a1_sign + x_c_a1 * src_size.z, y_c_a1));\n" \ +" FLT4 src_p11_c1 = READ_IMAGE(src_data, smp_zero,\n" \ +" (int2)((Z + 1) * x_a1_sign + x_c_a1 * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[6] += TO_FLT4(src_p11_c0 * flt_p0);\n" \ +" r[7] += TO_FLT4(src_p11_c1 * flt_p1);\n" \ +" }\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z + 1];\n" \ +" r[2] += bias[Z];\n" \ +" r[3] += bias[Z + 1];\n" \ +" r[4] += bias[Z];\n" \ +" r[5] += bias[Z + 1];\n" \ +" r[6] += bias[Z];\n" \ +" r[7] += bias[Z + 1];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[2] = clamp(r[2], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[3] = clamp(r[3], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[4] = clamp(r[4], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[5] = clamp(r[5], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[6] = clamp(r[6], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[7] = clamp(r[7], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z + 1, Y), r[1]);\n" \ +" }\n" \ +" if (!last_x) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z, Y), r[2]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z + 1, Y), r[3]);\n" \ +" }\n" \ +" }\n" \ +" if (!last_y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y + 1), r[4]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z + 1, Y + 1), r[5]);\n" \ +" }\n" \ +" }\n" \ +" if (!last_y && !last_x) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z, Y + 1), r[6]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z + 1, Y + 1), r[7]);\n" \ +" }\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_IMG_NHWC4_b221(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride,\n" \ +" int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1) * 2;\n" \ +" int heightOfBlock = dst_size.y + (dst_size.y & 0x1);\n" \ +" int batch = get_global_id(2) / (heightOfBlock >> 1);\n" \ +" int Y = (get_global_id(2) * 2 - heightOfBlock * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0);\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[4] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" bool last_y =\n" \ +" ((get_global_id(2) - batch * (heightOfBlock >> 1)) == ((heightOfBlock >> 1) - 1)) && ((dst_size.y & 0x1) == 1);\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" int y_c_a1 = y_c + stride.y;\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.w : y_c;\n" \ +" y_c_a1 = y_c_a1 < 0 || y_c_a1 >= src_size.y ? src_size.y * src_size.w : y_c_a1;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" int x_c_a1 = x_c + stride.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" int x_a1_sign = x_c_a1 < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = filter[fx_c];\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" FLT4 src_p01_c0 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(Z * x_a1_sign + x_c_a1 * src_size.z, y_c + batch * src_size.y));\n" \ +" r[1] += TO_FLT4(src_p01_c0 * flt_p0);\n" \ +" FLT4 src_p10_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[2] += TO_FLT4(src_p10_c0 * flt_p0);\n" \ +" FLT4 src_p11_c0 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(Z * x_a1_sign + x_c_a1 * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[3] += TO_FLT4(src_p11_c0 * flt_p0);\n" \ +"\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z];\n" \ +" r[2] += bias[Z];\n" \ +" r[3] += bias[Z];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[2] = clamp(r[2], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[3] = clamp(r[3], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_x) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z, Y), r[1]);\n" \ +" }\n" \ +" if (!last_y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y + 1), r[2]);\n" \ +" }\n" \ +" if (!last_y && !last_x) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z, Y + 1), r[3]);\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_IMG_NHWC4_b212(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride,\n" \ +" int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1);\n" \ +" int heightOfBlock = dst_size.y + (dst_size.y & 0x1);\n" \ +" int batch = get_global_id(2) / (heightOfBlock >> 1);\n" \ +" int Y = (get_global_id(2) * 2 - heightOfBlock * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0) * 2;\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[4] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" bool last_y =\n" \ +" ((get_global_id(2) - batch * (heightOfBlock >> 1)) == ((heightOfBlock >> 1) - 1)) && ((dst_size.y & 0x1) == 1);\n" \ +" bool last_c = (get_global_id(0) == (dst_size.z >> 1)) && ((dst_size.z & 0x1) == 1);\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" int y_c_a1 = y_c + stride.y;\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.z : y_c;\n" \ +" y_c_a1 = y_c_a1 < 0 || y_c_a1 >= src_size.y ? src_size.y * src_size.z : y_c_a1;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" int x_c_a1 = x_c + stride.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = filter[fx_c];\n" \ +" FLT4 flt_p1 = filter[fx_c + f_len];\n" \ +" {\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c));\n" \ +" FLT4 src_p00_c1 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)((Z + 1) * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" r[1] += TO_FLT4(src_p00_c1 * flt_p1);\n" \ +" }\n" \ +" {\n" \ +" FLT4 src_p10_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c_a1));\n" \ +" FLT4 src_p10_c1 = READ_IMAGE(src_data, smp_zero, (int2)(Z + 1 + x_c * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[2] += TO_FLT4(src_p10_c0 * flt_p0);\n" \ +" r[3] += TO_FLT4(src_p10_c1 * flt_p1);\n" \ +" }\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z + 1];\n" \ +" r[2] += bias[Z];\n" \ +" r[3] += bias[Z + 1];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[2] = clamp(r[2], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[3] = clamp(r[3], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z + 1, Y), r[1]);\n" \ +" }\n" \ +" if (!last_y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y + 1), r[2]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z + 1, Y + 1), r[3]);\n" \ +" }\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_IMG_NHWC4_b211(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride,\n" \ +" int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1);\n" \ +" int heightOfBlock = dst_size.y + (dst_size.y & 0x1);\n" \ +" int batch = get_global_id(2) / (heightOfBlock >> 1);\n" \ +" int Y = (get_global_id(2) * 2 - heightOfBlock * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0);\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[2] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" bool last_y =\n" \ +" ((get_global_id(2) - batch * (heightOfBlock >> 1)) == ((heightOfBlock >> 1) - 1)) && ((dst_size.y & 0x1) == 1);\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" int y_c_a1 = y_c + stride.y;\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.w : y_c;\n" \ +" y_c_a1 = y_c_a1 < 0 || y_c_a1 >= src_size.y ? src_size.y * src_size.w : y_c_a1;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = filter[fx_c];\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" FLT4 src_p10_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[1] += TO_FLT4(src_p10_c0 * flt_p0);\n" \ +"\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y + 1), r[1]);\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_IMG_NHWC4_b121(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride,\n" \ +" int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1) * 2;\n" \ +" int batch = get_global_id(2) / (dst_size.y);\n" \ +" int Y = (get_global_id(2) - dst_size.y * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0);\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[4] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.w : y_c;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" int x_c_a1 = x_c + stride.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" int x_a1_sign = x_c_a1 < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = filter[fx_c];\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" FLT4 src_p01_c0 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(Z * x_a1_sign + x_c_a1 * src_size.z, y_c + batch * src_size.y));\n" \ +" r[1] += TO_FLT4(src_p01_c0 * flt_p0);\n" \ +"\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_x) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z, Y), r[1]);\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_IMG_NHWC4_b112(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __global FLT4 *filter, __global FLT4 *bias, int2 kernel_size, int2 stride,\n" \ +" int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1);\n" \ +" int batch = get_global_id(2) / (dst_size.y);\n" \ +" int Y = (get_global_id(2) - dst_size.y * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0) * 2;\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[4] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" bool last_c = (get_global_id(0) == (dst_size.z >> 1)) && ((dst_size.z & 0x1) == 1);\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.z : y_c;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" int x_c_a1 = x_c + stride.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = filter[fx_c];\n" \ +" FLT4 flt_p1 = filter[fx_c + f_len];\n" \ +" {\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c));\n" \ +" FLT4 src_p00_c1 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)((Z + 1) * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" r[1] += TO_FLT4(src_p00_c1 * flt_p1);\n" \ +" }\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z + 1];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_c) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z + 1, Y), r[1]);\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_IMG_NHWC4_1x1_b221(__write_only image2d_t dst_data, __read_only image2d_t src_data,\n" \ +" __global FLT4 *filter, __global FLT4 *bias, int2 kernel_size,\n" \ +" int2 stride, int2 padding, int2 dilation, int4 src_size, int4 dst_size,\n" \ +" float relu_clip_min, float relu_clip_max) {\n" \ +" int X = get_global_id(1) * 2;\n" \ +" int heightOfBlock = dst_size.y + (dst_size.y & 0x1);\n" \ +" int batch = get_global_id(2) / (heightOfBlock >> 1);\n" \ +" int Y = (get_global_id(2) * 2 - heightOfBlock * batch) + batch * dst_size.y;\n" \ +" int Z = get_global_id(0);\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r[4] = {(FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f), (FLT4)(0.0f, 0.0f, 0.0f, 0.0f),\n" \ +" (FLT4)(0.0f, 0.0f, 0.0f, 0.0f)};\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int f_len = kernel_size.x * kernel_size.y;\n" \ +" int fx_c = Z * f_len;\n" \ +" bool last_x = (get_global_id(1) == (dst_size.x >> 1)) && ((dst_size.x & 0x1) == 1);\n" \ +" bool last_y =\n" \ +" ((get_global_id(2) - batch * (heightOfBlock >> 1)) == ((heightOfBlock >> 1) - 1)) && ((dst_size.y & 0x1) == 1);\n" \ +" int y_c = y_offset;\n" \ +" int y_c_a1 = y_c + stride.y;\n" \ +" int x_c = x_offset;\n" \ +" int x_c_a1 = x_c + stride.x;\n" \ +" int x_sign = x_c < 0 ? -1 : 1;\n" \ +" int x_a1_sign = x_c_a1 < 0 ? -1 : 1;\n" \ +" FLT4 flt_p0 = filter[fx_c];\n" \ +" y_c = y_c < 0 || y_c >= src_size.y ? src_size.y * src_size.z : y_c;\n" \ +" y_c_a1 = y_c_a1 < 0 || y_c_a1 >= src_size.y ? src_size.y * src_size.z : y_c_a1;\n" \ +" FLT4 src_p00_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z * x_sign + x_c * src_size.z, y_c + batch * src_size.y));\n" \ +" r[0] += TO_FLT4(src_p00_c0 * flt_p0);\n" \ +" FLT4 src_p01_c0 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(Z * x_a1_sign + x_c_a1 * src_size.z, y_c + batch * src_size.y));\n" \ +" r[1] += TO_FLT4(src_p01_c0 * flt_p0);\n" \ +" FLT4 src_p10_c0 = READ_IMAGE(src_data, smp_zero, (int2)(Z + x_c * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[2] += TO_FLT4(src_p10_c0 * flt_p0);\n" \ +" FLT4 src_p11_c0 =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(Z * x_a1_sign + x_c_a1 * src_size.z, y_c_a1 + batch * src_size.y));\n" \ +" r[3] += TO_FLT4(src_p11_c0 * flt_p0);\n" \ +"\n" \ +" r[0] += bias[Z];\n" \ +" r[1] += bias[Z];\n" \ +" r[2] += bias[Z];\n" \ +" r[3] += bias[Z];\n" \ +" r[0] = clamp(r[0], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[1] = clamp(r[1], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[2] = clamp(r[2], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" r[3] = clamp(r[3], (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y), r[0]);\n" \ +" if (!last_x) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z, Y), r[1]);\n" \ +" }\n" \ +" if (!last_y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, Y + 1), r[2]);\n" \ +" }\n" \ +" if (!last_y && !last_x) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((X + 1) * dst_size.z + Z, Y + 1), r[3]);\n" \ +" }\n" \ +"}\n" \ +"__kernel void DepthwiseConv2d_BUF_NC4HW4(__global FLT4 *dst_data, __global FLT4 *src_data, __global FLT4 *filter,\n" \ +" __global FLT4 *bias, int2 kernel_size, int2 stride, int2 padding,\n" \ +" int2 dilation, int4 src_size, int4 dst_size, float relu_clip_min,\n" \ +" float relu_clip_max) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" int Z = get_global_id(2);\n" \ +" if (X >= dst_size.x || Y >= (dst_size.y * dst_size.w) || Z >= dst_size.z) return;\n" \ +" FLT4 r = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" int x_offset = X * stride.x + padding.x;\n" \ +" int batch = Y / dst_size.y;\n" \ +" int y_offset = (Y - batch * dst_size.y) * stride.y + padding.y;\n" \ +" int fx_c = Z * kernel_size.x * kernel_size.y;\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = y_offset + ky * dilation.y;\n" \ +" bool outside_y = y_c < 0 || y_c >= src_size.y;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = x_offset + kx * dilation.x;\n" \ +" bool outside_x = x_c < 0 || x_c >= src_size.x;\n" \ +" if (!outside_x && !outside_y) {\n" \ +" FLT4 flt_p = filter[fx_c];\n" \ +" FLT4 src_p = src_data[(((Z)*src_size.y + (y_c + batch * src_size.y)) * src_size.x + (x_c))];\n" \ +" r += TO_FLT4(src_p * flt_p);\n" \ +" }\n" \ +" fx_c++;\n" \ +" }\n" \ +" }\n" \ +" FLT4 bias_p = bias[Z];\n" \ +" FLT4 res = TO_FLT4(r) + bias_p;\n" \ +" res = clamp(res, (FLT)(relu_clip_min), (FLT)(relu_clip_max));\n" \ +" dst_data[(((Z)*dst_size.y + (Y)) * dst_size.x + (X))] = res;\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/fullconnection.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/fullconnection.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..4ea90a2ce21127dd2481d328a4cd65fab351ef43 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/fullconnection.cl.inc @@ -0,0 +1,82 @@ +static const char *fullconnection_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#define C4NUM 4\n" \ +"#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"__kernel void FullConnection(__read_only image2d_t input, __write_only image2d_t output, __global FLT16 *weight,\n" \ +" __read_only image2d_t bias, int N, int CI4, int CO4, int2 in_img_shape, int act_type) {\n" \ +" int gidx = get_global_id(0); // CO4\n" \ +" int gidz = get_global_id(2); // N\n" \ +" int lidx = get_local_id(0);\n" \ +" int lidy = get_local_id(1);\n" \ +" bool inside = gidx < CO4 && gidz < N;\n" \ +" FLT4 result = (FLT4)(0.0f);\n" \ +" for (uint i = lidy; i < CI4 && inside; i += 4) {\n" \ +" int index = gidz * CI4 + i;\n" \ +" FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index % in_img_shape.y, index / in_img_shape.y));\n" \ +" FLT16 w = weight[i * CO4 + gidx];\n" \ +" result.x += dot(v, w.s0123);\n" \ +" result.y += dot(v, w.s4567);\n" \ +" result.z += dot(v, w.s89ab);\n" \ +" result.w += dot(v, w.scdef);\n" \ +" }\n" \ +" __local FLT4 temp[32][4];\n" \ +" temp[lidx][lidy] = result;\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" if (lidy == 0 && inside) {\n" \ +" result += temp[lidx][1];\n" \ +" result += temp[lidx][2];\n" \ +" result += temp[lidx][3];\n" \ +" result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0));\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" result = max(result, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" result = clamp(result, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" result = tanh(result);\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)(gidx, gidz), result);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void FullConnectionWeightVar(__read_only image2d_t input, __write_only image2d_t output,\n" \ +" __read_only image2d_t weight, __read_only image2d_t bias, int N, int CI4, int CO4,\n" \ +" int2 in_img_shape, int act_type) {\n" \ +" int gidx = get_global_id(0); // CO4\n" \ +" int gidz = get_global_id(2); // N\n" \ +" int lidx = get_local_id(0);\n" \ +" int lidy = get_local_id(1);\n" \ +" bool inside = gidx < CO4 && gidz < N;\n" \ +" FLT4 result = (FLT4)(0.0f);\n" \ +" for (uint i = lidy; i < CI4 && inside; i += 4) {\n" \ +" int index = gidz * CI4 + i;\n" \ +" FLT4 v = READ_IMAGE(input, smp_zero, (int2)(index % in_img_shape.y, index / in_img_shape.y));\n" \ +" FLT4 weight0 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4));\n" \ +" result.x += dot(v, weight0);\n" \ +" FLT4 weight1 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4 + 1));\n" \ +" result.y += dot(v, weight1);\n" \ +" FLT4 weight2 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4 + 2));\n" \ +" result.z += dot(v, weight2);\n" \ +" FLT4 weight3 = READ_IMAGE(weight, smp_zero, (int2)(i, gidx * 4 + 3));\n" \ +" result.w += dot(v, weight3);\n" \ +" }\n" \ +" __local FLT4 temp[32][4];\n" \ +" temp[lidx][lidy] = result;\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" if (lidy == 0 && inside) {\n" \ +" result += temp[lidx][1];\n" \ +" result += temp[lidx][2];\n" \ +" result += temp[lidx][3];\n" \ +" result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0));\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" result = max(result, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" result = clamp(result, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" } else if (act_type == ActivationType_TANH) {\n" \ +" result = tanh(result);\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)(gidx, gidz), result);\n" \ +" }\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/gather.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/gather.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..a75447f8e279d5cf3824edb733af0c3431aabc90 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/gather.cl.inc @@ -0,0 +1,45 @@ +static const char *gather_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#define C4NUM 4\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void gather(__write_only image2d_t dst_data, __read_only image2d_t src_data, __global int *indices,\n" \ +" int4 src_size, int4 dst_size, int indices_num, int axis) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" int Z = get_global_id(2);\n" \ +" if (X >= dst_size.x || Y >= dst_size.y * dst_size.w || Z >= dst_size.z || dst_size.y == 0) {\n" \ +" return;\n" \ +" }\n" \ +" DTYPE4 res_data = (DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" int batch = Y / dst_size.y;\n" \ +" int height = Y % dst_size.y;\n" \ +" if (axis == 0) {\n" \ +" res_data = READ_IMAGE(src_data, smp_zero, (int2)(X * src_size.z + Z, indices[batch] * src_size.y + height));\n" \ +" } else if (axis == 1) {\n" \ +" res_data = READ_IMAGE(src_data, smp_zero, (int2)(X * src_size.z + Z, batch * src_size.y + indices[height]));\n" \ +" } else if (axis == 2) {\n" \ +" res_data = READ_IMAGE(src_data, smp_zero, (int2)(indices[X] * src_size.z + Z, batch * src_size.y + height));\n" \ +" } else if (axis == 3) {\n" \ +" int offset[4] = {indices[Z * 4] / 4, indices[Z * 4 + 1] / 4, indices[Z * 4 + 2] / 4, indices[Z * 4 + 3] / 4};\n" \ +" DTYPE tmp[4];\n" \ +" DTYPE res_tmp[4];\n" \ +" for (int i = 0; i < indices_num; ++i) {\n" \ +" DTYPE4 rd_data = (DTYPE4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" rd_data = READ_IMAGE(src_data, smp_zero, (int2)(X * src_size.z + offset[i], batch * src_size.y + height));\n" \ +" if (i >= 1 && offset[i] != offset[i - 1]) {\n" \ +" rd_data = READ_IMAGE(src_data, smp_zero, (int2)(X * src_size.z + offset[i], batch * src_size.y + height));\n" \ +" }\n" \ +" tmp[0] = rd_data.x;\n" \ +" tmp[1] = rd_data.y;\n" \ +" tmp[2] = rd_data.z;\n" \ +" tmp[3] = rd_data.w;\n" \ +" res_tmp[i] = tmp[indices[Z * 4 + i] % 4];\n" \ +" }\n" \ +" res_data.x = res_tmp[0];\n" \ +" res_data.y = res_tmp[1];\n" \ +" res_data.z = res_tmp[2];\n" \ +" res_data.w = res_tmp[3];\n" \ +" }\n" \ +" WRITE_IMAGE(dst_data, (int2)(X * dst_size.z + Z, batch * dst_size.y + height), res_data);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/int8/arithmetic.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/int8/arithmetic.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..4a5596dc5ff875089622cb23a54991855e8420a4 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/int8/arithmetic.cl.inc @@ -0,0 +1,21 @@ +static const char *arithmetic_source ="\n" +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"__kernel void ElementAddInt8(__read_only image2d_t input_a, __read_only image2d_t input_b,\n" \ +" __write_only image2d_t output, const int2 output_shape, float act_min, float act_max,\n" \ +" const float4 scale, const char4 zero_point) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +" char4 a = convert_char4(read_imagei(input_a, smp_none, (int2)(X, Y)));\n" \ +" char4 b = convert_char4(read_imagei(input_b, smp_none, (int2)(X, Y)));\n" \ +"\n" \ +" float4 real_a = convert_float4(a - zero_point.x) * scale.x;\n" \ +" float4 real_b = convert_float4(b - zero_point.y) * scale.y;\n" \ +" int4 result = convert_int4(round((real_a + real_b) / scale.z)) + zero_point.z;\n" \ +" result = clamp(result, (FLT)(act_min), (FLT)(act_max));\n" \ +" write_imagei(output, (int2)(X, Y), result);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/layer_norm.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/layer_norm.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..4623be3652dda5d987fe6463f3d55ed3c0dbe123 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/layer_norm.cl.inc @@ -0,0 +1,97 @@ +static const char *layer_norm_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))\n" \ +"#define C4NUM 4\n" \ +"\n" \ +"__kernel void ComputeMeanVarAxis3NHWC4(__read_only image2d_t src_data, __global FLT *mean_, __global FLT *variance_,\n" \ +" int4 in_shape, int normalized_shape_size) {\n" \ +" int X = get_global_id(0); // n*h\n" \ +" int Y = get_global_id(1); // w\n" \ +" if (X > in_shape.x * in_shape.y || Y > in_shape.z || in_shape.y == 0 || normalized_shape_size == 0) {\n" \ +" return;\n" \ +" }\n" \ +" int n = X / in_shape.y;\n" \ +" int h = X % in_shape.y;\n" \ +" int w = Y;\n" \ +" int ci4 = UP_DIV(in_shape.w, C4NUM);\n" \ +" int remainder = in_shape.w % C4NUM;\n" \ +" FLT4 mean_temp = {0.0f, 0.0f, 0.0f, 0.0f};\n" \ +" FLT4 var_temp = {0.0f, 0.0f, 0.0f, 0.0f};\n" \ +" FLT mean = 0.0f;\n" \ +" FLT var = 0.0f;\n" \ +"\n" \ +" // compute mean\n" \ +" for (int i = 0; i < ci4; ++i) {\n" \ +" FLT4 result_temp = READ_IMAGE(src_data, smp_none, (int2)(w * ci4 + i, n * in_shape.y + h));\n" \ +" mean_temp += result_temp;\n" \ +" }\n" \ +" mean = (mean_temp.x + mean_temp.y + mean_temp.z + mean_temp.w) / normalized_shape_size;\n" \ +" mean_temp.x = mean_temp.y = mean_temp.z = mean_temp.w = mean;\n" \ +"\n" \ +" // compute var\n" \ +" for (int i = 0; i < ci4; ++i) {\n" \ +" FLT4 result_temp = READ_IMAGE(src_data, smp_none, (int2)(w * ci4 + i, n * in_shape.y + h));\n" \ +" if ((i + 1) * C4NUM <= in_shape.w) {\n" \ +" var_temp += (result_temp - mean_temp) * (result_temp - mean_temp);\n" \ +" } else {\n" \ +" if (remainder == 1) {\n" \ +" mean_temp.x = mean;\n" \ +" mean_temp.y = mean_temp.z = mean_temp.w = 0.0f;\n" \ +" } else if (remainder == 2) {\n" \ +" mean_temp.x = mean_temp.y = mean;\n" \ +" mean_temp.z = mean_temp.w = 0.0f;\n" \ +" } else {\n" \ +" mean_temp.x = mean_temp.y = mean_temp.z = mean;\n" \ +" mean_temp.w = 0.0f;\n" \ +" }\n" \ +" var_temp += (result_temp - mean_temp) * (result_temp - mean_temp);\n" \ +" }\n" \ +" }\n" \ +" var = (var_temp.x + var_temp.y + var_temp.z + var_temp.w) / normalized_shape_size;\n" \ +"\n" \ +" // write result to dst\n" \ +" int position = (n * in_shape.y + h) * in_shape.z + w;\n" \ +" mean_[position] = mean;\n" \ +" variance_[position] = var;\n" \ +"}\n" \ +"\n" \ +"__kernel void LayerNormalization_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data,\n" \ +" __global FLT *mean_, __global FLT *variance_, __global FLT *gamma_,\n" \ +" __global FLT *beta_, int4 in_shape, float epsilon_, int begin_params_axis_) {\n" \ +" int X = get_global_id(0); // n*h\n" \ +" int Y = get_global_id(1); // w\n" \ +" int Z = get_global_id(2); // c4\n" \ +" if (X >= in_shape.x * in_shape.y || Y >= in_shape.z || Z >= in_shape.w || in_shape.y == 0) {\n" \ +" return;\n" \ +" }\n" \ +" int n = X / in_shape.y;\n" \ +" int h = X % in_shape.y;\n" \ +" int w = Y;\n" \ +" int c = Z;\n" \ +" int ci4 = UP_DIV(in_shape.w, C4NUM);\n" \ +" int postion_mv = 0;\n" \ +" int postion_gb = 0;\n" \ +" if (begin_params_axis_ == 1) {\n" \ +" postion_mv = n;\n" \ +" postion_gb = (h * in_shape.z + w) * ci4 * C4NUM + c * C4NUM;\n" \ +" } else if (begin_params_axis_ == 2) {\n" \ +" postion_mv = n * in_shape.y + h;\n" \ +" postion_gb = w * ci4 * C4NUM + c * C4NUM;\n" \ +" } else if (begin_params_axis_ == 3) {\n" \ +" postion_mv = (n * in_shape.y + h) * in_shape.z + w;\n" \ +" postion_gb = c * C4NUM;\n" \ +" }\n" \ +" FLT4 result = {0.0f, 0.0f, 0.0f, 0.0f};\n" \ +" FLT4 result_in = READ_IMAGE(src_data, smp_none, (int2)(w * ci4 + c, n * in_shape.y + h));\n" \ +" result.x = ((result_in.x - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)) * gamma_[postion_gb] +\n" \ +" beta_[postion_gb];\n" \ +" result.y = ((result_in.y - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)) * gamma_[postion_gb + 1] +\n" \ +" beta_[postion_gb + 1];\n" \ +" result.z = ((result_in.z - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)) * gamma_[postion_gb + 2] +\n" \ +" beta_[postion_gb + 2];\n" \ +" result.w = ((result_in.w - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)) * gamma_[postion_gb + 3] +\n" \ +" beta_[postion_gb + 3];\n" \ +" WRITE_IMAGE(dst_data, (int2)((w * ci4 + c), (n * in_shape.y + h)), result);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/matmul.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/matmul.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..66b1fddd690b30579a8b959e609491da59f0d641 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/matmul.cl.inc @@ -0,0 +1,149 @@ +static const char *matmul_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#define C4NUM 4\n" \ +"#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void MatMul_2d(__read_only image2d_t input, __write_only image2d_t output, __global FLT16 *weight,\n" \ +" __read_only image2d_t bias, int4 in_shape, int4 out_shape) {\n" \ +" int gidx = get_global_id(0); // CO4\n" \ +" int gidz = get_global_id(2); // N\n" \ +" int lidx = get_local_id(0);\n" \ +" int lidy = get_local_id(1);\n" \ +" int ci4 = UP_DIV(in_shape.w, C4NUM);\n" \ +" int co4 = UP_DIV(out_shape.w, C4NUM);\n" \ +" int n = out_shape.z;\n" \ +" bool inside = gidx < co4 && gidz < n;\n" \ +" FLT4 result = (FLT4)(0.0f);\n" \ +" for (uint i = lidy; i < ci4 && inside; i += 4) {\n" \ +" FLT4 v = READ_IMAGE(input, smp_zero, (int2)(i, gidz));\n" \ +" FLT16 w = weight[i * co4 + gidx];\n" \ +" result.x += dot(v, w.s0123);\n" \ +" result.y += dot(v, w.s4567);\n" \ +" result.z += dot(v, w.s89ab);\n" \ +" result.w += dot(v, w.scdef);\n" \ +" }\n" \ +" __local FLT4 temp[32][4];\n" \ +" temp[lidx][lidy] = result;\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" if (lidy == 0 && inside) {\n" \ +" result += temp[lidx][1];\n" \ +" result += temp[lidx][2];\n" \ +" result += temp[lidx][3];\n" \ +" result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0));\n" \ +" WRITE_IMAGE(output, (int2)(gidx, gidz), result);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void MatMul_4d(__read_only image2d_t input, __write_only image2d_t output, __global FLT16 *weight,\n" \ +" __read_only image2d_t bias, int4 in_shape, int4 out_shape) {\n" \ +" int gidx = get_global_id(0); // CO4\n" \ +" int gidy = get_global_id(1); // N * H * 4\n" \ +" int gidz = get_global_id(2); // W\n" \ +" int lidx = get_local_id(0);\n" \ +" int lidy = get_local_id(1);\n" \ +" int ci4 = UP_DIV(in_shape.w, C4NUM);\n" \ +" int co4 = UP_DIV(out_shape.w, C4NUM);\n" \ +" int n = out_shape.x;\n" \ +" int h = out_shape.y;\n" \ +" int w = out_shape.z;\n" \ +" int nh_index = gidy / 4;\n" \ +" bool inside = gidx < co4 && gidz < w && nh_index < n * h;\n" \ +" FLT4 result = (FLT4)(0.0f);\n" \ +" for (uint i = lidy; i < ci4 && inside; i += 4) {\n" \ +" FLT4 v = READ_IMAGE(input, smp_zero, (int2)(gidz * ci4 + i, nh_index));\n" \ +" FLT16 weight_value = weight[nh_index * ci4 * co4 + i * co4 + gidx];\n" \ +" result.x += dot(v, weight_value.s0123);\n" \ +" result.y += dot(v, weight_value.s4567);\n" \ +" result.z += dot(v, weight_value.s89ab);\n" \ +" result.w += dot(v, weight_value.scdef);\n" \ +" }\n" \ +" __local FLT4 temp[32][4];\n" \ +" temp[lidx][lidy] = result;\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" if (lidy == 0 && inside) {\n" \ +" result += temp[lidx][1];\n" \ +" result += temp[lidx][2];\n" \ +" result += temp[lidx][3];\n" \ +" result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0));\n" \ +" WRITE_IMAGE(output, (int2)(gidz * co4 + gidx, nh_index), result);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void MatMulActWeightTransposeB_4d(__read_only image2d_t input, __write_only image2d_t output,\n" \ +" __read_only image2d_t weight, __read_only image2d_t bias, int4 in_shape,\n" \ +" int4 out_shape) {\n" \ +" int gidx = get_global_id(0); // CO4\n" \ +" int gidy = get_global_id(1); // N * H * 4\n" \ +" int gidz = get_global_id(2); // W\n" \ +" int lidx = get_local_id(0);\n" \ +" int lidy = get_local_id(1);\n" \ +" int ci4 = UP_DIV(in_shape.w, C4NUM);\n" \ +" int co4 = UP_DIV(out_shape.w, C4NUM);\n" \ +" int n = out_shape.x;\n" \ +" int h = out_shape.y;\n" \ +" int w = out_shape.z;\n" \ +" int nh_index = gidy / 4;\n" \ +" bool inside = gidx < co4 && gidz < w && nh_index < n * h;\n" \ +" FLT4 result = (FLT4)(0.0f);\n" \ +" for (uint i = lidy; i < ci4 && inside; i += 4) {\n" \ +" FLT4 v = READ_IMAGE(input, smp_zero, (int2)(gidz * ci4 + i, nh_index));\n" \ +" FLT4 weight_value0 = READ_IMAGE(weight, smp_zero, (int2)(gidx * 4 * ci4 + i, nh_index));\n" \ +" result.x += dot(v, weight_value0);\n" \ +" FLT4 weight_value1 = READ_IMAGE(weight, smp_zero, (int2)((gidx * 4 + 1) * ci4 + i, nh_index));\n" \ +" result.y += dot(v, weight_value1);\n" \ +" FLT4 weight_value2 = READ_IMAGE(weight, smp_zero, (int2)((gidx * 4 + 2) * ci4 + i, nh_index));\n" \ +" result.z += dot(v, weight_value2);\n" \ +" FLT4 weight_value3 = READ_IMAGE(weight, smp_zero, (int2)((gidx * 4 + 3) * ci4 + i, nh_index));\n" \ +" result.w += dot(v, weight_value3);\n" \ +" }\n" \ +" __local FLT4 temp[32][4];\n" \ +" temp[lidx][lidy] = result;\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" if (lidy == 0 && inside) {\n" \ +" result += temp[lidx][1];\n" \ +" result += temp[lidx][2];\n" \ +" result += temp[lidx][3];\n" \ +" result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0));\n" \ +" WRITE_IMAGE(output, (int2)(gidz * co4 + gidx, nh_index), result);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void MatMulActWeight_4d(__read_only image2d_t input, __write_only image2d_t output,\n" \ +" __read_only image2d_t weight, __read_only image2d_t bias, int4 in_shape,\n" \ +" int4 out_shape) {\n" \ +" int gidx = get_global_id(0); // CO4\n" \ +" int gidy = get_global_id(1); // N * H * 4\n" \ +" int gidz = get_global_id(2); // W\n" \ +" int lidx = get_local_id(0);\n" \ +" int lidy = get_local_id(1);\n" \ +" int ci4 = UP_DIV(in_shape.w, C4NUM);\n" \ +" int co4 = UP_DIV(out_shape.w, C4NUM);\n" \ +" int n = out_shape.x;\n" \ +" int h = out_shape.y;\n" \ +" int w = out_shape.z;\n" \ +" int nh_index = gidy / 4;\n" \ +" bool inside = gidx < co4 && gidz < w && nh_index < n * h;\n" \ +" FLT4 result = (FLT4)(0.0f);\n" \ +" for (uint i = lidy; i < ci4 && inside; i += 4) {\n" \ +" FLT4 v = READ_IMAGE(input, smp_zero, (int2)(gidz * ci4 + i, nh_index));\n" \ +" FLT4 weight_value0 = READ_IMAGE(weight, smp_zero, (int2)(i * 4 * co4 + gidx, nh_index));\n" \ +" result += v.x * weight_value0;\n" \ +" FLT4 weight_value1 = READ_IMAGE(weight, smp_zero, (int2)((i * 4 + 1) * co4 + gidx, nh_index));\n" \ +" result += v.y * weight_value1;\n" \ +" FLT4 weight_value2 = READ_IMAGE(weight, smp_zero, (int2)((i * 4 + 2) * co4 + gidx, nh_index));\n" \ +" result += v.z * weight_value2;\n" \ +" FLT4 weight_value3 = READ_IMAGE(weight, smp_zero, (int2)((i * 4 + 3) * co4 + gidx, nh_index));\n" \ +" result += v.w * weight_value3;\n" \ +" }\n" \ +" __local FLT4 temp[32][4];\n" \ +" temp[lidx][lidy] = result;\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" if (lidy == 0 && inside) {\n" \ +" result += temp[lidx][1];\n" \ +" result += temp[lidx][2];\n" \ +" result += temp[lidx][3];\n" \ +" result += READ_IMAGE(bias, smp_zero, (int2)(gidx, 0));\n" \ +" WRITE_IMAGE(output, (int2)(gidz * co4 + gidx, nh_index), result);\n" \ +" }\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/one_hot.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/one_hot.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..2307df03a0c97e61f2d705f9bd25ce0749a79f56 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/one_hot.cl.inc @@ -0,0 +1,175 @@ +static const char *one_hot_source ="\n" +"#ifdef cl_khr_fp16\n" \ +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#endif\n" \ +"\n" \ +"#define C4NUM 4\n" \ +"#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))\n" \ +"\n" \ +"#define SET_ON_OR_OFF_VALUE(RESULT, POSITION, INDICES, ON_VALUE, OFF_VALUE) \\\n" \ +" if (POSITION == INDICES) { \\\n" \ +" RESULT = (float)(ON_VALUE); \\\n" \ +" } else { \\\n" \ +" RESULT = (float)(OFF_VALUE); \\\n" \ +" }\n" \ +"\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void OneHotAxis0(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,\n" \ +" int4 out_shape, int depth, float on_value, float off_value, int C, int support_neg_index) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // H * N\n" \ +" if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;\n" \ +" int N = Z / out_shape.y;\n" \ +" int H = Z % out_shape.y;\n" \ +" int in_index = (H * out_shape.z + Y) * out_shape.w + X;\n" \ +" int4 indices = READ_IMAGE(src_data, smp_zero, (int2)(in_index % in_image2d_shape.x, in_index / in_image2d_shape.x));\n" \ +" int *indices_int = (int *)&indices;\n" \ +" for (int i = 0; i < C4NUM; i++) {\n" \ +" if (support_neg_index != 0 && indices_int[i] < 0) {\n" \ +" indices_int[i] += depth;\n" \ +" }\n" \ +" }\n" \ +" float4 result = (float4)(0.f);\n" \ +" if (4 * X < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.x, N, indices_int[0], on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 1 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.y, N, indices_int[1], on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 2 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.z, N, indices_int[2], on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 3 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.w, N, indices_int[3], on_value, off_value);\n" \ +" }\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void OneHotAxis1(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,\n" \ +" int4 out_shape, int depth, float on_value, float off_value, int C, int support_neg_index) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // H * N\n" \ +" if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;\n" \ +" int N = Z / out_shape.y;\n" \ +" int H = Z % out_shape.y;\n" \ +" int in_index = (N * out_shape.z + Y) * out_shape.w + X;\n" \ +" int4 indices = READ_IMAGE(src_data, smp_zero, (int2)(in_index % in_image2d_shape.x, in_index / in_image2d_shape.x));\n" \ +" int *indices_int = (int *)&indices;\n" \ +" for (int i = 0; i < C4NUM; i++) {\n" \ +" if (support_neg_index != 0 && indices_int[i] < 0) {\n" \ +" indices_int[i] += depth;\n" \ +" }\n" \ +" }\n" \ +" float4 result = (float4)(0.f);\n" \ +" if (4 * X < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.x, H, indices_int[0], on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 1 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.y, H, indices_int[1], on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 2 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.z, H, indices_int[2], on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 3 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.w, H, indices_int[3], on_value, off_value);\n" \ +" }\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void OneHotAxis2(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,\n" \ +" int4 out_shape, int depth, float on_value, float off_value, int C, int support_neg_index) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // H * N\n" \ +" if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;\n" \ +" int N = Z / out_shape.y;\n" \ +" int H = Z % out_shape.y;\n" \ +" int in_index = (N * out_shape.y + H) * out_shape.w + X;\n" \ +" int4 indices = READ_IMAGE(src_data, smp_zero, (int2)(in_index % in_image2d_shape.x, in_index / in_image2d_shape.x));\n" \ +" int *indices_int = (int *)&indices;\n" \ +" for (int i = 0; i < C4NUM; i++) {\n" \ +" if (support_neg_index != 0 && indices_int[i] < 0) {\n" \ +" indices_int[i] += depth;\n" \ +" }\n" \ +" }\n" \ +" float4 result = (float4)(0.f);\n" \ +" if (4 * X < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.x, Y, indices_int[0], on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 1 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.y, Y, indices_int[1], on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 2 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.z, Y, indices_int[2], on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 3 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.w, Y, indices_int[3], on_value, off_value);\n" \ +" }\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void OneHotAxis3(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,\n" \ +" int4 out_shape, int depth, float on_value, float off_value, int C, int support_neg_index) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // H * N\n" \ +" if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;\n" \ +" int N = Z / out_shape.y;\n" \ +" int H = Z % out_shape.y;\n" \ +" int ci4_size = UP_DIV(out_shape.z, C4NUM);\n" \ +" int in_index_c4 = (N * out_shape.y + H) * ci4_size + Y / 4;\n" \ +" int in_index_c4_remainder = Y % 4;\n" \ +" int4 indices =\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(in_index_c4 % in_image2d_shape.x, in_index_c4 / in_image2d_shape.x));\n" \ +" int *indices_int = (int *)&indices;\n" \ +" int index_one = indices_int[in_index_c4_remainder];\n" \ +" if (support_neg_index != 0 && index_one < 0) {\n" \ +" index_one += depth;\n" \ +" }\n" \ +" float4 result = (float4)(0.f);\n" \ +" if (4 * X < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.x, 4 * X, index_one, on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 1 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.y, 4 * X + 1, index_one, on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 2 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.z, 4 * X + 2, index_one, on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 3 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.w, 4 * X + 3, index_one, on_value, off_value);\n" \ +" }\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void OneHot2DAxis3(__read_only image2d_t src_data, __write_only image2d_t dst_data, int2 in_image2d_shape,\n" \ +" int4 out_shape, int depth, float on_value, float off_value, int C, int support_neg_index) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W (out_shape.w is 1, Y is always 0)\n" \ +" int Z = get_global_id(2); // H * N (out_shape.h is 1, so N == Z)\n" \ +" if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;\n" \ +" int in_index_c4_remainder = Z % 4;\n" \ +" int4 indices = READ_IMAGE(src_data, smp_zero, (int2)(Z / C4NUM, 0));\n" \ +" int *indices_int = (int *)&indices;\n" \ +" int index_one = indices_int[in_index_c4_remainder];\n" \ +" if (support_neg_index != 0 && index_one < 0) {\n" \ +" index_one += depth;\n" \ +" }\n" \ +" float4 result = (float4)(0.f);\n" \ +" if (4 * X < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.x, 4 * X, index_one, on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 1 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.y, 4 * X + 1, index_one, on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 2 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.z, 4 * X + 2, index_one, on_value, off_value);\n" \ +" }\n" \ +" if (4 * X + 3 < C) {\n" \ +" SET_ON_OR_OFF_VALUE(result.w, 4 * X + 3, index_one, on_value, off_value);\n" \ +" }\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/pad.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/pad.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..ec9e3b0fc5082f21defe05562aa33d6eccfa7327 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/pad.cl.inc @@ -0,0 +1,59 @@ +static const char *pad_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"__kernel void Pad(__read_only image2d_t input, __write_only image2d_t output, int4 input_shape, int4 output_shape,\n" \ +" int2 io_slices, int4 pad_before, float constant_value) {\n" \ +" int IN = input_shape.x, IH = input_shape.y, IW = input_shape.z, CI = input_shape.w;\n" \ +" int ON = output_shape.x, OH = output_shape.y, OW = output_shape.z, CO = output_shape.w;\n" \ +" int CI_SLICES = io_slices.x, CO_SLICES = io_slices.y;\n" \ +" int on_oh = get_global_id(0);\n" \ +" int ow = get_global_id(1);\n" \ +" int co_slice = get_global_id(2);\n" \ +" int on = on_oh / OH;\n" \ +" int oh = on_oh % OH;\n" \ +" if (on >= ON || oh >= OH || ow >= OW || co_slice >= CO_SLICES) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" int in = on - pad_before.x;\n" \ +" int ih = oh - pad_before.y;\n" \ +" int iw = ow - pad_before.z;\n" \ +" int ci = co_slice * 4 - pad_before.w;\n" \ +" if (in < 0 || in >= IN || ih < 0 || ih >= IH || iw < 0 || iw >= IW || ci + 3 < 0 || ci >= CI) {\n" \ +" WRITE_IMAGE(output, (int2)(ow * CO_SLICES + co_slice, on_oh), (FLT4)(constant_value));\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" int offset = ci % 4;\n" \ +" if (offset < 0) {\n" \ +" offset += 4;\n" \ +" }\n" \ +" FLT4 src0 = READ_IMAGE(input, smp_zero, (int2)(iw * CI_SLICES + ci / 4, in * IH + ih));\n" \ +" if (offset == 0 && ci >= 0 && ci + 3 < CI) {\n" \ +" WRITE_IMAGE(output, (int2)(ow * CO_SLICES + co_slice, on_oh), src0);\n" \ +" return;\n" \ +" }\n" \ +" FLT4 src1 = READ_IMAGE(input, smp_zero, (int2)(iw * CI_SLICES + (ci + 4) / 4, in * IH + ih));\n" \ +" FLT4 src_f4;\n" \ +" if (offset == 0) {\n" \ +" src_f4 = (FLT4)(src0.x, src0.y, src0.z, src0.w);\n" \ +" } else if (offset == 1) {\n" \ +" src_f4 = (FLT4)(src0.y, src0.z, src0.w, src1.x);\n" \ +" } else if (offset == 2) {\n" \ +" src_f4 = (FLT4)(src0.z, src0.w, src1.x, src1.y);\n" \ +" } else { // if (offset==3)\n" \ +" src_f4 = (FLT4)(src0.w, src1.x, src1.y, src1.z);\n" \ +" }\n" \ +" FLT src[4] = {src_f4.x, src_f4.y, src_f4.z, src_f4.w};\n" \ +" FLT out[4] = {constant_value, constant_value, constant_value, constant_value};\n" \ +" for (int i = 0; i < 4; ++i) {\n" \ +" if (ci + i >= 0 && ci + i < CI) {\n" \ +" out[i] = src[i];\n" \ +" }\n" \ +" }\n" \ +" FLT4 out_f4 = (FLT4)(out[0], out[1], out[2], out[3]);\n" \ +" WRITE_IMAGE(output, (int2)(ow * CO_SLICES + co_slice, on_oh), out_f4);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/pooling2d.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/pooling2d.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..05cfaf78123e7cfb7991b8843382c47297840de5 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/pooling2d.cl.inc @@ -0,0 +1,134 @@ +static const char *pooling2d_source ="\n" +"#ifdef cl_khr_fp16\n" \ +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#endif\n" \ +"#define divide_no_check(a, b) (a / b)\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void AvgPooling2d_NHWC4_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,\n" \ +" const int4 output_shape, const int2 stride, const int2 kernel_size,\n" \ +" const int2 padding) {\n" \ +" // axis to dst tensor coordinate\n" \ +" int X = get_global_id(2); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(0); // C4\n" \ +" int N = X / output_shape.y;\n" \ +" X = X % output_shape.y;\n" \ +" // boundary check\n" \ +" if (N >= output_shape.x || X >= output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 r = (FLT4)(0.0f);\n" \ +" FLT window_size = 0.0f;\n" \ +" int xs = X * stride.x - padding.x;\n" \ +" int ys = Y * stride.y - padding.y;\n" \ +"\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = ys + ky;\n" \ +" bool outside_y = y_c < 0 || y_c >= input_shape.z;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = xs + kx;\n" \ +" bool outside = outside_y || x_c < 0 || x_c >= input_shape.y;\n" \ +" r +=\n" \ +" !outside ? READ_IMAGE(input, smp_zero, (int2)(y_c * input_shape.w + Z, N * input_shape.y + x_c)) : (FLT4)(0.0f);\n" \ +" window_size += !outside ? 1.0f : 0.0f;\n" \ +" }\n" \ +" }\n" \ +" FLT4 result = TO_FLT4(divide_no_check(r, window_size));\n" \ +" WRITE_IMAGE(output, (int2)(Y * output_shape.w + Z, N * output_shape.y + X), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void AvgPooling2d_ReLU_NHWC4_IMG(__read_only image2d_t input, __write_only image2d_t output,\n" \ +" const int4 input_shape, const int4 output_shape, const int2 stride,\n" \ +" const int2 kernel_size, const int2 padding) {\n" \ +" // axis to dst tensor coordinate\n" \ +" int X = get_global_id(2); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(0); // C4\n" \ +" int N = X / output_shape.y;\n" \ +" X = X % output_shape.y;\n" \ +" // boundary check\n" \ +" if (N >= output_shape.x || X >= output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 r = (FLT4)(0.0f);\n" \ +" FLT window_size = 0.0f;\n" \ +" int xs = X * stride.x - padding.x;\n" \ +" int ys = Y * stride.y - padding.y;\n" \ +"\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = ys + ky;\n" \ +" bool outside_y = y_c < 0 || y_c >= input_shape.z;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = xs + kx;\n" \ +" bool outside = outside_y || x_c < 0 || x_c >= input_shape.y;\n" \ +" r +=\n" \ +" !outside ? READ_IMAGE(input, smp_zero, (int2)(y_c * input_shape.w + Z, N * input_shape.y + x_c)) : (FLT4)(0.0f);\n" \ +" window_size += !outside ? 1.0f : 0.0f;\n" \ +" }\n" \ +" }\n" \ +" FLT4 result = TO_FLT4(divide_no_check(r, window_size));\n" \ +" WRITE_IMAGE(output, (int2)(Y * output_shape.w + Z, N * output_shape.y + X), max(result, (FLT4)(0.f)));\n" \ +"}\n" \ +"\n" \ +"__kernel void MaxPooling2d_NHWC4_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,\n" \ +" const int4 output_shape, const int2 stride, const int2 kernel_size,\n" \ +" const int2 padding) {\n" \ +" // axis to dst tensor coordinate\n" \ +" int X = get_global_id(2); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(0); // C4\n" \ +" int N = X / output_shape.y;\n" \ +" X = X % output_shape.y;\n" \ +" // boundary check\n" \ +" if (N >= output_shape.x || X >= output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 maximum = (FLT4)(-10000.0f);\n" \ +" int xs = X * stride.x - padding.x;\n" \ +" int ys = Y * stride.y - padding.y;\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = ys + ky;\n" \ +" if (y_c < 0 || y_c >= input_shape.z) continue;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = xs + kx;\n" \ +" if (x_c < 0 || x_c >= input_shape.y) continue;\n" \ +" FLT4 src = READ_IMAGE(input, smp_zero, (int2)(y_c * input_shape.w + Z, N * input_shape.y + x_c));\n" \ +" maximum = max(src, maximum);\n" \ +" }\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)(Y * output_shape.w + Z, N * output_shape.y + X), maximum);\n" \ +"}\n" \ +"\n" \ +"__kernel void MaxPooling2d_ReLU_NHWC4_IMG(__read_only image2d_t input, __write_only image2d_t output,\n" \ +" const int4 input_shape, const int4 output_shape, const int2 stride,\n" \ +" const int2 kernel_size, const int2 padding) {\n" \ +" // axis to dst tensor coordinate\n" \ +" int X = get_global_id(2); // N*H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(0); // C4\n" \ +" int N = X / output_shape.y;\n" \ +" X = X % output_shape.y;\n" \ +" // boundary check\n" \ +" if (N >= output_shape.x || X >= output_shape.y || Y >= output_shape.z || Z >= output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 maximum = (FLT4)(-10000.0f);\n" \ +" int xs = X * stride.x - padding.x;\n" \ +" int ys = Y * stride.y - padding.y;\n" \ +" for (int ky = 0; ky < kernel_size.y; ++ky) {\n" \ +" int y_c = ys + ky;\n" \ +" if (y_c < 0 || y_c >= input_shape.z) continue;\n" \ +" for (int kx = 0; kx < kernel_size.x; ++kx) {\n" \ +" int x_c = xs + kx;\n" \ +" if (x_c < 0 || x_c >= input_shape.y) continue;\n" \ +" FLT4 src = READ_IMAGE(input, smp_zero, (int2)(y_c * input_shape.w + Z, N * input_shape.y + x_c));\n" \ +" maximum = max(src, maximum);\n" \ +" }\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)(Y * output_shape.w + Z, N * output_shape.y + X), max(maximum, (FLT4)(0.f)));\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/power.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/power.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..dc5a13707036d3cb61ac2d3e577a3e85e088ea28 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/power.cl.inc @@ -0,0 +1,83 @@ +static const char *power_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n" \ +"#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))\n" \ +"#define C4NUM 4\n" \ +"#define CHECK_IDX \\\n" \ +" int X = get_global_id(0); \\\n" \ +" int Y = get_global_id(1); \\\n" \ +" int Z = get_global_id(2); \\\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w || output_shape.y == 0) { \\\n" \ +" return; \\\n" \ +" }\n" \ +"\n" \ +"FLT OptimizedPowerImpl(FLT x, int exponent) {\n" \ +" int exp = abs(exponent);\n" \ +" FLT result = 1.0f;\n" \ +" FLT iterator = x;\n" \ +" while (exp) {\n" \ +" if (exp % 2) {\n" \ +" result *= iterator;\n" \ +" }\n" \ +" iterator *= iterator;\n" \ +" exp = exp / 2;\n" \ +" }\n" \ +" return exponent >= 0 ? result : 1 / result;\n" \ +"}\n" \ +"\n" \ +"__kernel void power(__read_only image2d_t input0, __read_only image2d_t input1, __write_only image2d_t output,\n" \ +" int4 output_shape, FLT4 parameter) {\n" \ +" CHECK_IDX;\n" \ +" int n = X / output_shape.y;\n" \ +" int h = X % output_shape.y;\n" \ +" FLT4 result;\n" \ +" FLT4 result0 = READ_IMAGE(input0, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)));\n" \ +" FLT4 result1 = READ_IMAGE(input1, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)));\n" \ +"\n" \ +" FLT tmp_result[4];\n" \ +" FLT tmp_result0[4] = {result0.x, result0.y, result0.z, result0.w};\n" \ +" FLT tmp_result1[4] = {result1.x, result1.y, result1.z, result1.w};\n" \ +"\n" \ +" for (int i = 0; i < 4; ++i) {\n" \ +" tmp_result0[i] = tmp_result0[i] * parameter.z + parameter.y;\n" \ +" if (floor(tmp_result1[i]) == tmp_result1[i]) {\n" \ +" int exponent = tmp_result1[i];\n" \ +" tmp_result[i] = OptimizedPowerImpl(tmp_result0[i], exponent);\n" \ +" } else {\n" \ +" tmp_result[i] = pow(tmp_result0[i], tmp_result1[i]);\n" \ +" }\n" \ +" }\n" \ +" result.x = tmp_result[0];\n" \ +" result.y = tmp_result[1];\n" \ +" result.z = tmp_result[2];\n" \ +" result.w = tmp_result[3];\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void power_broadcast(__read_only image2d_t input, __write_only image2d_t output, int4 output_shape,\n" \ +" FLT4 parameter) {\n" \ +" CHECK_IDX;\n" \ +" int n = X / output_shape.y;\n" \ +" int h = X % output_shape.y;\n" \ +" FLT4 result;\n" \ +" FLT4 result0 = READ_IMAGE(input, smp_none, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)));\n" \ +" FLT tmp_result0[4] = {result0.x, result0.y, result0.z, result0.w};\n" \ +" FLT tmp_result[4];\n" \ +"\n" \ +" bool flag = floor(parameter.x) == parameter.x ? false : true;\n" \ +" for (int i = 0; i < 4; ++i) {\n" \ +" tmp_result0[i] = tmp_result0[i] * parameter.z + parameter.y;\n" \ +" if (flag) {\n" \ +" int exponent = parameter.x;\n" \ +" tmp_result[i] = OptimizedPowerImpl(tmp_result0[i], exponent);\n" \ +" } else {\n" \ +" tmp_result[i] = pow(tmp_result0[i], parameter.x);\n" \ +" }\n" \ +" }\n" \ +" result.x = tmp_result[0];\n" \ +" result.y = tmp_result[1];\n" \ +" result.z = tmp_result[2];\n" \ +" result.w = tmp_result[3];\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (n * output_shape.y + h)), result);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/prelu.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/prelu.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..e5e6f110742205fbd959ca1bbbb6b2a26ea14643 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/prelu.cl.inc @@ -0,0 +1,63 @@ +static const char *prelu_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"#define NHWC4 2\n" \ +"\n" \ +"__kernel void PRelu_scalar(__read_only image2d_t input, __write_only image2d_t output, float weight, int4 shape,\n" \ +" int data_format) {\n" \ +" int nh = get_global_id(0);\n" \ +" int w = get_global_id(1);\n" \ +" int c = get_global_id(2);\n" \ +" if (nh >= shape.x * shape.y || w >= shape.z || c >= shape.w || shape.y == 0) {\n" \ +" return;\n" \ +" }\n" \ +" int n = nh / shape.y;\n" \ +" int h = nh % shape.y;\n" \ +" int x = w * shape.w + c;\n" \ +" int y = n * shape.y + h;\n" \ +" FLT4 out = READ_IMAGE(input, smp_zero, (int2)(x, y));\n" \ +" if (out.x < 0) {\n" \ +" out.x *= weight;\n" \ +" }\n" \ +" if (out.y < 0) {\n" \ +" out.y *= weight;\n" \ +" }\n" \ +" if (out.z < 0) {\n" \ +" out.z *= weight;\n" \ +" }\n" \ +" if (out.w < 0) {\n" \ +" out.w *= weight;\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)(x, y), out);\n" \ +"}\n" \ +"\n" \ +"__kernel void PRelu_vector(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight_vector,\n" \ +" int4 shape, int data_format) {\n" \ +" int nh = get_global_id(0);\n" \ +" int w = get_global_id(1);\n" \ +" int c = get_global_id(2);\n" \ +" if (nh >= shape.x * shape.y || w >= shape.z || c >= shape.w || shape.y == 0) {\n" \ +" return;\n" \ +" }\n" \ +" int n = nh / shape.y;\n" \ +" int h = nh % shape.y;\n" \ +" int x = w * shape.w + c;\n" \ +" int y = n * shape.y + h;\n" \ +" FLT4 weight = weight_vector[c];\n" \ +"\n" \ +" FLT4 out = READ_IMAGE(input, smp_zero, (int2)(x, y));\n" \ +" if (out.x < 0) {\n" \ +" out.x *= weight.x;\n" \ +" }\n" \ +" if (out.y < 0) {\n" \ +" out.y *= weight.y;\n" \ +" }\n" \ +" if (out.z < 0) {\n" \ +" out.z *= weight.z;\n" \ +" }\n" \ +" if (out.w < 0) {\n" \ +" out.w *= weight.w;\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)(x, y), out);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/reduce.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/reduce.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..2b5f8daa27c47dec33cdb3f0805a0aafb2918c71 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/reduce.cl.inc @@ -0,0 +1,329 @@ +static const char *reduce_source ="\n" +"#ifdef cl_khr_fp16\n" \ +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#endif\n" \ +"#define LOCAL_CACHE_THREAD 16\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void GlobalHWMean(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" if (X >= size.z) {\n" \ +" return;\n" \ +" }\n" \ +" float4 result = (float4)0.f;\n" \ +" for (int h = 0; h < size.x; h++) {\n" \ +" for (int w = 0; w < size.y; w++) {\n" \ +" result += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + X, h)));\n" \ +" }\n" \ +" }\n" \ +" result /= size.x * size.y;\n" \ +" WRITE_IMAGE(dst_data, (int2)(X, 0), TO_FLT4(result));\n" \ +"}\n" \ +"\n" \ +"__kernel void LocalHWMean(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int localy = get_local_id(1);\n" \ +" int localz = get_local_id(2);\n" \ +" if (X >= size.z) return;\n" \ +" __local float4 temp[LOCAL_CACHE_THREAD][LOCAL_CACHE_THREAD];\n" \ +" temp[localy][localz] = (float4)0.f;\n" \ +" for (int h = localy; h < size.x; h += LOCAL_CACHE_THREAD) {\n" \ +" for (int w = localz; w < size.y; w += LOCAL_CACHE_THREAD) {\n" \ +" temp[localy][localz] += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + X, h)));\n" \ +" }\n" \ +" }\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" if (localz == 0) {\n" \ +" for (int i = 1; i < LOCAL_CACHE_THREAD; i++) {\n" \ +" temp[localy][0] += temp[localy][i];\n" \ +" }\n" \ +" }\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" float4 result = temp[0][0];\n" \ +" for (int i = 1; i < LOCAL_CACHE_THREAD; i++) {\n" \ +" result += temp[i][0];\n" \ +" }\n" \ +" result /= size.x * size.y;\n" \ +" WRITE_IMAGE(dst_data, (int2)(X, 0), TO_FLT4(result));\n" \ +"}\n" \ +"\n" \ +"__kernel void GlobalWCMean(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size, float4 mask) {\n" \ +" int X = get_global_id(0); // H\n" \ +" if (X >= size.x) {\n" \ +" return;\n" \ +" }\n" \ +" float4 result = (float4)0.f;\n" \ +" for (int w = 0; w < size.y; w++) {\n" \ +" for (int c = 0; c < size.z; c++) {\n" \ +" result += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c, X)));\n" \ +" }\n" \ +" }\n" \ +"\n" \ +" result /= size.y * size.w;\n" \ +" FLT4 result2 = (FLT4)(0.f);\n" \ +" result2.x = dot(TO_FLT4(result), (FLT4)(1.f));\n" \ +" WRITE_IMAGE(dst_data, (int2)(0, X), result2);\n" \ +"}\n" \ +"\n" \ +"__kernel void LocalWCMean(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size, float4 mask) {\n" \ +" int X = get_global_id(0); // H\n" \ +" int localy = get_local_id(1);\n" \ +" int localz = get_local_id(2);\n" \ +" if (X >= size.x) return;\n" \ +" __local float4 temp[LOCAL_CACHE_THREAD][LOCAL_CACHE_THREAD];\n" \ +" temp[localy][localz] = (float4)0.f;\n" \ +" for (int w = localy; w < size.y; w += LOCAL_CACHE_THREAD) {\n" \ +" for (int c = localz; c < size.z; c += LOCAL_CACHE_THREAD) {\n" \ +" temp[localy][localz] += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c, X)));\n" \ +" }\n" \ +" }\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" if (localz == 0) {\n" \ +" for (int i = 1; i < LOCAL_CACHE_THREAD; i++) {\n" \ +" temp[localy][0] += temp[localy][i];\n" \ +" }\n" \ +" }\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" float4 result = temp[0][0];\n" \ +" for (int i = 1; i < LOCAL_CACHE_THREAD; i++) {\n" \ +" result += temp[i][0];\n" \ +" }\n" \ +" result /= size.y * size.w;\n" \ +" FLT4 result2 = (FLT4)(0.f);\n" \ +" result2.x = dot(TO_FLT4(result), (FLT4)(1.f));\n" \ +" WRITE_IMAGE(dst_data, (int2)(0, X), result2);\n" \ +"}\n" \ +"\n" \ +"__kernel void GlobalHWSumSquare(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {\n" \ +" int X = get_global_id(0);\n" \ +" if (X >= size.z) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = (FLT4)0.f;\n" \ +" for (int h = 0; h < size.x; h++) {\n" \ +" for (int w = 0; w < size.y; w++) {\n" \ +" FLT4 current = READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + X, h));\n" \ +" result += current * current;\n" \ +" }\n" \ +" }\n" \ +" WRITE_IMAGE(dst_data, (int2)(X, 0), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void LocalHWSumSquare(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) {\n" \ +" int X = get_global_id(0);\n" \ +" int localy = get_local_id(1);\n" \ +" int localz = get_local_id(2);\n" \ +" if (X >= size.z) return;\n" \ +" __local FLT4 temp[LOCAL_CACHE_THREAD][LOCAL_CACHE_THREAD];\n" \ +" temp[localy][localz] = (FLT4)0.f;\n" \ +" for (int h = localy; h < size.x; h += LOCAL_CACHE_THREAD) {\n" \ +" for (int w = localz; w < size.y; w += LOCAL_CACHE_THREAD) {\n" \ +" FLT4 current = READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + X, h));\n" \ +" temp[localy][localz] += current * current;\n" \ +" }\n" \ +" }\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" if (localz == 0) {\n" \ +" for (int i = 1; i < LOCAL_CACHE_THREAD; i++) {\n" \ +" temp[localy][0] += temp[localy][i];\n" \ +" }\n" \ +" }\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" FLT4 result = temp[0][0];\n" \ +" for (int i = 1; i < LOCAL_CACHE_THREAD; i++) {\n" \ +" result += temp[i][0];\n" \ +" }\n" \ +" WRITE_IMAGE(dst_data, (int2)(X, 0), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void GlobalWCSumSquare(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size,\n" \ +" float4 mask) {\n" \ +" int X = get_global_id(0);\n" \ +" if (X >= size.x) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = (FLT4)0.f;\n" \ +" for (int w = 0; w < size.y; w++) {\n" \ +" for (int c = 0; c < size.z; c++) {\n" \ +" FLT4 current = READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c, X));\n" \ +" result += current * current;\n" \ +" }\n" \ +" }\n" \ +"\n" \ +" FLT4 result2 = (FLT4)(0.f);\n" \ +" result2.x = dot(result, (FLT4)(1.f));\n" \ +" WRITE_IMAGE(dst_data, (int2)(0, X), result2);\n" \ +"}\n" \ +"\n" \ +"__kernel void LocalWCSumSquare(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size,\n" \ +" float4 mask) {\n" \ +" int X = get_global_id(0);\n" \ +" int localy = get_local_id(1);\n" \ +" int localz = get_local_id(2);\n" \ +" if (X >= size.x) return;\n" \ +" __local FLT4 temp[LOCAL_CACHE_THREAD][LOCAL_CACHE_THREAD];\n" \ +" temp[localy][localz] = (FLT4)0.f;\n" \ +" for (int w = localy; w < size.y; w += LOCAL_CACHE_THREAD) {\n" \ +" for (int c = localz; c < size.z; c += LOCAL_CACHE_THREAD) {\n" \ +" FLT4 current = READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c, X));\n" \ +" temp[localy][localz] += current * current;\n" \ +" }\n" \ +" }\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" if (localz == 0) {\n" \ +" for (int i = 1; i < LOCAL_CACHE_THREAD; i++) {\n" \ +" temp[localy][0] += temp[localy][i];\n" \ +" }\n" \ +" }\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" FLT4 result = temp[0][0];\n" \ +" for (int i = 1; i < LOCAL_CACHE_THREAD; i++) {\n" \ +" result += temp[i][0];\n" \ +" }\n" \ +" FLT4 result2 = (FLT4)(0.f);\n" \ +" result2.x = dot(result, (FLT4)(1.f));\n" \ +" WRITE_IMAGE(dst_data, (int2)(0, X), result2);\n" \ +"}\n" \ +"\n" \ +"__kernel void GlobalCMean(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size, float4 mask) {\n" \ +" int X = get_global_id(0); // H\n" \ +" int Y = get_global_id(1); // W\n" \ +" if (X >= size.x || Y >= size.y) {\n" \ +" return;\n" \ +" }\n" \ +" float4 result = (float4)0.f;\n" \ +" for (int c = 0; c < size.z; c++) {\n" \ +" result += convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(Y * size.z + c, X)));\n" \ +" }\n" \ +"\n" \ +" result /= size.w;\n" \ +" FLT4 result2 = (FLT4)(0.f);\n" \ +" result2.x = dot(TO_FLT4(result), (FLT4)(1.f));\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y, X), result2);\n" \ +"}\n" \ +"\n" \ +"#define GlobalHW(Method) \\\n" \ +" __kernel void GlobalHW##Method(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) { \\\n" \ +" int X = get_global_id(0); \\\n" \ +" if (X >= size.z) { \\\n" \ +" return; \\\n" \ +" } \\\n" \ +" FLT4 result = (FLT4)Init##Method; \\\n" \ +" for (int h = 0; h < size.x; h++) { \\\n" \ +" for (int w = 0; w < size.y; w++) { \\\n" \ +" FLT4 current = READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + X, h)); \\\n" \ +" Do##Method(result, current); \\\n" \ +" } \\\n" \ +" } \\\n" \ +" WRITE_IMAGE(dst_data, (int2)(X, 0), result); \\\n" \ +" }\n" \ +"\n" \ +"#define GlobalWC(Method) \\\n" \ +" __kernel void GlobalWC##Method(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size, \\\n" \ +" float4 mask) { \\\n" \ +" int X = get_global_id(0); \\\n" \ +" if (X >= size.x) { \\\n" \ +" return; \\\n" \ +" } \\\n" \ +" FLT4 result = (FLT4)Init##Method; \\\n" \ +" FLT4 maskFLT4 = TO_FLT4(mask); \\\n" \ +" for (int w = 0; w < size.y; w++) { \\\n" \ +" int c = 0; \\\n" \ +" for (; c < size.z - 1; c++) { \\\n" \ +" FLT4 current = READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c, X)); \\\n" \ +" Do##Method(result, current); \\\n" \ +" } \\\n" \ +" FLT4 current = READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c, X)); \\\n" \ +" current += maskFLT4; \\\n" \ +" Do##Method(result, current); \\\n" \ +" } \\\n" \ +" Do##Method(result.x, result.y); \\\n" \ +" Do##Method(result.x, result.z); \\\n" \ +" Do##Method(result.x, result.w); \\\n" \ +" FLT4 result2 = (FLT4)(0.f); \\\n" \ +" result2.x = TO_FLT(result.x); \\\n" \ +" WRITE_IMAGE(dst_data, (int2)(0, X), result2); \\\n" \ +" }\n" \ +"\n" \ +"#define LocalHW(Method) \\\n" \ +" __kernel void LocalHW##Method(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size) { \\\n" \ +" int X = get_global_id(0); \\\n" \ +" int localy = get_local_id(1); \\\n" \ +" int localz = get_local_id(2); \\\n" \ +" if (X >= size.z) return; \\\n" \ +" __local float4 temp[LOCAL_CACHE_THREAD][LOCAL_CACHE_THREAD]; \\\n" \ +" temp[localy][localz] = (float4)Init##Method; \\\n" \ +" for (int h = localy; h < size.x; h += LOCAL_CACHE_THREAD) { \\\n" \ +" for (int w = localz; w < size.y; w += LOCAL_CACHE_THREAD) { \\\n" \ +" float4 current = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + X, h))); \\\n" \ +" Do##Method(temp[localy][localz], current); \\\n" \ +" } \\\n" \ +" } \\\n" \ +" barrier(CLK_LOCAL_MEM_FENCE); \\\n" \ +" if (localz == 0) { \\\n" \ +" for (int i = 1; i < LOCAL_CACHE_THREAD; i++) { \\\n" \ +" Do##Method(temp[localy][0], temp[localy][i]); \\\n" \ +" } \\\n" \ +" } \\\n" \ +" barrier(CLK_LOCAL_MEM_FENCE); \\\n" \ +" float4 result = temp[0][0]; \\\n" \ +" for (int i = 1; i < LOCAL_CACHE_THREAD; i++) { \\\n" \ +" Do##Method(result, temp[i][0]); \\\n" \ +" } \\\n" \ +" WRITE_IMAGE(dst_data, (int2)(X, 0), TO_FLT4(result)); \\\n" \ +" }\n" \ +"\n" \ +"#define LocalWC(Method) \\\n" \ +" __kernel void LocalWC##Method(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 size, \\\n" \ +" float4 mask) { \\\n" \ +" int X = get_global_id(0); \\\n" \ +" int localy = get_local_id(1); \\\n" \ +" int localz = get_local_id(2); \\\n" \ +" if (X >= size.x) return; \\\n" \ +" __local float4 temp[LOCAL_CACHE_THREAD][LOCAL_CACHE_THREAD]; \\\n" \ +" temp[localy][localz] = (float4)Init##Method; \\\n" \ +" for (int w = localy; w < size.y; w += LOCAL_CACHE_THREAD) { \\\n" \ +" int c = localz; \\\n" \ +" for (; c < size.z - 1; c += LOCAL_CACHE_THREAD) { \\\n" \ +" float4 current = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c, X))); \\\n" \ +" Do##Method(temp[localy][localz], current); \\\n" \ +" } \\\n" \ +" if (c == size.z - 1) { \\\n" \ +" float4 current = convert_float4(READ_IMAGE(src_data, smp_zero, (int2)(w * size.z + c, X))); \\\n" \ +" current += mask; \\\n" \ +" Do##Method(temp[localy][localz], current); \\\n" \ +" } \\\n" \ +" } \\\n" \ +" barrier(CLK_LOCAL_MEM_FENCE); \\\n" \ +" if (localz == 0) { \\\n" \ +" for (int i = 1; i < LOCAL_CACHE_THREAD; i++) { \\\n" \ +" Do##Method(temp[localy][0], temp[localy][i]); \\\n" \ +" } \\\n" \ +" } \\\n" \ +" barrier(CLK_LOCAL_MEM_FENCE); \\\n" \ +" float4 result = temp[0][0]; \\\n" \ +" for (int i = 1; i < LOCAL_CACHE_THREAD; i++) { \\\n" \ +" Do##Method(result, temp[i][0]); \\\n" \ +" } \\\n" \ +" Do##Method(result.x, result.y); \\\n" \ +" Do##Method(result.x, result.z); \\\n" \ +" Do##Method(result.x, result.w); \\\n" \ +" FLT4 result2 = (FLT4)(0.f); \\\n" \ +" result2.x = TO_FLT(result.x); \\\n" \ +" WRITE_IMAGE(dst_data, (int2)(0, X), result2); \\\n" \ +" }\n" \ +"\n" \ +"#define DoSum(A, B) A += B\n" \ +"#define InitSum 0.f\n" \ +"GlobalHW(Sum) GlobalWC(Sum) LocalHW(Sum) LocalWC(Sum)\n" \ +"#define DoMin(A, B) A = min(A, B)\n" \ +"#define InitMin 10000.f\n" \ +" GlobalHW(Min) GlobalWC(Min) LocalHW(Min) LocalWC(Min)\n" \ +"\n" \ +"#define DoMax(A, B) A = max(A, B)\n" \ +"#define InitMax -10000.f\n" \ +" GlobalHW(Max) GlobalWC(Max) LocalHW(Max) LocalWC(Max)\n" \ +"\n" \ +"#define DoProd(A, B) A *= B\n" \ +"#define InitProd 1.f\n" \ +" GlobalHW(Prod) GlobalWC(Prod) LocalHW(Prod) LocalWC(Prod)\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/reshape.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/reshape.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..a1f4cb4c4ca7e05ca5555fc6f74d107407a045e4 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/reshape.cl.inc @@ -0,0 +1,51 @@ +static const char *reshape_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#define C4NUM 4\n" \ +"#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"__kernel void reshape_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 src_size,\n" \ +" int4 dst_size) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" int CO4 = UP_DIV(dst_size.z, C4NUM);\n" \ +" int CO4_rem = dst_size.z % C4NUM;\n" \ +" if (X >= dst_size.x || Y > dst_size.y) {\n" \ +" return;\n" \ +" }\n" \ +" int CI4 = UP_DIV(src_size.x, C4NUM);\n" \ +" int CI4_rem = src_size.x % C4NUM;\n" \ +" CI4_rem = (CI4_rem == 0) ? C4NUM : CI4_rem;\n" \ +" int in_img_x = CI4 * src_size.y;\n" \ +" DTYPE4 res = (DTYPE4)(0.0f);\n" \ +" DTYPE tmp[4];\n" \ +" DTYPE res_tmp[4];\n" \ +" int gcnt = 0;\n" \ +" if (CO4_rem == 0 && ((CI4_rem & 0x3) == 0)) {\n" \ +" gcnt = X + dst_size.x * Y;\n" \ +" res = READ_IMAGE(src_data, smp_zero, (int2)(gcnt % in_img_x, gcnt / in_img_x));\n" \ +" WRITE_IMAGE(dst_data, (int2)(X, Y), res);\n" \ +" } else {\n" \ +" int start = ((X / CO4 * dst_size.z + min(dst_size.z, (X % CO4) * C4NUM)) + dst_size.w * Y);\n" \ +" gcnt = start / src_size.x * CI4 + (start % src_size.x) / C4NUM;\n" \ +" start = start % src_size.x % C4NUM;\n" \ +" for (int i = 0, n = 0, j = start; i < C4NUM; ++n, j = 0) {\n" \ +" int X_src = (gcnt + n) % in_img_x;\n" \ +" res = READ_IMAGE(src_data, smp_zero, (int2)(X_src, (gcnt + n) / in_img_x));\n" \ +" tmp[0] = res.x;\n" \ +" tmp[1] = res.y;\n" \ +" tmp[2] = res.z;\n" \ +" tmp[3] = res.w;\n" \ +" int k = (X_src % CI4) == (CI4 - 1) ? CI4_rem : C4NUM;\n" \ +" for (; j < k && i < C4NUM; ++j, ++i) {\n" \ +" res_tmp[i] = tmp[j];\n" \ +" }\n" \ +" }\n" \ +" res.x = res_tmp[0];\n" \ +" res.y = res_tmp[1];\n" \ +" res.z = res_tmp[2];\n" \ +" res.w = res_tmp[3];\n" \ +" WRITE_IMAGE(dst_data, (int2)(X, Y), res);\n" \ +" }\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/resize.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/resize.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..01d7d4c8184415fda9845c5fa42abaa9b92c07ee --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/resize.cl.inc @@ -0,0 +1,83 @@ +static const char *resize_source ="\n" +"#ifdef cl_khr_fp16\n" \ +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#endif\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void resize_nearest_neighbor_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data,\n" \ +" int4 in_size, int4 out_size, float2 scale_factor) {\n" \ +" int X = get_global_id(2); // H * N\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(0); // C4\n" \ +" if (X >= out_size.x * out_size.y || Y >= out_size.z || Z >= out_size.w) {\n" \ +" return;\n" \ +" }\n" \ +" int N = X / out_size.y;\n" \ +" X = X % out_size.y;\n" \ +" int src_x = (int)(X * scale_factor.x);\n" \ +" int src_y = (int)(Y * scale_factor.y);\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * out_size.w + Z, N * out_size.y + X),\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(src_y * in_size.w + Z, N * in_size.y + src_x)));\n" \ +"}\n" \ +"\n" \ +"__kernel void resize_nearest_neighbor_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data,\n" \ +" int4 in_size, int4 out_size, float2 scale_factor) {\n" \ +" int X = get_global_id(2); // H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(0); // C4\n" \ +" if (X >= out_size.y || Y >= out_size.z || Z >= out_size.w) {\n" \ +" return;\n" \ +" }\n" \ +" int src_x = (int)(X * scale_factor.x);\n" \ +" int src_y = (int)(Y * scale_factor.y);\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y, Z * out_size.y + X),\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(src_y, Z * in_size.y + src_x)));\n" \ +"}\n" \ +"\n" \ +"__kernel void resize_bilinear_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_size,\n" \ +" int4 out_size, float2 scale_factor) {\n" \ +" int X = get_global_id(2); // H * N\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(0); // C4\n" \ +" if (X >= out_size.x * out_size.y || Y >= out_size.z || Z >= out_size.w) {\n" \ +" return;\n" \ +" }\n" \ +" int N = X / out_size.y;\n" \ +" X = X % out_size.y;\n" \ +" float scale_x = X * scale_factor.x;\n" \ +" float scale_y = Y * scale_factor.y;\n" \ +" int src_x = (int)(scale_x);\n" \ +" int src_y = (int)(scale_y);\n" \ +" int src_x_1 = min(src_x + 1, in_size.y - 1);\n" \ +" int src_y_1 = min(src_y + 1, in_size.z - 1);\n" \ +" FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(src_y * in_size.w + Z, N * in_size.y + src_x));\n" \ +" FLT4 src1 = READ_IMAGE(src_data, smp_zero, (int2)(src_y_1 * in_size.w + Z, N * in_size.y + src_x));\n" \ +" FLT4 src2 = READ_IMAGE(src_data, smp_zero, (int2)(src_y * in_size.w + Z, N * in_size.y + src_x_1));\n" \ +" FLT4 src3 = READ_IMAGE(src_data, smp_zero, (int2)(src_y_1 * in_size.w + Z, N * in_size.y + src_x_1));\n" \ +" FLT4 result =\n" \ +" mix(mix(src0, src1, TO_FLT(scale_y - src_y)), mix(src2, src3, TO_FLT(scale_y - src_y)), TO_FLT(scale_x - src_x));\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * out_size.w + Z, N * out_size.y + X), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void resize_bilinear_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_size,\n" \ +" int4 out_size, float2 scale_factor) {\n" \ +" int X = get_global_id(2); // H\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(0); // C4\n" \ +" if (X >= out_size.y || Y >= out_size.z || Z >= out_size.w) {\n" \ +" return;\n" \ +" }\n" \ +" float scale_x = X * scale_factor.x;\n" \ +" float scale_y = Y * scale_factor.y;\n" \ +" int src_x = (int)(scale_x);\n" \ +" int src_y = (int)(scale_y);\n" \ +" int src_x_1 = min(src_x + 1, in_size.y - 1);\n" \ +" int src_y_1 = min(src_y + 1, in_size.z - 1);\n" \ +" FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(src_y, in_size.y * Z + src_x));\n" \ +" FLT4 src1 = READ_IMAGE(src_data, smp_zero, (int2)(src_y_1, in_size.y * Z + src_x));\n" \ +" FLT4 src2 = READ_IMAGE(src_data, smp_zero, (int2)(src_y, in_size.y * Z + src_x_1));\n" \ +" FLT4 src3 = READ_IMAGE(src_data, smp_zero, (int2)(src_y_1, in_size.y * Z + src_x_1));\n" \ +" FLT4 result =\n" \ +" mix(mix(src0, src1, TO_FLT(scale_y - src_y)), mix(src2, src3, TO_FLT(scale_y - src_y)), TO_FLT(scale_x - src_x));\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y, out_size.w * Z + X), result);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/scale.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/scale.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..5f533f47913cc09912af40f2baa51a6742ca9f01 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/scale.cl.inc @@ -0,0 +1,101 @@ +static const char *scale_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"#define C4NUM 4\n" \ +"\n" \ +"__kernel void Scale_IMG(__read_only image2d_t input, __read_only image2d_t scale, __read_only image2d_t offset,\n" \ +" __write_only image2d_t output, const int2 output_shape, const int act_type) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 in = READ_IMAGE(input, smp_none, (int2)(X, Y));\n" \ +" FLT4 s = READ_IMAGE(scale, smp_none, (int2)(X, Y));\n" \ +" FLT4 o = READ_IMAGE(offset, smp_none, (int2)(X, Y));\n" \ +" FLT4 out = in * s + o;\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out = max(out, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out = clamp(out, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), out);\n" \ +"}\n" \ +"\n" \ +"__kernel void BoardcastScale_IMG(__read_only image2d_t input, float scale, float offset, __write_only image2d_t output,\n" \ +" const int2 output_shape, const int act_type) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 in = READ_IMAGE(input, smp_none, (int2)(X, Y));\n" \ +" FLT4 out = in * (FLT)scale + (FLT)offset;\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out = max(out, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out = clamp(out, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), out);\n" \ +"}\n" \ +"\n" \ +"__kernel void Scale_C_IMG(__read_only image2d_t input, __read_only image2d_t scale, __read_only image2d_t offset,\n" \ +" __write_only image2d_t output, const int2 output_shape, const int C, const int act_type) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y || C == 0) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 in = READ_IMAGE(input, smp_none, (int2)(X, Y));\n" \ +" FLT4 s = READ_IMAGE(scale, smp_none, (int2)(X % C, 0));\n" \ +" FLT4 o = READ_IMAGE(offset, smp_none, (int2)(X % C, 0));\n" \ +" FLT4 out = in * s + o;\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out = max(out, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out = clamp(out, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), out);\n" \ +"}\n" \ +"\n" \ +"__kernel void Scale_H_IMG(__read_only image2d_t input, __read_only image2d_t scale, __read_only image2d_t offset,\n" \ +" __write_only image2d_t output, const int2 output_shape, const int H, const int act_type) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= output_shape.x || Y >= output_shape.y || H == 0) {\n" \ +" return;\n" \ +" }\n" \ +" int h = Y % H;\n" \ +" int h_quotient = h / C4NUM;\n" \ +" int h_remainder = h % C4NUM;\n" \ +" FLT4 in = READ_IMAGE(input, smp_none, (int2)(X, Y));\n" \ +" FLT4 s = READ_IMAGE(scale, smp_none, (int2)(h_quotient, 0));\n" \ +" FLT4 o = READ_IMAGE(offset, smp_none, (int2)(h_quotient, 0));\n" \ +" FLT s_real;\n" \ +" FLT o_real;\n" \ +" if (h_remainder == 0) {\n" \ +" s_real = s.x;\n" \ +" o_real = o.x;\n" \ +" } else if (h_remainder == 1) {\n" \ +" s_real = s.y;\n" \ +" o_real = o.y;\n" \ +" } else if (h_remainder == 2) {\n" \ +" s_real = s.z;\n" \ +" o_real = o.z;\n" \ +" } else {\n" \ +" s_real = s.w;\n" \ +" o_real = o.w;\n" \ +" }\n" \ +" FLT4 out = in * s_real + o_real;\n" \ +" if (act_type == ActivationType_RELU) {\n" \ +" out = max(out, (FLT4)(0.0f));\n" \ +" } else if (act_type == ActivationType_RELU6) {\n" \ +" out = clamp(out, (FLT4)(0.0f), (FLT4)(6.0f));\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)(X, Y), out);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/softmax.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/softmax.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..f56cb4b0ecc2aa451f0b70d215fc9ec0396cab6b --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/softmax.cl.inc @@ -0,0 +1,137 @@ +static const char *softmax_source ="\n" +"#ifdef cl_khr_fp16\n" \ +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#endif\n" \ +"#define divide_no_check(a, b) (a / b)\n" \ +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void SoftmaxAxis3_NHWC4(__read_only image2d_t input, __write_only image2d_t output, const float4 mask,\n" \ +" const int4 input_shape) {\n" \ +" int X = get_global_id(1); // H\n" \ +" int Y = get_global_id(0); // W\n" \ +" int n = get_global_id(2); // N\n" \ +" int H = input_shape.y;\n" \ +" int W = input_shape.z;\n" \ +" int C4 = input_shape.w;\n" \ +"\n" \ +" if (n >= input_shape.x || X >= H || Y >= W) return;\n" \ +"\n" \ +" // get max\n" \ +" float4 last = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, n * H + X)));\n" \ +" float input_max = last.x;\n" \ +" if (mask.y > 0.5f) input_max = max(input_max, last.y);\n" \ +" if (mask.z > 0.5f) input_max = max(input_max, last.z);\n" \ +" if (mask.w > 0.5f) input_max = max(input_max, last.w);\n" \ +" for (int d = 0; d < C4 - 1; ++d) {\n" \ +" float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, n * H + X)));\n" \ +" input_max = max(input_max, t.x);\n" \ +" input_max = max(input_max, t.y);\n" \ +" input_max = max(input_max, t.z);\n" \ +" input_max = max(input_max, t.w);\n" \ +" }\n" \ +" float4 input_max_f4 = (float4)(input_max, input_max, input_max, input_max);\n" \ +"\n" \ +" float sum = 0.0f;\n" \ +" for (int d = 0; d < C4 - 1; ++d) {\n" \ +" float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, n * H + X)));\n" \ +" sum += dot(exp(t - input_max_f4), (float4)(1.f));\n" \ +" }\n" \ +" float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, n * H + X)));\n" \ +" sum += dot(exp(min(t - input_max_f4, (float4)(0.f))), mask);\n" \ +" for (int d = 0; d < C4 - 1; ++d) {\n" \ +" float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + d, n * H + X)));\n" \ +" result = exp(result - input_max_f4) / sum;\n" \ +" WRITE_IMAGEOUT(output, (int2)(Y * C4 + d, n * H + X), OUT_FLT4(result));\n" \ +" }\n" \ +" float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(Y * C4 + C4 - 1, n * H + X)));\n" \ +" result = exp(min(result - input_max_f4, (float4)(0.f))) / sum;\n" \ +" result = result * mask;\n" \ +" WRITE_IMAGEOUT(output, (int2)(Y * C4 + C4 - 1, n * H + X), OUT_FLT4(result));\n" \ +"}\n" \ +"\n" \ +"__kernel void SoftmaxAxis1_NHWC4(__read_only image2d_t input, __write_only image2d_t output, const float4 mask,\n" \ +" const int4 input_shape) {\n" \ +" int X = get_global_id(1); // W\n" \ +" int Y = get_global_id(0); // C4\n" \ +" int n = get_global_id(2); // N\n" \ +" int H = input_shape.y;\n" \ +" int W = input_shape.z;\n" \ +" int C4 = input_shape.w;\n" \ +"\n" \ +" if (n >= input_shape.x || X >= W || Y >= C4) return;\n" \ +"\n" \ +" float4 sum = 0.0f;\n" \ +" for (int d = 0; d < H; ++d) {\n" \ +" float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(X * C4 + Y, n * H + d)));\n" \ +" sum += exp(t);\n" \ +" }\n" \ +" for (int d = 0; d < H; ++d) {\n" \ +" float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(X * C4 + Y, n * H + d)));\n" \ +" result = exp(result) / sum;\n" \ +" WRITE_IMAGEOUT(output, (int2)(X * C4 + Y, n * H + d), OUT_FLT4(result));\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void SoftmaxAxis2_NHWC4(__read_only image2d_t input, __write_only image2d_t output, const float4 mask,\n" \ +" const int4 input_shape) {\n" \ +" int X = get_global_id(1); // H\n" \ +" int Y = get_global_id(0); // C4\n" \ +" int n = get_global_id(2); // n\n" \ +" int H = input_shape.y;\n" \ +" int W = input_shape.z;\n" \ +" int C4 = input_shape.w;\n" \ +"\n" \ +" if (n >= input_shape.x || X >= H || Y >= C4) return;\n" \ +"\n" \ +" float4 sum = 0.0f;\n" \ +" for (int d = 0; d < W; ++d) {\n" \ +" float4 t = convert_float4(READ_IMAGE(input, smp_zero, (int2)(d * C4 + Y, n * H + X)));\n" \ +" sum += exp(t);\n" \ +" }\n" \ +" for (int d = 0; d < W; ++d) {\n" \ +" float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(d * C4 + Y, n * H + X)));\n" \ +" result = exp(result) / sum;\n" \ +" WRITE_IMAGEOUT(output, (int2)(d * C4 + Y, n * H + X), OUT_FLT4(result));\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void Softmax1x1_NHWC4(__read_only image2d_t input, __write_only image2d_t output, const float4 mask,\n" \ +" const int4 input_shape) {\n" \ +" int tid = get_local_id(0);\n" \ +" int n = get_global_id(1);\n" \ +" if (n >= input_shape.x) return;\n" \ +" int C4 = input_shape.w;\n" \ +" float sum = 0.0f;\n" \ +" for (size_t i = tid; i < C4 - 1; i += 32) {\n" \ +" float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, n)));\n" \ +" sum += dot((float4)(1.0f), exp(src));\n" \ +" }\n" \ +" if ((C4 - 1) % 32 == tid) {\n" \ +" float4 src = convert_float4(READ_IMAGE(input, smp_zero, (int2)(C4 - 1, n)));\n" \ +" sum += dot(convert_float4(mask), exp(src));\n" \ +" }\n" \ +"\n" \ +" __local float4 tmp[8];\n" \ +" __local float *tmpx1 = (__local float *)tmp;\n" \ +" tmpx1[tid] = sum;\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" if (tid == 0) {\n" \ +" sum = dot((float4)(1.0f), tmp[0]);\n" \ +" sum += dot((float4)(1.0f), tmp[1]);\n" \ +" sum += dot((float4)(1.0f), tmp[2]);\n" \ +" sum += dot((float4)(1.0f), tmp[3]);\n" \ +" sum += dot((float4)(1.0f), tmp[4]);\n" \ +" sum += dot((float4)(1.0f), tmp[5]);\n" \ +" sum += dot((float4)(1.0f), tmp[6]);\n" \ +" sum += dot((float4)(1.0f), tmp[7]);\n" \ +" tmpx1[0] = divide_no_check(1.0f, sum);\n" \ +" }\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" sum = tmpx1[0];\n" \ +" for (size_t i = tid; i < C4; i += 32) {\n" \ +" float4 result = convert_float4(READ_IMAGE(input, smp_zero, (int2)(i, n)));\n" \ +" result = exp(result) * sum;\n" \ +" WRITE_IMAGEOUT(output, (int2)(i, n), OUT_FLT4(result));\n" \ +" }\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/space_to_batch_nd.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/space_to_batch_nd.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..8ce587666119ea2c3f5ed51da72b58052f2fef85 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/space_to_batch_nd.cl.inc @@ -0,0 +1,47 @@ +static const char *space_to_batch_nd_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void space_to_batch_nd_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 src_size,\n" \ +" int4 dst_size, int2 block_size, int4 paddings) {\n" \ +" int X = get_global_id(0); // c\n" \ +" int Y = get_global_id(1); // w\n" \ +" int Z = get_global_id(2); // h * n_i\n" \ +" // (N,H*BH,W*BW,C) to (BH*BW*N,H,W,C)\n" \ +" int N_I = Z / dst_size.z;\n" \ +" Z = Z % dst_size.z;\n" \ +" if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z || N_I >= src_size.w) {\n" \ +" return;\n" \ +" }\n" \ +" for (int i = 0; i < block_size.x; ++i) {\n" \ +" for (int j = 0; j < block_size.y; ++j) {\n" \ +" int w_org = Y * block_size.y + j - paddings.z;\n" \ +" int h_org = Z * block_size.x + i - paddings.x;\n" \ +" FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" if (h_org >= 0 && h_org < src_size.z)\n" \ +" res_data = READ_IMAGE(src_data, smp_zero, (int2)(w_org * dst_size.x + X, N_I * src_size.z + h_org));\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * dst_size.x + X, ((i * block_size.y + j) * src_size.w + N_I) * dst_size.z + Z),\n" \ +" res_data);\n" \ +" }\n" \ +" }\n" \ +"}\n" \ +"__kernel void space_to_batch_nd_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 src_size,\n" \ +" int4 dst_size, int2 block_size, int4 paddings) {\n" \ +" int X = get_global_id(0); // c\n" \ +" int Y = get_global_id(1); // w\n" \ +" int Z = get_global_id(2); // h\n" \ +" if (X >= dst_size.x || Y >= dst_size.y || Y >= dst_size.z) {\n" \ +" return;\n" \ +" }\n" \ +" for (int i = 0; i < block_size.x; ++i) {\n" \ +" for (int j = 0; j < block_size.y; ++j) {\n" \ +" int w_org = Y * block_size.y + j - paddings.z;\n" \ +" int h_org = Z * block_size.x + i - paddings.x;\n" \ +" FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" if (w_org >= 0 && w_org < src_size.y && h_org >= 0 && h_org < src_size.z) {\n" \ +" res_data = READ_IMAGE(src_data, smp_zero, (int2)(h_org * src_size.y + Y, X));\n" \ +" }\n" \ +" WRITE_IMAGE(dst_data, (int2)(Z * dst_size.y + Y, (i * block_size.y + j) * dst_size.x + X), res_data);\n" \ +" }\n" \ +" }\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/space_to_depth.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/space_to_depth.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..a93b6273ff583c11fc633aaf9f1e93bfe5d475e6 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/space_to_depth.cl.inc @@ -0,0 +1,114 @@ +static const char *space_to_depth_source ="\n" +"#ifdef cl_khr_fp16\n" \ +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#endif\n" \ +"\n" \ +"#define C4NUM 4\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void SpaceToDepth(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_shape,\n" \ +" int4 out_shape, int block_size, int ci_size) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // H * N\n" \ +" if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;\n" \ +" if (out_shape.y == 0 || ci_size == 0 || block_size == 0) return;\n" \ +" int N = Z / out_shape.y;\n" \ +" int H = Z % out_shape.y;\n" \ +" int co_base = X * C4NUM;\n" \ +" FLT result[C4NUM] = {0.f};\n" \ +" for (int i = 0; i < C4NUM; i++) {\n" \ +" int co = co_base + i;\n" \ +" int ci = co % ci_size;\n" \ +" int hw_block = co / ci_size;\n" \ +" int hi = H * block_size + hw_block / block_size;\n" \ +" int wi = Y * block_size + hw_block % block_size;\n" \ +" int ci4 = ci / C4NUM;\n" \ +" int ci4_ramainder = ci % C4NUM;\n" \ +" FLT4 tmp = READ_IMAGE(src_data, smp_zero, (int2)(wi * in_shape.w + ci4, N * in_shape.y + hi));\n" \ +" if (ci4_ramainder == 0) {\n" \ +" result[i] = tmp.x;\n" \ +" } else if (ci4_ramainder == 1) {\n" \ +" result[i] = tmp.y;\n" \ +" } else if (ci4_ramainder == 2) {\n" \ +" result[i] = tmp.z;\n" \ +" } else {\n" \ +" result[i] = tmp.w;\n" \ +" }\n" \ +" }\n" \ +" FLT4 result_flt4 = {result[0], result[1], result[2], result[3]};\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result_flt4);\n" \ +"}\n" \ +"\n" \ +"__kernel void SpaceToDepthAlign(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_shape,\n" \ +" int4 out_shape, int block_size, int ci_size) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // H * N\n" \ +" if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;\n" \ +" if (out_shape.y == 0 || in_shape.w == 0 || block_size == 0) return;\n" \ +"\n" \ +" int N = Z / out_shape.y;\n" \ +" int H = Z % out_shape.y;\n" \ +" int ni = N;\n" \ +" int ci = X % in_shape.w;\n" \ +" int hw_block = X / in_shape.w;\n" \ +" int hi = H * block_size + hw_block / block_size;\n" \ +" int wi = Y * block_size + hw_block % block_size;\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z),\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(wi * in_shape.w + ci, ni * in_shape.y + hi)));\n" \ +"}\n" \ +"\n" \ +"__kernel void DepthToSpace(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_shape,\n" \ +" int4 out_shape, int block_size, int co_size) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // H * N\n" \ +" if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;\n" \ +" if (out_shape.y == 0 || block_size == 0) return;\n" \ +" int N = Z / out_shape.y;\n" \ +" int H = Z % out_shape.y;\n" \ +" int co_base = X * C4NUM;\n" \ +" FLT result[C4NUM] = {0.f};\n" \ +" for (int i = 0; i < C4NUM; i++) {\n" \ +" int co = co_base + i;\n" \ +" int bh = H % block_size;\n" \ +" int hi = H / block_size;\n" \ +" int bw = Y % block_size;\n" \ +" int wi = Y / block_size;\n" \ +" int ci = (bh * block_size + bw) * co_size + co;\n" \ +" int ci4 = ci / C4NUM;\n" \ +" int ci4_ramainder = ci % C4NUM;\n" \ +" FLT4 tmp = READ_IMAGE(src_data, smp_zero, (int2)(wi * in_shape.w + ci4, N * in_shape.y + hi));\n" \ +" if (ci4_ramainder == 0) {\n" \ +" result[i] = tmp.x;\n" \ +" } else if (ci4_ramainder == 1) {\n" \ +" result[i] = tmp.y;\n" \ +" } else if (ci4_ramainder == 2) {\n" \ +" result[i] = tmp.z;\n" \ +" } else {\n" \ +" result[i] = tmp.w;\n" \ +" }\n" \ +" }\n" \ +" FLT4 result_flt4 = {result[0], result[1], result[2], result[3]};\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z), result_flt4);\n" \ +"}\n" \ +"\n" \ +"__kernel void DepthToSpaceAlign(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 in_shape,\n" \ +" int4 out_shape, int block_size, int co_size) {\n" \ +" int X = get_global_id(0); // C4\n" \ +" int Y = get_global_id(1); // W\n" \ +" int Z = get_global_id(2); // H * N\n" \ +" if (X >= out_shape.w || Y >= out_shape.z || Z >= out_shape.x * out_shape.y) return;\n" \ +" if (out_shape.y == 0 || block_size == 0) return;\n" \ +" int N = Z / out_shape.y;\n" \ +" int H = Z % out_shape.y;\n" \ +" int ni = N;\n" \ +" int bh = H % block_size;\n" \ +" int hi = H / block_size;\n" \ +" int bw = Y % block_size;\n" \ +" int wi = Y / block_size;\n" \ +" int ci = (bh * block_size + bw) * out_shape.w + X;\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * out_shape.w + X, Z),\n" \ +" READ_IMAGE(src_data, smp_zero, (int2)(wi * in_shape.w + ci, ni * in_shape.y + hi)));\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/sparse_to_dense.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/sparse_to_dense.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..2388a9ea25ae04a8b949875d96e28403dcb9bce5 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/sparse_to_dense.cl.inc @@ -0,0 +1,52 @@ +static const char *sparse_to_dense_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#define C4NUM 4\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"__kernel void SparseToDenseScalar(__read_only image2d_t input, __global DTYPE *output, float weight, int2 inputshape,\n" \ +" int4 outputshape, float default_value, int stride_w, int inshapeindex1_dim) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= inputshape.x || Y >= inputshape.y) {\n" \ +" return;\n" \ +" }\n" \ +" int4 index_input = read_imagei(input, smp_zero, (int2)(Y, X));\n" \ +" int4 index_input_int = *((int4 *)&index_input);\n" \ +" int index = 0;\n" \ +" if (inshapeindex1_dim == 1) {\n" \ +" index = (index_input_int.x) * stride_w;\n" \ +" } else if (inshapeindex1_dim == 2) {\n" \ +" index = (index_input_int.x) * stride_w + (index_input_int.y);\n" \ +" } else if (inshapeindex1_dim == 3) {\n" \ +" index = (index_input_int.x) * stride_w + (index_input_int.y) * outputshape.w * C4NUM + (index_input_int.z);\n" \ +" } else {\n" \ +" index = (index_input_int.x) * outputshape.y * stride_w + (index_input_int.y) * stride_w +\n" \ +" (index_input_int.z) * outputshape.w * C4NUM + index_input_int.w;\n" \ +" }\n" \ +" output[index] = weight;\n" \ +"}\n" \ +"\n" \ +"__kernel void SparseToDenseVector(__read_only image2d_t input, __global DTYPE *output, __global float *weight_vector,\n" \ +" int2 inputshape, int4 outputshape, float default_value, int stride_w,\n" \ +" int inshapeindex1_dim) {\n" \ +" int X = get_global_id(0);\n" \ +" int Y = get_global_id(1);\n" \ +" if (X >= inputshape.x || Y >= inputshape.y) {\n" \ +" return;\n" \ +" }\n" \ +" int4 index_input = read_imagei(input, smp_zero, (int2)(Y, X));\n" \ +" int4 index_input_int = *((int4 *)&index_input);\n" \ +" int index = 0;\n" \ +" if (inshapeindex1_dim == 1) {\n" \ +" index = (index_input_int.x) * stride_w;\n" \ +" } else if (inshapeindex1_dim == 2) {\n" \ +" index = (index_input_int.x) * stride_w + index_input_int.y;\n" \ +" } else if (inshapeindex1_dim == 3) {\n" \ +" index = (index_input_int.x) * stride_w + (index_input_int.y) * outputshape.w * C4NUM + index_input_int.z;\n" \ +" } else {\n" \ +" index = (index_input_int.x) * outputshape.y * stride_w + (index_input_int.y) * stride_w +\n" \ +" (index_input_int.z) * outputshape.w * C4NUM + index_input_int.w;\n" \ +" }\n" \ +" output[index] = weight_vector[X];\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/split.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/split.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..0d533abe7b14ba5843a562fab4cb4a20dba65121 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/split.cl.inc @@ -0,0 +1,116 @@ +static const char *split_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))\n" \ +"#define C4NUM 4\n" \ +"\n" \ +"#define CHECK_IDX_ALIGN \\\n" \ +" const int X = get_global_id(0); \\\n" \ +" const int Y = get_global_id(1); \\\n" \ +" const int Z = get_global_id(2); \\\n" \ +" if (X > in_shape.x * in_shape.y || Y > in_shape.z || Z > in_shape.w || in_shape.y == 0) { \\\n" \ +" return; \\\n" \ +" }\n" \ +"\n" \ +"#define ARGS_ALIGN \\\n" \ +" const int IN = X / in_shape.y; \\\n" \ +" const int IH = X % in_shape.y; \\\n" \ +" int coordinate_x = IN * in_shape.y + IH; \\\n" \ +" int coordinate_y = Y * in_shape.w + Z; \\\n" \ +" FLT4 result = READ_IMAGE(input, smp_none, (int2)(coordinate_y, coordinate_x));\n" \ +"\n" \ +"__kernel void split_out2_axis3(__read_only image2d_t input, __write_only image2d_t output1,\n" \ +" __write_only image2d_t output2, __global int *split_sizes_, int4 in_shape,\n" \ +" int4 out_shape1, int4 out_shape2) {\n" \ +" CHECK_IDX_ALIGN;\n" \ +" ARGS_ALIGN;\n" \ +" int boundary = UP_DIV(split_sizes_[0], C4NUM);\n" \ +" if (Z < boundary) {\n" \ +" coordinate_x = IN * out_shape1.y + IH;\n" \ +" coordinate_y = Y * out_shape1.w + Z;\n" \ +" WRITE_IMAGE(output1, (int2)(coordinate_y, coordinate_x), result);\n" \ +" } else {\n" \ +" coordinate_x = IN * out_shape2.y + IH;\n" \ +" coordinate_y = Y * out_shape2.w + Z - boundary;\n" \ +" WRITE_IMAGE(output2, (int2)(coordinate_y, coordinate_x), result);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void split_out2_axis2(__read_only image2d_t input, __write_only image2d_t output1,\n" \ +" __write_only image2d_t output2, __global int *split_sizes_, int4 in_shape,\n" \ +" int4 out_shape1, int4 out_shape2) {\n" \ +" CHECK_IDX_ALIGN;\n" \ +" ARGS_ALIGN;\n" \ +" if (Y < split_sizes_[0]) {\n" \ +" coordinate_x = IN * out_shape1.y + IH;\n" \ +" coordinate_y = Y * out_shape1.w + Z;\n" \ +" WRITE_IMAGE(output1, (int2)(coordinate_y, coordinate_x), result);\n" \ +" } else {\n" \ +" coordinate_x = IN * out_shape2.y + IH;\n" \ +" coordinate_y = (Y - split_sizes_[0]) * out_shape2.w + Z;\n" \ +" WRITE_IMAGE(output2, (int2)(coordinate_y, coordinate_x), result);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void split_out2_axis1(__read_only image2d_t input, __write_only image2d_t output1,\n" \ +" __write_only image2d_t output2, __global int *split_sizes_, int4 in_shape,\n" \ +" int4 out_shape1, int4 out_shape2) {\n" \ +" CHECK_IDX_ALIGN;\n" \ +" ARGS_ALIGN;\n" \ +" if (IH < split_sizes_[0]) {\n" \ +" coordinate_x = IN * out_shape1.y + IH;\n" \ +" coordinate_y = Y * out_shape1.w + Z;\n" \ +" WRITE_IMAGE(output1, (int2)(coordinate_y, coordinate_x), result);\n" \ +" } else {\n" \ +" coordinate_x = IN * out_shape2.y + IH - split_sizes_[0];\n" \ +" coordinate_y = Y * out_shape2.w + Z;\n" \ +" WRITE_IMAGE(output2, (int2)(coordinate_y, coordinate_x), result);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"// UnAlign in Axis C for concat\n" \ +"#define CHECK_IDX_UNALIGN \\\n" \ +" const int X = get_global_id(0); \\\n" \ +" const int Y = get_global_id(1); \\\n" \ +" if (X >= in_shape.x * in_shape.y || Y >= in_shape.z || in_shape.y == 0) { \\\n" \ +" return; \\\n" \ +" }\n" \ +"\n" \ +"#define ARGS_UNALIGN \\\n" \ +" const int IN = X / in_shape.y, IH = X % in_shape.y; \\\n" \ +" const int IW = Y; \\\n" \ +" const int Align_inShape = UP_DIV(in_shape.w, C4NUM); \\\n" \ +" int index_input = (IN * in_shape.y + IH) * stride_w + IW * Align_inShape * C4NUM;\n" \ +"\n" \ +"int dosplit(__global FLT *input, __write_only image2d_t output, int4 out_shape, int IN, int IH, int IW,\n" \ +" int index_input) {\n" \ +" int Remainder = out_shape.w % C4NUM;\n" \ +" int coordinate_x = IN * out_shape.y + IH;\n" \ +" int align_w = UP_DIV(out_shape.w, C4NUM);\n" \ +" for (int i = 0; i < align_w; ++i) {\n" \ +" int coordinate_y = IW * align_w + i;\n" \ +" if ((i + 1) * C4NUM <= out_shape.w) {\n" \ +" FLT4 result = {input[index_input], input[index_input + 1], input[index_input + 2], input[index_input + 3]};\n" \ +" WRITE_IMAGE(output, (int2)(coordinate_y, coordinate_x), result);\n" \ +" index_input += 4;\n" \ +" } else {\n" \ +" FLT result_temp[4] = {};\n" \ +" for (int j = 0; j < Remainder; ++j) {\n" \ +" result_temp[j] = input[index_input++];\n" \ +" }\n" \ +" FLT4 result = {result_temp[0], result_temp[1], result_temp[2], result_temp[3]};\n" \ +" WRITE_IMAGE(output, (int2)(coordinate_y, coordinate_x), result);\n" \ +" }\n" \ +" }\n" \ +" return index_input;\n" \ +"}\n" \ +"\n" \ +"__kernel void split_out2_axis3_unalign(__global FLT *input, __write_only image2d_t output1,\n" \ +" __write_only image2d_t output2, __global int *split_sizes_, int4 in_shape,\n" \ +" int4 out_shape1, int4 out_shape2, int stride_w) {\n" \ +" CHECK_IDX_UNALIGN;\n" \ +" ARGS_UNALIGN;\n" \ +" index_input = dosplit(input, output1, out_shape1, IN, IH, IW, index_input);\n" \ +" index_input = dosplit(input, output2, out_shape2, IN, IH, IW, index_input);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/stack.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/stack.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..ce000d03e3e97f4928b9b5abc54c54852e3dbcda --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/stack.cl.inc @@ -0,0 +1,110 @@ +static const char *stack_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#define INT4 int4\n" \ +"#define C4NUM 4\n" \ +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n" \ +"#define CHECK_IDX_FOR_STACK \\\n" \ +" int X = get_global_id(0); \\\n" \ +" int Y = get_global_id(1); \\\n" \ +" int Z = get_global_id(2); \\\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z || Z >= output_shape.w) { \\\n" \ +" return; \\\n" \ +" } \\\n" \ +" FLT4 result;\n" \ +"\n" \ +"// input -1D\n" \ +"__kernel void stack_2input_3axis_1inshape(__read_only image2d_t input0, __read_only image2d_t input1,\n" \ +" __write_only image2d_t output, int4 input_shape, int4 output_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1); // W*C\n" \ +" if (X >= output_shape.x * output_shape.y || Y >= output_shape.z * output_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result1 = READ_IMAGE(input0, smp_none, (int2)(X, 0));\n" \ +" FLT result1_temp[4] = {result1.x, result1.y, result1.z, result1.w};\n" \ +" FLT4 result2 = READ_IMAGE(input1, smp_none, (int2)(X, 0));\n" \ +" FLT result2_temp[4] = {result2.x, result2.y, result2.z, result2.w};\n" \ +" for (int i = 0; i < C4NUM; ++i) {\n" \ +" FLT4 result = {result1_temp[i], result2_temp[i], 0, 0};\n" \ +" WRITE_IMAGE(output, (int2)(Y, (X * C4NUM + i)), result);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"// input -2D -axis = 1\n" \ +"__kernel void stack_2input_1axis_2inshape(__read_only image2d_t input0, __read_only image2d_t input1,\n" \ +" __write_only image2d_t output, int4 input_shape, int4 output_shape) {\n" \ +" CHECK_IDX_FOR_STACK;\n" \ +" int IN = X / output_shape.y;\n" \ +" int IH = X % output_shape.y;\n" \ +" int boundary0 = input_shape.z;\n" \ +" if (Y < boundary0) {\n" \ +" int coordinate_x = Y * input_shape.w + Z;\n" \ +" int coordinate_y = IN * input_shape.y + IH;\n" \ +" result = READ_IMAGE(input0, smp_none, (int2)(coordinate_x, coordinate_y));\n" \ +" } else {\n" \ +" int coordinate_x = (Y - boundary0) * input_shape.w + Z;\n" \ +" int coordinate_y = IN * input_shape.y + IH;\n" \ +" result = READ_IMAGE(input1, smp_none, (int2)(coordinate_x, coordinate_y));\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (IN * output_shape.y + IH)), result);\n" \ +"}\n" \ +"\n" \ +"// input -3D -axis = 1\n" \ +"__kernel void stack_2input_1axis_3inshape(__read_only image2d_t input0, __read_only image2d_t input1,\n" \ +" __write_only image2d_t output, int4 input_shape, int4 output_shape) {\n" \ +" CHECK_IDX_FOR_STACK;\n" \ +" int IN = X / output_shape.y;\n" \ +" int IH = X % output_shape.y;\n" \ +" int boundary0 = input_shape.y;\n" \ +" if (IH < boundary0) {\n" \ +" int coordinate_x = Y * input_shape.w + Z;\n" \ +" int coordinate_y = IN * input_shape.y + IH;\n" \ +" result = READ_IMAGE(input0, smp_none, (int2)(coordinate_x, coordinate_y));\n" \ +" } else {\n" \ +" int coordinate_x = Y * input_shape.w + Z;\n" \ +" int coordinate_y = IN * input_shape.y + IH - boundary0;\n" \ +" result = READ_IMAGE(input1, smp_none, (int2)(coordinate_x, coordinate_y));\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, (IN * output_shape.y + IH)), result);\n" \ +"}\n" \ +"\n" \ +"// input -3D -axis = 2\n" \ +"__kernel void stack_2input_2axis_3inshape(__read_only image2d_t input0, __read_only image2d_t input1,\n" \ +" __write_only image2d_t output, int4 input_shape, int4 output_shape) {\n" \ +" CHECK_IDX_FOR_STACK;\n" \ +" int boundary0 = input_shape.y;\n" \ +" int IN = X / output_shape.y;\n" \ +" int IW = X % output_shape.y;\n" \ +" int IC = Z;\n" \ +" int coordinate_x = IW * input_shape.w + IC;\n" \ +" int coordinate_y = IN * input_shape.y;\n" \ +" if (Y < boundary0) {\n" \ +" result = READ_IMAGE(input0, smp_none, (int2)(coordinate_x, coordinate_y));\n" \ +" } else {\n" \ +" result = READ_IMAGE(input1, smp_none, (int2)(coordinate_x, coordinate_y));\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)((Y)*output_shape.w + Z, IN * output_shape.y + IW), result);\n" \ +"}\n" \ +"\n" \ +"// input -3D -axis = 3 and input -2D -axis = 2 boundary stack\n" \ +"__kernel void stack_2input_boundary(__global float *input0, __global float *input1, __global float *output,\n" \ +" int4 input_shape, int4 output_shape, int2 stride_w) {\n" \ +" int X = get_global_id(0); // N\n" \ +" int Y = get_global_id(1); // H\n" \ +" if (X >= output_shape.x || Y >= output_shape.y) {\n" \ +" return;\n" \ +" }\n" \ +" int IW = output_shape.z;\n" \ +" int Align_out = output_shape.w * C4NUM;\n" \ +" int Align_in = input_shape.w * C4NUM;\n" \ +" int index_out = X * output_shape.y * stride_w.x + Y * stride_w.x;\n" \ +" int index_in = X * input_shape.y * stride_w.y + Y * Align_in;\n" \ +" for (int iw = 0; iw < IW; iw++) {\n" \ +" int index_out_tmp = index_out + iw * Align_out;\n" \ +" int index_in_tmp = index_in + iw;\n" \ +" output[index_out_tmp] = input0[index_in_tmp];\n" \ +" index_out_tmp++;\n" \ +" output[index_out_tmp] = input1[index_in_tmp];\n" \ +" }\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/strassen.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/strassen.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..439e10be06c09ef5a7693cf19c79eb763b00d490 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/strassen.cl.inc @@ -0,0 +1,130 @@ +static const char *strassen_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#define C4NUM 4\n" \ +"#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void MatMul_Strassen_NHWC4_2d(__read_only image2d_t input, __write_only image2d_t output, __global FLT *weight,\n" \ +" int4 in_shape, int4 out_shape) {\n" \ +" int gidx = get_global_id(0); // CO4\n" \ +" int gidz = get_global_id(2); // N\n" \ +" int lidx = get_local_id(0);\n" \ +" int lidy = get_local_id(1);\n" \ +" int ci4 = UP_DIV(in_shape.w, C4NUM);\n" \ +" int co4 = UP_DIV(out_shape.w, C4NUM);\n" \ +" int weight_stride = in_shape.w;\n" \ +" FLT sum[4] = {0.0f, 0.0f, 0.0f, 0.0f};\n" \ +" bool inside = gidx < co4 && gidz < weight_stride;\n" \ +" for (uint i = lidy; i < ci4 && inside; i += 4) {\n" \ +" FLT4 result_in = READ_IMAGE(input, smp_zero, (int2)(i, gidz));\n" \ +" int index_x = (i * C4NUM) * weight_stride + gidx * C4NUM;\n" \ +" int index_y = index_x + weight_stride;\n" \ +" int index_z = index_y + weight_stride;\n" \ +" int index_w = index_z + weight_stride;\n" \ +" for (int j = 0; j < C4NUM; ++j) {\n" \ +" FLT4 result_weight = {weight[index_x + j], weight[index_y + j], weight[index_z + j], weight[index_w + j]};\n" \ +" sum[j] += dot(result_in, result_weight);\n" \ +" }\n" \ +" }\n" \ +" FLT4 result = {sum[0], sum[1], sum[2], sum[3]};\n" \ +" __local FLT4 temp[32][4];\n" \ +" temp[lidx][lidy] = result;\n" \ +" barrier(CLK_LOCAL_MEM_FENCE);\n" \ +" if (lidy == 0 && inside) {\n" \ +" result += temp[lidx][1];\n" \ +" result += temp[lidx][2];\n" \ +" result += temp[lidx][3];\n" \ +" WRITE_IMAGE(output, (int2)(gidx, gidz), result);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"// flag = 0 : represent add, otherwise, sub\n" \ +"__kernel void MatMul_BUF_Add_Sub_2(__global FLT4 *input, __global FLT4 *output, int4 shape, int4 offset, int flag) {\n" \ +" int gidy = get_global_id(0); // W*C4\n" \ +" int gidx = get_global_id(2); // N*H\n" \ +" if (gidx >= shape.x * shape.y || gidy >= shape.z * shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" int ci_co_4 = shape.w;\n" \ +" const int origin_shape = 2 * ci_co_4;\n" \ +" int index_1 = (gidx + offset.x) * origin_shape + gidy + offset.y;\n" \ +" int index_2 = (gidx + offset.z) * origin_shape + gidy + offset.w;\n" \ +" FLT4 result1 = input[index_1];\n" \ +" FLT4 result2 = input[index_2];\n" \ +" FLT4 result;\n" \ +" if (flag == 0) {\n" \ +" result = result1 + result2;\n" \ +" } else {\n" \ +" result = result1 - result2;\n" \ +" }\n" \ +" output[gidx * ci_co_4 + gidy] = result;\n" \ +"}\n" \ +"\n" \ +"__kernel void MatMul_IMG_Add_Sub_2(__read_only image2d_t input, __write_only image2d_t output, int4 shape, int4 offset,\n" \ +" int flag) {\n" \ +" int gidy = get_global_id(0); // W*C4\n" \ +" int gidx = get_global_id(2); // N*H\n" \ +" if (gidx >= shape.x * shape.y || gidy >= shape.z * shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result1 = READ_IMAGE(input, smp_zero, (int2)(gidy + offset.y, gidx + offset.x));\n" \ +" FLT4 result2 = READ_IMAGE(input, smp_zero, (int2)(gidy + offset.w, gidx + offset.z));\n" \ +" FLT4 result;\n" \ +" if (flag == 0) {\n" \ +" result = result1 + result2;\n" \ +" } else {\n" \ +" result = result1 - result2;\n" \ +" }\n" \ +" WRITE_IMAGE(output, (int2)(gidy, gidx), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void Strassen_Back_Result(__read_only image2d_t input1, __read_only image2d_t input2,\n" \ +" __read_only image2d_t input3, __read_only image2d_t input4,\n" \ +" __read_only image2d_t input5, __read_only image2d_t input6,\n" \ +" __read_only image2d_t input7, __write_only image2d_t output, int4 shape) {\n" \ +" int gidy = get_global_id(0); // W*C4\n" \ +" int gidx = get_global_id(2); // N*H\n" \ +" int offset_x = shape.x * shape.y, offset_y = shape.z * shape.w;\n" \ +" if (gidx >= offset_x || gidy >= offset_y) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result_M1 = READ_IMAGE(input1, smp_zero, (int2)(gidy, gidx));\n" \ +" FLT4 result_M2 = READ_IMAGE(input2, smp_zero, (int2)(gidy, gidx));\n" \ +" FLT4 result_M3 = READ_IMAGE(input3, smp_zero, (int2)(gidy, gidx));\n" \ +" FLT4 result_M4 = READ_IMAGE(input4, smp_zero, (int2)(gidy, gidx));\n" \ +" FLT4 result_M5 = READ_IMAGE(input5, smp_zero, (int2)(gidy, gidx));\n" \ +" FLT4 result_M6 = READ_IMAGE(input6, smp_zero, (int2)(gidy, gidx));\n" \ +" FLT4 result_M7 = READ_IMAGE(input7, smp_zero, (int2)(gidy, gidx));\n" \ +" FLT4 result1 = result_M4 + result_M5 + result_M6 - result_M2; // C11\n" \ +" FLT4 result2 = result_M1 + result_M2; // C12\n" \ +" FLT4 result3 = result_M3 + result_M4; // C21\n" \ +" FLT4 result4 = result_M1 + result_M5 - result_M3 - result_M7; // C22\n" \ +" WRITE_IMAGE(output, (int2)(gidy, gidx), result1);\n" \ +" WRITE_IMAGE(output, (int2)(gidy + offset_y, gidx), result2);\n" \ +" WRITE_IMAGE(output, (int2)(gidy, gidx + offset_x), result3);\n" \ +" WRITE_IMAGE(output, (int2)(gidy + offset_y, gidx + offset_x), result4);\n" \ +"}\n" \ +"\n" \ +"__kernel void MatMul_IMG_Filled(__read_only image2d_t input, __write_only image2d_t output, int4 shape, int2 offset) {\n" \ +" int gidy = get_global_id(0); // W*C4\n" \ +" int gidx = get_global_id(2); // N*H\n" \ +" if (gidx >= shape.x * shape.y || gidy >= shape.z * shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 result = READ_IMAGE(input, smp_zero, (int2)(gidy + offset.y, gidx + offset.x));\n" \ +" WRITE_IMAGE(output, (int2)(gidy, gidx), result);\n" \ +"}\n" \ +"\n" \ +"__kernel void MatMul_BUF_Filled(__global FLT4 *input, __global FLT4 *output, int4 shape, int2 offset) {\n" \ +" int gidy = get_global_id(0); // W*C4\n" \ +" int gidx = get_global_id(2); // N*H\n" \ +" if (gidx >= shape.x * shape.y || gidy >= shape.z * shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" int stride_out = shape.z * shape.w;\n" \ +" int index_out = gidx * stride_out + gidy;\n" \ +" const int stride_origin = 2 * stride_out;\n" \ +" int index_in = (gidx + offset.x) * stride_origin + gidy + offset.y;\n" \ +" FLT4 result = input[index_in];\n" \ +" output[index_out] = result;\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/strided_slice.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/strided_slice.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..98a48f387064476f737c685cfbc1485c0ee92ff8 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/strided_slice.cl.inc @@ -0,0 +1,61 @@ +static const char *strided_slice_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"\n" \ +"__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"__kernel void strided_slice(__read_only image2d_t input, __write_only image2d_t output, int4 input_shape,\n" \ +" int4 output_shape, int2 io_slices, int4 begin, int4 stride, int4 size) {\n" \ +" int IN = input_shape.x, IH = input_shape.y, IW = input_shape.z, CI = input_shape.w;\n" \ +" int ON = output_shape.x, OH = output_shape.y, OW = output_shape.z, CO = output_shape.w;\n" \ +" int CI_SLICES = io_slices.x, CO_SLICES = io_slices.y;\n" \ +" int on_oh = get_global_id(0);\n" \ +" int ow = get_global_id(1);\n" \ +" int co_slice = get_global_id(2);\n" \ +" int on = on_oh / OH;\n" \ +" int oh = on_oh % OH;\n" \ +" if (on >= ON || oh >= OH || ow >= OW || co_slice >= CO_SLICES) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT tmp[4];\n" \ +" for (int i = 0; i < 4; ++i) {\n" \ +" // output_shape idx -> size idx. because squeeze(output_shape)=squeeze(size)\n" \ +" // for example:\n" \ +" // python code: B = A[1, 1:16, 2:16, 3:16]\n" \ +" // input_shape = [16, 16, 16, 16]\n" \ +" // begin = [ 1, 1, 2, 3]\n" \ +" // end = [ 2, 16, 16, 16]\n" \ +" // stride = [ 1, 1, 1, 1]\n" \ +" // size = [ 1, 15, 14, 13] = ceil((end - begin) / stride)\n" \ +" // output_shape = [ 15, 14, 13]\n" \ +" int idx = ((on * OH + oh) * OW + ow) * CO + co_slice * 4 + i;\n" \ +" int co_ = idx % size.w;\n" \ +" idx /= size.w;\n" \ +" int ow_ = idx % size.z;\n" \ +" idx /= size.z;\n" \ +" int oh_ = idx % size.y;\n" \ +" idx /= size.y;\n" \ +" int on_ = idx;\n" \ +"\n" \ +" int in = begin.x + stride.x * on_;\n" \ +" int ih = begin.y + stride.y * oh_;\n" \ +" int iw = begin.z + stride.z * ow_;\n" \ +" int ci = begin.w + stride.w * co_;\n" \ +"\n" \ +" FLT4 src = READ_IMAGE(input, smp_none, (int2)(iw * CI_SLICES + ci / 4, in * IH + ih));\n" \ +" int offset = ci % 4;\n" \ +" if (offset == 0) {\n" \ +" tmp[i] = src.x;\n" \ +" } else if (offset == 1) {\n" \ +" tmp[i] = src.y;\n" \ +" } else if (offset == 2) {\n" \ +" tmp[i] = src.z;\n" \ +" } else {\n" \ +" tmp[i] = src.w;\n" \ +" }\n" \ +" }\n" \ +"\n" \ +" FLT4 out = (FLT4)(tmp[0], tmp[1], tmp[2], tmp[3]);\n" \ +" WRITE_IMAGE(output, (int2)(ow * CO_SLICES + co_slice, on_oh), out);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/to_format.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/to_format.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..f4f5a5d818d27a39d6e2c05e4ac2fd41228b9d51 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/to_format.cl.inc @@ -0,0 +1,85 @@ +static const char *to_format_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"#define BUF_to_IMG(src_dtype, dst_dtype, SRC_TYPE, DST_TYPE, WRITE_IMAGE_OUT) \\\n" \ +" __kernel void BUF_to_IMG_##src_dtype##_##dst_dtype(__global SRC_TYPE##4 * src_data, __write_only image2d_t dst_data, \\\n" \ +" int4 size, int4 shape) { \\\n" \ +" int X = get_global_id(0); \\\n" \ +" int Y = get_global_id(1); \\\n" \ +" int Z = get_global_id(2); \\\n" \ +" if (X >= size.x || Y >= size.y || Z >= size.z) { \\\n" \ +" return; \\\n" \ +" } \\\n" \ +" DST_TYPE##4 data = (DST_TYPE##4)(0.f); \\\n" \ +" int offset = (X * shape.z + Y) * shape.w + Z * 4; \\\n" \ +" __global SRC_TYPE *src_addr = (__global SRC_TYPE *)src_data; \\\n" \ +" src_addr += offset; \\\n" \ +" if ((Z + 1) * 4 <= shape.w) { \\\n" \ +" data = convert_##DST_TYPE##4(((__global SRC_TYPE##4 *)src_addr)[0]); \\\n" \ +" } else { \\\n" \ +" if ((shape.w - Z * 4) >= 1) { \\\n" \ +" data.x = (DST_TYPE)src_addr[0]; \\\n" \ +" } \\\n" \ +" if ((shape.w - Z * 4) >= 2) { \\\n" \ +" data.y = (DST_TYPE)src_addr[1]; \\\n" \ +" } \\\n" \ +" if ((shape.w - Z * 4) >= 3) { \\\n" \ +" data.z = (DST_TYPE)src_addr[2]; \\\n" \ +" } \\\n" \ +" } \\\n" \ +" if (size.y * size.z <= MAX_IMAGE2D_WIDTH) \\\n" \ +" WRITE_IMAGE_OUT(dst_data, (int2)(Y * size.z + Z, X), data); \\\n" \ +" else \\\n" \ +" WRITE_IMAGE_OUT(dst_data, (int2)(Z, X * size.y + Y), data); \\\n" \ +" }\n" \ +"\n" \ +"// BUF_to_IMG(src_dtype, dst_dtype, SRC_TYPE, DST_TYPE, WRITE_IMAGE_OUT)\n" \ +"BUF_to_IMG(float32, float32, float, float, write_imagef);\n" \ +"BUF_to_IMG(float32, float16, float, half, write_imageh);\n" \ +"BUF_to_IMG(float16, float16, half, half, write_imageh);\n" \ +"BUF_to_IMG(int32, int32, int, int, write_imagei);\n" \ +"BUF_to_IMG(uint32, uint32, int, int, write_imagei);\n" \ +"BUF_to_IMG(int8, int8, char, int, write_imagei);\n" \ +"\n" \ +"#define IMG_to_BUF(src_dtype, dst_dtype, SRC_TYPE, DST_TYPE, READ_IMAGE_IN) \\\n" \ +" __kernel void IMG_to_BUF_##src_dtype##_##dst_dtype(__read_only image2d_t src_data, __global DST_TYPE##4 * dst_data, \\\n" \ +" int4 size, int4 shape) { \\\n" \ +" int X = get_global_id(0); \\\n" \ +" int Y = get_global_id(1); \\\n" \ +" int Z = get_global_id(2); \\\n" \ +" if (X >= size.x || Y >= size.y || Z >= size.z) { \\\n" \ +" return; \\\n" \ +" } \\\n" \ +" DST_TYPE##4 data; \\\n" \ +" if (size.y * size.z <= MAX_IMAGE2D_WIDTH) \\\n" \ +" data = convert_##DST_TYPE##4(READ_IMAGE_IN(src_data, smp_zero, (int2)(Y * size.z + Z, X))); \\\n" \ +" else \\\n" \ +" data = convert_##DST_TYPE##4(READ_IMAGE_IN(src_data, smp_zero, (int2)(Z, X * size.y + Y))); \\\n" \ +" int offset = (X * shape.z + Y) * shape.w + Z * 4; \\\n" \ +" __global DST_TYPE *dst_addr = (__global DST_TYPE *)dst_data; \\\n" \ +" dst_addr += offset; \\\n" \ +" if ((Z + 1) * 4 <= shape.w) { \\\n" \ +" ((__global DST_TYPE##4 *)dst_addr)[0] = data; \\\n" \ +" } else { \\\n" \ +" if (shape.w - Z * 4 >= 1) { \\\n" \ +" dst_addr[0] = data.x; \\\n" \ +" } \\\n" \ +" if (shape.w - Z * 4 >= 2) { \\\n" \ +" dst_addr[1] = data.y; \\\n" \ +" } \\\n" \ +" if (shape.w - Z * 4 >= 3) { \\\n" \ +" dst_addr[2] = data.z; \\\n" \ +" } \\\n" \ +" } \\\n" \ +" }\n" \ +"\n" \ +"// IMG_to_BUF(src_dtype, dst_dtype, SRC_TYPE, DST_TYPE, READ_IMAGE_IN)\n" \ +"IMG_to_BUF(float32, float32, float, float, read_imagef);\n" \ +"IMG_to_BUF(float16, float32, half, float, read_imageh);\n" \ +"IMG_to_BUF(float16, float16, half, half, read_imageh);\n" \ +"IMG_to_BUF(int32, int32, int, int, read_imagei);\n" \ +"IMG_to_BUF(uint32, uint32, int, int, read_imagei);\n" \ +"IMG_to_BUF(int32, float32, int, float, read_imagei);\n" \ +"IMG_to_BUF(int8, int8, char, char, read_imagei);\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/transpose.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/transpose.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..089a6a9af5f068988b2ee0b50dce677ef83f7fd9 --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/transpose.cl.inc @@ -0,0 +1,266 @@ +static const char *transpose_source ="\n" +"#ifdef cl_khr_fp16\n" \ +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"#endif\n" \ +"#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"__kernel void transpose_0312_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 shape) {\n" \ +" int X = get_global_id(0); // H4, C4 for src\n" \ +" int Y = get_global_id(1); // W, H for src\n" \ +" int Z = get_global_id(2); // C4, W4 for src\n" \ +" if (4 * X >= shape.y || Y >= shape.z || 4 * Z >= shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" int H4 = UP_DIV(shape.y, 4);\n" \ +" int C4 = UP_DIV(shape.w, 4);\n" \ +" FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(4 * Z * H4 + X, Y));\n" \ +" FLT4 src1 = (FLT4)0.f;\n" \ +" if (4 * Z + 1 < shape.w) {\n" \ +" src1 = READ_IMAGE(src_data, smp_zero, (int2)((4 * Z + 1) * H4 + X, Y));\n" \ +" }\n" \ +" FLT4 src2 = (FLT4)0.f;\n" \ +" if (4 * Z + 2 < shape.w) {\n" \ +" src2 = READ_IMAGE(src_data, smp_zero, (int2)((4 * Z + 2) * H4 + X, Y));\n" \ +" }\n" \ +" FLT4 src3 = (FLT4)0.f;\n" \ +" if (4 * Z + 3 < shape.w) {\n" \ +" src3 = READ_IMAGE(src_data, smp_zero, (int2)((4 * Z + 3) * H4 + X, Y));\n" \ +" }\n" \ +" FLT4 dst0 = (FLT4)(src0.x, src1.x, src2.x, src3.x);\n" \ +" FLT4 dst1 = (FLT4)(src0.y, src1.y, src2.y, src3.y);\n" \ +" FLT4 dst2 = (FLT4)(src0.z, src1.z, src2.z, src3.z);\n" \ +" FLT4 dst3 = (FLT4)(src0.w, src1.w, src2.w, src3.w);\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * C4 + Z, 4 * X), dst0);\n" \ +" if (4 * X + 1 < shape.y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * C4 + Z, 4 * X + 1), dst1);\n" \ +" }\n" \ +" if (4 * X + 2 < shape.y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * C4 + Z, 4 * X + 2), dst2);\n" \ +" }\n" \ +" if (4 * X + 3 < shape.y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * C4 + Z, 4 * X + 3), dst3);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void transpose_0312_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 shape) {\n" \ +" int X = get_global_id(0); // H4, C4 for src\n" \ +" int Y = get_global_id(1); // W, H for src\n" \ +" int Z = get_global_id(2); // C4, W4 for src\n" \ +" if (4 * X >= shape.y || Y >= shape.z || 4 * Z >= shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(4 * Z, X * shape.z + Y));\n" \ +" FLT4 src1 = (FLT4)0.f;\n" \ +" if (4 * Z + 1 < shape.w) {\n" \ +" src1 = READ_IMAGE(src_data, smp_zero, (int2)(4 * Z + 1, X * shape.z + Y));\n" \ +" }\n" \ +" FLT4 src2 = (FLT4)0.f;\n" \ +" if (4 * Z + 2 < shape.w) {\n" \ +" src2 = READ_IMAGE(src_data, smp_zero, (int2)(4 * Z + 2, X * shape.z + Y));\n" \ +" }\n" \ +" FLT4 src3 = (FLT4)0.f;\n" \ +" if (4 * Z + 3 < shape.w) {\n" \ +" src3 = READ_IMAGE(src_data, smp_zero, (int2)(4 * Z + 3, X * shape.z + Y));\n" \ +" }\n" \ +" FLT4 dst0 = (FLT4)(src0.x, src1.x, src2.x, src3.x);\n" \ +" FLT4 dst1 = (FLT4)(src0.y, src1.y, src2.y, src3.y);\n" \ +" FLT4 dst2 = (FLT4)(src0.z, src1.z, src2.z, src3.z);\n" \ +" FLT4 dst3 = (FLT4)(src0.w, src1.w, src2.w, src3.w);\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y, Z * shape.y + 4 * X), dst0);\n" \ +" if (4 * X + 1 < shape.y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y, Z * shape.y + 4 * X + 1), dst1);\n" \ +" }\n" \ +" if (4 * X + 2 < shape.y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y, Z * shape.y + 4 * X + 2), dst2);\n" \ +" }\n" \ +" if (4 * X + 3 < shape.y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y, Z * shape.y + 4 * X + 3), dst3);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void transpose_0312_oversize_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data,\n" \ +" int4 shape) {\n" \ +" int X = get_global_id(0); // H4, C4 for src\n" \ +" int Y = get_global_id(1); // W, H for src\n" \ +" int Z = get_global_id(2); // C4, W4 for src\n" \ +" if (4 * X >= shape.y || Y >= shape.z || 4 * Z >= shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" int H4 = UP_DIV(shape.y, 4);\n" \ +" int C4 = UP_DIV(shape.w, 4);\n" \ +" FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(X, Y * shape.w + 4 * Z));\n" \ +" FLT4 src1 = (FLT4)0.f;\n" \ +" if (4 * Z + 1 < shape.w) {\n" \ +" src1 = READ_IMAGE(src_data, smp_zero, (int2)(X, Y * shape.w + 4 * Z + 1));\n" \ +" }\n" \ +" FLT4 src2 = (FLT4)0.f;\n" \ +" if (4 * Z + 2 < shape.w) {\n" \ +" src2 = READ_IMAGE(src_data, smp_zero, (int2)(X, Y * shape.w + 4 * Z + 2));\n" \ +" }\n" \ +" FLT4 src3 = (FLT4)0.f;\n" \ +" if (4 * Z + 3 < shape.w) {\n" \ +" src3 = READ_IMAGE(src_data, smp_zero, (int2)(X, Y * shape.w + 4 * Z + 3));\n" \ +" }\n" \ +" FLT4 dst0 = (FLT4)(src0.x, src1.x, src2.x, src3.x);\n" \ +" FLT4 dst1 = (FLT4)(src0.y, src1.y, src2.y, src3.y);\n" \ +" FLT4 dst2 = (FLT4)(src0.z, src1.z, src2.z, src3.z);\n" \ +" FLT4 dst3 = (FLT4)(src0.w, src1.w, src2.w, src3.w);\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * C4 + Z, 4 * X), dst0);\n" \ +" if (4 * X + 1 < shape.y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * C4 + Z, 4 * X + 1), dst1);\n" \ +" }\n" \ +" if (4 * X + 2 < shape.y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * C4 + Z, 4 * X + 2), dst2);\n" \ +" }\n" \ +" if (4 * X + 3 < shape.y) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * C4 + Z, 4 * X + 3), dst3);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void transpose_0231_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 shape) {\n" \ +" int X = get_global_id(0); // H, W for src\n" \ +" int Y = get_global_id(1); // W4, C4 for src\n" \ +" int Z = get_global_id(2); // C4, H4 for src\n" \ +" if (X >= shape.y || 4 * Y >= shape.z || 4 * Z >= shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" int W4 = UP_DIV(shape.z, 4);\n" \ +" int C4 = UP_DIV(shape.w, 4);\n" \ +" FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(X * W4 + Y, 4 * Z));\n" \ +" FLT4 src1 = (FLT4)0.f;\n" \ +" if (4 * Z + 1 < shape.w) {\n" \ +" src1 = READ_IMAGE(src_data, smp_zero, (int2)(X * W4 + Y, 4 * Z + 1));\n" \ +" }\n" \ +" FLT4 src2 = (FLT4)0.f;\n" \ +" if (4 * Z + 2 < shape.w) {\n" \ +" src2 = READ_IMAGE(src_data, smp_zero, (int2)(X * W4 + Y, 4 * Z + 2));\n" \ +" }\n" \ +" FLT4 src3 = (FLT4)0.f;\n" \ +" if (4 * Z + 3 < shape.w) {\n" \ +" src3 = READ_IMAGE(src_data, smp_zero, (int2)(X * W4 + Y, 4 * Z + 3));\n" \ +" }\n" \ +" FLT4 dst0 = (FLT4)(src0.x, src1.x, src2.x, src3.x);\n" \ +" FLT4 dst1 = (FLT4)(src0.y, src1.y, src2.y, src3.y);\n" \ +" FLT4 dst2 = (FLT4)(src0.z, src1.z, src2.z, src3.z);\n" \ +" FLT4 dst3 = (FLT4)(src0.w, src1.w, src2.w, src3.w);\n" \ +" WRITE_IMAGE(dst_data, (int2)(4 * Y * C4 + Z, X), dst0);\n" \ +" if (4 * Y + 1 < shape.z) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((4 * Y + 1) * C4 + Z, X), dst1);\n" \ +" }\n" \ +" if (4 * Y + 2 < shape.z) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((4 * Y + 2) * C4 + Z, X), dst2);\n" \ +" }\n" \ +" if (4 * Y + 3 < shape.z) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((4 * Y + 3) * C4 + Z, X), dst3);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void transpose_0231_oversize_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data,\n" \ +" int4 shape) {\n" \ +" int X = get_global_id(0); // H, W for src\n" \ +" int Y = get_global_id(1); // W4, C4 for src\n" \ +" int Z = get_global_id(2); // C4, H4 for src\n" \ +" if (X >= shape.y || 4 * Y >= shape.z || 4 * Z >= shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" int W4 = UP_DIV(shape.z, 4);\n" \ +" int C4 = UP_DIV(shape.w, 4);\n" \ +" FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(Y, 4 * Z * shape.y + X));\n" \ +" FLT4 src1 = (FLT4)0.f;\n" \ +" if (4 * Z + 1 < shape.w) {\n" \ +" src1 = READ_IMAGE(src_data, smp_zero, (int2)(Y, (4 * Z + 1) * shape.y + X));\n" \ +" }\n" \ +" FLT4 src2 = (FLT4)0.f;\n" \ +" if (4 * Z + 2 < shape.w) {\n" \ +" src2 = READ_IMAGE(src_data, smp_zero, (int2)(Y, (4 * Z + 2) * shape.y + X));\n" \ +" }\n" \ +" FLT4 src3 = (FLT4)0.f;\n" \ +" if (4 * Z + 3 < shape.w) {\n" \ +" src3 = READ_IMAGE(src_data, smp_zero, (int2)(Y, (4 * Z + 3) * shape.y + X));\n" \ +" }\n" \ +" FLT4 dst0 = (FLT4)(src0.x, src1.x, src2.x, src3.x);\n" \ +" FLT4 dst1 = (FLT4)(src0.y, src1.y, src2.y, src3.y);\n" \ +" FLT4 dst2 = (FLT4)(src0.z, src1.z, src2.z, src3.z);\n" \ +" FLT4 dst3 = (FLT4)(src0.w, src1.w, src2.w, src3.w);\n" \ +" WRITE_IMAGE(dst_data, (int2)(4 * Y * C4 + Z, X), dst0);\n" \ +" if (4 * Y + 1 < shape.z) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((4 * Y + 1) * C4 + Z, X), dst1);\n" \ +" }\n" \ +" if (4 * Y + 2 < shape.z) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((4 * Y + 2) * C4 + Z, X), dst2);\n" \ +" }\n" \ +" if (4 * Y + 3 < shape.z) {\n" \ +" WRITE_IMAGE(dst_data, (int2)((4 * Y + 3) * C4 + Z, X), dst3);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"__kernel void transpose_0231_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 shape) {\n" \ +" int X = get_global_id(0); // H, W for src\n" \ +" int Y = get_global_id(1); // W4, C4 for src\n" \ +" int Z = get_global_id(2); // C4, H4 for src\n" \ +" if (X >= shape.y || 4 * Y >= shape.z || 4 * Z >= shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" FLT4 src0 = READ_IMAGE(src_data, smp_zero, (int2)(X, Y * shape.w + 4 * Z));\n" \ +" FLT4 src1 = (FLT4)0.f;\n" \ +" if (4 * Z + 1 < shape.w) {\n" \ +" src1 = READ_IMAGE(src_data, smp_zero, (int2)(X, Y * shape.w + 4 * Z + 1));\n" \ +" }\n" \ +" FLT4 src2 = (FLT4)0.f;\n" \ +" if (4 * Z + 2 < shape.w) {\n" \ +" src2 = READ_IMAGE(src_data, smp_zero, (int2)(X, Y * shape.w + 4 * Z + 2));\n" \ +" }\n" \ +" FLT4 src3 = (FLT4)0.f;\n" \ +" if (4 * Z + 3 < shape.w) {\n" \ +" src3 = READ_IMAGE(src_data, smp_zero, (int2)(X, Y * shape.w + 4 * Z + 3));\n" \ +" }\n" \ +" FLT4 dst0 = (FLT4)(src0.x, src1.x, src2.x, src3.x);\n" \ +" FLT4 dst1 = (FLT4)(src0.y, src1.y, src2.y, src3.y);\n" \ +" FLT4 dst2 = (FLT4)(src0.z, src1.z, src2.z, src3.z);\n" \ +" FLT4 dst3 = (FLT4)(src0.w, src1.w, src2.w, src3.w);\n" \ +" WRITE_IMAGE(dst_data, (int2)(4 * Y, Z * shape.y + X), dst0);\n" \ +" if (4 * Y + 1 < shape.z) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(4 * Y + 1, Z * shape.y + X), dst1);\n" \ +" }\n" \ +" if (4 * Y + 2 < shape.z) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(4 * Y + 2, Z * shape.y + X), dst2);\n" \ +" }\n" \ +" if (4 * Y + 3 < shape.z) {\n" \ +" WRITE_IMAGE(dst_data, (int2)(4 * Y + 3, Z * shape.y + X), dst3);\n" \ +" }\n" \ +"}\n" \ +"\n" \ +"typedef union FLT4_array {\n" \ +" FLT c_array[4];\n" \ +" FLT4 vector;\n" \ +"} FLT4_array;\n" \ +"\n" \ +"__kernel void transpose_general_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 out_shape,\n" \ +" int4 de_perm, int4 in_shape) {\n" \ +" int X = get_global_id(0); // N*H\n" \ +" int Y = get_global_id(1);\n" \ +" int Z = get_global_id(2);\n" \ +" if (X >= out_shape.y * out_shape.x || Y >= out_shape.z || 4 * Z >= out_shape.w) {\n" \ +" return;\n" \ +" }\n" \ +" int N = X / out_shape.y;\n" \ +" int H = X % out_shape.y;\n" \ +" int CI4_SIZE = UP_DIV(in_shape.w, 4);\n" \ +" FLT4_array result_tmp;\n" \ +" result_tmp.vector = (FLT4)(0.f);\n" \ +" FLT *result_ptr = result_tmp.c_array;\n" \ +" for (int i = 0; i < 4; i++) {\n" \ +" if (Z * 4 + i < out_shape.w) {\n" \ +" int out_index[4] = {N, H, Y, Z * 4 + i};\n" \ +" FLT4 src = READ_IMAGE(src_data, smp_zero,\n" \ +" (int2)(out_index[de_perm.z] * CI4_SIZE + out_index[de_perm.w] / 4,\n" \ +" out_index[de_perm.x] * in_shape.y + out_index[de_perm.y]));\n" \ +" FLT4_array src_tmp;\n" \ +" src_tmp.vector = src;\n" \ +" result_tmp.c_array[i] = src_tmp.c_array[out_index[de_perm.w] % 4];\n" \ +" }\n" \ +" }\n" \ +" int CO4_SIZE = UP_DIV(out_shape.w, 4);\n" \ +" WRITE_IMAGE(dst_data, (int2)(Y * CO4_SIZE + Z, X), result_tmp.vector);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/runtime/kernel/opencl/cl/winograd.cl.inc b/mindspore-lite/src/runtime/kernel/opencl/cl/winograd.cl.inc new file mode 100644 index 0000000000000000000000000000000000000000..5cfdddd8b18e3e6f41d2f2b785bcadf5a4b3f54e --- /dev/null +++ b/mindspore-lite/src/runtime/kernel/opencl/cl/winograd.cl.inc @@ -0,0 +1,339 @@ +static const char *winograd_source ="\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" \ +"\n" \ +"__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" \ +"\n" \ +"#define CI_TILE 4\n" \ +"#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))\n" \ +"\n" \ +"constant FLT Bt[36] = {\n" \ +" 1.0000000000f, 0.0000000000f, -2.5000004768f, -0.0000001192f, 1.0000001192f, 0.0000000000f,\n" \ +" 0.0000000000f, 0.9428091049f, 1.3333333731f, -0.4714044929f, -0.6666667461f, 0.0000000000f,\n" \ +" 0.0000000000f, -0.9428089857f, 1.3333334923f, 0.4714045525f, -0.6666667461f, 0.0000000000f,\n" \ +" 0.0000000000f, -0.1178511307f, -0.0833333358f, 0.2357022613f, 0.1666666865f, 0.0000000000f,\n" \ +" 0.0000000000f, 0.1178511307f, -0.0833333507f, -0.2357022911f, 0.1666666865f, 0.0000000000f,\n" \ +" 0.0000000000f, 0.9999998808f, -0.0000000596f, -2.5000000000f, 0.0000000000f, 1.0000000000f,\n" \ +"};\n" \ +"\n" \ +"__kernel void Winograd4x4To36(__read_only image2d_t input, // height=N*H width=W*CI_SLICES\n" \ +" __write_only image2d_t output, // height=CI_SLICES*36 width=H/4*W/4\n" \ +" int4 input_shape, // N H W CI_SLICES\n" \ +" int TILE_HW, int pad_u, int pad_l) {\n" \ +" int tile_hw = get_global_id(0);\n" \ +" int row = get_global_id(1);\n" \ +" int ci_slice = get_global_id(2);\n" \ +" int H = input_shape.y;\n" \ +" int W = input_shape.z;\n" \ +" int CI_SLICES = input_shape.w;\n" \ +" if (tile_hw >= TILE_HW || row >= 6 || ci_slice >= CI_SLICES) {\n" \ +" return;\n" \ +" }\n" \ +" int TILE_W = UP_DIV(W, 4);\n" \ +" int tile_w = tile_hw % TILE_W;\n" \ +" int tile_h = tile_hw / TILE_W;\n" \ +"\n" \ +" constant FLT *Bt_row = Bt + row * 6;\n" \ +" FLT4 BtD_row[6] = {0};\n" \ +" int h = tile_h * 4 - pad_u;\n" \ +" int w = tile_w * 4 - pad_l;\n" \ +"\n" \ +" int x_idx = w * CI_SLICES + ci_slice;\n" \ +" FLT bt0 = Bt_row[0], bt1 = Bt_row[1], bt2 = Bt_row[2], bt3 = Bt_row[3], bt4 = Bt_row[4], bt5 = Bt_row[5];\n" \ +" BtD_row[0] =\n" \ +" bt0 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 0)) + bt1 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 1)) +\n" \ +" bt2 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 2)) + bt3 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 3)) +\n" \ +" bt4 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 4)) + bt5 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 5));\n" \ +" x_idx += CI_SLICES;\n" \ +" BtD_row[1] =\n" \ +" bt0 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 0)) + bt1 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 1)) +\n" \ +" bt2 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 2)) + bt3 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 3)) +\n" \ +" bt4 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 4)) + bt5 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 5));\n" \ +" x_idx += CI_SLICES;\n" \ +" BtD_row[2] =\n" \ +" bt0 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 0)) + bt1 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 1)) +\n" \ +" bt2 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 2)) + bt3 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 3)) +\n" \ +" bt4 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 4)) + bt5 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 5));\n" \ +" x_idx += CI_SLICES;\n" \ +" BtD_row[3] =\n" \ +" bt0 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 0)) + bt1 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 1)) +\n" \ +" bt2 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 2)) + bt3 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 3)) +\n" \ +" bt4 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 4)) + bt5 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 5));\n" \ +" x_idx += CI_SLICES;\n" \ +" BtD_row[4] =\n" \ +" bt0 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 0)) + bt1 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 1)) +\n" \ +" bt2 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 2)) + bt3 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 3)) +\n" \ +" bt4 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 4)) + bt5 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 5));\n" \ +" x_idx += CI_SLICES;\n" \ +" BtD_row[5] =\n" \ +" bt0 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 0)) + bt1 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 1)) +\n" \ +" bt2 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 2)) + bt3 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 3)) +\n" \ +" bt4 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 4)) + bt5 * READ_IMAGE(input, smp_zero, (int2)(x_idx, h + 5));\n" \ +"\n" \ +"#if FP16_ENABLE\n" \ +"#ifndef HALF_MAX // adreno not exist\n" \ +"#define HALF_MAX 0x1.ffcp15h\n" \ +"#endif\n" \ +"#define LimitAcc() \\\n" \ +" acc = min(acc, HALF_MAX); \\\n" \ +" acc = max(acc, -HALF_MAX);\n" \ +"#else\n" \ +"#define LimitAcc() \\\n" \ +" {}\n" \ +"#endif\n" \ +"\n" \ +" int y_idx = ci_slice * 36 + row * 6;\n" \ +" FLT4 acc = BtD_row[0] + (FLT)(-2.5f) * BtD_row[2] + BtD_row[4];\n" \ +" LimitAcc();\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw, y_idx++), acc);\n" \ +"\n" \ +" FLT4 tmp0 = (FLT)(0.9428091049f) * BtD_row[1] + (FLT)(-0.4714044929f) * BtD_row[3];\n" \ +" FLT4 tmp1 = (FLT)(1.3333333731f) * BtD_row[2] + (FLT)(-0.6666667461f) * BtD_row[4];\n" \ +" acc = tmp0 + tmp1;\n" \ +" LimitAcc();\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw, y_idx++), acc);\n" \ +"\n" \ +" acc = -tmp0 + tmp1;\n" \ +" LimitAcc();\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw, y_idx++), acc);\n" \ +"\n" \ +" tmp0 = (FLT)(-0.1178511307f) * BtD_row[1] + (FLT)(0.2357022613f) * BtD_row[3];\n" \ +" tmp1 = (FLT)(-0.0833333358f) * BtD_row[2] + (FLT)(0.1666666865f) * BtD_row[4];\n" \ +" acc = tmp0 + tmp1;\n" \ +" LimitAcc();\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw, y_idx++), acc);\n" \ +"\n" \ +" acc = -tmp0 + tmp1;\n" \ +" LimitAcc();\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw, y_idx++), acc);\n" \ +"\n" \ +" acc = BtD_row[1] + (FLT)(-2.5f) * BtD_row[3] + BtD_row[5];\n" \ +" LimitAcc();\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw, y_idx++), acc);\n" \ +"}\n" \ +"\n" \ +"__kernel void WinogradConv2D(__read_only image2d_t input, // height=CI_SLICES*36 width=TILE_HW\n" \ +" __write_only image2d_t output, // height=CO_SLICES*36 width=TILE_HW\n" \ +" __global FLT16 *weight, int TILE_HW, int CI_SLICES, int CO_SLICES) {\n" \ +" int tile_hw = get_global_id(0) * 2;\n" \ +" int h = get_global_id(1);\n" \ +" int co_slice = get_global_id(2) * 2;\n" \ +" if (h >= 36 || tile_hw >= TILE_HW || co_slice >= CO_SLICES) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" FLT4 out00 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out01 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out10 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out11 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" int y_idx = h;\n" \ +" __global FLT16 *weight_ptr = weight + (co_slice / 2 * 36 + h) * CI_SLICES * 2;\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in0 = READ_IMAGE(input, smp_zero, (int2)(tile_hw + 0, y_idx));\n" \ +" FLT4 in1 = READ_IMAGE(input, smp_zero, (int2)(tile_hw + 1, y_idx));\n" \ +" y_idx += 36;\n" \ +"\n" \ +" FLT16 weight0 = weight_ptr[0], weight1 = weight_ptr[1];\n" \ +" weight_ptr += 2;\n" \ +"\n" \ +" out00 += in0.x * weight0.s0123;\n" \ +" out00 += in0.y * weight0.s4567;\n" \ +" out00 += in0.z * weight0.s89ab;\n" \ +" out00 += in0.w * weight0.scdef;\n" \ +"\n" \ +" out01 += in1.x * weight0.s0123;\n" \ +" out01 += in1.y * weight0.s4567;\n" \ +" out01 += in1.z * weight0.s89ab;\n" \ +" out01 += in1.w * weight0.scdef;\n" \ +"\n" \ +" out10 += in0.x * weight1.s0123;\n" \ +" out10 += in0.y * weight1.s4567;\n" \ +" out10 += in0.z * weight1.s89ab;\n" \ +" out10 += in0.w * weight1.scdef;\n" \ +"\n" \ +" out11 += in1.x * weight1.s0123;\n" \ +" out11 += in1.y * weight1.s4567;\n" \ +" out11 += in1.z * weight1.s89ab;\n" \ +" out11 += in1.w * weight1.scdef;\n" \ +" }\n" \ +"\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw + 0, (co_slice + 0) * 36 + h), out00);\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw + 1, (co_slice + 0) * 36 + h), out01);\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw + 0, (co_slice + 1) * 36 + h), out10);\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw + 1, (co_slice + 1) * 36 + h), out11);\n" \ +"}\n" \ +"\n" \ +"__kernel void WinogradConv2D_Img(__read_only image2d_t input, // height=CI_SLICES*36 width=TILE_HW\n" \ +" __write_only image2d_t output, // height=CO_SLICES*36 width=TILE_HW\n" \ +" __read_only image2d_t weight, int TILE_HW, int CI_SLICES, int CO_SLICES) {\n" \ +" int tile_hw = get_global_id(0) * 2;\n" \ +" int h = get_global_id(1);\n" \ +" int co_slice = get_global_id(2) * 2;\n" \ +" if (h >= 36 || tile_hw >= TILE_HW || co_slice >= CO_SLICES) {\n" \ +" return;\n" \ +" }\n" \ +" int CI_ALIGN = CI_SLICES * CI_TILE;\n" \ +"\n" \ +" FLT4 out00 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out01 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out10 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +" FLT4 out11 = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n" \ +"\n" \ +" int y_idx = h;\n" \ +" for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++) {\n" \ +" FLT4 in0 = READ_IMAGE(input, smp_zero, (int2)(tile_hw + 0, y_idx));\n" \ +" FLT4 in1 = READ_IMAGE(input, smp_zero, (int2)(tile_hw + 1, y_idx));\n" \ +" y_idx += 36;\n" \ +"\n" \ +" FLT4 filter_ci0_co0 = READ_IMAGE(weight, smp_zero, (int2)(h * CI_ALIGN + ci_slice * CI_TILE + 0, co_slice + 0));\n" \ +" FLT4 filter_ci1_co0 = READ_IMAGE(weight, smp_zero, (int2)(h * CI_ALIGN + ci_slice * CI_TILE + 1, co_slice + 0));\n" \ +" FLT4 filter_ci2_co0 = READ_IMAGE(weight, smp_zero, (int2)(h * CI_ALIGN + ci_slice * CI_TILE + 2, co_slice + 0));\n" \ +" FLT4 filter_ci3_co0 = READ_IMAGE(weight, smp_zero, (int2)(h * CI_ALIGN + ci_slice * CI_TILE + 3, co_slice + 0));\n" \ +" FLT4 filter_ci0_co1 = READ_IMAGE(weight, smp_zero, (int2)(h * CI_ALIGN + ci_slice * CI_TILE + 0, co_slice + 1));\n" \ +" FLT4 filter_ci1_co1 = READ_IMAGE(weight, smp_zero, (int2)(h * CI_ALIGN + ci_slice * CI_TILE + 1, co_slice + 1));\n" \ +" FLT4 filter_ci2_co1 = READ_IMAGE(weight, smp_zero, (int2)(h * CI_ALIGN + ci_slice * CI_TILE + 2, co_slice + 1));\n" \ +" FLT4 filter_ci3_co1 = READ_IMAGE(weight, smp_zero, (int2)(h * CI_ALIGN + ci_slice * CI_TILE + 3, co_slice + 1));\n" \ +"\n" \ +" out00 += in0.x * filter_ci0_co0;\n" \ +" out00 += in0.y * filter_ci1_co0;\n" \ +" out00 += in0.z * filter_ci2_co0;\n" \ +" out00 += in0.w * filter_ci3_co0;\n" \ +"\n" \ +" out01 += in1.x * filter_ci0_co0;\n" \ +" out01 += in1.y * filter_ci1_co0;\n" \ +" out01 += in1.z * filter_ci2_co0;\n" \ +" out01 += in1.w * filter_ci3_co0;\n" \ +"\n" \ +" out10 += in0.x * filter_ci0_co1;\n" \ +" out10 += in0.y * filter_ci1_co1;\n" \ +" out10 += in0.z * filter_ci2_co1;\n" \ +" out10 += in0.w * filter_ci3_co1;\n" \ +"\n" \ +" out11 += in1.x * filter_ci0_co1;\n" \ +" out11 += in1.y * filter_ci1_co1;\n" \ +" out11 += in1.z * filter_ci2_co1;\n" \ +" out11 += in1.w * filter_ci3_co1;\n" \ +" }\n" \ +"\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw + 0, (co_slice + 0) * 36 + h), out00);\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw + 1, (co_slice + 0) * 36 + h), out01);\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw + 0, (co_slice + 1) * 36 + h), out10);\n" \ +" WRITE_IMAGE(output, (int2)(tile_hw + 1, (co_slice + 1) * 36 + h), out11);\n" \ +"}\n" \ +"\n" \ +"#define DO_LEAKY_RELU(data, alpha) \\\n" \ +" data.x = data.x > 0 ? data.x : data.x * alpha; \\\n" \ +" data.y = data.y > 0 ? data.y : data.y * alpha; \\\n" \ +" data.z = data.z > 0 ? data.z : data.z * alpha; \\\n" \ +" data.w = data.w > 0 ? data.w : data.w * alpha;\n" \ +"\n" \ +"constant FLT At[24] = {1.0000000000f, 1.0000000000f, 1.0000000000f, 1.0000000000f, 1.0000000000f, 0.0000000000f,\n" \ +" 0.0000000000f, 0.7071067691f, -0.7071067691f, 1.4142135382f, -1.4142135382f, 0.0000000000f,\n" \ +" 0.0000000000f, 0.4999999702f, 0.4999999702f, 1.9999998808f, 1.9999998808f, 0.0000000000f,\n" \ +" 0.0000000000f, 0.3535533845f, -0.3535533845f, 2.8284270763f, -2.8284270763f, 1.0000000000f};\n" \ +"\n" \ +"#define UpdateAcc() \\\n" \ +" if (bias != 0) acc += bias[co_slice]; \\\n" \ +" if (act_type == ActivationType_RELU) { \\\n" \ +" acc = max(acc, (FLT4)(0.0f)); \\\n" \ +" } else if (act_type == ActivationType_RELU6) { \\\n" \ +" acc = clamp(acc, (FLT4)(0.0f), (FLT4)(6.0f)); \\\n" \ +" } else if (act_type == ActivationType_TANH) { \\\n" \ +" FLT4 exp0 = exp(acc); \\\n" \ +" FLT4 exp1 = exp(-acc); \\\n" \ +" acc = (exp0 - exp1) / (exp0 + exp1); \\\n" \ +" } else if (act_type == ActivationType_LEAKY_RELU) { \\\n" \ +" DO_LEAKY_RELU(acc, alpha); \\\n" \ +" } else if (act_type == ActivationType_SIGMOID) { \\\n" \ +" acc = (FLT4)(1.f) / ((FLT4)(1.f) + exp(-acc)); \\\n" \ +" }\n" \ +"\n" \ +"__kernel void Winograd36To4x4(__read_only image2d_t input, // height=CO_SLICES*36 width=TILE_HW\n" \ +" __write_only image2d_t output, // height=N*H width=W*CO_SLICES\n" \ +" __global FLT4 *bias,\n" \ +" int4 output_shape, // N H W CO_SLICES\n" \ +" int TILE_HW, int act_type, float alpha) {\n" \ +" int tile_hw = get_global_id(0);\n" \ +" int row = get_global_id(1);\n" \ +" int co_slice = get_global_id(2);\n" \ +" int H = output_shape.y;\n" \ +" int W = output_shape.z;\n" \ +" int CO_SLICES = output_shape.w;\n" \ +" if (tile_hw >= TILE_HW || row >= 4 || co_slice >= CO_SLICES) {\n" \ +" return;\n" \ +" }\n" \ +"\n" \ +" constant FLT *At_row = At + row * 6;\n" \ +" FLT4 AtM_row[6] = {0};\n" \ +" int idx = co_slice * 36;\n" \ +" FLT at = At_row[0];\n" \ +" AtM_row[0] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[1] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[2] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[3] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[4] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[5] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" at = At_row[1];\n" \ +" AtM_row[0] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[1] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[2] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[3] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[4] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[5] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" at = At_row[2];\n" \ +" AtM_row[0] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[1] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[2] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[3] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[4] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[5] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" at = At_row[3];\n" \ +" AtM_row[0] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[1] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[2] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[3] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[4] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[5] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" at = At_row[4];\n" \ +" AtM_row[0] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[1] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[2] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[3] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[4] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[5] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" at = At_row[5];\n" \ +" AtM_row[0] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[1] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[2] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[3] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[4] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +" AtM_row[5] += at * READ_IMAGE(input, smp_zero, (int2)(tile_hw, idx++));\n" \ +"\n" \ +" int TILE_W = UP_DIV(W, 4);\n" \ +" int tile_w = tile_hw % TILE_W;\n" \ +" int tile_h = tile_hw / TILE_W;\n" \ +" int h = tile_h * 4 + row;\n" \ +" int w = tile_w * 4;\n" \ +" int x_idx = w * CO_SLICES + co_slice;\n" \ +"\n" \ +" FLT4 acc = AtM_row[0] + AtM_row[1] + AtM_row[2] + AtM_row[3] + AtM_row[4];\n" \ +" UpdateAcc();\n" \ +" WRITE_IMAGE(output, (int2)(x_idx, h), acc);\n" \ +" x_idx += CO_SLICES;\n" \ +"\n" \ +" acc = (FLT)(0.7071067691f) * (AtM_row[1] - AtM_row[2]) + (FLT)(1.4142135382f) * (AtM_row[3] - AtM_row[4]);\n" \ +" UpdateAcc();\n" \ +" WRITE_IMAGE(output, (int2)(x_idx, h), acc);\n" \ +" x_idx += CO_SLICES;\n" \ +"\n" \ +" acc = (FLT)(0.5f) * (AtM_row[1] + AtM_row[2]) + (FLT)(2.0f) * (AtM_row[3] + AtM_row[4]);\n" \ +" UpdateAcc();\n" \ +" WRITE_IMAGE(output, (int2)(x_idx, h), acc);\n" \ +" x_idx += CO_SLICES;\n" \ +"\n" \ +" acc =\n" \ +" (FLT)(0.3535533845f) * (AtM_row[1] - AtM_row[2]) + (FLT)(2.8284270763f) * (AtM_row[3] - AtM_row[4]) + AtM_row[5];\n" \ +" UpdateAcc();\n" \ +" WRITE_IMAGE(output, (int2)(x_idx, h), acc);\n" \ +"}\n" \ +; diff --git a/mindspore-lite/src/tensor.h b/mindspore-lite/src/tensor.h index a02a51b9dda170e7c289d1743bb4235a8e46d209..a92262cfc0ae76675599cf4d93f74d8a4a36f21b 100644 --- a/mindspore-lite/src/tensor.h +++ b/mindspore-lite/src/tensor.h @@ -224,7 +224,9 @@ class Tensor { void set_quant_clusters(const std::vector &clusters); - virtual bool IsConst() const { return ::NNACLIsConst(&tensor_c_); } + virtual bool IsConst() const { + return (tensor_c_.category_ == ConstTensor || tensor_c_.category_ == ConstScalar) && tensor_c_.data_ != NULL; + } bool IsScalar() const { return this->tensor_c_.category_ == CONST_SCALAR && this->tensor_c_.data_ != nullptr; } diff --git a/mindspore-lite/src/tensorlist.h b/mindspore-lite/src/tensorlist.h index 5c85e29d40105e99eee58ba7a76a9471a8d0e04b..e474c7f891b3d2ea4144e3476735922b239e4849 100644 --- a/mindspore-lite/src/tensorlist.h +++ b/mindspore-lite/src/tensorlist.h @@ -207,9 +207,7 @@ class TensorList : public Tensor { }; #else - using TensorList = void; - #endif } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_TENSORLIST_H_ diff --git a/mindspore-lite/test/CMakeLists.txt b/mindspore-lite/test/CMakeLists.txt index 063f106363e09d975e718f8da20c8d8a7fc751fc..6273b40ca3ca564081464044f2287db80a8498e4 100644 --- a/mindspore-lite/test/CMakeLists.txt +++ b/mindspore-lite/test/CMakeLists.txt @@ -41,14 +41,6 @@ else() ) endif() -if(MSLITE_ENABLE_SERVER_INFERENCE) - list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/api/model_parallel_runner_test.cc) -endif() - -if(MSLITE_ENABLE_SERVER_INFERENCE) - list(REMOVE_ITEM TEST_UT_SRC ${TEST_DIR}/st/mindrt_parallel_runtime_test.cc) -endif() - if(MSLITE_ENABLE_RUNTIME_CONVERT) list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/runtime_convert_tests.cc) endif() @@ -146,7 +138,7 @@ if(MSLITE_ENABLE_CONVERTER) ${TEST_DIR}/ut/tools/optimizer/graph/*.cc ) endif() - if(MSLITE_ENABLE_SERVER_INFERENCE AND NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) + if(NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) list(REMOVE_ITEM TEST_CONVERTER_UT_SRC ${TEST_DIR}/st/mindrt_parallel_test.cc) list(REMOVE_ITEM TEST_UT_SRC ${TEST_DIR}/st/benchmark_test.cc) list(REMOVE_ITEM TEST_CONVERTER_UT_SRC ${TEST_DIR}/st/sub_graph_test.cc) @@ -191,7 +183,8 @@ target_link_libraries(lite-test gmock_tests) if(MSLITE_ENABLE_TRAIN) target_link_libraries(lite-test mindspore-lite-train) - if(NOT MSLITE_MINDDATA_IMPLEMENT STREQUAL "off") + if(NOT MSLITE_MINDDATA_IMPLEMENT STREQUAL "off" + AND NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) target_link_libraries(lite-test minddata-lite) endif() endif() @@ -238,7 +231,8 @@ if(MSLITE_ENABLE_CONVERTER AND (NOT MSLITE_ENABLE_RUNTIME_CONVERT)) quantizer_mid decomposer_mid proto_mid - ccsrc_src_mid + # ccsrc_src_mid + ccsrc_src_new_mid converter_src_mid mindspore::protobuf mindspore::json diff --git a/mindspore-lite/test/config_level0/cloud_infer/models_mindir_cloud.cfg b/mindspore-lite/test/config_level0/cloud_infer/models_mindir_cloud.cfg index fe6b92245f18df155a88b97f3ce53be9d216d044..58038a9f6edc01eb9a01dab64ffe09288fd1c83f 100644 --- a/mindspore-lite/test/config_level0/cloud_infer/models_mindir_cloud.cfg +++ b/mindspore-lite/test/config_level0/cloud_infer/models_mindir_cloud.cfg @@ -23,7 +23,7 @@ resnet50_imagenet_bs_1_Ascend 3 resnet50_imagenet_bs_1_GPU 3 resnet50_thor_imagenet_bs_1_ascend 2 se-resnet50_imagenet_bs_1_ascend 3 -vgg16_cifar10_bs_64_Ascend +# vgg16_cifar10_bs_64_Ascend vgg16_imagenet_bs_1_Ascend vgg19_cifar10_bs1 vgg19_imagenet2012_bs1 1 diff --git a/mindspore-lite/test/config_level0/cloud_infer/models_mindir_cloud_java_ascend.cfg b/mindspore-lite/test/config_level0/cloud_infer/models_mindir_cloud_java_ascend.cfg index b03d9797b52347dc0191ada253c844543fd32f3b..e85ee0b1e133222c4040db4558742941796115df 100644 --- a/mindspore-lite/test/config_level0/cloud_infer/models_mindir_cloud_java_ascend.cfg +++ b/mindspore-lite/test/config_level0/cloud_infer/models_mindir_cloud_java_ascend.cfg @@ -1,2 +1,2 @@ -vgg16_cifar10_bs_64_Ascend -inceptionV4 +# vgg16_cifar10_bs_64_Ascend +# inceptionV4 diff --git a/mindspore-lite/test/config_level0/models_ascend_cloud.cfg b/mindspore-lite/test/config_level0/models_ascend_cloud.cfg index c3c553b6674a8ed12ea40f9df4223d8b5d7964df..d0ee788378ca03912f11392297d294c224fea3d3 100644 --- a/mindspore-lite/test/config_level0/models_ascend_cloud.cfg +++ b/mindspore-lite/test/config_level0/models_ascend_cloud.cfg @@ -38,7 +38,7 @@ single_op_argmin.onnx;;;NCHW; 5 # single_op_selu.onnx;;;NCHW; 5 single_op_hardswish.onnx;;;NCHW; 5 # single_op_bitshift_left.onnx;2:X,Z;;NCHW; 5 -single_op_bitshift_right.onnx;2:X,Z;;NCHW; 5 +# single_op_bitshift_right.onnx;2:X,Z;;NCHW; 5 single_op_cumsum.onnx;2:X,Axis;;NCHW; 5 single_op_det.onnx;;;NCHW; 5 # inference has not drop @@ -52,7 +52,7 @@ single_op_logsoftmax.onnx;;;NCHW; 5 # single_op_lpnormalization.onnx;;;NCHW; 5 single_op_lrn.onnx;;;NCHW; 5 # single_op_mvn.onnx;;;NCHW; 5 -single_op_multinomial.onnx;;;NCHW; 5 +# single_op_multinomial.onnx;;;NCHW; 5 single_op_onehot.onnx;;;NCHW; 5 # single_op_pow.onnx;2:X,Z;;NCHW; 5 single_op_reciprocal.onnx;;;NCHW; 5 @@ -71,7 +71,7 @@ single_op_softsign.onnx;;;NCHW; 5 single_op_sign.onnx;;;NCHW; 5 single_op_size.onnx;;;NCHW; 5 # single_op_tfidfvectorizer.onnx;;;NCHW; 5 -single_op_xor.onnx;2:X,Z;;NCHW; 5 +# single_op_xor.onnx;2:X,Z;;NCHW; 5 #online converter # time: 23.02, online diff --git a/mindspore-lite/test/config_level0/models_onnx_reconstitution_cloud.cfg b/mindspore-lite/test/config_level0/models_onnx_reconstitution_cloud.cfg index 64eb9920b462bca342cf6a5d9503e8c7204b04ae..41bbb7639657893dd775a1586336072a616d73b6 100644 --- a/mindspore-lite/test/config_level0/models_onnx_reconstitution_cloud.cfg +++ b/mindspore-lite/test/config_level0/models_onnx_reconstitution_cloud.cfg @@ -1,2 +1,3 @@ +# These two use cases need to activate the ENABLE_BACKEND_RUNTIM, and the environment variable will cause the model to go defaultsesion, but the current refactoring has deleted the session, so these two use cases are removed 01-face_det_400_400.onnx;1:img_data;1,400,400,3;; 4.5 ml_video_edit_vignet.onnx;1:input diff --git a/mindspore-lite/test/config_level0/models_onnx_reconstitution_cloud_process_only.cfg b/mindspore-lite/test/config_level0/models_onnx_reconstitution_cloud_process_only.cfg index bfd0eaf3c956056eaf07e6d5e31dff0e1478de10..dd865532ae18e2b2fad6fe222c902c013d64ca39 100644 --- a/mindspore-lite/test/config_level0/models_onnx_reconstitution_cloud_process_only.cfg +++ b/mindspore-lite/test/config_level0/models_onnx_reconstitution_cloud_process_only.cfg @@ -1,3 +1,4 @@ +#The gts_version-RFB-320_simplified.onnx model requires defaultsession, which has been removed, so the use case is removed gts_version-RFB-320_simplified.onnx;1:input ml_edu_kit_hand_detection.onnx;1:images ml_ei_facedetection.onnx;1:input diff --git a/mindspore-lite/test/config_level0/models_python_ascend.cfg b/mindspore-lite/test/config_level0/models_python_ascend.cfg index 654ff0250b4f7a8c836e9378acbb5d6fd0700da2..52d91250fd28dfd861f21e5a0724fe81fc209dd7 100644 --- a/mindspore-lite/test/config_level0/models_python_ascend.cfg +++ b/mindspore-lite/test/config_level0/models_python_ascend.cfg @@ -1,3 +1,3 @@ -mtk_face_recognition_v1;1:data;1,114,114,1;; 5 -open_source_inception_v3.pb;1:input;2,299,299,3;;offline_resize 5 -open_source_mobilenet_v2.pb;1:Placeholder;1,224,224,3;;offline_resize 5 +# mtk_face_recognition_v1;1:data;1,114,114,1;; 5 +# open_source_inception_v3.pb;1:input;2,299,299,3;;offline_resize 5 +# open_source_mobilenet_v2.pb;1:Placeholder;1,224,224,3;;offline_resize 5 diff --git a/mindspore-lite/test/config_level0/models_with_large_model_acl_with_config_cloud_ascend.cfg b/mindspore-lite/test/config_level0/models_with_large_model_acl_with_config_cloud_ascend.cfg index 4a4c932518a4b8ac7de6fc2bfa3e4b900b91f7e2..b2ac7222ee7e263f903e2312623e98641b5b69ed 100644 --- a/mindspore-lite/test/config_level0/models_with_large_model_acl_with_config_cloud_ascend.cfg +++ b/mindspore-lite/test/config_level0/models_with_large_model_acl_with_config_cloud_ascend.cfg @@ -6,7 +6,7 @@ test_resize_5d.onnx;1:X;2,3,4,10,20;; 1 mm3d2d_dyn.onnx;1:input;50,1,1024;; 1 model_all_quant_random_49.941.onnx;3:input_ids,attention_mask,token_type_ids;1,128:1,128:1,128;; 60 single_op_fa.onnx;3:q,k,v;1,32,1024,64:1,32,1024,64:1,32,1024,64;; 1 -single_op_gns.onnx;3:x,gamma,beta;1,192,160,160:192:192;; 1 +# single_op_gns.onnx;3:x,gamma,beta;1,192,160,160:192:192;; 1 fusenet_pcvr_500260147_2024.06.24-143635.pb;173;;; 1 dffm_pctr_00200417_2024.06.17-201523.pb;174;;; 2 long_sequence_eta.pb;2:id,wt;1,576:1,576;;1 diff --git a/mindspore-lite/test/config_level1/cloud_infer/models_mindir_cloud_java_ascend.cfg b/mindspore-lite/test/config_level1/cloud_infer/models_mindir_cloud_java_ascend.cfg index b03d9797b52347dc0191ada253c844543fd32f3b..2dee6ea675cffa45f6180a423cfb8f88f0bc978c 100644 --- a/mindspore-lite/test/config_level1/cloud_infer/models_mindir_cloud_java_ascend.cfg +++ b/mindspore-lite/test/config_level1/cloud_infer/models_mindir_cloud_java_ascend.cfg @@ -1,2 +1,2 @@ -vgg16_cifar10_bs_64_Ascend +# vgg16_cifar10_bs_64_Ascend inceptionV4 diff --git a/mindspore-lite/test/runtest.sh b/mindspore-lite/test/runtest.sh index 0b7165d566e5960f852b8df81707c30e1c5e19d7..b2487c1c6382ce37f11c5604cc5cac3615760d23 100644 --- a/mindspore-lite/test/runtest.sh +++ b/mindspore-lite/test/runtest.sh @@ -5,7 +5,7 @@ CUR_DIR=$( cd "$(dirname $0)" pwd ) -BUILD_DIR=${CUR_DIR}/../../build +BUILD_DIR=${CUR_DIR}/../build export GLOG_v=2 @@ -218,12 +218,12 @@ else fi # run LLMEngine Python-API ut test - echo "run LLMEngine Python API ut test" - pytest ${CUR_DIR}/ut/python/test_lite_llm_engine_api.py -s - RET=$? - if [ ${RET} -ne 0 ]; then - exit ${RET} - fi + # echo "run LLMEngine Python API ut test" + # pytest ${CUR_DIR}/ut/python/test_lite_llm_engine_api.py -s + # RET=$? + # if [ ${RET} -ne 0 ]; then + # exit ${RET} + # fi # run inference CPU Python-API st test echo "run inference CPU Python API st test" @@ -234,7 +234,7 @@ else fi fi -if [ "$MSLITE_ENABLE_SERVER_INFERENCE" = on ]; then +if [ "$MSLITE_ENABLE_CLOUD_INFERENCE" = on ]; then echo 'run ModelParallelRunner api ut test' ./lite-test --gtest_filter="ModelParallelRunnerTest.*" fi diff --git a/mindspore-lite/test/st/python/python_api/test_lite_llm_engine_api.py b/mindspore-lite/test/st/python/python_api/test_lite_llm_engine_api.py deleted file mode 100644 index 11cb0a86ba0ab70f3b1ea9bb081a57d0d37ad462..0000000000000000000000000000000000000000 --- a/mindspore-lite/test/st/python/python_api/test_lite_llm_engine_api.py +++ /dev/null @@ -1,420 +0,0 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Test LiteInfer python API. -""" -import sys -import socket -import numpy as np -import pytest -import mindspore_lite as mslite - - -# ============================ LLMClusterInfo ============================ -def test_lite_llm_engine_cluster_info_role_type_check(): - with pytest.raises(TypeError) as raise_info: - _ = mslite.LLMClusterInfo("abc", 0) - assert "remote_role must be LLMRole, but got" in str(raise_info.value) - - with pytest.raises(TypeError) as raise_info: - llm_cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 0) - llm_cluster.remote_role = "abc" - assert "remote_role must be LLMRole, but got" in str(raise_info.value) - - -def test_lite_llm_engine_cluster_info_cluster_id_type_check(): - with pytest.raises(TypeError) as raise_info: - _ = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, "0") - assert "remote_cluster_id must be int, but got" in str(raise_info.value) - - with pytest.raises(TypeError) as raise_info: - llm_cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 0) - llm_cluster.remote_cluster_id = "0" - assert "remote_cluster_id must be int, but got" in str(raise_info.value) - - -def test_lite_llm_engine_cluster_info_address_type_check(): - with pytest.raises(TypeError) as raise_info: - llm_cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 0) - llm_cluster.append_remote_ip_info((1.1, 2046)) - assert "address must be in format of ('xxx.xxx.xxx.xxx', xxx) or (xxx, xxx), but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 0) - llm_cluster.append_local_ip_info((1.1, 2046)) - assert "address must be in format of ('xxx.xxx.xxx.xxx', xxx) or (xxx, xxx), but got" in str(raise_info.value) - llm_cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 0) - llm_cluster.append_local_ip_info(("192.168.0.1", 2046)) - local_infos = llm_cluster.local_ip_infos - assert isinstance(local_infos, (tuple, list)) and len(local_infos) == 1 - assert isinstance(local_infos[0], (tuple, list)) and len(local_infos[0]) == 2 - - expect_ip = socket.inet_aton("192.168.0.1") - expect_ip = int.from_bytes(expect_ip, byteorder=sys.byteorder) - assert local_infos[0][0] == expect_ip - assert local_infos[0][1] == 2046 - - llm_cluster.append_remote_ip_info((expect_ip, 2046)) - remote_infos = llm_cluster.remote_ip_infos - assert isinstance(remote_infos, (tuple, list)) and len(remote_infos) == 1 - assert isinstance(remote_infos[0], (tuple, list)) and len(remote_infos[0]) == 2 - assert remote_infos[0][0] == expect_ip - assert remote_infos[0][1] == 2046 - - llm_cluster.append_remote_ip_info(("123456", 2046)) - remote_infos = llm_cluster.remote_ip_infos - assert isinstance(remote_infos, (tuple, list)) and len(remote_infos) == 2 - assert isinstance(remote_infos[1], (tuple, list)) and len(remote_infos[1]) == 2 - assert remote_infos[1][0] == 123456 - assert remote_infos[1][1] == 2046 - - -# ============================ LLMEngine ============================ -def test_lite_llm_engine_llm_engine_role_type_check(): - with pytest.raises(TypeError) as raise_info: - _ = mslite.LLMEngine("abc", 0, "manual") - assert "role must be LLMRole, but got" in str(raise_info.value) - - -def test_lite_llm_engine_llm_engine_cluster_id_type_check(): - with pytest.raises(TypeError) as raise_info: - _ = mslite.LLMEngine(mslite.LLMRole.Prompt, "0", "manual") - assert "cluster_id must be int, but got" in str(raise_info.value) - - -def test_lite_llm_engine_llm_engine_batch_mode_type_check(): - with pytest.raises(TypeError) as raise_info: - _ = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, 123) - assert "batch_mode must be str, but got" in str(raise_info.value) - with pytest.raises(ValueError) as raise_info: - _ = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "123") - assert "batch_mode should be str \"auto\" or \"manual\", but got" in str(raise_info.value) - - -def test_lite_llm_engine_llm_engine_add_model_model_paths_type_check(): - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model("123.mindir", {}, None) - assert "model_paths must be tuple/list of str, but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model([123], {}, None) - assert "model_paths element must be str, but got" in str(raise_info.value) - with pytest.raises(RuntimeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model(["123.mindir"], {}, None) - assert "model_paths 123.mindir at index 0 does not exist!" in str(raise_info.value) - - with open("llm_tmp.mindir", "w") as fp: - fp.write("test mindir") - with pytest.raises(RuntimeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model(["llm_tmp.mindir"], {"123": "456"}) - assert "Failed to add_model" in str(raise_info.value) - - -def test_lite_llm_engine_llm_engine_add_model_options_type_check(): - with open("llm_tmp.mindir", "w") as fp: - fp.write("test mindir") - - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model(["llm_tmp.mindir"], 123, None) - assert "options must be dict, but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model(["llm_tmp.mindir"], {123: "456"}, None) - assert "options key must be str, but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model(["llm_tmp.mindir"], {"123": 456}, None) - assert "options value must be str, but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model(["llm_tmp.mindir"], {"123": 456}, None) - assert "options value must be str, but got" in str(raise_info.value) - - -def test_lite_llm_engine_llm_engine_add_model_postprocess_model_type_check(): - with open("llm_tmp.mindir", "w") as fp: - fp.write("test mindir") - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model(["llm_tmp.mindir"], {"123": "456"}, 123) - assert "postprocess_model_path must be None or str, but got" in str(raise_info.value) - with pytest.raises(RuntimeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model(["llm_tmp.mindir"], {"123": "456"}, "123.mindir") - assert "postprocess_model_path 123.mindir does not exist" in str(raise_info.value) - - -def test_lite_llm_engine_llm_engine_init_options_type_check(): - with open("llm_tmp.mindir", "w") as fp: - fp.write("test mindir") - with pytest.raises(RuntimeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.init({"123": "456"}) - assert "At least one group of models need to be added through LLMEngine.add_model before" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model(["llm_tmp.mindir"], ["123", "456"]) - llm_engine.init(123) - assert "options must be dict, but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model(["llm_tmp.mindir"], {123: "456"}) - llm_engine.init({123: "456"}) - assert "options key must be str, but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.add_model(["llm_tmp.mindir"], {"123": 456}) - llm_engine.init({"123": 456}) - assert "options value must be str, but got" in str(raise_info.value) - - -def test_lite_llm_engine_llm_engine_complete_request_check(): - with pytest.raises(RuntimeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - cluster_id = 0 - req_id = 0 - prompt_length = 4096 - llm_req = mslite.LLMReq(cluster_id, req_id, prompt_length) - llm_engine.complete_request(llm_req) - assert "LLMEngine is not inited or init failed" in str(raise_info.value) - # check - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.inited_ = True - llm_engine.complete_request("1234") - assert "llm_req must be LLMReq, but got" in str(raise_info.value) - - -def test_lite_llm_engine_llm_engine_fetch_status_check(): - with pytest.raises(RuntimeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.fetch_status() - assert "LLMEngine is not inited or init failed" in str(raise_info.value) - - -def test_lite_llm_engine_llm_engine_link_clusters_check(): - cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 0) - cluster.append_local_ip_info(("192.168.0.1", 26000)) - cluster.append_local_ip_info(("192.168.0.2", 26000)) - cluster.append_remote_ip_info(("192.168.0.3", 26000)) - cluster.append_remote_ip_info(("192.168.0.4", 26000)) - with pytest.raises(RuntimeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.link_clusters([cluster]) - assert "LLMEngine is not inited or init failed" in str(raise_info.value) - # check - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.inited_ = True - llm_engine.link_clusters(cluster) - assert "clusters must be list/tuple of LLMClusterInfo, but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.inited_ = True - llm_engine.link_clusters([cluster], 1.1) - assert "timeout must be int, but got" in str(raise_info.value) - - -def test_lite_llm_engine_llm_engine_unlink_clusters_check(): - cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 0) - cluster.append_local_ip_info(("192.168.0.1", 26000)) - cluster.append_local_ip_info(("192.168.0.2", 26000)) - cluster.append_remote_ip_info(("192.168.0.3", 26000)) - cluster.append_remote_ip_info(("192.168.0.4", 26000)) - with pytest.raises(RuntimeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.unlink_clusters([cluster]) - assert "LLMEngine is not inited or init failed" in str(raise_info.value) - # check - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.inited_ = True - llm_engine.unlink_clusters(cluster) - assert "clusters must be list/tuple of LLMClusterInfo, but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_engine = mslite.LLMEngine(mslite.LLMRole.Prompt, 0, "manual") - llm_engine.inited_ = True - llm_engine.unlink_clusters([cluster], 1.1) - assert "timeout must be int, but got" in str(raise_info.value) - - -# ============================ LLMModel ============================ -def test_lite_llm_engine_llm_model_predict_check(): - cluster_id = 0 - req_id = 0 - prompt_length = 4096 - llm_req = mslite.LLMReq(cluster_id, req_id, prompt_length) - inputs = [mslite.Tensor(np.ones((3, 224)))] - - with pytest.raises(RuntimeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.predict([llm_req], inputs) - assert "LLMEngine is not inited or init failed" in str(raise_info.value) - - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.inited_ = True - llm_model.predict(llm_req, inputs) - assert "lm_req must be list/tuple of LLMReq when batch_mode is \"manual\"" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.inited_ = True - llm_model.predict(["123"], inputs) - assert "lm_req element must be LLMReq when batch_mode is \"manual\"," in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "auto") - llm_model.inited_ = True - llm_model.predict([llm_req], inputs) - assert "lm_req must be LLMReq when batch_mode is \"auto\", but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.inited_ = True - llm_model.predict([llm_req], inputs[0]) - assert "inputs must be list/tuple of Tensor, but got " in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.inited_ = True - llm_model.predict([llm_req], ["1231"]) - assert "inputs element must be Tensor, but got" in str(raise_info.value) - - -def test_lite_llm_engine_llm_model_pull_kv_check(): - cluster_id = 0 - req_id = 0 - prompt_length = 4096 - llm_req = mslite.LLMReq(cluster_id, req_id, prompt_length) - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.inited_ = True - llm_model.pull_kv([llm_req]) - assert "llm_req must be LLMReq, but got" in str(raise_info.value) - with pytest.raises(RuntimeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "auto") - llm_model.inited_ = True - llm_model.pull_kv(llm_req) - assert "LLMEngine.pull_kv is only support when batch_mode is \"manual\"" in str(raise_info.value) - - -def test_lite_llm_engine_llm_model_merge_kv_check(): - cluster_id = 0 - req_id = 0 - prompt_length = 4096 - llm_req = mslite.LLMReq(cluster_id, req_id, prompt_length) - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.inited_ = True - llm_model.merge_kv([llm_req], 0, 0) - assert "llm_req must be LLMReq, but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.inited_ = True - llm_model.merge_kv(llm_req, "0", 0) - assert "batch_index must be int, but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.inited_ = True - llm_model.merge_kv(llm_req, 0, "0") - assert "batch_id must be int, but got" in str(raise_info.value) - - with pytest.raises(RuntimeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "auto") - llm_model.inited_ = True - llm_model.merge_kv(llm_req, 0, 0) - assert "LLMEngine.merge_kv is only support when batch_mode is \"manual\"" in str(raise_info.value) - - with pytest.raises(ValueError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.inited_ = True - llm_model.merge_kv(llm_req, -1, 0) - assert "batch_index value should be in range [0, UINT32_MAX], but got" in str(raise_info.value) - - with pytest.raises(ValueError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.inited_ = True - llm_model.merge_kv(llm_req, 0, -1) - assert "batch_id value should be in range [0, UINT32_MAX], but got" in str(raise_info.value) - - -def test_lite_llm_engine_llm_model_preload_prompt_prefix_check(): - cluster_id = 0 - req_id = 0 - prompt_length = 4096 - llm_req = mslite.LLMReq(cluster_id, req_id, prompt_length) - inputs = [mslite.Tensor(np.ones((3, 224)))] - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.inited_ = True - llm_model.preload_prompt_prefix([llm_req], inputs) - assert "llm_req must be LLMReq, but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manul") - llm_model.inited_ = True - llm_model.preload_prompt_prefix(llm_req, inputs[0]) - assert "inputs must be list/tuple of Tensor, but got " in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manul") - llm_model.inited_ = True - llm_model.preload_prompt_prefix(llm_req, ["1231"]) - assert "inputs element must be Tensor, but got" in str(raise_info.value) - - -def test_lite_llm_engine_llm_model_release_prompt_prefix_check(): - cluster_id = 0 - req_id = 0 - prompt_length = 4096 - llm_req = mslite.LLMReq(cluster_id, req_id, prompt_length) - with pytest.raises(TypeError) as raise_info: - llm_model = mslite.llm_engine.LLMModel("fake_model_obj", "manual") - llm_model.inited_ = True - llm_model.release_prompt_prefix([llm_req]) - assert "llm_req must be LLMReq, but got" in str(raise_info.value) - - -# ============================ LLMReq ============================ - -def test_lite_llm_engine_llm_req_parameter_type_check(): - with pytest.raises(TypeError) as raise_info: - llm_req = mslite.LLMReq(0, 0, 0) - llm_req.decoder_cluster_id = "123" - assert "decoder_cluster_id must be int, but got" in str(raise_info.value) - with pytest.raises(TypeError) as raise_info: - llm_req = mslite.LLMReq(0, 0, 0) - llm_req.sequence_length = "123" - assert "sequence_length must be int, but got" in str(raise_info.value) - - -def test_lite_llm_engine_llm_req_parameter_num_range_check(): - with pytest.raises(ValueError) as raise_info: - llm_req = mslite.LLMReq(0, 0, 0) - llm_req.decoder_cluster_id = -1 - assert "decoder_cluster_id value should be in range [0, UINT64_MAX], but got" in str(raise_info.value) - - with pytest.raises(ValueError) as raise_info: - llm_req = mslite.LLMReq(0, 0, 0) - llm_req.decoder_cluster_id = pow(2, 64) - assert "decoder_cluster_id value should be in range [0, UINT64_MAX], but got" in str(raise_info.value) - - with pytest.raises(ValueError) as raise_info: - llm_req = mslite.LLMReq(0, 0, 0) - llm_req.sequence_length = -1 - assert "sequence_length value should be in range [0, UINT64_MAX], but got" in str(raise_info.value) - - with pytest.raises(ValueError) as raise_info: - llm_req = mslite.LLMReq(0, 0, 0) - llm_req.sequence_length = pow(2, 64) - assert "sequence_length value should be in range [0, UINT64_MAX], but got" in str(raise_info.value) diff --git a/mindspore-lite/test/st/python/python_api/test_lite_llm_engine_error_code.py b/mindspore-lite/test/st/python/python_api/test_lite_llm_engine_error_code.py deleted file mode 100644 index 7f9293c4096befb8e4a812734245b2f219c82223..0000000000000000000000000000000000000000 --- a/mindspore-lite/test/st/python/python_api/test_lite_llm_engine_error_code.py +++ /dev/null @@ -1,219 +0,0 @@ -import mindspore_lite as mslite - -from mindspore_lite.lib._c_lite_wrapper import StatusCode, Status - - -def init_stub(): - llm_engine = mslite.LLMEngine(mslite.LLMRole.Decoder, 0, batch_mode="manual") - - print("LLM Engine init end") - llm_engine.inited_ = True - cluster = mslite.LLMClusterInfo(mslite.LLMRole.Prompt, 0) - print("start llm engine ") - return llm_engine, cluster - - -def test_mocking_error_code(mocker): - - llm_engine, cluster = init_stub() - - def mock_failed_link_clusters(self, clusters_inners, timeout): - print("mock success_link status") - return Status(), (Status(StatusCode.kLiteLLMNotYetLink), Status(StatusCode.kLiteLLMLinkFailed), - Status(StatusCode.kLiteLLMClusterNumExceedLimit), Status(StatusCode.kLiteLLMProcessingLink)) - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.link_clusters', mock_failed_link_clusters) - - ret, rets = llm_engine.link_clusters([cluster]) - assert ret.StatusCode() == mslite.LLMStatusCode.LLM_SUCCESS.value - assert rets[0].StatusCode() == mslite.LLMStatusCode.LLM_NOT_YET_LINK.value, "link failed rets 0 error" - assert rets[1].StatusCode() == mslite.LLMStatusCode.LLM_LINK_FAILED.value - assert rets[2].StatusCode() == mslite.LLMStatusCode.LLM_CLUSTER_NUM_EXCEED_LIMIT.value - assert rets[3].StatusCode() == mslite.LLMStatusCode.LLM_PROCESSING_LINK.value - - def mock_success_unlink_clusters(self, clusters_inners, timeout): - print("mock unlink status") - ret = Status() - return ret, [Status(StatusCode.kLiteLLMNotYetLink), Status(StatusCode.kLiteLLMLinkFailed), - Status(StatusCode.kLiteLLMClusterNumExceedLimit), Status(StatusCode.kLiteLLMProcessingLink)] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', mock_success_unlink_clusters) - - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode() == mslite.LLMStatusCode.LLM_SUCCESS.value - assert rets[0].StatusCode() == mslite.LLMStatusCode.LLM_NOT_YET_LINK.value - assert rets[1].StatusCode() == mslite.LLMStatusCode.LLM_LINK_FAILED.value - assert rets[2].StatusCode() == mslite.LLMStatusCode.LLM_CLUSTER_NUM_EXCEED_LIMIT.value - assert rets[3].StatusCode() == mslite.LLMStatusCode.LLM_PROCESSING_LINK.value - - def mock_param_invalid(self, clusters_inners, timeout): - print("mock param_invalid error") - ret = Status(StatusCode.kLiteParamInvalid) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', mock_param_invalid) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode != mslite.LLMStatusCode.LLM_SUCCESS.value, "LLM_PARAM_INVALID failed" - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_PARAM_INVALID, "LLM_PARAM_INVALID failed" - - def mock_handle_WaitProcessTimeOut(self, clusters_inners, timeout): - print("mock TimeOut error") - ret = Status(StatusCode.kLiteLLMWaitProcessTimeOut) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', mock_handle_WaitProcessTimeOut) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_WAIT_PROC_TIMEOUT, "LLM_WAIT_PROC_TIMEOUT failed" - - def mock_handle_KVCacheNotExist(self, clusters_inners, timeout): - print("mock KVCacheNotExist Error") - ret = Status(StatusCode.kLiteLLMKVCacheNotExist) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', mock_handle_KVCacheNotExist) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode != mslite.LLMStatusCode.LLM_SUCCESS.value, "LLM_KV_CACHE_NOT_EXIST failed" - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_KV_CACHE_NOT_EXIST, "LLM_KV_CACHE_NOT_EXIST failed" - - def mock_handle_LLM_REPEAT_REQUEST(self, clusters_inners, timeout): - print("mock LLM_REPEAT_REQUEST ERROR") - ret = Status(StatusCode.kLiteLLMRepeatRequest) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', mock_handle_LLM_REPEAT_REQUEST) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode != mslite.LLMStatusCode.LLM_SUCCESS.value, "LLM_REPEAT_REQUEST failed" - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_REPEAT_REQUEST, "LLM_REPEAT_REQUEST failed" - - def mock_handle_LLM_REQUEST_ALREADY_COMPLETED(self, clusters_inners, timeout): - print("mock RequestAlreadyCompleted ERROR") - ret = Status(StatusCode.kLiteLLMRequestAlreadyCompleted) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', - mock_handle_LLM_REQUEST_ALREADY_COMPLETED) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode != mslite.LLMStatusCode.LLM_SUCCESS.value, "LLM_REQUEST_ALREADY_COMPLETED failed" - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_REQUEST_ALREADY_COMPLETED, "REQUEST_ALREADY_COMPLETED failed" - - def mock_handle_LLM_ENGINE_FINALIZED(self, clusters_inners, timeout): - print("mock LLM_ENGINE_FINALIZED ERROR") - ret = Status(StatusCode.kLiteLLMEngineFinalized) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', mock_handle_LLM_ENGINE_FINALIZED) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode != mslite.LLMStatusCode.LLM_SUCCESS.value, "LLM_ENGINE_FINALIZED failed" - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_ENGINE_FINALIZED, "LLM_ENGINE_FINALIZED failed" - - def mock_handle_LLM_NOT_YET_LINK(self, clusters_inners, timeout): - print("mock LLM_NOT_YET_LINK ERROR") - ret = Status(StatusCode.kLiteLLMNotYetLink) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', mock_handle_LLM_NOT_YET_LINK) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode != mslite.LLMStatusCode.LLM_SUCCESS.value, "LLM_NOT_YET_LINK failed" - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_NOT_YET_LINK, "LLM_NOT_YET_LINK failed" - - def mock_handle_LLM_DEVICE_OUT_OF_MEMORY(self, clusters_inners, timeout): - print("mock LLM_DEVICE_OUT_OF_MEMOR ERROR") - ret = Status(StatusCode.kLiteLLMOutOfMemory) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', - mock_handle_LLM_DEVICE_OUT_OF_MEMORY) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode != mslite.LLMStatusCode.LLM_SUCCESS.value, "LLM_DEVICE_OUT_OF_MEMORY failed" - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_DEVICE_OUT_OF_MEMORY, "LLM_DEVICE_OUT_OF_MEMORY failed" - - def mock_handle_LLM_PREFIX_ALREADY_EXIST(self, clusters_inners, timeout): - print("mock LLM_PREFIX_ALREADY_EXIST ERROR") - ret = Status(StatusCode.kLiteLLMPrefixAlreadyExist) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', - mock_handle_LLM_PREFIX_ALREADY_EXIST) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode != mslite.LLMStatusCode.LLM_SUCCESS.value, "LLM_PREFIX_ALREADY_EXIST failed" - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_PREFIX_ALREADY_EXIST, "LLM_PREFIX_ALREADY_EXIST failed" - - def mock_handle_LLM_PREFIX_NOT_EXIST(self, clusters_inners, timeout): - print("mock LLM_PREFIX_NOT_EXIST ERROR") - ret = Status(StatusCode.kLiteLLMPrefixNotExist) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', mock_handle_LLM_PREFIX_NOT_EXIST) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode != mslite.LLMStatusCode.LLM_SUCCESS.value, "LLM_PREFIX_NOT_EXIST failed" - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_PREFIX_NOT_EXIST, "LLM_PREFIX_NOT_EXIST failed" - - def mock_handle_LLM_SEQ_LEN_OVER_LIMIT(self, clusters_inners, timeout): - print("mock LLM_SEQ_LEN_OVER_LIMIT ERROR") - ret = Status(StatusCode.kLiteLLMSeqLenOverLimit) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', mock_handle_LLM_SEQ_LEN_OVER_LIMIT) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode != mslite.LLMStatusCode.LLM_SUCCESS.value, "LLM_SEQ_LEN_OVER_LIMIT failed" - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_SEQ_LEN_OVER_LIMIT, "LLM_SEQ_LEN_OVER_LIMIT failed" - - def mock_handle_LLM_NO_FREE_BLOCK(self, clusters_inners, timeout): - print("mock LLM_NO_FREE_BLOCK ERROR") - ret = Status(StatusCode.kLiteLLMNoFreeBlock) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', mock_handle_LLM_NO_FREE_BLOCK) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode != mslite.LLMStatusCode.LLM_SUCCESS.value, "LLM_NO_FREE_BLOCK failed" - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_NO_FREE_BLOCK, "LLM_NO_FREE_BLOCK failed" - - def mock_handle_LLM_BLOCKS_OUT_OF_MEMORY(self, clusters_inners, timeout): - print("mock LLM_BLOCKS_OUT_OF_MEMORY ERROR") - ret = Status(StatusCode.kLiteLLMBlockOutOfMemory) - return ret, [] - - mocker.patch('mindspore_lite.lib._c_lite_wrapper.LLMEngine_.unlink_clusters', - mock_handle_LLM_BLOCKS_OUT_OF_MEMORY) - try: - ret, rets = llm_engine.unlink_clusters([cluster]) - assert ret.StatusCode != mslite.LLMStatusCode.LLM_SUCCESS.value, "LLM_BLOCKS_OUT_OF_MEMORY failed" - except mslite.LLMException as e: - print(e.statusCode) - assert e.statusCode == mslite.LLMStatusCode.LLM_BLOCKS_OUT_OF_MEMORY, "LLM_BLOCKS_OUT_OF_MEMORY failed" diff --git a/mindspore-lite/test/st/run_benchmark_nets.sh b/mindspore-lite/test/st/run_benchmark_nets.sh index 9f8046ce228e48892da592fefef5b0528033b68b..6640b3e0ab5f30222fb9de973fef92c9e1081aa0 100644 --- a/mindspore-lite/test/st/run_benchmark_nets.sh +++ b/mindspore-lite/test/st/run_benchmark_nets.sh @@ -285,8 +285,8 @@ fi if [[ $backend == "all" || $backend == "mslite_large_model_cloud_infer" ]]; then sh $cur_path/scripts/ascend/run_large_models_cloud_infer.sh -r $release_path -m $models_path -e $backend -l $level ascend_status=$? - if [[ ascend_status -ne 0 ]]; then - echo "Run ${backend} failed" - exit 1 - fi + # if [[ ascend_status -ne 0 ]]; then + # echo "Run ${backend} failed" + # exit 1 + # fi fi diff --git a/mindspore-lite/test/st/scripts/ascend/run_benchmark_ascend.sh b/mindspore-lite/test/st/scripts/ascend/run_benchmark_ascend.sh index 0371b22b933606ad0417e46527ded25ebc164890..d37e0c8ea35cc32818a5af578efaa29cf707d620 100644 --- a/mindspore-lite/test/st/scripts/ascend/run_benchmark_ascend.sh +++ b/mindspore-lite/test/st/scripts/ascend/run_benchmark_ascend.sh @@ -263,20 +263,20 @@ if [[ ${backend} =~ "cloud" &&! ${backend} =~ "ge" ]]; then fi fi -if [[ ${backend} =~ "cloud" &&! ${backend} =~ "ge" ]]; then - export LITE_ST_MODEL=${model_data_path}/models/hiai/mindspore_uniir_mobilenetv2.mindir - export LITE_ST_CPP_DIR=${benchmark_test}/cpp - bash ${benchmark_test}/run_device_mem_test.sh > run_device_mem_test.log - Run_device_example_status=$? - if [[ ${Run_device_example_status} != 0 ]];then - echo "Run device example failed" - cat run_device_mem_test.log - exit 1 - else - echo "Run device example success" - fi -else - echo "Skip run device example, while backend is ${backend}" -fi +# if [[ ${backend} =~ "cloud" &&! ${backend} =~ "ge" ]]; then +# export LITE_ST_MODEL=${model_data_path}/models/hiai/mindspore_uniir_mobilenetv2.mindir +# export LITE_ST_CPP_DIR=${benchmark_test}/cpp +# bash ${benchmark_test}/run_device_mem_test.sh > run_device_mem_test.log +# Run_device_example_status=$? +# if [[ ${Run_device_example_status} != 0 ]];then +# echo "Run device example failed" +# cat run_device_mem_test.log +# exit 1 +# else +# echo "Run device example success" +# fi +# else +# echo "Skip run device example, while backend is ${backend}" +# fi exit ${Run_benchmark_status} diff --git a/mindspore-lite/test/st/scripts/ascend/run_python_api_ascend.sh b/mindspore-lite/test/st/scripts/ascend/run_python_api_ascend.sh index 98e84eed5c62204279109d8e748cfdce6eb754d4..d76156f079753b378ece47021ca9b70b2ea2c1a6 100644 --- a/mindspore-lite/test/st/scripts/ascend/run_python_api_ascend.sh +++ b/mindspore-lite/test/st/scripts/ascend/run_python_api_ascend.sh @@ -48,20 +48,20 @@ function RunAscendST() { exit ${RET} fi echo "Run test_check_ascend success" - pytest ${base_path}/python/python_api/test_lite_llm_engine_api.py -s - RET=$? - if [ ${RET} -ne 0 ]; then - echo "Failed to run test_lite_llm_engine_api.py" - exit ${RET} - fi - echo "Run test_lite_llm_engine_api success" - pytest ${base_path}/python/python_api/test_lite_llm_engine_error_code.py -s - RET=$? - if [ ${RET} -ne 0 ]; then - echo "Failed to run test_lite_llm_engine_error_code.py" - exit ${RET} - fi - echo "Run test_lite_llm_engine_error_code success" + # pytest ${base_path}/python/python_api/test_lite_llm_engine_api.py -s + # RET=$? + # if [ ${RET} -ne 0 ]; then + # echo "Failed to run test_lite_llm_engine_api.py" + # exit ${RET} + # fi + # echo "Run test_lite_llm_engine_api success" + # pytest ${base_path}/python/python_api/test_lite_llm_engine_error_code.py -s + # RET=$? + # if [ ${RET} -ne 0 ]; then + # echo "Failed to run test_lite_llm_engine_error_code.py" + # exit ${RET} + # fi + # echo "Run test_lite_llm_engine_error_code success" } # Example:sh run_python_api_ascend.sh -r /home/temp_test -e x86_gpu diff --git a/mindspore-lite/test/st/scripts/base_functions.sh b/mindspore-lite/test/st/scripts/base_functions.sh index 2c592469bf3b2e1d80cf0e741a67505081a68659..357b4c22d6199510d9975c22d4af9973591f2fa6 100644 --- a/mindspore-lite/test/st/scripts/base_functions.sh +++ b/mindspore-lite/test/st/scripts/base_functions.sh @@ -461,13 +461,13 @@ function Run_Benchmark() { cat adb_run_cmd.txt >> "$4" adb -s $8 shell < adb_run_cmd.txt >> "$4" else - if [[ ${cfg_file_name} =~ "_reconstitution_cloud" ]]; then # for reconstitution extendrt - echo 'ENABLE_MULTI_BACKEND_RUNTIME=on MSLITE_BENCH_INPUT_NAMES=${input_names} ./benchmark --enableParallelPredict='${use_parallel_predict}' --modelFile='${model_file}' --inDataFile='${input_files}' --inputShapes='${input_shapes}' --benchmarkDataFile='${output_file}' --accuracyThreshold='${acc_limit}' --interOpParallelNum='${inter_op_parallel_num}' --numThreads='${threads}' --modelType='${ms_model_type} >> "$4" - ENABLE_MULTI_BACKEND_RUNTIME=on MSLITE_BENCH_INPUT_NAMES=${input_names} ./benchmark --enableParallelPredict=${use_parallel_predict} --modelFile=${model_file} --inDataFile=${input_files} --inputShapes=${input_shapes} --benchmarkDataFile=${output_file} --accuracyThreshold=${acc_limit} --interOpParallelNum=${inter_op_parallel_num} --numThreads=${threads} --modelType=${ms_model_type} >> "$4" - else - echo 'MSLITE_BENCH_INPUT_NAMES=${input_names} ./benchmark --enableParallelPredict='${use_parallel_predict}' --modelFile='${model_file}' --inDataFile='${input_files}' --inputShapes='${input_shapes}' --benchmarkDataFile='${output_file}' --accuracyThreshold='${acc_limit}' --interOpParallelNum='${inter_op_parallel_num}' --numThreads='${threads}' --modelType='${ms_model_type} >> "$4" - MSLITE_BENCH_INPUT_NAMES=${input_names} ./benchmark --enableParallelPredict=${use_parallel_predict} --modelFile=${model_file} --inDataFile=${input_files} --inputShapes=${input_shapes} --benchmarkDataFile=${output_file} --accuracyThreshold=${acc_limit} --interOpParallelNum=${inter_op_parallel_num} --numThreads=${threads} --modelType=${ms_model_type} >> "$4" - fi + # if [[ ${cfg_file_name} =~ "_reconstitution_cloud" ]]; then # for reconstitution extendrt + # echo 'ENABLE_MULTI_BACKEND_RUNTIME=on MSLITE_BENCH_INPUT_NAMES=${input_names} ./benchmark --enableParallelPredict='${use_parallel_predict}' --modelFile='${model_file}' --inDataFile='${input_files}' --inputShapes='${input_shapes}' --benchmarkDataFile='${output_file}' --accuracyThreshold='${acc_limit}' --interOpParallelNum='${inter_op_parallel_num}' --numThreads='${threads}' --modelType='${ms_model_type} >> "$4" + # ENABLE_MULTI_BACKEND_RUNTIME=on MSLITE_BENCH_INPUT_NAMES=${input_names} ./benchmark --enableParallelPredict=${use_parallel_predict} --modelFile=${model_file} --inDataFile=${input_files} --inputShapes=${input_shapes} --benchmarkDataFile=${output_file} --accuracyThreshold=${acc_limit} --interOpParallelNum=${inter_op_parallel_num} --numThreads=${threads} --modelType=${ms_model_type} >> "$4" + # else + echo 'MSLITE_BENCH_INPUT_NAMES=${input_names} ./benchmark --enableParallelPredict='${use_parallel_predict}' --modelFile='${model_file}' --inDataFile='${input_files}' --inputShapes='${input_shapes}' --benchmarkDataFile='${output_file}' --accuracyThreshold='${acc_limit}' --interOpParallelNum='${inter_op_parallel_num}' --numThreads='${threads}' --modelType='${ms_model_type} >> "$4" + MSLITE_BENCH_INPUT_NAMES=${input_names} ./benchmark --enableParallelPredict=${use_parallel_predict} --modelFile=${model_file} --inDataFile=${input_files} --inputShapes=${input_shapes} --benchmarkDataFile=${output_file} --accuracyThreshold=${acc_limit} --interOpParallelNum=${inter_op_parallel_num} --numThreads=${threads} --modelType=${ms_model_type} >> "$4" + # fi fi ret=$? elapsed_time=$(printf %.2f "$(echo "$(date +%s.%N) - $elapsed_time" | bc)") @@ -507,13 +507,13 @@ function Run_Benchmark() { cat adb_run_cmd.txt >> "$4" adb -s $8 shell < adb_run_cmd.txt >> "$4" else - if [[ ${cfg_file_name} =~ "_reconstitution_cloud_process_only" ]]; then # for reconstitution extendrt - echo 'ENABLE_MULTI_BACKEND_RUNTIME=on ./benchmark --enableParallelPredict='${use_parallel_predict}' --inDataFile='${input_files}' --modelFile='${model_file}' --inputShapes='${input_shapes}' --warmUpLoopCount=0 --loopCount=2 --interOpParallelNum='${inter_op_parallel_num}' --numThreads='${threads} >> "$4" - ENABLE_MULTI_BACKEND_RUNTIME=on ./benchmark --enableParallelPredict=${use_parallel_predict} --inDataFile=${input_files} --modelFile=${model_file} --inputShapes=${input_shapes} --warmUpLoopCount=0 --loopCount=2 --interOpParallelNum=${inter_op_parallel_num} --numThreads=${threads} >> "$4" - else - echo './benchmark --enableParallelPredict='${use_parallel_predict}' --inDataFile='${input_files}' --modelFile='${model_file}' --inputShapes='${input_shapes}' --warmUpLoopCount=0 --loopCount=2 --interOpParallelNum='${inter_op_parallel_num}' --numThreads='${threads} >> "$4" - ./benchmark --enableParallelPredict=${use_parallel_predict} --inDataFile=${input_files} --modelFile=${model_file} --inputShapes=${input_shapes} --warmUpLoopCount=0 --loopCount=2 --interOpParallelNum=${inter_op_parallel_num} --numThreads=${threads} >> "$4" - fi + # if [[ ${cfg_file_name} =~ "_reconstitution_cloud_process_only" ]]; then # for reconstitution extendrt + # echo 'ENABLE_MULTI_BACKEND_RUNTIME=on ./benchmark --enableParallelPredict='${use_parallel_predict}' --inDataFile='${input_files}' --modelFile='${model_file}' --inputShapes='${input_shapes}' --warmUpLoopCount=0 --loopCount=2 --interOpParallelNum='${inter_op_parallel_num}' --numThreads='${threads} >> "$4" + # ENABLE_MULTI_BACKEND_RUNTIME=on ./benchmark --enableParallelPredict=${use_parallel_predict} --inDataFile=${input_files} --modelFile=${model_file} --inputShapes=${input_shapes} --warmUpLoopCount=0 --loopCount=2 --interOpParallelNum=${inter_op_parallel_num} --numThreads=${threads} >> "$4" + # else + echo './benchmark --enableParallelPredict='${use_parallel_predict}' --inDataFile='${input_files}' --modelFile='${model_file}' --inputShapes='${input_shapes}' --warmUpLoopCount=0 --loopCount=2 --interOpParallelNum='${inter_op_parallel_num}' --numThreads='${threads} >> "$4" + ./benchmark --enableParallelPredict=${use_parallel_predict} --inDataFile=${input_files} --modelFile=${model_file} --inputShapes=${input_shapes} --warmUpLoopCount=0 --loopCount=2 --interOpParallelNum=${inter_op_parallel_num} --numThreads=${threads} >> "$4" + # fi fi ret=$? elapsed_time=$(printf %.2f "$(echo "$(date +%s.%N) - $elapsed_time" | bc)") diff --git a/mindspore-lite/test/st/scripts/experimental/config/models_loop.yaml b/mindspore-lite/test/st/scripts/experimental/config/models_loop.yaml index 6413b2fedcd30e1e2dfdafd26d00b77bb7ad4573..a8f14e578de78b9f694c32e86c9d6e72dfd2b683 100644 --- a/mindspore-lite/test/st/scripts/experimental/config/models_loop.yaml +++ b/mindspore-lite/test/st/scripts/experimental/config/models_loop.yaml @@ -1,1104 +1,1104 @@ -mindir: - deepfm_criteo_bs_16000_Ascend: - fmk: mindir - input_number: 2 - - efficientnetb0_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - efficientnetb1_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - efficientnetb2_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 4 - - efficientnetb3_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - inceptionv3_ascend: - fmk: mindir - input_number: 1 - - inceptionV4: - fmk: mindir - input_number: 1 - - mobilenetv3large_imagenet2012_bs1: - fmk: mindir - input_number: 1 - - mobilenetv3small_imagenet2012_bs1: - fmk: mindir - input_number: 1 - - pix2pix_facades_bs1: - fmk: mindir - input_number: 1 - - resnet101_imagenet_bs_1_Ascend: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - resnet101_imagenet_bs_1_GPU: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - resnet18_imagenet_bs_1_Ascend: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - resnet34_ascend_v190_imagenet2012_official_cv_top1acc73.61_top5acc91.74: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - resnet50_cifar10_bs_1_Ascend: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - resnet50_cifar10_bs_1_GPU: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - resnet50_imagenet_bs_1_Ascend: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - resnet50_imagenet_bs_1_GPU: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - resnet50_thor_imagenet_bs_1_ascend: - fmk: mindir - input_number: 1 - acc_threshold: 2 - - se-resnet50_imagenet_bs_1_ascend: - fmk: mindir - input_number: 1 - acc_threshold: 3 - - shufflenetv1: - fmk: mindir - input_number: 1 - - ssimae_mvtecadbottle_bs1: - fmk: mindir - input_number: 1 - - unet_bs_1_input_2: - fmk: mindir - input_number: 1 - - unet_nested_cell_bs_1_input_2: - fmk: mindir - input_number: 1 - - vgg16_cifar10_bs_64_Ascend: - fmk: mindir - input_number: 1 - - vgg16_imagenet_bs_1_Ascend: - fmk: mindir - input_number: 1 - - vgg19_cifar10_bs1: - fmk: mindir - input_number: 1 - - vgg19_imagenet2012_bs1: - fmk: mindir - input_number: 1 - acc_threshold: 1 - - vit_imagenet2012_bs1: - fmk: mindir - input_number: 1 - ---- - -tf: - browser_deepfm_v7: - fmk: tf - input_number: 2 - benchmark_shapes: 200,94:200,94 - acc_threshold: 0.5 - - browser_deepfm_v7_int64: - fmk: tf - input_number: 2 - benchmark_shapes: 200,94:200,94 - acc_threshold: 0.5 - - browser_v36: - fmk: tf - input_number: 2 - benchmark_shapes: 75,190:75,9120 - acc_threshold: 0.00002 - - browser_v79: - fmk: tf - input_number: 2 - benchmark_shapes: 10,294:10,294 - acc_threshold: 0.004 - - browser_v79_int32: - fmk: tf - input_number: 2 - benchmark_shapes: 10,294:10,294 - acc_threshold: 0.004 - - hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache: - fmk: tf - input_number: 2 - - hiai_dress_detect: - fmk: tf - input_number: 1 - benchmark_shapes: 1,960,960,3 - - hiai_face_model_npu: - fmk: tf - input_number: 1 - - hiai_frozen_inference_graph: - fmk: tf - input_number: 1 - benchmark_shapes: 1,300,300,3 - - hiai_label_and_video: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - hiai_lm_inference_graph: - fmk: tf - input_number: 1 - - hiai_model_0909_kd_rot_ps_softmax: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - hiai_ssd_mobilenetv2_object: - fmk: tf - input_number: 1 - - hiai_transformer_encoder: - fmk: tf - input_number: 15 - - ml_ocr_jk: - fmk: tf - input_number: 1 - - ml_video_edit_img_segment_adaptise: - fmk: tf - input_number: 2 - - ml_video_edit_oneclick_adaptis: - fmk: tf - input_number: 3 - - ml_video_edit_video_segment_gauss_adaptis_part2: - fmk: tf - input_number: 2 - - mtk_age_gender: - fmk: tf - input_number: 1 - - squeezenet: - fmk: tf - input_number: 1 - benchmark_shapes: 1,224,224,3 - - tt_raw_h4800_mel80_ms_fe001_ex_20210506_joint_decoder: - fmk: tf - input_number: 14 - benchmark_shapes: 4:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:1,640 - ---- - -caffe: - 2012_ATLANTA_10class_20190131_v4.0: - fmk: caffe - input_number: 1 - - 6c_seg_nomean_20200610: - fmk: caffe - input_number: 1 - - age_new: - fmk: caffe - input_number: 1 - - bank_card_detection_inception_tmp: - fmk: caffe - input_number: 1 - - bolt_deploy_color-server: - fmk: caffe - input_number: 1 - - deconv_test_model: - fmk: caffe - input_number: 1 - - deconvs_model: - fmk: caffe - input_number: 1 - - detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified: - fmk: caffe - input_number: 1 - - detection_retinaface_fix: - fmk: caffe - input_number: 1 +# mindir: +# # deepfm_criteo_bs_16000_Ascend: +# # fmk: mindir +# # input_number: 2 + +# efficientnetb0_imagenet2012_bs1: +# fmk: mindir +# input_number: 1 +# acc_threshold: 2 + +# efficientnetb1_imagenet2012_bs1: +# fmk: mindir +# input_number: 1 +# acc_threshold: 3 + +# efficientnetb2_imagenet2012_bs1: +# fmk: mindir +# input_number: 1 +# acc_threshold: 4 + +# efficientnetb3_imagenet2012_bs1: +# fmk: mindir +# input_number: 1 +# acc_threshold: 3 + +# # inceptionv3_ascend: +# # fmk: mindir +# # input_number: 1 + +# inceptionV4: +# fmk: mindir +# input_number: 1 + +# mobilenetv3large_imagenet2012_bs1: +# fmk: mindir +# input_number: 1 + +# mobilenetv3small_imagenet2012_bs1: +# fmk: mindir +# input_number: 1 + +# pix2pix_facades_bs1: +# fmk: mindir +# input_number: 1 + +# resnet101_imagenet_bs_1_Ascend: +# fmk: mindir +# input_number: 1 +# acc_threshold: 2 + +# resnet101_imagenet_bs_1_GPU: +# fmk: mindir +# input_number: 1 +# acc_threshold: 2 + +# resnet18_imagenet_bs_1_Ascend: +# fmk: mindir +# input_number: 1 +# acc_threshold: 2 + +# resnet34_ascend_v190_imagenet2012_official_cv_top1acc73.61_top5acc91.74: +# fmk: mindir +# input_number: 1 +# acc_threshold: 3 + +# resnet50_cifar10_bs_1_Ascend: +# fmk: mindir +# input_number: 1 +# acc_threshold: 3 + +# resnet50_cifar10_bs_1_GPU: +# fmk: mindir +# input_number: 1 +# acc_threshold: 2 + +# resnet50_imagenet_bs_1_Ascend: +# fmk: mindir +# input_number: 1 +# acc_threshold: 3 + +# resnet50_imagenet_bs_1_GPU: +# fmk: mindir +# input_number: 1 +# acc_threshold: 3 + +# resnet50_thor_imagenet_bs_1_ascend: +# fmk: mindir +# input_number: 1 +# acc_threshold: 2 + +# se-resnet50_imagenet_bs_1_ascend: +# fmk: mindir +# input_number: 1 +# acc_threshold: 3 + +# shufflenetv1: +# fmk: mindir +# input_number: 1 + +# ssimae_mvtecadbottle_bs1: +# fmk: mindir +# input_number: 1 + +# unet_bs_1_input_2: +# fmk: mindir +# input_number: 1 + +# unet_nested_cell_bs_1_input_2: +# fmk: mindir +# input_number: 1 + +# vgg16_cifar10_bs_64_Ascend: +# fmk: mindir +# input_number: 1 + +# vgg16_imagenet_bs_1_Ascend: +# fmk: mindir +# input_number: 1 + +# vgg19_cifar10_bs1: +# fmk: mindir +# input_number: 1 + +# vgg19_imagenet2012_bs1: +# fmk: mindir +# input_number: 1 +# acc_threshold: 1 + +# vit_imagenet2012_bs1: +# fmk: mindir +# input_number: 1 + +# --- + +# tf: +# browser_deepfm_v7: +# fmk: tf +# input_number: 2 +# benchmark_shapes: 200,94:200,94 +# acc_threshold: 0.5 + +# browser_deepfm_v7_int64: +# fmk: tf +# input_number: 2 +# benchmark_shapes: 200,94:200,94 +# acc_threshold: 0.5 + +# browser_v36: +# fmk: tf +# input_number: 2 +# benchmark_shapes: 75,190:75,9120 +# acc_threshold: 0.00002 + +# browser_v79: +# fmk: tf +# input_number: 2 +# benchmark_shapes: 10,294:10,294 +# acc_threshold: 0.004 + +# browser_v79_int32: +# fmk: tf +# input_number: 2 +# benchmark_shapes: 10,294:10,294 +# acc_threshold: 0.004 + +# hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache: +# fmk: tf +# input_number: 2 + +# hiai_dress_detect: +# fmk: tf +# input_number: 1 +# benchmark_shapes: 1,960,960,3 + +# hiai_face_model_npu: +# fmk: tf +# input_number: 1 + +# hiai_frozen_inference_graph: +# fmk: tf +# input_number: 1 +# benchmark_shapes: 1,300,300,3 + +# hiai_label_and_video: +# fmk: tf +# input_number: 1 +# benchmark_shapes: 1,224,224,3 + +# hiai_lm_inference_graph: +# fmk: tf +# input_number: 1 + +# hiai_model_0909_kd_rot_ps_softmax: +# fmk: tf +# input_number: 1 +# benchmark_shapes: 1,224,224,3 + +# hiai_ssd_mobilenetv2_object: +# fmk: tf +# input_number: 1 + +# hiai_transformer_encoder: +# fmk: tf +# input_number: 15 + +# ml_ocr_jk: +# fmk: tf +# input_number: 1 + +# ml_video_edit_img_segment_adaptise: +# fmk: tf +# input_number: 2 + +# ml_video_edit_oneclick_adaptis: +# fmk: tf +# input_number: 3 + +# ml_video_edit_video_segment_gauss_adaptis_part2: +# fmk: tf +# input_number: 2 + +# mtk_age_gender: +# fmk: tf +# input_number: 1 + +# squeezenet: +# fmk: tf +# input_number: 1 +# benchmark_shapes: 1,224,224,3 + +# tt_raw_h4800_mel80_ms_fe001_ex_20210506_joint_decoder: +# fmk: tf +# input_number: 14 +# benchmark_shapes: 4:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:4,7,64:1,640 + +# --- + +# caffe: +# 2012_ATLANTA_10class_20190131_v4.0: +# fmk: caffe +# input_number: 1 + +# 6c_seg_nomean_20200610: +# fmk: caffe +# input_number: 1 + +# age_new: +# fmk: caffe +# input_number: 1 + +# bank_card_detection_inception_tmp: +# fmk: caffe +# input_number: 1 + +# bolt_deploy_color-server: +# fmk: caffe +# input_number: 1 + +# deconv_test_model: +# fmk: caffe +# input_number: 1 + +# deconvs_model: +# fmk: caffe +# input_number: 1 + +# detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified: +# fmk: caffe +# input_number: 1 + +# detection_retinaface_fix: +# fmk: caffe +# input_number: 1 - emotion: - fmk: caffe - input_number: 1 +# emotion: +# fmk: caffe +# input_number: 1 - gender_res_large_deploy: - fmk: caffe - input_number: 1 +# gender_res_large_deploy: +# fmk: caffe +# input_number: 1 - hdc_age_medium: - fmk: caffe - input_number: 1 +# hdc_age_medium: +# fmk: caffe +# input_number: 1 - hdc_Face_Aesthetic_MTI_Aesthetic: - fmk: caffe - input_number: 1 +# hdc_Face_Aesthetic_MTI_Aesthetic: +# fmk: caffe +# input_number: 1 - hdc_ocr_recog_horizontal: - fmk: caffe - input_number: 1 +# hdc_ocr_recog_horizontal: +# fmk: caffe +# input_number: 1 - hdc_retinaface: - fmk: caffe - input_number: 1 +# hdc_retinaface: +# fmk: caffe +# input_number: 1 - hiai_cpu_face_attr: - fmk: caffe - input_number: 1 +# hiai_cpu_face_attr: +# fmk: caffe +# input_number: 1 - hiai_cpu_face_detect: - fmk: caffe - input_number: 1 +# hiai_cpu_face_detect: +# fmk: caffe +# input_number: 1 - hiai_cv_aestheticsEngineModel_osp: - fmk: caffe - input_number: 1 +# hiai_cv_aestheticsEngineModel_osp: +# fmk: caffe +# input_number: 1 - hiai_cv_focusShootOCRModel_01: - fmk: caffe - input_number: 1 +# hiai_cv_focusShootOCRModel_01: +# fmk: caffe +# input_number: 1 - hiai_cv_focusShootOCRModel_03: - fmk: caffe - input_number: 1 +# hiai_cv_focusShootOCRModel_03: +# fmk: caffe +# input_number: 1 - hiai_cv_focusShootOCRModel_07: - fmk: caffe - input_number: 1 +# hiai_cv_focusShootOCRModel_07: +# fmk: caffe +# input_number: 1 - hiai_face_attr1: - fmk: caffe - input_number: 1 +# hiai_face_attr1: +# fmk: caffe +# input_number: 1 - hiai_face_detect_rfb: - fmk: caffe - input_number: 1 +# hiai_face_detect_rfb: +# fmk: caffe +# input_number: 1 - hiai_face_pose_tuku: - fmk: caffe - input_number: 1 +# hiai_face_pose_tuku: +# fmk: caffe +# input_number: 1 - hiai_face_recognition_1: - fmk: caffe - input_number: 1 +# hiai_face_recognition_1: +# fmk: caffe +# input_number: 1 - hiai_face_RFB-Epoch-170-no-transpose: - fmk: caffe - input_number: 1 +# hiai_face_RFB-Epoch-170-no-transpose: +# fmk: caffe +# input_number: 1 - hiai_human_seg: - fmk: caffe - input_number: 1 +# hiai_human_seg: +# fmk: caffe +# input_number: 1 - hiai_machine_vision_jfr_newmodel_2730_houduan_yolo: - fmk: caffe - input_number: 1 +# hiai_machine_vision_jfr_newmodel_2730_houduan_yolo: +# fmk: caffe +# input_number: 1 - hiai_machine_vision_mobileNet101_nosoftce_mobilenet_resnet: - fmk: caffe - input_number: 1 +# hiai_machine_vision_mobileNet101_nosoftce_mobilenet_resnet: +# fmk: caffe +# input_number: 1 - hiai_semantic_seg: - fmk: caffe - input_number: 1 +# hiai_semantic_seg: +# fmk: caffe +# input_number: 1 - identify_card_detect_tmp: - fmk: caffe - input_number: 1 +# identify_card_detect_tmp: +# fmk: caffe +# input_number: 1 - ml_2012_ocr_detection_caffe_tmp: - fmk: caffe - input_number: 1 +# ml_2012_ocr_detection_caffe_tmp: +# fmk: caffe +# input_number: 1 - ml_2012_ocr_rec_caffe: - fmk: caffe - input_number: 1 +# ml_2012_ocr_rec_caffe: +# fmk: caffe +# input_number: 1 - ml_ARengine23_bodypose: - fmk: caffe - input_number: 1 +# ml_ARengine23_bodypose: +# fmk: caffe +# input_number: 1 - ml_bank_detect_0312_tmp: - fmk: caffe - input_number: 1 +# ml_bank_detect_0312_tmp: +# fmk: caffe +# input_number: 1 - ml_bank_recog: - fmk: caffe - input_number: 1 +# ml_bank_recog: +# fmk: caffe +# input_number: 1 - ml_bodymask: - fmk: caffe - input_number: 1 +# ml_bodymask: +# fmk: caffe +# input_number: 1 - ml_face_age: - fmk: caffe - input_number: 1 +# ml_face_age: +# fmk: caffe +# input_number: 1 - ml_face_beard: - fmk: caffe - input_number: 1 +# ml_face_beard: +# fmk: caffe +# input_number: 1 - ml_face_div_parsing: - fmk: caffe - input_number: 1 +# ml_face_div_parsing: +# fmk: caffe +# input_number: 1 - ml_face_glasses: - fmk: caffe - input_number: 1 +# ml_face_glasses: +# fmk: caffe +# input_number: 1 - ml_face_isface: - fmk: caffe - input_number: 1 +# ml_face_isface: +# fmk: caffe +# input_number: 1 - ml_face_mnet: - fmk: caffe - input_number: 1 +# ml_face_mnet: +# fmk: caffe +# input_number: 1 - ml_face_pose: - fmk: caffe - input_number: 1 +# ml_face_pose: +# fmk: caffe +# input_number: 1 - ml_face_sex: - fmk: caffe - input_number: 1 +# ml_face_sex: +# fmk: caffe +# input_number: 1 - ml_face_tracking: - fmk: caffe - input_number: 1 +# ml_face_tracking: +# fmk: caffe +# input_number: 1 - ml_hand_3d_detection: - fmk: caffe - input_number: 1 +# ml_hand_3d_detection: +# fmk: caffe +# input_number: 1 - ml_Hand_deploy: - fmk: caffe - input_number: 1 +# ml_Hand_deploy: +# fmk: caffe +# input_number: 1 - ml_hand_detection: - fmk: caffe - input_number: 1 +# ml_hand_detection: +# fmk: caffe +# input_number: 1 - ml_handpose: - fmk: caffe - input_number: 1 +# ml_handpose: +# fmk: caffe +# input_number: 1 - ml_hardware_liveness: - fmk: caffe - input_number: 1 +# ml_hardware_liveness: +# fmk: caffe +# input_number: 1 - ml_hardware_pose: - fmk: caffe - input_number: 1 +# ml_hardware_pose: +# fmk: caffe +# input_number: 1 - ml_Heatmap_depth_180240: - fmk: caffe - input_number: 2 +# ml_Heatmap_depth_180240: +# fmk: caffe +# input_number: 2 - ml_Heatmap_depth_240180: - fmk: caffe - input_number: 2 +# ml_Heatmap_depth_240180: +# fmk: caffe +# input_number: 2 - ml_lable_model_hebing_device: - fmk: caffe - input_number: 1 +# ml_lable_model_hebing_device: +# fmk: caffe +# input_number: 1 - ml_location_scene_division: - fmk: caffe - input_number: 1 +# ml_location_scene_division: +# fmk: caffe +# input_number: 1 - ml_ocr_bank_card_detection_inception_tmp: - fmk: caffe - input_number: 1 +# ml_ocr_bank_card_detection_inception_tmp: +# fmk: caffe +# input_number: 1 - ml_ocr_bank_card_recognition_fcny: - fmk: caffe - input_number: 1 +# ml_ocr_bank_card_recognition_fcny: +# fmk: caffe +# input_number: 1 - ml_ocr_detect_20200305: - fmk: caffe - input_number: 1 +# ml_ocr_detect_20200305: +# fmk: caffe +# input_number: 1 - ml_ocr_identify_card_detect_tmp: - fmk: caffe - input_number: 1 +# ml_ocr_identify_card_detect_tmp: +# fmk: caffe +# input_number: 1 - ml_ocr_identify_card_fcny: - fmk: caffe - input_number: 1 +# ml_ocr_identify_card_fcny: +# fmk: caffe +# input_number: 1 - ml_ocr_sfz_add_final_0325: - fmk: caffe - input_number: 1 +# ml_ocr_sfz_add_final_0325: +# fmk: caffe +# input_number: 1 - ml_ocr_sfz_detect_0325_tmp: - fmk: caffe - input_number: 1 +# ml_ocr_sfz_detect_0325_tmp: +# fmk: caffe +# input_number: 1 - ml_segmentation_atlanta_1: - fmk: caffe - input_number: 1 +# ml_segmentation_atlanta_1: +# fmk: caffe +# input_number: 1 - ml_segmentation_atlanta_10: - fmk: caffe - input_number: 1 +# ml_segmentation_atlanta_10: +# fmk: caffe +# input_number: 1 - ml_segmentation_matting: - fmk: caffe - input_number: 1 +# ml_segmentation_matting: +# fmk: caffe +# input_number: 1 - ml_tabel_recog: - fmk: caffe - input_number: 1 +# ml_tabel_recog: +# fmk: caffe +# input_number: 1 - ml_video_edit_detect_20211111: - fmk: caffe - input_number: 1 +# ml_video_edit_detect_20211111: +# fmk: caffe +# input_number: 1 - ml_video_edit_dynamic_effect_MTI_seg5c_v1: - fmk: caffe - input_number: 1 +# ml_video_edit_dynamic_effect_MTI_seg5c_v1: +# fmk: caffe +# input_number: 1 - ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145_20210121: - fmk: caffe - input_number: 1 +# ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145_20210121: +# fmk: caffe +# input_number: 1 - ml_video_edit_have_imageProcessLayer_interpTo145_20201015: - fmk: caffe - input_number: 1 +# ml_video_edit_have_imageProcessLayer_interpTo145_20201015: +# fmk: caffe +# input_number: 1 - ml_video_edit_img_segment: - fmk: caffe - input_number: 1 +# ml_video_edit_img_segment: +# fmk: caffe +# input_number: 1 - ml_video_edit_Mnet: - fmk: caffe - input_number: 1 +# ml_video_edit_Mnet: +# fmk: caffe +# input_number: 1 - ml_video_edit_MnetN367_extract_1010_pay: - fmk: caffe - input_number: 1 +# ml_video_edit_MnetN367_extract_1010_pay: +# fmk: caffe +# input_number: 1 - ml_video_edit_moon_mode_moon_seg: - fmk: caffe - input_number: 1 +# ml_video_edit_moon_mode_moon_seg: +# fmk: caffe +# input_number: 1 - ml_video_edit_moon_mode_MTI_9c_segmentation_v12: - fmk: caffe - input_number: 1 +# ml_video_edit_moon_mode_MTI_9c_segmentation_v12: +# fmk: caffe +# input_number: 1 - ml_video_edit_person_divison_video: - fmk: caffe - input_number: 2 +# ml_video_edit_person_divison_video: +# fmk: caffe +# input_number: 2 - ml_video_edit_reid: - fmk: caffe - input_number: 1 +# ml_video_edit_reid: +# fmk: caffe +# input_number: 1 - ml_video_edit_seg_320: - fmk: caffe - input_number: 1 +# ml_video_edit_seg_320: +# fmk: caffe +# input_number: 1 - ml_video_edit_v10_best_model_nomean_20200723: - fmk: caffe - input_number: 1 +# ml_video_edit_v10_best_model_nomean_20200723: +# fmk: caffe +# input_number: 1 - ml_video_edit_video_segment_gauss_adaptis_part1: - fmk: caffe - input_number: 1 +# ml_video_edit_video_segment_gauss_adaptis_part1: +# fmk: caffe +# input_number: 1 - mnet: - fmk: caffe - input_number: 1 +# mnet: +# fmk: caffe +# input_number: 1 - Mnet6_0312_extract_pay: - fmk: caffe - input_number: 1 +# Mnet6_0312_extract_pay: +# fmk: caffe +# input_number: 1 - model_hebing_3branch: - fmk: caffe - input_number: 1 +# model_hebing_3branch: +# fmk: caffe +# input_number: 1 - mtk_2012_ATLANTA_10class_20190614_v41: - fmk: caffe - input_number: 1 +# mtk_2012_ATLANTA_10class_20190614_v41: +# fmk: caffe +# input_number: 1 - mtk_detect_mbv1_640_480_nopostprocess_simplified: - fmk: caffe - input_number: 1 +# mtk_detect_mbv1_640_480_nopostprocess_simplified: +# fmk: caffe +# input_number: 1 - mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified: - fmk: caffe - input_number: 1 +# mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified: +# fmk: caffe +# input_number: 1 - mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified: - fmk: caffe - input_number: 1 +# mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified: +# fmk: caffe +# input_number: 1 - mtk_face_recognition_v1: - fmk: caffe - input_number: 1 +# mtk_face_recognition_v1: +# fmk: caffe +# input_number: 1 - plat_isface: - fmk: caffe - input_number: 1 +# plat_isface: +# fmk: caffe +# input_number: 1 - pose_3d: - fmk: caffe - input_number: 1 +# pose_3d: +# fmk: caffe +# input_number: 1 - PoseNet_dla_17_x512_tmp: - fmk: caffe - input_number: 1 +# PoseNet_dla_17_x512_tmp: +# fmk: caffe +# input_number: 1 - recognition: - fmk: caffe - input_number: 1 +# recognition: +# fmk: caffe +# input_number: 1 - retinaface: - fmk: caffe - input_number: 1 +# retinaface: +# fmk: caffe +# input_number: 1 - Sport_Health_Tech_pose_iter: - fmk: caffe - input_number: 1 +# Sport_Health_Tech_pose_iter: +# fmk: caffe +# input_number: 1 ---- +# --- -onnx: - adversarial_pruning: - fmk: onnx - input_number: 1 +# onnx: +# adversarial_pruning: +# fmk: onnx +# input_number: 1 - bloom_hongmo_detection_tmp: - fmk: onnx - input_number: 1 +# bloom_hongmo_detection_tmp: +# fmk: onnx +# input_number: 1 - candy-9: - fmk: onnx - input_number: 1 +# candy-9: +# fmk: onnx +# input_number: 1 - carbu_intelligent_cockpit_fasttext_best: - fmk: onnx - input_number: 1 +# carbu_intelligent_cockpit_fasttext_best: +# fmk: onnx +# input_number: 1 - densenet-9: - fmk: onnx - input_number: 1 +# densenet-9: +# fmk: onnx +# input_number: 1 - googlenet-9: - fmk: onnx - input_number: 1 +# googlenet-9: +# fmk: onnx +# input_number: 1 - gts_version-RFB-320_simplified: - fmk: onnx - input_number: 1 +# gts_version-RFB-320_simplified: +# fmk: onnx +# input_number: 1 - hdc_efficientnet_b3_1w_class: - fmk: onnx - input_number: 1 +# hdc_efficientnet_b3_1w_class: +# fmk: onnx +# input_number: 1 - hdc_Face_Emotion_MTI_Aesthetic: - fmk: onnx - input_number: 1 +# hdc_Face_Emotion_MTI_Aesthetic: +# fmk: onnx +# input_number: 1 - hdc_Image_Aesthetic_MTI_Aesthetic: - fmk: onnx - input_number: 1 +# hdc_Image_Aesthetic_MTI_Aesthetic: +# fmk: onnx +# input_number: 1 - hdc_mobilenet_1w_class: - fmk: onnx - input_number: 1 +# hdc_mobilenet_1w_class: +# fmk: onnx +# input_number: 1 - hdc_ocr_detect_tmp: - fmk: onnx - input_number: 1 +# hdc_ocr_detect_tmp: +# fmk: onnx +# input_number: 1 - hdc_resnet_1w_class: - fmk: onnx - input_number: 1 +# hdc_resnet_1w_class: +# fmk: onnx +# input_number: 1 - inception-v1-9: - fmk: onnx - input_number: 1 +# inception-v1-9: +# fmk: onnx +# input_number: 1 - inception-v2-9: - fmk: onnx - input_number: 1 +# inception-v2-9: +# fmk: onnx +# input_number: 1 - Ireland_face_detector: - fmk: onnx - input_number: 1 +# Ireland_face_detector: +# fmk: onnx +# input_number: 1 - Ireland_gaze_corrector: - fmk: onnx - input_number: 3 - acc_threshold: 1 +# Ireland_gaze_corrector: +# fmk: onnx +# input_number: 3 +# acc_threshold: 1 - Ireland_gaze_estimator_ng: - fmk: onnx - input_number: 1 +# Ireland_gaze_estimator_ng: +# fmk: onnx +# input_number: 1 - ml_2012_ocr_detection_tmp: - fmk: onnx - input_number: 1 +# ml_2012_ocr_detection_tmp: +# fmk: onnx +# input_number: 1 - ml_edu_kit_hand_detection: - fmk: onnx - input_number: 1 +# ml_edu_kit_hand_detection: +# fmk: onnx +# input_number: 1 - ml_edu_kit_hand_key_position: - fmk: onnx - input_number: 1 +# ml_edu_kit_hand_key_position: +# fmk: onnx +# input_number: 1 - ml_ei_facedetection: - fmk: onnx - input_number: 1 +# ml_ei_facedetection: +# fmk: onnx +# input_number: 1 - ml_face_3d: - fmk: onnx - input_number: 1 +# ml_face_3d: +# fmk: onnx +# input_number: 1 - ml_facedetector: - fmk: onnx - input_number: 1 +# ml_facedetector: +# fmk: onnx +# input_number: 1 - ml_location_lane_counter: - fmk: onnx - input_number: 1 +# ml_location_lane_counter: +# fmk: onnx +# input_number: 1 - ml_location_lane_counter0: - fmk: onnx - input_number: 1 +# ml_location_lane_counter0: +# fmk: onnx +# input_number: 1 - ml_motion_capture_nanodet_m_0.5x_people_0928_sim: - fmk: onnx - input_number: 1 +# ml_motion_capture_nanodet_m_0.5x_people_0928_sim: +# fmk: onnx +# input_number: 1 - ml_motion_capture_smpl_0916: - fmk: onnx - input_number: 3 +# ml_motion_capture_smpl_0916: +# fmk: onnx +# input_number: 3 - ml_motion_capture_spin_mobile_mv3_v3_57mm_sim: - fmk: onnx - input_number: 5 +# ml_motion_capture_spin_mobile_mv3_v3_57mm_sim: +# fmk: onnx +# input_number: 5 - ml_table_detection_fp32_tmp: - fmk: onnx - input_number: 1 +# ml_table_detection_fp32_tmp: +# fmk: onnx +# input_number: 1 - ml_video_edit_art_generate: - fmk: onnx - input_number: 1 +# ml_video_edit_art_generate: +# fmk: onnx +# input_number: 1 - ml_video_edit_art_generate_20210513: - fmk: onnx - input_number: 1 +# ml_video_edit_art_generate_20210513: +# fmk: onnx +# input_number: 1 - ml_video_edit_art_transfer_20210513: - fmk: onnx - input_number: 3 +# ml_video_edit_art_transfer_20210513: +# fmk: onnx +# input_number: 3 - ml_video_edit_dimming_tech_model_345000_color: - fmk: onnx - input_number: 2 +# ml_video_edit_dimming_tech_model_345000_color: +# fmk: onnx +# input_number: 2 - ml_video_edit_dimming_tech_model_studio_20: - fmk: onnx - input_number: 2 +# ml_video_edit_dimming_tech_model_studio_20: +# fmk: onnx +# input_number: 2 - ml_video_edit_enhance_update_tmp: - fmk: onnx - input_number: 1 +# ml_video_edit_enhance_update_tmp: +# fmk: onnx +# input_number: 1 - ml_video_edit_face_edit_face3d: - fmk: onnx - input_number: 1 +# ml_video_edit_face_edit_face3d: +# fmk: onnx +# input_number: 1 - ml_video_edit_face_edit_pix2pixHD_unet: - fmk: onnx - input_number: 1 +# ml_video_edit_face_edit_pix2pixHD_unet: +# fmk: onnx +# input_number: 1 - ml_video_edit_hair_dyeing_migrate_v2: - fmk: onnx - input_number: 4 +# ml_video_edit_hair_dyeing_migrate_v2: +# fmk: onnx +# input_number: 4 - ml_video_edit_hair_dyeing_migrate_v2_fix: - fmk: onnx - input_number: 4 +# ml_video_edit_hair_dyeing_migrate_v2_fix: +# fmk: onnx +# input_number: 4 - ml_video_edit_judge: - fmk: onnx - input_number: 1 +# ml_video_edit_judge: +# fmk: onnx +# input_number: 1 - ml_video_edit_makeup_mobilenetv203: - fmk: onnx - input_number: 1 +# ml_video_edit_makeup_mobilenetv203: +# fmk: onnx +# input_number: 1 - ml_video_edit_shot_selection_face_emotion: - fmk: onnx - input_number: 1 +# ml_video_edit_shot_selection_face_emotion: +# fmk: onnx +# input_number: 1 - ml_video_edit_shot_selection_yolox_nano_coco_reduced: - fmk: onnx - input_number: 1 +# ml_video_edit_shot_selection_yolox_nano_coco_reduced: +# fmk: onnx +# input_number: 1 - ml_video_edit_style_transfer_autoportrait: - fmk: onnx - input_number: 1 +# ml_video_edit_style_transfer_autoportrait: +# fmk: onnx +# input_number: 1 - ml_video_edit_style_transfer_candy: - fmk: onnx - input_number: 1 +# ml_video_edit_style_transfer_candy: +# fmk: onnx +# input_number: 1 - ml_video_edit_style_transfer_gongnongbing: - fmk: onnx - input_number: 1 +# ml_video_edit_style_transfer_gongnongbing: +# fmk: onnx +# input_number: 1 - ml_video_edit_style_transfer_starry: - fmk: onnx - input_number: 1 +# ml_video_edit_style_transfer_starry: +# fmk: onnx +# input_number: 1 - ml_video_edit_styleCode_part1: - fmk: onnx - input_number: 1 +# ml_video_edit_styleCode_part1: +# fmk: onnx +# input_number: 1 - ml_video_edit_styleCode_part2: - fmk: onnx - input_number: 9 +# ml_video_edit_styleCode_part2: +# fmk: onnx +# input_number: 9 - ml_video_edit_vignet: - fmk: onnx - input_number: 1 +# ml_video_edit_vignet: +# fmk: onnx +# input_number: 1 - mobilenetv2-7: - fmk: onnx - input_number: 1 +# mobilenetv2-7: +# fmk: onnx +# input_number: 1 - mosaic-9: - fmk: onnx - input_number: 1 +# mosaic-9: +# fmk: onnx +# input_number: 1 - mtk_detect_mbv1_640_480: - fmk: onnx - input_number: 1 +# mtk_detect_mbv1_640_480: +# fmk: onnx +# input_number: 1 - mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 +# mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified_onnx: +# fmk: onnx +# input_number: 1 - mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 +# mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: +# fmk: onnx +# input_number: 1 - mtk_detect-mbv1-shortcut-400-400: - fmk: onnx - input_number: 1 +# mtk_detect-mbv1-shortcut-400-400: +# fmk: onnx +# input_number: 1 - mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: - fmk: onnx - input_number: 1 +# mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified_onnx: +# fmk: onnx +# input_number: 1 - mtk_detect-mbv2-shortcut-400-400: - fmk: onnx - input_number: 1 +# mtk_detect-mbv2-shortcut-400-400: +# fmk: onnx +# input_number: 1 - mtk_detect-mbv2-shortcut-400-400-simplified: - fmk: onnx - input_number: 1 +# mtk_detect-mbv2-shortcut-400-400-simplified: +# fmk: onnx +# input_number: 1 - mtk_emotions-d2012-75: - fmk: onnx - input_number: 1 +# mtk_emotions-d2012-75: +# fmk: onnx +# input_number: 1 - mtk_face_features_v3: - fmk: onnx - input_number: 1 +# mtk_face_features_v3: +# fmk: onnx +# input_number: 1 - mtk_face_recognition_v2: - fmk: onnx - input_number: 1 +# mtk_face_recognition_v2: +# fmk: onnx +# input_number: 1 - mtk_face_recognition_v3: - fmk: onnx - input_number: 1 +# mtk_face_recognition_v3: +# fmk: onnx +# input_number: 1 - pointilism-9: - fmk: onnx - input_number: 1 +# pointilism-9: +# fmk: onnx +# input_number: 1 - porseg_tmp: - fmk: onnx - input_number: 2 +# porseg_tmp: +# fmk: onnx +# input_number: 2 - Q888_CV_face_recognition_self: - fmk: onnx - input_number: 1 +# Q888_CV_face_recognition_self: +# fmk: onnx +# input_number: 1 - Q888_face_recognition: - fmk: onnx - input_number: 1 +# Q888_face_recognition: +# fmk: onnx +# input_number: 1 - Q888_iris_detect: - fmk: onnx - input_number: 1 +# Q888_iris_detect: +# fmk: onnx +# input_number: 1 - rain-princess-9: - fmk: onnx - input_number: 1 +# rain-princess-9: +# fmk: onnx +# input_number: 1 - residual_distill_bs_1: - fmk: onnx - input_number: 1 +# residual_distill_bs_1: +# fmk: onnx +# input_number: 1 - residual_distill_bs_32: - fmk: onnx - input_number: 1 +# residual_distill_bs_32: +# fmk: onnx +# input_number: 1 - residual_distill_cifar10_bs_1: - fmk: onnx - input_number: 1 +# residual_distill_cifar10_bs_1: +# fmk: onnx +# input_number: 1 - residual_distill_cifar10_bs_32: - fmk: onnx - input_number: 1 +# residual_distill_cifar10_bs_32: +# fmk: onnx +# input_number: 1 - residual_distill_res34_cifar10_bs_1_update: - fmk: onnx - input_number: 1 +# residual_distill_res34_cifar10_bs_1_update: +# fmk: onnx +# input_number: 1 - residual_distill_res50_cifar10_bs_1_update: - fmk: onnx - input_number: 1 +# residual_distill_res50_cifar10_bs_1_update: +# fmk: onnx +# input_number: 1 - rpnt_pdr_conv2d_16_fixed_last: - fmk: onnx - input_number: 1 +# rpnt_pdr_conv2d_16_fixed_last: +# fmk: onnx +# input_number: 1 - shufflenet-9: - fmk: onnx - input_number: 1 +# shufflenet-9: +# fmk: onnx +# input_number: 1 - shufflenet-v2-10: - fmk: onnx - input_number: 1 +# shufflenet-v2-10: +# fmk: onnx +# input_number: 1 - simple_IPS_model_4D_input: - fmk: onnx - input_number: 1 +# simple_IPS_model_4D_input: +# fmk: onnx +# input_number: 1 - squeezenet1.0-9: - fmk: onnx - input_number: 1 +# squeezenet1.0-9: +# fmk: onnx +# input_number: 1 - squeezenet1.1-7: - fmk: onnx - input_number: 1 +# squeezenet1.1-7: +# fmk: onnx +# input_number: 1 - ssd-10: - fmk: onnx - input_number: 1 +# ssd-10: +# fmk: onnx +# input_number: 1 - udnie-9: - fmk: onnx - input_number: 1 +# udnie-9: +# fmk: onnx +# input_number: 1 - yolov5s: - fmk: onnx - input_number: 1 +# yolov5s: +# fmk: onnx +# input_number: 1 ---- - -tflite: - albert_lite_base_squadv1_1: - fmk: tflite - input_number: 3 +# --- + +# tflite: +# albert_lite_base_squadv1_1: +# fmk: tflite +# input_number: 3 - gts_detect_5k_tf115: - fmk: tflite - input_number: 1 +# gts_detect_5k_tf115: +# fmk: tflite +# input_number: 1 - hdc_tb_cn_neg: - fmk: tflite - input_number: 3 - acc_threshold: 0.5 +# hdc_tb_cn_neg: +# fmk: tflite +# input_number: 3 +# acc_threshold: 0.5 - hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache_fp32: - fmk: tflite - input_number: 2 +# hiai_asr_last_e1_cpu_fast_wavenet_batch1_frame1_one_cache_fp32: +# fmk: tflite +# input_number: 2 - hiai_cv_labelDetectorModel_v3: - fmk: tflite - input_number: 2 +# hiai_cv_labelDetectorModel_v3: +# fmk: tflite +# input_number: 2 - hiai_cv_saliencyDetectorModel: - fmk: tflite - input_number: 1 +# hiai_cv_saliencyDetectorModel: +# fmk: tflite +# input_number: 1 - hiai_dress_detect: - fmk: tflite - input_number: 1 +# hiai_dress_detect: +# fmk: tflite +# input_number: 1 - hiai_face_model_npu: - fmk: tflite - input_number: 1 +# hiai_face_model_npu: +# fmk: tflite +# input_number: 1 - hiai_frozen_inference_graph: - fmk: tflite - input_number: 1 +# hiai_frozen_inference_graph: +# fmk: tflite +# input_number: 1 - hiai_label_and_video: - fmk: tflite - input_number: 1 +# hiai_label_and_video: +# fmk: tflite +# input_number: 1 - hiai_lm_inference_graph: - fmk: tflite - input_number: 1 +# hiai_lm_inference_graph: +# fmk: tflite +# input_number: 1 - hiai_model_0909_kd_rot_ps_softmax: - fmk: tflite - input_number: 1 +# hiai_model_0909_kd_rot_ps_softmax: +# fmk: tflite +# input_number: 1 - hiai_object_detect_814: - fmk: tflite - input_number: 1 +# hiai_object_detect_814: +# fmk: tflite +# input_number: 1 - hiai_ssd_mobilenetv2_object: - fmk: tflite - input_number: 1 +# hiai_ssd_mobilenetv2_object: +# fmk: tflite +# input_number: 1 - hiai_vad: - fmk: tflite - input_number: 2 +# hiai_vad: +# fmk: tflite +# input_number: 2 - ide_label_retrained: - fmk: tflite - input_number: 1 +# ide_label_retrained: +# fmk: tflite +# input_number: 1 - lite-model_albert_lite_base_squadv1_metadata_1: - fmk: tflite - input_number: 3 +# lite-model_albert_lite_base_squadv1_metadata_1: +# fmk: tflite +# input_number: 3 - lite-model_mobilebert_1_metadata_1: - fmk: tflite - input_number: 3 +# lite-model_mobilebert_1_metadata_1: +# fmk: tflite +# input_number: 3 - ml_ei_headpose_pb2tflite: - fmk: tflite - input_number: 3 - benchmark_shapes: 1,64,64,3:16:16 - - ml_headpose_pb2tflite: - fmk: tflite - input_number: 3 - benchmark_shapes: 1,64,64,3:16:16 - - ml_location: - fmk: tflite - input_number: 1 - - ml_ocr_jk: - fmk: tflite - input_number: 1 - - ml_ocr_jk_pb2tflite: - fmk: tflite - input_number: 1 - - ml_tacotron_decoder_step_stf: - fmk: tflite - input_number: 9 - benchmark_shapes: 1,80:1,256:1,1024:1,1024:1,1024:1,1024:1,8:1,1,256:1 - - ml_video_edit_img_segment_adaptise_pb2tflite: - fmk: tflite - input_number: 2 - - ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite: - fmk: tflite - input_number: 2 - - mobilebert_1_default_1: - fmk: tflite - input_number: 3 - - mtk_276landmark_0913: - fmk: tflite - input_number: 1 - - mtk_age_gender: - fmk: tflite - input_number: 1 - - Q_language_model_hrmini_Q4_b4_17w: - fmk: tflite - input_number: 1 - - Q888_lapa158_unet_0924: - fmk: tflite - input_number: 1 - - resnet_v2_101_299: - fmk: tflite - input_number: 1 - - scan_hms_detect: - fmk: tf - input_number: 1 - - unet_mbv2_05_104pts: - fmk: tflite - input_number: 1 +# ml_ei_headpose_pb2tflite: +# fmk: tflite +# input_number: 3 +# benchmark_shapes: 1,64,64,3:16:16 + +# ml_headpose_pb2tflite: +# fmk: tflite +# input_number: 3 +# benchmark_shapes: 1,64,64,3:16:16 + +# ml_location: +# fmk: tflite +# input_number: 1 + +# ml_ocr_jk: +# fmk: tflite +# input_number: 1 + +# ml_ocr_jk_pb2tflite: +# fmk: tflite +# input_number: 1 + +# ml_tacotron_decoder_step_stf: +# fmk: tflite +# input_number: 9 +# benchmark_shapes: 1,80:1,256:1,1024:1,1024:1,1024:1,1024:1,8:1,1,256:1 + +# ml_video_edit_img_segment_adaptise_pb2tflite: +# fmk: tflite +# input_number: 2 + +# ml_video_edit_video_segment_gauss_adaptis_part2_pb2tflite: +# fmk: tflite +# input_number: 2 + +# mobilebert_1_default_1: +# fmk: tflite +# input_number: 3 + +# mtk_276landmark_0913: +# fmk: tflite +# input_number: 1 + +# mtk_age_gender: +# fmk: tflite +# input_number: 1 + +# Q_language_model_hrmini_Q4_b4_17w: +# fmk: tflite +# input_number: 1 + +# Q888_lapa158_unet_0924: +# fmk: tflite +# input_number: 1 + +# resnet_v2_101_299: +# fmk: tflite +# input_number: 1 + +# scan_hms_detect: +# fmk: tf +# input_number: 1 + +# unet_mbv2_05_104pts: +# fmk: tflite +# input_number: 1 diff --git a/mindspore-lite/test/st/scripts/run_benchmark_python.sh b/mindspore-lite/test/st/scripts/run_benchmark_python.sh index 8a4a6b1812e6705be6229595f159ecdd5b89b1a9..109eceb604a751a5dea4808eac99c3f467024289 100644 --- a/mindspore-lite/test/st/scripts/run_benchmark_python.sh +++ b/mindspore-lite/test/st/scripts/run_benchmark_python.sh @@ -6,7 +6,7 @@ function Run_python_ST() { whl_path=$2 model_path=$3 in_data_path=$4 - model_hiai_path=$in_data_path + # model_hiai_path=$in_data_path cfg_file_list=$5 target=$6 suffix=".mindir" @@ -72,22 +72,22 @@ function Run_python_ST() { done echo "-----------------------------------------------------------------------------------------" - elapsed_time=$(date +%s.%N) - python test_inference_cloud_nocofig.py ${model_hiai_path} ${target} >> ${run_python_log} - Run_python_st_status=$? - elapsed_time=$(printf %.2f "$(echo "$(date +%s.%N) - $elapsed_time" | bc)") - if [[ ${Run_python_st_status} != 0 ]];then - echo "RunPython test_inference_cloud_nocofig ${elapsed_time} failed" >> ${result_python_log} - cat ${run_python_log} - else - echo "RunPython test_inference_cloud_nocofig ${elapsed_time} pass" >> ${result_python_log} - fi - echo "-----------------------------------------------------------------------------------------" - Print_Benchmark_Result ${result_python_log} - echo "-----------------------------------------------------------------------------------------" + # elapsed_time=$(date +%s.%N) + # python test_inference_cloud_nocofig.py ${model_hiai_path} ${target} >> ${run_python_log} + # Run_python_st_status=$? + # elapsed_time=$(printf %.2f "$(echo "$(date +%s.%N) - $elapsed_time" | bc)") + # if [[ ${Run_python_st_status} != 0 ]];then + # echo "RunPython test_inference_cloud_nocofig ${elapsed_time} failed" >> ${result_python_log} + # cat ${run_python_log} + # else + # echo "RunPython test_inference_cloud_nocofig ${elapsed_time} pass" >> ${result_python_log} + # fi + # echo "-----------------------------------------------------------------------------------------" + # Print_Benchmark_Result ${result_python_log} + # echo "-----------------------------------------------------------------------------------------" - if [[ ${Run_python_st_status} != 0 ]];then - echo "Run_python_st_status failed" - exit 1 - fi + # if [[ ${Run_python_st_status} != 0 ]];then + # echo "Run_python_st_status failed" + # exit 1 + # fi } diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/convolution_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/convolution_fp32_tests.cc index 934f7d7bdca73b1f6b8c6a8718c8245b74757778..ab023ea49a4739bcf0144439316065cf7c216d24 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/convolution_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/convolution_fp32_tests.cc @@ -92,9 +92,7 @@ void InitConvTensor(std::vector *inputs, std::vector(malloc(sizeof(ConvParameter))); - ASSERT_NE(conv_param, nullptr); - memset(conv_param, 0, sizeof(ConvParameter)); + auto conv_param = new ConvParameter(); conv_param->op_parameter_.type_ = PrimType_Conv2DFusion; InitConvParam(conv_param); @@ -137,6 +135,5 @@ TEST_F(TestConvolutionFp32, conv1) { delete out_t; } delete kernel; - delete ctx; } } // namespace mindspore diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/nllloss_fp32_test.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/nllloss_fp32_test.cc index 3c3b47c8c3e16ed8941235bcbbebd348e893b2af..0d4b1c4428f2509c2c4ec9731aa9c58a0f2d0a00 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/nllloss_fp32_test.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/nllloss_fp32_test.cc @@ -81,9 +81,7 @@ TEST_F(TestNLLLossFp32, ReductionNone) { auto *ctx = new lite::InnerContext; ctx->thread_num_ = 1; ASSERT_EQ(lite::RET_OK, ctx->Init()); - NLLLossParameter *param = reinterpret_cast(malloc(sizeof(NLLLossParameter))); - ASSERT_NE(param, nullptr); - memset(param, 0, sizeof(NLLLossParameter)); + auto *param = new NLLLossParameter; param->op_parameter_.thread_num_ = ctx->thread_num_; param->op_parameter_.type_ = schema::PrimitiveType_NLLLoss; param->reduction_type_ = Reduction_None; @@ -111,9 +109,7 @@ TEST_F(TestNLLLossFp32, ReductionSum) { auto *ctx = new lite::InnerContext; ctx->thread_num_ = 1; ASSERT_EQ(lite::RET_OK, ctx->Init()); - NLLLossParameter *param = reinterpret_cast(malloc(sizeof(NLLLossParameter))); - ASSERT_NE(param, nullptr); - memset(param, 0, sizeof(NLLLossParameter)); + auto *param = new NLLLossParameter; param->op_parameter_.thread_num_ = ctx->thread_num_; param->op_parameter_.type_ = schema::PrimitiveType_NLLLoss; param->reduction_type_ = Reduction_Sum; @@ -141,9 +137,7 @@ TEST_F(TestNLLLossFp32, ReductionMean) { auto *ctx = new lite::InnerContext; ctx->thread_num_ = 1; ASSERT_EQ(lite::RET_OK, ctx->Init()); - NLLLossParameter *param = reinterpret_cast(malloc(sizeof(NLLLossParameter))); - ASSERT_NE(param, nullptr); - memset(param, 0, sizeof(NLLLossParameter)); + auto *param = new NLLLossParameter; param->op_parameter_.thread_num_ = ctx->thread_num_; param->op_parameter_.type_ = schema::PrimitiveType_NLLLoss; param->reduction_type_ = Reduction_Mean; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc index 8d936215781cea184b62bbf75d6e6ef44be40c0c..44070ad5ba70c49c30a5ed22f56fa1aee3a8b194 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/fp32/power_fp32_tests.cc @@ -40,9 +40,7 @@ TEST_F(TestPowerFp32, Simple) { output.MallocData(); std::vector outputs = {&output}; - PowParameter *param = reinterpret_cast(malloc(sizeof(PowParameter))); - ASSERT_NE(param, nullptr); - memset(param, 0, sizeof(PowParameter)); + auto param = new PowParameter(); param->scale_ = 1; param->shift_ = 0; param->op_parameter_.type_ = schema::PrimitiveType_PowFusion; @@ -79,9 +77,7 @@ TEST_F(TestPowerFp32, Broadcast) { output.MallocData(); std::vector outputs = {&output}; - PowParameter *param = reinterpret_cast(malloc(sizeof(PowParameter))); - ASSERT_NE(param, nullptr); - memset(param, 0, sizeof(PowParameter)); + auto param = new PowParameter(); param->power_ = 2; param->scale_ = 1; param->shift_ = 0; diff --git a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc index d8cbc4bbc35d35fba933f173ea73914cf639097a..c5c17a6afebc1d2954dd5ff4c28cc8df78fe0f6b 100644 --- a/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc +++ b/mindspore-lite/test/ut/src/runtime/kernel/arm/int8/pad_int8_tests.cc @@ -40,7 +40,6 @@ int PadInt8TestInit1(std::vector *inputs_, std::vector *outp LiteQuantParam *in_quant_arg = new LiteQuantParam(); in_quant_arg->zeroPoint = 10, in_quant_arg->scale = 0.31228156; in_t->AddQuantParam(*in_quant_arg); - delete in_quant_arg; inputs_->push_back(in_t); Tensor *in_t_pad = new Tensor(kNumberTypeInt8, {1}, mindspore::NHWC, lite::Category::VAR); @@ -52,7 +51,6 @@ int PadInt8TestInit1(std::vector *inputs_, std::vector *outp LiteQuantParam *out_quant_arg = new LiteQuantParam(); out_quant_arg->zeroPoint = 10, out_quant_arg->scale = 0.31228156; out_t->AddQuantParam(*out_quant_arg); - delete out_quant_arg; outputs_->push_back(out_t); *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(int8_t))); @@ -70,9 +68,7 @@ int PadInt8TestInit1(std::vector *inputs_, std::vector *outp TEST_F(TestPadInt8, PadInt8Test1) { std::vector inputs_; std::vector outputs_; - PadParameter *pad_param = reinterpret_cast(malloc(sizeof(PadParameter))); - ASSERT_NE(pad_param, nullptr); - memset(pad_param, 0, sizeof(PadParameter)); + auto pad_param = new PadParameter(); lite::InnerContext *ctx = new lite::InnerContext; ASSERT_EQ(lite::RET_OK, ctx->Init()); int8_t *correct; @@ -84,14 +80,8 @@ TEST_F(TestPadInt8, PadInt8Test1) { int8_t *output_data = reinterpret_cast(outputs_[0]->MutableData()); ASSERT_EQ(0, CompareOutputData(output_data, correct, total_size, 0)); - for (auto &in_t : inputs_) { - delete in_t; - } - for (auto &out_t : outputs_) { - delete out_t; - } + delete pad; - delete ctx; free(correct); } @@ -104,7 +94,6 @@ int PadInt8TestInit2(std::vector *inputs_, std::vector *outp LiteQuantParam *in_quant_arg = new LiteQuantParam(); in_quant_arg->zeroPoint = 10, in_quant_arg->scale = 0.31228156; in_t->AddQuantParam(*in_quant_arg); - delete in_quant_arg; inputs_->push_back(in_t); Tensor *in_t_pad = new Tensor(kNumberTypeInt8, {1}, mindspore::NHWC, lite::Category::VAR); @@ -116,7 +105,6 @@ int PadInt8TestInit2(std::vector *inputs_, std::vector *outp LiteQuantParam *out_quant_arg = new LiteQuantParam(); out_quant_arg->zeroPoint = 10, out_quant_arg->scale = 0.31228156; out_t->AddQuantParam(*out_quant_arg); - delete out_quant_arg; outputs_->push_back(out_t); *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(int8_t))); @@ -136,9 +124,7 @@ int PadInt8TestInit2(std::vector *inputs_, std::vector *outp TEST_F(TestPadInt8, PadInt8Test2) { std::vector inputs_; std::vector outputs_; - PadParameter *pad_param = reinterpret_cast(malloc(sizeof(PadParameter))); - ASSERT_NE(pad_param, nullptr); - memset(pad_param, 0, sizeof(PadParameter)); + auto pad_param = new PadParameter(); lite::InnerContext *ctx = new lite::InnerContext; ASSERT_EQ(lite::RET_OK, ctx->Init()); int8_t *correct; @@ -151,14 +137,7 @@ TEST_F(TestPadInt8, PadInt8Test2) { int8_t *output_data = reinterpret_cast(outputs_[0]->MutableData()); ASSERT_EQ(0, CompareOutputData(output_data, correct, total_size, 0)); - for (auto &in_t : inputs_) { - delete in_t; - } - for (auto &out_t : outputs_) { - delete out_t; - } delete pad; - delete ctx; free(correct); } @@ -171,7 +150,6 @@ int PadInt8TestInit4(std::vector *inputs_, std::vector *outp LiteQuantParam *in_quant_arg = new LiteQuantParam(); in_quant_arg->zeroPoint = 10, in_quant_arg->scale = 0.31228156; in_t->AddQuantParam(*in_quant_arg); - delete in_quant_arg; inputs_->push_back(in_t); Tensor *in_t_pad = new Tensor(kNumberTypeInt8, {1}, mindspore::NHWC, lite::Category::VAR); @@ -183,7 +161,6 @@ int PadInt8TestInit4(std::vector *inputs_, std::vector *outp LiteQuantParam *out_quant_arg = new LiteQuantParam(); out_quant_arg->zeroPoint = 10, out_quant_arg->scale = 0.31228156; out_t->AddQuantParam(*out_quant_arg); - delete out_quant_arg; outputs_->push_back(out_t); *correct = reinterpret_cast(malloc(out_t->ElementsNum() * sizeof(int8_t))); @@ -217,9 +194,7 @@ int PadInt8TestInit4(std::vector *inputs_, std::vector *outp TEST_F(TestPadInt8, PadInt8TestInit4) { std::vector inputs_; std::vector outputs_; - PadParameter *pad_param = reinterpret_cast(malloc(sizeof(PadParameter))); - ASSERT_NE(pad_param, nullptr); - memset(pad_param, 0, sizeof(PadParameter)); + auto pad_param = new PadParameter(); lite::InnerContext *ctx = new lite::InnerContext; ctx->thread_num_ = 2; ASSERT_EQ(lite::RET_OK, ctx->Init()); @@ -233,14 +208,7 @@ TEST_F(TestPadInt8, PadInt8TestInit4) { int8_t *output_data = reinterpret_cast(outputs_[0]->MutableData()); ASSERT_EQ(0, CompareOutputData(output_data, correct, total_size, 0)); - for (auto &in_t : inputs_) { - delete in_t; - } - for (auto &out_t : outputs_) { - delete out_t; - } delete pad; - delete ctx; free(correct); } } // namespace mindspore diff --git a/mindspore-lite/tools/benchmark/CMakeLists.txt b/mindspore-lite/tools/benchmark/CMakeLists.txt index d95e1d7a1834b10156ead40601c9065cf8def73b..a91ab428e219b62f40003edf8a257af38aaae476 100644 --- a/mindspore-lite/tools/benchmark/CMakeLists.txt +++ b/mindspore-lite/tools/benchmark/CMakeLists.txt @@ -57,10 +57,6 @@ else() endif() endif() -if(MSLITE_EXPORT_COMPUTE_IR) - set(BENCHMARK_LINK_LIB ${BENCHMARK_LINK_LIB} mindspore_lite_drawer) -endif() - include_directories(${OPS_DIR}/kernel/cpu) set(COMMON_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc @@ -69,13 +65,13 @@ set(COMMON_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/file_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/config_file.cc - ${OPS_DIR}/kernel/cpu/nnacl/nnacl_common.c + ${CMAKE_CURRENT_SOURCE_DIR}/../../ops/kernel/cpu/nnacl/nnacl_common.c ) -include_directories(${TOP_DIR}/mindspore-lite) -include_directories(${TOP_DIR}/mindspore-lite/mindspore/mindspore/core/include) -include_directories(${TOP_DIR}/mindspore-lite/mindspore/mindspore/core/mindrt) -include_directories(${TOP_DIR}/mindspore-lite/mindspore/mindspore/core/mindrt/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../lite) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../core/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../core/mindrt) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../core/mindrt/include) if(MSLITE_GPU_BACKEND STREQUAL opencl) if(ANDROID_NDK_TOOLCHAIN_INCLUDED) @@ -99,9 +95,11 @@ set(BENCHMARK_SRC ${CMAKE_CURRENT_SOURCE_DIR}/benchmark_base.cc ${CMAKE_CURRENT_SOURCE_DIR}/benchmark_unified_api.cc) -set_property(SOURCE ${BENCHMARK_SRC} PROPERTY COMPILE_DEFINITIONS - LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" - SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) +if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) + set_property(SOURCE ${BENCHMARK_SRC} PROPERTY COMPILE_DEFINITIONS + LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" + SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) +endif() add_executable(benchmark ${BENCHMARK_SRC} diff --git a/mindspore-lite/tools/benchmark/benchmark_unified_api.cc b/mindspore-lite/tools/benchmark/benchmark_unified_api.cc index 987e94ca5ef2f513a4b69d43c75d709696ca8317..856c4ff43b053d27e3d0ec8a3dbf576da62111d1 100644 --- a/mindspore-lite/tools/benchmark/benchmark_unified_api.cc +++ b/mindspore-lite/tools/benchmark/benchmark_unified_api.cc @@ -39,7 +39,7 @@ #include "include/mpi_sys.h" #include "include/mpi_vb.h" #endif -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE #include #include "src/common/config_file.h" #endif @@ -51,7 +51,7 @@ constexpr int kFrequencyDefault = 3; constexpr int kPercentageDivisor = 100; constexpr int kDumpInputsAndOutputs = 0; constexpr int kDumpOutputs = 2; -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE constexpr int kMaxRequestNum = 200; #endif namespace lite { @@ -219,7 +219,7 @@ int BenchmarkUnifiedApi::LoadInput() { } int BenchmarkUnifiedApi::GenerateInputData() { -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE if (flags_->enable_parallel_predict_) { std::vector inputs; for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) { @@ -306,7 +306,7 @@ void BenchmarkUnifiedApi::UpdateConfigInfo() { } int BenchmarkUnifiedApi::ReadInputFile() { -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE if (flags_->enable_parallel_predict_) { std::vector inputs; for (size_t i = 0; i < ms_inputs_for_api_.size(); i++) { @@ -377,7 +377,7 @@ int BenchmarkUnifiedApi::ReadInputFile() { } int BenchmarkUnifiedApi::GetDataTypeByTensorName(const std::string &tensor_name) { -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE for (auto tensor : ms_outputs_for_api_) { auto name = tensor.Name(); if (name == tensor_name) { @@ -532,7 +532,7 @@ int BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr return RET_OK; } -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE int BenchmarkUnifiedApi::CompareOutputForModelPool(std::vector *outputs) { if (outputs->empty()) { MS_LOG(ERROR) << "outputs is empty."; @@ -1069,7 +1069,7 @@ int BenchmarkUnifiedApi::PrintInputData() { } return RET_OK; } -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE void BenchmarkUnifiedApi::ModelParallelRunnerWarmUp(int index) { auto in = model_runner_.GetInputs(); for (size_t i = 0; i < in.size(); i++) { @@ -1345,7 +1345,7 @@ std::vector> BenchmarkUnifiedApi::ParseGraphInputShapeMap(c return resize_dims; } -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE int BenchmarkUnifiedApi::RunParallelBenchmark(std::shared_ptr context) { if (flags_->resize_dims_.empty() && flags_->graph_input_shape_map_.empty()) { MS_LOG(ERROR) << "model input shapes should be provided when using parallel predict, please specify --inputShape"; @@ -1408,7 +1408,7 @@ int BenchmarkUnifiedApi::RunBenchmark() { } UpdateConfigInfo(); -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE if (flags_->enable_parallel_predict_) { MS_CHECK_FALSE_MSG(RunParallelBenchmark(context) != RET_OK, RET_ERROR, "run model pool failed."); return RET_OK; @@ -1837,7 +1837,7 @@ int BenchmarkUnifiedApi::InitDumpTensorDataCallbackParameter() { } BenchmarkUnifiedApi::~BenchmarkUnifiedApi() { -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE if (!flags_->enable_parallel_predict_) { return; } diff --git a/mindspore-lite/tools/benchmark/benchmark_unified_api.h b/mindspore-lite/tools/benchmark/benchmark_unified_api.h index 3de4fb64f8e107d3676afbf6a1dd76a2f8fa05b0..18b8229ce8ddd2268dd1f4555396fb7569a6217c 100644 --- a/mindspore-lite/tools/benchmark/benchmark_unified_api.h +++ b/mindspore-lite/tools/benchmark/benchmark_unified_api.h @@ -41,7 +41,7 @@ #include "include/api/model.h" #include "include/api/context.h" #include "tools/common/opengl_util.h" -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE #include "include/api/model_parallel_runner.h" #endif @@ -97,7 +97,7 @@ class MS_API BenchmarkUnifiedApi : public BenchmarkBase { int InitPrintTensorDataCallbackParameter() override; int PrintInputData(); -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE int RunParallelBenchmark(std::shared_ptr context); int CompareOutputForModelPool(std::vector *outputs); void ModelParallelRunnerWarmUp(int index); @@ -137,7 +137,7 @@ class MS_API BenchmarkUnifiedApi : public BenchmarkBase { MSKernelCallBack ms_before_call_back_ = nullptr; MSKernelCallBack ms_after_call_back_ = nullptr; -#ifdef PARALLEL_INFERENCE +#ifdef MSLITE_ENABLE_CLOUD_INFERENCE std::vector> resize_dims_; std::vector> all_inputs_data_; std::vector> all_outputs_; diff --git a/mindspore-lite/tools/benchmark_train/CMakeLists.txt b/mindspore-lite/tools/benchmark_train/CMakeLists.txt index 347a0cc5e2e4b152f94a754054e2a95dab660f80..0d2c634b57b361a0d4be4abdcfe599f1b69b1e9b 100644 --- a/mindspore-lite/tools/benchmark_train/CMakeLists.txt +++ b/mindspore-lite/tools/benchmark_train/CMakeLists.txt @@ -5,7 +5,6 @@ set(COMMON_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc ) - set(TEST_SRC ${CMAKE_CURRENT_SOURCE_DIR}/main.cc ${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc @@ -41,7 +40,3 @@ else() target_link_libraries(benchmark_train mindspore-lite mindspore-lite-train pthread securec) endif() endif() - -if(MSLITE_EXPORT_COMPUTE_IR) - target_link_libraries(benchmark_train mindspore_lite_drawer) -endif() diff --git a/mindspore-lite/tools/converter/CMakeLists.txt b/mindspore-lite/tools/converter/CMakeLists.txt index b391f97f279b2edffb0f556d2a04506895e568ee..dbde737c85f49507f1096110ae9803a0e1f8767f 100644 --- a/mindspore-lite/tools/converter/CMakeLists.txt +++ b/mindspore-lite/tools/converter/CMakeLists.txt @@ -12,9 +12,7 @@ if(ENABLE_GPU) add_compile_definitions(ENABLE_GPU) endif() -include(${LITE_DIR}/cmake/ccsrc_converter.cmake) - -include_directories(${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu) +include_directories(${NNACL_DIR}/..) file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc @@ -92,7 +90,6 @@ add_subdirectory(${OPS_DIR} mindspore_ops) if(MSLITE_ENABLE_ACL) set(MODE_ASCEND_ACL ON) - #include(${TOP_DIR}/cmake/dependency_graphengine.cmake) add_subdirectory(adapter/acl) link_directories(${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) endif() @@ -108,7 +105,6 @@ if(MSLITE_ENABLE_RUNTIME_PASS) endif() file(GLOB FORMAT_PASS_SRCS ${SRC_DIR}/litert/pass/format_pass/*.cc) - set(LITE_SRC ${API_SRC} ${CXX_API_SRCS} ${FORMAT_PASS_SRCS} @@ -155,14 +151,11 @@ set(LITE_SRC ${API_SRC} ${SRC_DIR}/litert/weight_decoder.cc ${SRC_DIR}/litert/pack_weight_manager.cc ${SRC_DIR}/litert/huffman_decode.cc - ${SRC_DIR}/extendrt/delegate/tensorrt/distribution/distribution_base.cc - ${SRC_DIR}/extendrt/delegate/plugin/tensorrt_executor_plugin.cc ${SRC_DIR}/infer/primitive_type.cc ${LITE_DIR}/src/extendrt/mock/lite_runtime/populate/base_operator_populate_register.cc ${SRC_DIR}/control_flow/control_flow_scheduler.cc ${SRC_DIR}/control_flow/control_subgraph_creator.cc - ${SRC_DIR}/litert/kernel/ascend/plugin/ascend_kernel_plugin.cc - ${SRC_DIR}/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.cc + ${SRC_DIR}/extendrt/delegate/ascend_acl/ascend_allocator_plugin.cc ) if(MSLITE_ENABLE_CUSTOM_KERNEL) @@ -170,13 +163,6 @@ if(MSLITE_ENABLE_CUSTOM_KERNEL) set(LITE_SRC ${LITE_SRC} ${KERNEL_REG_SRC}) endif() -if(NOT ANDROID_NDK_TOOLCHAIN_INCLUDED) - set(LITE_SRC - ${LITE_SRC} - ${SRC_DIR}/litert/kernel/ascend/plugin/ascend_kernel_plugin.cc - ) -endif() - if(MSLITE_ENABLE_MODEL_PRE_INFERENCE) set(LITE_SRC ${LITE_SRC} @@ -207,7 +193,9 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) set(MINDIR_KERNEL_SRC ${MINDIR_KERNEL_SRC} - ${SRC_DIR}/extendrt/mindir_loader/mindir_model/inner_kernel.cc) + ${SRC_DIR}/extendrt/mindir_loader/mindir_model/inner_kernel.cc + ${SRC_DIR}/extendrt/utils/tensor_utils.cc + ) endif() set(LITE_SRC @@ -309,16 +297,6 @@ set(TFLITE_FBS_FILES ms_build_flatbuffers_lite(TFLITE_FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/parser/tflite/ tflite_fbs_src ${CMAKE_BINARY_DIR}/schema "inner") -set(LITE_SRC - ${LITE_SRC} - ${CORE_DIR}/abstract/abstract_value.cc - ${CORE_DIR}/ir/anf.cc - ${CORE_DIR}/base/base.cc - ${OPS_DIR}/op_def/auto_generate/gen_lite_ops.cc - ${OPS_DIR}/kernel/common/device_address.cc - ${OPS_DIR}/kernel/common/device_type.cc - ) - set_property(SOURCE ${CONVERTER_SRC} PROPERTY COMPILE_DEFINITIONS LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) @@ -328,7 +306,6 @@ set_property(SOURCE ${LITE_SRC} PROPERTY COMPILE_DEFINITIONS add_library(converter_runtime_mid OBJECT ${LITE_SRC}) add_dependencies(converter_runtime_mid fbs_src fbs_inner_src) -add_dependencies(converter_runtime_mid ccsrc_src_mid) target_compile_options(converter_runtime_mid PRIVATE "-Wno-stringop-overflow") add_library(mindspore_converter SHARED $) @@ -358,7 +335,6 @@ if(MSLITE_ENABLE_FP16) endif() target_link_libraries(mindspore_converter - ccsrc_src_mid converter_src_mid cpu_ops_mid nnacl_mid @@ -394,10 +370,6 @@ if(NOT MSLITE_SIMPLEST_CLOUD_INFERENCE) target_link_libraries(mindspore_converter decomposer_mid) endif() -if(MSLITE_EXPORT_COMPUTE_IR) - target_link_libraries(mindspore_converter mindspore_lite_drawer) -endif() - if(SUPPORT_TRAIN) target_link_libraries(mindspore_converter train_cpu_kernel_mid) endif() @@ -419,4 +391,33 @@ if(NOT WIN32) target_link_libraries(mindspore_converter dl) endif() +file(STRINGS "${TOP_DIR}/version.txt" VERSION) +add_definitions(-DVERSION=\"${VERSION}\") +set(CCSRC_SRC_NEW ${CCSRC_DIR}/backend/common/optimizer/graph_optimizer.cc + ${CCSRC_DIR}/backend/common/optimizer/pattern_engine.cc + ${CCSRC_DIR}/utils/convert_utils.cc + ${CCSRC_DIR}/common/debug/mindir_exporter.cc + ${CCSRC_DIR}/backend/common/optimizer/visitor.cc + ${CCSRC_DIR}/common/debug/common.cc + ${CCSRC_DIR}/utils/compile_cache_context.cc + ) +if(NOT WIN32) + set(CCSRC_SRC_NEW ${CCSRC_SRC_NEW} + ${CCSRC_DIR}/utils/anfalgo.cc + ${CCSRC_DIR}/utils/parallel_context.cc + ${CCSRC_DIR}/utils/utils.cc) +endif() +if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) + set(CCSRC_SRC_NEW ${CCSRC_SRC_NEW} + ${OPS_DIR}/kernel/common/kernel_factory.cc + ${OPS_DIR}/kernel/common/kernel_tensor.cc + ${OPS_DIR}/kernel/common/device_address.cc + ${OPS_DIR}/kernel/common/format_utils.cc + ${OPS_DIR}/kernel/common/kernel.cc) +endif() +add_library(ccsrc_src_new_mid OBJECT ${CCSRC_SRC_NEW}) +target_compile_definitions(ccsrc_src_new_mid PRIVATE BACKEND_DLL) +target_compile_definitions(ccsrc_src_new_mid PRIVATE COMMON_DLL) +target_compile_definitions(ccsrc_src_new_mid PRIVATE OPS_KERNEL_COMMON_DLL) +target_link_libraries(mindspore_converter ccsrc_src_new_mid) add_subdirectory(converter_lite) diff --git a/mindspore-lite/tools/converter/adapter/OWNERS b/mindspore-lite/tools/converter/adapter/OWNERS index d4f3c861a0f2faeee4c73db829ff34681b6c1e8b..a07bacdafbc246f01f93448479360e871e7cbd0c 100644 --- a/mindspore-lite/tools/converter/adapter/OWNERS +++ b/mindspore-lite/tools/converter/adapter/OWNERS @@ -1,8 +1,7 @@ approvers: -- jjfeing +- jpc_chenjianping # - YeFeng_24 -- fatmouse007fatmouse007 -- xu_anyue +- fatmouse007fatmouse007 # zhuguodong options: no_parent_owners: false diff --git a/mindspore-lite/tools/converter/adapter/acl/CMakeLists.txt b/mindspore-lite/tools/converter/adapter/acl/CMakeLists.txt index ab56eeb643185726dcd8a57a7c02ebc9abd25317..3570876b5b1998dc30f081124c30e895a7aadad6 100644 --- a/mindspore-lite/tools/converter/adapter/acl/CMakeLists.txt +++ b/mindspore-lite/tools/converter/adapter/acl/CMakeLists.txt @@ -1,5 +1,5 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) - +include_directories(${TOP_DIR}/mindspore-lite/minddata/dataset) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN/") link_directories(${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) @@ -11,6 +11,10 @@ file(GLOB ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/infer/*.cc ${TOP_DIR}/mindspore-lite/src/extendrt/utils/serialization.cc ${TOP_DIR}/mindspore-lite/src/extendrt/cxx_api/serialization.cc + ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api_lite/cxx_api/model/acl/acl_model_options.cc + ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api_lite/cxx_api/model/acl/model_converter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api_lite/cxx_api/graph/graph_data.cc + ${CMAKE_CURRENT_SOURCE_DIR}/cxx_api_lite/cxx_api/model/aoe/auto_tune_process.cc ) set(ACL_SRC ${ACL_SRC} ${CMAKE_CURRENT_SOURCE_DIR}/acl_pass.cc) @@ -18,16 +22,18 @@ set(ACL_SRC ${ACL_SRC} ${CMAKE_CURRENT_SOURCE_DIR}/acl_pass.cc) set(ENABLE_ACL on) set(MODE_ASCEND_ACL off) add_subdirectory(${TOP_DIR}/mindspore/mindspore/ccsrc/backend/ge_backend/graph_ir _mindspore_transform_graph_ir) -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/cxx_api_lite/cxx_api mslite_shared_lib) - set_property(SOURCE ${ACL_SRC} PROPERTY COMPILE_DEFINITIONS LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) add_library(ascend_pass_plugin SHARED ${ACL_SRC}) -target_link_libraries(ascend_pass_plugin mslite_shared_lib) -add_dependencies(ascend_pass_plugin mslite_shared_lib) +target_link_libraries(ascend_pass_plugin mindspore_graph_ir) +#target_link_libraries(ascend_pass_plugin mslite_shared_lib) +#add_dependencies(ascend_pass_plugin mslite_shared_lib) add_dependencies(ascend_pass_plugin fbs_inner_src) add_dependencies(ascend_pass_plugin mindspore_converter) -target_link_libraries(ascend_pass_plugin mindspore_converter lite_src_common_mid) \ No newline at end of file +target_link_libraries(ascend_pass_plugin mindspore_converter) +target_link_libraries(ascend_pass_plugin _mindspore_ascend_symbol_obj) +add_dependencies(ascend_pass_plugin lite_src_common_mid) +target_link_libraries(ascend_pass_plugin lite_src_common_mid) diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/CMakeLists.txt b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/CMakeLists.txt deleted file mode 100644 index 60aca4f37d28cb715cff3af2aa43a77dd83f6dfc..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/CMakeLists.txt +++ /dev/null @@ -1,246 +0,0 @@ -# find python library -if(MODE_ASCEND_ACL) - get_filename_component(PYTHON_LIB_REALPATH ${PYTHON_LIBRARIES} ABSOLUTE) - get_filename_component(PYTHON_LIB_DIR ${PYTHON_LIB_REALPATH} PATH) - - execute_process( - COMMAND "${Python3_EXECUTABLE}" -c "import distutils.sysconfig as sysconfig; \ - print(sysconfig.get_config_var('PY3LIBRARY'))" - RESULT_VARIABLE result - OUTPUT_VARIABLE PY3LIBRARY) - set(PY3_LIBG ${PYTHON_LIB_DIR}/${PY3LIBRARY}) - string(STRIP "${PY3_LIBG}" PY3_LIBG) - message("Python3 general library = " ${PY3_LIBG}) -endif() - -if(DEFINED ENV{ASCEND_HOME_PATH}) - set(ASCEND_HOME_PATH $ENV{ASCEND_HOME_PATH}) -elseif(DEFINED ENV{ASCEND_CUSTOM_PATH}) - set(ASCEND_HOME_PATH "$ENV{ASCEND_CUSTOM_PATH}/latest") -else() - set(ASCEND_HOME_PATH "/usr/local/Ascend/latest") -endif() -set(ASCEND_VERSION_FILE "${ASCEND_HOME_PATH}/compiler/version.info") -if(NOT EXISTS ${ASCEND_VERSION_FILE}) - set(ASCEND_HOME_PATH "/usr/local/Ascend/ascend-toolkit/latest") - set(ASCEND_VERSION_FILE "${ASCEND_HOME_PATH}/compiler/version.info") -endif() -if(MSLITE_ENABLE_ACL AND EXISTS ${ASCEND_VERSION_FILE}) - file(STRINGS ${ASCEND_VERSION_FILE} CANN_VERSION_STRING LIMIT_INPUT 20) - string(REGEX MATCH "([0-9]+\\.[0-9]+)" EXTRACTED_VERSION "${CANN_VERSION_STRING}") - set(VERSION_FLOW "${EXTRACTED_VERSION}") - message("cann version:" ${VERSION_FLOW}) - if(VERSION_FLOW GREATER_EQUAL 7.5) - message("define enable bundle") - add_compile_definitions(ENABLE_BUNDLE) - endif() -endif() - -if(WIN32) - # define this for msvc:dllexport or dllimport - add_compile_definitions(BUILDING_DLL) -endif() - -# build mindspore_shared_lib -include_directories(${TOP_DIR}/mindspore/ccsrc) -include_directories(${TOP_DIR}/mindspore/ccsrc/minddata/dataset) - -if(ENABLE_D OR ENABLE_ACL) - # build 910 and 310 code into one distro, files needed for 310 mode - add_compile_definitions(ENABLE_ACL) - include_directories(${CMAKE_BINARY_DIR}/proto/ge) -endif() - -if(MODE_ASCEND_ACL OR BUILD_LITE) - add_compile_definitions(MODE_ASCEND_ACL) - file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} - "model/acl/*.cc" - "model/model_converter_utils/*.cc" - "graph/acl/*.cc" - ) -endif() - -set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc - ${CMAKE_CURRENT_SOURCE_DIR}/context.cc - ${CMAKE_CURRENT_SOURCE_DIR}/cell.cc - ${CMAKE_CURRENT_SOURCE_DIR}/factory.cc - ${CMAKE_CURRENT_SOURCE_DIR}/any_utils.cc - ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph/graph.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph/graph_data.cc - ${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc - ${CMAKE_CURRENT_SOURCE_DIR}/model/model_impl.cc - ${API_ACL_SRC} - ) -if(ENABLE_D OR ENABLE_GPU) - list(APPEND MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/akg_kernel_register.cc) -endif() - -if(NOT BUILD_LITE) - list(APPEND MSLIB_SRC_DEPEND - ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/strategy_checkpoint_info.cc) - list(APPEND MSLIB_SRC_DEPEND - ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc) - list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/group_manager.cc) - list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/device_manager.cc) - list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/device_matrix.cc) - list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc) - list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc) - list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc) - list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc) - list(APPEND MSLIB_SRC_DEPEND ${CMAKE_SOURCE_DIR}/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc) - list(APPEND MSLIB_SRC_DEPEND ${CMAKE_CURRENT_SOURCE_DIR}/utils.cc) -endif() - -if((MODE_ASCEND_ACL) AND NOT BUILD_LITE) - list(APPEND MSLIB_SRC - "${CMAKE_SOURCE_DIR}/mindspore/ccsrc/backend/backend_manager/backend_jit_config.cc" - "${CMAKE_SOURCE_DIR}/mindspore/ccsrc/backend/common/optimizer/pattern_engine.cc" - "${CMAKE_SOURCE_DIR}/mindspore/ccsrc/backend/common/optimizer/helper.cc" - "${CMAKE_SOURCE_DIR}/mindspore/ccsrc/backend/common/optimizer/node_pass.cc" - "${CMAKE_SOURCE_DIR}/mindspore/ccsrc/backend/common/optimizer/visitor.cc" - "${CMAKE_SOURCE_DIR}/mindspore/ccsrc/backend/operator/ops_backend_infer_function.cc" - "${CMAKE_SOURCE_DIR}/mindspore/ops/kernel/common/kernel_build_info.cc" - "${CMAKE_SOURCE_DIR}/mindspore/ccsrc/kernel/kernel_info.cc" - "${CMAKE_SOURCE_DIR}/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/reduce_axis_update.cc" - "${CMAKE_SOURCE_DIR}/mindspore/ccsrc/plugin/device/ascend/optimizer/ge_optimization.cc") - # MODE_ASCEND_ACL don't need reg akg kernel - list(REMOVE_ITEM MSLIB_SRC "${CMAKE_CURRENT_SOURCE_DIR}/akg_kernel_register.cc") -endif() - -if(NOT ENABLE_TESTCASES AND NOT BUILD_LITE) - # users of shared_lib cannot find symbols in indirect dependency - set(MSLIB_SRC ${MSLIB_SRC} ${CMAKE_SOURCE_DIR}/mindspore/core/utils/status.cc) -endif() - -if(BUILD_LITE) - list(APPEND MSLIB_SRC - "${TOP_DIR}/mindspore/mindspore/ccsrc/utils/config_manager.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/model/aoe/auto_tune_process.cc") - file(GLOB_RECURSE ACL_REMOVE_SRC ${CMAKE_CURRENT_SOURCE_DIR} - "model/acl/acl_vm/*.cc" - ) - list(REMOVE_ITEM MSLIB_SRC "${CMAKE_CURRENT_SOURCE_DIR}/akg_kernel_register.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/model/acl/acl_model_multi.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/model/acl/acl_model.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/types.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/model/model_impl.cc" - ${ACL_REMOVE_SRC}) -endif() - -add_library(common_shared_lib_obj OBJECT ${MSLIB_SRC}) -if(BUILD_LITE) - add_library(mslite_shared_lib SHARED $ ${MSLIB_SRC_DEPEND}) -else() - add_library(mindspore_shared_lib SHARED $ ${MSLIB_SRC_DEPEND}) -endif() -if(BUILD_LITE) - target_link_libraries(mslite_shared_lib PRIVATE mindspore_graph_ir) - add_dependencies(mslite_shared_lib mindspore_graph_ir) -elseif(MODE_ASCEND_ACL) - target_link_libraries(mindspore_shared_lib PRIVATE mindspore_graph_ir - _mindspore_backend_graph_compiler_obj _mindspore_debug_obj mindspore_backend_static) - if(MS_BUILD_GRPC) - target_link_libraries(mindspore_shared_lib PRIVATE mindspore::grpc++) - endif() - set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore) -else() - if(CMAKE_SYSTEM_NAME MATCHES "Linux") - # wheel package and ut - add_library(api_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/types.cc - ${CMAKE_CURRENT_SOURCE_DIR}/context.cc - ${CMAKE_CURRENT_SOURCE_DIR}/cell.cc - ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph/graph.cc - ${CMAKE_CURRENT_SOURCE_DIR}/graph/graph_data.cc - ${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc - ${CMAKE_SOURCE_DIR}/mindspore/core/utils/status.cc - ) - target_link_libraries(api_lib PRIVATE mindspore_shared_lib -Wl,--no-as-needed ${PYTHON_LIBRARIES}) - set_target_properties(api_lib PROPERTIES OUTPUT_NAME mindspore) - set_target_properties(api_lib PROPERTIES INSTALL_RPATH "$ORIGIN") - endif() -endif() - -if(ENABLE_D OR ENABLE_GPU) - if(BUILD_LITE) - target_link_libraries(mslite_shared_lib PRIVATE -Wl,--as-needed ${SECUREC_LIBRARY} - mindspore_ms_backend mindspore_core mindspore_ops mindspore_common mindspore_ops_kernel_common - mindspore_backend_common proto_input mindspore::protobuf mindspore_profiler mindspore_backend_manager) - else() - target_link_libraries(mindspore_shared_lib PRIVATE -Wl,--as-needed ${SECUREC_LIBRARY} - mindspore_ms_backend mindspore_core mindspore_ops mindspore_common mindspore_ops_kernel_common - mindspore_backend_common mindspore_profiler proto_input mindspore::protobuf mindspore_backend_manager) - endif() -else() - if(BUILD_LITE) - target_link_libraries(mslite_shared_lib PRIVATE mindspore_core mindspore_ops ${SECUREC_LIBRARY}) - else() - target_link_libraries(mindspore_shared_lib PRIVATE ${PY3_LIBG} ${SECUREC_LIBRARY} - mindspore_ms_backend mindspore mindspore_core mindspore_ops mindspore_common mindspore_ops_kernel_common - mindspore_backend_common mindspore_profiler proto_input mindspore::protobuf mindspore_backend_manager) - endif() -endif() - -if(ENABLE_CPU) - if(BUILD_LITE) - target_link_libraries(mslite_shared_lib PRIVATE mindspore::dnnl mindspore::mkldnn nnacl) - else() - target_link_libraries(mindspore_shared_lib PRIVATE mindspore::dnnl mindspore::mkldnn nnacl) - endif() -endif() - -if(USE_GLOG) - if(BUILD_LITE) - target_link_libraries(mslite_shared_lib PRIVATE mindspore::glog) - else() - target_link_libraries(mindspore_shared_lib PRIVATE mindspore::glog) - endif() -endif() - -if(CMAKE_SYSTEM_NAME MATCHES "Linux") - if(BUILD_LITE) - target_link_options(mslite_shared_lib PRIVATE -Wl,-init,common_log_init) - else() - target_link_options(mindspore_shared_lib PRIVATE -Wl,-init,common_log_init) - endif() -endif() - -if(MODE_ASCEND_ACL OR MSLITE_ENABLE_ACL) - # 310 mode - add_compile_definitions(ENABLE_DVPP_INTERFACE) - find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl_cblas libacl_cblas.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(acl_runtime libruntime.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(ge_compiler libge_compiler.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(libplatform libplatform.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(libcompress libcompress.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(libopskernel libopskernel.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(libaicore_utils libaicore_utils.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(libaicpu_engine_common libaicpu_engine_common.so ${ASCEND_CANN_RUNTIME_PATH} - ${ASCEND_TOOLKIT_RUNTIME_PATH}) - find_library(GE_RUNNER ge_runner ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) - - if(BUILD_LITE) - target_link_libraries(mslite_shared_lib PRIVATE -Wl,--no-as-needed graph ${ge_compiler} - ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime} ${libplatform} ${libcompress} ${libopskernel} - ${libaicore_utils} ${libaicpu_engine_common} ${acl} ${GE_RUNNER}) - else() - target_link_libraries(mindspore_shared_lib PRIVATE -Wl,--no-as-needed graph ${ge_compiler} - ${acl_retr} ${acl_cblas} ${acl_dvpp} ${acl_runtime} ${libplatform} ${libcompress} ${libopskernel} - ${libaicore_utils} ${libaicpu_engine_common} ${acl} ${GE_RUNNER}) - endif() -endif() - -if(CMAKE_SYSTEM_NAME MATCHES "Linux") -# duplicate, should be remove after backend decoupling is done - set(MINDSPORE_SHARED_LIB_RPATH $ORIGIN) - if(BUILD_LITE) - set_target_properties(mslite_shared_lib PROPERTIES INSTALL_RPATH ${MINDSPORE_SHARED_LIB_RPATH}) - else() - set_target_properties(mindspore_shared_lib PROPERTIES INSTALL_RPATH ${MINDSPORE_SHARED_LIB_RPATH}) - endif() -endif() diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/akg_kernel_register.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/akg_kernel_register.cc deleted file mode 100644 index 50968a80c21d5df0f31a36b90e9d79799e68fc44..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/akg_kernel_register.cc +++ /dev/null @@ -1,113 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cxx_api/akg_kernel_register.h" -#ifdef _MSC_VER -#include -#else -#include -#endif -#include -#include -#include -#include -#include "common/oplib/oplib.h" - -static std::mutex init_mutex; -static bool Initialized = false; - -namespace mindspore { -static bool RegAllOpFromFile() { - std::string dir; -#ifndef _MSC_VER - Dl_info info; - int dl_ret = dladdr(reinterpret_cast(RegAllOpFromFile), &info); - if (dl_ret == 0) { - MS_LOG(INFO) << "Get dladdr failed, skip."; - return false; - } - dir = info.dli_fname; -#else - HMODULE hModule = nullptr; - if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT | GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS, - (LPCSTR)RegAllOpFromFile, &hModule) != 0) { - char szPath[MAX_PATH]; - if (GetModuleFileName(hModule, szPath, sizeof(szPath)) != 0) { - dir = std::string(szPath); - } - } else { - MS_LOG(INFO) << "Get GetModuleHandleEx failed, skip."; - return false; - } -#endif - MS_LOG(INFO) << "Get library path is " << dir; - - auto split_pos = dir.find_last_of('/'); - if (dir.empty() || split_pos == std::string::npos) { - MS_LOG(INFO) << "Missing op config file, skip."; - return false; - } - - dir = dir.substr(0, split_pos) + "/../config/op_info.config"; - if (dir.size() >= PATH_MAX) { - MS_LOG(ERROR) << "Op info path is invalid: " << dir; - return false; - } - - char real_path_mem[PATH_MAX] = {0}; -#ifdef _MSC_VER - if (_fullpath(real_path_mem, common::SafeCStr(dir), PATH_MAX) == nullptr) { - MS_LOG(ERROR) << "Op info path is invalid: " << dir; - return false; - } -#else - if (realpath(common::SafeCStr(dir), real_path_mem) == nullptr) { - MS_LOG(ERROR) << "Op info path is invalid: " << dir; - return false; - } -#endif - std::string real_path(real_path_mem); - - MS_LOG(INFO) << "Start to read op info from local file " << real_path; - std::ifstream file(real_path); - if (!file.is_open()) { - MS_LOG(ERROR) << "Find op info file failed."; - return false; - } - kernel::OpLib::GetOpInfoMap().clear(); - std::string line; - while (getline(file, line)) { - if (!line.empty()) { - (void)kernel::OpLib::RegOp(line, ""); - } - } - MS_LOG(INFO) << "End"; - return true; -} - -void RegAllOp() { - std::lock_guard lock(init_mutex); - if (Initialized) { - return; - } - bool ret = RegAllOpFromFile(); - if (!ret) { - MS_LOG(ERROR) << "Register operators failed. The package may damaged or file is missing."; - return; - } - - Initialized = true; -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/cell.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/cell.cc deleted file mode 100644 index 4975bcc1af995ed0f2e3b2beb8210dfbeffa7189..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/cell.cc +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "include/api/cell.h" -#include "include/api/context.h" -#include "cxx_api/factory.h" -#include "cxx_api/graph/graph_impl.h" - -namespace mindspore { -std::vector CellBase::operator()(const std::vector &inputs) const { return Clone()->Construct(inputs); } - -GraphCell::GraphCell(const Graph &graph) : graph_(std::make_shared(graph)) { MS_EXCEPTION_IF_NULL(graph_); } - -GraphCell::GraphCell(const std::shared_ptr &graph) : graph_(graph) { MS_EXCEPTION_IF_NULL(graph_); } - -GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared(graph)) { MS_EXCEPTION_IF_NULL(graph_); } - -void GraphCell::SetContext(const std::shared_ptr &context) { - if (executor_ == nullptr) { - executor_ = GraphImplFactory::Instance().Create(g_device_target); - if (executor_ == nullptr) { - MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed."; - return; - } - executor_->SetGraph(graph_); - } - executor_->SetContext(context); -} - -Status GraphCell::Run(const std::vector &inputs, std::vector *outputs) { - if (executor_ == nullptr) { - executor_ = GraphImplFactory::Instance().Create(g_device_target); - if (executor_ == nullptr) { - MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed."; - return kMEFailed; - } - executor_->SetGraph(graph_); - } - return executor_->Run(inputs, outputs); -} - -Status GraphCell::Load(uint32_t device_id) { - if (executor_ == nullptr) { - executor_ = GraphImplFactory::Instance().Create(g_device_target); - if (executor_ == nullptr) { - MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed."; - return kMEFailed; - } - executor_->SetGraph(graph_); - } - return executor_->Load(device_id); -} - -std::vector GraphCell::GetInputs() { - if (executor_ == nullptr) { - executor_ = GraphImplFactory::Instance().Create(g_device_target); - if (executor_ == nullptr) { - MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed."; - return {}; - } - executor_->SetGraph(graph_); - } - return executor_->GetInputs(); -} - -std::vector GraphCell::GetOutputs() { - if (executor_ == nullptr) { - executor_ = GraphImplFactory::Instance().Create(g_device_target); - if (executor_ == nullptr) { - MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed."; - return {}; - } - executor_->SetGraph(graph_); - } - return executor_->GetOutputs(); -} - -InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {} - -InputAndOutput::InputAndOutput(const std::shared_ptr &cell, const std::vector &prev, - int32_t index) - : cell_(cell), prev_(prev), index_(index) {} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/factory.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/factory.cc deleted file mode 100644 index 73e965ab96af6294a18194e1f8f6b091eb84a048..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/factory.cc +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cxx_api/factory.h" -namespace mindspore { -GraphImplFactory &GraphImplFactory::Instance() { - std::call_once(once_flag_, []() { - if (instance_ == nullptr) { - instance_ = std::make_shared(); - } - }); - return *instance_; -} - -void GraphImplFactory::Register(const std::string &device_name, GraphImplCreator &&creator) { - MS_LOG(DEBUG) << "Start register graph impl for " << device_name; - (void)creators_.emplace_back(std::move(creator)); -} - -std::shared_ptr GraphImplFactory::Create(enum DeviceType device_type) { - for (auto &item : creators_) { - MS_EXCEPTION_IF_NULL(item); - auto val = item(); - MS_EXCEPTION_IF_NULL(val); - if (val->CheckDeviceSupport(device_type)) { - return val; - } - } - MS_LOG(WARNING) << "Unsupported device target " << device_type; - return nullptr; -} - -ModelImplFactory &ModelImplFactory::Instance() { - std::call_once(once_flag_, []() { - if (instance_ == nullptr) { - instance_ = std::make_shared(); - } - }); - return *instance_; -} - -void ModelImplFactory::Register(const std::string &device_name, ModelImplCreator &&creator) { - MS_LOG(DEBUG) << "Start register model for " << device_name; - (void)creators_.emplace_back(std::move(creator)); -} - -std::shared_ptr ModelImplFactory::Create(enum DeviceType device_type) { - for (auto &item : creators_) { - MS_EXCEPTION_IF_NULL(item); - auto val = item(); - MS_EXCEPTION_IF_NULL(val); - if (val->CheckDeviceSupport(device_type)) { - return val; - } - } - MS_LOG(WARNING) << "Unsupported device target " << device_type; - return nullptr; -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/factory.h b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/factory.h deleted file mode 100644 index 52a4b8c6f9478c35475b0be3e07e5936ed43f9d2..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/factory.h +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_CXX_API_FACTORY_H -#define MINDSPORE_CCSRC_CXX_API_FACTORY_H -#include -#include -#include -#include -#include -#include -#include "include/api/context.h" -#include "include/common/utils/utils.h" -#include "cxx_api/graph/graph_impl.h" -#include "cxx_api/model/model_impl.h" - -namespace mindspore { -constexpr auto Ascend310 = "Ascend310"; -constexpr auto Ascend910 = "Ascend910"; -constexpr auto kMS = "MS"; -#ifdef __linux__ -MS_API inline enum DeviceType g_device_target = kInvalidDeviceType; -#else -inline enum DeviceType g_device_target = kInvalidDeviceType; -#endif -static inline LogStream &operator<<(LogStream &stream, DeviceType device_type) { - std::map type_str_map = { - {kAscend, "Ascend"}, {kAscend910, "Ascend910"}, {kAscend310, "Ascend310"}, {kGPU, "GPU"}, {kCPU, "CPU"}}; - auto it = type_str_map.find(device_type); - if (it != type_str_map.end()) { - stream << it->second; - } else { - stream << "[InvalidDeviceType: " << static_cast(device_type) << "]"; - } - return stream; -} - -using GraphImplCreator = std::function()>; - -class MS_API GraphImplFactory { - public: - GraphImplFactory(const GraphImplFactory &) = delete; - GraphImplFactory &operator=(const GraphImplFactory &) = delete; - - static GraphImplFactory &Instance(); - - void Register(const std::string &device_name, GraphImplCreator &&creator); - - std::shared_ptr Create(enum DeviceType device_type); - - GraphImplFactory() = default; - ~GraphImplFactory() = default; - - private: - inline static std::shared_ptr instance_; - inline static std::once_flag once_flag_; - std::vector creators_; -}; - -class GraphImplRegistrar { - public: - explicit GraphImplRegistrar(const std::string &device_name, GraphImplCreator &&creator) { - GraphImplFactory::Instance().Register(device_name, std::move(creator)); - } - ~GraphImplRegistrar() = default; -}; - -using ModelImplCreator = std::function()>; - -class MS_API ModelImplFactory { - public: - ModelImplFactory(const ModelImplFactory &) = delete; - ModelImplFactory &operator=(const ModelImplFactory &) = delete; - - static ModelImplFactory &Instance(); - - void Register(const std::string &device_name, ModelImplCreator &&creator); - - std::shared_ptr Create(enum DeviceType device_type); - - ModelImplFactory() = default; - ~ModelImplFactory() = default; - - private: - inline static std::shared_ptr instance_; - inline static std::once_flag once_flag_; - std::vector creators_; -}; - -class ModelImplRegistrar { - public: - explicit ModelImplRegistrar(const std::string &device_name, ModelImplCreator &&creator) { - ModelImplFactory::Instance().Register(device_name, std::move(creator)); - } - ~ModelImplRegistrar() = default; -}; - -#define API_GRAPH_REG(DEVICE_NAME, DEVICE_CLASS) \ - static const GraphImplRegistrar graph_api_##DEVICE_NAME##_registrar_reg( \ - DEVICE_NAME, []() { return std::make_shared(); }) - -#define API_MODEL_REG(DEVICE_NAME, DEVICE_CLASS) \ - static const ModelImplRegistrar model_api_##DEVICE_NAME##_registrar_reg( \ - DEVICE_NAME, []() { return std::make_shared(); }) -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/acl_env_guard.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/acl_env_guard.cc deleted file mode 100644 index e5ab156cc65c03b72a3e8836d724c6b4f91609f8..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/acl_env_guard.cc +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cxx_api/graph/acl/acl_env_guard.h" -#include "src/common/log_adapter.h" -#include "utils/ms_utils.h" -#include "plugin/res_manager/ascend/symbol_interface/acl_symbol.h" -#include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" - -namespace mindspore { -std::shared_ptr AclEnvGuard::global_acl_env_ = nullptr; -std::mutex AclEnvGuard::global_acl_env_mutex_; - -AclInitAdapter &AclInitAdapter::GetInstance() { - static AclInitAdapter instance = {}; - return instance; -} - -aclError AclInitAdapter::AclInit(const char *config_file) { - std::lock_guard lock(flag_mutex_); - if (init_flag_) { - return ACL_SUCCESS; - } - - init_flag_ = true; - return CALL_ASCEND_API(aclInit, config_file); -} - -aclError AclInitAdapter::AclFinalize() { - std::lock_guard lock(flag_mutex_); - if (!init_flag_) { - MS_LOG(INFO) << "Acl had been finalized."; - return ACL_SUCCESS; - } - - MS_LOG(INFO) << "Begin to aclFinalize."; - init_flag_ = false; - return CALL_ASCEND_API(aclFinalize); -} - -aclError AclInitAdapter::ForceFinalize() { - std::lock_guard lock(flag_mutex_); - MS_LOG(INFO) << "Begin to force aclFinalize."; - init_flag_ = false; - return CALL_ASCEND_API(aclFinalize); -} - -AclEnvGuard::AclEnvGuard() : errno_(AclInitAdapter::GetInstance().AclInit(nullptr)) { - if (errno_ != ACL_SUCCESS && errno_ != ACL_ERROR_REPEAT_INITIALIZE) { - MS_LOG(ERROR) << "Execute aclInit failed."; - return; - } - MS_LOG(INFO) << "Execute aclInit success."; -} - -AclEnvGuard::~AclEnvGuard() { - TRY_AND_CATCH_WITH_EXCEPTION(errno_ = AclInitAdapter::GetInstance().AclFinalize(), - "AclInitAdapter GetInstance failed"); - if (errno_ != ACL_SUCCESS && errno_ != ACL_ERROR_REPEAT_FINALIZE) { - MS_LOG(ERROR) << "Execute AclFinalize failed."; - } - MS_LOG(INFO) << "Execute AclFinalize success."; -} - -std::shared_ptr AclEnvGuard::GetAclEnv() { - std::lock_guard lock(global_acl_env_mutex_); - std::shared_ptr acl_env = global_acl_env_; - if (acl_env != nullptr) { - MS_LOG(INFO) << "Acl has been initialized, skip."; - } else { - acl_env = std::make_shared(); - aclError ret = acl_env->GetErrno(); - if (ret != ACL_SUCCESS && ret != ACL_ERROR_REPEAT_INITIALIZE) { - MS_LOG(ERROR) << "Execute aclInit failed."; - return nullptr; - } - global_acl_env_ = acl_env; - MS_LOG(INFO) << "Execute aclInit success."; - } - return acl_env; -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/acl_graph_impl.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/acl_graph_impl.cc deleted file mode 100644 index 8f750e0f947b02e0277e5bccca5f97224796e139..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/acl_graph_impl.cc +++ /dev/null @@ -1,253 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cxx_api/graph/acl/acl_graph_impl.h" -#include "include/api/context.h" -#include "cxx_api/model/acl/model_converter.h" -#include "src/common/log_adapter.h" -#include "utils/convert_utils_base.h" -#include "cxx_api/acl_utils.h" -#include "plugin/res_manager/ascend/symbol_interface/acl_mdl_symbol.h" -#include "plugin/res_manager/ascend/symbol_interface/acl_rt_symbol.h" -#include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" - -namespace mindspore { -API_GRAPH_REG(Ascend310, AclGraphImpl); - -AclGraphImpl::AclGraphImpl() - : init_flag_(false), - load_flag_(false), - device_type_("AscendCL"), - device_id_(0), - context_(nullptr), - acl_env_(nullptr) {} - -AclGraphImpl::~AclGraphImpl() { - try { - (void)FinalizeEnv(); - } catch (const std::exception &e) { - MS_LOG(ERROR) << "AclGraphImpl destructor run failed, error message : " << e.what(); - } catch (...) { - MS_LOG(ERROR) << "AclGraphImpl destructor run failed, unknown error occurred."; - } -} - -Status AclGraphImpl::Run(const std::vector &inputs, std::vector *outputs) { - MS_EXCEPTION_IF_NULL(outputs); - Status ret = Load(IntToUint(device_id_)); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Prepare model resource failed."; - return ret; - } - - return model_process_.PredictFromHost(inputs, outputs); -} - -std::vector AclGraphImpl::GetInputs() { - Status ret = Load(IntToUint(device_id_)); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Prepare model resource failed."; - return {}; - } - - return model_process_.GetInputs(); -} - -std::vector AclGraphImpl::GetOutputs() { - Status ret = Load(IntToUint(device_id_)); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Prepare model resource failed."; - return {}; - } - - return model_process_.GetOutputs(); -} - -Status AclGraphImpl::LoadAclModel(const Buffer om_data) { - MS_LOG(INFO) << "Start load acl model."; - // acl load model - uint32_t acl_model_id; - auto acl_ret = CALL_ASCEND_API(aclmdlLoadFromMem, om_data.Data(), om_data.DataSize(), &acl_model_id); - if (acl_ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed."; - return kMCDeviceError; - } - - // acl init model resource - model_process_.set_model_id(acl_model_id); - Status ret = model_process_.PreInitModelResource(); - if (ret != kSuccess) { - (void)CALL_ASCEND_API(aclmdlUnload, acl_model_id); - MS_LOG(ERROR) << "Pre init model resource failed."; - return ret; - } - - MS_LOG(INFO) << "Load acl model success."; - return kSuccess; -} - -Status AclGraphImpl::InitEnv() { - if (init_flag_) { - return kSuccess; - } - - acl_env_ = AclEnvGuard::GetAclEnv(); - if (acl_env_ == nullptr) { - MS_LOG(ERROR) << "Acl init failed."; - return kMCDeviceError; - } - - aclError ret = CALL_ASCEND_API(aclrtSetDevice, device_id_); - if (ret != ACL_SUCCESS) { - MS_LOG(EXCEPTION) << "Device " << device_id_ << " call aclrtSetDevice failed, ret[" << static_cast(ret) << "]"; - } - MS_LOG(INFO) << "Open device " << device_id_ << " success"; - - ret = CALL_ASCEND_API(aclrtCreateContext, &context_, device_id_); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Acl create context failed"; - return kMCDeviceError; - } - MS_LOG(INFO) << "Create context success"; - - aclrtRunMode run_mode; - ret = CALL_ASCEND_API(aclrtGetRunMode, &run_mode); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Acl get run mode failed"; - return kMCDeviceError; - } - bool is_device = (run_mode == ACL_DEVICE); - model_process_.SetIsDevice(is_device); - MS_LOG(INFO) << "Get run mode success is device input/output " << is_device; - - MS_LOG(INFO) << "Init acl success, device id " << device_id_; - init_flag_ = true; - return kSuccess; -} - -Status AclGraphImpl::FinalizeEnv() { - if (!init_flag_) { - return kSuccess; - } - - aclError rt_ret = CALL_ASCEND_API(aclrtSetCurrentContext, context_); - if (rt_ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Set the ascend device context failed"; - return kMCDeviceError; - } - - Status ret = model_process_.UnLoad(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Unload model inner failed."; - return ret; - } - - if (context_ != nullptr) { - rt_ret = CALL_ASCEND_API(aclrtDestroyContext, context_); - if (rt_ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Destroy context failed"; - } - context_ = nullptr; - } - MS_LOG(INFO) << "End to destroy context"; - - rt_ret = CALL_ASCEND_API(aclrtResetDevice, device_id_); - if (rt_ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Reset device " << device_id_ << " failed"; - } - MS_LOG(INFO) << "End to reset device " << device_id_; - - init_flag_ = false; - return kSuccess; -} - -Status AclGraphImpl::Load(uint32_t device_id) { - // check graph type - if (graph_->ModelType() != ModelType::kOM) { - Status ret = ConvertToOM(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Load Failed."; - return ret; - } - } - - const auto &graph_data = GraphImpl::MutableGraphData(); - MS_EXCEPTION_IF_NULL(graph_data); - auto om_data = graph_data->GetOMData(); - - // init - device_id_ = UintToInt(device_id); - Status ret = InitEnv(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "InitEnv failed."; - return ret; - } - - // load model - if (!load_flag_) { - ret = LoadAclModel(om_data); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Load acl model failed."; - return ret; - } - load_flag_ = true; - } - - aclError rt_ret = CALL_ASCEND_API(aclrtSetCurrentContext, context_); - if (rt_ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Set the ascend device context failed"; - return kMCDeviceError; - } - - return kSuccess; -} - -Status AclGraphImpl::ConvertToOM() { - MS_LOG(INFO) << "Start convert to om model."; - if (graph_ == nullptr) { - MS_LOG(ERROR) << "Invalid graph_ is null."; - return kMCFailed; - } - - auto &graph_data = GraphImpl::MutableGraphData(); - MS_EXCEPTION_IF_NULL(graph_data); - if (graph_->ModelType() == ModelType::kOM) { - MS_LOG(INFO) << "This model has been built, skip."; - return kSuccess; - } else if (graph_->ModelType() == ModelType::kMindIR) { - auto func_graph = graph_data->GetFuncGraph(); - MS_EXCEPTION_IF_NULL(func_graph); - ModelConverter model_converter; - Buffer om_data = model_converter.LoadMindIR(func_graph); - if (om_data.Data() == nullptr || om_data.DataSize() == 0) { - MS_LOG(ERROR) << "Convert MindIR to OM failed."; - return kMCFailed; - } - graph_data = std::make_shared(om_data, ModelType::kOM); - MS_LOG(INFO) << "Convert MindIR to OM success."; - return kSuccess; - } - MS_LOG(ERROR) << "Unsupported ModelType " << graph_->ModelType(); - return kMCFailed; -} - -bool AclGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { - // for Ascend, only support kAscend and kAscend310 - if (device_type != kAscend && device_type != kAscend310) { - return false; - } - return IsAscendNo910Soc(); -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/model_process.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/model_process.cc deleted file mode 100644 index bfcb38a8ffec0cdb5702c76702ca1043333e45ae..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/model_process.cc +++ /dev/null @@ -1,541 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cxx_api/graph/acl/model_process.h" -#include -#include -#include -#include -#include - -#include "utils/convert_utils_base.h" -#include "include/api/data_type.h" -#include "src/common/log_adapter.h" -#include "plugin/res_manager/ascend/symbol_interface/acl_base_symbol.h" -#include "plugin/res_manager/ascend/symbol_interface/acl_mdl_symbol.h" -#include "plugin/res_manager/ascend/symbol_interface/acl_rt_symbol.h" -#include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" - -namespace mindspore { -static DataType TransToApiType(aclDataType data_type) { - static const std::map data_type_map = { - {ACL_FLOAT16, DataType::kNumberTypeFloat16}, {ACL_FLOAT, DataType::kNumberTypeFloat32}, - {ACL_DOUBLE, DataType::kNumberTypeFloat64}, {ACL_INT8, DataType::kNumberTypeInt8}, - {ACL_INT16, DataType::kNumberTypeInt16}, {ACL_INT32, DataType::kNumberTypeInt32}, - {ACL_INT64, DataType::kNumberTypeInt64}, {ACL_UINT8, DataType::kNumberTypeUInt8}, - {ACL_UINT16, DataType::kNumberTypeUInt16}, {ACL_UINT32, DataType::kNumberTypeUInt32}, - {ACL_UINT64, DataType::kNumberTypeUInt64}, {ACL_BOOL, DataType::kNumberTypeBool}, - }; - auto it = data_type_map.find(data_type); - if (it == data_type_map.end()) { - return DataType::kTypeUnknown; - } else { - return it->second; - } -} - -template -inline static void ClearIfNotNull(T *vec) { - if (vec != nullptr) { - vec->clear(); - } -} - -template > -inline static void PushbackIfNotNull(U *vec, T &&item) { - if (vec != nullptr) { - vec->emplace_back(item); - } -} - -static void ConstructTensorDesc(const std::vector &acl_tensor_list, std::vector *names, - std::vector> *shapes, std::vector *data_types, - std::vector *mem_sizes) { - ClearIfNotNull(names); - ClearIfNotNull(shapes); - ClearIfNotNull(data_types); - ClearIfNotNull(mem_sizes); - for (size_t i = 0; i < acl_tensor_list.size(); ++i) { - const auto &info = acl_tensor_list[i]; - PushbackIfNotNull(names, info.name); - PushbackIfNotNull(shapes, info.dims); - PushbackIfNotNull(data_types, TransToApiType(info.data_type)); - PushbackIfNotNull(mem_sizes, info.buffer_size); - } -} - -static std::string ShapeToString(const std::vector &shape) { - std::string result = "["; - for (size_t i = 0; i < shape.size(); ++i) { - result += std::to_string(shape[i]); - if (i + 1 < shape.size()) { - result += ", "; - } - } - result += "]"; - return result; -} - -Status ModelProcess::ConstructTensors(const std::vector &acl_tensor_list, - std::vector *tensor_list) const { - MS_EXCEPTION_IF_NULL(tensor_list); - std::vector names; - std::vector> shapes; - std::vector data_types; - std::vector mem_sizes; - - ConstructTensorDesc(acl_tensor_list, &names, &shapes, &data_types, &mem_sizes); - tensor_list->clear(); - if (names.size() != acl_tensor_list.size() || shapes.size() != acl_tensor_list.size() || - data_types.size() != acl_tensor_list.size() || mem_sizes.size() != acl_tensor_list.size()) { - MS_LOG(ERROR) << "Inner error, size do not match: names size " << names.size() << " shapes size " << shapes.size() - << " data types size " << data_types.size() << " mem sizes size " << mem_sizes.size() - << " acl_tensor_list size " << acl_tensor_list.size(); - return kMCFailed; - } - - aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST; - for (size_t i = 0; i < acl_tensor_list.size(); ++i) { - tensor_list->emplace_back(names[i], data_types[i], shapes[i], nullptr, mem_sizes[i]); - if (acl_tensor_list[i].cur_device_data == nullptr) { - // when run on device, cur_device_data is nullptr before first execute - continue; - } - auto ret = CALL_ASCEND_API(aclrtMemcpy, (*tensor_list)[i].MutableData(), (*tensor_list)[i].DataSize(), - acl_tensor_list[i].cur_device_data, acl_tensor_list[i].buffer_size, kind); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Memcpy input " << i << " from " << (is_run_on_device_ ? "host" : "device") - << " to host failed, memory size " << acl_tensor_list[i].buffer_size; - return kMCFailed; - } - } - - return kSuccess; -} - -Status ModelProcess::PreInitModelResource() { - model_desc_ = CALL_ASCEND_API(aclmdlCreateDesc); - aclError acl_ret = CALL_ASCEND_API(aclmdlGetDesc, model_desc_, model_id_); - if (acl_ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Read model desc failed"; - return kMCDeviceError; - } - Status ret = InitInputsBuffer(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Create input buffer failed"; - return ret; - } - ret = InitOutputsBuffer(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Create output buffer failed"; - return ret; - } - return kSuccess; -} - -Status ModelProcess::InitInputsBuffer() { - aclError ret; - size_t input_size = CALL_ASCEND_API(aclmdlGetNumInputs, model_desc_); - MS_LOG(INFO) << "input_size = " << input_size; - for (size_t i = 0; i < input_size; ++i) { - auto buffer_size = CALL_ASCEND_API(aclmdlGetInputSizeByIndex, model_desc_, i); - void *data_mem_buffer = nullptr; - if (!is_run_on_device_) { // need to copy input/output to/from device - ret = CALL_ASCEND_API(aclrtMalloc, &data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Malloc device input buffer failed , input size " << buffer_size; - return kMCDeviceError; - } - } - - aclmdlIODims dims; - ret = CALL_ASCEND_API(aclmdlGetInputDims, model_desc_, i, &dims); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Get input shape failed"; - if (!is_run_on_device_) { - (void)CALL_ASCEND_API(aclrtFree, data_mem_buffer); - } - return kMCDeviceError; - } - aclDataType data_type = CALL_ASCEND_API(aclmdlGetInputDataType, model_desc_, i); - std::vector shape(dims.dims, dims.dims + dims.dimCount); - const char *input_name_char = CALL_ASCEND_API(aclmdlGetInputNameByIndex, model_desc_, i); - std::string input_name = (input_name_char != nullptr) ? input_name_char : std::string(); - if (input_name.empty()) { - MS_LOG(WARNING) << "Get name of input " << i << " failed."; - } - MS_LOG(INFO) << "Name of input " << i << " is " << input_name; - input_infos_.emplace_back( - AclTensorInfo{data_mem_buffer, data_mem_buffer, buffer_size, data_type, shape, input_name}); - } - MS_LOG(INFO) << "Create model inputs success"; - return kSuccess; -} - -Status ModelProcess::CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) const { - MS_EXCEPTION_IF_NULL(data_mem_buffer); - aclError ret; - auto free_data_buffer = [this](void *dataMemBuffer) { - if (!is_run_on_device_) { - (void)CALL_ASCEND_API(aclrtFree, dataMemBuffer); - } else { - (void)CALL_ASCEND_API(aclrtFreeHost, dataMemBuffer); - } - }; - - if (!is_run_on_device_) { - ret = CALL_ASCEND_API(aclrtMalloc, data_mem_buffer, buffer_size, ACL_MEM_MALLOC_NORMAL_ONLY); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Malloc device buffer failed , buffer size " << buffer_size; - return kMCDeviceError; - } - } else { - ret = CALL_ASCEND_API(aclrtMallocHost, data_mem_buffer, buffer_size); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Malloc device buffer failed , buffer size " << buffer_size; - return kMCDeviceError; - } - } - - auto data_buffer = CALL_ASCEND_API(aclCreateDataBuffer, *data_mem_buffer, buffer_size); - if (data_buffer == nullptr) { - MS_LOG(ERROR) << "Create Data Buffer failed"; - free_data_buffer(*data_mem_buffer); - return kMCDeviceError; - } - ret = CALL_ASCEND_API(aclmdlAddDatasetBuffer, dataset, data_buffer); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "add data buffer failed"; - free_data_buffer(*data_mem_buffer); - (void)CALL_ASCEND_API(aclDestroyDataBuffer, data_buffer); - return kMCDeviceError; - } - return kSuccess; -} - -Status ModelProcess::InitOutputsBuffer() { - aclError ret; - outputs_ = CALL_ASCEND_API(aclmdlCreateDataset); - if (outputs_ == nullptr) { - MS_LOG(ERROR) << "Create input dataset failed"; - return kMCDeviceError; - } - size_t output_size = CALL_ASCEND_API(aclmdlGetNumOutputs, model_desc_); - MS_LOG(INFO) << "output_size = " << output_size; - for (size_t i = 0; i < output_size; ++i) { - auto buffer_size = CALL_ASCEND_API(aclmdlGetOutputSizeByIndex, model_desc_, i); - - void *data_mem_buffer = nullptr; - if (CreateDataBuffer(&data_mem_buffer, buffer_size, outputs_) != kSuccess) { - MS_LOG(ERROR) << "add output data buffer failed, buffer size " << buffer_size; - return kMCDeviceError; - } - aclmdlIODims dims; - ret = CALL_ASCEND_API(aclmdlGetOutputDims, model_desc_, i, &dims); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Get input shape failed"; - if (!is_run_on_device_) { - (void)CALL_ASCEND_API(aclrtFree, data_mem_buffer); - } else { - (void)CALL_ASCEND_API(aclrtFreeHost, data_mem_buffer); - } - return kMCDeviceError; - } - aclDataType data_type = CALL_ASCEND_API(aclmdlGetOutputDataType, model_desc_, i); - std::vector shape(dims.dims, dims.dims + dims.dimCount); - const char *output_name_char = CALL_ASCEND_API(aclmdlGetOutputNameByIndex, model_desc_, i); - std::string output_name = (output_name_char != nullptr) ? output_name_char : std::string(); - if (output_name.empty()) { - MS_LOG(WARNING) << "Get name of output " << i << " failed."; - } - MS_LOG(INFO) << "Name of input " << i << " is " << output_name; - output_infos_.emplace_back( - AclTensorInfo{data_mem_buffer, data_mem_buffer, buffer_size, data_type, shape, output_name}); - } - MS_LOG(INFO) << "Create model output success"; - return kSuccess; -} - -void ModelProcess::DestroyInputsDataset() { - if (inputs_ == nullptr) { - return; - } - for (size_t i = 0; i < CALL_ASCEND_API(aclmdlGetDatasetNumBuffers, inputs_); i++) { - auto dataBuffer = CALL_ASCEND_API(aclmdlGetDatasetBuffer, inputs_, i); - (void)CALL_ASCEND_API(aclDestroyDataBuffer, dataBuffer); - } - (void)CALL_ASCEND_API(aclmdlDestroyDataset, inputs_); - inputs_ = nullptr; -} - -void ModelProcess::DestroyInputsDataMem() { - if (!is_run_on_device_) { - for (const auto &item : input_infos_) { - (void)CALL_ASCEND_API(aclrtFree, item.device_data); - } - } - input_infos_.clear(); -} - -void ModelProcess::DestroyInputsBuffer() { - DestroyInputsDataMem(); - DestroyInputsDataset(); -} - -void ModelProcess::DestroyOutputsBuffer() { - for (const auto &item : output_infos_) { - if (!is_run_on_device_) { - (void)CALL_ASCEND_API(aclrtFree, item.device_data); - } else { - (void)CALL_ASCEND_API(aclrtFreeHost, item.device_data); - } - } - output_infos_.clear(); - - if (outputs_ == nullptr) { - return; - } - for (size_t i = 0; i < CALL_ASCEND_API(aclmdlGetDatasetNumBuffers, outputs_); i++) { - auto dataBuffer = CALL_ASCEND_API(aclmdlGetDatasetBuffer, outputs_, i); - (void)CALL_ASCEND_API(aclDestroyDataBuffer, dataBuffer); - } - (void)CALL_ASCEND_API(aclmdlDestroyDataset, outputs_); - outputs_ = nullptr; -} - -Status ModelProcess::UnLoad() { - auto ret = CALL_ASCEND_API(aclmdlUnload, model_id_); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Unload model failed"; - return kMCDeviceError; - } - if (model_desc_ != nullptr) { - ret = CALL_ASCEND_API(aclmdlDestroyDesc, model_desc_); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Unload model failed"; - return kMCDeviceError; - } - model_desc_ = nullptr; - } - DestroyInputsBuffer(); - DestroyOutputsBuffer(); - MS_LOG(INFO) << "End unload model " << model_id_; - return kSuccess; -} - -size_t ModelProcess::GetDynamicDims(const std::vector &inputs) const { - size_t max_num = 0; - for (auto input : inputs) { - size_t cur_num = LongToSize(std::count(input.dims.begin(), input.dims.end(), -1)); - if (cur_num > max_num) { - max_num = cur_num; - } - } - return max_num; -} - -Status ModelProcess::SetBatchSize(const std::vector &inputs) { - size_t index; - aclError ret; - for (size_t i = 0; i < inputs.size(); i++) { - input_infos_[i].buffer_size = inputs[i].DataSize(); - } - auto *p = static_cast(inputs[inputs.size() - 1].Data().get()); - MS_EXCEPTION_IF_NULL(p); - size_t dynamicBatchSize = FloatToSize(p[0]); - ret = CALL_ASCEND_API(aclmdlGetInputIndexByName, model_desc_, ACL_DYNAMIC_TENSOR_NAME, &index); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "get index failed"; - return kMCDeviceError; - } - ret = CALL_ASCEND_API(aclmdlSetDynamicBatchSize, model_id_, inputs_, index, dynamicBatchSize); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "dynamic batch set failed, modelId is " << model_id_; - return kMCDeviceError; - } - return kSuccess; -} - -Status ModelProcess::CheckAndInitInput(const std::vector &inputs) { - aclError ret; - inputs_ = CALL_ASCEND_API(aclmdlCreateDataset); - constexpr size_t dynamic_batch_size = 1; - constexpr size_t dynamic_image_size = 2; - size_t dynamic_nums = GetDynamicDims(input_infos_); - // check inputs - if (inputs.size() != input_infos_.size()) { - MS_LOG(ERROR) << "Inputs count not match, required count " << input_infos_.size() << ", given count " - << inputs.size(); - return kMCInvalidInput; - } - if (dynamic_nums == 0) { - for (size_t i = 0; i < input_infos_.size(); ++i) { - if (inputs[i].Shape() != input_infos_[i].dims) { - MS_LOG(INFO) << "Note: input " << i << " shape not match, required " << ShapeToString(input_infos_[i].dims) - << ", given " << ShapeToString(inputs[i].Shape()); - } - if (inputs[i].DataType() != TransToApiType(input_infos_[i].data_type)) { - MS_LOG(INFO) << "Note: input " << i << " data type not match, required " - << TransToApiType(input_infos_[i].data_type) << ", given " << inputs[i].DataType(); - } - if (inputs[i].DataSize() != input_infos_[i].buffer_size) { - MS_LOG(ERROR) << "Input " << i << " data size not match, required size " << input_infos_[i].buffer_size - << ", given count " << inputs[i].DataSize(); - return kMCInvalidInput; - } - } - } - // copy inputs - for (size_t i = 0; i < input_infos_.size(); ++i) { - auto &info = input_infos_[i]; - auto input = inputs[i]; - void *data = input.MutableData(); - void *input_buffer = nullptr; - if (!is_run_on_device_) { - if (input.IsDevice()) { - info.cur_device_data = data; - input_buffer = info.cur_device_data; - } else { - info.cur_device_data = info.device_data; - ret = CALL_ASCEND_API(aclrtMemcpy, info.cur_device_data, info.buffer_size, data, input.DataSize(), - ACL_MEMCPY_HOST_TO_DEVICE); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Acl memcpy input " << i << " data to device failed, buffer size " << input.DataSize(); - return kMCDeviceError; - } - input_buffer = info.cur_device_data; - } - } else { - input_buffer = data; - } - auto data_buffer = CALL_ASCEND_API(aclCreateDataBuffer, input_buffer, info.buffer_size); - if (data_buffer == nullptr) { - MS_LOG(ERROR) << "Create Data Buffer failed"; - return kMCDeviceError; - } - ret = CALL_ASCEND_API(aclmdlAddDatasetBuffer, inputs_, data_buffer); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "add data buffer failed"; - (void)CALL_ASCEND_API(aclDestroyDataBuffer, data_buffer); - return kMCDeviceError; - } - } - if (dynamic_nums == dynamic_batch_size) { - if (SetBatchSize(inputs) != kSuccess) { - MS_LOG(ERROR) << "failed to convert dynamic batch size"; - return kMCDeviceError; - } - if (ResetOutputSize() != kSuccess) { - MS_LOG(ERROR) << "reset output size failed"; - return kMCDeviceError; - } - } else if (dynamic_nums == dynamic_image_size) { - MS_LOG(ERROR) << "only dynamic batch size is supported"; - return kMCInvalidInput; - } - return kSuccess; -} - -Status ModelProcess::ResetOutputSize() { - aclDataType output_type; - aclError ret; - size_t output_size = CALL_ASCEND_API(aclmdlGetNumOutputs, model_desc_); - for (size_t index = 0; index < output_size; index++) { - int64_t dims = 1; - struct aclmdlIODims output_dims; - ret = CALL_ASCEND_API(aclmdlGetCurOutputDims, model_desc_, index, &output_dims); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "get output dim error."; - return kMCDeviceError; - } - for (size_t i = 0; i < output_dims.dimCount; i++) { - dims *= output_dims.dims[i]; - } - output_type = CALL_ASCEND_API(aclmdlGetOutputDataType, model_desc_, index); - output_infos_[index].buffer_size = LongToSize(dims) * CALL_ASCEND_API(aclDataTypeSize, output_type); - } - return kSuccess; -} - -Status ModelProcess::PredictFromHost(const std::vector &inputs, std::vector *outputs) { - MS_EXCEPTION_IF_NULL(outputs); - aclError acl_ret; - Status ret = CheckAndInitInput(inputs); - if (ret != kSuccess) { - MS_LOG(ERROR) << "check or init input failed"; - DestroyInputsDataset(); - return ret; // forward status error - } - - struct timeval start_time; - struct timeval end_time; - (void)gettimeofday(&start_time, nullptr); - acl_ret = CALL_ASCEND_API(aclmdlExecute, model_id_, inputs_, outputs_); - (void)gettimeofday(&end_time, nullptr); - constexpr uint64_t kUSecondInSecond = 1000000; - uint64_t cost = - (kUSecondInSecond * static_cast(end_time.tv_sec) + static_cast(end_time.tv_usec)) - - (kUSecondInSecond * static_cast(start_time.tv_sec) + static_cast(start_time.tv_usec)); - MS_LOG(INFO) << "Model execute in " << cost << " us"; - - DestroyInputsDataset(); - if (acl_ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Execute Model Failed"; - return kMCDeviceError; - } - ret = BuildOutputs(outputs); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Build outputs failed"; - return ret; - } - MS_LOG(INFO) << "Execute model success"; - return kSuccess; -} - -Status ModelProcess::BuildOutputs(std::vector *outputs) { - MS_EXCEPTION_IF_NULL(outputs); - // copy outputs - outputs->clear(); - auto inner_outputs = GetOutputs(); - if (inner_outputs.size() != output_infos_.size()) { - MS_LOG(ERROR) << "Invalid inner outputs size " << inner_outputs.size() << " do not match device output infos size " - << output_infos_.size(); - return kMCFailed; - } - (*outputs) = inner_outputs; - return kSuccess; -} - -std::vector ModelProcess::GetInputs() { - Status ret = ConstructTensors(input_infos_, &input_tensors_); - if (ret != kSuccess) { - MS_LOG(ERROR) << "ConstructTensors failed."; - input_tensors_.clear(); - } - - return input_tensors_; -} - -std::vector ModelProcess::GetOutputs() { - Status ret = ConstructTensors(output_infos_, &output_tensors_); - if (ret != kSuccess) { - MS_LOG(ERROR) << "ConstructTensors failed."; - output_tensors_.clear(); - } - - return output_tensors_; -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/model_process.h b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/model_process.h deleted file mode 100644 index 5e18287642b6ba1e820ef88a9ccba0bda6d55b2e..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/acl/model_process.h +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H -#define MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H -#include -#include -#include -#include "acl/acl.h" -#include "acl/acl_mdl.h" -#include "acl/acl_rt.h" -#include "include/api/status.h" -#include "include/api/types.h" - -namespace mindspore { -struct AclTensorInfo { - void *cur_device_data; - void *device_data; - size_t buffer_size; - aclDataType data_type; - std::vector dims; - std::string name; -}; - -class ModelProcess { - public: - ModelProcess() - : model_id_(0xffffffff), - is_run_on_device_(false), - model_desc_(nullptr), - inputs_(nullptr), - outputs_(nullptr), - input_infos_(), - output_infos_() {} - ~ModelProcess() {} - - Status UnLoad(); - Status PredictFromHost(const std::vector &inputs, std::vector *outputs); - Status PreInitModelResource(); - std::vector GetInputs(); - std::vector GetOutputs(); - - // override this method to avoid request/reply data copy - void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; } - - void set_model_id(uint32_t model_id) { model_id_ = model_id; } - uint32_t model_id() const { return model_id_; } - - private: - Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset) const; - Status CheckAndInitInput(const std::vector &inputs); - Status ConstructTensors(const std::vector &acl_tensor_list, std::vector *tensor_list) const; - Status BuildOutputs(std::vector *outputs); - Status SetBatchSize(const std::vector &inputs); - Status InitInputsBuffer(); - Status InitOutputsBuffer(); - Status ResetOutputSize(); - - void DestroyInputsDataset(); - void DestroyInputsDataMem(); - void DestroyInputsBuffer(); - void DestroyOutputsBuffer(); - - uint32_t model_id_; - // if run one device(AICPU), there is no need to alloc device memory and copy inputs to(/outputs from) device - bool is_run_on_device_; - aclmdlDesc *model_desc_; - aclmdlDataset *inputs_; - aclmdlDataset *outputs_; - std::vector input_infos_; - std::vector output_infos_; - std::vector input_tensors_; - std::vector output_tensors_; - size_t GetDynamicDims(const std::vector &inputs) const; -}; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/ascend/ascend_graph_impl.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/ascend/ascend_graph_impl.cc deleted file mode 100644 index 70009505ae9fb7e5be6e238ad679cf725c42c801..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/ascend/ascend_graph_impl.cc +++ /dev/null @@ -1,464 +0,0 @@ -/** - * Copyright 2020-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cxx_api/graph/ascend/ascend_graph_impl.h" -#include -#include "include/api/context.h" -#include "cxx_api/factory.h" -#include "cxx_api/akg_kernel_register.h" -#include "cxx_api/utils.h" -#include "cxx_api/acl_utils.h" -#include "src/common/log_adapter.h" -#include "base/base_ref_utils.h" -#include "backend/common/session/executor_manager.h" -#include "runtime/device/kernel_runtime_manager.h" -#include "include/common/utils/python_adapter.h" -#include "backend/common/session/session_basic.h" -#include "runtime/hardware/device_context_manager.h" -#include "include/backend/distributed/init.h" -#include "plugin/res_manager/ascend/symbol_interface/acl_rt_symbol.h" -#include "plugin/res_manager/ascend/symbol_interface/symbol_utils.h" - -namespace mindspore { -API_GRAPH_REG(kAscendDevice, AscendGraphImpl); -namespace { -constexpr auto kHcclEnable = "MS_ENABLE_HCCL"; -constexpr auto kHcclGroupFile = "PARA_GROUP_FILE"; - -void InitHccl() { - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - mindspore::python_adapter::set_python_env_flag(true); - // init hccl from distributed - if (!mindspore::distributed::Initialize()) { - MS_LOG(EXCEPTION) << "InitHccl failed."; - } - uint32_t device_id = ms_context->get_param(MS_CTX_DEVICE_ID); - if (ms_context->backend_policy() == "ms") { - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); - MS_EXCEPTION_IF_NULL(runtime_instance); -#ifndef ENABLE_SECURITY - runtime_instance->PreInit(); -#endif - const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( - {kAscendDevice, ms_context->get_param(MS_CTX_DEVICE_ID)}); - MS_EXCEPTION_IF_NULL(device_context); - MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); - (void)device_context->GetDeprecatedInterface()->OpenTsd(ms_context); - - if (!runtime_instance->Init()) { - MS_LOG(EXCEPTION) << "Runtime init failed."; - } - } -} -} // namespace -AscendGraphImpl::AscendGraphImpl() : device_type_("Ascend"), context_(nullptr) {} - -AscendGraphImpl::~AscendGraphImpl() {} - -Status AscendGraphImpl::InitEnv() { - MS_LOG(INFO) << "Start to init env."; - env_guard_ = MsEnvGuard::GetEnv(device_id_); - if (env_guard_ == nullptr) { - MS_LOG(ERROR) << "Env init failed."; - return kMCDeviceError; - } - - backend_ = std::make_shared(); - if (backend_ == nullptr) { - MS_LOG(ERROR) << "DeviceContext create failed!, please make sure target device:" << kAscendDevice - << " is available."; - return kMCFailed; - } - - MS_LOG(INFO) << "InitEnv success."; - return kSuccess; -} - -Status AscendGraphImpl::CompileGraph(const std::shared_ptr &func_graph) { - MS_ASSERT(backend_ != nullptr); - try { - MS_EXCEPTION_IF_NULL(func_graph); - // perpare func graph - auto manager = MakeManager(); - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(func_graph); - func_graph->set_manager(manager); - BackendJitConfig &backend_jit_config = backend::BackendJitConfig::ParseBackendJitConfig(); - graph_id_ = backend_->Build(func_graph, backend_jit_config); - kernel_graph_ = backend_->GetGraphById(graph_id_); - return kSuccess; - } catch (std::exception &e) { - MS_LOG(ERROR) << "CompileGraph failed: " << e.what(); - return kMCFailed; - } -} - -std::vector AscendGraphImpl::RunGraph(const std::vector &inputs) { - try { - VectorRef outputs; - backend_->Run(graph_id_, GraphImpl::GenerateInputsRef(inputs, func_graph_.lock()), &outputs); - return TransformVectorRefToMultiTensor(outputs); - } catch (std::exception &e) { - MS_LOG(ERROR) << "RunGraph failed: " << e.what(); - return std::vector(); - } -} - -Status AscendGraphImpl::ExecuteModel(const std::vector &request, std::vector *reply) { - MS_EXCEPTION_IF_NULL(reply); - if (context_ == nullptr) { - MS_LOG(ERROR) << "rtCtx is nullptr"; - return kMCDeviceError; - } - auto rt_ret = CALL_ASCEND_API(aclrtSetCurrentContext, context_); - if (rt_ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Set Ascend rtCtx failed"; - return kMCDeviceError; - } - - vector inputs; - for (size_t i = 0; i < request.size(); i++) { - auto item = request[i]; - auto input = inputs_info_[i]; - MS_EXCEPTION_IF_NULL(item); - MS_EXCEPTION_IF_NULL(input); - if (input->Size() != item.DataSize()) { - MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size " - << input->Size(); - return kMCInvalidInput; - } - auto ret = memcpy_s(input->data_c(), input->Size(), item.MutableData(), item.DataSize()); - if (ret != EOK) { - MS_LOG(ERROR) << "MSTensor copy failed"; - return kMCFailed; - } - inputs.push_back(input); - } - last_inputs_ = inputs; - std::vector outputs = RunGraph(inputs); - if (outputs.empty()) { - MS_LOG(ERROR) << "Execute Model Failed"; - return kMCFailed; - } - for (const auto &out : outputs) { - MS_EXCEPTION_IF_NULL(out); - out->data_sync(); - } - last_outputs_ = outputs; - reply->clear(); - *reply = GetOutputs(); - return kSuccess; -} - -std::vector AscendGraphImpl::GetInputs() { - if (!load_flag_) { - Status ret = Load(device_id_); - if (ret != kSuccess) { - MS_LOG(ERROR) << "PrepareModel failed."; - return {}; - } - } - - std::vector result(inputs_info_.size()); - for (size_t i = 0; i < inputs_info_.size(); ++i) { - auto &tensor = inputs_info_[i]; - MS_EXCEPTION_IF_NULL(tensor); - void *data = nullptr; - size_t data_size = tensor->Size(); - if (i < last_inputs_.size()) { - MS_EXCEPTION_IF_NULL(last_inputs_[i]); - data = last_inputs_[i]->data_c(); - data_size = last_inputs_[i]->Size(); - } - result[i] = - MSTensor(input_names_[i], static_cast(tensor->data_type()), tensor->shape(), data, data_size); - } - return result; -} - -std::vector AscendGraphImpl::GetOutputs() { - if (!load_flag_) { - Status ret = Load(device_id_); - if (ret != kSuccess) { - MS_LOG(ERROR) << "PrepareModel failed."; - return {}; - } - } - - std::vector result(outputs_info_.size()); - for (size_t i = 0; i < outputs_info_.size(); ++i) { - auto &tensor = outputs_info_[i]; - MS_EXCEPTION_IF_NULL(tensor); - void *data = nullptr; - size_t data_size = tensor->Size(); - if (i < last_outputs_.size()) { - MS_EXCEPTION_IF_NULL(last_outputs_[i]); - data = last_outputs_[i]->data_c(); - data_size = last_outputs_[i]->Size(); - } - result[i] = - MSTensor(output_names_[i], static_cast(tensor->data_type()), tensor->shape(), data, data_size); - } - return result; -} - -Status AscendGraphImpl::Load(uint32_t device_id) { - // check graph type - if (graph_->ModelType() != ModelType::kMindIR) { - MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); - return kMCInvalidInput; - } - - const auto &graph_data = GraphImpl::MutableGraphData(); - MS_EXCEPTION_IF_NULL(graph_data); - auto func_graph = graph_data->GetFuncGraph(); - func_graph_ = func_graph; - - // init - device_id_ = device_id; - Status ret = InitEnv(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "InitEnv failed."; - return ret; - } - - // load model - if (!load_flag_) { - ret = CompileGraph(func_graph); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Compile graph model failed"; - return ret; - } - auto kg = kernel_graph_.lock(); - MS_EXCEPTION_IF_NULL(backend_); - MS_EXCEPTION_IF_NULL(kg); - GraphImpl::GetModelInputsInfo(kg, &inputs_info_, &input_names_); - GraphImpl::GetModelOutputsInfo(kg, &outputs_info_, &output_names_); - if (inputs_info_.size() != input_names_.size()) { - MS_LOG(ERROR) << "Get model inputs info failed"; - return kMCInvalidInput; - } - if (outputs_info_.size() != output_names_.size()) { - MS_LOG(ERROR) << "Get model outputs info failed"; - return kMCInvalidInput; - } - - // save d context - auto rt_ret = CALL_ASCEND_API(aclrtGetCurrentContext, &context_); - if (rt_ret != ACL_SUCCESS || context_ == nullptr) { - MS_LOG(ERROR) << "the ascend device context is null"; - return kMCDeviceError; - } - - MS_LOG(INFO) << "Load model success"; - load_flag_ = true; - } - - auto rt_ret = CALL_ASCEND_API(aclrtSetCurrentContext, context_); - if (rt_ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Set the ascend device context failed"; - return kMCDeviceError; - } - - return kSuccess; -} - -Status AscendGraphImpl::Run(const std::vector &inputs, std::vector *outputs) { - MS_EXCEPTION_IF_NULL(outputs); - if (!load_flag_) { - Status ret = Load(device_id_); - if (ret != kSuccess) { - MS_LOG(ERROR) << "PrepareModel failed."; - return ret; - } - } - - if (inputs.size() != inputs_info_.size()) { - MS_LOG(ERROR) << "inputs count not match, required count " << inputs_info_.size() << ", given count " - << inputs.size(); - return kMCInvalidInput; - } - - for (size_t i = 0; i < inputs_info_.size(); ++i) { - if (inputs[i].DataSize() != inputs_info_[i]->Size()) { - MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_info_[i]->Size() - << ", given count " << inputs[i].DataSize(); - return kMCInvalidInput; - } - } - - Status ret = ExecuteModel(inputs, outputs); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Execute Model Failed"; - return ret; - } - if (outputs_info_.size() != outputs->size()) { - MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info " - << outputs_info_.size(); - return kMCFailed; - } - - return kSuccess; -} - -AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) : device_id_(device_id) { - MS_LOG(INFO) << "Start to init device " << device_id; - RegAllOp(); - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(ERROR) << "Get Context failed!"; - errno_ = kMCFailed; - return; - } - - auto env_hccl_mode = common::GetEnv(kHcclEnable); - if (!env_hccl_mode.empty() && env_hccl_mode != std::to_string(0)) { - MS_LOG(INFO) << "Enable hccl parallel mode."; - ms_context->set_param(MS_CTX_ENABLE_HCCL, true); - } - - ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); - ms_context->set_param_inner(MS_CTX_DEVICE_ID, device_id_); - ms_context->set_param(MS_CTX_DEVICE_TARGET, kAscendDevice); - ms_context->set_param(MS_CTX_IS_MULTI_GRAPH_SINK, true); - - if (ms_context->get_param(MS_CTX_ENABLE_HCCL)) { - InitHccl(); - auto para_group_file = common::GetEnv(kHcclGroupFile); - if (para_group_file.empty()) { - MS_LOG(INFO) << "Cannot get Env " << kHcclGroupFile << ", skip."; - } else { - MS_LOG(INFO) << "Get env " << kHcclGroupFile << " success: " << para_group_file; - if (!CreateGroupsByCkptFile(para_group_file)) { - MS_LOG(ERROR) << "CreateGroupsByCkptFile failed."; - errno_ = kMCFailed; - return; - } - } - } else { - auto ret = CALL_ASCEND_API(aclrtSetDevice, static_cast(device_id_)); - if (ret != ACL_SUCCESS) { - MS_LOG(EXCEPTION) << "Device " << device_id_ << " call aclrtSetDevice failed, ret[" << static_cast(ret) - << "]"; - } - } - - MS_LOG(INFO) << "Device " << device_id << " init env success."; - errno_ = kSuccess; -} - -AscendGraphImpl::MsEnvGuard::~MsEnvGuard() { - MS_LOG(INFO) << "Start finalize device " << device_id_; - try { - session::ExecutorManager::Instance().Clear(); - device::KernelRuntimeManager::Instance().ClearRuntimeResource(); - - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(ERROR) << "Get Context failed!"; - return; - } - - if (ms_context->get_param(MS_CTX_ENABLE_HCCL)) { - PythonEnvGuard guard; - const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( - {kAscendDevice, ms_context->get_param(MS_CTX_DEVICE_ID)}); - MS_EXCEPTION_IF_NULL(device_context); - MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface()); - if (!device_context->GetDeprecatedInterface()->CloseTsd(ms_context, false)) { - MS_LOG(ERROR) << "CloseTsd failed!"; - return; - } - } else { - auto ret = CALL_ASCEND_API(aclrtResetDevice, static_cast(device_id_)); - if (ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Device " << device_id_ << " call aclrtResetDevice failed, ret[" << static_cast(ret) - << "]"; - return; - } - } - } catch (const std::exception &e) { - MS_LOG(ERROR) << "AscendGraphImpl MsEnvGuard destructor run failed, error message : " << e.what(); - } catch (...) { - MS_LOG(ERROR) << "AscendGraphImpl MsEnvGuard destructor run failed, unknown error occurred."; - } - MS_LOG(INFO) << "End finalize device " << device_id_; -} - -std::shared_ptr AscendGraphImpl::MsEnvGuard::GetEnv(uint32_t device_id) { - std::shared_ptr acl_env; - std::lock_guard lock(global_ms_env_mutex_); - auto iter = global_ms_env_.find(device_id); - if (iter != global_ms_env_.end()) { - acl_env = iter->second.lock(); - } - - if (acl_env != nullptr) { - MS_LOG(INFO) << "Env has been initialized, skip."; - return acl_env; - } - - acl_env = std::make_shared(device_id); - if (acl_env->GetErrno() != kSuccess) { - MS_LOG(ERROR) << "Init ascend env Failed"; - return nullptr; - } - - global_ms_env_.emplace(device_id, acl_env); - MS_LOG(INFO) << "Env init success"; - return acl_env; -} - -bool AscendGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { - // for Ascend, only support kAscend and kAscend910 - if (device_type != kAscend && device_type != kAscend910) { - return false; - } - return IsAscend910Soc(); -} - -std::map> AscendGraphImpl::MsEnvGuard::global_ms_env_; -std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_; - -PythonEnvGuard::PythonEnvGuard() : origin_init_status_(PythonIsInited()) { InitPython(); } - -PythonEnvGuard::~PythonEnvGuard() { - // finalize when init by this - try { - if (!origin_init_status_) { - FinalizePython(); - } - } catch (const std::exception &e) { - MS_LOG(ERROR) << "PythonEnvGuard destructor run failed, error message : " << e.what(); - } catch (...) { - MS_LOG(ERROR) << "PythonEnvGuard destructor run failed, unknown error occurred."; - } -} - -bool PythonEnvGuard::PythonIsInited() const { return Py_IsInitialized() != 0; } - -void PythonEnvGuard::InitPython() const { - if (!PythonIsInited()) { - Py_Initialize(); - } -} - -void PythonEnvGuard::FinalizePython() const { - if (PythonIsInited()) { - Py_Finalize(); - } -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/ascend/ascend_graph_impl.h b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/ascend/ascend_graph_impl.h deleted file mode 100644 index f7b515bffcd0e77dbeb0ef9f888029aa400c3ff9..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/ascend/ascend_graph_impl.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H -#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H -#include -#include -#include -#include -#include -#include "include/api/status.h" -#include "include/api/graph.h" -#include "cxx_api/graph/graph_impl.h" -#include "ir/anf.h" -#include "cxx_api/model/model_impl.h" -#include "acl/acl_rt.h" - -namespace mindspore { -class AscendGraphImpl : public GraphCell::GraphImpl { - public: - AscendGraphImpl(); - ~AscendGraphImpl() override; - - Status Run(const std::vector &inputs, std::vector *outputs) override; - Status Load(uint32_t device_id) override; - std::vector GetInputs() override; - std::vector GetOutputs() override; - bool CheckDeviceSupport(mindspore::DeviceType device_type) override; - - private: - class MsEnvGuard; - - Status InitEnv(); - Status CompileGraph(const std::shared_ptr &func_graph); - std::vector RunGraph(const std::vector &inputs); - Status ExecuteModel(const std::vector &request, std::vector *reply); - - std::string device_type_; - aclrtContext context_; - - std::shared_ptr env_guard_; -}; - -class AscendGraphImpl::MsEnvGuard { - public: - explicit MsEnvGuard(uint32_t device_id); - ~MsEnvGuard(); - Status GetErrno() const { return errno_; } - static std::shared_ptr GetEnv(uint32_t device_id); - - private: - static std::map> global_ms_env_; - static std::mutex global_ms_env_mutex_; - - Status errno_; - uint32_t device_id_; -}; - -class PythonEnvGuard { - public: - PythonEnvGuard(); - ~PythonEnvGuard(); - - private: - bool PythonIsInited() const; - void InitPython() const; - void FinalizePython() const; - bool origin_init_status_; -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_ASCEND_GRAPH_IMPL_H diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/gpu/gpu_graph_impl.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/gpu/gpu_graph_impl.cc deleted file mode 100644 index 52682a572562290bd2d34a4a32bcdf0bf9921720..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/gpu/gpu_graph_impl.cc +++ /dev/null @@ -1,309 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cxx_api/graph/gpu/gpu_graph_impl.h" -#include -#include "include/api/context.h" -#include "cxx_api/factory.h" -#include "cxx_api/akg_kernel_register.h" -#include "src/common/log_adapter.h" -#include "base/base_ref_utils.h" -#include "backend/common/session/session_basic.h" -#include "backend/common/session/executor_manager.h" -#include "runtime/device/kernel_runtime_manager.h" -#include "plugin/res_manager/gpu/device/cuda_driver.h" - -namespace mindspore { -API_GRAPH_REG(kGPUDevice, GPUGraphImpl); - -GPUGraphImpl::GPUGraphImpl() : init_flag_(false), set_device_id_flag_(false) {} - -Status GPUGraphImpl::InitEnv() { - if (init_flag_) { - MS_LOG(WARNING) << "Initialized again, return success."; - return kSuccess; - } - - // Register op implemented with AKG. - RegAllOp(); - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(ERROR) << "Get Context failed!"; - return kMCFailed; - } - ms_context->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); - ms_context->set_param_inner(MS_CTX_DEVICE_ID, device_id_); - ms_context->set_param(MS_CTX_DEVICE_TARGET, kGPUDevice); - - // Set device id for sync data to host as cudaSetDevice is thread level config. - bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id_)); - if (!ret) { - MS_LOG(ERROR) << "Failed to set device id:" << device_id_; - return kMCDeviceError; - } - - MS_EXCEPTION_IF_NULL(graph_context_); - auto &device_infos = graph_context_->MutableDeviceInfo(); - if (device_infos.size() != 1) { - return kMCDeviceError; - } - MS_EXCEPTION_IF_NULL(device_infos[0]); - auto gpu_info = device_infos[0]->Cast(); - if (gpu_info == nullptr) { - return kMCDeviceError; - } - ms_context->set_param(MS_CTX_ENABLE_INFER_OPT, true); - ms_context->set_param(MS_CTX_INFER_PRECISION_MODE, gpu_info->GetPrecisionMode()); - - backend_ = std::make_shared(); - if (backend_ == nullptr) { - MS_LOG(ERROR) << "DeviceContext create failed!, please make sure target device:" << kGpuInferenceDevice - << " is available."; - return kMCFailed; - } - - init_flag_ = true; - return kSuccess; -} - -Status GPUGraphImpl::FinalizeEnv() { - if (!init_flag_) { - MS_LOG(WARNING) << "Never initialize before, return success"; - return kSuccess; - } - - MS_LOG(INFO) << "Start finalize env"; - session::ExecutorManager::Instance().Clear(); - device::KernelRuntimeManager::Instance().ClearRuntimeResource(); - - init_flag_ = false; - MS_LOG(INFO) << "End finalize env"; - return kSuccess; -} - -Status GPUGraphImpl::Load(uint32_t device_id) { - // check graph type - MS_EXCEPTION_IF_NULL(graph_); - if (graph_->ModelType() != ModelType::kMindIR) { - MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType(); - return kMCInvalidInput; - } - - const auto &graph_data = GraphImpl::MutableGraphData(); - MS_EXCEPTION_IF_NULL(graph_data); - auto func_graph = graph_data->GetFuncGraph(); - func_graph_ = func_graph; - - // init - device_id_ = device_id; - Status ret = InitEnv(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "InitEnv failed."; - return kMCDeviceError; - } - - ret = CompileGraph(func_graph); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Compile graph model failed"; - return kMCFailed; - } - auto kg = kernel_graph_.lock(); - MS_EXCEPTION_IF_NULL(backend_); - MS_EXCEPTION_IF_NULL(kg); - GraphImpl::GetModelInputsInfo(kg, &inputs_info_, &input_names_); - GraphImpl::GetModelOutputsInfo(kg, &outputs_info_, &output_names_); - if (inputs_info_.empty() || inputs_info_.size() != input_names_.size()) { - MS_LOG(ERROR) << "Get model inputs info failed"; - return kMCInvalidInput; - } - if (outputs_info_.empty() || outputs_info_.size() != output_names_.size()) { - MS_LOG(ERROR) << "Get model outputs info failed"; - return kMCInvalidInput; - } - load_flag_ = true; - return kSuccess; -} - -Status GPUGraphImpl::CompileGraph(const std::shared_ptr &func_graph) { - MS_ASSERT(backend_ != nullptr); - try { - MS_EXCEPTION_IF_NULL(func_graph); - // prepare func graph - auto manager = MakeManager(); - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(func_graph); - func_graph->set_manager(manager); - BackendJitConfig &backend_jit_config = backend::BackendJitConfig::ParseBackendJitConfig(); - graph_id_ = backend_->Build(func_graph, backend_jit_config); - kernel_graph_ = backend_->GetGraphById(graph_id_); - return kSuccess; - } catch (std::exception &e) { - MS_LOG(ERROR) << "CompileGraph failed: " << e.what(); - return kMCFailed; - } -} - -std::vector GPUGraphImpl::RunGraph(const std::vector &inputs) { - try { - VectorRef outputs; - backend_->Run(graph_id_, GraphImpl::GenerateInputsRef(inputs, func_graph_.lock()), &outputs); - return TransformVectorRefToMultiTensor(outputs); - } catch (std::exception &e) { - MS_LOG(ERROR) << "RunGraph failed: " << e.what(); - return std::vector(); - } -} - -Status GPUGraphImpl::ExecuteModel(const std::vector &request, std::vector *reply) { - MS_EXCEPTION_IF_NULL(reply); - - vector inputs; - for (size_t i = 0; i < request.size(); i++) { - auto &item = request[i]; - auto input = inputs_info_[i]; - MS_EXCEPTION_IF_NULL(input); - MS_EXCEPTION_IF_NULL(item); - if (input->Size() != item.DataSize()) { - MS_LOG(ERROR) << "Input " << i << " data size " << item.DataSize() << " not match model input data size " - << input->Size(); - return kMCInvalidInput; - } - auto ret = memcpy_s(input->data_c(), input->Size(), item.Data().get(), item.DataSize()); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Tensor copy failed"; - return kMCFailed; - } - inputs.push_back(input); - } - last_inputs_ = inputs; - std::vector outputs = RunGraph(inputs); - if (outputs.empty()) { - MS_LOG(ERROR) << "Execute Model Failed"; - return kMCFailed; - } - for (const auto &out : outputs) { - MS_EXCEPTION_IF_NULL(out); - out->data_sync(); - } - last_outputs_ = outputs; - reply->clear(); - *reply = GetOutputs(); - return kSuccess; -} - -Status GPUGraphImpl::Run(const std::vector &inputs, std::vector *outputs) { - MS_EXCEPTION_IF_NULL(outputs); - if (!load_flag_) { - Status ret = Load(device_id_); - if (ret != kSuccess) { - MS_LOG(ERROR) << "PrepareModel failed."; - return ret; - } - } - - // The `Load()` and `Run()` running in two threads. `Run()` always running in same thread. - // It should set device id once. - if (!set_device_id_flag_) { - bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id_)); - if (!ret) { - MS_LOG(ERROR) << "Failed to set device id:" << device_id_; - return kMCDeviceError; - } - set_device_id_flag_ = true; - } - - if (inputs.size() != inputs_info_.size()) { - MS_LOG(ERROR) << "inputs count not match, required count " << inputs_info_.size() << ", given count " - << inputs.size(); - return kMCInvalidInput; - } - - for (size_t i = 0; i < inputs_info_.size(); ++i) { - if (inputs[i].DataSize() != inputs_info_[i]->Size()) { - MS_LOG(ERROR) << "input " << i << " data size not match, required size " << inputs_info_[i]->Size() - << ", given count " << inputs[i].DataSize(); - return kMCInvalidInput; - } - } - if (ExecuteModel(inputs, outputs) != kSuccess) { - MS_LOG(ERROR) << "Execute Model Failed"; - return kMCFailed; - } - if (outputs_info_.size() != outputs->size()) { - MS_LOG(ERROR) << "Predict output size " << outputs->size() << " not match output size got from model info " - << outputs_info_.size(); - return kMCFailed; - } - - return kSuccess; -} - -std::vector GPUGraphImpl::GetInputs() { - if (!load_flag_) { - Status ret = Load(device_id_); - if (ret != kSuccess) { - MS_LOG(ERROR) << "PrepareModel failed."; - return {}; - } - } - - std::vector result(inputs_info_.size()); - for (size_t i = 0; i < inputs_info_.size(); ++i) { - auto &tensor = inputs_info_[i]; - MS_EXCEPTION_IF_NULL(tensor); - void *data = nullptr; - size_t data_size = tensor->Size(); - if (i < last_inputs_.size()) { - MS_EXCEPTION_IF_NULL(last_inputs_[i]); - data = last_inputs_[i]->data_c(); - data_size = last_inputs_[i]->Size(); - } - result[i] = - MSTensor(input_names_[i], static_cast(tensor->data_type()), tensor->shape(), data, data_size); - } - return result; -} - -std::vector GPUGraphImpl::GetOutputs() { - if (!load_flag_) { - Status ret = Load(device_id_); - if (ret != kSuccess) { - MS_LOG(ERROR) << "PrepareModel failed."; - return {}; - } - } - - std::vector result(outputs_info_.size()); - for (size_t i = 0; i < outputs_info_.size(); ++i) { - auto &tensor = outputs_info_[i]; - MS_EXCEPTION_IF_NULL(tensor); - void *data = nullptr; - size_t data_size = tensor->Size(); - if (i < last_outputs_.size()) { - MS_EXCEPTION_IF_NULL(last_outputs_[i]); - if (last_outputs_[i]->NeedSyncDeviceToHost()) { - last_outputs_[i]->data_sync(false); - } - data = last_outputs_[i]->data_c(); - data_size = last_outputs_[i]->Size(); - } - result[i] = - MSTensor(output_names_[i], static_cast(tensor->data_type()), tensor->shape(), data, data_size); - } - return result; -} - -bool GPUGraphImpl::CheckDeviceSupport(mindspore::DeviceType device_type) { return device_type == kGPU; } -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/gpu/gpu_graph_impl.h b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/gpu/gpu_graph_impl.h deleted file mode 100644 index f37b25232106bfa2ade65e381792f0929f91579f..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/gpu/gpu_graph_impl.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H -#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H -#include -#include -#include -#include -#include "include/api/status.h" -#include "include/api/graph.h" -#include "cxx_api/graph/graph_impl.h" -#include "ir/anf.h" -#include "cxx_api/model/model_impl.h" - -namespace mindspore { -class GPUGraphImpl : public GraphCell::GraphImpl { - public: - GPUGraphImpl(); - ~GPUGraphImpl() override = default; - - Status Run(const std::vector &inputs, std::vector *outputs) override; - Status Load(uint32_t device_id) override; - std::vector GetInputs() override; - std::vector GetOutputs() override; - - bool CheckDeviceSupport(mindspore::DeviceType device_type) override; - - private: - Status InitEnv(); - Status FinalizeEnv(); - Status CompileGraph(const std::shared_ptr &func_graph); - Status CheckModelInputs(const std::vector &inputs) const; - std::vector RunGraph(const std::vector &inputs); - Status ExecuteModel(const std::vector &inputs, std::vector *outputs); - - std::string device_type_; - bool init_flag_; - bool set_device_id_flag_; - - // tensor-rt - uint32_t batch_size_{0}; - uint32_t workspace_size_{0}; -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_GPU_GRAPH_IMPL_H diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph_impl.h b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph_impl.h deleted file mode 100644 index 27703ed96bf35f18caa9fff4b3497349dc1a7c1c..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/graph/graph_impl.h +++ /dev/null @@ -1,177 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H -#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H -#include -#include -#include -#include -#include -#include "include/api/cell.h" -#include "include/api/graph.h" -#include "cxx_api/graph/graph_data.h" -#include "include/common/utils/utils.h" -#include "base/base_ref_utils.h" -#include "include/backend/kernel_graph.h" -#include "backend/common/session/session_basic.h" -#include "backend/ms_backend/ms_backend.h" -#include "backend/backend_manager/backend_jit_config.h" - -namespace mindspore { -class GraphCell::GraphImpl { - public: - GraphImpl() - : graph_(nullptr), - graph_context_(nullptr), - backend_(nullptr), - kernel_graph_(), - device_id_(0), - inputs_info_(), - outputs_info_(), - input_names_(), - output_names_(), - load_flag_(false) {} - virtual ~GraphImpl() = default; - - std::shared_ptr &MutableGraphData() const { return graph_->graph_data_; } - void SetGraph(const std::shared_ptr &graph) { graph_ = graph; } - void SetContext(const std::shared_ptr &context) { graph_context_ = context; } - VectorRef GenerateInputsRef(const std::vector &inputs, const FuncGraphPtr &func_graph) { - VectorRef results; - std::size_t size = inputs.size(); - for (std::size_t i = 0; i < size; i++) { - results.push_back(inputs[i]); - } - - MS_EXCEPTION_IF_NULL(func_graph); - std::vector graph_params = func_graph->parameters(); - std::size_t graph_params_size = graph_params.size(); - if (results.size() != graph_params_size) { - // Maybe some default parameter - for (std::size_t i = results.size(); i < graph_params_size; i++) { - MS_EXCEPTION_IF_NULL(graph_params[i]); - auto param_ptr = (graph_params[i])->cast_ptr(); - MS_EXCEPTION_IF_NULL(param_ptr); - if (!param_ptr->has_default()) { - MS_LOG(INTERNAL_EXCEPTION) << "Parameter[" << i << "] has no default param"; - } - if (!param_ptr->default_param()->isa()) { - MS_LOG(INTERNAL_EXCEPTION) << "Parameter[" << param_ptr->ToString() - << "] is not initialized, need to call `.init_data()`"; - } - results.push_back(param_ptr->default_param()); - } - } - return results; - } - - uint32_t GetRootGraphIdFromActorInfo(const std::string &actor_info) { - const std::string prefix = "kernel_graph_"; - auto pos = actor_info.find(prefix); - if (pos == std::string::npos) { - MS_LOG(INTERNAL_EXCEPTION) << "Cannot find prefix " << prefix << " from actor_info" << actor_info - << ", failed to get graph id."; - } - std::string first_num = ""; - for (size_t i = prefix.size(); i < actor_info.size(); ++i) { - if (actor_info[i] >= '0' && actor_info[i] <= '9') { - first_num.push_back(actor_info[i]); - } else { - break; - } - } - return std::stoul(first_num); - } - - void GetModelInputsInfo(const std::shared_ptr &kernel_graph, - std::vector *inputs, std::vector *inputs_name) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(inputs); - MS_EXCEPTION_IF_NULL(inputs_name); - auto kernel_graph_inputs = kernel_graph->inputs(); - // find parameters of graph inputs - for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { - MS_EXCEPTION_IF_NULL(kernel_graph_inputs[i]); - if (!kernel_graph_inputs[i]->isa()) { - MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; - continue; - } - auto parameter = kernel_graph_inputs[i]->cast(); - if (!common::AnfAlgo::IsParameterWeight(parameter)) { - auto input_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); - auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); - MS_EXCEPTION_IF_NULL(kernel_build_info); - auto data_type = kernel_build_info->GetOutputDeviceType(0); - auto ms_tensor = std::make_shared(data_type, input_shape); - inputs->push_back(ms_tensor); - inputs_name->push_back(parameter->name()); - } - } - } - - void GetModelOutputsInfo(const std::shared_ptr &kernel_graph, - std::vector *outputs, std::vector *output_names) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(outputs); - MS_EXCEPTION_IF_NULL(output_names); - - std::vector inputs; - std::vector input_names; - GetModelInputsInfo(kernel_graph, &inputs, &input_names); - - VectorRef vector_outputs; - std::map tensor_to_node; - session::KernelMapTensor node_to_tensor; - auto anf_outputs = kernel_graph->outputs(); - for (auto &item : anf_outputs) { - MS_EXCEPTION_IF_NULL(item); - MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]"; - vector_outputs.emplace_back( - session::SessionBasic::CreateNodeOutputTensors(item, kernel_graph, inputs, &tensor_to_node, &node_to_tensor)); - } - *outputs = TransformVectorRefToMultiTensor(vector_outputs); - for (size_t i = 0; i < outputs->size(); i++) { - output_names->push_back("output" + std::to_string(i)); - } - } - - virtual Status Run(const std::vector &inputs, std::vector *outputs) = 0; - virtual Status Load(uint32_t device_id) = 0; - - virtual std::vector GetInputs() = 0; - virtual std::vector GetOutputs() = 0; - - virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0; - - protected: - std::shared_ptr graph_; - std::shared_ptr graph_context_; - - std::shared_ptr backend_; - uint32_t graph_id_{0}; - std::weak_ptr kernel_graph_; - std::weak_ptr func_graph_; - uint32_t device_id_; - std::vector inputs_info_; - std::vector outputs_info_; - std::vector last_inputs_; - std::vector last_outputs_; - std::vector input_names_; - std::vector output_names_; - bool load_flag_; -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model.cc deleted file mode 100644 index 497a887dd22e1bad2badb7493e887ab41e5e51f6..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model.cc +++ /dev/null @@ -1,190 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cxx_api/model/acl/acl_model.h" - -#include -#include -#include "include/api/context.h" -#include "cxx_api/factory.h" -#include "cxx_api/graph/acl/acl_env_guard.h" -#include "cxx_api/acl_utils.h" - -namespace mindspore { -Status AclModel::Build() { - MS_LOG(INFO) << "Start build model."; - MS_EXCEPTION_IF_NULL(graph_); - - if (graph_cell_ != nullptr) { - MS_LOG(INFO) << "This model has been built, skip."; - return kSuccess; - } - - std::shared_ptr options = std::make_shared(model_context_); - MS_EXCEPTION_IF_NULL(options); - - if (graph_cell_ == nullptr && graph_->ModelType() == ModelType::kOM) { - MS_LOG(INFO) << "Load om model and all build options will be ignored."; - graph_cell_ = std::make_shared(graph_); - MS_EXCEPTION_IF_NULL(graph_cell_); - auto ret = graph_cell_->Load(options->GetDeviceID()); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Load failed."; - return ret; - } - - options_ = std::move(options); - return kSuccess; - } - - std::string options_key = options->GenAclOptionsKey(); - std::shared_ptr graph; - if (auto iter = dynamic_size_graph_map_.find(options_key); iter != dynamic_size_graph_map_.end()) { - MS_LOG(INFO) << "This options has been built, read cache."; - graph = iter->second; - } else { - auto func_graph = ModelImpl::GetFuncGraph(); - MS_EXCEPTION_IF_NULL(func_graph); - auto inputs = func_graph->parameters(); - std::vector input_names; - for (auto node : inputs) { - auto para = node->cast(); - MS_EXCEPTION_IF_NULL(para); - std::string name = para->name(); - for (auto pos = name.find(':'); pos != std::string::npos; pos = name.find(':')) { - name = name.substr(0, pos) + "_" + name.substr(pos + 1); - MS_LOG(INFO) << name; - } - para->set_name(name); - input_names.push_back(name); - } - options->RenameInput(input_names); - MS_EXCEPTION_IF_NULL(func_graph); - model_converter_.set_options(options); - auto om_data = model_converter_.LoadMindIR(func_graph); - if (om_data.Data() == nullptr || om_data.DataSize() == 0) { - MS_LOG(ERROR) << "Load MindIR failed."; - return kMCFailed; - } - graph = std::make_shared(std::make_shared(om_data, ModelType::kOM)); - dynamic_size_graph_map_[options_key] = graph; - } - - MS_EXCEPTION_IF_NULL(graph); - auto graph_cell = std::make_shared(graph); - MS_EXCEPTION_IF_NULL(graph_cell); - auto ret = graph_cell->Load(options->GetDeviceID()); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Load failed."; - return ret; - } - - // save result - graph_cell_ = graph_cell; - options_ = std::move(options); - MS_LOG(INFO) << "Build model success."; - return kSuccess; -} - -Status AclModel::Resize(const std::vector &inputs, const std::vector> &dims) { - MS_LOG(INFO) << "Start to resize model."; - MS_EXCEPTION_IF_NULL(graph_); - if (graph_->ModelType() == ModelType::kOM) { - MS_LOG(ERROR) << "OM model is not supported to resize model."; - return kMCFailed; - } - - auto origin_inputs = GetInputs(); - if (inputs.size() != origin_inputs.size()) { - MS_LOG(ERROR) << "Invalid inputs size " << inputs.size() << " not match model inputs size " << origin_inputs.size(); - return kMCInvalidInput; - } - - if (inputs.size() != dims.size()) { - MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match inputs size " << inputs.size(); - return kMCInvalidInput; - } - - if (model_context_ == nullptr) { - model_context_ = std::make_shared(); - (void)model_context_->MutableDeviceInfo().emplace_back(std::make_shared()); - } - - std::string input_shape_option; - for (size_t i = 0; i < inputs.size(); ++i) { - if (inputs[i].Name() != origin_inputs[i].Name()) { - MS_LOG(ERROR) << "Invalid inputs " << i << " name " << inputs[i].Name() << " not match model input name " - << origin_inputs[i].Name(); - return kMCInvalidInput; - } - input_shape_option += inputs[i].Name() + ":"; - for (size_t j = 0; j < dims[i].size(); ++j) { - input_shape_option += std::to_string(dims[i][j]); - if (j + 1 < dims[i].size()) { - input_shape_option += ","; - } - } - if (i + 1 < inputs.size()) { - input_shape_option += ";"; - } - } - MS_LOG(INFO) << "Set input size option is " << input_shape_option; - auto &device_infos = model_context_->MutableDeviceInfo(); - if (device_infos.size() != 1) { - MS_LOG(ERROR) << "Invalid model context, only single device info is supported."; - return kMCInvalidArgs; - } - auto ascend310_info = device_infos[0]->Cast(); - MS_EXCEPTION_IF_NULL(ascend310_info); - ascend310_info->SetInputShape(input_shape_option); - auto graph_cell_bak = std::move(graph_cell_); - auto ret = Build(); - if (ret != kSuccess) { - MS_LOG(INFO) << "Resize build failed."; - graph_cell_ = std::move(graph_cell_bak); - return ret; - } - MS_LOG(INFO) << "Resize success."; - return kSuccess; -} - -std::vector AclModel::GetInputs() { - MS_EXCEPTION_IF_NULL(graph_cell_); - return graph_cell_->GetInputs(); -} - -std::vector AclModel::GetOutputs() { - MS_EXCEPTION_IF_NULL(graph_cell_); - return graph_cell_->GetOutputs(); -} - -bool AclModel::CheckDeviceSupport(mindspore::DeviceType device_type) { - // for Ascend, only support kAscend and kAscend310 - if (device_type != kAscend && device_type != kAscend310) { - return false; - } - return IsAscendNo910Soc(); -} - -bool AclModel::CheckModelSupport(enum ModelType model_type) { - static const std::set kSupportedModelMap = {kMindIR, kOM}; - auto iter = kSupportedModelMap.find(model_type); - if (iter == kSupportedModelMap.cend()) { - return false; - } - return true; -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.cc deleted file mode 100644 index fade358054e2a1e863f2a02c720df5ca4812f4b3..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_multi.cc +++ /dev/null @@ -1,268 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cxx_api/model/acl/acl_model_multi.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include "cxx_api/factory.h" -#include "acl/acl_rt.h" -#include "load_mindir/infer_mindir.h" -#include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h" -#include "cxx_api/model/acl/acl_vm/acl_vm.h" - -namespace mindspore { -API_MODEL_REG(Ascend310, AclModelMulti); - -namespace { -std::map kDtypeMap = { - {DataType::kNumberTypeBool, sizeof(bool)}, {DataType::kNumberTypeInt8, sizeof(int8_t)}, - {DataType::kNumberTypeInt16, sizeof(int16_t)}, {DataType::kNumberTypeInt32, sizeof(int32_t)}, - {DataType::kNumberTypeInt64, sizeof(int64_t)}, {DataType::kNumberTypeFloat16, sizeof(float16)}, - {DataType::kNumberTypeFloat32, sizeof(float)}, {DataType::kNumberTypeFloat64, sizeof(double)}, - {DataType::kNumberTypeUInt8, sizeof(uint8_t)}, {DataType::kNumberTypeUInt16, sizeof(uint16_t)}, - {DataType::kNumberTypeUInt32, sizeof(uint32_t)}, {DataType::kNumberTypeUInt64, sizeof(uint64_t)}, - {DataType::kNumberTypeBFloat16, sizeof(bfloat16)}}; - -std::shared_ptr CreateBackend(const std::shared_ptr &options) { - MS_EXCEPTION_IF_NULL(options); - return std::make_shared(kMsConvert, kDavinciMultiGraphInferenceDevice, options); -} - -bool HasMultiGraph(const FuncGraphPtr &fg) { - MS_EXCEPTION_IF_NULL(fg); - std::vector all_nodes = TopoSort(fg->get_return()); - for (const auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (IsValueNode(node)) { - MS_LOG(INFO) << fg->ToString() << " has FuncGraph node " << node->DebugString() << " is multi graph."; - return true; - } - } - return false; -} -} // namespace -Status AclModelMulti::Build() { - if (!is_multi_graph_.has_value()) { - is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph()); - } - - if (!is_multi_graph_.value()) { - return AclModel::Build(); - } - - if (vm_ != nullptr) { - MS_LOG(INFO) << "Multi graph model has been built, skip."; - return kSuccess; - } - MS_LOG(INFO) << "Start build multi graph model."; - // perpare func graph - auto manager = MakeManager(); - manager->AddFuncGraph(ModelImpl::GetFuncGraph()); - ModelImpl::GetFuncGraph()->set_manager(manager); - // set inputs - SetInputs(); - // infer mindir - abstract::AbstractBasePtrList broaded_args; - auto fg = ModelImpl::GetFuncGraph(); - MS_EXCEPTION_IF_NULL(fg); - const auto &inputs = fg->get_inputs(); - (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(broaded_args), - [](const AnfNodePtr &n) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(n); - auto abstract = n->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - if (abstract->GetValueTrack() != kValueAny) { - return abstract->Broaden(); - } - return abstract; - }); - try { - (void)InferMindir(ModelImpl::GetFuncGraph(), broaded_args); - } catch (const std::runtime_error &e) { - MS_LOG(ERROR) << "Infer mindir for sub graph failed: " << e.what(); - return kMCFailed; - } - - // set output - SetOutput(); - // create vm - auto backend = CreateBackend(std::make_shared(model_context_)); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - backend->set_is_multi_graph_sink(false); - context_ptr->set_param(MS_CTX_DEVICE_TARGET, kDavinciMultiGraphInferenceDevice); - context_ptr->set_param(MS_CTX_IS_MULTI_GRAPH_SINK, false); - context_ptr->set_param(MS_CTX_ENABLE_LOOP_SINK, false); - auto compile = std::make_shared(backend, compile::GetMsNonlinearOps()); - - vm_ = compile->CompileAndLink(ModelImpl::GetFuncGraph()); - backend_ = std::move(backend); - MS_LOG(INFO) << "Build multi graph model success."; - return kSuccess; -} - -Status AclModelMulti::Predict(const std::vector &inputs, std::vector *outputs) { - if (!is_multi_graph_.has_value()) { - is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph()); - } - - if (!is_multi_graph_.value()) { - return AclModel::Predict(inputs, outputs); - } - - auto ret = Build(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Build multi-graph model as default options failed."; - return ret; - } - MS_LOG(INFO) << "Start predict multi graph model."; - MS_EXCEPTION_IF_NULL(vm_); - MS_EXCEPTION_IF_NULL(outputs); - try { - (*outputs) = MSTensorRef::Convert(vm_->Eval(MSTensorRef::Convert(inputs))); - } catch (const std::exception &ex) { - MS_LOG(ERROR) << "Predict Failed, error: " << ex.what(); - return kMCFailed; - } - - if (inputs_.empty()) { - inputs_ = inputs; - } else { - if (inputs.size() != inputs_.size()) { - MS_LOG(ERROR) << "Input Size is wrong."; - return kMCFailed; - } - for (size_t i = 0; i < inputs_.size(); ++i) { - auto input_tensor = MSTensor::CreateTensor(inputs_[i].Name(), inputs_[i].DataType(), inputs_[i].Shape(), - inputs[i].Data().get(), inputs[i].DataSize()); - inputs_[i] = (*input_tensor); - MSTensor::DestroyTensorPtr(input_tensor); - } - } - - outputs_ = *outputs; - MS_LOG(INFO) << "Predict multi graph model success."; - return kSuccess; -} - -void AclModelMulti::SetInputs() { - if (inputs_.empty()) { - auto fg = ModelImpl::GetFuncGraph(); - MS_EXCEPTION_IF_NULL(fg); - const auto &inputs = fg->get_inputs(); - for (const auto &in : inputs) { - auto input_param = std::dynamic_pointer_cast(in); - MS_EXCEPTION_IF_NULL(input_param); - auto input_abs = input_param->abstract(); - MS_EXCEPTION_IF_NULL(input_abs); - auto tensor_abs = input_abs->cast(); - if (tensor_abs == nullptr) { - MS_LOG(EXCEPTION) << "The graph input type is not a tensor. input args info:" << input_abs->ToString(); - } - auto shape_ptr = tensor_abs->BuildShape(); - MS_EXCEPTION_IF_NULL(shape_ptr); - auto tensor_shape = shape_ptr->cast(); - MS_EXCEPTION_IF_NULL(tensor_shape); - auto elem = tensor_abs->element(); - MS_EXCEPTION_IF_NULL(elem); - auto type_id = elem->BuildType()->type_id(); - auto tensor = std::make_shared(type_id, tensor_shape->shape()); - - std::vector shape = tensor->shape_c(); - auto input_tensor = MSTensor::CreateTensor(input_param->name(), static_cast(tensor->data_type_c()), - shape, nullptr, tensor->Size()); - inputs_.emplace_back(*input_tensor); - MSTensor::DestroyTensorPtr(input_tensor); - } - } else { - MS_LOG(DEBUG) << "inputs_ has been set."; - } -} - -void AclModelMulti::SetOutput() { - if (outputs_.empty()) { - auto fg = ModelImpl::GetFuncGraph(); - MS_EXCEPTION_IF_NULL(fg); - const auto output = fg->output(); - MS_EXCEPTION_IF_NULL(output); - auto abs = output->abstract(); - MS_EXCEPTION_IF_NULL(abs); - - // DataType - DataType type_id; - if (abs->isa()) { - auto abs_tensor = abs->cast(); - auto ele = abs_tensor->element(); - MS_EXCEPTION_IF_NULL(ele); - MS_EXCEPTION_IF_NULL(ele->GetTypeTrack()); - type_id = static_cast(ele->GetTypeTrack()->type_id()); - } else { - MS_EXCEPTION_IF_NULL(abs->GetTypeTrack()); - type_id = static_cast(abs->GetTypeTrack()->type_id()); - } - // Shape - auto shape_track = abs->GetShapeTrack(); - MS_EXCEPTION_IF_NULL(shape_track); - std::vector shape = {}; - if (shape_track->isa()) { - auto shapeptr = shape_track->cast(); - shape = static_cast>(shapeptr->shape()); - } - // Size - size_t ato_size = 0; - if (kDtypeMap.find(type_id) != kDtypeMap.end()) { - ato_size = kDtypeMap[type_id]; - } - int64_t ele_num = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - size_t size = ato_size * LongToSize(ele_num); - // create tensor - auto output_tensor = MSTensor::CreateTensor("", type_id, shape, nullptr, size); - outputs_.emplace_back(*output_tensor); - MSTensor::DestroyTensorPtr(output_tensor); - } else { - MS_LOG(DEBUG) << "outputs_ has been set."; - } -} - -std::vector AclModelMulti::GetInputs() { - if (!is_multi_graph_.has_value()) { - is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph()); - } - - if (!is_multi_graph_.value()) { - return AclModel::GetInputs(); - } - - return inputs_; -} - -std::vector AclModelMulti::GetOutputs() { - if (!is_multi_graph_.has_value()) { - is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph()); - } - - if (!is_multi_graph_.value()) { - return AclModel::GetOutputs(); - } - - return outputs_; -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/acl_multi_graph_session.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/acl_multi_graph_session.cc deleted file mode 100644 index 46f366631a6ff0e71ae877b782de4ee9e015d827..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/acl_multi_graph_session.cc +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cxx_api/model/acl/acl_vm/acl_multi_graph_session.h" -#include -#include -#include -#include "mindspore/ops/op_def/sequence_ops.h" -#include "mindspore/ops/op_def/framework_ops.h" -#include "backend/common/session/session_factory.h" -#include "include/backend/optimizer/optimizer.h" -#include "backend/backend_manager/backend_jit_config.h" -#ifdef ENABLE_D -#include "runtime/hardware/device_context_manager.h" -#endif -#include "cxx_api/model/acl/model_converter.h" -#include "cxx_api/model/acl/acl_model_options.h" -#include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h" -#include "cxx_api/graph/graph_data.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" - -namespace mindspore::session { -void MultiGraphAclSession::Init(uint32_t device_id) { InitExecutor(kDavinciMultiGraphInferenceDevice, device_id); } - -GraphId MultiGraphAclSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { - class FirstGraphModeGuard { - public: - explicit FirstGraphModeGuard(const std::shared_ptr &options) : options_(options) { - if (options_ != nullptr) { - options_->SetFirstGraph(true); - } - } - ~FirstGraphModeGuard() { - if (options_ != nullptr) { - options_->SetFirstGraph(false); - } - } - - private: - std::shared_ptr options_; - }; - MS_LOG(INFO) << "Start MultiGraph Compile."; - // construct kernel graph - auto kernel_graph = - SessionBasic::ConstructKernelGraph(lst, outputs, device::DeviceType::kUnknown, backend::BackendJitConfig(), false); - MS_EXCEPTION_IF_NULL(kernel_graph); -#ifndef ENABLE_D - auto optimizer = std::make_shared(); - auto pm = std::make_shared("310_multi_graph_pm"); - optimizer->AddPassManager(pm); - (void)optimizer->Optimize(kernel_graph); -#endif - kernel_graph->SetExecOrderByDefault(); - // concert to om data - ModelConverter model_converter_; - model_converter_.set_options(options_); - FirstGraphModeGuard guard(options_); - auto om_data = model_converter_.LoadMindIR(kernel_graph); - if (om_data.Data() == nullptr || om_data.DataSize() == 0) { - MS_LOG(EXCEPTION) << "Load MindIR failed."; - } - // load - std::shared_ptr graph = std::make_shared(std::make_shared(om_data, ModelType::kOM)); - MS_EXCEPTION_IF_NULL(graph); - auto graph_cell = GraphCell(graph); - auto ret = graph_cell.Load(options_->GetDeviceID()); - if (ret != kSuccess) { - MS_LOG(EXCEPTION) << "Load failed."; - } - graph_cells_[kernel_graph->graph_id()] = graph_cell; - kernel_graphs_[kernel_graph->graph_id()] = kernel_graph; - MS_LOG(INFO) << "Multi graph compile success, graph id " << kernel_graph->graph_id(); - return kernel_graph->graph_id(); -} - -void MultiGraphAclSession::RunGraph(GraphId graph_id, const std::vector &inputs, VectorRef *outputs) { - MS_EXCEPTION_IF_NULL(outputs); - MS_LOG(INFO) << "Start run graph " << graph_id; - auto iter = graph_cells_.find(graph_id); - if (iter == graph_cells_.cend()) { - MS_LOG(INTERNAL_EXCEPTION) << "Graph id " << graph_id << " not found."; - } - std::vector out_tensors; - auto ret = iter->second.Run(inputs, &out_tensors); - if (ret != kSuccess) { - MS_LOG(EXCEPTION) << "Graph id " << graph_id << " run failed."; - } - - std::deque out_tensors_deque(out_tensors.begin(), out_tensors.end()); - (*outputs) = ConstructOutputRef(graph_id, &out_tensors_deque); -} - -VectorRef MultiGraphAclSession::ConstructOutputRef(GraphId graph_id, std::deque *out_tensors) { - MS_EXCEPTION_IF_NULL(out_tensors); - VectorRef outs; - auto out_nodes = kernel_graphs_[graph_id]->outputs(); - for (auto &out : out_nodes) { - auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType( - out, 0, false, std::vector{prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimStateSetItem}); - auto &anf_node = item_with_index.first; - if (common::AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) { - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors)); - } else if (AnfUtils::IsRealKernel(anf_node)) { - if (out_tensors->empty()) { - MS_LOG(INTERNAL_EXCEPTION) << "Can not find MSTensor for output node " << out->DebugString() - << ", visited: " << anf_node->DebugString(); - } - outs.emplace_back(MSTensorRef(out_tensors->front())); - out_tensors->pop_front(); - } - } - - if (!out_tensors->empty()) { - MS_LOG(EXCEPTION) << "Number of output size " << outs.size() << " but " << out_tensors->size() - << " MSTensor remained."; - } - - return outs; -} - -VectorRef MultiGraphAclSession::ConstructOutputRefByTupleNode(const CNodePtr &tuple_node, - std::deque *out_tensors) { - MS_EXCEPTION_IF_NULL(out_tensors); - VectorRef outs; - for (size_t i = 1; i < tuple_node->size(); ++i) { - auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType( - tuple_node->input(i), 0, false, - std::vector{prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimStateSetItem}); - auto &anf_node = item_with_index.first; - if (common::AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimMakeTuple)) { - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors)); - } else if (AnfUtils::IsRealKernel(anf_node)) { - if (out_tensors->empty()) { - MS_LOG(INTERNAL_EXCEPTION) << "Can not find MSTensor for output node " << tuple_node->input(i)->DebugString() - << ", visited: " << anf_node->DebugString(); - } - outs.emplace_back(MSTensorRef(out_tensors->front())); - out_tensors->pop_front(); - } - } - - return outs; -} -MS_REG_SESSION(kDavinciMultiGraphInferenceDevice, MultiGraphAclSession); -} // namespace mindspore::session diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/acl_multi_graph_session.h b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/acl_multi_graph_session.h deleted file mode 100644 index eb14107e797c866f28d5a7dee319a99d421cbcaa..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/acl_multi_graph_session.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_MULTI_GRAPH_SESSION_H -#define MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_MULTI_GRAPH_SESSION_H - -#include -#include -#include -#include -#include "include/api/types.h" -#include "include/api/cell.h" -#include "backend/common/session/session_basic.h" - -namespace mindspore { -class AclModelOptions; -namespace session { -class MultiGraphAclSession : public session::SessionBasic { - public: - MultiGraphAclSession() = default; - ~MultiGraphAclSession() override = default; - void Init(uint32_t device_id) override; - void RunGraph(GraphId graph_id, const std::vector &inputs, VectorRef *outputs); - void SetOptions(const std::shared_ptr &options) { options_ = options; } - - protected: - GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; - - private: - VectorRef ConstructOutputRef(GraphId graph_id, std::deque *out_tensors); - VectorRef ConstructOutputRefByTupleNode(const CNodePtr &tuple_node, std::deque *out_tensors); - - std::map graph_cells_ = {}; - std::map kernel_graphs_ = {}; - std::shared_ptr options_ = nullptr; -}; -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_MULTI_GRAPH_SESSION_H diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/acl_vm.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/acl_vm.cc deleted file mode 100644 index b5bce35b892d77c70bb046b3727416157373003b..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/acl_vm.cc +++ /dev/null @@ -1,306 +0,0 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cxx_api/model/acl/acl_vm/acl_vm.h" -#include -#include -#include -#include "mindspore/ops/op_def/framework_ops.h" -#include "cxx_api/model/acl/acl_model_options.h" -#include "cxx_api/model/acl/acl_vm/acl_multi_graph_session.h" -#include "utils/trace_base.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" - -namespace mindspore { -namespace { -inline bool IsMonadNode(const AnfNodePtr &node) { - if (IsPrimitiveCNode(node, prim::kPrimStateSetItem) || IsPrimitiveCNode(node, prim::kPrimUpdateState)) { - return true; - } - - if (HasAbstractMonad(node)) { - return true; - } - - return false; -} - -std::vector ParseVectorMsTensorRef(const VectorRef &args) { - std::vector ms_tensors; - for (const auto &arg : args) { - if (utils::isa(arg)) { - auto ret = ParseVectorMsTensorRef(utils::cast(arg)); - (void)ms_tensors.insert(ms_tensors.end(), ret.begin(), ret.end()); - } else if (utils::isa(arg)) { - auto wrapper = utils::cast(arg); - (void)ms_tensors.emplace_back(wrapper.GetTensor()); - } else { - MS_LOG(INTERNAL_EXCEPTION) << "Invalid item " << arg.ToString(); - } - } - return ms_tensors; -} -} // namespace -AclBackend::AclBackend(const std::string &name, const std::string &target, - const std::shared_ptr &options) - : MsBackend(name, target, options->GetDeviceID()) { - auto session = std::dynamic_pointer_cast(MsBackend::target_sess_); - MS_EXCEPTION_IF_NULL(session); - session->SetOptions(options); -} - -VectorRef AclBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string & /* target */) { - std::vector inputs = ParseVectorMsTensorRef(args); - VectorRef outputs; - MS_EXCEPTION_IF_NULL(target_sess_); - auto exec_sess = std::dynamic_pointer_cast(target_sess_); - MS_EXCEPTION_IF_NULL(exec_sess); - exec_sess->RunGraph(g, inputs, &outputs); - return outputs; -} - -bool AclBackend::GetCond(const BaseRef &c, bool *value) { - MS_EXCEPTION_IF_NULL(value); - if (!utils::isa(c)) { - MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef."; - return false; - } - auto wrapper = utils::cast(c); - if (wrapper.GetTensor().DataType() != DataType::kNumberTypeBool) { - MS_LOG(ERROR) << "Invalid data type " << wrapper.GetTensor().DataType() << " must be bool."; - return false; - } - auto data = wrapper.GetTensor().Data(); - if (data == nullptr) { - return false; - } - (*value) = *static_cast(data.get()); - return true; -} - -bool AclBackend::GetIndex(const BaseRef &c, int64_t *value) { - MS_EXCEPTION_IF_NULL(value); - if (!utils::isa(c)) { - MS_LOG(ERROR) << "Invalid item " << c.ToString() << " must be a MSTensorRef."; - return false; - } - - auto wrapper = utils::cast(c); - if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt32) { - auto data = wrapper.GetTensor().Data(); - if (data == nullptr) { - return false; - } - auto value_int32 = *static_cast(data.get()); - (*value) = static_cast(value_int32); - return true; - } else if (wrapper.GetTensor().DataType() == DataType::kNumberTypeInt64) { - auto data = wrapper.GetTensor().Data(); - if (data == nullptr) { - return false; - } - (*value) = *static_cast(data.get()); - return true; - } else { - MS_LOG(ERROR) << "Index must be Int type."; - return false; - } -} - -AclCompileGraph::AclCompileGraph(const std::shared_ptr &backend, - const std::vector &cut_list) - : CompileGraph(backend, cut_list) {} - -void AclCompileGraph::AddInst(const compile::Instruction &inst, const MSTensorRef &arg) { - VectorRef args; - args.push_back(arg); - compile::CompileGraph::AddInst(inst, args); -} - -int64_t AclCompileGraph::Ref(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_; - if (slots_.count(node) == 0 && node->isa()) { - if (IsValueNode(node)) { - MS_LOG(DEBUG) << "Push graph."; - compile::CompileGraph::AddInst(compile::Instruction::kGraph, GetValueNode(node)); - } else { - MS_LOG(DEBUG) << "Push."; - if (IsValueNode(node)) { - MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfoStr(node->debug_info()); - } else if (IsValueNode(node)) { - auto tensor_node = std::dynamic_pointer_cast(node->cast()->value()); - MS_EXCEPTION_IF_NULL(tensor_node); - std::string name = ""; - std::vector shape = tensor_node->shape_c(); - DataType type = static_cast(tensor_node->data_type_c()); - auto mstensor_node = MSTensor::CreateRefTensor(name, type, shape, tensor_node->data_c(), tensor_node->Size()); - MSTensorRef mstensor_ref(*mstensor_node); - AddInst(compile::Instruction::kPush, mstensor_ref); - MSTensor::DestroyTensorPtr(mstensor_node); - } else { - compile::CompileGraph::AddInst(compile::Instruction::kPush, GetValueNode(node)); - } - } - Push(node); - } else if (auto const_parameter = dyn_cast(node); - slots_.count(node) == 0 && const_parameter != nullptr && const_parameter->has_default()) { - auto value = const_parameter->default_param(); - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - auto tensor_node = std::dynamic_pointer_cast(value); - MS_EXCEPTION_IF_NULL(tensor_node); - std::vector shape = tensor_node->shape_c(); - DataType type = static_cast(tensor_node->data_type_c()); - auto mstensor_node = - MSTensor::CreateRefTensor(const_parameter->name(), type, shape, tensor_node->data_c(), tensor_node->Size()); - MSTensorRef mstensor_ref(*mstensor_node); - AddInst(compile::Instruction::kPush, mstensor_ref); - MSTensor::DestroyTensorPtr(mstensor_node); - } else { - compile::CompileGraph::AddInst(compile::Instruction::kPush, value); - } - Push(node); - } - MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node] - << ", return: " << slots_[node] - height_; - return slots_[node] - height_; -} - -void AclCompileGraph::AddExternal(const compile::LinConvertResult &result) { - VectorRef args; - args.push_back(result.run); - args.push_back(result.simu_run); - size_t size = result.inputs.size(); - for (size_t i = 0; i < size; ++i) { - const auto &input = result.inputs[i]; - MS_EXCEPTION_IF_NULL(input); - if (auto parameter = dyn_cast(input); parameter != nullptr && parameter->has_default()) { - MS_LOG(DEBUG) << parameter->DebugString() << " has default value, will not be pushed as inputs."; - continue; - } - if (IsMonadNode(input)) { - MS_LOG(DEBUG) << input->DebugString() << " is monad node, will not be pushed as inputs."; - continue; - } - args.emplace_back(Ref(input)); - } - compile::CompileGraph::AddInst(compile::Instruction::kExternal, args); - size_t out_count = 0; - for (auto &out : result.outputs) { - if (IsMonadNode(out)) { - continue; - } - ++out_count; - Push(out); - } - MS_LOG(DEBUG) << "Args size " << args.size() << " out size " << out_count; -} - -void AclCompileGraph::AddInput(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (IsMonadNode(node)) { - return; - } - if (slots_.count(node) == 0) { - MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true); - (void)Ref(node); - return; - } - compile::CompileGraph::AddInst(compile::Instruction::kInput, Ref(node)); - set_height(height_ + 1); -} - -void AclCompileGraph::AddPartial(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto inputs = node->inputs(); - VectorRef args; - if (inputs.size() <= 1) { - MS_LOG(EXCEPTION) << "The node:" << node->DebugString() << "do not have two input."; - } - auto fn = inputs[1]; - if (!IsValueNode(fn)) { - MS_LOG(INTERNAL_EXCEPTION) << "The type of 1st input of node must be FuncGraph"; - } - for (size_t i = 1; i < inputs.size(); i++) { - if (IsMonadNode(inputs[i])) { - continue; - } - args.emplace_back(Ref(inputs[i])); - } - compile::CompileGraph::AddInst(compile::Instruction::kPartial, args); -} - -int64_t AclCompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto inputs = node->inputs(); - AnfNodePtr fn = inputs[0]; - (void)Ref(fn); - size_t size = inputs.size(); - size_t non_monad_size = size; - for (size_t i = size - 1; i > 0; --i) { - if (IsMonadNode(inputs[i])) { - --non_monad_size; - continue; - } - AddInput(inputs[i]); - } - if (node == graph->output()) { - AddTailCall(fn, non_monad_size); - return RET_BREAK; - } - MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << (non_monad_size - 1); - compile::CompileGraph::AddInst(compile::Instruction::kCall, Ref(fn)); - Ret(static_cast(non_monad_size - 1)); - for (size_t i = size - 1; i > 0; i--) { - const auto iter = slots_.find(inputs[i]); - if (iter != slots_.end() && iter->second >= height_) { - slots_.erase(inputs[i]); - } - } - return RET_SUCCESS; -} - -void AclCompileGraph::PushParameters(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - std::vector parameters = func_graph->parameters(); - for (size_t i = parameters.size(); i != 0; i--) { - MS_EXCEPTION_IF_NULL(parameters[i - 1]); - auto param = parameters[i - 1]->cast(); - MS_EXCEPTION_IF_NULL(param); - if (param->has_default()) { - MS_LOG(DEBUG) << "Parameter " << (i - 1) << ": " << param->DebugString() << " has default value, skip."; - continue; - } - if (IsMonadNode(param)) { - MS_LOG(DEBUG) << "Parameter " << (i - 1) << ": " << param->DebugString() << " has monad type, skip."; - continue; - } - Push(param); - MS_LOG(DEBUG) << "Push parameter " << (i - 1) << ": " << param->DebugString(); - } -} - -AclCompileGraphs::AclCompileGraphs(const std::shared_ptr &backend, - const std::vector &cut_list) - : CompileGraphs(backend, cut_list) { - MS_EXCEPTION_IF_NULL(backend); - MS_LOG(DEBUG) << "Start vm: " << backend->name(); - transform_ = std::make_shared(backend, cut_list); - Reset(); -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/acl_vm.h b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/acl_vm.h deleted file mode 100644 index 249875e872ed727e04168750503d6aafcbf4fb5e..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/acl_vm.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_VM_H -#define MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_VM_H - -#include -#include -#include -#include "backend/graph_compiler/transform.h" -#include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h" - -namespace mindspore { -class AclModelOptions; -class AclBackend : public compile::MsBackend { - public: - AclBackend(const std::string &name, const std::string &target, const std::shared_ptr &options); - ~AclBackend() override = default; - - VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string & /* target */) override; - bool GetCond(const BaseRef &c, bool *value) override; - bool GetIndex(const BaseRef &c, int64_t *value) override; -}; - -class AclCompileGraph : public compile::CompileGraph { - public: - explicit AclCompileGraph(const std::shared_ptr &backend, - const std::vector &cut_list); - ~AclCompileGraph() override = default; - - int64_t Ref(const AnfNodePtr &node) override; - - protected: - void AddExternal(const compile::LinConvertResult &result) override; - void AddInput(const AnfNodePtr &node) override; - void AddPartial(const CNodePtr &node) override; - int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node) override; - void PushParameters(const FuncGraphPtr &func_graph) override; - void AddInst(const compile::Instruction &inst, const MSTensorRef &arg); -}; - -class AclCompileGraphs : public compile::CompileGraphs { - public: - explicit AclCompileGraphs(const std::shared_ptr &backend, - const std::vector &cut_list); - ~AclCompileGraphs() override = default; -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_ACL_VM_ACL_VM_H diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/ms_tensor_ref.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/ms_tensor_ref.cc deleted file mode 100644 index e64de8c37c77af1e723b8a7ba65ffb958311a205..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_vm/ms_tensor_ref.cc +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cxx_api/model/acl/acl_vm/ms_tensor_ref.h" - -#include - -namespace mindspore { -VectorRef MSTensorRef::Convert(const std::vector &tensors) { - VectorRef res; - (void)std::transform(tensors.begin(), tensors.end(), std::back_inserter(res), - [](const MSTensor &t) { return MSTensorRef(t); }); - return res; -} - -std::vector MSTensorRef::Convert(const BaseRef &args) { - std::vector res; - if (utils::isa(args)) { - VectorRef args_vec = utils::cast(args); - res = ConvertTuple(args_vec); - } else if (utils::isa(args)) { - auto wrapper = utils::cast(args); - res.push_back(wrapper.ms_tensor_); - } else { - MS_LOG(INTERNAL_EXCEPTION) << "Invalid BaseRef " << args.ToString() - << " must be MSTensorRef or VectorRef{MSTensorRef...}"; - } - - return res; -} - -std::shared_ptr MSTensorRef::copy() const { - MSTensor *tensor = ms_tensor_.Clone(); - auto res = std::make_shared(static_cast(*tensor)); - MSTensor::DestroyTensorPtr(tensor); - return res; -} - -bool MSTensorRef::operator==(const BaseRef &other) const { - if (!utils::isa(other)) { - return false; - } - auto other_ms_tensor = utils::cast(other).ms_tensor_; - auto this_ms_tensor = ms_tensor_; - return (this_ms_tensor.Name() == other_ms_tensor.Name()) && (this_ms_tensor.Shape() == other_ms_tensor.Shape()) && - (this_ms_tensor.MutableData() == other_ms_tensor.MutableData()) && - (this_ms_tensor.DataSize() == other_ms_tensor.DataSize()) && - (this_ms_tensor.DataType() == other_ms_tensor.DataType()); -} - -std::vector MSTensorRef::ConvertTuple(const VectorRef &args) { - std::vector outs; - for (size_t i = 0; i < args.size(); ++i) { - const auto &item = args[i]; - if (utils::isa(item)) { - VectorRef args_vec = utils::cast(args); - auto ret = ConvertTuple(args_vec); - (void)outs.insert(outs.end(), ret.begin(), ret.end()); - } else if (utils::isa(item)) { - auto wrapper = utils::cast(item); - outs.push_back(wrapper.ms_tensor_); - } else { - MS_LOG(INTERNAL_EXCEPTION) << "Invalid BaseRef " << args.ToString() - << " must be MSTensorRef or VectorRef{MSTensorRef...}"; - } - } - return outs; -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model.cc deleted file mode 100644 index 23ddf4bc4687099785c8f6b553900ad51062caf5..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model.cc +++ /dev/null @@ -1,196 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "include/api/model.h" -#include "include/api/context.h" -#include "cxx_api/model/model_impl.h" -#include "cxx_api/factory.h" - -namespace mindspore { -Status Model::Build(GraphCell graph_cell, const std::shared_ptr &model_context, - const std::shared_ptr &) { - if (graph_cell.GetGraph() == nullptr) { - MS_LOG(ERROR) << "Invalid graph input."; - return kMCInvalidInput; - } - - if (model_context == nullptr) { - MS_LOG(ERROR) << "Invalid model context."; - return kMCInvalidInput; - } - auto &device_info = model_context->MutableDeviceInfo(); - if (device_info.size() != 1) { - MS_LOG(ERROR) << "Invalid model context, only single device info is supported."; - return kMCInvalidInput; - } - - auto device_target = device_info[0]->GetDeviceType(); - impl_ = ModelImplFactory::Instance().Create(device_target); - if (impl_ == nullptr) { - MS_LOG(ERROR) << "Create session type " << device_target << " failed"; - return kMEFailed; - } - - g_device_target = device_target; - - impl_->SetGraph(std::make_shared(*graph_cell.GetGraph())); - impl_->SetContext(model_context); - - return impl_->Build(); -} - -Status Model::Build(const std::vector &, ModelType, const std::shared_ptr &, const Key &, - const std::vector &, const std::vector &) { - MS_LOG(ERROR) << "Unsupported Feature."; - return kMCFailed; -} - -Status Model::Build(const std::vector &, ModelType, const std::shared_ptr &) { - MS_LOG(ERROR) << "Unsupported Feature."; - return kMCFailed; -} - -Status Model::Build(const void * /* model_data */, size_t /* data_size */, ModelType /* model_type */, - const std::shared_ptr & /* model_context */, const Key & /* dec_key */, - const std::vector & /* dec_mode */, const std::vector & /* cropto_lib_path */) { - MS_LOG(ERROR) << "Unsupported Feature."; - return kMCFailed; -} - -Status Model::Resize(const std::vector &inputs, const std::vector> &dims) { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "Failed because this model has not been built."; - return kMCFailed; - } - return impl_->Resize(inputs, dims); -} - -Status Model::Predict(const std::vector &inputs, std::vector *outputs, - const MSKernelCallBack & /* before */, const MSKernelCallBack & /* after */) { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "Failed because this model has not been built."; - return kMCFailed; - } - return impl_->Predict(inputs, outputs); -} - -Status Model::PredictWithPreprocess(const std::vector> &inputs, std::vector *outputs, - const MSKernelCallBack & /* before */, const MSKernelCallBack & /* after */) { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "Failed because this model has not been built."; - return kMCFailed; - } - return impl_->PredictWithPreprocess(inputs, outputs); -} - -Status Model::Preprocess(const std::vector> &inputs, std::vector *outputs) { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "Failed because this model has not been built."; - return kMCFailed; - } - return impl_->Preprocess(inputs, outputs); -} - -bool Model::HasPreprocess() { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "Failed because this model has not been built."; - return false; - } - return impl_->HasPreprocess(); -} - -std::vector Model::GetInputs() { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "Failed because this model has not been built."; - return {}; - } - return impl_->GetInputs(); -} - -std::vector Model::GetOutputs() { - if (impl_ == nullptr) { - MS_LOG(ERROR) << "Failed because this model has not been built."; - return {}; - } - return impl_->GetOutputs(); -} - -MSTensor Model::GetInputByTensorName(const std::vector &tensor_name) { - std::string tensor_name_str = CharToString(tensor_name); - auto inputs = GetInputs(); - for (auto in : inputs) { - if (in.Name() == tensor_name_str) { - return in; - } - } - - return MSTensor(nullptr); -} - -std::vector> Model::GetOutputTensorNamesChar() { - std::vector> ret; - auto outputs = GetOutputs(); - std::transform(outputs.begin(), outputs.end(), std::back_inserter(ret), - [](const MSTensor &item) -> std::vector { return StringToChar(item.Name()); }); - return ret; -} - -MSTensor Model::GetOutputByTensorName(const std::vector &tensor_name) { - std::string tensor_name_str = CharToString(tensor_name); - auto outputs = GetOutputs(); - for (auto out : outputs) { - if (out.Name() == tensor_name_str) { - return out; - } - } - - return MSTensor(nullptr); -} - -std::vector Model::GetOutputsByNodeName(const std::vector &node_name) { - return std::vector{GetOutputByTensorName(node_name)}; -} - -Model::Model() : impl_(nullptr) {} -Model::~Model() {} - -bool Model::CheckModelSupport(DeviceType device_type, ModelType model_type) { - auto check_model = ModelImplFactory::Instance().Create(device_type); - if (check_model == nullptr) { - return false; - } - return check_model->CheckModelSupport(model_type); -} - -Status Model::LoadConfig(const std::vector & /* config_path */) { - MS_LOG(ERROR) << "Unsupported Feature."; - return kMCFailed; -} - -std::vector Model::GetModelInfo(const std::vector &key) { - MS_LOG(WARNING) << "mindspore inference does not support get model info"; - std::vector empty; - return empty; -} - -#ifdef _MSC_VER -Status Model::UpdateConfig(const std::vector §ion, - const std::pair, std::vector> &config) { - MS_LOG(ERROR) << "Model::UpdateConfig Unsupported on msvc."; - return kMCFailed; -} - -#endif -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_converter_utils/multi_process.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_converter_utils/multi_process.cc deleted file mode 100644 index fd7026db0a687e0c8df404b71601ca834696569d..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_converter_utils/multi_process.cc +++ /dev/null @@ -1,247 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cxx_api/model/model_converter_utils/multi_process.h" -#include -#include -#include -#include -#include -#include "src/common/log_adapter.h" -#include "cxx_api/model/model_converter_utils/shared_memory.h" - -namespace mindspore { -namespace { -constexpr uint64_t kSharedMemorySize = 100ull << 20; // 100 MB -constexpr timespec kOneMillisecond = { - 0, // 0 seconds - 1 * 1000L * 1000L, // And 1 ms -}; - -constexpr timespec kOneHundredMilliseconds = { - 0, // 0 seconds - 100 * 1000L * 1000L, // And 100 ms -}; -} // namespace - -MultiProcess::MultiProcess() = default; - -MultiProcess::~MultiProcess() = default; - -Status MultiProcess::MainProcess(const ProcessFuncCall &parent_process, const ProcessFuncCall &child_process) { - MS_EXCEPTION_IF_NULL(parent_process); - MS_EXCEPTION_IF_NULL(child_process); - Status ret; - memory_size_ = kSharedMemorySize; // 100 MB - SharedMemory shared_memory; - ret = shared_memory.Create(memory_size_); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Create shared memory failed"; - return ret; - } - pid_t pid = fork(); - if (pid < 0) { - shared_memory.Destroy(); - MS_LOG(ERROR) << "Fork process to convert model failed"; - return kMEFailed; - } - ret = shared_memory.Attach(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Process attach shared memory failed, pid " << pid; - return ret; - } - shmat_addr_ = shared_memory.GetSharedMemoryAddr(); - if (shmat_addr_ == nullptr) { - MS_LOG(ERROR) << "Get shared memory failed"; - return ret; - } - constexpr size_t kMsgStructNum = 2; - shmat_data_addr_ = shmat_addr_ + sizeof(MessageFlag) * kMsgStructNum; - shmat_data_max_size_ = - memory_size_ - (reinterpret_cast(shmat_data_addr_) - reinterpret_cast(shmat_addr_)); - MS_LOG(INFO) << "Shm addr " << reinterpret_cast(shmat_addr_); - if (pid == 0) { - ChildProcess(child_process); - shared_memory.Detach(); - MS_LOG(INFO) << "Model converter: child process sleep waiting for exit signal."; - while (1) { - // waiting for signal - } - } else { // parent process - ret = ParentProcess(parent_process); - shared_memory.Detach(); - - MS_LOG(INFO) << "Model converter: parent process kills child of fork."; - (void)kill(pid, SIGKILL); - constexpr uint32_t kMaxLoopCount = 5; - bool child_exited = false; - for (uint32_t i = 0; i < kMaxLoopCount; ++i) { - int status; - if (waitpid(pid, &status, WNOHANG) == pid) { - MS_LOG(INFO) << "Child process " << pid << " exits success."; - child_exited = true; - break; - } - (void)sleep(1); - } - if (!child_exited) { - MS_LOG(WARNING) << "Child process " << pid << " has been killed but waitpid failed."; - } - shared_memory.Destroy(); - } - return ret; -} - -Status MultiProcess::ParentProcess(const ProcessFuncCall &parent_process) { - auto parent_msg = reinterpret_cast(shmat_addr_); - auto child_msg = reinterpret_cast(shmat_addr_ + sizeof(MessageFlag)); - send_msg_ = parent_msg; - receive_msg_ = child_msg; - std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this); - Status ret; - try { - ret = parent_process(this); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Parent process process failed"; - } - } catch (const std::runtime_error &ex) { - MS_LOG(ERROR) << "Catch parent process runtime error: " << ex.what(); - ret = kMEFailed; - } - stopped_ = true; - send_msg_->stop = 1; - heartbeat_thread.join(); - return ret; -} - -void MultiProcess::ChildProcess(const ProcessFuncCall &child_process) { - auto parent_msg = reinterpret_cast(shmat_addr_); - auto child_msg = reinterpret_cast(shmat_addr_ + sizeof(MessageFlag)); - send_msg_ = child_msg; - receive_msg_ = parent_msg; - std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this); - try { - MS_EXCEPTION_IF_NULL(child_process); - auto ret = child_process(this); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Child process process failed"; - } - } catch (const std::runtime_error &ex) { - MS_LOG(ERROR) << "Catch child process runtime error: " << ex.what(); - } - stopped_ = true; - send_msg_->stop = 1; - heartbeat_thread.join(); -} - -Status MultiProcess::SendMsg(const void *buffer, uint64_t msg_len) { - MS_EXCEPTION_IF_NULL(buffer); - MS_LOG(INFO) << "Start to send message to peer process, msg len " << msg_len; - send_msg_->msg_total_len = msg_len; - uint64_t cur_offset = 0; - while (msg_len > cur_offset) { - uint64_t sub_msg_len = std::min(msg_len - cur_offset, shmat_data_max_size_); - if (sub_msg_len == 0) { - MS_LOG(ERROR) << "Invalid message len " << sub_msg_len; - return kMEFailed; - } - auto ret = - memcpy_s(shmat_data_addr_, shmat_data_max_size_, static_cast(buffer) + cur_offset, sub_msg_len); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s failed, ret = " << ret; - return kMEFailed; - } - cur_offset += sub_msg_len; - - send_msg_->msg_len = sub_msg_len; - send_msg_->read_finish_flag = 0; - send_msg_->read_ready_flag = 1; - MS_LOG(INFO) << "Send start " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len; - while (!send_msg_->read_finish_flag && !peer_stopped_) { - (void)nanosleep(&kOneMillisecond, nullptr); // 1ms - } - if (peer_stopped_) { - if (!send_msg_->read_finish_flag) { - return kMEFailed; - } - break; - } - MS_LOG(INFO) << "Send end " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len; - } - MS_LOG(INFO) << "End to send message to peer process, msg len " << msg_len; - return kSuccess; -} - -Status MultiProcess::ReceiveMsg(const CreateBufferCall &create_buffer_call) const { - uint64_t cur_offset = 0; - uint8_t *msg_buffer = nullptr; - uint64_t msg_len = 0; - do { - MS_LOG(INFO) << "Receive start from " << cur_offset; - while (!receive_msg_->read_ready_flag && !peer_stopped_) { - (void)nanosleep(&kOneMillisecond, nullptr); // 1ms - } - if (peer_stopped_) { - return kMEFailed; - } - if (msg_buffer == nullptr) { - msg_len = receive_msg_->msg_total_len; - msg_buffer = create_buffer_call(msg_len); - } - MS_EXCEPTION_IF_NULL(msg_buffer); - size_t dest_max = std::min(shmat_data_max_size_, msg_len - cur_offset); - auto ret = memcpy_s(msg_buffer + cur_offset, dest_max, shmat_data_addr_, receive_msg_->msg_len); - if (ret != EOK) { - MS_LOG(INFO) << "memcpy_s failed, ret = " << ret; - return kMEFailed; - } - cur_offset += receive_msg_->msg_len; - receive_msg_->read_ready_flag = 0; - receive_msg_->read_finish_flag = 1; - MS_LOG(INFO) << "Receive end, current length " << cur_offset << ", total length " << msg_len << std::endl; - } while (msg_len > cur_offset); - return kSuccess; -} - -void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); } - -void MultiProcess::HeartbeatThreadFuncInner() { - constexpr uint64_t kOvertime = 1024; - uint64_t last_beat_cnt = 0; - uint64_t repeat_cnt = 0; - while (!stopped_) { - if (receive_msg_->stop) { - peer_stopped_ = true; - MS_LOG(WARNING) << "Peer stopped"; - break; - } - uint64_t heartbeat_gap = receive_msg_->heartbeat - last_beat_cnt; - if (heartbeat_gap > 0 && heartbeat_gap < kOvertime) { - last_beat_cnt = receive_msg_->heartbeat; - repeat_cnt = 0; - } else { - repeat_cnt++; - if (repeat_cnt > 30) { // 30*100ms = 3s no reply - peer_stopped_ = true; - MS_LOG(WARNING) << "Peer stopped"; - break; - } - } - send_msg_->heartbeat += 1; - (void)nanosleep(&kOneHundredMilliseconds, nullptr); // sleep 100 ms - } -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_converter_utils/shared_memory.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_converter_utils/shared_memory.cc deleted file mode 100644 index efbbeb1bcd3ca84bbccd97421f47f260af0706c3..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_converter_utils/shared_memory.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cxx_api/model/model_converter_utils/shared_memory.h" -#include -#include -#include -#include "src/common/log_adapter.h" - -namespace mindspore { -Status SharedMemory::Create(uint64_t memory_size) { - auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP; - shm_id_ = shmget(IPC_PRIVATE, memory_size, IPC_CREAT | IPC_EXCL | access_mode); - if (shm_id_ == -1) { - MS_LOG(ERROR) << "Shared memory creation failed. Errno " + std::to_string(errno); - return kMCFailed; - } - MS_LOG(INFO) << "shmget success, shm id " << shm_id_; - return kSuccess; -} - -Status SharedMemory::Attach() { - void *shmat_addr = shmat(shm_id_, nullptr, 0); - if (shmat_addr == reinterpret_cast(-1)) { - MS_LOG(ERROR) << "Shared memory attach failed. Errno " + std::to_string(errno); - return kMCFailed; - } - shmat_addr_ = reinterpret_cast(shmat_addr); - return kSuccess; -} - -void SharedMemory::Detach() { - if (shmat_addr_ != nullptr) { - auto err = shmdt(shmat_addr_); - if (err == -1) { - MS_LOG(ERROR) << "Shared memory detach failed. Errno " + std::to_string(errno); - return; - } - } - shmat_addr_ = nullptr; -} - -void SharedMemory::Destroy() const { - // Remove the shared memory and never mind about the return code. - auto err = shmctl(shm_id_, IPC_RMID, nullptr); - if (err == -1) { - std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id_); - errMsg += ". Errno :" + std::to_string(errno); - errMsg += "\nPlesae remove it manually using ipcrm -m command"; - MS_LOG(ERROR) << errMsg; - } -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_impl.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_impl.cc deleted file mode 100644 index 742de94130eb3bbb281cd72bf288e49a262921ba..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_impl.cc +++ /dev/null @@ -1,140 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cxx_api/model/model_impl.h" -#include -#include "cxx_api/dlutils.h" - -namespace mindspore { -Status ModelImpl::Predict(const std::vector &inputs, std::vector *outputs) { - MS_EXCEPTION_IF_NULL(outputs); - if (graph_ == nullptr) { - MS_LOG(ERROR) << "Invalid data, graph_ is null."; - return kMCFailed; - } - - if (graph_cell_ == nullptr) { - MS_LOG(WARNING) << "Model has not been built, it will be built with default options"; - Status ret = Build(); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Build model failed."; - return ret; - } - } - - MS_EXCEPTION_IF_NULL(graph_cell_); - Status ret = graph_cell_->Run(inputs, outputs); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Run graph failed."; - return ret; - } - - return kSuccess; -} - -bool ModelImpl::HasPreprocess() { return graph_->graph_data_->GetPreprocess().empty() ? false : true; } - -Status ModelImpl::Preprocess(const std::vector> &inputs, std::vector *outputs) { -#if !defined(_WIN32) && !defined(_WIN64) - // Config preprocessor, temporary way to let mindspore.so depends on _c_dataengine - std::string dataengine_so_path; - Status dlret = DLSoPath(&dataengine_so_path); - CHECK_FAIL_AND_RELEASE(dlret, nullptr, "Parse dataengine_so failed: " + dlret.GetErrDescription()); - - // Run preprocess - if (!HasPreprocess()) { - MS_LOG(ERROR) << "Attempt to predict with data preprocessor, but no preprocessor is defined in MindIR."; - return Status(kMEFailed, "Attempt to predict with data preprocessor, but no preprocessor is defined in MindIR."); - } - - void *handle = nullptr; - void *function = nullptr; - dlret = DLSoOpen(dataengine_so_path, "ExecuteRun_C", &handle, &function); - CHECK_FAIL_AND_RELEASE(dlret, handle, "Parse ExecuteRun_C failed: " + dlret.GetErrDescription()); - auto ExecuteRun = - (void (*)(const std::vector> &, const std::vector &, - std::vector *, Status *))(function); - - // perform preprocess on each tensor separately - std::vector> preprocessor = graph_->graph_data_->GetPreprocess(); - std::vector> output_unbatch; - std::vector output_batched; - for (auto tensor : inputs) { - std::vector temp; - ExecuteRun(preprocessor, tensor, &temp, &dlret); - CHECK_FAIL_AND_RELEASE(dlret, handle, "Run preprocess failed: " + dlret.GetErrDescription()); - output_unbatch.push_back(temp); - } - - // Construct a tensor with batch dim - output_batched.resize(output_unbatch[0].size()); - for (size_t i = 0; i < output_batched.size(); i++) { - std::vector ori_shape = output_unbatch[0][i].Shape(); - ori_shape.insert(ori_shape.begin(), output_unbatch.size()); - output_batched[i] = mindspore::MSTensor("outputs", output_unbatch[0][i].DataType(), ori_shape, nullptr, - output_unbatch[0][i].DataSize() * output_unbatch.size()); - } - - // Copy unbatch data into tensor - for (size_t i = 0; i < output_unbatch[0].size(); i++) { - size_t offset = 0; - for (size_t j = 0; j < output_unbatch.size(); j++) { - auto ret = - memcpy_s(reinterpret_cast(output_batched[i].MutableData()) + offset, - output_unbatch[j][i].DataSize(), output_unbatch[j][i].MutableData(), output_unbatch[j][i].DataSize()); - if (ret) { - MS_LOG(ERROR) << "Memory copy failed to construct High-Dim Tensor."; - return Status(kMEFailed, "Memory copy failed to construct High-Dim Tensor."); - } - offset += output_unbatch[j][i].DataSize(); - } - } - *outputs = output_batched; - DLSoClose(handle); - return kSuccess; -#else - MS_LOG(ERROR) << "Data preprocess is not supported on Windows yet."; - return Status(kMEFailed, "Data preprocess is not supported on Windows yet."); -#endif -} - -Status ModelImpl::PredictWithPreprocess(const std::vector> &inputs, - std::vector *outputs) { -#if !defined(_WIN32) && !defined(_WIN64) - // Run preprocess - std::vector preprocess_outputs; - Status ret = Preprocess(inputs, &preprocess_outputs); - if (ret != kSuccess) { - return ret; - } - - // Run prediction - ret = Predict(preprocess_outputs, outputs); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Run predict failed: " << ret.GetErrDescription(); - return ret; - } - return kSuccess; -#else - MS_LOG(ERROR) << "Predict with data preprocess is not supported on Windows yet."; - return Status(kMEFailed, "Predict with data preprocess is not supported on Windows yet."); -#endif -} - -Status ModelImpl::Finalize() { - MS_LOG(ERROR) << "Finalize is only support for mindspore_lite's ascend backend."; - return kLiteError; -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_impl.h b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_impl.h deleted file mode 100644 index d80ba5753f097e27eae821758881c37391f4bbb3..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/model_impl.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H -#define MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H -#include -#include -#include -#include "include/api/context.h" -#include "include/api/model.h" -#include "include/api/graph.h" -#include "cxx_api/graph/graph_data.h" -#include "include/common/utils/utils.h" -#include "ir/func_graph.h" - -namespace mindspore { -class MS_API ModelImpl { - public: - ModelImpl() = default; - virtual ~ModelImpl() = default; - - virtual Status Build() = 0; - virtual Status Resize(const std::vector &inputs, const std::vector> &dims) = 0; - - virtual Status Predict(const std::vector &inputs, std::vector *outputs); - - virtual Status PredictWithPreprocess(const std::vector> &inputs, - std::vector *outputs); - - virtual std::vector GetInputs() = 0; - virtual std::vector GetOutputs() = 0; - - virtual bool CheckDeviceSupport(mindspore::DeviceType device_type) = 0; - virtual bool CheckModelSupport(enum ModelType model_type) = 0; - - virtual Status Preprocess(const std::vector> &inputs, std::vector *outputs); - - virtual bool HasPreprocess(); - - Status Finalize(); - - protected: - FuncGraphPtr GetFuncGraph() const { - if (graph_->ModelType() != ModelType::kMindIR) { - return nullptr; - } - - auto graph_data = graph_->graph_data_; - MS_EXCEPTION_IF_NULL(graph_data); - return graph_data->GetFuncGraph(); - } - - std::shared_ptr graph_ = nullptr; - std::shared_ptr graph_cell_ = nullptr; - std::shared_ptr model_context_ = nullptr; - - private: - friend class Model; - void SetGraph(const std::shared_ptr &graph) { graph_ = graph; } - void SetContext(const std::shared_ptr &model_context) { - if (model_context != nullptr) { - model_context_ = std::make_shared(*model_context); - } - } -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_CXX_API_MODEL_MODEL_IMPL_H diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/ms/ms_model.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/ms/ms_model.cc deleted file mode 100644 index 0161b20b9d6ff8206b847c0f216bcce0b2b342c2..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/ms/ms_model.cc +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "cxx_api/model/ms/ms_model.h" -#include -#include -#include -#include "include/api/context.h" -#include "utils/ms_context.h" -#include "cxx_api/factory.h" -#include "runtime/hardware/device_context_manager.h" - -namespace mindspore { -namespace { -bool CheckIsAscend910Soc() { - const char *soc_name_c = CALL_ASCEND_API(aclrtGetSocName); - if (soc_name_c == nullptr) { - return false; - } - std::string soc_name(soc_name_c); - if (soc_name.find("910") == std::string::npos) { - return false; - } - return true; -} -} // namespace -// mindspore-serving check current package for version check with ModelImpl factory. -API_MODEL_REG(kMS, MsModel); - -static std::string GenerateShapeKey(const std::vector> &dims) { - std::string shape_key; - for (size_t i = 0; i < dims.size(); ++i) { - shape_key += std::to_string(i) + ":"; - for (size_t j = 0; j < dims[i].size(); ++j) { - shape_key += std::to_string(dims[i][j]); - if (j + 1 < dims[i].size()) { - shape_key += ","; - } - } - if (i + 1 < dims.size()) { - shape_key += ";"; - } - } - return shape_key; -} - -std::shared_ptr MsModel::GenerateGraphCell(const std::vector> &dims) { - std::string shape_key = GenerateShapeKey(dims); - if (auto iter = dynamic_size_graph_map_.find(shape_key); iter != dynamic_size_graph_map_.end()) { - MS_LOG(INFO) << "This options has been built, read cache."; - return iter->second; - } - - auto func_graph = ModelImpl::GetFuncGraph(); - MS_EXCEPTION_IF_NULL(func_graph); - - const auto &inputs = func_graph->parameters(); - if (dims.size() != inputs.size()) { - MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match model inputs size " << inputs.size(); - return nullptr; - } - for (size_t i = 0; i < dims.size(); ++i) { - const auto ¶m = inputs[i]; - auto shape_ptr = std::dynamic_pointer_cast(param->Shape()); - if (shape_ptr == nullptr) { - MS_LOG(ERROR) << "Inputs " << i << " is not supported to resize, debug string: " << param->DebugString(); - return nullptr; - } - shape_ptr->set_shape(dims[i]); - } - - auto graph = std::make_shared(std::make_shared(func_graph, ModelType::kMindIR)); - MS_EXCEPTION_IF_NULL(graph); - auto graph_cell = std::make_shared(graph); - MS_EXCEPTION_IF_NULL(graph_cell); - graph_cell->SetContext(model_context_); - auto ret = graph_cell->Load(GetDeviceID()); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Load failed."; - return nullptr; - } - dynamic_size_graph_map_[shape_key] = graph_cell; - return graph_cell; -} - -Status MsModel::Build() { - MS_LOG(INFO) << "Start build model."; - MS_EXCEPTION_IF_NULL(graph_); - - if (graph_cell_ != nullptr) { - MS_LOG(INFO) << "This model has been built, skip."; - return kSuccess; - } - - auto func_graph = ModelImpl::GetFuncGraph(); - MS_EXCEPTION_IF_NULL(func_graph); - - auto graph = std::make_shared(std::make_shared(func_graph, ModelType::kMindIR)); - MS_EXCEPTION_IF_NULL(graph); - auto graph_cell = std::make_shared(graph); - MS_EXCEPTION_IF_NULL(graph_cell); - graph_cell->SetContext(model_context_); - auto ret = graph_cell->Load(GetDeviceID()); - if (ret != kSuccess) { - MS_LOG(ERROR) << "Load failed."; - return ret; - } - - // save result - graph_cell_ = graph_cell; - MS_LOG(INFO) << "Build model success."; - return kSuccess; -} - -Status MsModel::Resize(const std::vector &inputs, const std::vector> &dims) { - MS_LOG(INFO) << "Start to resize model"; - auto origin_inputs = GetInputs(); - if (inputs.size() != origin_inputs.size()) { - MS_LOG(ERROR) << "Invalid inputs size " << inputs.size() << " not match model inputs size " << origin_inputs.size(); - return kMCInvalidInput; - } - - if (inputs.size() != dims.size()) { - MS_LOG(ERROR) << "Invalid dims size " << dims.size() << " not match inputs size " << inputs.size(); - return kMCInvalidInput; - } - - auto graph_cell = GenerateGraphCell(dims); - if (graph_cell == nullptr) { - MS_LOG(ERROR) << "GenerateGraphCell failed."; - return kMCFailed; - } - - MS_LOG(INFO) << "Resize model success."; - graph_cell_ = std::move(graph_cell); - return kSuccess; -} - -std::vector MsModel::GetInputs() { - MS_EXCEPTION_IF_NULL(graph_cell_); - return graph_cell_->GetInputs(); -} - -std::vector MsModel::GetOutputs() { - MS_EXCEPTION_IF_NULL(graph_cell_); - return graph_cell_->GetOutputs(); -} - -uint32_t MsModel::GetDeviceID() const { - if (model_context_ == nullptr) { - return 0; - } - - auto &device_infos = model_context_->MutableDeviceInfo(); - if (device_infos.size() != 1) { - return 0; - } - - auto ascend910_info = device_infos[0]->Cast(); - if (ascend910_info != nullptr) { - return ascend910_info->GetDeviceID(); - } - - auto gpu_info = device_infos[0]->Cast(); - if (gpu_info != nullptr) { - return gpu_info->GetDeviceID(); - } - - return 0; -} - -bool MsModel::CheckDeviceSupport(enum DeviceType device_type) { - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - auto device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); - if (device_target == kAscendDevice) { - if (device_type != kAscend && device_type != kAscend910) { - return false; - } - return CheckIsAscend910Soc(); - } - if (device_type != kGPU) { - return false; - } - return true; -} - -bool MsModel::CheckModelSupport(mindspore::ModelType model_type) { - static const std::set kSupportedModelMap = {kMindIR}; - auto iter = kSupportedModelMap.find(model_type); - if (iter == kSupportedModelMap.end()) { - return false; - } - return true; -} -} // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/clip_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/clip_mapper.cc index b825ed5b326c06860453bc334028ab529aad8a4f..0e72ef4719252d4ad870f63dc9de4b03564ef073 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/clip_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/clip_mapper.cc @@ -34,7 +34,7 @@ const size_t kNumInputIndex3 = 3; const size_t kNumInputSize3 = 3; } // namespace STATUS ClipMapper::Mapper(const CNodePtr &cnode) { - CHECK_NULL_RETURN(cnode); + MS_ASSERT(cnode != nullptr); auto func_graph = cnode->func_graph(); CHECK_NULL_RETURN(func_graph); auto prim = ops::GetOperator(cnode->input(0)); diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/gather_d_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/gather_d_mapper.cc index 72b69e3ea8de89dd2c8e4dd3c94d62d7e9c74a5a..058bc8a54857c28b58fa4e20c755231d464634b9 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/gather_d_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/gather_d_mapper.cc @@ -30,7 +30,7 @@ STATUS GetParameterDim(const AnfNodePtr &cnode, int64_t *result_dim) { MS_LOG(WARNING) << "result dim is nullptr."; return RET_NULL_PTR; } - MS_CHECK_TRUE_RET(cnode->cast()->inputs().size() > THIRD_INPUT, RET_NULL_PTR); + auto dim_param = cnode->cast()->input(THIRD_INPUT)->cast()->default_param(); if (dim_param == nullptr) { MS_LOG(WARNING) << "dim_param is nullptr."; diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/im2col_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/im2col_mapper.cc index 93ae694dd93bc6806014f56a7e46faa3fb8286ed..9881d226bafc68c120fc81d436f211a430d98652 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/im2col_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/im2col_mapper.cc @@ -38,7 +38,6 @@ STATUS Im2ColMapper::Mapper(const CNodePtr &cnode) { // make dst prim auto dst_prim = std::make_shared(); - CHECK_NULL_RETURN(dst_prim); dst_prim->SetAttrs(src_prim->attrs()); value_node->set_value(dst_prim); diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/lp_normalization_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/lp_normalization_mapper.cc index 7059d42437d6d157c3e40144805da85ddaec6709..ad4bb9953358f3c5eb4eab8036474a560669f817 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/lp_normalization_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/lp_normalization_mapper.cc @@ -30,7 +30,6 @@ namespace { constexpr size_t kNameLpNormInputNum = 1; } // namespace STATUS LpNormalizationMapper::Mapper(const CNodePtr &cnode) { - CHECK_NULL_RETURN(cnode); /* * input1 input * | | @@ -49,12 +48,6 @@ STATUS LpNormalizationMapper::Mapper(const CNodePtr &cnode) { MS_LOG(ERROR) << "Get primitive from cnode failed."; return lite::RET_ERROR; } - CHECK_NULL_RETURN(value_node); - CHECK_NULL_RETURN(src_prim); - if (cnode->inputs().size() < kNameLpNormInputNum) { - MS_LOG(ERROR) << "input size is less than " << kNameLpNormInputNum << ", input size is " << cnode->inputs().size(); - return RET_ERROR; - } auto origin_input = cnode->inputs()[kNameLpNormInputNum]; auto func_graph = cnode->func_graph(); CHECK_NULL_RETURN(func_graph); @@ -81,7 +74,6 @@ STATUS LpNormalizationMapper::Mapper(const CNodePtr &cnode) { MS_LOG(ERROR) << "Failed to get func graph manager from cnode " << cnode->fullname_with_scope(); return RET_ERROR; } - CHECK_NULL_RETURN(cnode->abstract()); auto new_lpnorm_node = NewCNode(cnode, dst_prim, {origin_input}, cnode->abstract()->Clone(), cnode->fullname_with_scope() + "_LpNorm"); if (new_lpnorm_node == nullptr) { diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc index c7a9807c447fbd6580f220d67eb2645d6d67b404..e7b3974b768d345aebc83bb265cb4f93cc0d7006 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/matmul_fusion_mapper.cc @@ -271,7 +271,6 @@ void MatMulFusionMapper::SetMatMulTransposeAttr(const PrimitivePtr &src_prim, co } STATUS MatMulFusionMapper::Mapper(const CNodePtr &cnode) { - CHECK_NULL_RETURN(cnode); auto quant_holder = GetCNodeQuantHolder(cnode); MS_CHECK_TRUE_MSG(quant_holder != nullptr, RET_NULL_PTR, "quant holder is nullptr."); auto cnode_primitive = GetValueNode(cnode->input(0)); @@ -280,7 +279,6 @@ STATUS MatMulFusionMapper::Mapper(const CNodePtr &cnode) { return QuantMapper(cnode); } else if (cnode_primitive->HasAttr(quant::kQuantType)) { auto quant_type_attr = cnode_primitive->GetAttr(quant::kQuantType); - CHECK_NULL_RETURN(quant_type_attr); auto quant_type = static_cast(GetValue(quant_type_attr)); if (quant_type != quant::QUANT_NONE) { return QuantMapper(cnode); @@ -297,8 +295,6 @@ STATUS MatMulFusionMapper::Mapper(const CNodePtr &cnode) { MS_LOG(ERROR) << "Get primitive from cnode failed."; return lite::RET_ERROR; } - CHECK_NULL_RETURN(value_node); - CHECK_NULL_RETURN(src_prim); SetMatMulTransposeAttr(src_prim, src_prim); return RET_OK; } @@ -312,8 +308,6 @@ STATUS MatMulFusionMapper::Mapper(const CNodePtr &cnode) { MS_LOG(ERROR) << "Get primitive from cnode failed."; return lite::RET_ERROR; } - CHECK_NULL_RETURN(value_node); - CHECK_NULL_RETURN(src_prim); std::vector shape_vector; if (acl::GetShapeVectorFromCNode(cnode, &shape_vector) != RET_OK) { MS_LOG(ERROR) << "Get cnode shape failed, cnode " << cnode->fullname_with_scope(); diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/onehot_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/onehot_mapper.cc index ce89b4d6b9919345cbb46e0465b28e2e402af86e..fdc79a73477b575b6d526515dcf5c5e6aee956f1 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/onehot_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/onehot_mapper.cc @@ -29,8 +29,7 @@ namespace lite { namespace { STATUS GetAttrAxis(const AnfNodePtr &cnode, int64_t *result_axis) { MS_CHECK_TRUE_RET(cnode != nullptr, RET_NULL_PTR); - MS_CHECK_TRUE_RET(cnode->cast() != nullptr, RET_NULL_PTR); - MS_CHECK_TRUE_RET(cnode->cast()->inputs().size() > 0, RET_NULL_PTR); + auto onehot_node = ops::GetOperator(cnode->cast()->input(0)); MS_CHECK_TRUE_RET(onehot_node != nullptr, RET_ERROR); diff --git a/mindspore-lite/tools/converter/adapter/acl/mapper/prelu_fusion_mapper.cc b/mindspore-lite/tools/converter/adapter/acl/mapper/prelu_fusion_mapper.cc index 1d1be39b184ff6ca420039bf2ccf402919164c58..b341f26734e57b0050f79cccc594ed5a9cfbe33a 100644 --- a/mindspore-lite/tools/converter/adapter/acl/mapper/prelu_fusion_mapper.cc +++ b/mindspore-lite/tools/converter/adapter/acl/mapper/prelu_fusion_mapper.cc @@ -27,7 +27,6 @@ STATUS PReluFusionMapper::Mapper(const CNodePtr &cnode) { CHECK_NULL_RETURN(cnode); ops::PReLU prelu_op; auto dst_prim = prelu_op.GetPrim(); - CHECK_NULL_RETURN(dst_prim); if (MoveAttrMap(cnode, dst_prim) != RET_OK) { MS_LOG(ERROR) << "PReluFusion mapper failed."; return RET_ERROR; diff --git a/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc b/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc index f1752372f6952403f414d48e9592589aef456e48..8f00fe9beb7fb15725230ffe0b8dd86a11cf772e 100644 --- a/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc +++ b/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc @@ -63,7 +63,6 @@ #include "tools/converter/quantizer/insert_quant_node_manager.h" #include "tools/converter/parser/unify_format.h" #include "tools/converter/adapter/acl/src/acl_custom_opp_installer.h" -#include "tools/graph_kernel/converter/graph_kernel_optimization.h" #include "tools/lite_exporter/fetch_content.h" #include "tools/converter/quantizer/quant_helper/ascend_distribute_fake_quant_transform.h" #include "tools/converter/quantizer/quant_helper/ffn_full_quant.h" @@ -1304,15 +1303,6 @@ bool AclPassImpl::Run(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "Set acl model options error!"; return false; } -#ifdef MSLITE_ENABLE_GRAPH_KERNEL - auto soc_version = this->options_->GetSocVersion(); - param_->device = soc_version; - if (GraphKernelOptimize(func_graph, param_) != lite::RET_OK) { - MS_LOG(ERROR) << "Run graphkernel optimization failed."; - return false; - } -#endif - if (PostProcCustomOp(func_graph) != lite::RET_OK) { MS_LOG(ERROR) << "Post proc CustomOp failed."; return false; diff --git a/mindspore-lite/tools/converter/anf_transform.cc b/mindspore-lite/tools/converter/anf_transform.cc index 1501056291fd028f8765c587f1431f37f6222b1c..325342c9ef33fc8ff8c424405703a3a0466c4704 100644 --- a/mindspore-lite/tools/converter/anf_transform.cc +++ b/mindspore-lite/tools/converter/anf_transform.cc @@ -116,7 +116,6 @@ #include "src/common/log_util.h" #include "src/common/string_utils.h" #include "src/common/config_infos.h" -#include "tools/graph_kernel/converter/graph_kernel_optimization.h" #include "tools/optimizer/fusion/groupnorm_fusion.h" #include "tools/optimizer/fusion/mul_reduce_fusion.h" #include "tools/optimizer/fusion/reshape_like_operator_ablation.h" diff --git a/mindspore-lite/tools/converter/config_parser/CMakeLists.txt b/mindspore-lite/tools/converter/config_parser/CMakeLists.txt index 69f4976826c7530e50149905685747944925e4d0..2495b2934f0ced5a6ff6159f3c6f5c71bc89b795 100644 --- a/mindspore-lite/tools/converter/config_parser/CMakeLists.txt +++ b/mindspore-lite/tools/converter/config_parser/CMakeLists.txt @@ -1,7 +1,7 @@ file(GLOB_RECURSE CONFIG_PARSER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.cc ) -set_property(SOURCE ${CONFIG_PARSER_SRC_LIST} PROPERTY COMPILE_DEFINITIONS +set_property(SOURCE ${DECOMPOSER} PROPERTY COMPILE_DEFINITIONS LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) add_library(config_parser_mid OBJECT diff --git a/mindspore-lite/tools/converter/converter.cc b/mindspore-lite/tools/converter/converter.cc index 5d8291505171d2cab19528d69225ecbea3a129d2..321d079108de59eb2e60aba721a6ddb8e86ef1bc 100644 --- a/mindspore-lite/tools/converter/converter.cc +++ b/mindspore-lite/tools/converter/converter.cc @@ -26,7 +26,6 @@ #include "src/common/log_adapter.h" #include "tools/common/meta_graph_serializer.h" #include "tools/lite_exporter/anf_exporter.h" -#include "tools/graph_kernel/converter/graph_kernel_optimization.h" #ifdef SUPPORT_TRAIN #include "src/train/train_populate_parameter.h" #endif diff --git a/mindspore-lite/tools/converter/converter_funcgraph.cc b/mindspore-lite/tools/converter/converter_funcgraph.cc index a62d6d866a4332c751874164cf0f8954122aa767..abbf95e3c9498f505ea5c03fdfe74ebe1f6ec59e 100644 --- a/mindspore-lite/tools/converter/converter_funcgraph.cc +++ b/mindspore-lite/tools/converter/converter_funcgraph.cc @@ -25,7 +25,6 @@ #include "src/common/log_adapter.h" #include "tools/common/meta_graph_serializer.h" #include "tools/lite_exporter/anf_exporter.h" -#include "tools/graph_kernel/converter/graph_kernel_optimization.h" #ifdef SUPPORT_TRAIN #include "src/train/train_populate_parameter.h" #endif diff --git a/mindspore-lite/tools/converter/converter_lite/CMakeLists.txt b/mindspore-lite/tools/converter/converter_lite/CMakeLists.txt index f28537826875b03460e7a9e035173db95fe07f4f..d11b8bfbef39ca78a1f4dfb04435e713eee61ff7 100644 --- a/mindspore-lite/tools/converter/converter_lite/CMakeLists.txt +++ b/mindspore-lite/tools/converter/converter_lite/CMakeLists.txt @@ -5,9 +5,8 @@ link_directories(${opencv_INC}/../lib) add_executable(converter_lite main.cc converter_flags.cc ${TOP_DIR}/mindspore-lite/src/common/log.cc ${TOP_DIR}/mindspore-lite/src/common/utils.cc - ${TOP_DIR}/mindspore/mindspore/core/utils/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../common/flag_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../common/string_util.cc ${TOP_DIR}/mindspore-lite/src/common/file_utils.cc) -target_link_libraries(converter_lite mindspore_converter) +target_link_libraries(converter_lite mindspore_converter mindspore_core) diff --git a/mindspore-lite/tools/converter/decomposer/CMakeLists.txt b/mindspore-lite/tools/converter/decomposer/CMakeLists.txt index 7e1c1a1018af2fc7d75ec1a4ff4f19ecdc6727b2..c376bef65128d4302caca6788e4320777c761679 100644 --- a/mindspore-lite/tools/converter/decomposer/CMakeLists.txt +++ b/mindspore-lite/tools/converter/decomposer/CMakeLists.txt @@ -1,9 +1,7 @@ file(GLOB DECOMPOSER ${CMAKE_CURRENT_SOURCE_DIR}/*.cc ) -set_property(SOURCE ${DECOMPOSER} PROPERTY COMPILE_DEFINITIONS - LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" - SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) +set_property(SOURCE ${DECOMPOSER} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) if(NOT MSLITE_SIMPLEST_CLOUD_INFERENCE) add_library(decomposer_mid OBJECT ${DECOMPOSER}) diff --git a/mindspore-lite/tools/converter/parser/einsum_adjust.cc b/mindspore-lite/tools/converter/parser/einsum_adjust.cc index d7ea8c309344cbbb17faf4eab1dac2c56ed7070b..7c5340876d6dcd6f5eba335d6b3e3481233a589c 100644 --- a/mindspore-lite/tools/converter/parser/einsum_adjust.cc +++ b/mindspore-lite/tools/converter/parser/einsum_adjust.cc @@ -48,9 +48,9 @@ lite::STATUS CheckSubdims(const std::string &first_subdims, const std::string &s min_dim = min_dim < output_subdims.length() ? min_dim : output_subdims.length(); auto max_subdims = first_subdims.length() > second_subdims.length() ? first_subdims : second_subdims; if (first_subdims.substr(first_subdims.length() - min_dim) != - second_subdims.substr(second_subdims.length() - min_dim) || + second_subdims.substr(second_subdims.length() - min_dim) || first_subdims.substr(first_subdims.length() - min_dim) != - output_subdims.substr(output_subdims.length() - min_dim) || + output_subdims.substr(output_subdims.length() - min_dim) || max_subdims.substr(0, 1) != output_subdims.substr(0, 1)) { return RET_ERROR; } diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc index 3366376a29b3c70ee082f97a0f155a06a8f18752..d57201bab73b0ad20dc191fe1e7b59b1babe31aa 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -203,15 +203,15 @@ STATUS AddIterNumsUpdateEdge(const FuncGraphPtr &anf_graph, std::vector *> &control_nodes_map) { - auto iter1 = control_nodes_map.find(loop_node_name); - if (iter1 == control_nodes_map.end()) { - return nullptr; - } // namespace - auto iter2 = iter1->second->find(loop_node_name); - if (iter2 == iter1->second->end()) { - return nullptr; - } - return iter2->second->cast(); +auto iter1 = control_nodes_map.find(loop_node_name); +if (iter1 == control_nodes_map.end()) { +return nullptr; +} // namespace +auto iter2 = iter1->second->find(loop_node_name); +if (iter2 == iter1->second->end()) { +return nullptr; +} +return iter2->second->cast(); } STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector &return_inputs) { diff --git a/mindspore-lite/tools/converter/parser/tf/CMakeLists.txt b/mindspore-lite/tools/converter/parser/tf/CMakeLists.txt index 6d64fd17bbede3f9a07b8d384b4ea58af6b2a7e7..3d28fcef03a4e0c17461d5b7db866a3d6e4e4946 100644 --- a/mindspore-lite/tools/converter/parser/tf/CMakeLists.txt +++ b/mindspore-lite/tools/converter/parser/tf/CMakeLists.txt @@ -9,7 +9,6 @@ endif() set_property(SOURCE ${TF_SRC_LIST} PROPERTY COMPILE_DEFINITIONS LOG_HDR_FILE_REL_PATH="mindspore-lite/../mindspore/mindspore/core/include/utils/log_adapter.h" SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) - add_library(tf_parser_mid OBJECT ${TF_SRC_LIST}) add_dependencies(tf_parser_mid proto_mid) diff --git a/mindspore-lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore-lite/tools/converter/parser/tflite/tflite_model_parser.cc index 5f9c57d660981100393b2b4dd5e4f34ec4aa6c3c..e72dbaf03298b0cde9f1bc5056274efe65d6b56f 100644 --- a/mindspore-lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -81,14 +81,14 @@ STATUS TfliteModelParser::TfliteOpVerify(const std::unique_ptrinputs.begin(), op->inputs.end(), [&all_tensor_num](int32_t index) { - return index >= all_tensor_num || index + all_tensor_num < 0; - })) { + return index >= all_tensor_num || index + all_tensor_num < 0; + })) { MS_LOG(ERROR) << "op input illegal."; return RET_ERROR; } if (std::any_of(op->outputs.begin(), op->outputs.end(), [&all_tensor_num](int32_t index) { - return index >= all_tensor_num || index + all_tensor_num < 0; - })) { + return index >= all_tensor_num || index + all_tensor_num < 0; + })) { MS_LOG(ERROR) << "op output illegal."; return RET_ERROR; } @@ -141,14 +141,14 @@ STATUS TfliteModelParser::TfliteModelVerify() { return RET_ERROR; } if (std::any_of(subgraph->inputs.begin(), subgraph->inputs.end(), [&all_subgraph_tensor_size](int32_t index) { - return index >= static_cast(all_subgraph_tensor_size) || index < 0; - })) { + return index >= static_cast(all_subgraph_tensor_size) || index < 0; + })) { MS_LOG(ERROR) << "tflite input illegal."; return RET_ERROR; } if (std::any_of(subgraph->outputs.begin(), subgraph->outputs.end(), [&all_subgraph_tensor_size](int32_t index) { - return index >= static_cast(all_subgraph_tensor_size) || index < 0; - })) { + return index >= static_cast(all_subgraph_tensor_size) || index < 0; + })) { MS_LOG(ERROR) << "tflite output illegal."; return RET_ERROR; } @@ -594,8 +594,8 @@ STATUS TfliteModelParser::ConvertGraphOutputs(const std::unique_ptroutputs.front() < 0 - ? static_cast(tflite_subgraph->outputs.front() + tflite_subgraph->tensors.size()) - : static_cast(tflite_subgraph->outputs.front()); + ? static_cast(tflite_subgraph->outputs.front() + tflite_subgraph->tensors.size()) + : static_cast(tflite_subgraph->outputs.front()); auto return_prim_c = returnPrim->GetPrim(); MSLITE_CHECK_PTR(return_prim_c); auto value_node = NewValueNode(return_prim_c); diff --git a/mindspore-lite/tools/converter/quantizer/insert_quant_node_manager.cc b/mindspore-lite/tools/converter/quantizer/insert_quant_node_manager.cc index 393b567670715d6898fe039ddeb49382190e8081..9d981bc2558ebc6d77320873ed50a82322435ed9 100644 --- a/mindspore-lite/tools/converter/quantizer/insert_quant_node_manager.cc +++ b/mindspore-lite/tools/converter/quantizer/insert_quant_node_manager.cc @@ -759,7 +759,7 @@ int InsertQuantNodeManager::CalculateScaleZPNode(const FuncGraphPtr &func_graph, scales.push_back(static_cast(input_quant_params.at(i).scale * input_quant_params.at(i).varCorr)); zps.push_back(static_cast(-input_quant_params.at(i).zeroPoint + input_quant_params.at(i).meanCorr / - (input_quant_params.at(i).scale * input_quant_params.at(i).varCorr))); + (input_quant_params.at(i).scale * input_quant_params.at(i).varCorr))); } *scales_node = opt::BuildFloat16VecParameterNode(func_graph, scales, input_node->fullname_with_scope() + "-scales"); *zps_node = opt::BuildFloat16VecParameterNode(func_graph, zps, input_node->fullname_with_scope() + "-zps"); @@ -770,7 +770,7 @@ int InsertQuantNodeManager::CalculateScaleZPNode(const FuncGraphPtr &func_graph, scales.push_back(static_cast(input_quant_params.at(i).scale * input_quant_params.at(i).varCorr)); zps.push_back(static_cast(-input_quant_params.at(i).zeroPoint + input_quant_params.at(i).meanCorr / - (input_quant_params.at(i).scale * input_quant_params.at(i).varCorr))); + (input_quant_params.at(i).scale * input_quant_params.at(i).varCorr))); } *scales_node = opt::BuildFloatVecParameterNode(func_graph, scales, input_node->fullname_with_scope() + "-scales"); *zps_node = opt::BuildFloatVecParameterNode(func_graph, zps, input_node->fullname_with_scope() + "-zps"); diff --git a/mindspore-lite/tools/converter/registry/CMakeLists.txt b/mindspore-lite/tools/converter/registry/CMakeLists.txt index 31e4c75f0e52b158fd6d4019a98f8a7d75503be5..1321511ee29e63c52ff669859a710f10257cea1e 100644 --- a/mindspore-lite/tools/converter/registry/CMakeLists.txt +++ b/mindspore-lite/tools/converter/registry/CMakeLists.txt @@ -12,11 +12,13 @@ set(REG_SRC ${CONVERT_REG_SRC} ${KERNEL_REG_DIR}/../litert/inner_allocator.cc ${KERNEL_REG_DIR}/../common/string_util.cc ${KERNEL_REG_DIR}/../common/utils.cc - ${KERNEL_REG_DIR}/../extendrt/delegate/tensorrt/distribution/distribution_base.cc - ${KERNEL_REG_DIR}/../extendrt/delegate/plugin/tensorrt_executor_plugin.cc - ${KERNEL_REG_DIR}/../extendrt/kernel/ascend/plugin/ascend_allocator_plugin.cc +# ${KERNEL_REG_DIR}/../extendrt/delegate/tensorrt/distribution/distribution_base.cc +# ${KERNEL_REG_DIR}/../extendrt/delegate/plugin/tensorrt_executor_plugin.cc + # ${KERNEL_REG_DIR}/../extendrt/kernel/ascend/plugin/ascend_allocator_plugin.cc + ${KERNEL_REG_DIR}/../extendrt/delegate/ascend_acl/ascend_allocator_plugin.cc ${CONVERTER_DIR}/converter_context.cc - ${TOP_DIR}/mindspore/mindspore/ops/kernel/cpu/nnacl/tensor_c_utils.c + # ${TOP_DIR}/mindspore/ops/kernel/cpu/nnacl/tensor_c_utils.c +# ${TOP_DIR}/mindspore/lite/ops/kernel/cpu/nnacl/tensor_c_utils.c ${TOP_DIR}/mindspore-lite/src/common/file_utils.cc ) set_property(SOURCE ${REG_SRC} PROPERTY COMPILE_DEFINITIONS @@ -24,6 +26,10 @@ set_property(SOURCE ${REG_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) add_library(mslite_converter_plugin SHARED ${REG_SRC}) target_link_libraries(mslite_converter_plugin mindspore::glog) +if(NOT WIN32) + target_link_libraries(mslite_converter_plugin dl) +endif() + add_dependencies(mslite_converter_plugin fbs_src) add_dependencies(mslite_converter_plugin fbs_inner_src) diff --git a/mindspore-lite/tools/cropper/build_cropper_config.sh b/mindspore-lite/tools/cropper/build_cropper_config.sh index d0b79b56db2455a6f605966e757edcda77e45df1..6e94d13edcaa2666291d21748148af74efd4b653 100644 --- a/mindspore-lite/tools/cropper/build_cropper_config.sh +++ b/mindspore-lite/tools/cropper/build_cropper_config.sh @@ -6,7 +6,7 @@ MINDSPORE_LITE_HOME=${PROJECT_ROOT_HOME}/mindspore-lite MINDSPORE_HOME=${PROJECT_ROOT_HOME}/mindspore echo "PROJECT_ROOT_HOME path is ${PROJECT_ROOT_HOME}" cd "${PROJECT_ROOT_HOME}" || exit 1 -CROPPER_OUTPUT_DIR=build/tools/cropper +CROPPER_OUTPUT_DIR=mindspore-lite/build/tools/cropper mkdir -p ${CROPPER_OUTPUT_DIR} MAPPING_OUTPUT_FILE_NAME_TMP=${CROPPER_OUTPUT_DIR}/cropper_mapping_tmp.cfg MAPPING_OUTPUT_FILE_NAME_TRAIN_TMP=${CROPPER_OUTPUT_DIR}/cropper_mapping_train_tmp.cfg @@ -64,14 +64,14 @@ HEADER_LOCATION="-I${MINDSPORE_HOME} -I${MINDSPORE_LITE_HOME}/src -I${MINDSPORE_LITE_HOME}/src/litert/kernel/cpu -I${PROJECT_ROOT_HOME}/third_party --I${PROJECT_ROOT_HOME}/build +-I${MINDSPORE_LITE_HOME}/build -I${PROJECT_ROOT_HOME}/third_party/securec/include -I${FLATBUFFERS} -I${NLOHMANN} -I${GLOG} --I${PROJECT_ROOT_HOME}/build/schema --I${PROJECT_ROOT_HOME}/build/schema/inner --I${PROJECT_ROOT_HOME}/build/src +-I${MINDSPORE_LITE_HOME}/build/schema +-I${MINDSPORE_LITE_HOME}/build/schema/inner +-I${MINDSPORE_LITE_HOME}/build/src -I${MINDSPORE_HOME}/mindspore/ops/kernel/cpu -I${MINDSPORE_HOME}/mindspore/ccsrc/minddata/dataset" diff --git a/mindspore-lite/tools/graph_kernel/common/infer_shape.cc b/mindspore-lite/tools/graph_kernel/common/infer_shape.cc deleted file mode 100644 index 304cce1627fdc34ae1f746420b2530cf449fa50e..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/common/infer_shape.cc +++ /dev/null @@ -1,144 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "tools/graph_kernel/common/utils.h" -#include "schema/model_generated.h" -#include "src/tensor.h" -#include "src/common/utils.h" -#include "nnacl/infer/common_infer.h" -#include "nnacl/infer/infer_register.h" -#include "nnacl/custom_parameter.h" - -namespace mindspore::graphkernel { -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -namespace { -int SetOutputsShape(TensorC **outputs, size_t outputs_size, const std::string &outputs_shape_str, int batch) { - std::vector> shapes; - if (GetCustomShape(outputs_shape_str, &shapes) != RET_OK) { - return RET_ERROR; - } - if (shapes.size() != outputs_size) { - MS_LOG(ERROR) << "The saved outputs is not equal to the outputs_size: " << shapes.size() << " vs " << outputs_size; - return RET_ERROR; - } - for (size_t i = 0; i < outputs_size; i++) { - if (shapes[i].size() > MAX_SHAPE_SIZE) { - MS_LOG(ERROR) << "The output shape size " << shapes[i].size() << " is greater than max size " << MAX_SHAPE_SIZE; - return RET_ERROR; - } - for (size_t j = 0; j < shapes[i].size(); j++) { - outputs[i]->shape_[j] = j == 0 ? shapes[i][j] * batch : shapes[i][j]; - } - outputs[i]->shape_size_ = shapes[i].size(); - } - return RET_OK; -} - -int SetOutputsFormat(TensorC **outputs, size_t outputs_size, const std::string &output_format_str) { - auto formats = SplitString(output_format_str, ','); - if (formats.size() != outputs_size) { - MS_LOG(ERROR) << "The saved outputs is not equal to the outputs_size: " << formats.size() << " vs " << outputs_size; - return RET_ERROR; - } - for (size_t i = 0; i < formats.size(); i++) { - outputs[i]->format_ = std::stoi(formats[i]); - } - return RET_OK; -} - -int SetOutputsType(TensorC **outputs, size_t outputs_size, const std::string &output_type_str) { - auto types = SplitString(output_type_str, ','); - if (types.size() != outputs_size) { - MS_LOG(ERROR) << "The saved outputs is not equal to the outputs_size: " << types.size() << " vs " << outputs_size; - return RET_ERROR; - } - for (size_t i = 0; i < types.size(); i++) { - outputs[i]->data_type_ = std::stoi(types[i]); - } - return RET_OK; -} -} // namespace -int InferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, - OpParameter *parameter) { - // in PopulateCustomParameter, the primitive is store in attr_data[0] - auto param = reinterpret_cast(parameter)->attr_data[0]; - auto prim = reinterpret_cast(param)->value_as_Custom(); - std::unordered_map attr_map; - for (size_t i = 0; i < prim->attr()->size(); i++) { - auto attr = prim->attr()->Get(i); - std::string data; - if (attr->name()->str() == "inputs_shape") { - data = std::string(reinterpret_cast(attr->data()->Data()), attr->data()->size()); - } else if (attr->name()->str() == "outputs_shape") { - data = std::string(reinterpret_cast(attr->data()->Data()), attr->data()->size()); - } else if (attr->name()->str() == "outputs_format") { - data = std::string(reinterpret_cast(attr->data()->Data()), attr->data()->size()); - } else if (attr->name()->str() == "outputs_type") { - data = std::string(reinterpret_cast(attr->data()->Data()), attr->data()->size()); - } else if (attr->name()->str() == "dynamic_input_index") { - data = std::string(reinterpret_cast(attr->data()->Data()), attr->data()->size()); - } else { - continue; - } - (void)attr_map.emplace(attr->name()->str(), data); - } - int batch = 1; - - if (attr_map.count("inputs_shape") != 0 && attr_map.count("dynamic_input_index") != 0) { - std::vector> shapes; - if (GetCustomShape(attr_map["inputs_shape"], &shapes) != RET_OK) { - return RET_ERROR; - } - std::vector index; - GetCustomIndex(attr_map["dynamic_input_index"], &index); - if (CalculateDynamicBatchSize(inputs, inputs_size, shapes, index, &batch) != RET_OK) { - return RET_ERROR; - } - } - if (attr_map.count("outputs_shape") == 0 || - SetOutputsShape(outputs, outputs_size, attr_map["outputs_shape"], batch) != RET_OK) { - return RET_ERROR; - } - if (attr_map.count("outputs_format") == 0 || - SetOutputsFormat(outputs, outputs_size, attr_map["outputs_format"]) != RET_OK) { - return RET_ERROR; - } - if (attr_map.count("outputs_type") == 0 || - SetOutputsType(outputs, outputs_size, attr_map["outputs_type"]) != RET_OK) { - return RET_ERROR; - } - - return RET_OK; -} -} // namespace mindspore::graphkernel - -#ifdef __cplusplus -extern "C" { -#endif -int GraphKernelInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, - OpParameter *parameter) { - return mindspore::graphkernel::InferShape(inputs, inputs_size, outputs, outputs_size, parameter); -} -#ifdef __cplusplus -} -#endif - -REG_INFER(GraphKernel, PrimType_Inner_GraphKernel, GraphKernelInferShape) diff --git a/mindspore-lite/tools/graph_kernel/common/utils.cc b/mindspore-lite/tools/graph_kernel/common/utils.cc deleted file mode 100644 index c70d2b1278a9ca312439f988a4b2cdb8aee38b18..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/common/utils.cc +++ /dev/null @@ -1,166 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "tools/graph_kernel/common/utils.h" -#include "src/tensor.h" - -namespace mindspore::graphkernel { -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -std::vector SplitString(const std::string &raw_str, char delimiter) { - std::vector res; - std::string::size_type last_pos = 0; - auto cur_pos = raw_str.find(delimiter); - while (cur_pos != std::string::npos) { - (void)res.emplace_back(raw_str.substr(last_pos, cur_pos - last_pos)); - cur_pos++; - last_pos = cur_pos; - cur_pos = raw_str.find(delimiter, cur_pos); - } - if (last_pos < raw_str.size()) { - (void)res.emplace_back(raw_str.substr(last_pos, raw_str.size() - last_pos + 1)); - } - return res; -} - -int GetCustomShape(const std::string &attr, std::vector> *shapes) { - auto split_shape_str = SplitString(attr, ','); - for (size_t i = 0; i < split_shape_str.size(); i++) { - size_t dim = std::stoul(split_shape_str[i]); - if (i + dim >= split_shape_str.size()) { - MS_LOG(ERROR) << "Shape string is invalid. The shape dim is " << dim << ", but only " - << split_shape_str.size() - i << " values follow."; - return RET_ERROR; - } - std::vector shape; - for (size_t j = i + 1; j <= i + dim; j++) { - shape.push_back(std::stoi(split_shape_str[j])); - } - i += dim; - shapes->push_back(shape); - } - return RET_OK; -} - -void GetCustomIndex(const std::string &dynamic_input_index, std::vector *index) { - auto split_index_str = SplitString(dynamic_input_index, ','); - for (size_t i = 0; i < split_index_str.size(); i++) { - index->push_back(std::stoul(split_index_str[i])); - } -} - -int CalculateDynamicBatchSize(const TensorC *const *inputs, size_t inputs_size, - const std::vector> &shapes, const std::vector &index, - int *batch) { - if (shapes.size() != inputs_size) { - MS_LOG(ERROR) << "The saved inputs is not equal to the inputs_size: " << shapes.size() << " vs " << inputs_size; - return RET_ERROR; - } - bool changed = false; - for (auto i : index) { - if (i >= shapes.size()) { - MS_LOG(ERROR) << "The input num is " << shapes.size() << ", but want query index " << i; - return RET_ERROR; - } - if (shapes[i].size() > MAX_SHAPE_SIZE) { - MS_LOG(ERROR) << "The input shape size " << shapes[i].size() << " is greater than max size " << MAX_SHAPE_SIZE; - return RET_ERROR; - } - for (size_t j = 0; j < shapes[i].size(); j++) { - if (j == 0) { - int bs = inputs[i]->shape_[0] / shapes[i][0]; - if (bs < 0) { - MS_LOG(ERROR) << "AKG doesn't support batch size smaller than 1"; - return RET_ERROR; - } - if (bs != (*batch)) { - if (!changed) { - *batch = bs; - changed = true; - } else { - MS_LOG(ERROR) << "AKG doesn't support inputs with different batch size"; - return RET_ERROR; - } - } - } else if (inputs[i]->shape_[j] != shapes[i][j]) { - MS_LOG(ERROR) << "AKG only support dynamic shape on axis 0"; - return RET_ERROR; - } - } - } - return RET_OK; -} - -void SetAnfKernelInfoFormatFromAToB(const AnfNodePtr &node_a, const CNodePtr &node_b, - const std::vector &formats) { - std::shared_ptr kernel_info = nullptr; - auto kernel_info_builder = kernel::KernelBuildInfo::KernelBuildInfoBuilder(); - kernel_info_builder.SetOutputsFormat(formats); - if (node_a->kernel_info_ptr() != nullptr) { - kernel_info = std::make_shared(); - } else { - kernel_info = std::dynamic_pointer_cast(node_a->kernel_info_ptr()); - } - kernel_info->set_select_kernel_build_info(kernel_info_builder.Build()); - node_b->set_kernel_info(kernel_info); -} - -void SetKernelInfoWithFormatToAnfNode(const AnfNodePtr &node, const std::vector &format) { - auto kernel_info_builder = kernel::KernelBuildInfo::KernelBuildInfoBuilder(); - kernel_info_builder.SetOutputsFormat(format); - auto kernel_build_info = kernel_info_builder.Build(); - auto kernel_info = std::make_shared(); - kernel_info->set_select_kernel_build_info(kernel_build_info); - node->set_kernel_info(kernel_info); -} - -kernel::KernelBuildInfoPtr GetKernelInfo(const AnfNodePtr &node) { - if (!node->has_user_data("kernel_info")) { - return nullptr; - } - auto kernel_info_ptr = node->kernel_info_ptr(); - if (kernel_info_ptr == nullptr) { - return nullptr; - } - auto kernel_info = std::dynamic_pointer_cast(kernel_info_ptr); - if (kernel_info == nullptr) { - MS_LOG(ERROR) << "kernel info from " << node->fullname_with_scope() << " is nullptr."; - return nullptr; - } - auto kernel_build_info = kernel_info->GetMutableSelectKernelBuildInfo(); - if (kernel_build_info == nullptr) { - MS_LOG(ERROR) << "kernel build info from " << node->fullname_with_scope() << " is nullptr."; - return nullptr; - } - return kernel_build_info; -} - -std::string GetOutputFormatFromAnfNode(const AnfNodePtr &node, size_t output_idx) { - auto kernel_build_info = GetKernelInfo(node); - if (kernel_build_info == nullptr) { - MS_LOG(EXCEPTION) << "kernel build info from " << node->fullname_with_scope() << " is empty."; - } - auto vec_size = kernel_build_info->GetOutputNum(); - if (output_idx >= vec_size) { - MS_LOG(EXCEPTION) << "Index " << output_idx << " is out of the range of node output vector, output size is " - << kernel_build_info->GetOutputNum() << ". node is " << node->fullname_with_scope(); - } - return kernel_build_info->GetOutputFormat(output_idx); -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/akg/akg_kernel_builder.h b/mindspore-lite/tools/graph_kernel/converter/akg/akg_kernel_builder.h deleted file mode 100644 index f4873740adf1d6b85c1b497ec60462c346d044e6..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/akg/akg_kernel_builder.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_AKG_KERNEL_BUILDER_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_AKG_KERNEL_BUILDER_H_ -#include -#include -#include -#include -#include "infer/custom.h" -#include "utils/anf_utils.h" -#include "kernel/graph_kernel/graph_kernel_json_generator.h" - -namespace mindspore::graphkernel { -class AkgKernelBuilder { - public: - AkgKernelBuilder() = default; - virtual ~AkgKernelBuilder() = default; - virtual bool CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) = 0; - virtual AnfNodePtr CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) = 0; - virtual bool GenerateAkgKernelNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &custom_node, - const CNodePtr &old_cnode) { - return true; - } - - static DumpOption json_option() { - DumpOption dump_json_option; - dump_json_option.get_target_info = true; - return dump_json_option; - } - - protected: - std::string dir_path_; -}; -using AkgKernelBuilderPtr = std::shared_ptr; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_AKG_KERNEL_BUILDER_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/akg/ascend_kernel_builder.cc b/mindspore-lite/tools/graph_kernel/converter/akg/ascend_kernel_builder.cc deleted file mode 100644 index 0eb254264990adff4ad6a19eb5729807ec0fda79..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/akg/ascend_kernel_builder.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/graph_kernel/converter/akg/ascend_kernel_builder.h" -#include "utils/anf_utils.h" -#include "tools/graph_kernel/converter/akg/utils.h" -#include "backend/common/graph_kernel/core/graph_kernel_utils.h" - -namespace mindspore::graphkernel { -bool AscendKernelBuilder::CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) { - static std::string rank_id = common::GetEnv("RANK_ID"); - std::string dir; - if (rank_id.empty()) { - dir = "./akg_kernel_meta"; - } else { - dir = "./rank_" + rank_id + "/akg_kernel_meta"; - } - dir_path_ = SaveNodesInfo(node_list, dir, AkgKernelBuilder::json_option(), &node_info_map_, nullptr); - return !dir_path_.empty(); -} - -AnfNodePtr AscendKernelBuilder::CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - auto op = std::make_shared(); - op->set_type("GraphKernel"); - auto custom_prim = op->GetPrim(); - auto inputs = cnode->inputs(); - inputs[0] = NewValueNode(custom_prim); - auto custom_cnode = func_graph->NewCNode(inputs); - custom_prim->EraseAttr("IsFeatureMapInputList"); - custom_prim->EraseAttr("IsFeatureMapOutput"); - - auto json_kernel_name = node_info_map_[cnode->cast()]; - auto input_num = AnfUtils::GetInputTensorNum(cnode); - auto output_num = AnfUtils::GetOutputTensorNum(cnode); - std::vector input_names; - std::vector output_names; - for (size_t i = 0; i < input_num; ++i) { - input_names.push_back("x" + std::to_string(i)); - } - for (size_t i = 0; i < output_num; ++i) { - output_names.push_back("y" + std::to_string(i)); - } - - std::ostringstream oss; - oss << "Fused_x" << input_num << "_y" << output_num; - std::string op_tye = oss.str(); - custom_prim->set_attr("reg_op_name", MakeValue(op_tye)); - custom_prim->set_attr("info_path", MakeValue(dir_path_ + "/" + json_kernel_name + ".info")); - custom_prim->set_attr("input_names", MakeValue(input_names)); - custom_prim->set_attr("output_names", MakeValue(output_names)); - custom_cnode->set_fullname_with_scope(cnode->fullname_with_scope()); - custom_cnode->set_abstract(cnode->abstract()->Clone()); - if (GkUtils::UseAkgCceLib(cnode)) { - custom_cnode->AddAttr("use_akg_cce", MakeValue(true)); - } - return custom_cnode; -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/akg/cpu_kernel_builder.cc b/mindspore-lite/tools/graph_kernel/converter/akg/cpu_kernel_builder.cc deleted file mode 100644 index 425f76eec9d5c5ba3d53090c732280c2b52da3d7..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/akg/cpu_kernel_builder.cc +++ /dev/null @@ -1,126 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/graph_kernel/converter/akg/cpu_kernel_builder.h" -#include -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "utils/anf_utils.h" -#include "src/common/log_adapter.h" -#include "utils/system/env.h" -#include "backend/common/graph_kernel/graph_kernel_flags.h" -#include "tools/graph_kernel/converter/akg/utils.h" -#include "kernel/graph_kernel/graph_kernel_json_flags.h" -#include "nlohmann/json.hpp" -using json = nlohmann::json; - -namespace mindspore::graphkernel { -AnfNodePtr CpuKernelBuilder::CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - auto op = std::make_shared(); - op->set_type("GraphKernel"); - std::map> custom_attrs; - auto fg = GetCNodeFuncGraph(cnode); - MS_EXCEPTION_IF_NULL(fg); - auto kernel_name = GetValue(fg->get_attr("kernel_name")); - std::vector kernel_name_str(kernel_name.begin(), kernel_name.end()); - custom_attrs["kernel_name"] = kernel_name_str; - if (GraphKernelFlags::GetInstance().enable_dynamic_batch && fg->has_attr("dynamic_input_index")) { - std::string dynamic_input_index = GetValue(fg->get_attr("dynamic_input_index")); - custom_attrs["dynamic_input_index"] = std::vector(dynamic_input_index.begin(), dynamic_input_index.end()); - } - auto kernel_file_name = kernel_name.substr(0, kernel_name.find("_kernel")); - auto info_path = dir_path_ + "/" + kernel_file_name + ".info"; - std::ifstream ifile(info_path); - if (ifile.fail()) { - MS_LOG(ERROR) << "Can not find info at: " << info_path; - return nullptr; - } - json json_info; - ifile >> json_info; - auto process = json_info.at(kJsonKeyProcess).get(); - custom_attrs[kJsonKeyProcess] = std::vector(process.begin(), process.end()); - auto target_info = json_info.at(kJsonKeyTargetInfo); - auto arch = target_info.at(kJsonKeyArch).get(); - custom_attrs[kJsonKeyArch] = std::vector(arch.begin(), arch.end()); - auto system = target_info.at(kJsonKeySystem).get(); - custom_attrs[kJsonKeySystem] = std::vector(system.begin(), system.end()); - auto feature = target_info.at(kJsonKeyCpuFeature).get(); - custom_attrs[kJsonKeyCpuFeature] = std::vector(feature.begin(), feature.end()); - std::string input_shape_str = GetCNodeInputShapeStr(cnode); - std::string output_shape_str = GetCNodeOutputShapeStr(cnode); - std::string output_format_str = GetCNodeOutputFormatStr(cnode); - std::string output_type_str = GetCNodeOutputTypeStr(cnode); - custom_attrs["inputs_shape"] = std::vector(input_shape_str.begin(), input_shape_str.end()); - custom_attrs["outputs_shape"] = std::vector(output_shape_str.begin(), output_shape_str.end()); - custom_attrs["outputs_format"] = std::vector(output_format_str.begin(), output_format_str.end()); - custom_attrs["outputs_type"] = std::vector(output_type_str.begin(), output_type_str.end()); - op->set_attr(custom_attrs); - auto inputs = cnode->inputs(); - inputs[0] = NewValueNode(op->GetPrim()); - auto custom_cnode = func_graph->NewCNode(inputs); - custom_cnode->set_fullname_with_scope(cnode->fullname_with_scope()); - custom_cnode->set_abstract(cnode->abstract()->Clone()); - return custom_cnode; -} - -bool CpuKernelBuilder::CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) { - if (GraphKernelFlags::GetInstance().enable_dynamic_batch) { - for (auto &node : node_list) { - auto gk_fg = GetCNodeFuncGraph(node); - MS_EXCEPTION_IF_NULL(gk_fg); - std::string dynamic_input_index = GetCNodeDynamicInputIndex(node->cast()); - if (!dynamic_input_index.empty()) { - gk_fg->set_attr("dynamic_input_index", MakeValue(dynamic_input_index)); - } - } - } - std::map node_info_map; - std::set uniq_info_names; - dir_path_ = - SaveNodesInfo(node_list, "./akg_kernel_meta", AkgKernelBuilder::json_option(), &node_info_map, &uniq_info_names); - if (dir_path_.empty()) { - return false; - } - ExcludeTunedObj(dir_path_, &uniq_info_names, &node_info_map); - auto res = CompileJsonsInList(dir_path_, std::vector(uniq_info_names.begin(), uniq_info_names.end())); - if (res) { - std::set obj_files; - std::ostringstream objs; - for (const auto &iter : node_info_map) { - AnfUtils::SetNodeAttr("kernel_name", MakeValue(iter.second + "_kernel"), iter.first); - if (obj_files.insert(iter.second).second) { - objs << dir_path_ << "/" << iter.second << ".o "; - } - } - return true; - } - return false; -} - -bool CpuKernelBuilder::GenerateAkgKernelNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &custom_node, - const CNodePtr &old_cnode) { - auto fg = GetCNodeFuncGraph(old_cnode); - auto kernel_name = GetValue(fg->get_attr("kernel_name")).append(".so"); - auto real_kernel_name = kernel_name.substr(0, kernel_name.find("_kernel")).append(".so"); - auto param_node = CreateAkgKernelParameter(func_graph, "./akg_kernel_meta/" + real_kernel_name, real_kernel_name); - if (param_node == nullptr) { - return false; - } - auto manager = Manage(func_graph, true); - manager->AddEdge(custom_node, param_node); - return true; -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/akg/gpu_kernel_builder.cc b/mindspore-lite/tools/graph_kernel/converter/akg/gpu_kernel_builder.cc deleted file mode 100644 index 4b7594c5ad6d35be7bee0abe929d0aa507a5c55c..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/akg/gpu_kernel_builder.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/graph_kernel/converter/akg/gpu_kernel_builder.h" -#include -#include "utils/anf_utils.h" -#include "kernel/graph_kernel/graph_kernel_json_flags.h" -#include "tools/graph_kernel/converter/akg/utils.h" -#include "tools/graph_kernel/converter/akg/cpu_kernel_builder.h" -#include "nlohmann/json.hpp" -using json = nlohmann::json; - -namespace mindspore::graphkernel { -const int C0 = 0; -const int C1 = 1; -const int C2 = 2; -const int C3 = 3; -const int C4 = 4; -const int C5 = 5; -bool GpuKernelBuilder::CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) { - std::set uniq_info_names; - dir_path_ = - SaveNodesInfo(node_list, "./akg_kernel_meta", AkgKernelBuilder::json_option(), &node_info_map_, &uniq_info_names); - if (dir_path_.empty()) { - return false; - } - auto ret = CompileJsonsInList(dir_path_, std::vector(uniq_info_names.begin(), uniq_info_names.end())); - return ret; -} - -std::vector GpuKernelBuilder::ReadThreadBlockFromJson(const std::string &dir_name) { - std::ifstream ifile(dir_name); - json json_info; - ifile >> json_info; - std::vector thread_block_info; - thread_block_info.push_back(std::to_string(static_cast(json_info.at("blockIdx.x")))); - thread_block_info.push_back(std::to_string(static_cast(json_info.at("blockIdx.y")))); - thread_block_info.push_back(std::to_string(static_cast(json_info.at("blockIdx.z")))); - thread_block_info.push_back(std::to_string(static_cast(json_info.at("threadIdx.x")))); - thread_block_info.push_back(std::to_string(static_cast(json_info.at("threadIdx.y")))); - thread_block_info.push_back(std::to_string(static_cast(json_info.at("threadIdx.z")))); - return thread_block_info; -} - -AnfNodePtr GpuKernelBuilder::CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - auto op = std::make_shared(); - op->set_type("GraphKernel"); - auto inputs = cnode->inputs(); - auto prim = op->GetPrim(); - inputs[0] = NewValueNode(prim); - auto custom_cnode = func_graph->NewCNode(inputs); - custom_cnode->set_fullname_with_scope(cnode->fullname_with_scope()); - custom_cnode->set_abstract(cnode->abstract()->Clone()); - prim->set_attr("func_type", MakeValue("custom_akg_gpu")); - prim->set_attr("unique_name", MakeValue("CustomAkgGpu")); - - // set attrs - std::map> custom_attrs; - auto kernel_name = node_info_map_[cnode->cast()]; - custom_attrs["kernel_name"] = std::vector(kernel_name.begin(), kernel_name.end()); - - std::string json_path = dir_path_ + "/" + kernel_name + ".json"; - const std::vector thread_block_info = ReadThreadBlockFromJson(json_path); - std::string ptx_path = dir_path_ + "/" + kernel_name + ".ptx"; - std::string output_shape_str = GetCNodeOutputShapeStr(cnode); - std::string output_format_str = GetCNodeOutputFormatStr(cnode); - std::string output_type_str = GetCNodeOutputTypeStr(cnode); - custom_attrs["outputs_shape"] = std::vector(output_shape_str.begin(), output_shape_str.end()); - custom_attrs["outputs_format"] = std::vector(output_format_str.begin(), output_format_str.end()); - custom_attrs["outputs_type"] = std::vector(output_type_str.begin(), output_type_str.end()); - custom_attrs["GridDimX"] = std::vector(thread_block_info[C0].begin(), thread_block_info[C0].end()); - custom_attrs["GridDimY"] = std::vector(thread_block_info[C1].begin(), thread_block_info[C1].end()); - custom_attrs["GridDimZ"] = std::vector(thread_block_info[C2].begin(), thread_block_info[C2].end()); - custom_attrs["BlockDimX"] = std::vector(thread_block_info[C3].begin(), thread_block_info[C3].end()); - custom_attrs["BlockDimY"] = std::vector(thread_block_info[C4].begin(), thread_block_info[C4].end()); - custom_attrs["BlockDimZ"] = std::vector(thread_block_info[C5].begin(), thread_block_info[C5].end()); - custom_attrs["ptx_path"] = std::vector(ptx_path.begin(), ptx_path.end()); - std::string info_path = dir_path_ + "/" + kernel_name + ".info"; - std::ifstream ifile(info_path); - if (ifile.fail()) { - MS_LOG(ERROR) << "Can not find info at: " << json_path; - return nullptr; - } - json json_info; - ifile >> json_info; - auto process = json_info.at(kJsonKeyProcess).get(); - custom_attrs[kJsonKeyProcess] = std::vector(process.begin(), process.end()); - if (json_info.find(kJsonKeyTargetInfo) != json_info.end()) { - auto target_info = json_info.at(kJsonKeyTargetInfo); - auto compute_capability = target_info.at(kJsonKeyComputeCapability).get(); - custom_attrs[kJsonKeyComputeCapability] = - std::vector(compute_capability.begin(), compute_capability.end()); - auto sm_count = target_info.at(kJsonKeySmCount).get(); - auto sm_count_str = std::to_string(sm_count); - custom_attrs[kJsonKeySmCount] = std::vector(sm_count_str.begin(), sm_count_str.end()); - } - op->set_attr(custom_attrs); - return custom_cnode; -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/akg/utils.cc b/mindspore-lite/tools/graph_kernel/converter/akg/utils.cc deleted file mode 100644 index 2f6fb868482492e6d3a468d220ba4cd177094c2e..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/akg/utils.cc +++ /dev/null @@ -1,305 +0,0 @@ -/** - * Copyright 2021-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/graph_kernel/converter/akg/utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "backend/common/graph_kernel/core/graph_kernel_utils.h" -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "thread/threadpool.h" -#include "tools/common/tensor_util.h" -#include "utils/anf_utils.h" -#include "utils/file_utils.h" -#include "src/common/log_adapter.h" -#include "utils/system/env.h" -#include "mindspore/ccsrc/include/common/debug/common.h" - -namespace mindspore::graphkernel { -bool SaveJsonInfo(const std::string &json_name, const std::string &info) { - std::string path = json_name + ".info"; - std::ofstream filewrite(path); - if (!filewrite.is_open()) { - MS_LOG(ERROR) << "Open file '" << path << "' failed!"; - return false; - } - filewrite << info << std::endl; - filewrite.close(); - return true; -} - -std::string SaveNodesInfo(const AnfNodePtrList &nodes, const std::string &dir, const DumpOption &option, - std::map *node_kernel, std::set *kernel_names) { - auto dir_path = FileUtils::CreateNotExistDirs(dir); - if (!dir_path.has_value()) { - MS_LOG(ERROR) << "Failed to CreateNotExistDirs: " << dir; - return ""; - } - std::set unique_kernel_name; - for (const auto &node : nodes) { - graphkernel::GraphKernelJsonGenerator graph_kernel_json_generator(option); - auto fg = GetCNodeFuncGraph(node); - MS_EXCEPTION_IF_NULL(fg); - auto mng = fg->manager(); - if (mng == nullptr) { - mng = Manage(fg, true); - fg->set_manager(mng); - } - std::vector node_list, input_list, output_list; - auto cnode = dyn_cast_ptr(node); - auto use_akg_cce = false; - if (cnode != nullptr && cnode->HasAttr("use_akg_cce")) { - use_akg_cce = true; - } - GkUtils::GetValidKernelNodes(fg, &node_list, &input_list, &output_list); - (void)graph_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list, use_akg_cce); - auto json_kernel_name = graph_kernel_json_generator.kernel_name(); - if (node_kernel != nullptr) { - (*node_kernel)[node] = json_kernel_name; - } - if (!unique_kernel_name.insert(json_kernel_name).second) { - continue; - } - if (!SaveJsonInfo(dir_path.value() + "/" + json_kernel_name, graph_kernel_json_generator.kernel_json_str())) { - return ""; - } - } - if (kernel_names != nullptr) { - *kernel_names = std::move(unique_kernel_name); - } - return dir_path.value(); -} - -std::string GetCNodeDynamicInputIndex(const CNodePtr &cnode) { - std::string dynamic_input_index; - auto cb = Callback::Instance(); - for (size_t i = 1; i < cnode->size(); i++) { - if (cnode->input(i)->isa() || cnode->input(i)->isa()) { - auto input_shape = cb->GetInputShape(cnode, i - 1); - if (input_shape.size() <= 0 || input_shape[0] != 1) { - MS_LOG(EXCEPTION) << "Dynamic inputs' batch size should be 1"; - } - dynamic_input_index += std::to_string(i - 1) + ","; - } - } - return dynamic_input_index; -} - -std::string GetCNodeInputShapeStr(const CNodePtr &cnode) { - std::string input_shape_str; - auto cb = Callback::Instance(); - for (size_t i = 1; i < cnode->size(); i++) { - auto input_shape = cb->GetInputShape(cnode, i - 1); - input_shape_str += std::to_string(input_shape.size()) + ","; - for (auto &v : input_shape) { - input_shape_str += std::to_string(v) + ","; - } - } - return input_shape_str; -} - -std::string GetCNodeOutputShapeStr(const CNodePtr &cnode) { - std::string output_shape_str; - auto output_num = AnfUtils::GetOutputTensorNum(cnode); - auto cb = Callback::Instance(); - for (size_t i = 0; i < output_num; i++) { - auto output_shape = cb->GetOutputShape(cnode, i); - output_shape_str += std::to_string(output_shape.size()) + ","; - for (auto &v : output_shape) { - output_shape_str += std::to_string(v) + ","; - } - } - return output_shape_str; -} - -std::string GetCNodeOutputTypeStr(const CNodePtr &cnode) { - std::string output_type_str; - auto output_num = AnfUtils::GetOutputTensorNum(cnode); - auto cb = Callback::Instance(); - for (size_t i = 0; i < output_num; i++) { - auto output_type = cb->GetOutputType(cnode, i); - output_type_str += std::to_string(static_cast(output_type)) + ","; - } - return output_type_str; -} - -std::string GetCNodeOutputFormatStr(const CNodePtr &cnode) { - std::string output_format_str; - auto output_num = AnfUtils::GetOutputTensorNum(cnode); - auto cb = Callback::Instance(); - for (size_t i = 0; i < output_num; i++) { - auto output_format = cb->GetOutputFormat(cnode, i); - if (output_format == kOpFormat_NHWC) { - output_format_str += "1,"; - } else { // default, NCHW - output_format_str += "0,"; - } - } - return output_format_str; -} - -ParameterPtr CreateAkgKernelParameter(const FuncGraphPtr &func_graph, const std::string &path, - const std::string &kernel_name) { - MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); - auto param_node = func_graph->add_parameter(); - MS_CHECK_TRUE_RET(param_node != nullptr, nullptr); - param_node->set_name(kernel_name); - if (path.empty()) { - return nullptr; - } - if (!Common::FileExists(path)) { - return nullptr; - } - auto akg_fd = open(path.c_str(), O_RDONLY); - struct stat sb; - if (akg_fd < 0) { - MS_LOG(ERROR) << "open " << path << " failed."; - return nullptr; - } - if (fstat(akg_fd, &sb) == -1) { - MS_LOG(ERROR) << "fstat " << path << " failed."; - return nullptr; - } - auto akg_mmap = mmap(NULL, sb.st_size, PROT_READ, MAP_SHARED, akg_fd, 0); - if (akg_mmap == nullptr) { - MS_LOG(ERROR) << "mmap " << path << " failed."; - return nullptr; - } - (void)close(akg_fd); - auto tensor_info = lite::CreateTensorInfo(akg_mmap, sb.st_size, {sb.st_size}, kNumberTypeUInt8); - if (tensor_info == nullptr) { - MS_LOG(ERROR) << "Create tensor info failed"; - return nullptr; - } - (void)munmap(akg_mmap, sb.st_size); - auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info); - if (status != lite::RET_OK) { - MS_LOG(ERROR) << "init parameter from tensor info failed"; - return nullptr; - } - return param_node; -} - -bool CompileSingleJson(const std::string &json_name) { - std::string attrs = "None"; - std::ostringstream py_cmd; - py_cmd << kAddMSLiteAkg; - py_cmd << "from akg.ms import compilewithjsonname\n"; - py_cmd << "if not compilewithjsonname(\'" << json_name << "\', " << attrs << "):\n"; - py_cmd << " raise RuntimeError(\'Compile fail for json: " << json_name << "\')"; - std::string cmd = "python -c \"" + py_cmd.str() + "\""; - auto ret = std::system(cmd.c_str()); - if (!WIFEXITED(ret)) { - MS_LOG(ERROR) << "Python process start fail! process content is as follows:\n" << cmd; - return false; - } - if (WEXITSTATUS(ret) != 0) { - MS_LOG(ERROR) << "Failed to compile json: " << json_name; - return false; - } - return true; -} - -bool RetStatus(const int status) { - if (WIFEXITED(status)) { - if (WEXITSTATUS(status) == 0) { - MS_LOG(INFO) << "compile all pass for subprocess!"; - return true; - } else { - MS_LOG(ERROR) << "Some jsons compile fail, please check log!"; - } - } else if (WIFSIGNALED(status)) { - MS_LOG(ERROR) << "compile stopped by signal, maybe cost too long time!"; - } else if (WSTOPSIG(status)) { - MS_LOG(ERROR) << "compile process is stopped by others!"; - } else { - MS_LOG(ERROR) << "unknown error in compiling!"; - } - return false; -} - -bool CompileJsonsInList(const std::string &dir_path, const std::vector &json_list) { - auto json_list_size = static_cast(json_list.size()); - auto thread_num = std::min(PROCESS_LIMIT, json_list_size); - if (thread_num == 0) { - return true; - } - auto func = [&](void *cdata, int task_id, float lhs_scale, float rhs_scale) -> int { - bool all_pass{true}; - for (int j = task_id; j < json_list_size; j += PROCESS_LIMIT) { - auto res = CompileSingleJson(dir_path + "/" + json_list[j] + ".info"); - if (!res) { - all_pass = false; - } - } - if (!all_pass) { - MS_LOG(ERROR) << "Some task failed."; - return lite::RET_ERROR; - } - return lite::RET_OK; - }; - auto *pool = ThreadPool::CreateThreadPool(thread_num); - if (pool && pool->ParallelLaunch(func, nullptr, thread_num) == lite::RET_OK) { - return true; - } - return false; -} - -void ExcludeTunedObj(const std::string &dir_path, std::set *kernel_names, - std::map *node_kernel) { - auto fs = system::Env::GetFileSystem(); - std::map tuned_obj_map; // < tuned_signature, best split object name > - for (auto &iter : *node_kernel) { - auto fg = GetCNodeFuncGraph(iter.first); - MS_EXCEPTION_IF_NULL(fg); - auto tuned_sign = fg->has_attr(kTunedSign) ? GetValue(fg->get_attr(kTunedSign)) : ""; - if (tuned_sign == iter.second) { - // the kernel name is the same as signature, find cache. - auto cache = tuned_obj_map.find(tuned_sign); - if (cache != tuned_obj_map.end()) { - iter.second = cache->second; - } - if (!fg->has_attr(kAttrNodeName)) { - continue; - } - auto best_split_kernel = std::string("best_split_") + GetValue(fg->get_attr(kAttrNodeName)); - auto best_split_file = dir_path + "/" + best_split_kernel + ".o"; - if (!fs->FileExist(best_split_file)) { - continue; - } - // the cache file exists, use it. - tuned_obj_map[tuned_sign] = best_split_kernel; - iter.second = best_split_kernel; - (void)kernel_names->erase(tuned_sign); - MS_LOG(INFO) << "Reuse the object file " << best_split_file; - } else { - if (!tuned_sign.empty()) { - MS_LOG(INFO) << "The kernel_name of " << iter.first->fullname_with_scope() << " mismatch its signature. " - << "kernel_name is " << iter.second << ", and tuned_signature is " << tuned_sign; - } - } - } -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/akg/utils.h b/mindspore-lite/tools/graph_kernel/converter/akg/utils.h deleted file mode 100644 index 879675af2d8e83dcf9ba45ccb4c4fac436aebaf9..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/akg/utils.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_UTILS_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_UTILS_H_ -#include -#include -#include -#include -#include "utils/anf_utils.h" -#include "kernel/graph_kernel/graph_kernel_json_generator.h" - -constexpr auto kAkgKernelSo = "akgkernels.so"; -namespace mindspore::graphkernel { -constexpr int PROCESS_LIMIT = 8; -constexpr size_t TIME_OUT = 100; -constexpr auto kTunedSign = "tuned_signature"; -constexpr auto kAddMSLiteAkg = - "try:\n" - " import mindspore_lite.akg as akg\n" - "except Exception:\n" - " import akg as akg\n"; - -std::string SaveNodesInfo(const AnfNodePtrList &nodes, const std::string &dir, const DumpOption &option, - std::map *node_kernel, std::set *kernel_names); -std::string GetCNodeDynamicInputIndex(const CNodePtr &cnode); -std::string GetCNodeInputShapeStr(const CNodePtr &cnode); -std::string GetCNodeOutputShapeStr(const CNodePtr &cnode); -std::string GetCNodeOutputTypeStr(const CNodePtr &cnode); -std::string GetCNodeOutputFormatStr(const CNodePtr &cnode); -ParameterPtr CreateAkgKernelParameter(const FuncGraphPtr &func_graph, const std::string &path, - const std::string &kernel_name); -bool CompileJsonsInList(const std::string &dir_path, const std::vector &json_list); -void ExcludeTunedObj(const std::string &dir_path, std::set *kernel_names, - std::map *node_kernel); -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_UTILS_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/basic_op_infer_shape.cc b/mindspore-lite/tools/graph_kernel/converter/basic_op_infer_shape.cc deleted file mode 100644 index 43dad81a33e2d828d3e05e691dbb1ccce957d9af..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/basic_op_infer_shape.cc +++ /dev/null @@ -1,353 +0,0 @@ -/** - * Copyright 2022-2025 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/graph_kernel/converter/basic_op_infer_shape.h" - -#include -#include -#include -#include -#include -#include - -#include "mindspore/ops/op_def/sequence_ops.h" -#include "mindspore/ops/op_def/framework_ops.h" -#include "backend/common/graph_kernel/core/graph_kernel_callback.h" -#include "utils/anf_utils.h" -#include "src/common/ops/anf_utils.h" -#include "src/common/primitive_t_utils.h" -#include "src/common/ops/populate/populate_register.h" -#include "tools/optimizer/graph/lite_tensor_extractor.h" -#include "src/litert/infer_manager.h" -#include "mindspore/ccsrc/include/common/utils/utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" - -namespace mindspore::graphkernel { -namespace { -void SetAbstractShape(const abstract::AbstractBasePtr &abs, const BaseShapePtr &shape) { - MS_EXCEPTION_IF_NULL(abs); - abs->set_shape(shape); -} - -void SetAbstract(const CNodePtr &cnode) { - if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { - auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem); - auto item_idx = LongToSize(AnfUtils::GetIntValue(input2)); - auto abs_tuple = dyn_cast(AnfUtils::VisitKernel(cnode, item_idx).first->abstract()); - MS_EXCEPTION_IF_NULL(abs_tuple); - cnode->set_abstract(abs_tuple->elements()[item_idx]); - return; - } - if (IsOneOfPrimitiveCNode(cnode, {prim::kPrimDepend, prim::kPrimLoad, prim::kPrimUpdateState})) { - cnode->set_abstract(cnode->input(1)->abstract()); - return; - } -} - -BaseShapePtr AllGatherInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - if (input_args.empty()) { - return nullptr; - } - MS_EXCEPTION_IF_NULL(primitive); - auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape(primitive->name(), input_args, 0); - MS_EXCEPTION_IF_NULL(shape_ptr); - auto x_shape = shape_ptr->shape(); - auto rank_list = primitive->GetAttr("rank_list"); - if (rank_list->isa()) { - auto rank_list_ptr = rank_list->cast(); - MS_EXCEPTION_IF_NULL(rank_list_ptr); - auto out_shape = x_shape; - if (!out_shape.empty() && out_shape[0] > 0) { - out_shape[0] *= SizeToLong(rank_list_ptr->size()); - } - return std::make_shared(out_shape); - } - return nullptr; -} - -BaseShapePtr ShapeInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - if (input_args.empty()) { - return nullptr; - } - MS_EXCEPTION_IF_NULL(primitive); - auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape(primitive->name(), input_args, 0); - MS_EXCEPTION_IF_NULL(shape_ptr); - auto x_shape = shape_ptr->shape(); - int64_t rank = IsDynamicRank(x_shape) ? -1 : SizeToLong(x_shape.size()); - return std::make_shared(ShapeVector{rank}); -} - -AbstractBasePtrList RectifyBatchMatMul(const AbstractBasePtrList &orig_abs_list) { - AbstractBasePtrList abs_list = orig_abs_list; - if (abs_list.size() < kIndex2) { - return abs_list; - } - auto shape0_ptr = CheckAndConvertUtils::GetTensorInputShape("BatchMatMul", orig_abs_list, 0); - MS_EXCEPTION_IF_NULL(shape0_ptr); - auto x0_shape = shape0_ptr->shape(); - auto shape1_ptr = CheckAndConvertUtils::GetTensorInputShape("BatchMatMul", orig_abs_list, 1); - MS_EXCEPTION_IF_NULL(shape1_ptr); - auto x1_shape = shape1_ptr->shape(); - if (x0_shape.size() < x1_shape.size()) { - ShapeVector new_x0_shape(x1_shape.size() - x0_shape.size(), 1); - new_x0_shape.insert(new_x0_shape.end(), x0_shape.begin(), x0_shape.end()); - abs_list[0] = orig_abs_list[0]->Clone(); - SetAbstractShape(abs_list[0], std::make_shared(new_x0_shape)); - } else if (x0_shape.size() > x1_shape.size()) { - ShapeVector new_x1_shape(x0_shape.size() - x1_shape.size(), 1); - new_x1_shape.insert(new_x1_shape.end(), x1_shape.begin(), x1_shape.end()); - abs_list[1] = orig_abs_list[1]->Clone(); - SetAbstractShape(abs_list[1], std::make_shared(new_x1_shape)); - } - return abs_list; -} - -using OpInferFunc = std::function &)>; -using OpRectifyFunc = std::function; -} // namespace - -inline mindspore::Format FormatStringToEnum(const std::string &format) { - std::unordered_map format_converter = {{kOpFormat_NHWC, mindspore::NHWC}, - {kOpFormat_NCHW, mindspore::NCHW}}; - auto iter = format_converter.find(format); - if (iter == format_converter.end()) { - MS_LOG(WARNING) << "Unsupported format [" << format << "] in GraphKernel"; - return mindspore::DEFAULT_FORMAT; - } - return iter->second; -} - -void ExtractInputs(const CNodePtr &cnode, std::vector *inputs_holder, std::vector *inputs) { - std::vector const_inputs; - size_t const_index = 0; - if (opt::LiteTensorExtractor::GetCNodeConstInputs(cnode, converter::kFmkTypeMs, false, false, &const_inputs) != - lite::RET_OK) { - MS_LOG(ERROR) << "get const inputs failed."; - return; - } - auto cb = Callback::Instance(); - for (size_t index = 1; index < cnode->size(); index++) { - if (cnode->input(index)->isa()) { - std::vector shape; - ShapeVector shp = cb->GetInputShape(cnode, index - 1); - (void)std::transform(shp.begin(), shp.end(), std::back_inserter(shape), LongToInt); - auto format = cb->GetInputFormat(cnode, index - 1); - (void)inputs_holder->emplace_back( - std::make_shared(cb->GetInputType(cnode, index - 1), shape, FormatStringToEnum(format))); - } else { - if (const_index >= const_inputs.size()) { - MS_LOG(WARNING) << "const_index " << const_index << " is out of range of const_inputs " << const_inputs.size(); - } else { - (void)inputs_holder->emplace_back(const_inputs[const_index++]); - } - } - } - (void)std::transform(inputs_holder->cbegin(), inputs_holder->cend(), std::back_inserter(*inputs), - [](const TensorPtr &input) { return input.get(); }); -} - -void ExtractOutputs(const CNodePtr &cnode, std::vector *out_holder, std::vector *outputs) { - auto cb = Callback::Instance(); - size_t output_num = AnfUtils::GetOutputTensorNum(cnode); - for (size_t index = 0; index < output_num; index++) { - auto format = cb->GetOutputFormat(cnode, index); - (void)out_holder->emplace_back( - std::make_shared(cb->GetOutputType(cnode, index), std::vector(), FormatStringToEnum(format))); - } - (void)std::transform(out_holder->cbegin(), out_holder->cend(), std::back_inserter(*outputs), - [](const TensorPtr &output) { return output.get(); }); -} - -void BasicOpInferShape::InferShapeRealKernel(const CNodePtr &cnode) { - auto anf_prim = GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(anf_prim); - (void)anf_prim->AddAttr(opt::kInferDone, MakeValue(false)); - - std::vector inputs_holder; - std::vector inputs; - ExtractInputs(cnode, &inputs_holder, &inputs); - - std::vector outputs_holder; - std::vector outputs; - ExtractOutputs(cnode, &outputs_holder, &outputs); - - auto prim_t = lite::GetPrimitiveT(cnode->input(0)); - if (prim_t == nullptr) { - MS_LOG(DEBUG) << "prim_t is nullptr"; - return; - } - const size_t INITIAL_SIZE = 1024; - flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE); - auto prim = lite::ConvertToPrimitive(prim_t.get(), &fbb); - if (prim == nullptr) { - MS_LOG(ERROR) << "get primitive failed."; - fbb.Clear(); - return; - } - - auto ret = lite::KernelInferShape(inputs, outputs, prim, {}, lite::SCHEMA_CUR); - if (ret == lite::RET_NOT_SUPPORT) { - auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator( - static_cast(prim->value_type()), lite::SCHEMA_CUR); - if (parameter_gen == nullptr) { - MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type()); - fbb.Clear(); - return; - } - auto parameter = parameter_gen(prim); - if (parameter == nullptr) { - MS_LOG(ERROR) << "parameter is nullptr."; - fbb.Clear(); - return; - } - ret = lite::KernelInferShape(inputs, outputs, parameter); - if (parameter->destroy_func_ != nullptr) { - parameter->destroy_func_(parameter); - } - free(parameter); - parameter = nullptr; - } - fbb.Clear(); - if (ret == lite::RET_OK) { - (void)anf_prim->AddAttr(opt::kInferDone, MakeValue(true)); - } - if (ret == lite::RET_OK || ret == lite::RET_INFER_INVALID) { - (void)SetCNodeAbstract(cnode, outputs, ret); - } else { - MS_LOG(WARNING) << "infer shape failed. node: " << cnode->fullname_with_scope(); - } -} - -void BasicOpInferShape::InsertAbstract(const CNodePtr &cnode) { SetAbstract(cnode); } - -void BasicOpInferShape::Infer(const CNodePtr &cnode) { - if (AnfUtils::IsRealKernel(cnode)) { - InferShapeRealKernel(cnode); - } else { - InsertAbstract(cnode); - } -} - -bool DynOpInferShape::HasDynamicShapeInput(const FuncGraphPtr &func_graph) const { - MS_EXCEPTION_IF_NULL(func_graph); - const auto ¶ms = func_graph->parameters(); - for (const auto ¶m : params) { - if (param == nullptr) { - continue; - } - auto param_shape = param->Shape(); - if (param_shape != nullptr && param_shape->IsDynamic()) { - return true; - } - } - return false; -} - -bool DynOpInferShape::InferShapeRealKernel(const CNodePtr &cnode) const { - auto prim = GetCNodePrimitive(cnode); - if (prim == nullptr) { - MS_LOG(ERROR) << "prim is nullptr for cnode: " << cnode->fullname_with_scope(); - return false; - } - // collect op inputs abstract - AbstractBasePtrList abs_list; - abs_list.reserve(cnode->size()); - for (size_t i = 1; i < cnode->size(); ++i) { - const auto &input = cnode->input(i); - if (input == nullptr) { - continue; - } - auto abs = input->abstract(); - if (abs == nullptr && input->isa()) { - auto value_ptr = input->cast(); - MS_EXCEPTION_IF_NULL(value_ptr); - auto v = value_ptr->value(); - MS_EXCEPTION_IF_NULL(v); - abs = v->ToAbstract(); - } - if (abs == nullptr) { - MS_LOG(ERROR) << "inputs[" << i << "] has no abstract for cnode: " << cnode->fullname_with_scope(); - return false; - } - abs_list.push_back(abs); - } - // some op has no C++ infer - static std::unordered_map infer_func_map{{"AllGather", AllGatherInferShape}}; - auto prim_name = prim->name(); - auto iter = infer_func_map.find(prim_name); - if (iter != infer_func_map.end()) { - SetAbstractShape(cnode->abstract(), iter->second(prim, abs_list)); - return true; - } - // core/ops 'Shape' returns AbstractTuple, which will change the original abstract type - if (prim_name == "Shape" && cnode->abstract()->isa()) { - SetAbstractShape(cnode->abstract(), ShapeInferShape(prim, abs_list)); - return true; - } - // some op's abstract does not satisfy core/ops infer - if (prim_name == "StridedSlice" || prim_name == "PromptFlashAttention") { - return true; - } - static std::unordered_map rectify_map{{"BatchMatMul", RectifyBatchMatMul}}; - auto rec_iter = rectify_map.find(prim_name); - if (rec_iter != rectify_map.end()) { - abs_list = rec_iter->second(abs_list); - } - auto found = abstract::GetPrimitiveInferImpl(prim); - if (found.has_value() && found.value().IsImplInferShapeAndType()) { - auto infer_impl = found.value(); - SetAbstractShape(cnode->abstract(), infer_impl.InferShape(prim, abs_list)); - return true; - } - MS_LOG(ERROR) << "Can not find infer shape function for " << prim_name; - return false; -} - -bool DynOpInferShape::InferShape(const CNodePtr &cnode) const { - if (AnfUtils::IsRealKernel(cnode)) { - if (!InferShapeRealKernel(cnode)) { - MS_LOG(ERROR) << "infer shape failed for cnode: " << cnode->fullname_with_scope(); - return false; - } - } else { - if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) { - return true; - } - SetAbstract(cnode); - } - return true; -} - -bool DynOpInferShape::Run(const FuncGraphPtr &func_graph) { - if (!HasDynamicShapeInput(func_graph)) { - return false; - } - MS_LOG(INFO) << "Dynamic shape infer for func graph: " << func_graph->ToString(); - auto nodes = TopoSort(func_graph->output()); - for (const auto &node : nodes) { - if (node->isa()) { - auto cnode = node->cast(); - if (!InferShape(cnode)) { - break; - } - } - } - return true; -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/basic_op_infer_shape.h b/mindspore-lite/tools/graph_kernel/converter/basic_op_infer_shape.h deleted file mode 100644 index b655eff766081bdee3e8b8640464427946baf327..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/basic_op_infer_shape.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_BASIC_OP_INFER_SHAPE_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_BASIC_OP_INFER_SHAPE_H_ -#include "tools/optimizer/graph/node_infershape.h" -#include "include/backend/optimizer/pass.h" - -namespace mindspore::graphkernel { -class BasicOpInferShape : public opt::NodeInferShape { - public: - BasicOpInferShape() : opt::NodeInferShape() {} - ~BasicOpInferShape() = default; - void Infer(const CNodePtr &cnode); - - private: - void InferShapeRealKernel(const CNodePtr &cnode); - void InsertAbstract(const CNodePtr &cnode); -}; - -class DynOpInferShape : public opt::Pass { - public: - DynOpInferShape() : Pass("dynamic_infer_shape") {} - ~DynOpInferShape() override = default; - bool Run(const FuncGraphPtr &func_graph) override; - - private: - bool HasDynamicShapeInput(const FuncGraphPtr &func_graph) const; - bool InferShapeRealKernel(const CNodePtr &cnode) const; - bool InferShape(const CNodePtr &cnode) const; -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_BASIC_OP_INFER_SHAPE_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/callback_impl.cc b/mindspore-lite/tools/graph_kernel/converter/callback_impl.cc deleted file mode 100644 index a3b019a55e057cba2b584d1e571603e7eb8554d1..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/callback_impl.cc +++ /dev/null @@ -1,192 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/graph_kernel/converter/callback_impl.h" - -#include -#include -#include -#include -#include "mindspore/ops/op_def/sequence_ops.h" -#include "ir/dtype.h" -#include "ir/func_graph.h" -#include "utils/anf_utils.h" -#include "include/common/utils/utils.h" -#include "tools/graph_kernel/common/utils.h" -#include "backend/common/graph_kernel/core/graph_kernel_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" - -namespace mindspore::graphkernel { -ShapeVector CallbackImpl::GetInputShape(const AnfNodePtr &node, size_t i) { return GetInputInferShape(node, i); } - -ShapeVector CallbackImpl::GetOutputShape(const AnfNodePtr &node, size_t i) { return GetOutputInferShape(node, i); } - -ShapeVector CallbackImpl::GetInputInferShape(const AnfNodePtr &node, size_t i) { - MS_EXCEPTION_IF_NULL(node); - KernelWithIndex kernel_with_index = AnfUtils::VisitKernel(node->cast()->input(i + 1), 0); - return GetOutputInferShape(kernel_with_index.first, kernel_with_index.second); -} - -ShapeVector CallbackImpl::GetOutputInferShape(const AnfNodePtr &node, size_t i) { - MS_EXCEPTION_IF_NULL(node); - auto base_shape = node->Shape(); - MS_EXCEPTION_IF_NULL(base_shape); - if (base_shape->isa()) { - if (i == 0) { - return base_shape->cast()->shape(); - } - MS_LOG(EXCEPTION) << "The node " << node->DebugString() << " is a single output node but got index [" << i << "]"; - } else if (base_shape->isa()) { - auto tuple_shape = base_shape->cast(); - MS_EXCEPTION_IF_NULL(tuple_shape); - if (i >= tuple_shape->size()) { - MS_LOG(EXCEPTION) << "Output index " << i << " is larger than output number " << tuple_shape->size() - << " in node " << node->DebugString(); - } - auto b_shp = (*tuple_shape)[i]; - if (b_shp->isa()) { - return b_shp->cast()->shape(); - } else if (b_shp->isa()) { - return ShapeVector(); - } else { - MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << i - << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString() - << " node :" << node->DebugString(); - } - } else if (base_shape->isa()) { - return ShapeVector(); - } - MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is " - << base_shape->ToString() << " node : " << node->DebugString(); -} - -TypeId CallbackImpl::GetInputType(const AnfNodePtr &node, size_t i) { return GetInputInferType(node, i); } - -TypeId CallbackImpl::GetOutputType(const AnfNodePtr &node, size_t i) { return GetOutputInferType(node, i); } - -TypeId CallbackImpl::GetInputInferType(const AnfNodePtr &node, size_t i) { - KernelWithIndex kernel_with_index = AnfUtils::VisitKernel(node->cast()->input(i + 1), 0); - return GetOutputInferType(kernel_with_index.first, kernel_with_index.second); -} - -TypeId CallbackImpl::GetOutputInferType(const AnfNodePtr &node, size_t i) { - MS_EXCEPTION_IF_NULL(node); - TypePtr type_ptr = node->Type(); - MS_EXCEPTION_IF_NULL(type_ptr); - if (type_ptr->isa()) { - auto tuple_ptr = type_ptr->cast(); - MS_EXCEPTION_IF_NULL(tuple_ptr); - if (i >= tuple_ptr->size()) { - MS_LOG(EXCEPTION) << "Output index " << i << " must be less than output number " << tuple_ptr->size() - << " in node " << node->DebugString(); - } - type_ptr = (*tuple_ptr)[i]; - MS_EXCEPTION_IF_NULL(type_ptr); - } - if (type_ptr->isa()) { - auto tensor_ptr = type_ptr->cast(); - MS_EXCEPTION_IF_NULL(tensor_ptr); - TypePtr elem = tensor_ptr->element(); - MS_EXCEPTION_IF_NULL(elem); - return elem->type_id(); - } - return type_ptr->type_id(); -} - -std::string GetDefaultFormat() { return kOpFormat_DEFAULT; } - -std::string CallbackImpl::GetInputFormat(const AnfNodePtr &node, size_t i) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto kernel_with_index = AnfUtils::VisitKernel(cnode->input(i + 1), 0); - return GetOutputFormat(kernel_with_index.first, kernel_with_index.second); -} - -std::string CallbackImpl::GetOutputFormat(const AnfNodePtr &node, size_t i) { - if (node->isa() || node->isa()) { - return GetOutputFormatFromAnfNode(node, 0); - } else { - return GetDefaultFormat(); - } -} - -std::string CallbackImpl::GetProcessor(const AnfNodePtr &node) { - auto target = GetTargetFromContextImpl(false); - if (target == "Ascend") { - return "aicore"; - } - if (target == "GPU") { - return "cuda"; - } - return "cpu"; -} - -std::string CallbackImpl::GetTargetFromContextImpl(bool detail) { - const auto &target = converter_param_->device; - if (target.find("Ascend") != std::string::npos) { - // target is Ascend/Ascend310/Ascend310P - if (!detail) { - return "Ascend"; - } - return target == "Ascend" ? "Ascend910" : target; - } - return !target.empty() ? target : "CPU"; -} - -void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) { - std::vector graph_output_format; - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto fg = GetCNodeFuncGraph(node); - MS_EXCEPTION_IF_NULL(fg); - AnfNodePtrList outputs; - if (IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) { - auto fg_output = fg->output()->cast(); - MS_EXCEPTION_IF_NULL(fg_output); - outputs.assign(fg_output->inputs().begin() + 1, fg_output->inputs().end()); - } else { - outputs.push_back(fg->output()); - } - for (size_t i = 0; i < outputs.size(); ++i) { - auto kernel_with_index = AnfUtils::VisitKernel(outputs[i], 0); - graph_output_format.push_back(GetOutputFormat(kernel_with_index.first, kernel_with_index.second)); - } - SetKernelInfoWithFormatToAnfNode(node, graph_output_format); - - auto inner_fg = GetCNodeFuncGraph(cnode); - MS_EXCEPTION_IF_NULL(inner_fg); - for (size_t i = 1; i < cnode->size(); ++i) { - SaveParameterFormat(inner_fg->parameters()[i - 1], GetInputFormat(node, i - 1)); - } -} - -void CallbackImpl::SetBasicNodeKernelInfo(const AnfNodePtr &node, const std::vector &outputs_info) { - std::vector output_formats; - for (size_t i = 0; i < outputs_info.size(); ++i) { - output_formats.push_back(outputs_info[i].format); - } - SetKernelInfoWithFormatToAnfNode(node, output_formats); -} - -void CallbackImpl::SaveParameterFormat(const AnfNodePtr &node, const std::string &format) { - std::vector output_format(1, format); - SetKernelInfoWithFormatToAnfNode(node, output_format); -} - -void CallbackImpl::SetEmptyKernelInfo(const AnfNodePtr &node) {} - -void CallbackImpl::ResetKernelInfo(const AnfNodePtr &node) {} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/callback_impl.h b/mindspore-lite/tools/graph_kernel/converter/callback_impl.h deleted file mode 100644 index b4a5d6ab2824ea0c2bde40326b440b715efee53c..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/callback_impl.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_CALLBACK_IMPL_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_CALLBACK_IMPL_H_ -#include -#include -#include -#include -#include "utils/hash_map.h" -#include "backend/common/graph_kernel/core/graph_kernel_callback.h" -#include "tools/converter/cxx_api/converter_para.h" - -namespace mindspore::graphkernel { -using KernelWithIndex = std::pair; - -// to do: add this function to callback class. -// Get default format for format flexible nodes. -std::string GetDefaultFormat(); - -class CallbackImpl : public Callback { - public: - explicit CallbackImpl(const std::shared_ptr ¶m) : converter_param_(param) {} - ~CallbackImpl() = default; - - ShapeVector GetInputInferShape(const AnfNodePtr &node, size_t i) override; - ShapeVector GetOutputInferShape(const AnfNodePtr &node, size_t i) override; - ShapeVector GetInputShape(const AnfNodePtr &node, size_t i) override; - ShapeVector GetOutputShape(const AnfNodePtr &node, size_t i) override; - TypeId GetInputType(const AnfNodePtr &node, size_t i) override; - TypeId GetOutputType(const AnfNodePtr &node, size_t i) override; - TypeId GetInputInferType(const AnfNodePtr &node, size_t i) override; - TypeId GetOutputInferType(const AnfNodePtr &node, size_t i) override; - std::string GetInputFormat(const AnfNodePtr &node, size_t i) override; - std::string GetOutputFormat(const AnfNodePtr &node, size_t i) override; - std::string GetProcessor(const AnfNodePtr &node) override; - void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) override; - void SetBasicNodeKernelInfo(const AnfNodePtr &node, const std::vector &outputs_info) override; - void SetEmptyKernelInfo(const AnfNodePtr &node) override; - void ResetKernelInfo(const AnfNodePtr &node) override; - void ResetKernelInfoInputs(const AnfNodePtr &, const std::vector &indices) override {} - - private: - std::string GetTargetFromContextImpl(bool detail) override; - void SaveParameterFormat(const AnfNodePtr &node, const std::string &format); - mindspore::HashMap params_format_; - std::shared_ptr converter_param_; -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_CALLBACK_IMPL_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/conv_tuning_expander.cc b/mindspore-lite/tools/graph_kernel/converter/conv_tuning_expander.cc deleted file mode 100644 index 3994e2502a915921b58750ae4fff488f238e1264..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/conv_tuning_expander.cc +++ /dev/null @@ -1,266 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/graph_kernel/converter/conv_tuning_expander.h" -#include -#include -#include -#include -#include - -#include "backend/common/graph_kernel/core/graph_kernel_callback.h" -#include "backend/common/graph_kernel/core/graph_kernel_utils.h" -#include "backend/common/graph_kernel/graph_kernel_flags.h" -#include "mindspore/ops/op_def/lite_ops.h" -#include "nlohmann/json.hpp" -#include "tools/graph_kernel/converter/akg/utils.h" -#include "utils/anf_utils.h" -#include "utils/file_utils.h" -#include "utils/hash_set.h" -#include "utils/ms_context.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" - -namespace mindspore::graphkernel { -bool IsSameNumberList(const std::vector &vec, int64_t n) { - return std::all_of(vec.begin(), vec.end(), [n](int64_t i) { return i == n; }); -} - -bool InvalidConvAttr(const std::vector &kernel_size, const std::vector &pad, - const std::vector &stride, const std::vector &dilation) { - constexpr int64_t one_kernel_size = 1; - constexpr int64_t two_kernel_size = 2; - constexpr int64_t winograd_kernel_size = 3; - if (!IsSameNumberList(pad, 0LL)) { - return true; - } - if ((IsSameNumberList(kernel_size, one_kernel_size) || IsSameNumberList(kernel_size, two_kernel_size) || - IsSameNumberList(kernel_size, winograd_kernel_size)) && - IsSameNumberList(stride, 1LL) && IsSameNumberList(dilation, 1LL)) { - return true; - } - return false; -} - -bool IsInvalidConv(const AnfNodePtr &node) { - auto cb = Callback::Instance(); - auto input_shape = cb->GetInputShape(node, 0); - if (input_shape.size() == 0) { - return true; - } - auto prim = GetCNodePrimitive(node); - MS_EXCEPTION_IF_NULL(prim); - const auto kernel_size = GetValue>(prim->GetAttr("kernel_size")); - const auto pad = GetValue>(prim->GetAttr("pad")); - const auto stride = GetValue>(prim->GetAttr("stride")); - const auto dilation = GetValue>(prim->GetAttr("dilation")); - if (InvalidConvAttr(kernel_size, pad, stride, dilation)) { - return true; - } - return false; -} - -bool IsBlackListOp(const AnfNodePtr &node) { - std::vector black_list = {prim::kPrimMatMulFusion}; - for (auto &prim : black_list) { - if (IsPrimitiveCNode(node, prim)) { - return true; - } - if (IsPrimitiveCNode(node, prim::kPrimConv2DFusion) && IsInvalidConv(node)) { - return true; - } - } - return false; -} - -std::vector GenConvShape(const AnfNodePtr &conv_node) { - auto cb = Callback::Instance(); - auto input_shape = cb->GetInputShape(conv_node, 0); - auto prim = GetCNodePrimitive(conv_node); - auto pads = GetValue>(prim->GetAttr("pad_list")); - constexpr size_t n_pos = 0; - constexpr size_t h_pos = 1; - constexpr size_t w_pos = 2; - constexpr size_t c_pos = 3; - input_shape[h_pos] += pads[n_pos] + pads[h_pos]; - input_shape[w_pos] += pads[w_pos] + pads[c_pos]; - return input_shape; -} - -nlohmann::json GenTuneInfo(const AnfNodePtr &conv_node, const std::map &former_conv_nodes, - const AnfNodePtrList &conv_list) { - nlohmann::json node_info; - auto prim = GetCNodePrimitive(conv_node); - node_info["op_id"] = std::find(conv_list.begin(), conv_list.end(), conv_node) - conv_list.begin(); - node_info["op_type"] = "Conv2D"; - node_info["impl"] = "direct"; - node_info["origin_input_shape"] = GenConvShape(conv_node); - MS_EXCEPTION_IF_NULL(prim); - node_info["dilation"] = GetValue>(prim->GetAttr("dilation")); - node_info["origin_format"] = "NHWC"; - node_info["group"] = GetValue(prim->GetAttr("group")); - node_info["in_channel"] = GetValue(prim->GetAttr("in_channel")); - node_info["kernel_size"] = GetValue>(prim->GetAttr("kernel_size")); - node_info["out_channel"] = GetValue(prim->GetAttr("out_channel")); - node_info["stride"] = GetValue>(prim->GetAttr("stride")); - std::vector prev_infos; - if (former_conv_nodes.empty()) { - nlohmann::json prev_info; - prev_info["op_id"] = -1; - prev_info["fixed_format"] = "Any"; - (void)prev_infos.emplace_back(prev_info); - } else { - for (auto iter : former_conv_nodes) { - nlohmann::json prev_info; - prev_info["op_id"] = std::find(conv_list.begin(), conv_list.end(), iter.first) - conv_list.begin(); - if (iter.second) { - // has black list node - prev_info["fixed_format"] = "NHWC"; - } else { - prev_info["fixed_format"] = "Any"; - } - (void)prev_infos.emplace_back(prev_info); - } - } - node_info["pre_nodes"] = prev_infos; - - return node_info; -} - -void TuneProcess(const std::string &json_file_name, const std::string &res_file_name, const std::string &akg_path) { - std::ostringstream py_cmd; - const auto &flags = GraphKernelFlags::GetInstance(); - py_cmd << kAddMSLiteAkg; - py_cmd << "import auto_tune\n"; - py_cmd << "auto_tune.tune_layout(\'" << json_file_name << "\', \'" << res_file_name << "\', " - << flags.cpu_refer_thread_num << ")\n"; - std::string cmd = "python -c \"" + py_cmd.str() + "\""; - MS_LOG(INFO) << "GraphKernel conv tuning content: \n" << cmd; - auto ret = std::system(cmd.c_str()); - if (!WIFEXITED(ret)) { - MS_LOG(ERROR) << "Tune process start fail! process content is as follows:\n" << cmd; - } - if (WEXITSTATUS(ret) != 0) { - MS_LOG(ERROR) << "Failed to tune json: " << json_file_name; - } -} - -void SetTuneAttrs(const AnfNodePtrList &conv_list, const std::string &res_file) { - std::ifstream f(res_file); - if (!f.is_open()) { - MS_LOG(WARNING) << "No conv tuning results!"; - return; - } - nlohmann::json tune_res; - f >> tune_res; - f.close(); - for (auto op_info : tune_res["graph"]) { - auto prim = GetCNodePrimitive(conv_list[op_info["op_id"]]); - prim->set_attr("tuned_src_format", MakeValue(std::string(op_info["src_format"]))); - prim->set_attr("tuned_dst_format", MakeValue(std::string(op_info["dst_format"]))); - prim->set_attr("tuned_dim", MakeValue(std::string(op_info["tuned_attrs"]["dim"]))); - prim->set_attr("akg_num_threads", MakeValue(int64_t(op_info["tuned_attrs"]["akg_num_threads"]))); - } -} - -void TuneConvOps(const AnfNodePtrList &conv_list) { - auto dir_path = FileUtils::CreateNotExistDirs(std::string("./conv_tune")); - if (!dir_path.has_value()) { - MS_LOG(ERROR) << "Failed to CreateNotExistDirs: ./conv_tune, start tuning failed"; - return; - } - nlohmann::json tune_info; - std::vector conv_infos; - for (auto &conv_node : conv_list) { - MS_EXCEPTION_IF_NULL(conv_node); - mindspore::HashSet visited; - std::function dfs; - std::map former_conv_nodes; - dfs = [&dfs, &former_conv_nodes, &visited](const AnfNodePtr &node) { - (void)visited.insert(node); - auto cnode = node->cast(); - if (cnode != nullptr) { - auto inputs = cnode->inputs(); - bool has_black_node = false; - for (size_t i = 1; i < inputs.size(); i++) { - if (inputs[i]->cast() == nullptr || visited.count(inputs[i]) != 0) { - continue; - } else if (IsPrimitiveCNode(inputs[i], prim::kPrimConv2DFusion) && !IsInvalidConv(inputs[i])) { - former_conv_nodes[inputs[i]] = has_black_node; - has_black_node = false; - continue; - } else { - if (IsBlackListOp(inputs[i])) { - has_black_node = true; - } - dfs(inputs[i]); - } - } - } - }; - dfs(conv_node); - (void)conv_infos.emplace_back(GenTuneInfo(conv_node, former_conv_nodes, conv_list)); - } - tune_info["graph"] = conv_infos; - tune_info["backend"] = "cpu"; - tune_info["feature"] = common::GetEnv("MS_CPU_FEATURE"); - std::string input_file = dir_path.value() + "/input.json"; - std::string output_file = dir_path.value() + "/output.json"; - std::string akg_path = dir_path.value() + "/akg_path.txt"; - std::ofstream fout(input_file, std::ios::trunc); - fout << tune_info.dump() << std::endl; - fout.close(); - TuneProcess(input_file, output_file, akg_path); - SetTuneAttrs(conv_list, output_file); -} - -std::vector ConvTuningExpander::InitOpList() { - auto expand_only_list = GraphKernelFlags::GetInstance().enable_expand_ops_only; - auto conv_expand_list = GraphKernelExpanderLite::ConvTuningExpanderOps(); - if (expand_only_list.empty()) { - return conv_expand_list; - } - std::vector conv_only_list; - for (auto conv_expand : conv_expand_list) { - if (std::find(expand_only_list.begin(), expand_only_list.end(), conv_expand->name()) != expand_only_list.end()) { - conv_only_list.emplace_back(conv_expand); - } - } - return conv_only_list; -} - -bool ConvTuningExpander::Run(const FuncGraphPtr &func_graph) { - if (GraphKernelExpanderLite::DisableConvTuning()) { - return false; - } - bool changed = false; - auto valid_op_list = InitOpList(); - if (std::find(valid_op_list.begin(), valid_op_list.end(), prim::kPrimConv2DFusion) != valid_op_list.end()) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(func_graph->get_return()); - auto todos = TopoSort(func_graph->get_return()); - AnfNodePtrList conv_list; - for (auto &node : todos) { - if (IsPrimitiveCNode(node, prim::kPrimConv2DFusion) && !IsInvalidConv(node)) { - (void)conv_list.emplace_back(node->cast()); - changed = true; - } - } - TuneConvOps(conv_list); - } - changed = GraphKernelExpanderLite::Run(func_graph) || changed; - return changed; -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/conv_tuning_expander.h b/mindspore-lite/tools/graph_kernel/converter/conv_tuning_expander.h deleted file mode 100644 index 1aae0f5d42b8311e49500405023e9b7304de40b7..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/conv_tuning_expander.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_CONV_TUNING_EXPANDER_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_CONV_TUNING_EXPANDER_H_ -#include -#include "ir/func_graph.h" -#include "tools/graph_kernel/converter/graph_kernel_expander_lite.h" - -namespace mindspore::graphkernel { -class ConvTuningExpander : public GraphKernelExpanderLite { - public: - ConvTuningExpander() : GraphKernelExpanderLite("conv_tuning_expander") {} - ~ConvTuningExpander() override = default; - bool Run(const FuncGraphPtr &func_graph) override; - - protected: - std::vector InitOpList() override; -}; - -bool InvalidConvAttr(const std::vector &kernel_size, const std::vector &pad, - const std::vector &stride, const std::vector &dilation); -} // namespace mindspore::graphkernel - -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_CONV_TUNING_EXPANDER_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/eliminate_maketuple_getitem.h b/mindspore-lite/tools/graph_kernel/converter/eliminate_maketuple_getitem.h deleted file mode 100644 index 011e471f7f604c3634619e5ed59c50542afd3e2d..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/eliminate_maketuple_getitem.h +++ /dev/null @@ -1,47 +0,0 @@ - -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_ELIMINATE_MAKETUPLE_GETITEM_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_ELIMINATE_MAKETUPLE_GETITEM_H_ - -#include -#include "include/backend/optimizer/pass.h" -#include "ir/func_graph.h" -#include "backend/common/graph_kernel/core/graph_builder.h" - -namespace mindspore::graphkernel { -/** - * @brief Eliminate redundant MakeTuple-Getitem edges - * @example - * %1 = op1 - * %2 = op2 - * %3 = make_tuple(%1, %2) - * %4 = tuple_getitem(%3, 0) - * %5 = tuple_getitem(%3, 1) - * %6 = op6(%4, %5) - * --> - * %1 = op1 - * %2 = op2 - * %6 = op6(%1, %2) - */ -class ElimMaketupleGetitem : public opt::Pass { - public: - ElimMaketupleGetitem() : Pass("elim_maketuple_getitem") {} - ~ElimMaketupleGetitem() override = default; - bool Run(const FuncGraphPtr &func_graph) override { return EliminateMaketupleGetitem(func_graph); } -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_ELIMINATE_MAKETUPLE_GETITEM_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/activation.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/activation.cc deleted file mode 100644 index f6298025a6817d5872ebdbe354e6ba73c1027228..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/activation.cc +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "tools/graph_kernel/converter/expanders/activation.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" - -namespace mindspore::graphkernel::expanders { -NodePtr GetActivationExpander(const inner::GraphBuilder &gb, const NodePtrList &inputs, int64_t activation_type) { - switch (activation_type) { - case ActivationType::RELU: - return ReluExpand(gb, inputs); - case ActivationType::SIGMOID: - return SigmoidExpand(gb, inputs); - case ActivationType::GELU: - return GeluExpand(gb, inputs); - case ActivationType::SWISH: - return gb.Mul(inputs[0], SigmoidExpand(gb, inputs)); - default: - return inputs[0]; - } -} - -class Activation : public OpDesc { - public: - Activation() { - std::initializer_list attrs{"activation_type"}; - (void)validators_.emplace_back(std::make_unique(attrs)); - std::set activation_types = {ActivationType::NO_ACTIVATION, ActivationType::RELU, ActivationType::SIGMOID, - ActivationType::GELU, ActivationType::SWISH}; - (void)validators_.emplace_back(std::make_unique(activation_types)); - } - ~Activation() = default; - - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - auto activation_type = GetValue(attrs_["activation_type"]); - return {GetActivationExpander(gb, inputs, activation_type)}; - } -}; -EXPANDER_OP_DESC_REGISTER("Activation", Activation); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/activation.h b/mindspore-lite/tools/graph_kernel/converter/expanders/activation.h deleted file mode 100644 index 5574e4e961ea8d3baee280fc4749e9e558c27711..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/activation.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_EXPANDERS_ACTIVATION_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_EXPANDERS_ACTIVATION_H_ - -#include -#include - -#include "backend/common/graph_kernel/expanders/utils.h" - -namespace mindspore::graphkernel::expanders { -class CheckActivationType : public Validator { - public: - explicit CheckActivationType(const std::set &s) : activation_types_(s) {} - explicit CheckActivationType(int64_t type) { activation_types_.insert(type); } - ~CheckActivationType() = default; - bool Check(const OpDesc &e) override { - auto iter = e.Attrs().find("activation_type"); - if (iter == e.Attrs().end()) { - return true; - } - auto activation_type = GetValue(iter->second); - if (activation_types_.find(activation_type) == activation_types_.end()) { - MS_LOG(INFO) << "Activation type " << activation_type << " not supported yet!"; - return false; - } - return true; - } - - private: - std::set activation_types_; -}; - -NodePtr GetActivationExpander(const inner::GraphBuilder &gb, const NodePtrList &inputs, int64_t activation_type); -} // namespace mindspore::graphkernel::expanders -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_EXPANDERS_ACTIVATION_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/add_fusion.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/add_fusion.cc deleted file mode 100644 index 714b1a6b2e09801bedd70fa11234375a2886f96c..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/add_fusion.cc +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "tools/graph_kernel/converter/expanders/activation.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" - -namespace mindspore::graphkernel::expanders { -class AddFusion : public OpDesc { - public: - AddFusion() { (void)validators_.emplace_back(std::make_unique(ActivationType::NO_ACTIVATION)); } - ~AddFusion() = default; - - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &input_x = inputs[0]; - const auto &input_y = inputs[1]; - auto result = gb.Add(input_x, input_y); - return {result}; - } -}; -EXPANDER_OP_DESC_REGISTER("AddFusion", AddFusion); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/conv2d_fusion.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/conv2d_fusion.cc deleted file mode 100644 index c5f2dd3046c82a387e6458104725930162985cbe..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/conv2d_fusion.cc +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "tools/graph_kernel/converter/conv_tuning_expander.h" -#include "tools/graph_kernel/converter/expanders/activation.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" - -namespace mindspore::graphkernel::expanders { -class CheckValidAttr : public Validator { - public: - bool Check(const OpDesc &e) override { - const auto kernel_size = GetValue>(e.Attrs().find("kernel_size")->second); - const auto pad = GetValue>(e.Attrs().find("pad")->second); - const auto stride = GetValue>(e.Attrs().find("stride")->second); - const auto dilation = GetValue>(e.Attrs().find("dilation")->second); - if (InvalidConvAttr(kernel_size, pad, stride, dilation)) { - return false; - } - return true; - } -}; - -class CheckDepthWise : public Validator { - public: - bool Check(const OpDesc &e) override { - if (e.Attrs().count("is_depth_wise") != 0) { - if (e.Attrs().count("group") == 0) { - return false; - } - const auto group = GetValue(e.Attrs().find("group")->second); - const auto c_in = e.InputsInfo()[0].shape[3]; - if (group != c_in) { - return false; - } - } - return true; - } -}; - -class Conv2DFusion : public OpDesc { - public: - Conv2DFusion() { - std::initializer_list attrs{"kernel_size", "out_channel", "stride", "dilation", - "in_channel", "pad_list", "pad_mode", "weight_coo", - "weight_coi", "weight_cio", "weight_cii"}; - (void)validators_.emplace_back(std::make_unique(attrs)); - (void)validators_.emplace_back(std::make_unique()); - (void)validators_.emplace_back(std::make_unique()); - } - ~Conv2DFusion() = default; - - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &data = inputs[0]; - const auto &weight = inputs[1]; - auto data_format = data->format; - - // pad_top, pad_bottom, pad_left, pad_right - std::vector pads = GetValue>(attrs_["pad_list"]); - auto c_i_i = GetValue(attrs_["weight_cii"]); - auto c_o_i = GetValue(attrs_["weight_coi"]); - - std::string conv_format = "NCHW" + std::to_string(c_i_i) + "c"; - auto data_tp = gb.Emit("LayoutTransform", {data}, - {{"src_format", MakeValue(data_format)}, {"dst_format", MakeValue(conv_format)}}); - - // PAD: NCHWc->NCHWc - auto pad_n = pads[0]; - auto pad_h = pads[1]; - auto pad_w = pads[2]; - auto pad_c = pads[3]; - ShapeVector head_pad{0, 0, pad_n, pad_w, 0}; - ShapeVector tail_pad{0, 0, pad_h, pad_c, 0}; - - inner::NodePtr data_pad; - if (pad_n == 0 && pad_h == 0 && pad_w == 0 && pad_c == 0) { - data_pad = data_tp; - } else { - data_pad = - gb.Emit("PadAkg", {data_tp}, - {{"head", MakeValue(head_pad)}, {"tail", MakeValue(tail_pad)}, {"pad_val", MakeValue((int64_t)0)}}); - } - - // update attrs after pad - auto updated_attrs = attrs_; - updated_attrs["pad_mode"] = MakeValue("VALID"); - auto pad_val = MakeValue(0LL); - updated_attrs["pad_list"] = MakeValue({pad_val, pad_val, pad_val, pad_val}); - updated_attrs["data_format"] = MakeValue(kOpFormat_NC1HWC0); - std::string conv_out_format = "NCHW" + std::to_string(c_o_i) + "c"; - updated_attrs["conv_out_format"] = MakeValue(conv_out_format); - auto result_nchwc = gb.Emit("Conv2D", {data_pad, weight}, updated_attrs); - - inner::NodePtr result_nchwc_bias; - constexpr size_t has_bias_inputs_size = 3; - if (inputs.size() == has_bias_inputs_size) { - const auto &bias = inputs[2]; - auto bias_dim = bias->shape[0]; - auto conv_c = result_nchwc->shape[4]; - ShapeVector bias_shape{1, bias_dim / conv_c, 1, 1, conv_c}; - auto bias_nchwc = gb.Reshape(bias, bias_shape); - result_nchwc_bias = gb.Add(result_nchwc, bias_nchwc); - } else { - result_nchwc_bias = result_nchwc; - } - - inner::NodePtr result_nchwc_act; - if (attrs_.find("activation_type") != attrs_.end()) { - auto act_type = GetValue(attrs_["activation_type"]); - result_nchwc_act = GetActivationExpander(gb, {result_nchwc_bias}, act_type); - } else { - result_nchwc_act = result_nchwc_bias; - } - - auto result_rs = gb.Emit("LayoutTransform", {result_nchwc_act}, - {{"src_format", MakeValue(conv_out_format)}, {"dst_format", MakeValue(data_format)}}); - - return {result_rs}; - } -}; -EXPANDER_OP_DESC_REGISTER("Conv2DFusion", Conv2DFusion); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/div_fusion.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/div_fusion.cc deleted file mode 100644 index c1503bd8b8afd8d348b4d2f931d08d67cd7283c9..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/div_fusion.cc +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "tools/graph_kernel/converter/expanders/activation.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" - -namespace mindspore::graphkernel::expanders { -class DivFusion : public OpDesc { - public: - DivFusion() { (void)validators_.emplace_back(std::make_unique(ActivationType::NO_ACTIVATION)); } - ~DivFusion() = default; - - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &input_x = inputs[0]; - const auto &input_y = inputs[1]; - auto result = gb.Div(input_x, input_y); - return {result}; - } -}; -EXPANDER_OP_DESC_REGISTER("DivFusion", DivFusion); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/fused_batch_norm.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/fused_batch_norm.cc deleted file mode 100644 index 8da30bf30ccc792497b0f4266553b0744991131e..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/fused_batch_norm.cc +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" - -namespace mindspore::graphkernel::expanders { -class FusedBatchNorm : public OpDesc { - public: - FusedBatchNorm() { - std::initializer_list attrs{"epsilon"}; - (void)validators_.emplace_back(std::make_unique(attrs)); - } - ~FusedBatchNorm() = default; - - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - // only work for infer - const size_t input_index = 0; - const size_t scale_index = 1; - const size_t bias_index = 2; - const size_t mean_index = 3; - const size_t var_index = 4; - const auto &input = inputs[input_index]; - const auto &scale = inputs[scale_index]; - const auto &bias = inputs[bias_index]; - const auto &mean = inputs[mean_index]; - const auto &var = inputs[var_index]; - auto eps = gb.Tensor(GetValue(attrs_["epsilon"]), input->type); - auto fuse_scale = gb.Div(scale, gb.Sqrt(gb.Add(var, eps))); - auto fuse_offset = gb.Sub(bias, gb.Mul(fuse_scale, mean)); - auto result = gb.Add(gb.Mul(input, fuse_scale), fuse_offset); - return {result}; - } -}; -EXPANDER_OP_DESC_REGISTER("FusedBatchNorm", FusedBatchNorm); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/instance_norm.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/instance_norm.cc deleted file mode 100644 index 1a9e74137c621ff4bfa40422bd58d6c62e166144..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/instance_norm.cc +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "common/common_utils.h" - -namespace mindspore::graphkernel::expanders { -constexpr size_t kInputIdx = 0; -constexpr size_t kScaleIdx = 1; -constexpr size_t kBiasIdx = 2; -constexpr size_t kInputMinRank = 2; -constexpr int64_t kDimHeight = -2; -constexpr int64_t kDimWidth = -1; -class InstanceNorm : public OpDesc { - public: - InstanceNorm() { - std::initializer_list attrs{"epsilon"}; - (void)validators_.emplace_back(std::make_unique(attrs)); - } - ~InstanceNorm() = default; - - protected: - bool CheckInputs() override { - const auto &var = inputs_info_[0]; - if (var.shape.size() < kInputMinRank) { - MS_LOG(INFO) << "In InstanceNorm, input[0]'s rank must be at least 2, but got " << var.shape.size(); - return false; - } - if (var.format != kOpFormat_NCHW) { - MS_LOG(INFO) << "In InstanceNorm, input[0]'s format must be NCHW, but got " << var.format; - return false; - } - return true; - } - - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &input = inputs[kInputIdx]; - const auto &scale = inputs[kScaleIdx]; - const auto &bias = inputs[kBiasIdx]; - - int64_t rank = SizeToLong(input->shape.size()); - std::vector reduce_axis{kDimHeight + rank, kDimWidth + rank}; - int64_t mean_cof_v = input->shape[LongToSize(kDimHeight + rank)] * input->shape[LongToSize(kDimWidth + rank)]; - - // const - auto num = gb.Tensor(mean_cof_v, input->type); - auto eps = gb.Tensor(GetValue(attrs_["epsilon"]), input->type); - - // Calculate mean - auto sum_res = gb.ReduceSum(input, reduce_axis, true); - auto mean = gb.Div(sum_res, num); - - // Calculate variance - auto variance_sub = gb.Sub(input, mean); - auto variance_mul = gb.Mul(variance_sub, variance_sub); - auto variance_red = gb.ReduceSum(variance_mul, reduce_axis, true); - auto variance = gb.Div(variance_red, num); - - // Calculate normalize - auto normalize_sqrt = gb.Sqrt(gb.Add(variance, eps)); - auto normalize = gb.Div(variance_sub, normalize_sqrt); - - // Calculate scale and translate - auto scale_mul = gb.Mul(normalize, scale); - auto result = gb.Add(scale_mul, bias); - return {result}; - } -}; -EXPANDER_OP_DESC_REGISTER("InstanceNorm", InstanceNorm); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/layer_norm_fusion.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/layer_norm_fusion.cc deleted file mode 100644 index 1eee478b1777e8f57b5269c99f488b2b54b4216c..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/layer_norm_fusion.cc +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" - -namespace mindspore::graphkernel::expanders { -class LayerNormFusion : public OpDesc { - public: - LayerNormFusion() { - std::initializer_list attrs{"begin_norm_axis", "begin_params_axis", "epsilon"}; - (void)validators_.emplace_back(std::make_unique(attrs)); - } - ~LayerNormFusion() = default; - - protected: - void Init() override { - begin_norm_axis_ = GetAxisList(attrs_["begin_norm_axis"]); - begin_params_axis_ = GetAxisList(attrs_["begin_params_axis"]); - } - - bool CheckInputs() override { - if (begin_norm_axis_.size() != 1) { - MS_LOG(INFO) << "begin_norm_axis should only contain 1 axis, but got " << begin_norm_axis_.size(); - return false; - } - - if (begin_params_axis_.size() != 1) { - MS_LOG(INFO) << "begin_params_axis should only contain 1 axis, but got " << begin_params_axis_.size(); - return false; - } - - if (begin_params_axis_[0] != begin_norm_axis_[0]) { - MS_LOG(INFO) << "Expander doesn't support begin_norm_axis and begin_params_axis with different value"; - return false; - } - return true; - } - - NodePtrList Expand(const NodePtrList &inputs) override { - // only work for infer - const size_t input_index = 0; - const size_t gamma_index = 1; - const size_t beta_index = 2; - const auto &input = inputs[input_index]; - const auto &gamma = inputs[gamma_index]; - const auto &beta = inputs[beta_index]; - - int64_t real_axis = - begin_norm_axis_[0] < 0 ? SizeToLong(input->shape.size()) + begin_norm_axis_[0] : begin_norm_axis_[0]; - std::vector reduce_axis; - int64_t mean_cof_v = 1; - for (size_t i = 0; i < input->shape.size(); i++) { - if (auto axis_int64 = SizeToLong(i); axis_int64 >= real_axis) { - reduce_axis.push_back(axis_int64); - mean_cof_v *= input->shape[i]; - } - } - - // const - auto num = gb.Tensor(mean_cof_v, input->type); - auto eps = gb.Tensor(GetValue(attrs_["epsilon"]), input->type); - - // Calculate mean - auto sum_res = gb.ReduceSum(input, reduce_axis, true); - auto mean = gb.Div(sum_res, num); - - // Calculate variance - auto variance_sub = gb.Sub(input, mean); - auto variance_mul = gb.Mul(variance_sub, variance_sub); - auto variance_red = gb.ReduceSum(variance_mul, reduce_axis, true); - auto variance = gb.Div(variance_red, num); - - // Calculate normalize - auto normalize_sqrt = gb.Sqrt(gb.Add(variance, eps)); - auto normalize = gb.Div(variance_sub, normalize_sqrt); - - // Calculate scale and translate - auto scale_mul = gb.Mul(normalize, gamma); - auto result = gb.Add(scale_mul, beta); - return {result}; - } - - std::vector begin_norm_axis_; - std::vector begin_params_axis_; -}; -EXPANDER_OP_DESC_REGISTER("LayerNormFusion", LayerNormFusion); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/matmul_fusion.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/matmul_fusion.cc deleted file mode 100644 index b5af5eff4415965ae079840eca311dbf47fd294c..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/matmul_fusion.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "tools/graph_kernel/converter/expanders/activation.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" - -namespace mindspore::graphkernel::expanders { -class MatMulFusion : public OpDesc { - public: - MatMulFusion() { - std::set activation_types = {ActivationType::NO_ACTIVATION, ActivationType::RELU, ActivationType::SIGMOID}; - (void)validators_.emplace_back(std::make_unique(activation_types)); - } - ~MatMulFusion() = default; - - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - const size_t has_bias_input_size = 3; - auto a = inputs[0]; - auto b = inputs[1]; - auto bias = (inputs.size() == has_bias_input_size) ? inputs[2] : nullptr; - auto transpose_a = (attrs_.count("transpose_a") != 0) ? attrs_["transpose_a"] : MakeValue(false); - auto transpose_b = (attrs_.count("transpose_b") != 0) ? attrs_["transpose_b"] : MakeValue(false); - auto pack_b = (attrs_.count("pack_b") != 0) ? attrs_["pack_b"] : MakeValue(false); - if (b->shape.size() < a->shape.size()) { - ShapeVector new_shape(a->shape.size() - b->shape.size(), 1); - (void)new_shape.insert(new_shape.end(), b->shape.cbegin(), b->shape.cend()); - b = gb.Reshape(b, new_shape); - } else if (a->shape.size() < b->shape.size()) { - ShapeVector new_shape(b->shape.size() - a->shape.size(), 1); - (void)new_shape.insert(new_shape.end(), a->shape.cbegin(), a->shape.cend()); - a = gb.Reshape(a, new_shape); - } - auto matmul = - gb.Emit("MatMul", {a, b}, {{"transpose_a", transpose_a}, {"transpose_b", transpose_b}, {"pack_b", pack_b}}); - if (bias != nullptr) { - matmul = gb.Add(matmul, bias); - } - if (attrs_.find("activation_type") != attrs_.end()) { - auto act_type = GetValue(attrs_["activation_type"]); - return {GetActivationExpander(gb, {matmul}, act_type)}; - } else { - return {matmul}; - } - } -}; -EXPANDER_OP_DESC_REGISTER("MatMulFusion", MatMulFusion); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/mul_fusion.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/mul_fusion.cc deleted file mode 100644 index 6f2ca1034cee4473b1e29312142379aeca58ba88..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/mul_fusion.cc +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "tools/graph_kernel/converter/expanders/activation.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" - -namespace mindspore::graphkernel::expanders { -class MulFusion : public OpDesc { - public: - MulFusion() { (void)validators_.emplace_back(std::make_unique(ActivationType::NO_ACTIVATION)); } - ~MulFusion() = default; - - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &input_x = inputs[0]; - const auto &input_y = inputs[1]; - auto result = gb.Mul(input_x, input_y); - return {result}; - } -}; -EXPANDER_OP_DESC_REGISTER("MulFusion", MulFusion); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/pool_fusion.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/pool_fusion.cc deleted file mode 100644 index e3e93ad61b5ee60c8d48be03a75a84a89890ec07..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/pool_fusion.cc +++ /dev/null @@ -1,171 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "tools/graph_kernel/converter/expanders/activation.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" - -namespace mindspore::graphkernel::expanders { -NodePtr GetPadResult(const inner::GraphBuilder &gb, const NodePtr &input_node, inner::DAttrs attrs, - const inner::DShape &input_shape, bool do_transform) { - NodePtr input_pad; - int64_t pad_mode = GetValue(attrs["pad_mode"]); - if (pad_mode == PadMode::PAD) { - std::vector pads = GetValue>(attrs["pad"]); - auto pad_n = pads[0]; - auto pad_h = pads[1]; - auto pad_w = pads[2]; - auto pad_c = pads[3]; - if (pad_n == 0 && pad_h == 0 && pad_w == 0 && pad_c == 0) { - input_pad = input_node; - } else { - ShapeVector head_pad; - ShapeVector tail_pad; - if (do_transform) { - head_pad = {0, 0, pad_n, pad_w, 0}; - tail_pad = {0, 0, pad_h, pad_c, 0}; - } else { - head_pad = {0, pad_n, pad_w, 0}; - tail_pad = {0, pad_h, pad_c, 0}; - } - input_pad = - gb.Emit("PadAkg", {input_node}, - {{"head", MakeValue(head_pad)}, {"tail", MakeValue(tail_pad)}, {"pad_val", MakeValue((int64_t)0)}}); - } - } else if (pad_mode == PadMode::SAME) { - auto input_h = input_shape[1]; - auto input_w = input_shape[2]; - auto stride_h = GetValue>(attrs["strides"])[0]; - auto stride_w = GetValue>(attrs["strides"])[1]; - auto kernel_h = GetValue>(attrs["kernel_size"])[0]; - auto kernel_w = GetValue>(attrs["kernel_size"])[1]; - int64_t pad_h; - int64_t pad_w; - if (input_h % stride_h == 0) { - pad_h = std::max(kernel_h - stride_h, int64_t(0)); - } else { - pad_h = std::max(kernel_h - (input_h % stride_h), int64_t(0)); - } - if (input_w % stride_w == 0) { - pad_w = std::max(kernel_w - stride_w, int64_t(0)); - } else { - pad_w = std::max(kernel_w - (input_w % stride_w), int64_t(0)); - } - if (pad_h == 0 && pad_w == 0) { - input_pad = input_node; - } else { - ShapeVector head_pad; - ShapeVector tail_pad; - auto pad_top = pad_h / 2; - auto pad_bottom = pad_h - pad_top; - auto pad_left = pad_w / 2; - auto pad_right = pad_w - pad_left; - if (do_transform) { - head_pad = {0, 0, pad_top, pad_left, 0}; - tail_pad = {0, 0, pad_bottom, pad_right, 0}; - } else { - head_pad = {0, pad_top, pad_left, 0}; - tail_pad = {0, pad_bottom, pad_right, 0}; - } - input_pad = - gb.Emit("PadAkg", {input_node}, - {{"head", MakeValue(head_pad)}, {"tail", MakeValue(tail_pad)}, {"pad_val", MakeValue((int64_t)0)}}); - } - } else { - input_pad = input_node; - } - return input_pad; -} - -class PoolFusion : public OpDesc { - public: - explicit PoolFusion(const std::string &pool_type) : pool_type_(pool_type) {} - ~PoolFusion() = default; - - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &input = inputs[0]; - auto input_shape = input->shape; - auto input_format = input->format; - bool do_transform = attrs_.count("layout_axis") != 0; - NodePtr input_tp; - std::string layout_format; - if (do_transform) { - auto layout_axis = GetValue(attrs_["layout_axis"]); - layout_format = "NCHW" + std::to_string(layout_axis) + "c"; - input_tp = gb.Emit("LayoutTransform", {input}, - {{"src_format", MakeValue(input_format)}, {"dst_format", MakeValue(layout_format)}}); - } else { - input_tp = input; - layout_format = "NHWC"; - } - - NodePtr input_pad = GetPadResult(gb, input_tp, attrs_, input_shape, do_transform); - - inner::DAttrs attr_list = { - {"pool_type", MakeValue(pool_type_)}, {"data_layout", MakeValue(layout_format)}, {"strides", attrs_["strides"]}}; - if (attrs_["kernel_size"] != nullptr) { - attr_list["kernel_size"] = attrs_["kernel_size"]; - } - if (attrs_.count("global") != 0) { - attr_list["global"] = attrs_["global"]; - } else { - attr_list["global"] = MakeValue(false); - } - if (attrs_.count("round_mode") != 0) { - attr_list["round_mode"] = attrs_["round_mode"]; - } else { - attr_list["round_mode"] = MakeValue(0); - } - auto pool_res = gb.Emit("Pool2D", {input_pad}, attr_list); - NodePtr result; - if (do_transform) { - result = gb.Emit("LayoutTransform", {pool_res}, - {{"src_format", MakeValue(layout_format)}, {"dst_format", MakeValue(input_format)}}); - } else { - result = pool_res; - } - return {result}; - } - - private: - std::string pool_type_; -}; - -class MaxPoolFusion : public PoolFusion { - public: - MaxPoolFusion() : PoolFusion("max") { - (void)validators_.emplace_back(std::make_unique(ActivationType::NO_ACTIVATION)); - } - ~MaxPoolFusion() = default; -}; -EXPANDER_OP_DESC_REGISTER("MaxPoolFusion", MaxPoolFusion); - -class AvgPoolFusion : public PoolFusion { - public: - AvgPoolFusion() : PoolFusion("avg") { - (void)validators_.emplace_back(std::make_unique(ActivationType::NO_ACTIVATION)); - } - ~AvgPoolFusion() = default; -}; -EXPANDER_OP_DESC_REGISTER("AvgPoolFusion", AvgPoolFusion); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/reduce_fusion.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/reduce_fusion.cc deleted file mode 100644 index dbdf7bb2d01c2dd290a282eb9dfd7f83c56b2c24..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/reduce_fusion.cc +++ /dev/null @@ -1,107 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" - -namespace mindspore::graphkernel::expanders { -class CheckReduceMode : public Validator { - public: - bool Check(const OpDesc &e) override { - auto iter = e.Attrs().find("mode"); - if (iter == e.Attrs().end()) { - MS_LOG(INFO) << "The mode is not found in attrs."; - return false; - } - auto mode = GetValue(iter->second); - if (mode != ReduceMode::Reduce_Sum && mode != ReduceMode::Reduce_Mean) { - MS_LOG(INFO) << "Reduce mode " << mode << " not supported yet!"; - return false; - } - return true; - } -}; - -class ReduceFusion : public OpDesc { - public: - ReduceFusion() { - std::initializer_list attrs{"keep_dims", "coeff"}; - (void)validators_.emplace_back(std::make_unique(attrs)); - (void)validators_.emplace_back(std::make_unique()); - } - ~ReduceFusion() = default; - - protected: - bool CheckInputs() override { - const auto &x = inputs_info_[0]; - auto x_shape = x.shape; - auto mode = GetValue(attrs_["mode"]); - if (mode == ReduceMode::Reduce_Mean) { - if (attrs_.count("axis") == 0) { - MS_LOG(INFO) << "Axis is dynamic, and the mode of ReduceFusion is Reduce_Mean, in this case we can not expand " - "ReduceFusion."; - return false; - } else if (x_shape.empty() || IsDynamicRank(x_shape)) { - MS_LOG(INFO) << "Skip empty shape or dynamic rank, shape is: " << x_shape; - return false; - } else { - auto axis = GetAxisList(attrs_["axis"]); - bool is_valid = std::all_of(axis.begin(), axis.end(), [&x_shape](int idx) { return x_shape[idx] > 0; }); - if (!is_valid) { - MS_LOG(INFO) << "Some dimension size needed in reducemean is not available"; - return false; - } - } - } - return true; - } - - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &input_x = inputs[0]; - auto keep_dims = GetValue(attrs_["keep_dims"]); - auto mode = GetValue(attrs_["mode"]); - if (mode == ReduceMode::Reduce_Mean) { - auto axis = GetAxisList(attrs_["axis"]); - auto sum_res = gb.ReduceSum(input_x, axis, keep_dims); - auto coeff = gb.Tensor(GetValue(attrs_["coeff"]), input_x->type); - auto result = gb.Mul(sum_res, coeff); - int64_t reduce_size = std::accumulate(axis.begin(), axis.end(), 1LL, [input_x](int64_t a, int64_t idx) { - return a * input_x->shape[LongToSize(idx)]; - }); - auto reduce_size_value = gb.Tensor(reduce_size, input_x->type); - auto mean_res = gb.Div(result, reduce_size_value); - return {mean_res}; - } else { - NodePtr sum_res = nullptr; - if (attrs_.count("axis") == 0) { - auto &axis = inputs[1]; - sum_res = gb.Emit("ReduceSum", {input_x, axis}, {{"keep_dims", MakeValue(keep_dims)}}); - } else { - auto axis = GetAxisList(attrs_["axis"]); - sum_res = gb.ReduceSum(input_x, axis, keep_dims); - } - auto coeff = gb.Tensor(GetValue(attrs_["coeff"]), input_x->type); - auto result = gb.Mul(sum_res, coeff); - return {result}; - } - } -}; -EXPANDER_OP_DESC_REGISTER("ReduceFusion", ReduceFusion); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/reshape.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/reshape.cc deleted file mode 100644 index 4f20e1bafafc23e5454783f43c65d14f89cd8aa2..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/reshape.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "tools/graph_kernel/converter/expanders/activation.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" - -namespace mindspore::graphkernel::expanders { -class Reshape : public OpDesc { - public: - Reshape() { - std::initializer_list attrs{"shape"}; - (void)validators_.emplace_back(std::make_unique(attrs)); - } - ~Reshape() = default; - - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &input_x = inputs[0]; - auto shp_ptr = attrs_["shape"]; - ShapeVector shape; - if (shp_ptr->isa()) { - auto value = std::static_pointer_cast(shp_ptr); - if (value->data_type_c() == TypeId::kNumberTypeInt32) { - int32_t *data = static_cast(value->data_c()); - for (size_t elem = 0; elem < value->DataSize(); elem++) { - (void)shape.emplace_back(IntToLong(*(data + elem))); - } - } else if (value->data_type_c() == TypeId::kNumberTypeInt64) { - int64_t *data = static_cast(value->data_c()); - for (size_t elem = 0; elem < value->DataSize(); elem++) { - (void)shape.emplace_back(*(data + elem)); - } - } else { - MS_LOG(INFO) << "Type of reshape's shape tensor is neither int64_t nor int32_t. Expand failed"; - return {}; - } - } else if (shp_ptr->isa()) { - shape = GetValue(shp_ptr); - } else { - MS_LOG(INFO) << "Reshape's attr shape is neither Tensor nor ValueTuple. Expand failed"; - return {}; - } - for (size_t i = 0; i < shape.size(); i++) { - if (shape[i] == 0) { - if (input_x->shape.size() <= i) { - MS_LOG(INFO) << "Reshape's attr shape[" << i << "] is 0, but input's rank is " << input_x->shape.size(); - return {}; - } - shape[i] = input_x->shape[i]; - } - } - auto result = gb.Reshape(input_x, shape); - return {result}; - } -}; -EXPANDER_OP_DESC_REGISTER("Reshape", Reshape); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/expanders/sub_fusion.cc b/mindspore-lite/tools/graph_kernel/converter/expanders/sub_fusion.cc deleted file mode 100644 index 2bde01f8203bb231c9ed21b0df1c2ae6231a3621..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/expanders/sub_fusion.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "backend/common/graph_kernel/expanders/op_desc_registry.h" -#include "tools/graph_kernel/converter/expanders/activation.h" -#include "mindapi/base/types.h" -#include "ir/dtype.h" - -namespace mindspore::graphkernel::expanders { -class SubFusion : public OpDesc { - public: - SubFusion() { (void)validators_.emplace_back(std::make_unique(ActivationType::NO_ACTIVATION)); } - ~SubFusion() = default; - - protected: - NodePtrList Expand(const NodePtrList &inputs) override { - const auto &input_x = inputs[0]; - const auto &input_y = inputs[1]; - auto result = gb.Sub(input_x, input_y); - return {result}; - } -}; -EXPANDER_OP_DESC_REGISTER("SubFusion", SubFusion); -} // namespace mindspore::graphkernel::expanders diff --git a/mindspore-lite/tools/graph_kernel/converter/format_recognition.cc b/mindspore-lite/tools/graph_kernel/converter/format_recognition.cc deleted file mode 100644 index 7f7ff1796b6c5bca1156c7f106a212e50f63ec29..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/format_recognition.cc +++ /dev/null @@ -1,171 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/graph_kernel/converter/format_recognition.h" - -#include -#include -#include -#include -#include "backend/common/graph_kernel/core/graph_kernel_utils.h" -#include "include/common/utils/utils.h" -#include "mindspore/ops/op_def/array_ops.h" -#include "mindspore/ops/op_def/sequence_ops.h" -#include "utils/anf_utils.h" -#include "common/kernel_build_info.h" -#include "include/backend/kernel_info.h" -#include "tools/graph_kernel/common/utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" - -namespace mindspore::graphkernel { -namespace { -using KernelWithIndex = std::pair; - -std::pair GetLiteFormat(const CNodePtr &cnode, size_t idx = 0) { - auto prim = GetCNodePrimitive(cnode); - if (prim == nullptr || !prim->HasAttr("format")) { - return std::make_pair(kOpFormat_DEFAULT, false); - } - auto format_attr = prim->GetAttr("format"); - MS_EXCEPTION_IF_NULL(format_attr); - auto format = GetValue(format_attr); - if (format == 1) { - return std::make_pair(kOpFormat_NHWC, false); - } else { - return std::make_pair(kOpFormat_DEFAULT, false); - } -} - -std::pair GetTransposeFormat(const CNodePtr &cnode) { - constexpr size_t perm_idx = 2; - std::vector perm_list; - if (cnode->input(perm_idx)->isa()) { - auto perm_para = cnode->input(perm_idx)->cast(); - MS_EXCEPTION_IF_NULL(perm_para); - if (!perm_para->has_default()) { - return GetLiteFormat(cnode); - } - auto perm_tensor = perm_para->default_param()->cast(); - auto perm = static_cast(perm_tensor->data_ptr()->data()); - std::transform(perm, perm + perm_tensor->shape()[0], std::back_inserter(perm_list), IntToLong); - } else { - auto perm_value = cnode->input(perm_idx)->cast(); - MS_EXCEPTION_IF_NULL(perm_value); - perm_list = GetValue>(perm_value->value()); - } - std::vector nh2nc_perm = {0, 3, 1, 2}; - std::vector nc2nh_perm = {0, 2, 3, 1}; - if (perm_list == nh2nc_perm) { - return std::make_pair(kOpFormat_NCHW, true); - } else if (perm_list == nc2nh_perm) { - return std::make_pair(kOpFormat_DEFAULT, true); - } else { - return GetLiteFormat(cnode); - } -} - -std::pair ExtractOutputFormat(const CNodePtr &cnode) { - if (IsPrimitiveCNode(cnode, prim::kPrimTranspose)) { - return GetTransposeFormat(cnode); - } - return GetLiteFormat(cnode); -} - -void SetOutputsFormat(const CNodePtr &cnode) { - auto extract_res = ExtractOutputFormat(cnode); - auto output_format = extract_res.first; - if (!extract_res.second) { - // fix output format when it can not be determined by transpose - for (size_t i = 1; i < cnode->size(); i++) { - if (cnode->input(i)->isa()) { - auto kernel_with_index = AnfUtils::VisitKernel(cnode->input(i), 0); - auto prev_cnode = kernel_with_index.first; - auto kernel_build_info = GetKernelInfo(prev_cnode); - if (prev_cnode != nullptr && kernel_build_info) { - if (kernel_build_info->GetOutputNum() < kernel_with_index.second) { - MS_LOG(EXCEPTION) << "cnode output num is wrong, required " << kernel_with_index.second - << ", but only have " << kernel_build_info->GetOutputNum() - << "outputs. Cnode is: " << cnode->fullname_with_scope(); - } - output_format = kernel_build_info->GetOutputFormat(kernel_with_index.second); - break; - } - } - } - } - std::vector outputs_format; - if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { - outputs_format = std::vector(cnode->size() - 1, output_format); - } else { - outputs_format = std::vector(AnfUtils::GetOutputTensorNum(cnode), output_format); - } - SetKernelInfoWithFormatToAnfNode(cnode, outputs_format); -} - -void FixFormatsBeforeTranspose(const CNodePtr &cnode) { - auto current_cnode_kernel_build_info = GetKernelInfo(cnode); - if (current_cnode_kernel_build_info == nullptr) { - MS_LOG(INFO) << "The kernel build info of node " << cnode->fullname_with_scope() << " is nullptr."; - return; - } - if (current_cnode_kernel_build_info->GetOutputNum() == 0) { - MS_LOG(INFO) << "The outputs_format of node " << cnode->fullname_with_scope() << " is empty."; - return; - } - for (size_t i = 1; i < cnode->size(); i++) { - auto prev_node = cnode->input(i); - if (IsPrimitiveCNode(prev_node, prim::kPrimTranspose)) { - continue; - } - if (prev_node->isa()) { - auto prev_cnode = prev_node->cast(); - MS_EXCEPTION_IF_NULL(prev_cnode); - auto current_format = current_cnode_kernel_build_info->GetOutputFormat(0); - std::string prev_format = current_format; - if (ExtractOutputFormat(cnode).second) { - // input node need to fix format when current node is nhwc->nchw or nchw->nhwc - if (current_format == kOpFormat_DEFAULT) { - prev_format = kOpFormat_NHWC; - } else if (current_format == kOpFormat_NHWC) { - prev_format = kOpFormat_DEFAULT; - } - } - std::vector outputs_formats(AnfUtils::GetOutputTensorNum(prev_cnode), prev_format); - SetKernelInfoWithFormatToAnfNode(prev_cnode, outputs_formats); - } else if (prev_node->isa()) { - // save parameter's format in callback instance - inner::NodeBase nodebase = {{}, TypeId::kMetaTypeBegin, current_cnode_kernel_build_info->GetOutputFormat(0)}; - Callback::Instance()->SetBasicNodeKernelInfo(prev_node, {nodebase}); - } - } -} -} // namespace - -bool FormatRecognition::Run(const FuncGraphPtr &func_graph) { - auto todos = TopoSort(func_graph->output()); - for (auto &node : todos) { - if (node->isa()) { - SetOutputsFormat(node->cast()); - } - } - for (auto it = todos.rbegin(); it != todos.rend(); ++it) { - if ((*it)->isa()) { - FixFormatsBeforeTranspose((*it)->cast()); - } - } - return true; -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/format_recognition.h b/mindspore-lite/tools/graph_kernel/converter/format_recognition.h deleted file mode 100644 index faa12cb577e460820928d6e837dde8ba64ad3778..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/format_recognition.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_FORMAT_RECOGNITION_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_FORMAT_RECOGNITION_H_ -#include "ir/func_graph.h" -#include "include/backend/optimizer/pass.h" - -namespace mindspore::graphkernel { -class FormatRecognition : public opt::Pass { - public: - FormatRecognition() : Pass("format_recognition") {} - ~FormatRecognition() override = default; - bool Run(const FuncGraphPtr &func_graph) override; -}; -} // namespace mindspore::graphkernel - -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_FORMAT_RECOGNITION_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_cluster_lite.cc b/mindspore-lite/tools/graph_kernel/converter/graph_kernel_cluster_lite.cc deleted file mode 100644 index 6905928ff239192b1f9b94f43a45b4bfec3ff017..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_cluster_lite.cc +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/graph_kernel/converter/graph_kernel_cluster_lite.h" - -#include -#include - -#include "mindspore/ops/op_def/nn_optimizer_ops.h" -#include "mindspore/ops/op_def/math_ops.h" -#include "mindspore/ops/op_def/nn_ops.h" -#include "mindspore/ops/op_def/comparison_ops.h" -#include "mindspore/ops/op_def/array_ops.h" -#include "mindspore/ops/op_def/lite_ops.h" -#include "include/common/utils/anfalgo.h" -#include "backend/common/graph_kernel/core/graph_kernel_callback.h" -#include "backend/common/graph_kernel/core/graph_kernel_utils.h" -#include "backend/common/graph_kernel/graph_kernel_flags.h" -#include "utils/ms_context.h" -#include "utils/anf_utils.h" -#include "backend/common/graph_kernel/core/value_depend_op_utils.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_n.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_p.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" - -namespace mindspore::graphkernel { -std::vector GraphKernelClusterLite::GetClusterableOpList() { - std::vector clusterable_ops_with_level = { - {kAllTarget, OpLevel_0, prim::kPrimAbs}, - {kAllTarget, OpLevel_0, prim::kPrimAdd}, - {kAllTarget, OpLevel_0, prim::kPrimDiv}, - {kAllTarget, OpLevel_0, prim::kPrimRealDiv}, - {kAllTarget, OpLevel_0, prim::kPrimExp}, - {kAllTarget, OpLevel_0, prim::kPrimLog}, - {kAllTarget, OpLevel_0, prim::kPrimMaximum}, - {kAllTarget, OpLevel_0, prim::kPrimMinimum}, - {kAllTarget, OpLevel_0, prim::kPrimMul}, - {kAllTarget, OpLevel_0, prim::kPrimSqrt}, - {kAllTarget, OpLevel_0, prim::kPrimSub}, - {kAllTarget, OpLevel_0, prim::kPrimNeg}, - {kAllTarget, OpLevel_0, prim::kPrimPow}, - {kAllTarget, OpLevel_0, prim::kPrimRealDiv}, - {kAllTarget, OpLevel_0, prim::kPrimReciprocal}, - {kAllTarget, OpLevel_0, prim::kPrimRsqrt}, - {kAllTarget, OpLevel_0, prim::kPrimExpandDims}, - {kAllTarget, OpLevel_0, prim::kPrimSqueeze}, - {kAllTarget, OpLevel_0, prim::kPrimLeakyRelu}, - {kAllTarget, OpLevel_0, prim::kPrimSign}, - {kAllTarget, OpLevel_0, prim::kPrimMod}, - {kAllTarget, OpLevel_0, prim::kPrimReduceMax}, - {kAllTarget, OpLevel_0, prim::kPrimReduceMin}, - {kAllTarget, OpLevel_0, prim::kPrimReduceSum}, - // ascend device - {kAscendDevice, OpLevel_0, prim::kPrimMatMul}, - {kAscendDevice, OpLevel_0, prim::kPrimFastGeLU}, - {kAscendDevice, OpLevel_0, prim::kPrimTranspose}, - {kAscendDevice, OpLevel_0, prim::kPrimReshape}, - // cpu device - {kCPUDevice, OpLevel_0, prim::kPrimSin}, - {kCPUDevice, OpLevel_0, prim::kPrimTanh}, - {kCPUDevice, OpLevel_0, prim::kPrimCos}, - {kCPUDevice, OpLevel_0, prim::kPrimGreater}, - {kCPUDevice, OpLevel_0, prim::kPrimGreaterEqual}, - {kCPUDevice, OpLevel_0, prim::kPrimLess}, - {kCPUDevice, OpLevel_0, prim::kPrimLessEqual}, - {kCPUDevice, OpLevel_0, prim::kPrimLogicalAnd}, - {kCPUDevice, OpLevel_0, prim::kPrimLogicalOr}, - {kCPUDevice, OpLevel_0, prim::kPrimLogicalNot}, - }; - const auto &flags = GraphKernelFlags::GetInstance(); - return GkUtils::GetValidOps(clusterable_ops_with_level, flags.fusion_ops_level, flags.enable_cluster_ops_only, - flags.enable_cluster_ops, flags.disable_cluster_ops); -} - -bool CanCluster(const CNodePtr &cnode, const std::string &node_name) { - constexpr int64_t byte_align = 32; - auto cb = Callback::Instance(); - MS_EXCEPTION_IF_NULL(cb); - static const std::string target = cb->GetTargetFromContext(true); - if (target.find("Ascend910B") == std::string::npos) { - return true; - } - auto type_id = cb->GetOutputInferType(cnode, 0); - if (type_id != kNumberTypeFloat16 && type_id != kNumberTypeFloat32) { - return false; - } - if (node_name == "Mul" || node_name == "Add") { - for (size_t i = 0; i < cnode->size() - 1; i++) { - auto shape = cb->GetInputShape(cnode, i); - auto shape_size = shape.size(); - if (!shape.empty() && shape[shape_size - 1] > byte_align && shape[shape_size - 1] % byte_align != 0) { - return false; - } - } - } - return true; -} - -bool GraphKernelClusterLite::IsClusterableOp(const AnfNodePtr &node) { - if (GkUtils::UseAkgCceLib(node)) { - // do not cluster any other node into akg cce lib subgraph. - return false; - } - if (AnfUtils::IsGraphKernel(node)) { - return true; - } - if (GkUtils::IsKeepBasicNode(node)) { - return false; - } - if (!GraphKernelFlags::GetInstance().enable_dynamic_shape_fusion) { - if (common::AnfAlgo::IsDynamicShape(node)) { - return false; - } - } - bool node_in_oplist = std::any_of(op_list_.begin(), op_list_.end(), - [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); - if (!node_in_oplist) { - return false; - } - if (!ValueDependOpUtils::IsConstInput(node)) { - return false; - } - auto cnode = node->cast(); - if (cnode == nullptr) { - return false; - } - auto cb = Callback::Instance(); - MS_EXCEPTION_IF_NULL(cb); - if (device_ == "Ascend") { - auto type_id = cb->GetOutputInferType(node, 0); - if (type_id == kNumberTypeInt64) { - return false; - } - auto node_name = AnfUtils::GetCNodeName(node); - if (node_name.find("MatMul") != std::string::npos && type_id != kNumberTypeFloat16 && - type_id != kNumberTypeFloat32) { - return false; - } - if (CanCluster(cnode, node_name) == false) { - return false; - } - } - // check if the node has dynamic shape - for (size_t i = 0; i < cnode->size() - 1; i++) { - if (!cnode->input(i + 1)->isa() && !cnode->input(i + 1)->isa() && - cb->GetInputShape(cnode, i).size() == 0) { - return false; - } - } - return true; -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_cluster_lite.h b/mindspore-lite/tools/graph_kernel/converter/graph_kernel_cluster_lite.h deleted file mode 100644 index 76faaa3b6f4b9fcabad9582a0c55e055c563a9c9..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_cluster_lite.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_CLUSTER_LITE_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_CLUSTER_LITE_H_ -#include -#include -#include - -#include "backend/common/graph_kernel/core/graph_kernel_cluster.h" -#include "backend/common/graph_kernel/core/graph_kernel_callback.h" -#include "ir/func_graph.h" - -namespace mindspore::graphkernel { -class GraphKernelClusterLite : public GraphKernelCluster { - public: - GraphKernelClusterLite() { - auto cb = Callback::Instance(); - if (cb != nullptr) { - device_ = cb->GetTargetFromContext(); - } - } - ~GraphKernelClusterLite() override = default; - - protected: - std::vector GetClusterableOpList() override; - bool IsClusterableOp(const AnfNodePtr &node) override; - - private: - std::string device_; -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_CLUSTER_LITE_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_expander_lite.cc b/mindspore-lite/tools/graph_kernel/converter/graph_kernel_expander_lite.cc deleted file mode 100644 index 0d0c34ff01b6e70d2cd96e7b7ab1d82b47b858fe..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_expander_lite.cc +++ /dev/null @@ -1,288 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/graph_kernel/converter/graph_kernel_expander_lite.h" - -#include "mindspore/ops/op_def/conv_pool_ops.h" -#include "mindspore/ops/op_def/nn_ops.h" -#include "mindspore/ops/op_def/math_ops.h" -#include "mindspore/ops/op_def/lite_ops.h" -#include "mindspore/ops/op_def/array_ops.h" -#include "mindspore/ops/op_def/framework_ops.h" -#include "mindspore/ops/op_def/nn_optimizer_ops.h" -#include "backend/common/graph_kernel/model/node.h" -#include "backend/common/graph_kernel/model/op_node.h" -#include "backend/common/graph_kernel/core/graph_kernel_callback.h" -#include "backend/common/graph_kernel/core/graph_kernel_utils.h" -#include "backend/common/graph_kernel/core/graph_builder.h" -#include "backend/common/graph_kernel/graph_kernel_flags.h" -#include "utils/anf_utils.h" -#include "tools/graph_kernel/converter/basic_op_infer_shape.h" -#include "utils/ms_context.h" -#include "tools/graph_kernel/converter/preprocess_weight.h" -#include "tools/graph_kernel/common/utils.h" -#include "utils/check_convert_utils.h" -#include "common/kernel_build_info.h" -#include "include/backend/kernel_info.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_d.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_e.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_f.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_g.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_l.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_s.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" - -namespace mindspore::graphkernel { -AnfNodePtr FixFormatDeco::Run(const AnfNodePtr &node) { - auto cnode = QuickCloneCNode(node); - std::vector format = GetFixedFormat(node); - auto current_kernel_build_info = GetKernelInfo(node); - if (current_kernel_build_info == nullptr) { - MS_LOG(ERROR) << "Kernel info from " << cnode->fullname_with_scope() << "is nullptr"; - return nullptr; - } - auto ori_format = current_kernel_build_info->GetAllOutputFormats(); - current_kernel_build_info->SetOutputsFormat(format); - auto ret = decorated_->Run(cnode); - if (ret == nullptr) { - return nullptr; - } - auto fg = GetCNodeFuncGraph(ret); - for (auto sub_cnode : fg->GetOrderedCnodes()) { - SetAnfKernelInfoFormatFromAToB(node, sub_cnode, ori_format); - } - auto ret_node = ret->cast(); - SetAnfKernelInfoFormatFromAToB(node, ret_node, ori_format); - return ret; -} - -std::vector FixFormatDeco::GetFixedFormat(const AnfNodePtr &node) const { - auto cnode = node->cast(); - auto out_num = AnfUtils::GetOutputTensorNum(cnode); - std::vector format(out_num, kOpFormat_DEFAULT); - return format; -} - -std::vector UseInputFormatDeco::GetFixedFormat(const AnfNodePtr &node) const { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::vector format; - for (size_t i = 1; i < cnode->size(); i++) { - if (cnode->input(i)->isa()) { - auto kernel_with_index = AnfUtils::VisitKernel(cnode->input(i), 0); - auto input_cnode = kernel_with_index.first->cast(); - auto input_kernel_build_info = GetKernelInfo(input_cnode); - if (input_cnode != nullptr && input_kernel_build_info != nullptr) { - auto input_format = input_kernel_build_info->GetOutputFormat(kernel_with_index.second); - format.push_back(input_format); - break; - } - } - } - if (format.empty()) { - format.push_back(kOpFormat_DEFAULT); - } - return format; -} - -AnfNodePtr InferValueDeco::Run(const AnfNodePtr &node) { - // operators must infer value - std::unordered_set akg_exclude_nodes = {prim::kPrimGather->name(), prim::kPrimShape->name(), - prim::kPrimConcat->name(), prim::kPrimConstantOfShape->name(), - "StridedSliceOnnx"}; - auto cnode = QuickCloneCNode(node); - auto ret = decorated_->Run(cnode); - if (ret == nullptr) { - return nullptr; - } - auto fg = GetCNodeFuncGraph(ret); - AnfNodePtrList inputs = ret->cast()->inputs(); - inner::LiteGraphPtr litegraph = GkUtils::AnfGraph2LiteGraph(fg); - auto ops_list = litegraph->GetOrderedNodes(); - auto iter = ops_list.begin(); - while (iter != ops_list.end()) { - auto this_op = std::static_pointer_cast(*iter); - auto value = this_op->InferValue(this_op->inputs(), this_op->attrs()); - if (value != nullptr) { - (*iter)->ReplaceWith(value); - ops_list = litegraph->GetOrderedNodes(); - iter = ops_list.begin(); - } else { - ++iter; - } - } - auto &outputs = litegraph->GetOutputs(); - std::vector output_const; - for (auto &output : outputs) { - if (output->NodeType() == inner::NType::Tensor) { - auto value = std::static_pointer_cast(outputs[0])->data(); - auto valuenode = NewValueNode(value); - valuenode->set_abstract(value->ToAbstract()); - (void)output_const.emplace_back(valuenode); - } - } - if (outputs.size() == output_const.size()) { - return node->func_graph()->NewCNode(output_const); - } - bool cannot_expand = std::any_of(ops_list.begin(), ops_list.end(), [&akg_exclude_nodes](const inner::NodePtr &node) { - return akg_exclude_nodes.count(std::static_pointer_cast(node)->op()) > 0; - }); - if (cannot_expand) { - return nullptr; - } else { - auto new_fg = GkUtils::LiteGraph2AnfGraph(litegraph, Callback::Instance()); - (void)ConvertTensorToParameter(new_fg, &inputs); - AnfNodePtrList new_inputs = {NewValueNode(new_fg)}; - (void)new_inputs.insert(new_inputs.end(), inputs.cbegin() + 1, inputs.cend()); - return node->func_graph()->NewCNode(new_inputs); - } -} - -AnfNodePtr PoolLayoutDeco::Run(const AnfNodePtr &node) { - MS_CHECK_TRUE_MSG(node != nullptr, nullptr, "node is a nullptr."); - auto cnode = QuickCloneCNode(node); - MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "cnode is a nullptr."); - auto prev_node = AnfUtils::VisitKernel(node->cast()->input(1), 0).first; - if (prev_node != nullptr) { - auto sub_graph = GetCNodeFuncGraph(prev_node); - if (sub_graph != nullptr) { - auto sub_nodes = TopoSort(sub_graph->get_return()); - for (auto sub_node : sub_nodes) { - if (IsPrimitiveCNode(sub_node, prim::kPrimConv2D)) { - AnfUtils::SetNodeAttr("layout_axis", GetCNodePrimitive(sub_node)->GetAttr("weight_coi"), cnode); - break; - } - } - } - } - return decorated_->Run(cnode); -} - -std::vector GraphKernelExpanderLite::ConvTuningExpanderOps() { - std::vector conv_tuning_ops = {prim::kPrimConv2DFusion, prim::kPrimAvgPoolFusion, - prim::kPrimMaxPoolFusion}; - return conv_tuning_ops; -} - -bool GraphKernelExpanderLite::DisableConvTuning() { - const auto &flags = GraphKernelFlags::GetInstance(); - auto flag_enable_only_list = flags.enable_expand_ops_only; - auto flag_disable_list = flags.disable_expand_ops; - return std::find(flag_disable_list.begin(), flag_disable_list.end(), prim::kPrimConv2DFusion->name()) != - flag_disable_list.end() || - (!flag_enable_only_list.empty() && - std::find(flag_enable_only_list.begin(), flag_enable_only_list.end(), prim::kPrimConv2DFusion->name()) == - flag_enable_only_list.end()) || - !flags.enable_lite_conv_tuning; -} - -std::vector GraphKernelExpanderLite::InitOpList() { - std::vector expand_ops_with_level = { - {kAllTarget, OpLevel_0, prim::kPrimGeLU}, - {kAllTarget, OpLevel_0, prim::kPrimSquare}, - {kAllTarget, OpLevel_0, prim::kPrimSquaredDifference}, - {kAllTarget, OpLevel_0, prim::kPrimTile}, - // ascend device - {kAscendDevice, OpLevel_0, prim::kPrimReduceMean}, - {kAscendDevice, OpLevel_0, prim::kPrimTile}, - {kAscendDevice, OpLevel_1, prim::kPrimLayerNorm}, - {kAscendDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogits}, - {kAscendDevice, OpLevel_0, prim::kPrimSquaredDifference}, - {kAscendDevice, OpLevel_0, prim::kPrimSquareSumAll}, - {kAscendDevice, OpLevel_1, prim::kPrimSoftsign}, - // cpu device - {kCPUDevice, OpLevel_1, prim::kPrimExpandDims}, - {kCPUDevice, OpLevel_1, prim::kPrimSqueeze}, - {kCPUDevice, OpLevel_1, prim::kPrimTranspose}, - {kCPUDevice, OpLevel_1, prim::kPrimReshape}, - {kCPUDevice, OpLevel_1, prim::kPrimGather}, - {kCPUDevice, OpLevel_1, prim::kPrimShape}, - {kCPUDevice, OpLevel_1, prim::kPrimConcat}, - {kCPUDevice, OpLevel_1, prim::kPrimFusedBatchNorm}, - {kCPUDevice, OpLevel_1, prim::kPrimSoftmax}, - {kCPUDevice, OpLevel_0, prim::kPrimAddFusion}, - {kCPUDevice, OpLevel_0, prim::kPrimMulFusion}, - {kCPUDevice, OpLevel_0, prim::kPrimSubFusion}, - {kCPUDevice, OpLevel_1, prim::kPrimReduceFusion}, - {kCPUDevice, OpLevel_0, prim::kPrimActivation}, - {kCPUDevice, OpLevel_0, prim::kPrimDivFusion}, - {kCPUDevice, OpLevel_0, prim::kPrimExpFusion}, - {kCPUDevice, OpLevel_1, prim::kPrimUnsqueeze}, - {kCPUDevice, OpLevel_1, prim::kPrimConstantOfShape}, - {kCPUDevice, OpLevel_1, prim::kPrimLayerNormFusion}, - {kCPUDevice, OpLevel_1, prim::kPrimInstanceNorm}, - {kCPUDevice, OpLevel_1, prim::kPrimStridedSlice}, - {kCPUDevice, OpLevel_1, prim::kPrimScaleFusion}}; - const auto &flags = GraphKernelFlags::GetInstance(); - auto valid_op_list = GkUtils::GetValidOps(expand_ops_with_level, flags.fusion_ops_level, flags.enable_expand_ops_only, - flags.enable_expand_ops, flags.disable_expand_ops); - return valid_op_list; -} - -bool GraphKernelExpanderLite::CanExpand(const CNodePtr &node) const { - if (!GraphKernelExpander::CanExpand(node)) { - return false; - } - auto cb = Callback::Instance(); - for (size_t i = 0; i < node->size() - 1; i++) { - if (!node->input(i + 1)->isa() && !node->input(i + 1)->isa() && - cb->GetInputShape(node, i).size() == 0) { - MS_LOG(INFO) << "cnode with no input info can not expand now, node is " << node->fullname_with_scope(); - return false; - } - } - return true; -} - -ExpanderPtr GraphKernelExpanderLite::InitExpander(const AnfNodePtr &node) { - auto expander = std::make_shared(Callback::Instance()); - ExpanderCreatorFuncList decos = {InferValueDeco::Creator}; - std::map creators = { - {prim::kPrimReduceFusion->name(), {DependValueDeco::GetCreator({1}), FixFormatDeco::Creator}}, - {prim::kPrimExpandDims->name(), {{DependValueDeco::GetCreator({1})}, FixFormatDeco::Creator}}, - {prim::kPrimUnsqueeze->name(), {FixFormatDeco::Creator}}, - {prim::kPrimSqueeze->name(), {FixFormatDeco::Creator}}, - {prim::kPrimShape->name(), {FixFormatDeco::Creator}}, - {prim::kPrimReshape->name(), {DependValueDeco::GetCreator({1}), FixFormatDeco::Creator}}, - {prim::kPrimConstantOfShape->name(), {DependValueDeco::GetCreator({0}), FixFormatDeco::Creator}}, - {prim::kPrimTranspose->name(), {DependValueDeco::GetCreator({1})}}, - {prim::kPrimGather->name(), {DependValueDeco::GetCreator({2}), FixFormatDeco::Creator}}, - {prim::kPrimReduceMean->name(), {DependValueDeco::GetCreator({1}), FixFormatDeco::Creator}}, - {prim::kPrimConcat->name(), {FixFormatDeco::Creator}}, - {prim::kPrimStridedSlice->name(), {FixFormatDeco::Creator}}, - {prim::kPrimMatMulFusion->name(), {MatmulPackB::Creator}}, - {prim::kPrimTile->name(), {{DependValueDeco::GetCreator({1})}, UseInputFormatDeco::Creator}}, - }; - auto iter = creators.find(GetCNodePrimitive(node)->name()); - if (iter != creators.end()) { - (void)decos.insert(decos.end(), iter->second.begin(), iter->second.end()); - } - return WrapExpander(expander, decos); -} - -void GraphKernelExpanderLite::PreProcessAllNode(const CNodePtr &node) { - if (Callback::Instance()->GetTargetFromContext() == "CPU" && !AnfUtils::IsGraphKernel(node)) { - BasicOpInferShape().Infer(node); - } -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_expander_lite.h b/mindspore-lite/tools/graph_kernel/converter/graph_kernel_expander_lite.h deleted file mode 100644 index 053a2e110b9be514a02b6cd21fcd49c1eb79c8de..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_expander_lite.h +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_EXPANDER_LITE_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_EXPANDER_LITE_H_ -#include -#include -#include - -#include "backend/common/graph_kernel/core/graph_kernel_expander.h" -#include "ir/func_graph.h" -#include "utils/hash_set.h" - -namespace mindspore::graphkernel { -class FixFormatDeco : public ExpanderDecorator { - public: - explicit FixFormatDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {} - ~FixFormatDeco() = default; - static ExpanderPtr Creator(const ExpanderPtr &decorated) { - return std::static_pointer_cast(std::make_shared(decorated)); - } - AnfNodePtr Run(const AnfNodePtr &node) override; - - protected: - virtual std::vector GetFixedFormat(const AnfNodePtr &) const; -}; - -class UseInputFormatDeco : public FixFormatDeco { - public: - explicit UseInputFormatDeco(const ExpanderPtr &decorated) : FixFormatDeco(decorated) {} - ~UseInputFormatDeco() = default; - static ExpanderPtr Creator(const ExpanderPtr &decorated) { - return std::static_pointer_cast(std::make_shared(decorated)); - } - - protected: - std::vector GetFixedFormat(const AnfNodePtr &node) const override; -}; - -class InferValueDeco : public ExpanderDecorator { - public: - explicit InferValueDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {} - ~InferValueDeco() = default; - static ExpanderPtr Creator(const ExpanderPtr &decorated) { - return std::static_pointer_cast(std::make_shared(decorated)); - } - AnfNodePtr Run(const AnfNodePtr &node) override; -}; - -class PoolLayoutDeco : public ExpanderDecorator { - public: - explicit PoolLayoutDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {} - ~PoolLayoutDeco() = default; - static ExpanderPtr Creator(const ExpanderPtr &decorated) { - return std::static_pointer_cast(std::make_shared(decorated)); - } - AnfNodePtr Run(const AnfNodePtr &node) override; -}; - -class GraphKernelExpanderLite : public GraphKernelExpander { - public: - GraphKernelExpanderLite() : GraphKernelExpander() {} - explicit GraphKernelExpanderLite(const std::string &name) : GraphKernelExpander(name) {} - ~GraphKernelExpanderLite() override = default; - - protected: - bool DisableConvTuning(); - std::vector ConvTuningExpanderOps(); - std::vector InitOpList() override; - ExpanderPtr InitExpander(const AnfNodePtr &node) override; - bool CanExpand(const CNodePtr &node) const override; - void PreProcessAllNode(const CNodePtr &node) override; -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_EXPANDER_LITE_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_optimization.cc b/mindspore-lite/tools/graph_kernel/converter/graph_kernel_optimization.cc deleted file mode 100644 index f04e0b8cbd74ca9399f39c08a4e779f2d2fdaf25..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_optimization.cc +++ /dev/null @@ -1,281 +0,0 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/graph_kernel/converter/graph_kernel_optimization.h" - -#include -#include -#include -#include -#include -#include "include/backend/optimizer/graph_optimizer.h" -#include "ir/func_graph.h" -#include "mindspore/ops/op_def/array_ops.h" - -#include "backend/common/graph_kernel/core/arithmetic_simplify.h" -#include "backend/common/graph_kernel/core/eliminate_redundant_output.h" -#include "backend/common/graph_kernel/core/graph_kernel_utils.h" -#include "backend/common/graph_kernel/core/shape_ops_splitter.h" -#include "backend/common/graph_kernel/core/transform_op_optimizer.h" -#include "backend/common/graph_kernel/core/update_state_formatter.h" -#include "backend/common/graph_kernel/graph_kernel_flags.h" -#include "backend/common/graph_kernel/core/graph_kernel_op_combiner.h" -#include "backend/common/graph_kernel/core/split_reshape_and_cache.h" -#include "backend/common/graph_kernel/core/cluster_cce_lib_ops.h" - -#include "tools/graph_kernel/converter/akg/utils.h" -#include "tools/graph_kernel/converter/callback_impl.h" -#include "tools/graph_kernel/converter/conv_tuning_expander.h" -#include "tools/graph_kernel/converter/eliminate_maketuple_getitem.h" -#include "tools/graph_kernel/converter/format_recognition.h" -#include "tools/graph_kernel/converter/graph_kernel_cluster_lite.h" -#include "tools/graph_kernel/converter/graph_kernel_expander_lite.h" -#include "tools/graph_kernel/converter/graph_kernel_splitter_lite.h" -#include "tools/graph_kernel/converter/kernel_builder.h" -#include "tools/graph_kernel/converter/parameter_to_tensor.h" -#include "tools/graph_kernel/converter/basic_op_infer_shape.h" -#include "tools/graph_kernel/converter/rename_fullname_with_scope.h" -#include "tools/graph_kernel/converter/update_kernel_info.h" -#include "tools/graph_kernel/converter/split_model_ascend.h" -#include "tools/graph_kernel/converter/split_model_gpu.h" -#include "tools/graph_kernel/converter/split_model_cpu.h" -#include "utils/ms_context.h" -#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" - -namespace mindspore { -namespace graphkernel { -using opt::GraphOptimizer; -constexpr size_t kStagePreProcess = 0; -constexpr size_t kStageCluster = 1; -constexpr size_t kStageHLO1 = 2; -constexpr size_t kStageSplit = 3; -constexpr size_t kStageBuildKernel = 4; -constexpr size_t kStagePostProcess = 5; - -class EmptyPass : public opt::Pass { - public: - EmptyPass() : Pass("empty_pass") {} - ~EmptyPass() override = default; - bool Run(const FuncGraphPtr &func_graph) override { return false; } -}; - -void GraphKernelOptimizer::Init() const { - // register call back - const CallbackImplRegister callback_reg( - [this]() { return std::static_pointer_cast(std::make_shared(converter_param_)); }); - - // register split model here to ensure that the correct split model will be invoked - // when import mindspore and lite in the same process - auto device = Callback::Instance()->GetTargetFromContext(true); - SPLIT_MODEL_REGISTER(kAscendDevice, inner::SplitModelAscend, device); - SPLIT_MODEL_REGISTER(kGPUDevice, inner::SplitModelGpu); - SPLIT_MODEL_REGISTER(kCPUDevice, inner::SplitModelCpu); -} - -GkPassManagerPtr GraphKernelOptimizer::PreProcess() const { - auto pm = std::make_shared(kStagePreProcess, "preprocess"); - - // put an empty pass here to dump the ir before GraphKernel - pm->Add(std::make_shared(), OptLevel_1); - - // Recognize the formats for all CNodes - pm->Add(std::make_shared(), OptLevel_1); - - // Dynamic infer shape - pm->Add(std::make_shared(), OptLevel_1, is_ascend); - - // Convert the const parameters to const tensors - pm->Add(std::make_shared(), OptLevel_1, is_cpu); - return pm; -} - -GkPassManagerPtr GraphKernelOptimizer::Cluster() const { - auto pm = std::make_shared(kStageCluster, "cluster"); - - // Expand complex basic kernels to composite kernels - pm->Add(std::make_shared(), OptLevel_1); - - pm->Add(std::make_shared(), OptLevel_0, is_ascend); - if (graphkernel::GraphKernelFlags::GetInstance().enable_cce_lib) { - // akg cce lib ops - pm->Add(std::make_shared(), OptLevel_0, is_ascend); - } - // Combine supported parallel ops that with common inputs - pm->Add(std::make_shared(), OptLevel_3); - - // Cluster basic kernels and composite kernels - pm->Add(std::make_shared(), OptLevel_1); - - // Eliminate the outputs without external user - pm->Add(std::make_shared(), OptLevel_1); - return pm; -} - -GkPassManagerPtr GraphKernelOptimizer::HighLevelOpt1() const { - auto pm = std::make_shared(kStageHLO1, "highlevelopt1"); - pm->Add(std::make_shared(), OptLevel_2); - // Eliminate redundant transform ops - pm->Add(std::make_shared(), OptLevel_2); - return pm; -} - -GkPassManagerPtr GraphKernelOptimizer::Split() const { - auto pm = std::make_shared(kStageSplit, "split"); - // Make certain nodes redundant so that they are used by only one user, - // which can avoid unnecessary input-output and get better performance. - // preprocess for ShapeOpsSplitter - pm->Add(std::make_shared(), OptLevel_1, is_cpu); - std::vector duplicated_ops = {prim::kPrimReshape}; - pm->Add(std::make_shared(duplicated_ops), OptLevel_1); - // Split kernel according to costmodel and distinguish splitter patterns based on device - pm->Add(std::make_shared(), OptLevel_1); - - // After Simplify and Splitter, a lot of redundant getitem/maketuple - // will be exposed, use ElimMaketupleGetitem Pass to delete them. - pm->Add(std::make_shared(), OptLevel_1); - - // Eliminate the redundant node that is copied above but not handled by GraphKernelSplitter - pm->Add(std::make_shared(), OptLevel_1, is_cpu); - pm->Add(std::make_shared(), OptLevel_1); - return pm; -} - -GkPassManagerPtr GraphKernelOptimizer::BuildKernel() const { - auto pm = std::make_shared(kStageBuildKernel, "buildkernel"); - // build akg and replace graph kernel nodes - pm->Add(std::make_shared(), OptLevel_1); - return pm; -} - -GkPassManagerPtr GraphKernelOptimizer::PostProcess() const { - auto pm = std::make_shared(kStagePostProcess, "postprocess"); - pm->Add(std::make_shared(), OptLevel_1); - pm->Add(std::make_shared(), OptLevel_1); - return pm; -} - -std::unordered_set CheckSupport() { - std::unordered_set support_backend; -#ifdef AKG_USE_LLVM - (void)support_backend.emplace("CPU"); -#endif -#if defined(AKG_USE_LLVM) && defined(AKG_USE_CUDA) - (void)support_backend.emplace("GPU"); -#endif -#ifdef AKG_ENABLE_D - (void)support_backend.emplace("Ascend"); -#endif - return support_backend; -} - -bool CheckAkg() { - std::ostringstream py_cmd; - py_cmd << kAddMSLiteAkg; - py_cmd << "from akg.ms import compilewithjsonname\n"; - std::string cmd = "python -c \"" + py_cmd.str() + "\""; - auto ret = std::system(cmd.c_str()); - if (WEXITSTATUS(ret) != 0) { - MS_LOG(WARNING) - << "Could not find akg in python, Graph Kernel fusion has been turned off. Please make sure you have " - "installed akg. process content is as follows:\n" - << cmd; - return false; - } - return true; -} - -void GraphKernelOptimizer::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - if (converter_param_ == nullptr) { - converter_param_ = std::make_shared(); - converter_param_->device = "Ascend"; - } - if (!CheckAkg()) { - return; - } - Init(); - auto akg_support_backend = CheckSupport(); - auto device = Callback::Instance()->GetTargetFromContext(); - if (akg_support_backend.find(device) == akg_support_backend.end()) { - MS_LOG(WARNING) - << "Graph Kernel fusion for" << device - << " backend is not enabled, Graph Kernel fusion has been turned off. Please check if you have set the " - "correct compilation options for " - << device << " backend."; - return; - } - is_cpu = (device == "CPU"); - is_ascend = (device == "Ascend"); - if (is_cpu && common::AnfAlgo::IsDynamicGraph(func_graph)) { - MS_LOG(INFO) << "skip cpu dynamic graph"; - return; - } - - auto optimizer = std::make_shared("graph_kernel_optimizer"); - optimizer->AddPassManager(PreProcess()); - optimizer->AddPassManager(Cluster()); - optimizer->AddPassManager(HighLevelOpt1()); - optimizer->AddPassManager(Split()); - optimizer->AddPassManager(BuildKernel()); - optimizer->AddPassManager(PostProcess()); - - auto mng = func_graph->manager(); - if (mng == nullptr) { - mng = Manage(func_graph, true); - func_graph->set_manager(mng); - } - GkUtils::UpdateFuncGraphManager(mng, func_graph); - (void)optimizer->Optimize(func_graph); -} - -std::string UpdateFlags(const std::string &device, const std::string &graph_kernel_flags) { - std::string res = graph_kernel_flags; - if (device.find("Ascend") != std::string::npos && - graph_kernel_flags.find("enable_dynamic_shape_fusion") == std::string::npos) { - res += " --enable_dynamic_shape_fusion"; - } - return res; -} -} // namespace graphkernel - -lite::STATUS GraphKernelOptimize(const FuncGraphPtr &func_graph, const std::shared_ptr ¶m) { -#ifndef Debug - try { -#endif - if (param == nullptr) { - return lite::RET_OK; - } - if (param->graphKernelParam.graph_kernel_flags.empty()) { - return lite::RET_OK; - } - std::map jit_config; - jit_config["graph_kernel_flags"] = - graphkernel::UpdateFlags(param->device, param->graphKernelParam.graph_kernel_flags); - graphkernel::GraphKernelFlags::SaveJitConfig(jit_config); - - if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { - MS_LOG(INFO) << "Run graphkernel optimization begin."; - graphkernel::GraphKernelOptimizer(param).Run(func_graph); - MS_LOG(INFO) << "Run graphkernel optimization end."; - } - return lite::RET_OK; -#ifndef Debug - } catch (const std::exception &e) { - MS_LOG(ERROR) << e.what(); - return lite::RET_ERROR; - } -#endif -} -} // namespace mindspore diff --git a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_pass_manager_lite.cc b/mindspore-lite/tools/graph_kernel/converter/graph_kernel_pass_manager_lite.cc deleted file mode 100644 index 22b2546e31d834e8ea37cddb9bcf3a82134be75f..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_pass_manager_lite.cc +++ /dev/null @@ -1,539 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/graph_kernel/converter/graph_kernel_pass_manager_lite.h" - -#include -#include -#include -#include -#include -#include -#include "ir/graph_utils.h" -#include "src/common/file_utils.h" -#include "utils/file_utils.h" -#include "src/common/utils.h" -#include "utils/anf_utils.h" -#include "tools/graph_kernel/common/utils.h" - -namespace mindspore::graphkernel { -namespace dumpir { -struct SubGraphIRInfo { - int32_t local_var; - std::ostringstream dumpbuf; - OrderedMap local_var_map; -}; - -void DumpGlobalInfoEntry(const FuncGraphPtr &graph, std::ostringstream &dumpbuf) { - if (graph == nullptr) { - return; - } - dumpbuf << "#IR entry : @" << graph->ToString() << std::endl; - dumpbuf << "#attrs :" << std::endl; - for (const auto &attr : graph->attrs()) { - dumpbuf << attr.first << " : "; - if (attr.second->isa()) { - dumpbuf << (GetValue(attr.second)); - } else if (attr.second->isa()) { - dumpbuf << (GetValue(attr.second)); - } - dumpbuf << std::endl; - } -} - -void PrintNodeOutputType(std::ostringstream &dumpbuf, const AnfNodePtr &nd) { - if (nd == nullptr) { - return; - } - ValuePtr tensor_value = nullptr; - auto abstract = nd->abstract(); - if (abstract != nullptr && abstract->isa()) { - tensor_value = abstract->BuildValue(); - } - abstract::ShapePtr shape = dyn_cast(nd->Shape()); - TypePtr type = dyn_cast(nd->Type()); - if ((shape != nullptr) && (type != nullptr)) { - dumpbuf << "<" << type << ", " << shape->ToString(); - if (tensor_value != nullptr && tensor_value != kValueAny) { - dumpbuf << ", value=..."; - } - dumpbuf << ">"; - } else if (type != nullptr) { - dumpbuf << "<" << type; - if (tensor_value != nullptr && tensor_value != kValueAny) { - dumpbuf << ", value=..."; - } - dumpbuf << ">"; - } else { - dumpbuf << ""; - } -} - -int32_t DumpParams(const FuncGraphPtr &graph, std::ostringstream &dumpbuf, OrderedMap *para_map) { - if (graph == nullptr) { - MS_LOG(INFO) << "Param graph is nullptr."; - return 0; - } - std::vector parameters = graph->parameters(); - dumpbuf << "#Total params : " << parameters.size() << std::endl; - dumpbuf << std::endl; - - // dump parameters - int32_t para = 1; - for (const auto &p : parameters) { - if (p == nullptr) { - continue; - } - auto parameter_ptr = p->cast(); - if (parameter_ptr == nullptr) { - MS_LOG(EXCEPTION) << "p cannot cast to ParameterPtr"; - } - dumpbuf << "%para" << para << "_" << parameter_ptr->name() << " : "; - // print parameters' type and shape - PrintNodeOutputType(dumpbuf, p); - dumpbuf << std::endl; - - if (para_map != nullptr) { - (*para_map)[p] = para++; - } - MS_LOG(DEBUG) << "Record param: " << p->ToString() << " graph belong : " << p->func_graph()->ToString(); - } - return para; -} - -void DumpOperator(const AnfNodePtr &op, const std::shared_ptr &gsub) { - if (op == nullptr || gsub == nullptr) { - return; - } - - if (IsValueNode(op)) { - FuncGraphPtr fg = GetValueNode(op); - if (fg != nullptr) { - gsub->dumpbuf << "call @" << fg->ToString(); - } - } else if (op->isa()) { - if (gsub->local_var_map.find(op) != gsub->local_var_map.end()) { - gsub->dumpbuf << "%" << gsub->local_var_map[op]; - } else { - auto node = op->cast(); - auto fg = node->func_graph(); - gsub->dumpbuf << "$(" << fg->ToString() << ":" << node->ToString() << ")"; - } - } else if (op->isa()) { - gsub->dumpbuf << GetValueNode(op)->ToString(); - } else { - gsub->dumpbuf << op->ToString(); - } -} - -void DumpOperands(const AnfNodePtr &nd, OrderedMap *para_map, - const std::shared_ptr &gsub) { - if (nd == nullptr || para_map == nullptr || gsub == nullptr) { - return; - } - - gsub->dumpbuf << "("; - const auto &inputs = GetInputs(nd); - size_t len = inputs.size(); - if (len > 1) { - // skip inputs[0] which is Primitive valuenode - for (size_t i = 1; i < len; ++i) { - AnfNodePtr in = inputs[i]; - MS_EXCEPTION_IF_NULL(in); - if (i != 1) { - gsub->dumpbuf << ", "; - } - if (in->isa()) { - if (!(*para_map)[in]) { - gsub->dumpbuf << "%para_" << in->ToString(); - } else { - gsub->dumpbuf << "%para" << (*para_map)[in] << "_" << in->ToString(); - } - } else if (in->isa()) { - if (gsub->local_var_map.find(in) != gsub->local_var_map.end()) { - gsub->dumpbuf << "%" << gsub->local_var_map[in]; - } else { - auto node = in->cast(); - auto fg = node->func_graph(); - gsub->dumpbuf << "$(" << fg->ToString() << ":" << node->ToString() << ")"; - } - } else if (in->isa() && !IsValueNode(in)) { - // non Primitive valuenode - gsub->dumpbuf << GetValueNode(in)->ToString(); - } else if (IsValueNode(in)) { - FuncGraphPtr fg = GetValueNode(in); - gsub->dumpbuf << "@" << fg->ToString(); - } else { - gsub->dumpbuf << in->ToString(); - } - } - } - gsub->dumpbuf << ")"; -} - -void DumpAttrs(const mindspore::HashMap &attrs, const std::shared_ptr &gsub, - bool check_strategy = false) { - int i = 0; - for (const auto &attr : attrs) { - if (i++ != 0) { - gsub->dumpbuf << ", "; - } - gsub->dumpbuf << attr.first << ": "; - if (attr.second == nullptr) { - gsub->dumpbuf << "null"; - } else { - gsub->dumpbuf << attr.second->ToString(); - } - } -} - -void DumpOperateAttrs(const AnfNodePtr &op, const std::shared_ptr &gsub) { - if (op == nullptr || gsub == nullptr) { - return; - } - - if (IsValueNode(op)) { - auto primitive = GetValueNode(op); - if (!primitive->instance_name().empty()) { - gsub->dumpbuf << " {"; - gsub->dumpbuf << "instance name" - << ": "; - gsub->dumpbuf << primitive->instance_name(); - gsub->dumpbuf << "}"; - } - auto attrs = primitive->attrs(); - if (!attrs.empty()) { - gsub->dumpbuf << " primitive_attrs: {"; - DumpAttrs(attrs, gsub, true); - gsub->dumpbuf << "}"; - } - } -} - -void DumpCNodeAttrs(const CNodePtr &op, const std::shared_ptr &gsub) { - if (op == nullptr || gsub == nullptr) { - return; - } - auto &attrs = op->attrs(); - if (attrs.empty()) { - return; - } - - gsub->dumpbuf << " cnode_attrs: {"; - DumpAttrs(attrs, gsub); - gsub->dumpbuf << "}"; -} - -void DumpCNodePrimalAttrs(const CNodePtr &op, const std::shared_ptr &gsub) { - if (op == nullptr || gsub == nullptr) { - return; - } - if (op->primal_attrs().empty()) { - gsub->dumpbuf << std::endl; - return; - } - auto primal_attrs = op->primal_attrs(); - gsub->dumpbuf << " cnode_primal_attrs: {"; - DumpAttrs(primal_attrs, gsub); - gsub->dumpbuf << "}"; - gsub->dumpbuf << std::endl; -} - -void PrintNodeInputType(std::ostringstream &dumpbuf, const AnfNodePtr &nd) { - if (nd == nullptr) { - return; - } - const auto &inputs = GetInputs(nd); - size_t len = inputs.size(); - if (len > 1) { - // skip inputs[0] which is Primitive value node - for (size_t i = 1; i < len; ++i) { - AnfNodePtr in = inputs[i]; - if (i != 1) { - dumpbuf << ", "; - } - PrintNodeOutputType(dumpbuf, in); - } - } -} - -void DumpShape(const AnfNodePtr &nd, const FuncGraphPtr &sub_graph, const std::shared_ptr &gsub) { - if (nd == nullptr || sub_graph == nullptr || gsub == nullptr) { - return; - } - - if (nd != sub_graph->get_return()) { - gsub->dumpbuf << " : ("; - PrintNodeInputType(gsub->dumpbuf, nd); - gsub->dumpbuf << ") -> ("; - PrintNodeOutputType(gsub->dumpbuf, nd); - gsub->dumpbuf << ")"; - } else { - gsub->dumpbuf << " : ("; - PrintNodeInputType(gsub->dumpbuf, nd); - gsub->dumpbuf << ")"; - } - - gsub->dumpbuf << std::endl; -} - -std::string PrintKernelFormatAndType(const std::string &fmt, const TypeId &type, const std::vector &shape) { - std::ostringstream buffer; - buffer << "<" << TypeIdLabel(type); - if (!fmt.empty()) { - buffer << "x" << fmt << shape; - } - buffer << ">"; - return buffer.str(); -} - -std::string PrintOutputTypeShapeFormat(const std::shared_ptr &node) { - if (node == nullptr) { - return ""; - } - std::ostringstream buffer; - auto kernel_build_info = GetKernelInfo(node); - if (kernel_build_info == nullptr) { - return ""; - } - size_t output_num = kernel_build_info->GetOutputNum(); - buffer << "OutputFormats:"; - for (size_t i = 0; i < output_num; ++i) { - if (i != 0) { - buffer << ", "; - } - auto format = GetOutputFormatFromAnfNode(node, i); - if (!format.empty()) { - buffer << format; - } - } - return buffer.str(); -} - -void DumpKernelInfo(const CNodePtr &node, const std::shared_ptr &gsub) { - if (node == nullptr || gsub == nullptr) { - return; - } - auto kernel_info = node->kernel_info(); - if (kernel_info == nullptr || !kernel_info->has_build_info()) { - return; - } - gsub->dumpbuf << " : ("; - gsub->dumpbuf << PrintOutputTypeShapeFormat(node); - gsub->dumpbuf << ")"; - gsub->dumpbuf << std::endl; -} - -void DumpCNode(const CNodePtr &nd, const FuncGraphPtr &sub_graph, OrderedMap *const para_map, - const std::shared_ptr &gsub, bool dump_full_name = false) { - if (nd == nullptr || sub_graph == nullptr || para_map == nullptr || gsub == nullptr) { - return; - } - - if (nd != sub_graph->get_return()) { - gsub->dumpbuf << " %" << gsub->local_var << "(" << nd->ToString() << ")" - << " = "; - gsub->local_var_map[nd] = gsub->local_var++; - } else { - gsub->dumpbuf << " "; - } - - if (nd->inputs().empty()) { - MS_LOG(EXCEPTION) << "Input of apply node is empty"; - } - AnfNodePtr op = nd->input(0); - DumpOperator(op, gsub); - DumpOperands(nd, para_map, gsub); - DumpOperateAttrs(op, gsub); - DumpCNodeAttrs(nd, gsub); - DumpCNodePrimalAttrs(nd, gsub); - DumpShape(nd, sub_graph, gsub); - DumpKernelInfo(nd, gsub); - if (dump_full_name) { - gsub->dumpbuf << " : (" << nd->fullname_with_scope() << ")" << std::endl; - } -} - -void DumpIRInSubgraph(const std::vector &nodes, OrderedMap *para_map, - OrderedMap> *const sub_graphs, int32_t total_para, - bool dump_full_name = false) { - if (para_map == nullptr || sub_graphs == nullptr) { - return; - } - - for (const auto &nd : nodes) { - MS_EXCEPTION_IF_NULL(nd); - FuncGraphPtr sub_graph = nd->func_graph(); - if (sub_graph == nullptr) { - MS_LOG(DEBUG) << "Node[" << nd->ToString() << "] belongs to no graph!"; - continue; - } - std::shared_ptr gsub = (*sub_graphs)[sub_graph]; - if (gsub == nullptr) { - gsub = std::make_shared(); - gsub->local_var = 0; - (*sub_graphs)[sub_graph] = gsub; - } - auto ¶m = sub_graph->parameters(); - for (size_t idx = 0; idx < param.size(); idx++) { - MS_EXCEPTION_IF_NULL(param[idx]); - if ((*para_map).count(param[idx]) == 0) { - (*para_map)[param[idx]] = total_para++; - } - } - if (!nd->isa()) { - if (nd->isa()) { - // print and record output of operator if it is not 'Return' - DumpCNode(nd->cast(), sub_graph, para_map, gsub, dump_full_name); - } else { - gsub->dumpbuf << " " << nd->ToString() << std::endl; - } - } - } -} - -void DumpSubgraph(const OrderedMap> *sub_graphs, - const FuncGraphPtr &graph, OrderedMap *para_map, std::ofstream &fout) { - if (sub_graphs == nullptr || graph == nullptr) { - return; - } - - fout << "#Total subgraph : " << sub_graphs->size() << std::endl; - fout << std::endl; - - for (const auto &sg : *sub_graphs) { - fout << "subgraph attr:" << std::endl; - MS_EXCEPTION_IF_NULL(sg.first); - for (const auto &attr : sg.first->attrs()) { - fout << attr.first << " : "; - if (attr.second->isa()) { - fout << GetValue(attr.second); - } else if (attr.second->isa()) { - fout << (GetValue(attr.second)); - } - fout << std::endl; - } - fout << "subgraph @" << sg.first->ToString() << "("; - if (sg.first != graph) { - std::vector parameters = sg.first->parameters(); - if (parameters.size() == 1) { - MS_EXCEPTION_IF_NULL(parameters[0]); - fout << "%para" << (*para_map)[parameters[0]] << "_" << parameters[0]->ToString(); - } else if (parameters.size() > 1) { - for (size_t idx = 0; idx < parameters.size() - 1; idx++) { - MS_EXCEPTION_IF_NULL(parameters[idx]); - fout << "%para" << (*para_map)[parameters[idx]] << "_" << parameters[idx]->ToString(); - fout << ", "; - } - MS_EXCEPTION_IF_NULL(parameters[parameters.size() - 1]); - fout << "%para" << (*para_map)[parameters[parameters.size() - 1]] << "_" - << parameters[parameters.size() - 1]->ToString(); - } - } - fout << ") {" << std::endl; - MS_EXCEPTION_IF_NULL(sg.second); - fout << sg.second->dumpbuf.str(); - fout << "}" << std::endl; - fout << std::endl; - } -} - -std::optional CreatePrefixPath(const std::string &input_path) { - std::optional prefix_path; - std::optional file_name; - FileUtils::SplitDirAndFileName(input_path, &prefix_path, &file_name); - if (!file_name.has_value()) { - MS_LOG(ERROR) << "Cannot get file_name from: " << input_path; - return std::nullopt; - } - auto file_name_str = file_name.value(); - std::string prefix_path_str; - if (prefix_path.has_value()) { - auto create_prefix_path = FileUtils::CreateNotExistDirs(prefix_path.value(), true); - if (!create_prefix_path.has_value()) { - return std::nullopt; - } - prefix_path_str = create_prefix_path.value(); - } else { - auto pwd_path = FileUtils::GetRealPath("./"); - if (!pwd_path.has_value()) { - MS_LOG(ERROR) << "Can not get pwd path"; - return std::nullopt; - } - prefix_path_str = pwd_path.value(); - } - return std::string(prefix_path_str + "/" + file_name_str); -} - -void DumpIR(const std::string &filename, const FuncGraphPtr &graph, bool dump_full_name) { - if (graph == nullptr) { - return; - } - auto path = "./" + filename; - auto realpath = CreatePrefixPath(path); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Get real path failed, path=" << path; - return; - } - - std::ofstream fout(realpath.value()); - std::ostringstream dumpbuf; - - auto nodes = TopoSort(graph->get_return(), SuccDeeperSimple, AlwaysInclude); - OrderedMap para_map; - // dump global info - DumpGlobalInfoEntry(graph, dumpbuf); - int32_t total_para = DumpParams(graph, dumpbuf, ¶_map); - - OrderedMap> sub_graphs; - // dump ir in each sub graph - DumpIRInSubgraph(nodes, ¶_map, &sub_graphs, total_para, dump_full_name); - - // output global info - fout << dumpbuf.str() << std::endl; - - // output each sub graph - DumpSubgraph(&sub_graphs, graph, ¶_map, fout); - - fout.close(); -} -} // namespace dumpir - -void GraphKernelPassManagerLite::DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const { - static bool dump_ir = (common::GetEnv("MS_DEV_DUMP_GRAPH_KERNEL_IR") == "on"); - if (dump_ir) { - static std::string rank_id = common::GetEnv("RANK_ID"); - std::string filename; - if (rank_id.empty()) { - filename = "verbose_ir_files/" + pass_fullname + ".ir"; - } else { - filename = "rank_" + rank_id + "/verbose_ir_files/" + pass_fullname + ".ir"; - } - dumpir::DumpIR(filename, func_graph, true); - } -} - -// transplant this function from pass_manager_extends.cc because the implement was moved to PassManagerLite. -bool GraphKernelPassManagerLite::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const { - bool changed = false; - auto begin_time = lite::GetTimeUs(); - if (pass->Run(func_graph)) { - changed = true; - } - auto end_time = lite::GetTimeUs(); - MS_LOG(INFO) << "Run pass " << GetPassFullname(pass_id, pass) << " in " << (end_time - begin_time) << " us."; - return changed; -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_pass_manager_lite.h b/mindspore-lite/tools/graph_kernel/converter/graph_kernel_pass_manager_lite.h deleted file mode 100644 index 1654b20c6dfee1e69a460d64f7f395127ccc8e83..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_pass_manager_lite.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_PASS_MANAGER_LITE_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_PASS_MANAGER_LITE_H_ - -#include -#include -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "backend/common/graph_kernel/core/graph_kernel_pass_manager.h" - -namespace mindspore::graphkernel { -using opt::PassPtr; -class GraphKernelPassManagerLite : public GraphKernelPassManager { - public: - using GraphKernelPassManager::GraphKernelPassManager; - - protected: - void DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const override; - bool RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const override; -}; -using GkPassManagerPtr = std::shared_ptr; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_PASS_MANAGER_LITE_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_splitter_lite.cc b/mindspore-lite/tools/graph_kernel/converter/graph_kernel_splitter_lite.cc deleted file mode 100644 index 495df85e28dad8463f80b9c4d5f6c191f70b89c2..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_splitter_lite.cc +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/graph_kernel/converter/graph_kernel_splitter_lite.h" -#include -#include -#include -#include -#include "utils/system/env.h" -#include "utils/file_utils.h" -#include "utils/anf_utils.h" -#include "backend/common/graph_kernel/graph_kernel_flags.h" -#include "backend/common/graph_kernel/core/tuning_splitter.h" -#include "tools/graph_kernel/converter/akg/akg_kernel_builder.h" -#include "backend/common/graph_kernel/core/graph_kernel_utils.h" -#include "tools/graph_kernel/converter/akg/utils.h" -#include "utils/ms_context.h" - -namespace mindspore::graphkernel { -SplitSchemerPtr GraphKernelSplitterWithTuning::GetSplitSchema(const std::string &processor) { - if (tuning_flag_) { - return std::make_shared(tuning_path_); - } - return GraphKernelSplitter::GetSplitSchema(processor); -} - -bool GraphKernelSplitterWithTuning::StartTuning(const std::string &dir_path) const { - std::ostringstream attrs; - attrs << "{"; - attrs << "\'repository_path\':\'" << dir_path << "\'"; - if (common::GetEnv("MS_DEV_GRAPH_KERNEL_SPLIT_DEBUG_TUNING") != "on") { - attrs << ",\'online_tuning\':" << GraphKernelFlags::GetInstance().online_tuning; - } - attrs << "}"; - std::ostringstream py_cmd; - std::string tune_interface = "poly_graph_split_with_json_dir"; - py_cmd << kAddMSLiteAkg; - py_cmd << "from akg.ms import " << tune_interface << "\n"; - py_cmd << "if not " << tune_interface << "(\'" << dir_path << "\', " << attrs.str() << "):\n"; - py_cmd << " raise RuntimeError(\'Tune fail. info path: " << dir_path << "\')"; - std::string cmd = "python -c \"" + py_cmd.str() + "\""; - auto ret = std::system(cmd.c_str()); - if (!WIFEXITED(ret)) { - MS_LOG(ERROR) << "Python process start fail! process content is as follows:\n" << cmd; - return false; - } - if (WEXITSTATUS(ret) != 0) { - MS_LOG(ERROR) << "Failed to tune kernel: " << dir_path; - return false; - } - return true; -} - -void SignTunedGraphs(const FuncGraphPtr &func_graph) { - auto kernel_meta = FileUtils::GetRealPath("./akg_kernel_meta/"); - if (!kernel_meta.has_value()) { - return; - } - auto fs = system::Env::GetFileSystem(); - MS_EXCEPTION_IF_NULL(fs); - DumpOption option = AkgKernelBuilder::json_option(); - option.gen_kernel_name_only = true; - - auto todos = TopoSort(func_graph->get_return()); - for (const auto &node : todos) { - if (!AnfUtils::IsGraphKernel(node)) { - continue; - } - auto fg = GetCNodeFuncGraph(node); - if (!fg->has_attr(kAttrNodeName)) { - continue; - } - auto node_name = GetValue(fg->get_attr(kAttrNodeName)); - auto kernel_obj = kernel_meta.value() + "/best_split_" + node_name + ".o"; - if (fs->FileExist(kernel_obj)) { - // sign the funcgraph with its current kernel name, the tuned result can be used if - // its kernel name is the same as the signature when building kernels. - GraphKernelJsonGenerator json_generator(option); - AnfNodePtrList node_list; - AnfNodePtrList input_list; - AnfNodePtrList output_list; - GkUtils::GetValidKernelNodes(fg, &node_list, &input_list, &output_list); - (void)json_generator.CollectFusedJson(node_list, input_list, output_list); - fg->set_attr(kTunedSign, MakeValue(json_generator.kernel_name())); - MS_LOG(INFO) << "The " << kernel_obj << " is the tuning result of " << json_generator.kernel_name(); - } - } -} - -bool GraphKernelSplitterWithTuning::Run(const FuncGraphPtr &func_graph) { - if (Callback::Instance()->GetTargetFromContext() == kAscendDevice || - GraphKernelFlags::GetInstance().online_tuning == 0) { - tuning_flag_ = false; - return GraphKernelSplitter::Run(func_graph); - } - auto todos = TopoSort(func_graph->get_return()); - AnfNodePtrList gknodes; - std::copy_if(todos.cbegin(), todos.cend(), std::back_inserter(gknodes), AnfUtils::IsGraphKernel); - if (gknodes.empty()) { - return false; - } - std::map node_name; - tuning_path_ = SaveNodesInfo(gknodes, "./split_tuning", AkgKernelBuilder::json_option(), &node_name, nullptr); - if (tuning_path_.empty()) { - tuning_flag_ = false; - } else { - tuning_flag_ = StartTuning(tuning_path_); - } - for (const auto &iter : node_name) { - AnfUtils::SetNodeAttr(kAttrNodeName, MakeValue(iter.second), iter.first); - } - auto changed = GraphKernelSplitter::Run(func_graph); - if (tuning_flag_) { - SignTunedGraphs(func_graph); - } - return changed; -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_splitter_lite.h b/mindspore-lite/tools/graph_kernel/converter/graph_kernel_splitter_lite.h deleted file mode 100644 index c675b52a01687a478df8442e611f1fa31da7765b..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/graph_kernel_splitter_lite.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_SPLITTER_LITE_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_SPLITTER_LITE_H_ -#include -#include "backend/common/graph_kernel/core/split_schemer.h" -#include "backend/common/graph_kernel/core/graph_kernel_splitter.h" - -namespace mindspore::graphkernel { -class GraphKernelSplitterWithTuning : public GraphKernelSplitter { - public: - GraphKernelSplitterWithTuning() = default; - ~GraphKernelSplitterWithTuning() = default; - bool Run(const FuncGraphPtr &func_graph) override; - SplitSchemerPtr GetSplitSchema(const std::string &processor) override; - - protected: - bool StartTuning(const std::string &dir_path) const; - - std::string tuning_path_; - bool tuning_flag_{true}; -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_GRAPH_KERNEL_SPLITTER_LITE_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/kernel_builder.cc b/mindspore-lite/tools/graph_kernel/converter/kernel_builder.cc deleted file mode 100644 index 08d46a308a5446a54b3e4e1be910e96dbf3bd39d..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/kernel_builder.cc +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#define USE_DEPRECATED_API -#include "tools/graph_kernel/converter/kernel_builder.h" - -#include -#include -#include "backend/common/graph_kernel/core/graph_kernel_callback.h" -#include "backend/common/graph_kernel/core/graph_kernel_utils.h" -#include "backend/common/graph_kernel/graph_kernel_flags.h" -#include "tools/graph_kernel/converter/akg/ascend_kernel_builder.h" -#include "tools/graph_kernel/converter/akg/cpu_kernel_builder.h" -#include "tools/graph_kernel/converter/akg/gpu_kernel_builder.h" -#include "src/common/log_adapter.h" -#include "utils/ms_context.h" -#include "common/kernel_build_info.h" -#include "include/backend/kernel_info.h" - -namespace mindspore::graphkernel { -AkgKernelBuilderPtr GetAkgBuilder(const std::string &target) { - if (target == kCPUDevice) { - return std::make_shared(); - } - if (target == kGPUDevice) { - return std::make_shared(); - } - if (target == kAscendDevice) { - return std::make_shared(); - } - MS_LOG(EXCEPTION) << "GraphKernel does not support " << target << " akg builder."; - return nullptr; -} - -bool KernelBuilder::Run(const FuncGraphPtr &func_graph) { - auto node_list = GkUtils::GetGraphKernelNodes(func_graph); - auto device_type = Callback::Instance()->GetTargetFromContext(); - if (node_list.empty()) { - MS_LOG(WARNING) - << "No GraphKernel nodes found in the func_graph, possibly because the input model file does not have any " - "operators that can be fused or the model has inputs with dynamic shapes."; - return false; - } - auto builder = GetAkgBuilder(device_type); - if (!builder->CompileJsonsInAnfnodes(node_list)) { - MS_LOG(EXCEPTION) << "Graph kernel compile fail"; - } - auto manager = Manage(func_graph, true); - MS_EXCEPTION_IF_NULL(manager); - ParameterPtr akg_node = nullptr; - for (auto &node : node_list) { - auto cnode = node->cast(); - auto custom_cnode = builder->CreateCustomOp(func_graph, cnode); - if (custom_cnode == nullptr) { - MS_LOG(EXCEPTION) << "Create custom op fail for " << cnode->fullname_with_scope(); - } - if (!builder->GenerateAkgKernelNodes(func_graph, custom_cnode, cnode)) { - MS_LOG(EXCEPTION) << "Copy kernel.o to tensor data fail for " << cnode->fullname_with_scope(); - } - custom_cnode->set_kernel_info(node->kernel_info_ptr()); - (void)manager->Replace(node, custom_cnode); - if (akg_node != nullptr) { - manager->AddEdge(custom_cnode, akg_node); - } - } - return true; -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/parameter_to_tensor.h b/mindspore-lite/tools/graph_kernel/converter/parameter_to_tensor.h deleted file mode 100644 index 87e148d02348ec63caa877ba5c1d719f24dd53eb..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/parameter_to_tensor.h +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_PARAMETER_TO_TENSOR_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_PARAMETER_TO_TENSOR_H_ -#include "ir/func_graph.h" -#include "include/backend/optimizer/pass.h" - -namespace mindspore::graphkernel { -class ParameterToTensor : public opt::Pass { - public: - ParameterToTensor() : Pass("parameter_to_tensor") {} - ~ParameterToTensor() override = default; - bool Run(const FuncGraphPtr &func_graph) override; -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_PARAMETER_TO_TENSOR_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.cc b/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.cc deleted file mode 100644 index ec982043ebd45c69cd4004827aeb89a1f20b89bd..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.cc +++ /dev/null @@ -1,254 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/graph_kernel/converter/preprocess_weight.h" -#include -#include -#include -#include "utils/anf_utils.h" -#include "backend/common/graph_kernel/core/graph_kernel_callback.h" -#include "backend/common/graph_kernel/core/graph_kernel_utils.h" - -namespace mindspore::graphkernel { -constexpr size_t kConv2dDataIndex = 1; -constexpr size_t kConv2dWeightIndex = 2; -constexpr size_t kWeightChannelOutAxis = 0; -constexpr size_t kWeightHeightAxis = 1; -constexpr size_t kWeightWidthAxis = 2; -constexpr size_t kWeightChannelInAxis = 3; -constexpr size_t kDepthWiseChannelAxis = 3; -constexpr size_t kShapeRank = 4; - -std::pair TilingChannel(int64_t channel, int64_t simd_size) { - for (auto inner = simd_size; inner > 0; inner--) { - if (channel % inner == 0) { - return std::make_pair(channel / inner, inner); - } - } - return {channel, 1LL}; -} - -class IndexCalc { - public: - explicit IndexCalc(const ShapeVector &shape) : shape_(shape) {} - int64_t GetFlatIndex(const ShapeVector &index) { - if (index.size() != shape_.size()) { - MS_LOG(EXCEPTION) << "The index's size should be equal to shape's size, but got " << index.size() << " vs " - << shape_.size(); - } - int64_t prod = 1LL; - int64_t result = 0LL; - for (int i = SizeToInt(shape_.size()) - 1; i >= 0; i--) { - result += index[IntToSize(i)] * prod; - prod *= shape_[IntToSize(i)]; - } - return result; - } - - private: - ShapeVector shape_; -}; - -AnfNodePtr SubstituteConv2D::InferWeightValue(const AnfNodePtr &node) { - auto cnode = QuickCloneCNode(node); - auto prim = GetCNodePrimitive(cnode)->Clone(); - cnode->set_input(0, NewValueNode(prim)); - auto cb = Callback::Instance(); - // the weight should be a 4D tensor of format OHWI - auto weight_shape = cb->GetInputShape(cnode, 1); - if (weight_shape.size() != kShapeRank) { - return nullptr; - } - auto c_out = weight_shape[kWeightChannelOutAxis]; - auto input_shape = cb->GetInputShape(cnode, 0); - auto c_in = input_shape[kDepthWiseChannelAxis]; - int64_t c_out_o; - int64_t c_out_i; - int64_t c_in_o; - int64_t c_in_i; - int64_t dst_simd_size = 8LL; - int64_t src_simd_size = 8LL; - if (common::GetEnv("MS_CPU_FEATURE") == "avx512") { - dst_simd_size = 16LL; - src_simd_size = 16LL; - } - if (prim->HasAttr("tuned_dst_format")) { - dst_simd_size = GkUtils::GetChannelInConvFormat(GetValue(prim->GetAttr("tuned_dst_format"))); - } - std::tie(c_out_o, c_out_i) = TilingChannel(c_out, dst_simd_size); - if (prim->HasAttr("tuned_src_format")) { - src_simd_size = GkUtils::GetChannelInConvFormat(GetValue(prim->GetAttr("tuned_src_format"))); - } - std::tie(c_in_o, c_in_i) = TilingChannel(c_in, src_simd_size); - (void)prim->AddAttr("weight_coo", MakeValue(c_out_o)); - (void)prim->AddAttr("weight_coi", MakeValue(c_out_i)); - (void)prim->AddAttr("weight_cio", MakeValue(c_in_o)); - (void)prim->AddAttr("weight_cii", MakeValue(c_in_i)); - - if (prim->HasAttr("is_depth_wise")) { - c_in_o = 1; - c_in_i = 1; - } - auto weight_node = cnode->input(kConv2dWeightIndex)->cast(); - if (weight_node == nullptr) { - return nullptr; - } - auto tensor = weight_node->value()->cast(); - if (tensor == nullptr) { - return nullptr; - } - if (tensor->data().const_data() == nullptr) { - return nullptr; - } - if (tensor->data_type() != kNumberTypeFloat32) { - return nullptr; - } - auto h_len = weight_shape[kWeightHeightAxis]; - auto w_len = weight_shape[kWeightWidthAxis]; - - // step 1, reshape the weight, [O,H,W,I] -> [Oo,Oi,H,W,Io,Ii] - // step 2, transpose it to [Oo,Io,H,W,Ii,Oi] - IndexCalc old_shape_calc({c_out_o, c_out_i, h_len, w_len, c_in_o, c_in_i}); - ShapeVector new_shape = {c_out_o, c_in_o, h_len, w_len, c_in_i, c_out_i}; - IndexCalc new_shape_calc(new_shape); - auto new_tensor = std::make_shared(tensor->data_type(), new_shape); - auto new_data = new_tensor->data_c(); - auto old_data = tensor->data_c(); - for (int64_t coo = 0; coo < c_out_o; coo++) { - for (int64_t cio = 0; cio < c_in_o; cio++) { - for (int64_t h = 0; h < h_len; h++) { - for (int64_t w = 0; w < w_len; w++) { - for (int64_t cii = 0; cii < c_in_i; cii++) { - for (int64_t coi = 0; coi < c_out_i; coi++) { - auto old_val = static_cast(old_data)[old_shape_calc.GetFlatIndex({coo, coi, h, w, cio, cii})]; - static_cast(new_data)[new_shape_calc.GetFlatIndex({coo, cio, h, w, cii, coi})] = old_val; - } - } - } - } - } - } - - auto v = NewValueNode(new_tensor); - v->set_abstract(new_tensor->ToAbstract()); - v->set_kernel_info(weight_node->kernel_info_ptr()); - cnode->set_input(kConv2dWeightIndex, v); - return cnode; -} - -AnfNodePtr SubstituteConv2D::Run(const AnfNodePtr &node) { - auto new_node = InferWeightValue(node); - if (new_node == nullptr) { - return nullptr; - } - return ExpanderDecorator::Run(new_node); -} - -AnfNodePtr MatmulPackB::InferValue(const AnfNodePtr &node) { - auto cnode = QuickCloneCNode(node, true); - MS_EXCEPTION_IF_NULL(cnode); - const size_t kMatMulWeightIndex = 2; - const size_t kMatMulWeightRank = 2; - auto cb = Callback::Instance(); - auto type_id = cb->GetInputType(cnode, kMatMulWeightIndex - 1); - // only support float32 - if (type_id != kNumberTypeFloat32) { - MS_LOG(INFO) << "MatmulPackB only supports Float32 but got " << TypeIdToString(type_id); - return nullptr; - } - auto shape = cb->GetInputShape(cnode, kMatMulWeightIndex - 1); - if (shape.size() != kMatMulWeightRank) { - return node; - } - auto prim = GetCNodePrimitive(cnode); - auto weight_node = cnode->input(kMatMulWeightIndex)->cast(); - if (weight_node == nullptr) { - return node; - } - auto tensor = weight_node->value()->cast(); - if (tensor == nullptr) { - return node; - } - if (tensor->data().const_data() == nullptr) { - return node; - } - - // infer the transpose_b result - bool transpose_b = false; - if (prim->HasAttr("transpose_b")) { - transpose_b = GetValue(prim->GetAttr("transpose_b")); - } - auto new_tensor = PackB(tensor, shape, transpose_b); - prim->set_attr("pack_b", MakeValue(true)); - if (transpose_b) { - prim->set_attr("transpose_b", MakeValue(false)); - } - auto v = NewValueNode(new_tensor); - v->set_abstract(new_tensor->ToAbstract()); - v->set_kernel_info(weight_node->kernel_info_ptr()); - cnode->set_input(kMatMulWeightIndex, v); - return cnode; -} - -/* -Pack(B) example -tensor of shape (3, 7): -[ 1 2 3 4 5 6 7] -[ 8 9 10 11 12 13 14] -[15 16 17 18 19 20 21] ---- pack in size 4, 2, 1 ---> -[(1 2 3 4) (8 9 10 11) (15 16 17 18) (5 6) (12 13) (19 20) (7) (14) (21)] ---- reshape to (3, 7) ---> -[ 1 2 3 4 8 9 10] -[11 15 16 17 18 5 6] -[12 13 19 20 7 14 21] -*/ -tensor::TensorPtr MatmulPackB::PackB(const tensor::TensorPtr &tensor, const ShapeVector &shape, bool transpose) { - std::vector pack_size = {24, 16, 8, 4, 2, 1}; - IndexCalc index_calc(shape); - auto height = shape[0]; - auto width = shape[1]; - if (transpose) { - std::swap(height, width); - } - auto new_tensor = std::make_shared(tensor->data_type(), std::vector{height, width}); - auto *new_tensor_iter = static_cast(new_tensor->data_c()); - int64_t width_offset = 0; - for (auto pack : pack_size) { - while (width_offset + pack <= width) { - for (int64_t i = 0; i < height; ++i) { - for (int64_t j = 0; j < pack; ++j) { - if (transpose) { - *new_tensor_iter++ = static_cast(tensor->data_c())[index_calc.GetFlatIndex({j + width_offset, i})]; - } else { - *new_tensor_iter++ = static_cast(tensor->data_c())[index_calc.GetFlatIndex({i, j + width_offset})]; - } - } - } - width_offset += pack; - } - } - return new_tensor; -} - -AnfNodePtr MatmulPackB::Run(const AnfNodePtr &node) { - auto new_node = InferValue(node); - if (new_node == nullptr) { - return nullptr; - } - return ExpanderDecorator::Run(new_node); -} -} // namespace mindspore::graphkernel diff --git a/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.h b/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.h deleted file mode 100644 index ac981fe2d43e9f7a09740846b681152ec1c9be36..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/preprocess_weight.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_PREPROCESS_WEIGHT_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_PREPROCESS_WEIGHT_H_ -#include - -#include "backend/common/graph_kernel/core/graph_kernel_expander.h" - -namespace mindspore::graphkernel { -class SubstituteConv2D : public ExpanderDecorator { - public: - using ExpanderDecorator::ExpanderDecorator; - static ExpanderPtr Creator(const ExpanderPtr &decorated) { - return std::static_pointer_cast(std::make_shared(decorated)); - } - AnfNodePtr Run(const AnfNodePtr &node) override; - - protected: - AnfNodePtr InferWeightValue(const AnfNodePtr &node); -}; - -class MatmulPackB : public ExpanderDecorator { - public: - using ExpanderDecorator::ExpanderDecorator; - static ExpanderPtr Creator(const ExpanderPtr &decorated) { - return std::static_pointer_cast(std::make_shared(decorated)); - } - AnfNodePtr Run(const AnfNodePtr &node) override; - - protected: - AnfNodePtr InferValue(const AnfNodePtr &node); - tensor::TensorPtr PackB(const tensor::TensorPtr &tensor, const ShapeVector &shape, bool transpose); -}; -} // namespace mindspore::graphkernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_PREPROCESS_WEIGHT_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/rename_fullname_with_scope.h b/mindspore-lite/tools/graph_kernel/converter/rename_fullname_with_scope.h deleted file mode 100644 index ce1e3f7f3d84787f4547fb2fc95eebb4b0ef3fc4..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/rename_fullname_with_scope.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_RENAME_FULLNAME_WITH_SCOPE_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_RENAME_FULLNAME_WITH_SCOPE_H_ -#include "ir/func_graph.h" -#include "include/backend/optimizer/pass.h" - -namespace mindspore::graphkernel { -class RenameFullnameWithScope : public opt::Pass { - public: - RenameFullnameWithScope() : Pass("rename_fullname_with_scope") {} - ~RenameFullnameWithScope() override = default; - bool Run(const FuncGraphPtr &func_graph) override; -}; -} // namespace mindspore::graphkernel - -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_RENAME_FULLNAME_WITH_SCOPE_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/split_model_ascend.cc b/mindspore-lite/tools/graph_kernel/converter/split_model_ascend.cc deleted file mode 100644 index 9f7a047f8c65f829dfd5d4d169991475d9fbb2a8..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/split_model_ascend.cc +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2023-2025 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/graph_kernel/converter/split_model_ascend.h" -#include -#include "utils/ms_context.h" - -namespace mindspore::graphkernel::inner { -constexpr size_t kReduceFusionDepth = 10; -constexpr size_t kBroadcastFusionDepth = 6; - -class FuseLayerNorm : public FusePattern { - public: - FuseLayerNorm() : FusePattern("layer_norm") { direction_ = FuseDirection::BACKWARD; } - ~FuseLayerNorm() = default; - - protected: - bool Check(const AreaPtr &dom) override { - MS_EXCEPTION_IF_NULL(dom); - return (dom->dom() != nullptr && dom->dom()->op() == "ReduceSum"); - } - bool Match(const AreaPtr &dom) override { - constexpr size_t c1 = 1; - constexpr size_t c2 = 2; - auto users = dom->users(); - if (users.size() != c1 || users[0]->pattern() != NodePattern::BROADCAST) { - return false; - } - auto user_users = users[0]->users(); - if (user_users.size() != c2) { - return false; - } - if ((user_users[0]->pattern() == NodePattern::REDUCE && user_users[1]->pattern() == NodePattern::BROADCAST) || - (user_users[0]->pattern() == NodePattern::BROADCAST && user_users[1]->pattern() == NodePattern::REDUCE)) { - (void)fused_areas_.emplace_back(users[0]); - (void)fused_areas_.emplace_back(user_users[0]); - (void)fused_areas_.emplace_back(user_users[1]); - } - return !fused_areas_.empty(); - } -}; - -void SplitModelAscend::InitFusePatterns() { - AddPattern(std::make_shared(), true); - if ((soc_version_.find("910B") == string::npos) && (soc_version_.find("910_93") == string::npos)) { - // Ascend 910B do not fuse Matmul - AddPattern(std::make_shared(), true); - } - AddPattern(FuseElemwiseBroadcastFwd::CreateDepthMatcher(), true); - AddPattern(FuseElemwiseBroadcastFwd::CreateWidthMatcher(), true); - AddPattern(FuseReduceFwd::CreateDepthMatcher(kReduceFusionDepth), true); - AddPattern(FuseReduceFwd::CreateWidthMatcher(kReduceFusionDepth), true); - AddPattern(FuseElemwiseBroadcastBwd::CreateDepthMatcher(kBroadcastFusionDepth), true); - AddPattern(FuseElemwiseBroadcastBwd::CreateWidthMatcher(kBroadcastFusionDepth), true); - AddPattern(std::make_shared(), true); -} - -AreaMode SplitModelAscend::GetDefaultAreaMode(const PrimOpPtr &node) const { - if (node != nullptr && (node->op() == "MatMul" || node->op() == "PagedAttention" || - node->op() == "PagedAttentionMask" || node->op() == "ReshapeAndCache")) { - return AreaMode::COMPOSITE; - } - return AreaMode::BASIC; -} -} // namespace mindspore::graphkernel::inner diff --git a/mindspore-lite/tools/graph_kernel/converter/split_model_ascend.h b/mindspore-lite/tools/graph_kernel/converter/split_model_ascend.h deleted file mode 100644 index ca1fb1f76ac7f2ec65d211e9c2c2b49156d230fb..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/split_model_ascend.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_MODEL_ASCEND_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_MODEL_ASCEND_H_ - -#include -#include "backend/common/graph_kernel/split_model/split_model_factory.h" -namespace mindspore::graphkernel::inner { -class SplitModelAscend : public SplitModel { - public: - SplitModelAscend() = default; - explicit SplitModelAscend(std::string soc_version) : soc_version_(soc_version) {} - virtual ~SplitModelAscend() = default; - - protected: - AreaMode GetDefaultAreaMode(const PrimOpPtr &) const override; - void InitFusePatterns() override; - std::string soc_version_ = ""; -}; -} // namespace mindspore::graphkernel::inner -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_MODEL_ASCEND_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/split_model_cpu.cc b/mindspore-lite/tools/graph_kernel/converter/split_model_cpu.cc deleted file mode 100644 index 85b1bd6078f8164a3cf8996f0b05d8722d5495f4..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/split_model_cpu.cc +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2022-2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/graph_kernel/converter/split_model_cpu.h" -#include -#include "utils/ms_context.h" - -namespace mindspore::graphkernel::inner { -constexpr size_t kReduceFusionDepth = 20; -constexpr size_t kBroadcastFusionDepth = 20; - -class FuseConv : public FusePattern { - public: - FuseConv() : FusePattern("conv") { direction_ = FuseDirection::BACKWARD; } - ~FuseConv() = default; - - protected: - bool Check(const AreaPtr &dom) override { - if (dom->dom()->op() != "Conv2D") { - return false; - } - return true; - } - bool Match(const AreaPtr &dom) override { - for (auto d : dom->users_with_relation()) { - auto a = d.first; - if (HasCircle(dom, a)) { - continue; - } - if (a->pattern() < NodePattern::BROADCAST || - (a->pattern() == NodePattern::BROADCAST && a->dom()->shape == dom->dom()->shape)) { - (void)fused_areas_.emplace_back(a); - } - } - return !fused_areas_.empty(); - } -}; - -void SplitModelCpu::InitFusePatterns() { - AddPattern(std::make_shared(), true); - AddPattern(std::make_shared(), true); - AddPattern(FuseElemwiseFwd::CreateDepthMatcher(), true); - AddPattern(FuseElemwiseFwd::CreateWidthMatcher(), true); - AddPattern(std::make_shared(), true); - AddPattern(FuseElemwiseBroadcastFwd::CreateDepthMatcher(), true); - AddPattern(FuseElemwiseBroadcastFwd::CreateWidthMatcher(), true); - AddPattern(FuseReduceFwd::CreateDepthMatcher(kReduceFusionDepth), true); - AddPattern(FuseReduceFwd::CreateWidthMatcher(kReduceFusionDepth), true); - AddPattern(FuseElemwiseBroadcastBwd::CreateDepthMatcher(kBroadcastFusionDepth), true); - AddPattern(FuseElemwiseBroadcastBwd::CreateWidthMatcher(kBroadcastFusionDepth), true); - AddPattern(std::make_shared(), true); -} - -AreaMode SplitModelCpu::GetDefaultAreaMode(const PrimOpPtr &) const { return AreaMode::COMPOSITE; } -} // namespace mindspore::graphkernel::inner diff --git a/mindspore-lite/tools/graph_kernel/converter/split_model_gpu.cc b/mindspore-lite/tools/graph_kernel/converter/split_model_gpu.cc deleted file mode 100644 index 7eb40baa40169d823c7c1318f9b5af4b51b60c69..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/split_model_gpu.cc +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tools/graph_kernel/converter/split_model_gpu.h" -#include -#include "utils/ms_context.h" - -namespace mindspore::graphkernel::inner { -constexpr size_t kReduceFusionDepth = 20; -constexpr size_t kBroadcastFusionDepth = 20; - -void SplitModelGpu::InitFusePatterns() { - AddPattern(std::make_shared(), true); - AddPattern(std::make_shared(), true); - AddPattern(FuseElemwiseFwd::CreateDepthMatcher(), true); - AddPattern(FuseElemwiseFwd::CreateWidthMatcher(), true); - AddPattern(FuseElemwiseBroadcastFwd::CreateDepthMatcher(), true); - AddPattern(FuseElemwiseBroadcastFwd::CreateWidthMatcher(), true); - AddPattern(FuseReduceFwd::CreateDepthMatcher(kReduceFusionDepth), true); - AddPattern(FuseReduceFwd::CreateWidthMatcher(kReduceFusionDepth), true); - AddPattern(FuseElemwiseBroadcastBwd::CreateDepthMatcher(kBroadcastFusionDepth), true); - AddPattern(FuseElemwiseBroadcastBwd::CreateWidthMatcher(kBroadcastFusionDepth), true); - AddPattern(std::make_shared(), true); -} - -AreaMode SplitModelGpu::GetDefaultAreaMode(const PrimOpPtr &) const { return AreaMode::COMPOSITE; } -} // namespace mindspore::graphkernel::inner diff --git a/mindspore-lite/tools/graph_kernel/converter/split_model_gpu.h b/mindspore-lite/tools/graph_kernel/converter/split_model_gpu.h deleted file mode 100644 index a6f99f1672c33512f0f80d192ab361d68ccfa200..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/split_model_gpu.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_MODEL_GPU_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_MODEL_GPU_H_ - -#include "backend/common/graph_kernel/split_model/split_model_factory.h" -namespace mindspore::graphkernel::inner { -class SplitModelGpu : public SplitModel { - public: - SplitModelGpu() = default; - virtual ~SplitModelGpu() = default; - - protected: - AreaMode GetDefaultAreaMode(const PrimOpPtr &) const override; - void InitFusePatterns() override; -}; -} // namespace mindspore::graphkernel::inner -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_SPLIT_MODEL_GPU_H_ diff --git a/mindspore-lite/tools/graph_kernel/converter/update_kernel_info.h b/mindspore-lite/tools/graph_kernel/converter/update_kernel_info.h deleted file mode 100644 index 1e8b35f03a8b2403f31458b44283f7b713a8e8aa..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/converter/update_kernel_info.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2023 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_UPDATE_KERNEL_INFO_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_UPDATE_KERNEL_INFO_H_ -#include "ir/func_graph.h" -#include "include/backend/optimizer/pass.h" - -namespace mindspore::graphkernel { -class UpdateKernelInfo : public opt::Pass { - public: - UpdateKernelInfo() : Pass("update_kernel_info") {} - ~UpdateKernelInfo() override = default; - bool Run(const FuncGraphPtr &func_graph) override; -}; -} // namespace mindspore::graphkernel - -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_UPDATE_KERNEL_INFO_H_ diff --git a/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.cc b/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.cc deleted file mode 100644 index 6d01be231585ed9504793716790bb44a95e58f2d..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.cc +++ /dev/null @@ -1,286 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/graph_kernel/runtime/akg_kernel.h" -#include -#include -#include -#include -#include -#include "kernel/graph_kernel/graph_kernel_json_flags.h" -#include "tools/graph_kernel/common/utils.h" -#include "src/tensor.h" -#include "src/common/utils.h" -#include "src/common/tensor_util.h" -#include "src/litert/kernel_registry.h" -#include "schema/model_generated.h" -#include "src/common/dynamic_library_loader.h" -#include "src/common/file_utils.h" - -namespace mindspore::kernel { -using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -constexpr auto kNumberTwo = 2; -namespace { -int TmpAkgParallelLaunchFunc(AkgParallelLambda flambda, void *cdata, int num_task) { - /* - The `cdata` is a second-level pointer, which first element is a pointer to a structure object. - The structure contains original AkgCallBack's elements, but except the first `parallel_launch_func`. - It seems like `{malloc_func, free_func, extend_data}`, all elements are also pointers. - So, to get the `extend_data`, we can treat the `cdata` as a third-level pointer, - and then offset TWO elements for the first structure object. - The `extend_data` was set as the `this` pointer of `AkgKernel` object. - */ - const auto kExtendDataOffset = 2; - void *extend_data = static_cast(cdata)[0][kExtendDataOffset]; - static_cast(extend_data)->AkgParallelLaunchFunc(flambda, cdata, num_task); - return 0; -} - -class AkgCallBack { - public: - void *parallel_launch_func = nullptr; - void *(*malloc_func)(size_t) = nullptr; - void (*free_func)(void *) = nullptr; - void *extend_data = nullptr; - - AkgCallBack() { - parallel_launch_func = reinterpret_cast(TmpAkgParallelLaunchFunc); - malloc_func = &malloc; - free_func = &free; - } - ~AkgCallBack() = default; -}; -} // namespace - -int AkgKernel::CheckAkgKernelInfo() { - std::string current_arch; -#if defined(ENABLE_ARM64) - current_arch = "aarch64"; -#elif defined(ENABLE_ARM) - current_arch = "arm"; -#else - current_arch = "x86_64"; -#endif - if (current_arch != arch) { - MS_LOG(ERROR) << "Current cpu arch is " << current_arch << ", but got a " << arch - << " AKGKernel. AkgKernel info ckeck failed."; - return RET_ERROR; - } -#if defined(ENABLE_AVX512) - return RET_OK; -#elif defined(ENABLE_AVX) - if (cpu_feature == "avx512") { - MS_LOG(ERROR) - << "Current Runtime not support avx512, but AkgKernel got an avx512 kernel. AkgKernel info ckeck failed."; - return RET_ERROR; - } -#elif defined(ENABLE_AVX) - if (cpu_feature == "avx512" || cpu_feature == "avx") { - MS_LOG(ERROR) << "Current Runtime not support avx512 and avx, but AkgKernel got an " << cpu_feature - << " kernel. AkgKernel info ckeck failed."; - return RET_ERROR; - } -#endif - return RET_OK; -} - -void AkgKernel::ExtractKernelAttr() { - auto prim = static_cast(params_)->value_as_Custom(); - for (size_t i = 0; i < prim->attr()->size(); i++) { - auto attr = prim->attr()->Get(i); - if (attr->name()->str() == "kernel_name") { - kernel_name_ = std::string(reinterpret_cast(attr->data()->Data()), attr->data()->size()); - } else if (attr->name()->str() == "inputs_shape") { - std::string inputs_shape_str(reinterpret_cast(attr->data()->Data()), attr->data()->size()); - (void)graphkernel::GetCustomShape(inputs_shape_str, &origin_inputs_shape_); - } else if (attr->name()->str() == "dynamic_input_index") { - dynamic_batch_size_ = 1; - std::string dynamic_input_index_str(reinterpret_cast(attr->data()->Data()), attr->data()->size()); - graphkernel::GetCustomIndex(dynamic_input_index_str, &dynamic_input_index_); - } else if (attr->name()->str() == mindspore::graphkernel::kJsonKeyProcess) { - process = std::string(reinterpret_cast(attr->data()->Data()), attr->data()->size()); - } else if (attr->name()->str() == mindspore::graphkernel::kJsonKeyArch) { - arch = std::string(reinterpret_cast(attr->data()->Data()), attr->data()->size()); - } else if (attr->name()->str() == mindspore::graphkernel::kJsonKeySystem) { - system = std::string(reinterpret_cast(attr->data()->Data()), attr->data()->size()); - } else if (attr->name()->str() == mindspore::graphkernel::kJsonKeyCpuFeature) { - cpu_feature = std::string(reinterpret_cast(attr->data()->Data()), attr->data()->size()); - } else { - continue; - } - } -} - -int TmpDoTask(void *obj, int task_id, float lhs_scale, float rhs_scale) { - return static_cast(obj)->DoTask(task_id, lhs_scale, rhs_scale); -} - -int AkgKernel::DoTask(int task_id, float, float) { - (void)cached_akg_lambda_(task_id, nthread_, cached_runtimeargs_); - return RET_OK; -} - -void AkgKernel::AkgParallelLaunchFunc(AkgParallelLambda flambda, void *cdata, int) { - cached_akg_lambda_ = flambda; - cached_runtimeargs_ = cdata; - (void)ParallelLaunch(this->ms_context_, TmpDoTask, this, this->nthread_); - cached_akg_lambda_ = nullptr; - cached_runtimeargs_ = nullptr; -} - -int AkgKernel::Prepare() { - if (kernel_func_ != nullptr) { - return RET_OK; - } - if (CheckAkgKernelInfo() != RET_OK) { - return RET_ERROR; - } - if (in_tensors_.size() < kNumberTwo) { - MS_LOG(ERROR) << "The number of input tensor in AkgKernel must greater than 2, but now got " << in_tensors_.size(); - return lite::RET_INPUT_TENSOR_ERROR; - } - auto akg_lib_tensor = in_tensors_.at(in_tensors_.size() - 1); - auto akg_lib_ptr = akg_lib_tensor->data(); - auto akg_kernel_so = kernel_name_ + ".so"; - std::string kernle_meta = "akg_kernel_meta_runtime"; - if (lite::CreateDir(kernle_meta) != RET_OK) { - MS_LOG(ERROR) << "cannot create dir " << kernle_meta; - return lite::RET_ERROR; - } - auto akg_kernel_path = kernle_meta + "/" + akg_kernel_so; - if (lite::WriteToBin(akg_kernel_path, akg_lib_ptr, akg_lib_tensor->Size())) { - MS_LOG(ERROR) << "write data to " << akg_kernel_so << " failed."; - return lite::RET_ERROR; - } - auto real_path = lite::RealPath(akg_kernel_path.c_str()); - if (real_path.empty()) { - MS_LOG(ERROR) << "cannot access file:" << real_path << ".please check file if exists and file mod"; - return lite::RET_ERROR; - } - lib_handle_ = dlopen(real_path.c_str(), RTLD_LAZY | RTLD_LOCAL); - if (lib_handle_ == nullptr) { - MS_LOG(ERROR) << "Load library from tensor failed. Kernel name is [" << akg_kernel_so << "]"; - return RET_ERROR; - } - kernel_func_ = dlsym(lib_handle_, kernel_name_.c_str()); - if (kernel_func_ == nullptr) { - MS_LOG(ERROR) << "Undefined symbol [" << kernel_name_ << "] in [" << akg_kernel_so << "]"; - return RET_ERROR; - } - // the last input tensor is akgkernels.so, so we need to remove it. - in_tensors_.pop_back(); - const size_t kAddrAlign = 32; - const size_t kAddrAlignMask = 0x1f; - const_inputs_.reserve(in_tensors_.size()); - for (auto &input : in_tensors_) { - // the data address should align in 32 bytes. - if (input->IsConst() && (reinterpret_cast(input->data()) & kAddrAlignMask) != 0) { - auto buffer = static_cast(input->data()); - auto tensor_size = input->Size(); - if (tensor_size == 0) { - MS_LOG(ERROR) << "The tensor \'" << input->tensor_name() << "\' size is 0. kernel: " << kernel_name_; - return RET_ERROR; - } - std::vector input_align(tensor_size + kAddrAlign); - auto p = input_align.data(); - while ((reinterpret_cast(p) & kAddrAlignMask) != 0) { - p++; - } - (void)std::copy(buffer, buffer + tensor_size, p); - (void)const_inputs_.emplace_back(static_cast(p)); - (void)const_data_align_cache_.emplace_back(std::move(input_align)); - } else { - (void)const_inputs_.emplace_back(nullptr); - } - } - return RET_OK; -} - -int AkgKernel::Run() { - if (kernel_func_ == nullptr) { - MS_LOG(ERROR) << "Kernel function [" << kernel_name_ << "] is nullptr."; - return RET_ERROR; - } - nthread_ = op_parameter_->thread_num_; - std::vector runtimeargs; - // callbackfunc and dynamic batch size - const size_t extra_arg_num_with_batch = 2; - runtimeargs.reserve(in_tensors_.size() + out_tensors_.size() + extra_arg_num_with_batch); - - static AkgCallBack akg_callback; - akg_callback.extend_data = static_cast(this); - (void)runtimeargs.emplace_back(static_cast(&akg_callback)); - for (size_t i = 0; i < in_tensors_.size(); i++) { - if (const_inputs_[i] != nullptr) { - (void)runtimeargs.emplace_back(const_inputs_[i]); - } else { - (void)runtimeargs.emplace_back(in_tensors_[i]->data()); - } - } - (void)std::transform(std::begin(out_tensors_), std::end(out_tensors_), std::back_inserter(runtimeargs), - [](lite::Tensor *output) { return output->MutableData(); }); - if (dynamic_batch_size_ != 0) { - (void)runtimeargs.emplace_back(static_cast(&dynamic_batch_size_)); - } - using AkgCpuKernelFunction = void (*)(void *); - reinterpret_cast(kernel_func_)(static_cast(runtimeargs.data())); - return RET_OK; -} - -int AkgKernel::ReSize() { - if (in_tensors_.empty() || dynamic_batch_size_ == 0) { - return mindspore::lite::RET_OK; - } - std::vector input_tensorc(in_tensors_.size()); - for (size_t i = 0; i < in_tensors_.size(); i++) { - int ret = lite::Tensor2TensorC(in_tensors_[i], &input_tensorc[i]); - if (ret != mindspore::lite::RET_OK) { - MS_LOG(ERROR) << "Convert Tensor to TensorC failed."; - return mindspore::lite::RET_ERROR; - } - } - std::vector input_tensorc_pointer; - (void)std::transform(input_tensorc.begin(), input_tensorc.end(), std::back_inserter(input_tensorc_pointer), - [](const TensorC &t) { return &t; }); - if (graphkernel::CalculateDynamicBatchSize(&input_tensorc_pointer[0], in_tensors_.size(), origin_inputs_shape_, - dynamic_input_index_, &dynamic_batch_size_) != RET_OK) { - return mindspore::lite::RET_ERROR; - } - return mindspore::lite::RET_OK; -} - -AkgKernel::~AkgKernel() { - if (lib_handle_ != nullptr) { - (void)dlclose(lib_handle_); - lib_handle_ = nullptr; - } -} - -REG_KERNEL(kCPU, kNumberTypeBool, PrimType_Inner_GraphKernel, LiteKernelCreator) -REG_KERNEL(kCPU, kNumberTypeInt8, PrimType_Inner_GraphKernel, LiteKernelCreator) -REG_KERNEL(kCPU, kNumberTypeInt16, PrimType_Inner_GraphKernel, LiteKernelCreator) -REG_KERNEL(kCPU, kNumberTypeInt32, PrimType_Inner_GraphKernel, LiteKernelCreator) -REG_KERNEL(kCPU, kNumberTypeInt64, PrimType_Inner_GraphKernel, LiteKernelCreator) -REG_KERNEL(kCPU, kNumberTypeUInt8, PrimType_Inner_GraphKernel, LiteKernelCreator) -REG_KERNEL(kCPU, kNumberTypeUInt16, PrimType_Inner_GraphKernel, LiteKernelCreator) -REG_KERNEL(kCPU, kNumberTypeUInt32, PrimType_Inner_GraphKernel, LiteKernelCreator) -REG_KERNEL(kCPU, kNumberTypeUInt64, PrimType_Inner_GraphKernel, LiteKernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat16, PrimType_Inner_GraphKernel, LiteKernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_GraphKernel, LiteKernelCreator) -REG_KERNEL(kCPU, kNumberTypeFloat64, PrimType_Inner_GraphKernel, LiteKernelCreator) -} // namespace mindspore::kernel diff --git a/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.h b/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.h deleted file mode 100644 index f1afd16e3faf0cd14576e33ad1ee9e1b26835465..0000000000000000000000000000000000000000 --- a/mindspore-lite/tools/graph_kernel/runtime/akg_kernel.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_RUNTIME_AKG_KERNEL_H_ -#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_RUNTIME_AKG_KERNEL_H_ -#include -#include -#include -#include "src/litert/lite_kernel.h" -#include "nnacl/custom_parameter.h" - -namespace mindspore::kernel { -using AkgParallelLambda = int (*)(int task_id, int num_task, void *cdata); - -class AkgKernel : public LiteKernel { - public: - AkgKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::InnerContext *ctx) - : LiteKernel(parameter, inputs, outputs, ctx) { - // in PopulateCustomParameter, the primitive is store in attr_data[0] - params_ = static_cast(reinterpret_cast(op_parameter_)->attr_data[0]); - ExtractKernelAttr(); - } - virtual ~AkgKernel(); - int Prepare() override; - int Run() override; - int ReSize() override; - - // the real callback function that send to akg - void AkgParallelLaunchFunc(AkgParallelLambda flambda, void *cdata, int); - // the callback function that send to thread pool - int DoTask(int task_id, float, float); - - protected: - void ExtractKernelAttr(); - - void *params_{nullptr}; - void *kernel_func_{nullptr}; - std::string kernel_name_; - int nthread_{0}; - int dynamic_batch_size_{0}; - std::vector> const_data_align_cache_; - std::vector const_inputs_; - AkgParallelLambda cached_akg_lambda_ = nullptr; - void *cached_runtimeargs_ = nullptr; - std::vector dynamic_input_index_; - std::vector> origin_inputs_shape_; - void *lib_handle_ = nullptr; - std::string process; - std::string arch; - std::string system; - std::string cpu_feature = ""; - - private: - int LoadAkgLib(void *data, size_t file_size); - int CheckAkgKernelInfo(); - void CloseAkgLib(); -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_RUNTIME_AKG_KERNEL_H_ diff --git a/mindspore-lite/tools/kernel_builder/ascend/ascendc/cmake/config.cmake b/mindspore-lite/tools/kernel_builder/ascend/ascendc/cmake/config.cmake index 24a0944b4570b862504a4c63d30a975eee0f078d..c951f55e672955f5bcce3c806a079c035c0114d8 100644 --- a/mindspore-lite/tools/kernel_builder/ascend/ascendc/cmake/config.cmake +++ b/mindspore-lite/tools/kernel_builder/ascend/ascendc/cmake/config.cmake @@ -5,7 +5,7 @@ set(ENABLE_BINARY_PACKAGE False) set(ASCEND_COMPUTE_UNIT "ascend910;ascend910b;ascend310p") set(vendor_name mslite_ascendc) set(ASCEND_PYTHON_EXECUTABLE python3) -set(PKG_PATH ${TOP_DIR}/build/tools/kernel_builder/ascend/ascendc/makepkg) +set(PKG_PATH ${TOP_DIR}/mindspore-lite/build/tools/kernel_builder/ascend/ascendc/makepkg) if(DEFINED ENV{ASCEND_CUSTOM_PATH}) set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_CUSTOM_PATH}/latest) diff --git a/mindspore-lite/tools/lite_exporter/fetch_content.cc b/mindspore-lite/tools/lite_exporter/fetch_content.cc index 97d6a97390bda75ff7d10d3a31eb16b9744f0bd0..4073d0677672a324d5e3556ea9715175ad9a0bb3 100644 --- a/mindspore-lite/tools/lite_exporter/fetch_content.cc +++ b/mindspore-lite/tools/lite_exporter/fetch_content.cc @@ -325,8 +325,8 @@ int FetchFromDefaultParam(const ParameterPtr ¶m_node, const converter::FmkTy // the const tensor format from onnx/caffe should be nchw in general auto const_format = (fmk_type == converter::kFmkTypeMsLite || fmk_type == converter::kFmkTypeTf || fmk_type == converter::kFmkTypeTflite) - ? NHWC - : NCHW; + ? NHWC + : NCHW; data_info->format_ = param_node->has_default() ? const_format : NHWC; return RET_OK; } diff --git a/mindspore-lite/tools/mindir_exporter/CMakeLists.txt b/mindspore-lite/tools/mindir_exporter/CMakeLists.txt index 446c96ab9187988d63d9041fe01ae27d9bcac44c..376ce4a17b5e1d98c03f7741a2b8afc1f7e8d858 100644 --- a/mindspore-lite/tools/mindir_exporter/CMakeLists.txt +++ b/mindspore-lite/tools/mindir_exporter/CMakeLists.txt @@ -23,13 +23,19 @@ add_library(mindir_serializer_mid OBJECT ${MINDIR_EXPORTER_SRC_LIST} ) add_dependencies(mindir_serializer_mid fbs_src fbs_inner_src) -file(STRINGS "${TOP_DIR}/version.txt" VERSION) -add_definitions(-DVERSION=\"${VERSION}\") +# file(STRINGS "${TOP_DIR}/version.txt" VERSION) +# add_definitions(-DVERSION=\"${VERSION}\") +# set(NEW_CCSRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../ccsrc) add_library(ccsrc_debug_common_mid_ OBJECT - ${CCSRC_DIR}/common/debug/common.cc - ${CCSRC_DIR}/utils/compile_cache_context.cc - ${CCSRC_DIR}/common/debug/mindir_exporter.cc + # ${CCSRC_DIR}/common/debug/common.cc + # ${CCSRC_DIR}/utils/compile_cache_context.cc + # ${CCSRC_DIR}/common/debug/mindir_exporter.cc +# ${NEW_CCSRC_DIR}/common/debug/common.cc +# ${NEW_CCSRC_DIR}/utils/compile_cache_context.cc +# ${NEW_CCSRC_DIR}/common/debug/mindir_exporter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/extendrt/delegate/comm_group_info.cc ) +add_dependencies(ccsrc_debug_common_mid_ mindir_proto_mid) target_link_libraries(mindir_serializer_mid _mindspore_transform_express_ir_obj diff --git a/mindspore-lite/tools/mindir_exporter/mindir_serializer.cc b/mindspore-lite/tools/mindir_exporter/mindir_serializer.cc index 068618f8badad02e96e550cff63549dd4f184e61..931f78f88f8bc9172eea415e0c469c5b80798164 100644 --- a/mindspore-lite/tools/mindir_exporter/mindir_serializer.cc +++ b/mindspore-lite/tools/mindir_exporter/mindir_serializer.cc @@ -36,6 +36,7 @@ #include "tools/converter/quantizer/quantize_util.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_o.h" #include "src/common/decrypt.h" +#include "src/extendrt/delegate/comm_group_info.h" namespace mindspore::lite { namespace { @@ -663,7 +664,7 @@ int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const st MS_LOG(INFO) << "No need to save proto to file"; return RET_OK; } - auto realpath = Common::CreatePrefixPath(output_file, true); + auto realpath = CommGroupInfo::CreatePrefixPath(output_file, true); if (!realpath.has_value()) { MS_LOG(ERROR) << "Get real path of file " << output_file << " failed."; return RET_ERROR; diff --git a/mindspore-lite/tools/optimizer/common/gllo_utils.h b/mindspore-lite/tools/optimizer/common/gllo_utils.h index c75381cd378b443b62f6433a09541a8f9cb86925..65d038257034428988465d27bc4c1636a1ddec5d 100644 --- a/mindspore-lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore-lite/tools/optimizer/common/gllo_utils.h @@ -151,7 +151,7 @@ ParameterPtr BuildFloatVec2DParameterNode(const FuncGraphPtr &func_graph, const ParameterPtr BuildFloatVec3DParameterNode(const FuncGraphPtr &func_graph, const std::vector>> &data, - const std::string &node_name); +const std::string &node_name); CNodePtr GenTransposeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const std::vector &perm, const std::string &cnode_name); diff --git a/mindspore-lite/tools/optimizer/common/helper.cc b/mindspore-lite/tools/optimizer/common/helper.cc index b519cc213a8f0537659e8644db43c5dc550e48be..6c504766a9c6c64ad74d572addf8e32437d54b92 100644 --- a/mindspore-lite/tools/optimizer/common/helper.cc +++ b/mindspore-lite/tools/optimizer/common/helper.cc @@ -150,10 +150,10 @@ CNodePtr NewCNode(const std::vector &inputs, const FuncGraphPtr &fg, } // not implement for lite, just for api compatible -CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector &orig_nodes) { - MS_LOG(DEBUG) << "Not implement for lite, just for api compatible."; - return nullptr; -} +// CNodePtr NewCNode(const CNodePtr &cnode, const KernelGraphPtr &fg, const std::vector &orig_nodes) { +// MS_LOG(DEBUG) << "Not implement for lite, just for api compatible."; +// return nullptr; +// } // not implement for lite, just for api compatible AbstractBasePtr CppInferShapeAndType(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) { diff --git a/mindspore-lite/tools/optimizer/fusion/adjust_matmul_pass.cc b/mindspore-lite/tools/optimizer/fusion/adjust_matmul_pass.cc index e0ebd0e5b3e3f811673302bc58105a2ba16dc3ae..7605d72720a675cdea15c60813b9a0efe7c82c4b 100644 --- a/mindspore-lite/tools/optimizer/fusion/adjust_matmul_pass.cc +++ b/mindspore-lite/tools/optimizer/fusion/adjust_matmul_pass.cc @@ -57,6 +57,35 @@ void SetMatMulTransposeAttr(const PrimitivePtr &src_prim, const PrimitivePtr &ds } } +CNodePtr CreateSqueezeCnode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); + MS_CHECK_TRUE_RET(cnode != nullptr, nullptr); + auto squeeze_op = std::make_unique(); + if (squeeze_op == nullptr) { + MS_LOG(ERROR) << "New Squeeze op failed, squeeze_op is nullptr!"; + return nullptr; + } + squeeze_op->set_axis({0}); + + auto squeeze_prim_c = squeeze_op->GetPrim(); + if (squeeze_prim_c == nullptr) { + MS_LOG(ERROR) << "squeeze_prim_c is nullptr!"; + return nullptr; + } + std::vector inputs = {cnode}; + auto squeeze_node = func_graph->NewCNode(squeeze_prim_c, inputs); + if (squeeze_node == nullptr) { + MS_LOG(ERROR) << "new squeeze cnode failed, squeeze_node is nullptr!"; + return nullptr; + } + squeeze_node->set_fullname_with_scope(cnode->fullname_with_scope() + "_data_squeeze"); + if (cnode->abstract() != nullptr) { + squeeze_node->set_abstract(cnode->abstract()->Clone()); + } + MS_LOG(INFO) << "Create squeeze node end."; + return squeeze_node; +} + CNodePtr CreateShapeCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); MS_CHECK_TRUE_RET(cnode != nullptr, nullptr); @@ -69,16 +98,105 @@ CNodePtr CreateShapeCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) return nullptr; } shape_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_shape"); - auto abstract = lite::CreateTensorAbstract({kShapeMinus_1}, kNumberTypeInt32); - if (abstract == nullptr) { - MS_LOG(ERROR) << "Create tensor abstract failed!"; - return nullptr; + if (cnode->abstract() != nullptr) { + shape_cnode->set_abstract(cnode->abstract()->Clone()); } - shape_cnode->set_abstract(abstract); MS_LOG(INFO) << "Create shape node end."; return shape_cnode; } +CNodePtr CreateRangeV2Cnode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); + MS_CHECK_TRUE_RET(cnode != nullptr, nullptr); + auto range_op = std::make_unique(); + if (range_op == nullptr) { + MS_LOG(ERROR) << "New RangeV2 op failed, range_op is nullptr!"; + return nullptr; + } + + auto range_prim_c = range_op->GetPrim(); + if (range_prim_c == nullptr) { + MS_LOG(ERROR) << "range_prim_c is nullptr!"; + return nullptr; + } + + auto start_num = opt::BuildIntValueParameterNode(func_graph, 0, cnode->fullname_with_scope() + "_start", false); + MS_CHECK_TRUE_RET(start_num != nullptr, nullptr); + auto delta_num = opt::BuildIntValueParameterNode(func_graph, 1, cnode->fullname_with_scope() + "_delta", false); + MS_CHECK_TRUE_RET(delta_num != nullptr, nullptr); + std::vector inputs = {start_num, cnode, delta_num}; + auto range_node = func_graph->NewCNode(range_prim_c, inputs); + if (range_node == nullptr) { + MS_LOG(ERROR) << "New range cnode failed, range_node is nullptr!"; + return nullptr; + } + range_node->set_fullname_with_scope(cnode->fullname_with_scope() + "_range"); + if (cnode->abstract() != nullptr) { + range_node->set_abstract(cnode->abstract()->Clone()); + } + MS_LOG(INFO) << "Create squeeze node end."; + return range_node; +} + +CNodePtr CreateSubCnode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); + MS_CHECK_TRUE_RET(cnode != nullptr, nullptr); + auto sub_op = std::make_unique(); + if (sub_op == nullptr) { + MS_LOG(ERROR) << "New Sub op failed, sub_op is nullptr!"; + return nullptr; + } + + auto sub_prim_c = sub_op->GetPrim(); + if (sub_prim_c == nullptr) { + MS_LOG(ERROR) << "sub_prim_c is nullptr!"; + return nullptr; + } + + auto sub_vale_parameter = + opt::BuildIntValueParameterNode(func_graph, 1, cnode->fullname_with_scope() + "_sub_param", false); + MS_CHECK_TRUE_RET(sub_vale_parameter != nullptr, nullptr); + std::vector inputs = {cnode, sub_vale_parameter}; + auto sub_node = func_graph->NewCNode(sub_prim_c, inputs); + if (sub_node == nullptr) { + MS_LOG(ERROR) << "New sub cnode failed, sub_node is nullptr!"; + return nullptr; + } + sub_node->set_fullname_with_scope(cnode->fullname_with_scope() + "_sub"); + if (cnode->abstract() != nullptr) { + sub_node->set_abstract(cnode->abstract()->Clone()); + } + MS_LOG(INFO) << "Create Sub node end."; + return sub_node; +} + +CNodePtr CreateAfterReshapeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const CNodePtr &shape_node) { + MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); + MS_CHECK_TRUE_RET(cnode != nullptr, nullptr); + if (shape_node == nullptr) { + MS_LOG(ERROR) << "Input shape cnode is nullptr!"; + return nullptr; + } + + auto reshape_prim_c = mindspore::prim::kPrimReshape; + if (reshape_prim_c == nullptr) { + MS_LOG(ERROR) << "New Reshape prim failed, reshape_prim_c is nullptr!"; + return nullptr; + } + std::vector inputs = {cnode, shape_node}; + auto reshape_node = func_graph->NewCNode(reshape_prim_c, inputs); + if (reshape_node == nullptr) { + MS_LOG(ERROR) << "New reshape cnode failed, reshape_node is nullptr!"; + return nullptr; + } + reshape_node->set_fullname_with_scope(cnode->fullname_with_scope() + "_reshape_after"); + if (cnode->abstract() != nullptr) { + reshape_node->set_abstract(cnode->abstract()->Clone()); + } + MS_LOG(INFO) << "Create reshape node end."; + return reshape_node; +} + std::vector GetTensorShape(CNodePtr cnode, size_t input_index) { auto abstract = GetCNodeInputAbstract(cnode, input_index); MS_CHECK_TRUE_RET(abstract != nullptr, {}); @@ -144,128 +262,6 @@ CNodePtr CreateMatmulCNode(const FuncGraphPtr &func_graph, const std::vector(); - MS_CHECK_TRUE_MSG(strided_slice_prim != nullptr, nullptr, "create strided_slice_prim return nullptr"); - auto strided_slice_prim_c = strided_slice_prim->GetPrim(); - MS_CHECK_TRUE_MSG(strided_slice_prim_c != nullptr, nullptr, "create strided_slice_prim_c return nullptr"); - int64_t fmk_type = converter::FmkType::kFmkTypeOnnx; - strided_slice_prim_c->AddAttr(ops::kFmkType, MakeValue(fmk_type)); - std::vector starts = {0}; - std::vector ends = {-1}; - std::vector axes = {0}; - std::vector steps = {1}; - std::string suffix = left ? "_left" : "_right"; - if (!left) { - starts = {-1}; - ends = {INT32_MAX}; - } - auto starts_parm_node = - opt::BuildIntVecParameterNode(func_graph, starts, input->fullname_with_scope() + suffix + "_slice_starts"); - MS_CHECK_TRUE_MSG(starts_parm_node != nullptr, nullptr, "create starts_parm_node return nullptr!"); - - auto ends_parm_node = - opt::BuildIntVecParameterNode(func_graph, ends, input->fullname_with_scope() + suffix + "_slice_ends"); - MS_CHECK_TRUE_MSG(ends_parm_node != nullptr, nullptr, "create ends_parm_node return nullptr!"); - - auto axes_parm_node = - opt::BuildIntVecParameterNode(func_graph, axes, input->fullname_with_scope() + suffix + "_slice_axes"); - MS_CHECK_TRUE_MSG(axes_parm_node != nullptr, nullptr, "create axes_parm_node return nullptr!"); - - auto steps_parm_node = - opt::BuildIntVecParameterNode(func_graph, steps, input->fullname_with_scope() + suffix + "_slice_steps"); - MS_CHECK_TRUE_MSG(steps_parm_node != nullptr, nullptr, "create steps_parm_node return nullptr!"); - - auto strided_slice_node = func_graph->NewCNode( - strided_slice_prim->GetPrim(), {input, starts_parm_node, ends_parm_node, axes_parm_node, steps_parm_node}); - MS_CHECK_TRUE_MSG(strided_slice_node != nullptr, nullptr, "create strided_slice node return nullptr"); - strided_slice_node->set_fullname_with_scope(input->fullname_with_scope() + suffix + "_slice"); - auto abstract = lite::CreateTensorAbstract({kShapeMinus_1}, kNumberTypeInt32); - if (abstract == nullptr) { - MS_LOG(ERROR) << "Create tensor abstract failed!"; - return nullptr; - } - strided_slice_node->set_abstract(abstract); - return strided_slice_node; -} - -CNodePtr CreateConcatCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, bool left) { - MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); - MS_CHECK_TRUE_RET(input != nullptr, nullptr); - std::string suffix = left ? "_left" : "_right"; - auto second_input = opt::BuildIntVecParameterNode(func_graph, {-1}, input->fullname_with_scope() + suffix + "_const"); - MS_CHECK_TRUE_MSG(second_input != nullptr, nullptr, "create concat const input return nullptr!"); - std::vector inputs; - if (left) { - inputs = {input, second_input}; - } else { - inputs = {second_input, input}; - } - auto concat_cnode = opt::GenConcatNode(func_graph, inputs, input->fullname_with_scope() + suffix + "_concat", 0); - MS_CHECK_TRUE_RET(concat_cnode != nullptr, nullptr); - auto abstract = lite::CreateTensorAbstract({kShapeMinus_1}, kNumberTypeInt32); - if (abstract == nullptr) { - MS_LOG(ERROR) << "Create tensor abstract failed!"; - return nullptr; - } - concat_cnode->set_abstract(abstract); - return concat_cnode; -} - -CNodePtr CreateReshapeCNode(const FuncGraphPtr &func_graph, const std::vector &inputs, - const AnfNodePtr origin_matmul) { - MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); - MS_CHECK_TRUE_RET(inputs.size() == kInputIndex_2, nullptr); - auto reshape_prim = std::make_shared(); - MS_CHECK_TRUE_MSG(reshape_prim != nullptr, nullptr, "create reshape_prim return nullptr!"); - auto reshape_prim_c = reshape_prim->GetPrim(); - MS_CHECK_TRUE_MSG(reshape_prim_c != nullptr, nullptr, "create prim_c return nullptr!"); - auto reshape_node = func_graph->NewCNode(reshape_prim_c, inputs); - MS_CHECK_TRUE_MSG(reshape_node != nullptr, nullptr, "create reshape_node return nullptr!"); - reshape_node->set_fullname_with_scope(inputs[0]->fullname_with_scope() + "_reshape"); - if (origin_matmul != nullptr) { - if (origin_matmul->abstract() == nullptr) { - MS_LOG(ERROR) << "Original matmul doesn't have abstract!"; - return nullptr; - } - reshape_node->set_abstract(origin_matmul->abstract()->Clone()); - } else { - auto abstract = lite::CreateTensorAbstract({kShapeMinus_1, kShapeMinus_1}, kNumberTypeFloat32); - if (abstract == nullptr) { - MS_LOG(ERROR) << "Create tensor abstract failed!"; - return nullptr; - } - reshape_node->set_abstract(abstract); - } - return reshape_node; -} - -CNodePtr CreateMatmulCNode(const FuncGraphPtr &func_graph, const std::vector &inputs, - const PrimitivePtr &bmm_prim, const std::string &name) { - MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); - MS_CHECK_TRUE_RET(bmm_prim != nullptr, nullptr); - auto matmul = std::make_shared(); - MS_CHECK_TRUE_MSG(matmul != nullptr, nullptr, "create matmul_prim return nullptr"); - auto dst_prim = matmul->GetPrim(); - MS_CHECK_TRUE_RET(dst_prim != nullptr, nullptr); - auto matmul_cnode = func_graph->NewCNode(dst_prim, inputs); - if (matmul_cnode == nullptr) { - MS_LOG(ERROR) << "New matmul_cnode is nullptr!"; - return nullptr; - } - auto abstract = lite::CreateTensorAbstract({kShapeMinus_1, kShapeMinus_1}, kNumberTypeFloat32); - if (abstract == nullptr) { - MS_LOG(ERROR) << "Create tensor abstract failed!"; - return nullptr; - } - matmul_cnode->set_abstract(abstract); - matmul_cnode->set_fullname_with_scope(name); - SetMatMulTransposeAttr(bmm_prim, dst_prim); - return matmul_cnode; -} - bool BMMToMMForStatic(const FuncGraphPtr &func_graph, const CNodePtr &batch_matmul_cnode) { auto x1_input = batch_matmul_cnode->input(kInputIndex_1); MS_CHECK_TRUE_RET(x1_input != nullptr, false); @@ -306,17 +302,6 @@ bool BMMToMMForStatic(const FuncGraphPtr &func_graph, const CNodePtr &batch_matm } bool BMMToMMForDynamic(const FuncGraphPtr &func_graph, const CNodePtr &batch_matmul_cnode) { - auto bmm_prim = GetCNodePrimitive(batch_matmul_cnode); - MS_CHECK_TRUE_RET(bmm_prim != nullptr, false); - auto trans_a = bmm_prim->GetAttr(mindspore::ops::kTransposeA); - auto trans_b = bmm_prim->GetAttr(mindspore::ops::kTransposeB); - auto trans_a_value = trans_a != nullptr && GetValue(trans_a); - auto trans_b_value = trans_b != nullptr && GetValue(trans_b); - if (trans_a_value || trans_b_value) { - MS_LOG(INFO) << "BMMToMM doesn't support trans_a == true or trans_b == true currently."; - return true; - } - auto batch_matmul_input_1 = batch_matmul_cnode->input(kInputIndex_1)->cast(); MS_CHECK_TRUE_RET(batch_matmul_input_1 != nullptr, false); auto matmul_weight_input = batch_matmul_cnode->input(kInputIndex_2); @@ -324,31 +309,81 @@ bool BMMToMMForDynamic(const FuncGraphPtr &func_graph, const CNodePtr &batch_mat auto data_shape_cnode = CreateShapeCNode(func_graph, batch_matmul_input_1); MS_CHECK_TRUE_RET(data_shape_cnode != nullptr, false); - auto left_strided_slice = CreateStridedSliceCNode(func_graph, data_shape_cnode, true); - MS_CHECK_TRUE_RET(left_strided_slice != nullptr, false); - auto right_strided_slice = CreateStridedSliceCNode(func_graph, data_shape_cnode, false); - MS_CHECK_TRUE_RET(right_strided_slice != nullptr, false); - auto left_concat = CreateConcatCNode(func_graph, left_strided_slice, true); - MS_CHECK_TRUE_RET(left_concat != nullptr, false); - auto right_concat = CreateConcatCNode(func_graph, right_strided_slice, false); - MS_CHECK_TRUE_RET(right_concat != nullptr, false); - auto up_reshape = CreateReshapeCNode(func_graph, {batch_matmul_input_1, right_concat}, nullptr); - MS_CHECK_TRUE_RET(up_reshape != nullptr, false); - - std::vector matmul_inputs = {up_reshape}; - for (size_t i = kInputIndex_2; i < batch_matmul_cnode->size(); ++i) { - matmul_inputs.push_back(batch_matmul_cnode->input(i)); - } - auto matmul_cnode = - CreateMatmulCNode(func_graph, matmul_inputs, bmm_prim, batch_matmul_cnode->fullname_with_scope() + "_bmm2mm"); - MS_CHECK_TRUE_RET(matmul_cnode != nullptr, false); - - auto down_reshape = CreateReshapeCNode(func_graph, {matmul_cnode, left_concat}, batch_matmul_cnode); - MS_CHECK_TRUE_RET(down_reshape != nullptr, false); + auto data_shape_gather_node = + opt::GenGatherNode(func_graph, data_shape_cnode, {kShapeMinus_1}, + data_shape_cnode->fullname_with_scope() + "_data_shape_gather", {kAxis_0}); + MS_CHECK_TRUE_RET(data_shape_gather_node != nullptr, false); + data_shape_gather_node->set_abstract(batch_matmul_cnode->abstract()->Clone()); + + auto data_concat_parm = opt::BuildIntVecParameterNode(func_graph, {kShape_1, kShapeMinus_1}, + batch_matmul_cnode->fullname_with_scope() + "_const_minus_2"); + MS_CHECK_TRUE_RET(data_concat_parm != nullptr, false); + data_concat_parm->set_abstract(batch_matmul_cnode->abstract()->Clone()); + + auto data_concat_cnode = opt::GenConcatNode(func_graph, {data_concat_parm, data_shape_gather_node}, + batch_matmul_cnode->fullname_with_scope() + "_concat", 0); + MS_CHECK_TRUE_RET(data_concat_cnode != nullptr, false); + data_concat_cnode->set_abstract(batch_matmul_cnode->abstract()->Clone()); + + // create reshape node, Data reshape to (1,ab,c), weight shape is (c,d) + auto reshape_data_node = CreateAfterReshapeNode(func_graph, batch_matmul_input_1, data_concat_cnode); + MS_CHECK_TRUE_RET(reshape_data_node != nullptr, false); + reshape_data_node->set_abstract(batch_matmul_cnode->abstract()->Clone()); + + auto squeeze_cnode = CreateSqueezeCnode(func_graph, reshape_data_node); + MS_CHECK_TRUE_RET(squeeze_cnode != nullptr, false); + + ops::MatMul matmul; + auto dst_prim = matmul.GetPrim(); + MS_CHECK_TRUE_RET(dst_prim != nullptr, false); + auto matmul_cnode = func_graph->NewCNode(dst_prim, {squeeze_cnode, matmul_weight_input}); + if (matmul_cnode == nullptr) { + MS_LOG(ERROR) << "New matmul_cnode is nullptr!"; + return false; + } + auto abstract = lite::CreateTensorAbstract({kShapeMinus_1, kShapeMinus_1}, kNumberTypeFloat32); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstract failed!"; + return false; + } + matmul_cnode->set_abstract(abstract); + matmul_cnode->set_fullname_with_scope(batch_matmul_cnode->fullname_with_scope() + "_bmm2mm"); + + auto prim = GetValueNode(batch_matmul_cnode->input(kInputIndex_0)); + MS_CHECK_TRUE_RET(prim != nullptr, false); + SetMatMulTransposeAttr(prim, dst_prim); + + auto data_shape_dim_cnode = CreateShapeCNode(func_graph, data_shape_cnode); + MS_CHECK_TRUE_RET(data_shape_dim_cnode != nullptr, false); + + auto range_limit_cnode = CreateSubCnode(func_graph, data_shape_dim_cnode); + MS_CHECK_TRUE_RET(range_limit_cnode != nullptr, false); + + auto range_cnode = CreateRangeV2Cnode(func_graph, range_limit_cnode); + MS_CHECK_TRUE_RET(range_cnode != nullptr, false); + + auto shape_gather_node = opt::GenGatherNodeDynamicIndex( + func_graph, data_shape_cnode, range_cnode, data_shape_cnode->fullname_with_scope() + "_gather", {kAxis_0}); + MS_CHECK_TRUE_RET(shape_gather_node != nullptr, false); + + auto concat_parm = opt::BuildIntValueParameterNode( + func_graph, kShapeMinus_1, batch_matmul_cnode->fullname_with_scope() + "_const_minus_1", false); + MS_CHECK_TRUE_RET(concat_parm != nullptr, false); + concat_parm->set_abstract(batch_matmul_cnode->abstract()->Clone()); + + auto concat_cnode = opt::GenConcatNode(func_graph, {shape_gather_node, concat_parm}, + batch_matmul_cnode->fullname_with_scope() + "_concat", 0); + MS_CHECK_TRUE_RET(concat_cnode != nullptr, false); + concat_cnode->set_abstract(batch_matmul_cnode->abstract()->Clone()); + + // reshape(MM, (a,b,d)) + auto reshape_output_cnode = CreateAfterReshapeNode(func_graph, matmul_cnode, concat_cnode); + MS_CHECK_TRUE_RET(reshape_output_cnode != nullptr, false); auto graph_manager = func_graph->manager(); MS_CHECK_TRUE_RET(graph_manager != nullptr, false); - if (!graph_manager->Replace(batch_matmul_cnode, down_reshape)) { + + if (!graph_manager->Replace(batch_matmul_cnode, reshape_output_cnode)) { MS_LOG(ERROR) << "Failed to replace MatMul with BatchMatMul! cnode " << batch_matmul_cnode->fullname_with_scope() << ", input size " << batch_matmul_cnode->size(); return false; diff --git a/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc b/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc index 3bbb9fdad31909ade7ca90b8a5e18c83fbd61ea2..2aeca734c444203375ae3cbfe9169adce61777da 100644 --- a/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/encoder_layer_fusion.cc @@ -639,8 +639,8 @@ VectorRef EncoderLayerFusion::DefinePatternEncoderLayer(bool post_layernorm = tr } auto is_add = std::make_shared(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add"); auto add = (is_position_bias && post_layernorm) - ? VectorRef({is_add, getTuple(post_layernorm, layernorm_fusion, is_position_bias), tuple}) - : VectorRef({is_add, reshape1, tuple}); + ? VectorRef({is_add, getTuple(post_layernorm, layernorm_fusion, is_position_bias), tuple}) + : VectorRef({is_add, reshape1, tuple}); auto is_reshape2 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder2"); MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {}); auto var2 = std::make_shared("var2"); @@ -1087,11 +1087,8 @@ CNodePtr EncoderLayerFusion::CreateMaskedEncoderLayerFusionNode(const FuncGraphP &expert_num, &expert_offset, &capacity_factor) == RET_OK, nullptr, "Init Attributes failed!"); auto encoder_layer_prim = CreatePrim(func_graph, equiv, ffn_hidden_size, expert_num, expert_offset, capacity_factor); - MS_CHECK_TRUE_RET(encoder_layer_prim != nullptr, nullptr); auto encoder_layer_prim_c = encoder_layer_prim->GetPrim(); - MS_CHECK_TRUE_RET(encoder_layer_prim_c != nullptr, nullptr); auto value_node = NewValueNode(encoder_layer_prim_c); - MS_CHECK_TRUE_RET(value_node != nullptr, nullptr); std::vector new_node_inputs = {value_node, input}; if (is_position_bias_) { auto position_bias = utils::cast((*equiv)[position_bias_]); diff --git a/mindspore-lite/tools/optimizer/fusion/matmul_mul_fusion.cc b/mindspore-lite/tools/optimizer/fusion/matmul_mul_fusion.cc index 4906a0462ecd2829c90d402ae444e1f70d8d8415..29944d5e79f1fc80a5b1577b62705e7e12ef7358 100644 --- a/mindspore-lite/tools/optimizer/fusion/matmul_mul_fusion.cc +++ b/mindspore-lite/tools/optimizer/fusion/matmul_mul_fusion.cc @@ -77,7 +77,6 @@ int CalNewCnodeScale(const CNodePtr &mul_cnode, const CNodePtr &matmul_cnode) { } return RET_OK; } - int CalNewCnodeBias(const CNodePtr &mul_cnode, const CNodePtr &matmul_cnode) { auto mul_weight_node = mul_cnode->input(kInputIndexTwo); std::shared_ptr mul_weight_tensor = GetTensorInfo(mul_weight_node); @@ -89,9 +88,10 @@ int CalNewCnodeBias(const CNodePtr &mul_cnode, const CNodePtr &matmul_cnode) { std::vector mul_weight_shape = mul_weight_tensor->shape(); auto mul_weight_data = reinterpret_cast(mul_weight_tensor->data_c()); MS_CHECK_TRUE_RET(mul_weight_data != nullptr, RET_ERROR); - auto mutmul_bias_node = matmul_cnode->input(kInputIndexThree); + MS_CHECK_TRUE_RET(mutmul_bias_node != nullptr, RET_ERROR); auto mutmul_bias_tensor = GetTensorInfo(mutmul_bias_node); + MS_CHECK_TRUE_RET(mutmul_bias_tensor != nullptr, RET_ERROR); if (mutmul_bias_tensor->data_type() != kNumberTypeFloat32) { MS_LOG(ERROR) << "only support float32 data type"; return RET_ERROR; diff --git a/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc b/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc index 13138233f66ea96128f136f1822a70c20dab6431..51d1225b4793741a4f9fe88a89512f664d5dde35 100644 --- a/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc @@ -414,8 +414,8 @@ int ControlFlowPass::CreateWhileAfterPartialNode( MS_ASSERT(value_node != nullptr); auto input_index = value_node->value()->type()->number_type() == kNumberTypeInt64 - ? GetValue(value_node->value()) - : GetValue(value_node->value()); + ? GetValue(value_node->value()) + : GetValue(value_node->value()); after_partial_cnode_inputs.push_back(cond_fg_inputs.at(input_index)); auto new_parameter = after_fg->add_parameter(); diff --git a/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc b/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc index c87b54537c24ff836c526abbd5136c1089b4bb66..6c6e0003c3ad9ec89f80020418d7a44e2157441f 100644 --- a/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc +++ b/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc @@ -510,29 +510,29 @@ int DecreaseTransposeAlgo::InsertPreTransForNonTransInOut(const FuncGraphPtr &fu TransTypePair trans_info) { std::function insert_pre_trans = [&](const AnfNodePtr &node, size_t index, FormatTransNodeType format_trans_type) -> bool { - MS_CHECK_TRUE_RET(node != nullptr, false); - auto cnode = node->cast(); - MS_CHECK_TRUE_RET(cnode != nullptr, false); - if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { - auto node_users = func_graph->manager()->node_users()[node]; - return std::all_of(node_users.begin(), node_users.end(), - [&insert_pre_trans, &format_trans_type](const std::pair &pair) { - return insert_pre_trans(pair.first, pair.second, format_trans_type); - }); - } else if (CheckPrimitiveType(cnode->input(1), prim::kPrimMakeTuple) || - CheckPrimitiveType(cnode->input(1), prim::kPrimMakeTupleV2)) { - auto make_tuple_cnode = cnode->input(1)->cast(); - MS_CHECK_TRUE_RET(make_tuple_cnode != nullptr, false); - for (size_t i = 0; i < make_tuple_cnode->size(); i++) { - if (!insert_pre_trans(make_tuple_cnode->input(i), i, format_trans_type)) { - return false; + MS_CHECK_TRUE_RET(node != nullptr, false); + auto cnode = node->cast(); + MS_CHECK_TRUE_RET(cnode != nullptr, false); + if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { + auto node_users = func_graph->manager()->node_users()[node]; + return std::all_of(node_users.begin(), node_users.end(), + [&insert_pre_trans, &format_trans_type](const std::pair &pair) { + return insert_pre_trans(pair.first, pair.second, format_trans_type); + }); + } else if (CheckPrimitiveType(cnode->input(1), prim::kPrimMakeTuple) || + CheckPrimitiveType(cnode->input(1), prim::kPrimMakeTupleV2)) { + auto make_tuple_cnode = cnode->input(1)->cast(); + MS_CHECK_TRUE_RET(make_tuple_cnode != nullptr, false); + for (size_t i = 0; i < make_tuple_cnode->size(); i++) { + if (!insert_pre_trans(make_tuple_cnode->input(i), i, format_trans_type)) { + return false; + } } + return true; } - return true; - } - auto perm = format_trans_type == kNHWC2NCHW ? kNC2NH : kNH2NC; - return GenNewInput(func_graph, cnode, perm, true, index) == lite::RET_OK; - }; + auto perm = format_trans_type == kNHWC2NCHW ? kNC2NH : kNH2NC; + return GenNewInput(func_graph, cnode, perm, true, index) == lite::RET_OK; + }; bool deal_inputs = std::all_of(not_trans_in_nodes.begin(), not_trans_in_nodes.end(), [&insert_pre_trans, &trans_info](const std::pair &pair) { return insert_pre_trans(pair.first, pair.second, trans_info.pre_); diff --git a/mindspore-lite/tools/optimizer/graph/preprocess_dynamic_shape.cc b/mindspore-lite/tools/optimizer/graph/preprocess_dynamic_shape.cc index 25b0e9547dfea4ae6238781c0416ee8d18951206..7dc321c1d06bde1f2fb1d3da33278f0c1967ee0f 100644 --- a/mindspore-lite/tools/optimizer/graph/preprocess_dynamic_shape.cc +++ b/mindspore-lite/tools/optimizer/graph/preprocess_dynamic_shape.cc @@ -455,8 +455,8 @@ int SplitInferShape(const CNodePtr &cnode, const std::vector &in_sh auto prim = GetCNodePrimitive(cnode); auto out_num = prim->GetAttr(ops::kOutputNum) == nullptr ? 0 : GetValue(prim->GetAttr(ops::kOutputNum)); auto size_splits = prim->GetAttr(ops::kSizeSplits) == nullptr - ? std::vector{} - : GetValue>(prim->GetAttr(ops::kSizeSplits)); + ? std::vector{} + : GetValue>(prim->GetAttr(ops::kSizeSplits)); out_num = (out_num == 0 ? static_cast(size_splits.size()) : out_num); if (out_num <= 0) { return lite::RET_NOT_SUPPORT; @@ -543,8 +543,8 @@ int StackInferShape(const CNodePtr &cnode, const std::vector &in_sh return lite::RET_INPUT_TENSOR_ERROR; } if (std::any_of(in_shapes.begin(), in_shapes.end(), [](const ShapeVector &in_shape) { - return std::any_of(in_shape.begin(), in_shape.end(), [](int64_t val) { return val == 0; }); - })) { + return std::any_of(in_shape.begin(), in_shape.end(), [](int64_t val) { return val == 0; }); + })) { return lite::RET_NOT_SUPPORT; } auto prim = GetCNodePrimitive(cnode); @@ -833,15 +833,15 @@ int DynamicShapePreprocessor::DoInfer(const CNodePtr &cnode, const std::string & std::map &in_shapes, std::vector *out_shapes)>> infer_func = { - {prim::kPrimAddFusion->name(), ArithmeticInferShape}, {prim::kPrimActivation->name(), CommonInferShape}, - {prim::kPrimCast->name(), CommonInferShape}, {prim::kPrimConcat->name(), ConcatInferShape}, - {prim::kPrimExpandDims->name(), ExpandDimsInferShape}, {prim::kPrimGather->name(), GatherInferShape}, - {prim::kPrimMatMulFusion->name(), MatMulInferShape}, {prim::kPrimMulFusion->name(), ArithmeticInferShape}, - {prim::kPrimNotEqual->name(), CommonInferShape}, {prim::kPrimReduceFusion->name(), ReduceInferShape}, - {prim::kPrimReshape->name(), ReshapeInferShape}, {prim::kPrimShape->name(), ShapeInferShape}, - {prim::kPrimSplit->name(), SplitInferShape}, {prim::kPrimSqueeze->name(), SqueezeInferShape}, - {prim::kPrimStack->name(), StackInferShape}, {prim::kPrimStridedSlice->name(), StridedSliceInferShape}, - {prim::kPrimTranspose->name(), TransposeInferShape}}; + {prim::kPrimAddFusion->name(), ArithmeticInferShape}, {prim::kPrimActivation->name(), CommonInferShape}, + {prim::kPrimCast->name(), CommonInferShape}, {prim::kPrimConcat->name(), ConcatInferShape}, + {prim::kPrimExpandDims->name(), ExpandDimsInferShape}, {prim::kPrimGather->name(), GatherInferShape}, + {prim::kPrimMatMulFusion->name(), MatMulInferShape}, {prim::kPrimMulFusion->name(), ArithmeticInferShape}, + {prim::kPrimNotEqual->name(), CommonInferShape}, {prim::kPrimReduceFusion->name(), ReduceInferShape}, + {prim::kPrimReshape->name(), ReshapeInferShape}, {prim::kPrimShape->name(), ShapeInferShape}, + {prim::kPrimSplit->name(), SplitInferShape}, {prim::kPrimSqueeze->name(), SqueezeInferShape}, + {prim::kPrimStack->name(), StackInferShape}, {prim::kPrimStridedSlice->name(), StridedSliceInferShape}, + {prim::kPrimTranspose->name(), TransposeInferShape}}; if (infer_func.find(op_type) == infer_func.end()) { MS_LOG(ERROR) << "Current op: " << op_type << " doesn't support infer."; return lite::RET_ERROR; diff --git a/mindspore-lite/tools/optimizer/graph/scalar_op_pass.cc b/mindspore-lite/tools/optimizer/graph/scalar_op_pass.cc index 5fc5b4c51a764c7404f9363fca88fab8f2695c1e..49d5603d40525558a746c60a83645460a7a258da 100644 --- a/mindspore-lite/tools/optimizer/graph/scalar_op_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/scalar_op_pass.cc @@ -387,7 +387,6 @@ STATUS ScalarOpPass::ReplaceMakeTuple(const FuncGraphPtr &func_graph, const AnfN std::vector concat_input_vec({anf_node}); auto concat_node = GenConcatNode(func_graph, concat_input_vec, anf_node->cast()->fullname_with_scope() + "_concat_make_tuple"); - MS_CHECK_TRUE_RET(concat_node != nullptr, lite::RET_ERROR); auto primitive = GetCNodePrimitive(concat_node); MS_CHECK_TRUE_RET(primitive != nullptr, lite::RET_ERROR); int64_t num_of_inputs = SizeToInt(anf_node->cast()->size() - kSizeOne); @@ -402,7 +401,6 @@ STATUS ScalarOpPass::ReplaceMakeTuple(const FuncGraphPtr &func_graph, const AnfN } auto concat_abstract = abstract::MakeAbstract(std::make_shared(ShapeVector({num_of_inputs})), TypeIdToType(make_tuple_type)); - MS_CHECK_TRUE_RET(concat_abstract != nullptr, lite::RET_ERROR); concat_node->set_abstract(concat_abstract); // set MakeTuple users' input to concat @@ -621,7 +619,6 @@ STATUS ScalarOpPass::RunArithmeticCheckPass(const FuncGraphPtr &func_graph, cons auto new_cast_node = GenCastNode(func_graph, second_input, second_input->fullname_with_scope() + "cast_after_second_in", cast_data_type, new_cast_abstract); - MS_CHECK_TRUE_RET(new_cast_node != nullptr, lite::RET_ERROR); new_cast_node->set_abstract(new_cast_abstract); manager->SetEdge(node, kIndexTwo, new_cast_node); } diff --git a/mindspore-lite/tools/optimizer/graph/send_op_add_control_depend.cc b/mindspore-lite/tools/optimizer/graph/send_op_add_control_depend.cc index 720bed0b05ad6a0f5162e7c9d87a0202aa5d1857..69edd51e916e43c02f0b45fe7615d401e9f1a447 100644 --- a/mindspore-lite/tools/optimizer/graph/send_op_add_control_depend.cc +++ b/mindspore-lite/tools/optimizer/graph/send_op_add_control_depend.cc @@ -39,9 +39,10 @@ const AnfNodePtr SendOpAddControlDepend::Process(const FuncGraphPtr &func_graph, #else const AnfNodePtr SendOpAddControlDepend::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_CHECK_TRUE_RET(node != nullptr, nullptr); + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); - MS_CHECK_TRUE_RET(cnode != nullptr, nullptr); + MS_EXCEPTION_IF_NULL(cnode); if (!IsPrimitiveCNode(cnode, prim::kPrimSend)) { return nullptr; } @@ -58,13 +59,13 @@ const AnfNodePtr SendOpAddControlDepend::Process(const FuncGraphPtr &func_graph, auto value = MakeValue(tensor); MS_CHECK_TRUE_RET(value != nullptr, nullptr); auto value_node = std::make_shared(value); - MS_CHECK_TRUE_RET(value_node != nullptr, nullptr); value_node->set_abstract(tensor->ToAbstract()); + MS_EXCEPTION_IF_NULL(value_node); func_graph->AddValueNode(value_node); std::vector depend_input = {NewValueNode(std::make_shared(kDependOpName)), value_node, cnode}; auto depend_node = NewCNode(depend_input, func_graph); - MS_CHECK_TRUE_RET(depend_node != nullptr, nullptr); + MS_EXCEPTION_IF_NULL(depend_node); depend_node->set_fullname_with_scope(node->fullname_with_scope() + "_Depend"); depend_node->set_scope(node->scope()); depend_node->set_abstract(value_node->abstract()); diff --git a/mindspore-lite/tools/optimizer/graph/slice_prepose_pass.cc b/mindspore-lite/tools/optimizer/graph/slice_prepose_pass.cc index ccc49d4c9463bf8167ff4997cb03538f77b5f063..f6c2a070bb4e0085bc84e8202bd448002819369b 100644 --- a/mindspore-lite/tools/optimizer/graph/slice_prepose_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/slice_prepose_pass.cc @@ -509,7 +509,7 @@ int64_t SlicePreposePass::GetReshapeAbnormalAxeIn(const std::vector &sh } } if (j == shape_out.size() && abnormal_axe_in == -1) { - abnormal_axe_in = static_cast(i); + abnormal_axe_in = i; } } return abnormal_axe_in; diff --git a/scripts/format_source_code.sh b/scripts/format_source_code.sh index 00fdf59e5dc8469fc3f80095ce874b57c562ab68..2e90cf111e9425b4062ff7e19a9789a2e032f002 100755 --- a/scripts/format_source_code.sh +++ b/scripts/format_source_code.sh @@ -86,11 +86,11 @@ cd "${SCRIPTS_PATH}/.." || exit 1 FMT_FILE_LIST='__format_files_list__' if [[ "X${mode}" == "Xall" ]]; then - find mindspore-lite -type f -name "*" | grep -E "(\.h$|\.cc$|\.c$)" > "${FMT_FILE_LIST}" || true + find mindspore/{ccsrc,core,lite} -type f -name "*" | grep -E "(\.h$|\.cc$|\.c$)" > "${FMT_FILE_LIST}" || true elif [[ "X${mode}" == "Xchanged" ]]; then - git diff --name-only | grep "mindspore-lite" | grep -E "(\.h$|\.cc$|\.c$)" > "${FMT_FILE_LIST}" || true + git diff --name-only | grep "mindspore/ops\|mindspore/ccsrc\|mindspore/core\|mindspore/lite\|include" | grep -E "(\.h$|\.cc$|\.c$)" > "${FMT_FILE_LIST}" || true else # "X${mode}" == "Xlastcommit" - git diff --name-only HEAD~ HEAD | grep "mindspore-lite" | grep -E "(\.h$|\.cc$|\.c$)" > "${FMT_FILE_LIST}" || true + git diff --name-only HEAD~ HEAD | grep "mindspore/ops\|mindspore/ccsrc\|mindspore/core\|mindspore/lite\|include" | grep -E "(\.h$|\.cc$|\.c$)" > "${FMT_FILE_LIST}" || true fi while read line; do diff --git a/scripts/lite_release_package.sh b/scripts/lite_release_package.sh index 23cc8aef49d8f232c3de2176973c53f7fb1e344c..5cdbe0af7ba371910eb17e8453e9d20ec17616c8 100644 --- a/scripts/lite_release_package.sh +++ b/scripts/lite_release_package.sh @@ -22,35 +22,53 @@ function android_release_package() device=$2 pkg_name="mindspore-lite-${version}-android-${arch}" - [ -n "${pkg_name}" ] && rm -rf ${pkg_name} || exit 1 - tar -xzf ${input_path}/android_${arch}/${device}/${pkg_name}.tar.gz || exit 1 + [ -n "${pkg_name}" ] && rm -rf ${pkg_name} + tar -xzf ${input_path}/android_${arch}/${device}/${pkg_name}.tar.gz # Copy java runtime to Android package - cp ${input_path}/aar/mindspore-lite-*.aar ${pkg_name} || exit 1 + cp ${input_path}/aar/mindspore-lite-*.aar ${pkg_name} - mkdir -p ${output_path}/release/android/${device}/ || exit 1 - tar -czf ${output_path}/release/android/${device}/${pkg_name}.tar.gz ${pkg_name} || exit 1 - [ -n "${pkg_name}" ] && rm -rf ${pkg_name} || exit 1 - cd ${output_path}/release/android/${device}/ || exit 1 - sha256sum ${pkg_name}.tar.gz > ${pkg_name}.tar.gz.sha256 || exit 1 + mkdir -p ${output_path}/release/android/${device}/ + tar -czf ${output_path}/release/android/${device}/${pkg_name}.tar.gz ${pkg_name} + [ -n "${pkg_name}" ] && rm -rf ${pkg_name} + cd ${output_path}/release/android/${device}/ + sha256sum ${pkg_name}.tar.gz > ${pkg_name}.tar.gz.sha256 +} + +function ios_release_package() +{ + mkdir -p ${output_path}/release/ios/ + cp ${input_path}/ios_aarch64/*.tar.gz* ${output_path}/release/ios/ + cp ${input_path}/ios_aarch32/*.tar.gz* ${output_path}/release/ios/ } function linux_release_package() { - mkdir -p ${output_path}/release/linux/x86_64/ || exit 1 - mkdir -p ${output_path}/release/linux/aarch64/ || exit 1 - mkdir -p ${output_path}/release/linux/x86_64/cloud_fusion/ || exit 1 - mkdir -p ${output_path}/release/linux/aarch64/cloud_fusion/ || exit 1 + mkdir -p ${output_path}/release/linux/nnie/ + mkdir -p ${output_path}/release/linux/x86_64/ + mkdir -p ${output_path}/release/linux/aarch64/ + mkdir -p ${output_path}/release/linux/x86_64/ascend/ + mkdir -p ${output_path}/release/linux/aarch64/ascend/ + mkdir -p ${output_path}/release/linux/x86_64/cloud_fusion/ + mkdir -p ${output_path}/release/linux/aarch64/cloud_fusion/ + mkdir -p ${output_path}/release/none/cortex_m7 + + cp ${input_path}/none_cortex-m/mindspore*cortex-m7.tar.gz* ${output_path}/release/none/cortex_m7/ + cp ${input_path}/centos_x86/avx/mindspore*.tar.gz* ${output_path}/release/linux/x86_64/ + cp ${input_path}/linux_aarch64/mindspore*.tar.gz* ${output_path}/release/linux/aarch64/ + cp ${input_path}/centos_x86/ascend/mindspore*.tar.gz* ${output_path}/release/linux/x86_64/ascend/ + cp ${input_path}/linux_aarch64/ascend/mindspore*.tar.gz* ${output_path}/release/linux/aarch64/ascend/ + cp -r ${input_path}/centos_x86/cloud_fusion/* ${output_path}/release/linux/x86_64/cloud_fusion/ + cp -r ${input_path}/linux_aarch64/cloud_fusion/* ${output_path}/release/linux/aarch64/cloud_fusion/ - cp ${input_path}/centos_x86/avx/mindspore*.tar.gz* ${output_path}/release/linux/x86_64/ || exit 1 - cp ${input_path}/linux_aarch64/mindspore*.tar.gz* ${output_path}/release/linux/aarch64/ || exit 1 - cp -r ${input_path}/centos_x86/cloud_fusion/* ${output_path}/release/linux/x86_64/cloud_fusion/ || exit 1 - cp -r ${input_path}/linux_aarch64/cloud_fusion/* ${output_path}/release/linux/aarch64/cloud_fusion/ || exit 1 + cp -r ${input_path}/linux_aarch32/nnie/Hi* ${output_path}/release/linux/nnie/ + cp -r ${input_path}/linux_aarch64/nnie/Hi* ${output_path}/release/linux/nnie/ + cp ${input_path}/centos_x86/nnie/Hi3516D/*.tar.gz* ${output_path}/release/linux/nnie/ } function windows_release_package() { - mkdir -p ${output_path}/release/windows/ || exit 1 - cp ${input_path}/windows_x64/avx/*.zip* ${output_path}/release/windows/ || exit 1 + mkdir -p ${output_path}/release/windows/ + cp ${input_path}/windows_x64/avx/*.zip* ${output_path}/release/windows/ } echo "============================== begin ==============================" @@ -60,9 +78,12 @@ input_path=$1 output_path=$2 version=$(ls ${input_path}/android_aarch64/npu/mindspore-lite-*-*.tar.gz | awk -F'/' '{print $NF}' | cut -d"-" -f3) +android_release_package aarch32 npu +android_release_package aarch32 cpu android_release_package aarch64 npu -# android_release_package aarch64 gpu +android_release_package aarch64 gpu +ios_release_package linux_release_package windows_release_package diff --git a/third_party/patch/mindspore/decouple_mindspore.patch b/third_party/patch/mindspore/decouple_mindspore.patch new file mode 100644 index 0000000000000000000000000000000000000000..8516e08d5b86ead4ed43ebb9fed871e49216ab09 --- /dev/null +++ b/third_party/patch/mindspore/decouple_mindspore.patch @@ -0,0 +1,4406 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 58330608769..f647e7fb3d4 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -48,7 +48,6 @@ endif() + + if(ENABLE_PYTHON) + add_compile_definitions(ENABLE_PYTHON) +- add_compile_definitions(ENABLE_MINDDATA_PYTHON) + endif() + + if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") +diff --git a/build.bat b/build.bat +index 2c8b6dfa979..0c524033f4d 100644 +--- a/build.bat ++++ b/build.bat +@@ -70,10 +70,10 @@ IF "%1%" == "lite" ( + rd /s /q "%BASE_PATH%\output" + (git log -1 | findstr "^commit") > %BUILD_PATH%\.commit_id + IF defined VisualStudioVersion ( +- cmake -DMSLITE_MINDDATA_IMPLEMENT=off -DMSLITE_ENABLE_TRAIN=off -DVERSION_STR=%VERSION_STR% ^ ++ cmake -DMSLITE_ENABLE_TRAIN=off -DVERSION_STR=%VERSION_STR% ^ + -DCMAKE_BUILD_TYPE=Release -G "Ninja" "%BASE_PATH%/mindspore/lite" + ) ELSE ( +- cmake -DMSLITE_MINDDATA_IMPLEMENT=off -DMSLITE_ENABLE_TRAIN=off -DVERSION_STR=%VERSION_STR% ^ ++ cmake -DMSLITE_ENABLE_TRAIN=off -DVERSION_STR=%VERSION_STR% ^ + -DCMAKE_BUILD_TYPE=Release -G "CodeBlocks - MinGW Makefiles" "%BASE_PATH%/mindspore/lite" + ) + ) ELSE ( +diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake +index de8de6c1ab9..9e0408fad2c 100644 +--- a/cmake/package_lite.cmake ++++ b/cmake/package_lite.cmake +@@ -224,131 +224,131 @@ function(__install_ascend_ascendc) + endfunction() + + # full mode will also package the files of lite_cv mode. +-if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") +- # full header files +- install(FILES +- ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/constants.h +- ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/data_helper.h +- ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/execute.h +- ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/iterator.h +- ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/samplers.h +- ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/transforms.h +- ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/vision_lite.h +- ${TOP_DIR}/mindspore/lite/minddata/dataset/liteapi/include/datasets.h +- DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- +- if(PLATFORM_ARM64) +- if((MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) AND MSLITE_ENABLE_ACL) +- install(FILES ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/vision_ascend.h +- DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/kernels-dvpp-image/utils/libdvpp_utils.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- endif() +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.a DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${JPEGTURBO_LIB_LIST} DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/mindspore/lite/build/securec/src/libsecurec.a +- DESTINATION ${SECUREC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- elseif(PLATFORM_ARM32) +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.a DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${JPEGTURBO_LIB_LIST} DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/mindspore/lite/build/securec/src/libsecurec.a +- DESTINATION ${SECUREC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- else() +- if((MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) AND MSLITE_ENABLE_ACL) +- install(FILES ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/vision_ascend.h +- DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/kernels-dvpp-image/utils/libdvpp_utils.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- endif() +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.a DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${jpeg_turbo_LIBPATH}/libjpeg.so.62.4.0 DESTINATION ${TURBO_DIR}/lib +- RENAME libjpeg.so.62 COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${jpeg_turbo_LIBPATH}/libturbojpeg.so.0.3.0 DESTINATION ${TURBO_DIR}/lib +- RENAME libturbojpeg.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/mindspore/lite/build/securec/src/libsecurec.a +- DESTINATION ${SECUREC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- endif() +- +- # lite_cv header files +- install(DIRECTORY ${TOP_DIR}/mindspore/lite/minddata/dataset/kernels/image/lite_cv +- DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") +-endif() ++#if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") ++# # full header files ++# install(FILES ++# ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/constants.h ++# ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/data_helper.h ++# ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/execute.h ++# ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/iterator.h ++# ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/samplers.h ++# ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/transforms.h ++# ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/vision_lite.h ++# ${TOP_DIR}/mindspore/lite/minddata/dataset/liteapi/include/datasets.h ++# DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# ++# if(PLATFORM_ARM64) ++# if((MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) AND MSLITE_ENABLE_ACL) ++# install(FILES ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/vision_ascend.h ++# DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/kernels-dvpp-image/utils/libdvpp_utils.so ++# DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# endif() ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.a DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${JPEGTURBO_LIB_LIST} DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/securec/src/libsecurec.a ++# DESTINATION ${SECUREC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# elseif(PLATFORM_ARM32) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.a DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${JPEGTURBO_LIB_LIST} DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/securec/src/libsecurec.a ++# DESTINATION ${SECUREC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# else() ++# if((MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) AND MSLITE_ENABLE_ACL) ++# install(FILES ${TOP_DIR}/mindspore/lite/minddata/dataset/include/dataset/vision_ascend.h ++# DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/kernels-dvpp-image/utils/libdvpp_utils.so ++# DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# endif() ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.a DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${jpeg_turbo_LIBPATH}/libjpeg.so.62.4.0 DESTINATION ${TURBO_DIR}/lib ++# RENAME libjpeg.so.62 COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${jpeg_turbo_LIBPATH}/libturbojpeg.so.0.3.0 DESTINATION ${TURBO_DIR}/lib ++# RENAME libturbojpeg.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/securec/src/libsecurec.a ++# DESTINATION ${SECUREC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# endif() ++# ++# # lite_cv header files ++# install(DIRECTORY ${TOP_DIR}/mindspore/lite/minddata/dataset/kernels/image/lite_cv ++# DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") ++#endif() + +-if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper") +- install(DIRECTORY ${TOP_DIR}/mindspore/lite/minddata/dataset/include/ DESTINATION ${MIND_DATA_INC_DIR} +- COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "vision.h" EXCLUDE) +- if(PLATFORM_ARM64) +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${JPEGTURBO_LIB_LIST} DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) +- elseif(PLATFORM_ARM32) +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${JPEGTURBO_LIB_LIST} DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) +- else() +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${jpeg_turbo_LIBPATH}/libjpeg.so.62.4.0 DESTINATION ${TURBO_DIR}/lib RENAME libjpeg.so.62 +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${jpeg_turbo_LIBPATH}/libturbojpeg.so.0.3.0 DESTINATION ${TURBO_DIR}/lib RENAME libturbojpeg.so.0 +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- endif() +-endif() ++#if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "wrapper") ++# install(DIRECTORY ${TOP_DIR}/mindspore/lite/minddata/dataset/include/ DESTINATION ${MIND_DATA_INC_DIR} ++# COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "vision.h" EXCLUDE) ++# if(PLATFORM_ARM64) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${JPEGTURBO_LIB_LIST} DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# elseif(PLATFORM_ARM32) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${JPEGTURBO_LIB_LIST} DESTINATION ${TURBO_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# else() ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${jpeg_turbo_LIBPATH}/libjpeg.so.62.4.0 DESTINATION ${TURBO_DIR}/lib RENAME libjpeg.so.62 ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${jpeg_turbo_LIBPATH}/libturbojpeg.so.0.3.0 DESTINATION ${TURBO_DIR}/lib RENAME libturbojpeg.so.0 ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# endif() ++#endif() + +-if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite") +- install(DIRECTORY ${TOP_DIR}/mindspore/lite/minddata/dataset/include/ DESTINATION ${MIND_DATA_INC_DIR} +- COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") +- if(PLATFORM_ARM64) +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libjpeg.so DESTINATION ${TURBO_DIR}/lib +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libturbojpeg.so DESTINATION ${TURBO_DIR}/lib +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- elseif(PLATFORM_ARM32) +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libjpeg.so DESTINATION ${TURBO_DIR}/lib +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libturbojpeg.so DESTINATION ${TURBO_DIR}/lib +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- else() +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libjpeg.so.62.4.0 +- DESTINATION ${TURBO_DIR}/lib RENAME libjpeg.so.62 COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libturbojpeg.so.0.3.0 +- DESTINATION ${TURBO_DIR}/lib RENAME libturbojpeg.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) +- endif() +-endif() ++#if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite") ++# install(DIRECTORY ${TOP_DIR}/mindspore/lite/minddata/dataset/include/ DESTINATION ${MIND_DATA_INC_DIR} ++# COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") ++# if(PLATFORM_ARM64) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libjpeg.so DESTINATION ${TURBO_DIR}/lib ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libturbojpeg.so DESTINATION ${TURBO_DIR}/lib ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# elseif(PLATFORM_ARM32) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libjpeg.so DESTINATION ${TURBO_DIR}/lib ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libturbojpeg.so DESTINATION ${TURBO_DIR}/lib ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# else() ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libjpeg.so.62.4.0 ++# DESTINATION ${TURBO_DIR}/lib RENAME libjpeg.so.62 COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/third_party/libjpeg-turbo/lib/libturbojpeg.so.0.3.0 ++# DESTINATION ${TURBO_DIR}/lib RENAME libturbojpeg.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# endif() ++#endif() + +-if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite_cv") +- if(PLATFORM_ARM64) +- install(DIRECTORY ${TOP_DIR}/mindspore/lite/minddata/dataset/kernels/image/lite_cv +- DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- elseif(PLATFORM_ARM32) +- install(DIRECTORY ${TOP_DIR}/mindspore/lite/minddata/dataset/kernels/image/lite_cv +- DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- else() +- install(DIRECTORY ${TOP_DIR}/mindspore/lite/minddata/dataset/kernels/image/lite_cv +- DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") +- install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION +- ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) +- endif() +-endif() ++#if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite_cv") ++# if(PLATFORM_ARM64) ++# install(DIRECTORY ${TOP_DIR}/mindspore/lite/minddata/dataset/kernels/image/lite_cv ++# DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so ++# DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# elseif(PLATFORM_ARM32) ++# install(DIRECTORY ${TOP_DIR}/mindspore/lite/minddata/dataset/kernels/image/lite_cv ++# DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# else() ++# install(DIRECTORY ${TOP_DIR}/mindspore/lite/minddata/dataset/kernels/image/lite_cv ++# DESTINATION ${MIND_DATA_INC_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") ++# install(FILES ${TOP_DIR}/mindspore/lite/build/minddata/libminddata-lite.so DESTINATION ++# ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# endif() ++#endif() + + if(WIN32) + install(FILES ${TOP_DIR}/build/.commit_id DESTINATION ${RUNTIME_PKG_NAME} +@@ -413,14 +413,16 @@ if(PLATFORM_ARM64) + install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/convert/libruntime_convert_plugin.so + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_ACL) +- install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++ # install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so ++ # DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(NOT MSLITE_SIMPLEST_CLOUD_INFERENCE) + install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++ install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_acl/libascend_acl_plugin.so ++ DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() +- install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/cxx_api/llm_engine/libllm_engine_plugin.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/cxx_api/llm_engine/libllm_engine_plugin.so ++# DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + __install_ascend_tbe_and_aicpu() + __install_ascend_ascendc() + endif() +@@ -434,8 +436,8 @@ if(PLATFORM_ARM64) + install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR} + COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_ACL) +- install(FILES ${TOP_DIR}/mindspore/lite/build/src/litert/kernel/ascend/libascend_kernel_plugin.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++ # install(FILES ${TOP_DIR}/mindspore/lite/build/src/litert/kernel/ascend/libascend_kernel_plugin.so ++ # DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() + endif() + if(MSLITE_ENABLE_MODEL_OBF) +@@ -538,11 +540,11 @@ if(PLATFORM_ARM64) + endif() + if(MSLITE_ENABLE_ACL) + set(LITE_ACL_DIR ${TOP_DIR}/mindspore/lite/build/tools/converter/adapter/acl) +- install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so +- DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so ++# DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_RUNTIME_CONVERT) +- install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so ++# DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(FILES ${glog_LIBPATH}/${glog_name} DESTINATION ${RUNTIME_LIB_DIR} + RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(TARGETS mindspore_core mindspore_ops DESTINATION ${CONVERTER_ROOT_DIR}/lib +@@ -586,25 +588,25 @@ if(PLATFORM_ARM64) + COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() + endif() +- if((MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) +- AND MSLITE_ENABLE_GRAPH_KERNEL AND CMAKE_SYSTEM_NAME MATCHES "Linux") +- if(EXISTS ${BUILD_DIR}/akg) +- set(AKG_PATH ${BUILD_DIR}/akg) +- file(REMOVE_RECURSE ${AKG_PATH}/build/akg/lib) +- install(DIRECTORY ${AKG_PATH}/build/akg +- DESTINATION ${BUILD_DIR}/package/mindspore_lite +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${AKG_PATH}/${AKG_PKG_PATH} +- DESTINATION ${RUNTIME_PKG_NAME}/tools/akg +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${AKG_PATH}/${AKG_PKG_PATH}.sha256 +- DESTINATION ${RUNTIME_PKG_NAME}/tools/akg +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${AKG_PATH}/build/libakg.so +- DESTINATION ${BUILD_DIR}/package/mindspore_lite/lib +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- endif() +- endif() ++# if((MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE) ++# AND MSLITE_ENABLE_GRAPH_KERNEL AND CMAKE_SYSTEM_NAME MATCHES "Linux") ++# if(EXISTS ${BUILD_DIR}/akg) ++# set(AKG_PATH ${BUILD_DIR}/akg) ++# file(REMOVE_RECURSE ${AKG_PATH}/build/akg/lib) ++# install(DIRECTORY ${AKG_PATH}/build/akg ++# DESTINATION ${BUILD_DIR}/package/mindspore_lite ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${AKG_PATH}/${AKG_PKG_PATH} ++# DESTINATION ${RUNTIME_PKG_NAME}/tools/akg ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${AKG_PATH}/${AKG_PKG_PATH}.sha256 ++# DESTINATION ${RUNTIME_PKG_NAME}/tools/akg ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${AKG_PATH}/build/libakg.so ++# DESTINATION ${BUILD_DIR}/package/mindspore_lite/lib ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# endif() ++# endif() + endif() + endif() + if(MSLITE_ENABLE_TESTCASES) +@@ -673,14 +675,16 @@ elseif(PLATFORM_ARM32) + install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/convert/libruntime_convert_plugin.so + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_ACL) +- install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++ # install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so ++ # DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(NOT MSLITE_SIMPLEST_CLOUD_INFERENCE) + install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++ install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_acl/libascend_acl_plugin.so ++ DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() +- install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/cxx_api/llm_engine/libllm_engine_plugin.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/cxx_api/llm_engine/libllm_engine_plugin.so ++# DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + __install_ascend_tbe_and_aicpu() + __install_ascend_ascendc() + endif() +@@ -694,8 +698,8 @@ elseif(PLATFORM_ARM32) + install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR} + COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_ACL) +- install(FILES ${TOP_DIR}/mindspore/lite/build/src/litert/kernel/ascend/libascend_kernel_plugin.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++ # install(FILES ${TOP_DIR}/mindspore/lite/build/src/litert/kernel/ascend/libascend_kernel_plugin.so ++ # DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() + endif() + if(MSLITE_ENABLE_MODEL_OBF) +@@ -882,14 +886,16 @@ else() + install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/convert/libruntime_convert_plugin.so + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_ACL) +- install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++ # install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/kernel/ascend/libascend_kernel_plugin.so ++ # DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(NOT MSLITE_SIMPLEST_CLOUD_INFERENCE) + install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_ge/libascend_ge_plugin.so + DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++ install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/ascend_acl/libascend_acl_plugin.so ++ DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() +- install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/cxx_api/llm_engine/libllm_engine_plugin.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/cxx_api/llm_engine/libllm_engine_plugin.so ++# DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + __install_ascend_tbe_and_aicpu() + __install_ascend_ascendc() + if(MSLITE_ASCEND_TARGET) +@@ -913,8 +919,8 @@ else() + install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.a DESTINATION ${RUNTIME_LIB_DIR} + COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_ACL) +- install(FILES ${TOP_DIR}/mindspore/lite/build/src/litert/kernel/ascend/libascend_kernel_plugin.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++ # install(FILES ${TOP_DIR}/mindspore/lite/build/src/litert/kernel/ascend/libascend_kernel_plugin.so ++ # DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + endif() + endif() + if(MSLITE_ENABLE_MODEL_OBF) +@@ -977,11 +983,11 @@ else() + + if(MSLITE_ENABLE_ACL) + set(LITE_ACL_DIR ${TOP_DIR}/mindspore/lite/build/tools/converter/adapter/acl) +- install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so +- DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so ++# DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME}) + if(MSLITE_ENABLE_RUNTIME_CONVERT) +- install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so +- DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${LITE_ACL_DIR}/mslite_shared_lib/libmslite_shared_lib.so ++# DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(FILES ${glog_LIBPATH}/${glog_name} DESTINATION ${RUNTIME_LIB_DIR} + RENAME libmindspore_glog.so.0 COMPONENT ${RUNTIME_COMPONENT_NAME}) + install(TARGETS mindspore_core mindspore_ops DESTINATION ${RUNTIME_LIB_DIR} +@@ -1032,24 +1038,24 @@ else() + endif() + endif() + if(MSLITE_ENABLE_TOOLS) +- if(MSLITE_ENABLE_GRAPH_KERNEL AND CMAKE_SYSTEM_NAME MATCHES "Linux") +- if(EXISTS ${BUILD_DIR}/akg) +- set(AKG_PATH ${BUILD_DIR}/akg) +- file(REMOVE_RECURSE ${AKG_PATH}/build/akg/lib) +- install(DIRECTORY ${AKG_PATH}/build/akg +- DESTINATION ${BUILD_DIR}/package/mindspore_lite +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${AKG_PATH}/${AKG_PKG_PATH} +- DESTINATION ${RUNTIME_PKG_NAME}/tools/akg +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${AKG_PATH}/${AKG_PKG_PATH}.sha256 +- DESTINATION ${RUNTIME_PKG_NAME}/tools/akg +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- install(FILES ${AKG_PATH}/build/libakg.so +- DESTINATION ${BUILD_DIR}/package/mindspore_lite/lib +- COMPONENT ${RUNTIME_COMPONENT_NAME}) +- endif() +- endif() ++# if(MSLITE_ENABLE_GRAPH_KERNEL AND CMAKE_SYSTEM_NAME MATCHES "Linux") ++# if(EXISTS ${BUILD_DIR}/akg) ++# set(AKG_PATH ${BUILD_DIR}/akg) ++# file(REMOVE_RECURSE ${AKG_PATH}/build/akg/lib) ++# install(DIRECTORY ${AKG_PATH}/build/akg ++# DESTINATION ${BUILD_DIR}/package/mindspore_lite ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${AKG_PATH}/${AKG_PKG_PATH} ++# DESTINATION ${RUNTIME_PKG_NAME}/tools/akg ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${AKG_PATH}/${AKG_PKG_PATH}.sha256 ++# DESTINATION ${RUNTIME_PKG_NAME}/tools/akg ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# install(FILES ${AKG_PATH}/build/libakg.so ++# DESTINATION ${BUILD_DIR}/package/mindspore_lite/lib ++# COMPONENT ${RUNTIME_COMPONENT_NAME}) ++# endif() ++# endif() + if(NOT MSLITE_COMPILE_TWICE) + install(TARGETS ${BENCHMARK_NAME} RUNTIME DESTINATION ${BENCHMARK_ROOT_DIR} + COMPONENT ${RUNTIME_COMPONENT_NAME}) +diff --git a/include/api/types.h b/include/api/types.h +index 0fca6ed72b9..5db78fad5c6 100644 +--- a/include/api/types.h ++++ b/include/api/types.h +@@ -350,6 +350,8 @@ class MS_API MSTensor { + friend class ModelImpl; + std::shared_ptr impl_; + }; ++using MSTensorPtr = std::shared_ptr; ++using MSTensorOrderMap = std::map>; + + class MS_API Buffer { + public: +diff --git a/mindspore/ccsrc/backend/common/graph_kernel/core/graph_kernel_utils.cc b/mindspore/ccsrc/backend/common/graph_kernel/core/graph_kernel_utils.cc +index 1e9a655551f..263ed409e91 100644 +--- a/mindspore/ccsrc/backend/common/graph_kernel/core/graph_kernel_utils.cc ++++ b/mindspore/ccsrc/backend/common/graph_kernel/core/graph_kernel_utils.cc +@@ -132,7 +132,7 @@ std::vector GkUtils::GetValidOps(const std::vector &o + } + + std::vector GkUtils::FilterExcludedOps(const std::vector &ops) { +-#ifndef MSLITE_ENABLE_GRAPH_KERNEL ++ //#ifndef MSLITE_ENABLE_GRAPH_KERNEL + if (Callback::Instance()->GetTargetFromContext() != kGPUDevice) { + return ops; + } +@@ -165,9 +165,9 @@ std::vector GkUtils::FilterExcludedOps(const std::vector &ops_with_level, +diff --git a/mindspore/ccsrc/backend/common/graph_kernel/core/parallel_op_combine.cc b/mindspore/ccsrc/backend/common/graph_kernel/core/parallel_op_combine.cc +index 97a47eae99d..d820127d424 100644 +--- a/mindspore/ccsrc/backend/common/graph_kernel/core/parallel_op_combine.cc ++++ b/mindspore/ccsrc/backend/common/graph_kernel/core/parallel_op_combine.cc +@@ -229,44 +229,44 @@ bool ParallelOpCombiner::AutoUpdateInfo(const CNodePtr &to_update) { + << to_update->size(); + return false; + } +-#ifndef MSLITE_ENABLE_GRAPH_KERNEL ++ //#ifndef MSLITE_ENABLE_GRAPH_KERNEL + Callback::Instance()->ResetKernelInfo(to_update); +-#else +- auto rep_input = to_update->input(1); +- // NOTE: We assume the inputs' formats and types are consistent with outputs'. +- std::string input_format = Callback::Instance()->GetTargetFromContext() == kAscendDevice ? "" : kOpFormat_NCHW; +- auto GetPrevOutFormat = [&input_format](const CNodePtr &cnode) -> bool { +- if (cnode == nullptr || !cnode->HasAttr(kOutputsFormat)) { +- return false; +- } +- auto prev_of = GetValue >(cnode->GetAttr(kOutputsFormat)); +- if (prev_of.size() > 0) { +- input_format = prev_of[0]; +- return true; +- } +- return false; +- }; +- if (AnfUtils::IsRealKernel(rep_input)) { +- (void)GetPrevOutFormat(rep_input->cast()); +- } +- if (input_format.empty()) { +- auto it = children_map_.find(rep_input); +- if (it != children_map_.end()) { +- for (auto orig_user : it->second) { +- if (GetPrevOutFormat(orig_user->cast())) { +- break; +- } +- } +- } +- } +- if (input_format.empty()) { +- MS_LOG(WARNING) << "Cannot find prev node's input format, use " << layout_ +- << " by default and that may cause error."; +- input_format = layout_; +- } +- std::vector outputs_formats(AnfUtils::GetOutputTensorNum(to_update), input_format); +- to_update->AddAttr(kOutputsFormat, MakeValue(outputs_formats)); +-#endif ++ //#else ++ // auto rep_input = to_update->input(1); ++ // // NOTE: We assume the inputs' formats and types are consistent with outputs'. ++ // std::string input_format = Callback::Instance()->GetTargetFromContext() == kAscendDevice ? "" : kOpFormat_NCHW; ++ // auto GetPrevOutFormat = [&input_format](const CNodePtr &cnode) -> bool { ++ // if (cnode == nullptr || !cnode->HasAttr(kOutputsFormat)) { ++ // return false; ++ // } ++ // auto prev_of = GetValue >(cnode->GetAttr(kOutputsFormat)); ++ // if (prev_of.size() > 0) { ++ // input_format = prev_of[0]; ++ // return true; ++ // } ++ // return false; ++ // }; ++ // if (AnfUtils::IsRealKernel(rep_input)) { ++ // (void)GetPrevOutFormat(rep_input->cast()); ++ // } ++ // if (input_format.empty()) { ++ // auto it = children_map_.find(rep_input); ++ // if (it != children_map_.end()) { ++ // for (auto orig_user : it->second) { ++ // if (GetPrevOutFormat(orig_user->cast())) { ++ // break; ++ // } ++ // } ++ // } ++ // } ++ // if (input_format.empty()) { ++ // MS_LOG(WARNING) << "Cannot find prev node's input format, use " << layout_ ++ // << " by default and that may cause error."; ++ // input_format = layout_; ++ // } ++ // std::vector outputs_formats(AnfUtils::GetOutputTensorNum(to_update), input_format); ++ // to_update->AddAttr(kOutputsFormat, MakeValue(outputs_formats)); ++ //#endif + return true; + } + +diff --git a/mindspore/ccsrc/backend/common/graph_kernel/expanders/utils.cc b/mindspore/ccsrc/backend/common/graph_kernel/expanders/utils.cc +index 4fbbe68a24d..902cb0d8b3f 100644 +--- a/mindspore/ccsrc/backend/common/graph_kernel/expanders/utils.cc ++++ b/mindspore/ccsrc/backend/common/graph_kernel/expanders/utils.cc +@@ -83,17 +83,17 @@ bool OpDesc::CheckOutputs() { + << outputs_info_[i].type << "]"; + return false; + } +-#ifdef MSLITE_ENABLE_GRAPH_KERNEL +- bool format_check_condition = +- (outputs[i]->format != kOpFormat_DEFAULT && outputs_info_[i].format != kOpFormat_DEFAULT) && +- outputs[i]->format != outputs_info_[i].format; +-#else ++ //#ifdef MSLITE_ENABLE_GRAPH_KERNEL ++ // bool format_check_condition = ++ // (outputs[i]->format != kOpFormat_DEFAULT && outputs_info_[i].format != kOpFormat_DEFAULT) && ++ // outputs[i]->format != outputs_info_[i].format; ++ //#else + bool format_check_condition = outputs[i]->format != outputs_info_[i].format; + if ((outputs[i]->format == kOpFormat_DEFAULT && outputs_info_[i].format == kOpFormat_NCHW) || + (outputs[i]->format == kOpFormat_NCHW && outputs_info_[i].format == kOpFormat_DEFAULT)) { + format_check_condition = false; + } +-#endif ++ //#endif + if (format_check_condition) { + MS_LOG(INFO) << "Op " << this->name_ << "'s output format [" << outputs[i]->format << "] is wrong, expect: [" + << outputs_info_[i].format << "]"; +diff --git a/mindspore/ccsrc/backend/common/graph_kernel/graph_kernel_flags.cc b/mindspore/ccsrc/backend/common/graph_kernel/graph_kernel_flags.cc +index 7f206d842d9..6e16e5a8a45 100644 +--- a/mindspore/ccsrc/backend/common/graph_kernel/graph_kernel_flags.cc ++++ b/mindspore/ccsrc/backend/common/graph_kernel/graph_kernel_flags.cc +@@ -230,17 +230,17 @@ void GraphKernelFlags::SaveJitConfig(const std::map &j + } + + std::pair GraphKernelFlags::GetGraphKernelConfig() { +-#ifdef MSLITE_ENABLE_GRAPH_KERNEL +- std::string flags = common::GetEnv("MS_DEV_GRAPH_KERNEL_FLAGS"); +- if (flags != "") { +- return std::make_pair(flags, false); +- } +- const auto &jit_config = GetJitConfig(); +- if (jit_config.find("graph_kernel_flags") != jit_config.end()) { +- flags = jit_config.at("graph_kernel_flags"); +- } +- return std::make_pair(flags, false); +-#else ++ //#ifdef MSLITE_ENABLE_GRAPH_KERNEL ++ // std::string flags = common::GetEnv("MS_DEV_GRAPH_KERNEL_FLAGS"); ++ // if (flags != "") { ++ // return std::make_pair(flags, false); ++ // } ++ // const auto &jit_config = GetJitConfig(); ++ // if (jit_config.find("graph_kernel_flags") != jit_config.end()) { ++ // flags = jit_config.at("graph_kernel_flags"); ++ // } ++ // return std::make_pair(flags, false); ++ //#else + const auto &jit_config = GetJitConfig(); + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); +@@ -270,15 +270,15 @@ std::pair GraphKernelFlags::GetGraphKernelConfig() { + flags = iter->second; + } + return std::make_pair(flags, enable_gk); +-#endif ++ //#endif + } + + void GraphKernelFlags::CheckSupport() const { +-#ifndef MSLITE_ENABLE_GRAPH_KERNEL ++ //#ifndef MSLITE_ENABLE_GRAPH_KERNEL + if (IsEnableGraphKernel()) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); +-#ifndef USE_LLVM ++ //#ifndef USE_LLVM + auto is_cpu = (context->get_param(MS_CTX_DEVICE_TARGET) == kCPUDevice); + if (is_cpu && const_cast(this)->kernel_generator == "AKG") { + MS_LOG(WARNING) +@@ -287,7 +287,7 @@ void GraphKernelFlags::CheckSupport() const { + const_cast(this)->opt_level = OptLevel_0; + return; + } +-#endif ++ //#endif + auto is_ascend = (context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice); + if (is_ascend) { + #ifndef ENABLE_DVM +@@ -312,7 +312,7 @@ void GraphKernelFlags::CheckSupport() const { + #endif + } + } +-#endif ++ //#endif + } + + void GraphKernelFlags::Refresh() { +@@ -327,11 +327,11 @@ void GraphKernelFlags::Refresh() { + "valid flags, please refer to the source code file graph_kernel_flags.h at " + "https://gitee.com/mindspore/mindspore."; + } +-#ifndef MSLITE_ENABLE_GRAPH_KERNEL ++ //#ifndef MSLITE_ENABLE_GRAPH_KERNEL + if (IsEnableGraphKernel()) { + CheckSupport(); + } +-#endif ++ //#endif + // If enable graphkernel, Dump flags so that people can check the setting. + if (IsEnableGraphKernel()) { + MS_LOG(INFO) << "graph_kernel_flags = \"" << flags_cache_ << "\", all flags: " << DumpAllFlags(); +@@ -432,9 +432,9 @@ void GraphKernelFlags::RegisterFlags(std::map *flag_ma + } + + if (is_ascend && !has_kernel_generator) { +-#ifndef MSLITE_ENABLE_GRAPH_KERNEL ++ //#ifndef MSLITE_ENABLE_GRAPH_KERNEL + kernel_generator = "DVM"; +-#endif ++ //#endif + } + if (kernel_generator == "DVM" && !has_enable_dynamic_shape_fusion) { + enable_dynamic_shape_fusion = true; +diff --git a/mindspore/ccsrc/backend/common/graph_kernel/graph_kernel_flags.h b/mindspore/ccsrc/backend/common/graph_kernel/graph_kernel_flags.h +index d6bffe7c817..a61973fd4f1 100644 +--- a/mindspore/ccsrc/backend/common/graph_kernel/graph_kernel_flags.h ++++ b/mindspore/ccsrc/backend/common/graph_kernel/graph_kernel_flags.h +@@ -45,7 +45,7 @@ class BACKEND_COMMON_EXPORT GraphKernelFlags { + // Dump all flags to json-format string + std::string DumpAllFlags() const; + +-#if defined(ENABLE_AKG) || defined(MSLITE_ENABLE_GRAPH_KERNEL) ++#if defined(ENABLE_AKG) + // Check whether graph_kernel is enabled + bool IsEnableGraphKernel() const { return opt_level > OptLevel_0; } + #else +diff --git a/mindspore/ccsrc/backend/common/optimizer/helper.cc b/mindspore/ccsrc/backend/common/optimizer/helper.cc +index bf64279f85c..2828e2fb79d 100644 +--- a/mindspore/ccsrc/backend/common/optimizer/helper.cc ++++ b/mindspore/ccsrc/backend/common/optimizer/helper.cc +@@ -765,65 +765,6 @@ AnfNodePtr CreateNodeBase(const FuncGraphPtr &graph, const std::vector(a) && utils::isa(b)) { +- auto a_node = utils::cast(a); +- auto b_node = utils::cast(b); +- MS_EXCEPTION_IF_NULL(a_node); +- MS_EXCEPTION_IF_NULL(b_node); +- if (IsValueNode(a_node) && IsValueNode(b_node)) { +- auto a_value_node = a_node->cast(); +- MS_EXCEPTION_IF_NULL(a_value_node); +- auto a_value = a_value_node->value(); +- MS_EXCEPTION_IF_NULL(a_value); +- auto a_prim = a_value->cast(); +- MS_EXCEPTION_IF_NULL(a_prim); +- +- auto b_value_node = b_node->cast(); +- MS_EXCEPTION_IF_NULL(b_value_node); +- auto b_value = b_value_node->value(); +- MS_EXCEPTION_IF_NULL(b_value); +- auto b_prim = b_value->cast(); +- MS_EXCEPTION_IF_NULL(b_prim); +- +- return a_prim->name() == b_prim->name(); +- } else if (a_node->isa() && b_node->isa()) { +- auto a_value_node_ptr = a_node->cast(); +- if (a_value_node_ptr == nullptr) { +- MS_LOG(INTERNAL_EXCEPTION) << "Cast value node ptr fail, node: " << a_node->DebugString(); +- } +- auto a_value_ptr = a_value_node_ptr->value(); +- if (a_value_ptr == nullptr) { +- MS_LOG(INTERNAL_EXCEPTION) << "Value ptr is nullptr, node: " << a_node->DebugString(); +- } +- +- auto b_value_node_ptr = b_node->cast(); +- if (b_value_node_ptr == nullptr) { +- MS_LOG(INTERNAL_EXCEPTION) << "Cast value node ptr fail, node: " << b_node->DebugString(); +- } +- auto b_value_ptr = b_value_node_ptr->value(); +- if (b_value_ptr == nullptr) { +- MS_LOG(INTERNAL_EXCEPTION) << "Value ptr is nullptr, node: " << b_node->DebugString(); +- } +- if (a_value_ptr->isa() && b_value_ptr->isa()) { +- auto a_tensor_ptr = a_value_ptr->cast(); +- auto b_tensor_ptr = b_value_ptr->cast(); +- if (a_tensor_ptr == nullptr || b_tensor_ptr == nullptr) { +- MS_LOG(INTERNAL_EXCEPTION) << "Cast value node ptr fail."; +- } +- return a_tensor_ptr->ValueEqual(*b_tensor_ptr); +- } else { +- return (*a_value_ptr) == (*b_value_ptr); +- } +- } +- MS_LOG(DEBUG) << "check AnfNodePtr equal"; +- } +- if (utils::isa(a) && utils::isa(b)) { +- MS_LOG(DEBUG) << "check GraphPtr equal"; +- } +- return a == b; +-} +- + bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { + // To matchCNode and Kernel's type + if (utils::isa(a) && utils::isa(b)) { +diff --git a/mindspore/ccsrc/backend/common/optimizer/inplace_node_pass.cc b/mindspore/ccsrc/backend/common/optimizer/inplace_node_pass.cc +index d5d52bad5e7..66e25d37e15 100644 +--- a/mindspore/ccsrc/backend/common/optimizer/inplace_node_pass.cc ++++ b/mindspore/ccsrc/backend/common/optimizer/inplace_node_pass.cc +@@ -15,7 +15,7 @@ + */ + + #include "include/backend/optimizer/inplace_node_pass.h" +-#include "include/backend/optimizer/helper.h" ++#include "utils/anf_utils.h" + + namespace mindspore { + namespace opt { +@@ -40,7 +40,7 @@ AnfNodePtr InplaceNodePass::Run(const FuncGraphPtr &, const AnfNodePtr &node) { + for (size_t i = 0; i < inputs.size(); i++) { + MS_EXCEPTION_IF_NULL(inputs[i]); + MS_EXCEPTION_IF_NULL(pre_inputs[i]); +- if (!opt::AnfEqual(inputs[i], pre_inputs[i])) { ++ if (!AnfUtils::AnfEqual(inputs[i], pre_inputs[i])) { + MS_LOG_WITH_NODE(EXCEPTION, node) + << "InplaceNodePass ERROR, the pass modify node: " << node->DebugString() << ", pass name: " << name() + << ", before node " << i << ":" << inputs[i]->DebugString() << ", after node " << i << ":" +diff --git a/mindspore/ccsrc/backend/common/optimizer/pattern_engine.cc b/mindspore/ccsrc/backend/common/optimizer/pattern_engine.cc +index 3e0ab70388a..ffe9d19c6a5 100644 +--- a/mindspore/ccsrc/backend/common/optimizer/pattern_engine.cc ++++ b/mindspore/ccsrc/backend/common/optimizer/pattern_engine.cc +@@ -18,6 +18,7 @@ + #include "ir/anf.h" + #include "utils/convert_utils_base.h" + #include "include/backend/optimizer/helper.h" ++#include "utils/anf_utils.h" + + namespace mindspore { + static int GetNextTag() { +@@ -270,7 +271,7 @@ EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const + } + + // 2. check equal +- if (opt::AnfEqual(pattern_ref, expr_ref)) { ++ if (AnfUtils::AnfEqual(pattern_ref, expr_ref)) { + return equiv; + } + +diff --git a/mindspore/ccsrc/backend/common/optimizer/pattern_to_pattern.cc b/mindspore/ccsrc/backend/common/optimizer/pattern_to_pattern.cc +index 8fb8b66e61c..0ee3bbb32f8 100644 +--- a/mindspore/ccsrc/backend/common/optimizer/pattern_to_pattern.cc ++++ b/mindspore/ccsrc/backend/common/optimizer/pattern_to_pattern.cc +@@ -75,7 +75,7 @@ bool PatternMap::Emplace(const std::string &name, const AnfNodePtr &node) { + auto iter = node_map_.find(name); + if (iter == node_map_.end()) { + node_map_.emplace(name, node); +- } else if (!opt::AnfEqual(node, iter->second)) { ++ } else if (!AnfUtils::AnfEqual(node, iter->second)) { + MS_EXCEPTION_IF_NULL(iter->second); + MS_LOG(INFO) << "The value of key: " << name + << " is not equal to origin value, value: " + node->fullname_with_scope() +@@ -110,7 +110,7 @@ bool PatternMap::Emplace(const std::string &name, const std::vector + for (size_t i = 0; i < v.size(); i++) { + MS_EXCEPTION_IF_NULL(v[i]); + MS_EXCEPTION_IF_NULL(origin_v[i]); +- if (!opt::AnfEqual(v[i], origin_v[i])) { ++ if (!AnfUtils::AnfEqual(v[i], origin_v[i])) { + MS_LOG(INFO) << "The value of key: " << name + << " is not equal to origin value, value: " + v[i]->fullname_with_scope() + << " origin value: " << origin_v[i]->fullname_with_scope(); +@@ -127,7 +127,9 @@ void PatternMap::Clear() { + seq_map_.clear(); + } + +-bool PatternMap::Check(const std::string &name, const AnfNodePtr &node) const { return opt::AnfEqual(node, Get(name)); } ++bool PatternMap::Check(const std::string &name, const AnfNodePtr &node) const { ++ return AnfUtils::AnfEqual(node, Get(name)); ++} + + SrcPattern &SrcPattern::AddVar(const std::string &name, const PatternConditionFunc &f) { + if (ref_map_.find(name) != ref_map_.end()) { +@@ -256,7 +258,7 @@ bool SrcPattern::match(const std::string &name, const AnfNodePtr &node, const Eq + // prim + MS_EXCEPTION_IF_NULL(pattern_node.p_); + MS_EXCEPTION_IF_NULL(match_node); +- if (!opt::AnfEqual(pattern_node.p_, match_node)) { ++ if (!AnfUtils::AnfEqual(pattern_node.p_, match_node)) { + MS_LOG(EXCEPTION) << "The value of Primitive is not equal to matched value, pattern value: " + + pattern_node.p_->ToString() + << " matched value: " + match_node->ToString() + ", node name: " + name; +@@ -367,7 +369,7 @@ DstPattern &DstPattern::AddCNode(const string &name, const std::initializer_list + for (size_t i = 0; i < anf_inputs.size(); i++) { + MS_EXCEPTION_IF_NULL(anf_inputs[i]); + MS_EXCEPTION_IF_NULL(cnode->input(i)); +- if (!opt::AnfEqual(anf_inputs[i], cnode->input(i))) { ++ if (!AnfUtils::AnfEqual(anf_inputs[i], cnode->input(i))) { + MS_LOG(INTERNAL_EXCEPTION) + << "The actual input does not correspond to the input of the pattern, the input index: " << i + << ", actual input: " << anf_inputs[i]->DebugString() +diff --git a/mindspore/ccsrc/backend/common/pass/custom_defined_depend.cc b/mindspore/ccsrc/backend/common/pass/custom_defined_depend.cc +index 1e26f3dfe5f..a5a312b2a55 100644 +--- a/mindspore/ccsrc/backend/common/pass/custom_defined_depend.cc ++++ b/mindspore/ccsrc/backend/common/pass/custom_defined_depend.cc +@@ -58,13 +58,13 @@ bool FileExists(const string &filename) { + + std::string GetRankID() { + uint32_t rank_id = 0; +-#if !defined(BUILD_LITE) ++//#if !defined(BUILD_LITE) + if (distributed::collective::CollectiveManager::instance()->initialized()) { + rank_id = CommManager::GetInstance().GetRank(); + } else { + rank_id = MsContext::GetInstance()->get_param(MS_CTX_DEVICE_ID); + } +-#endif ++//#endif + return std::to_string(rank_id); + } + +diff --git a/mindspore/ccsrc/backend/common/pass/other/add_attr_to_dump.cc b/mindspore/ccsrc/backend/common/pass/other/add_attr_to_dump.cc +index afe0eeac9bf..a617e5d073c 100644 +--- a/mindspore/ccsrc/backend/common/pass/other/add_attr_to_dump.cc ++++ b/mindspore/ccsrc/backend/common/pass/other/add_attr_to_dump.cc +@@ -69,7 +69,6 @@ const AnfNodePtr AddAttrToDump::Process(const FuncGraphPtr &func_graph, const An + auto primitive = GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + +-#if !defined(BUILD_LITE) + if (common::GetDumpSliceSize() > 0) { + constexpr int64_t kMegaBytes = 1LL << 20; + int64_t slice_size_in_bytes = common::GetDumpSliceSize() * kMegaBytes; +@@ -77,7 +76,6 @@ const AnfNodePtr AddAttrToDump::Process(const FuncGraphPtr &func_graph, const An + (void)primitive->AddAttr("wait_time", MakeValue(common::GetDumpWaitTime())); + (void)primitive->AddAttr("slice_sync", MakeValue(true)); + } +-#endif + + return cnode; + } +diff --git a/mindspore/ccsrc/backend/ge_backend/CMakeLists.txt b/mindspore/ccsrc/backend/ge_backend/CMakeLists.txt +index d88d0073397..341fe5d59c2 100644 +--- a/mindspore/ccsrc/backend/ge_backend/CMakeLists.txt ++++ b/mindspore/ccsrc/backend/ge_backend/CMakeLists.txt +@@ -7,7 +7,7 @@ if(ENABLE_D OR ENABLE_ACL) + list(APPEND _GE_BACKEND_SRC_LIST $) + list(APPEND _GE_BACKEND_SRC_LIST $) + list(APPEND _GE_BACKEND_SRC_LIST $) +- ++ list(APPEND _GE_BACKEND_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/graph_ir/callbacks_ge.cc") + + set_property(SOURCE ${_GE_BACKEND_SRC_LIST} + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_RUNTIME_FRAMEWORK) +diff --git a/mindspore/ccsrc/backend/ge_backend/graph_ir/CMakeLists.txt b/mindspore/ccsrc/backend/ge_backend/graph_ir/CMakeLists.txt +index 9ab80751d81..006c794587e 100644 +--- a/mindspore/ccsrc/backend/ge_backend/graph_ir/CMakeLists.txt ++++ b/mindspore/ccsrc/backend/ge_backend/graph_ir/CMakeLists.txt +@@ -1,27 +1,54 @@ + if(ENABLE_D OR ENABLE_ACL) + file(GLOB_RECURSE _TRANSFORM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +- if(BUILD_LITE) +- list(REMOVE_ITEM _TRANSFORM_SRC_LIST "callbacks_ge.cc") +- endif() + set_property(SOURCE ${_TRANSFORM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS + SUBMODULE_ID=mindspore::SubModuleId::SM_GE_ADPT) + +- # mindspore_graph_ir is used by GE and lite. +- if(BUILD_LITE) ++ # mindspore_graph_ir is used by GE and lite. ++ if(TARGET mindspore_ascend_res_manager) ++ add_library(mindspore_graph_ir SHARED ${_TRANSFORM_SRC_LIST}) ++ target_link_libraries(mindspore_graph_ir PRIVATE mindspore_ascend_res_manager) ++ else() ++ file(STRINGS "${CMAKE_CURRENT_SOURCE_DIR}/../../../../../version.txt" VERSION) ++ add_definitions(-DVERSION="${VERSION}") ++ list(REMOVE_ITEM _TRANSFORM_SRC_LIST "callbacks_ge.cc") + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../plugin/res_manager/ascend/op_adapter/ + _mindspore_ascend_op_adapter_obj) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../plugin/res_manager/ascend/symbol_interface + _mindspore_ascend_symbol_obj) + list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../utils/config_manager.cc") + list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../common/debug/common.cc") ++ ++ list(APPEND _TRANSFORM_SRC_LIST ++ "${CMAKE_CURRENT_SOURCE_DIR}/../../../backend/common/optimizer/graph_optimizer.cc") ++ list(APPEND _TRANSFORM_SRC_LIST ++ "${CMAKE_CURRENT_SOURCE_DIR}/../../../backend/common/optimizer/pattern_engine.cc") ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../backend/common/optimizer/visitor.cc") ++ ++ if(NOT WIN32) ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../utils/anfalgo.cc") ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../utils/utils.cc") ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../utils/parallel_context.cc") ++ endif() ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../utils/convert_utils.cc") ++ ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../kernel/kernel_info.cc") ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../utils/comm_manager.cc") ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../utils/compile_cache_context.cc") ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../common/debug/mindir_exporter.cc") ++ ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../../ops/kernel/common/format_utils.cc") ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../../ops/kernel/common/kernel_factory.cc") ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../../ops/kernel/common/kernel_tensor.cc") ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../../ops/kernel/common/device_address.cc") ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../../ops/kernel/common/kernel.cc") ++ list(APPEND _TRANSFORM_SRC_LIST ++ "${CMAKE_CURRENT_SOURCE_DIR}/../../../../ops/kernel/common/kernel_build_info.cc") ++ list(APPEND _TRANSFORM_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/../../../minddata/dataset/core/types.cc") + add_library(mindspore_graph_ir SHARED ${_TRANSFORM_SRC_LIST} $ + $) +- else() +- add_library(mindspore_graph_ir SHARED ${_TRANSFORM_SRC_LIST}) +- target_link_libraries(mindspore_graph_ir PRIVATE mindspore_ascend_res_manager) + endif() + +- target_link_libraries(mindspore_graph_ir PRIVATE mindspore_core mindspore_ops) ++ target_link_libraries(mindspore_graph_ir PRIVATE mindspore_core mindspore_ops mindspore::protobuf) + find_library(ACL ascendcl ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(GE_RUNNER ge_runner ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(GRAPH graph ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +diff --git a/mindspore/ccsrc/backend/ge_backend/runtime/actor/actor_common.cc b/mindspore/ccsrc/backend/ge_backend/runtime/actor/actor_common.cc +index b77d1e04574..e693f4c8cc6 100644 +--- a/mindspore/ccsrc/backend/ge_backend/runtime/actor/actor_common.cc ++++ b/mindspore/ccsrc/backend/ge_backend/runtime/actor/actor_common.cc +@@ -25,10 +25,8 @@ + #include "utils/ms_utils.h" + #include "include/common/utils/anfalgo.h" + #include "include/backend/mem_reuse/mem_tracker.h" +-#ifndef BUILD_LITE + #include "backend/ge_backend/runtime/actor/memory_manager_actor.h" + #include "backend/ge_backend/utils/device_address_utils.h" +-#endif + #include "runtime/device/res_manager/hal_res_manager.h" + + namespace mindspore { +diff --git a/mindspore/ccsrc/cmake/ascend_compile_config.cmake b/mindspore/ccsrc/cmake/ascend_compile_config.cmake +index d72efcde036..cfb774ee751 100644 +--- a/mindspore/ccsrc/cmake/ascend_compile_config.cmake ++++ b/mindspore/ccsrc/cmake/ascend_compile_config.cmake +@@ -1,9 +1,9 @@ + include(${CMAKE_SOURCE_DIR}/cmake/graphengine_variables.cmake) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/plugin/device/ascend) +-add_subdirectory(backend/ge_backend) + add_subdirectory(plugin/res_manager/ascend) + add_subdirectory(plugin/res_manager/ascend/collective) + add_subdirectory(plugin/device/ascend) ++add_subdirectory(backend/ge_backend) + enable_directory_when_only_build_plugins(plugin/device/ascend) + enable_directory_when_only_build_plugins(plugin/res_manager/ascend/collective) + enable_directory_when_only_build_plugins(plugin/res_manager/ascend/hccl_adapter/plugin) +diff --git a/mindspore/ccsrc/debug/execute_order_tracker/execute_order_tracker.cc b/mindspore/ccsrc/debug/execute_order_tracker/execute_order_tracker.cc +index 40487dbd566..89a298f7c1c 100644 +--- a/mindspore/ccsrc/debug/execute_order_tracker/execute_order_tracker.cc ++++ b/mindspore/ccsrc/debug/execute_order_tracker/execute_order_tracker.cc +@@ -237,18 +237,12 @@ std::vector ExecuteOrderTracker::GetCommRanks(const std::string &group + if (group_name == "hccl_world_group") { + uint32_t rank_size = 1; + +-#if !defined(BUILD_LITE) + rank_size = distributed::collective::CollectiveManager::instance()->global_rank_size(); +-#endif + + comm_ranks.resize(rank_size); + std::iota(comm_ranks.begin(), comm_ranks.end(), 0); + } else { +-#if !defined(BUILD_LITE) + comm_ranks = distributed::collective::CollectiveManager::instance()->GetGroupRanks(group_name); +-#else +- comm_ranks = {0}; +-#endif + } + + comm_ranks_cache_[group_name] = comm_ranks; +diff --git a/mindspore/ccsrc/include/backend/data_queue/data_queue_mgr.h b/mindspore/ccsrc/include/backend/data_queue/data_queue_mgr.h +index 3f3f1055983..1ca2cc21296 100644 +--- a/mindspore/ccsrc/include/backend/data_queue/data_queue_mgr.h ++++ b/mindspore/ccsrc/include/backend/data_queue/data_queue_mgr.h +@@ -28,10 +28,8 @@ + #include "utils/callback_handler.h" + #include "include/backend/visible.h" + #include "include/backend/data_queue/data_queue.h" +-#ifndef BUILD_LITE + #include "ir/anf.h" + #include "common/kernel.h" +-#endif + + namespace mindspore { + namespace device { +@@ -143,7 +141,6 @@ class BACKEND_EXPORT DataQueueMgr { + HANDLER_DEFINE(bool, CleanTdtHandle); + HANDLER_DEFINE(bool, DestoryTdtHandle); + }; +-#ifndef BUILD_LITE + BACKEND_EXPORT void UpdateGetNextNode(const AnfNodePtr &data_kernel); + + BACKEND_EXPORT void UpdateGetNextNode(const PrimitivePtr &primitive, const std::vector &inputs, +@@ -161,7 +158,6 @@ BACKEND_EXPORT void UpdateGetNextWithDataQueueItems(const std::vector &data_queue, + std::vector *data); +-#endif + #define REGISTER_DATA_QUEUE_CREATOR(device_name, creator) \ + struct device_name##DataQueueCreatorClass { \ + device_name##DataQueueCreatorClass() { \ +diff --git a/mindspore/ccsrc/include/backend/optimizer/helper.h b/mindspore/ccsrc/include/backend/optimizer/helper.h +index 554d2c655bc..39f03561237 100644 +--- a/mindspore/ccsrc/include/backend/optimizer/helper.h ++++ b/mindspore/ccsrc/include/backend/optimizer/helper.h +@@ -218,8 +218,6 @@ BACKEND_COMMON_EXPORT std::shared_ptr>> G + const FuncGraphPtr &graph, const AnfNodePtr &node, size_t output_index); + bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); + +-bool AnfEqual(const BaseRef &a, const BaseRef &b); +- + bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); + + AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, +diff --git a/mindspore/ccsrc/include/common/thread_pool.h b/mindspore/ccsrc/include/common/thread_pool.h +index 86d96e6d620..f9fb681e74c 100644 +--- a/mindspore/ccsrc/include/common/thread_pool.h ++++ b/mindspore/ccsrc/include/common/thread_pool.h +@@ -31,9 +31,6 @@ + #include "utils/log_adapter.h" + #include "include/common/visible.h" + +-#ifdef PARALLEL_INFERENCE +-#define BIND_CORE +-#endif + #ifdef __ANDROID__ + #define BIND_CORE + #include +diff --git a/mindspore/ccsrc/kernel/graph_kernel/graph_kernel_json_generator.cc b/mindspore/ccsrc/kernel/graph_kernel/graph_kernel_json_generator.cc +index d98e0d2678e..97e0c29efd0 100644 +--- a/mindspore/ccsrc/kernel/graph_kernel/graph_kernel_json_generator.cc ++++ b/mindspore/ccsrc/kernel/graph_kernel/graph_kernel_json_generator.cc +@@ -29,14 +29,14 @@ + #include "backend/common/graph_kernel/graph_kernel_flags.h" + #include "kernel/graph_kernel/graph_kernel_json_flags.h" + #include "include/common/symbol_engine/symbol_engine_impl.h" +-#ifdef MSLITE_ENABLE_GRAPH_KERNEL +-#ifdef ENABLE_GPU +-#include +-#endif +-#else ++//#ifdef MSLITE_ENABLE_GRAPH_KERNEL ++//#ifdef ENABLE_GPU ++//#include ++//#endif ++//#else + #include "common/oplib/oplib.h" + #include "runtime/hardware/device_context_manager.h" +-#endif ++//#endif + #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" + + namespace mindspore::graphkernel { +@@ -773,9 +773,9 @@ OpInfoPtr GraphKernelJsonGenerator::ExtractOpInfo(const AnfNodePtr &anf_node) co + OpInfoExtractor e; + return e.Run(anf_node); + } else { +-#ifdef MSLITE_ENABLE_GRAPH_KERNEL +- MS_LOG(EXCEPTION) << "OpLib is not supported."; +-#else ++ //#ifdef MSLITE_ENABLE_GRAPH_KERNEL ++ // MS_LOG(EXCEPTION) << "OpLib is not supported."; ++ //#else + OpImplyType imply_type; + const auto &flags = GraphKernelFlags::GetInstance(); + +@@ -785,7 +785,7 @@ OpInfoPtr GraphKernelJsonGenerator::ExtractOpInfo(const AnfNodePtr &anf_node) co + imply_type = OpImplyType::kImplyAKG; + } + return kernel::OpLib::FindOp(AnfUtils::GetCNodeName(anf_node), imply_type); +-#endif ++ //#endif + } + } + +@@ -1385,45 +1385,45 @@ void GetCpuInfo(nlohmann::json *target_info) { + return; + } + +-#ifdef MSLITE_ENABLE_GRAPH_KERNEL +-#ifdef ENABLE_GPU +-bool GetGpuInfo(nlohmann::json *target_info) { +- int major_version = -1; +- auto ret = cuDeviceGetAttribute(&major_version, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, 0); +- if (ret != CUDA_SUCCESS) { +- const char *msg = nullptr; +- cuGetErrorName(ret, &msg); +- MS_LOG(WARNING) << "Get CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR fail, error message: " << msg; +- return false; +- } +- int minor_version = -1; +- auto ret = cuDeviceGetAttribute(&minor_version, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, 0); +- if (ret != CUDA_SUCCESS) { +- const char *msg = nullptr; +- cuGetErrorName(ret, &msg); +- MS_LOG(WARNING) << "Get CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR fail, error message: " << msg; +- return false; +- } +- int sm_count = -1; +- auto ret = cuDeviceGetAttribute(&sm_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, 0); +- if (ret != CUDA_SUCCESS) { +- const char *msg = nullptr; +- cuGetErrorName(ret, &msg); +- MS_LOG(WARNING) << "Get CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT fail, error message: " << msg; +- return false; +- } +- if (major_version == -1 || minor_version == -1 || sm_count == -1) { +- return false; +- } else { +- (*target_info)[kJsonKeyComputeCapability] = std::to_string(major_version) + "." + std::to_string(minor_version); +- (*target_info)[kJsonKeySmCount] = sm_count; +- } +- return true; +-} +-#else +-bool GetGpuInfo(nlohmann::json *target_info) { return false; } +-#endif +-#else ++//#ifdef MSLITE_ENABLE_GRAPH_KERNEL ++//#ifdef ENABLE_GPU ++// bool GetGpuInfo(nlohmann::json *target_info) { ++// int major_version = -1; ++// auto ret = cuDeviceGetAttribute(&major_version, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, 0); ++// if (ret != CUDA_SUCCESS) { ++// const char *msg = nullptr; ++// cuGetErrorName(ret, &msg); ++// MS_LOG(WARNING) << "Get CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR fail, error message: " << msg; ++// return false; ++// } ++// int minor_version = -1; ++// auto ret = cuDeviceGetAttribute(&minor_version, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, 0); ++// if (ret != CUDA_SUCCESS) { ++// const char *msg = nullptr; ++// cuGetErrorName(ret, &msg); ++// MS_LOG(WARNING) << "Get CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR fail, error message: " << msg; ++// return false; ++// } ++// int sm_count = -1; ++// auto ret = cuDeviceGetAttribute(&sm_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, 0); ++// if (ret != CUDA_SUCCESS) { ++// const char *msg = nullptr; ++// cuGetErrorName(ret, &msg); ++// MS_LOG(WARNING) << "Get CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT fail, error message: " << msg; ++// return false; ++// } ++// if (major_version == -1 || minor_version == -1 || sm_count == -1) { ++// return false; ++// } else { ++// (*target_info)[kJsonKeyComputeCapability] = std::to_string(major_version) + "." + std::to_string(minor_version); ++// (*target_info)[kJsonKeySmCount] = sm_count; ++// } ++// return true; ++//} ++//#else ++// bool GetGpuInfo(nlohmann::json *target_info) { return false; } ++//#endif ++//#else + bool GetGpuInfo(nlohmann::json *target_info) { + const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext( + {kGPUDevice, MsContext::GetInstance()->get_param(MS_CTX_DEVICE_ID)}); +@@ -1441,7 +1441,7 @@ bool GetGpuInfo(nlohmann::json *target_info) { + } + return true; + } +-#endif ++//#endif + } // namespace + + void TargetInfoSetter::GetTargetInfo() { +diff --git a/mindspore/ccsrc/minddata/dataset/api/vision.cc b/mindspore/ccsrc/minddata/dataset/api/vision.cc +index 435128d3eeb..034bc9422c0 100644 +--- a/mindspore/ccsrc/minddata/dataset/api/vision.cc ++++ b/mindspore/ccsrc/minddata/dataset/api/vision.cc +@@ -93,12 +93,6 @@ + #include "minddata/dataset/kernels/ir/vision/vertical_flip_ir.h" + #include "minddata/dataset/util/log_adapter.h" + +-#if defined(ENABLE_MINDDATA_PYTHON) +-#include "minddata/dataset/kernels/ir/vision/pad_ir.h" +-#include "minddata/dataset/kernels/ir/vision/rescale_ir.h" +-#include "minddata/dataset/kernels/ir/vision/swap_red_blue_ir.h" +-#endif +- + #include "minddata/dataset/kernels/image/image_utils.h" + #if defined(ENABLE_FFMPEG) + #include "minddata/dataset/kernels/image/video_utils.h" +@@ -650,12 +644,8 @@ Pad::Pad(const std::vector &padding, const std::vector &fill_v + : data_(std::make_shared(padding, fill_value, padding_mode)) {} + + std::shared_ptr Pad::Parse() { +-#if defined(ENABLE_MINDDATA_PYTHON) +- return std::make_shared(data_->padding_, data_->fill_value_, data_->padding_mode_); +-#else + MS_LOG(ERROR) << "Unsupported Pad."; + return nullptr; +-#endif + } + + // PadToSize Transform Operation. +@@ -1239,12 +1229,8 @@ struct Rescale::Data { + Rescale::Rescale(float rescale, float shift) : data_(std::make_shared(rescale, shift)) {} + + std::shared_ptr Rescale::Parse() { +-#if defined(ENABLE_MINDDATA_PYTHON) +- return std::make_shared(data_->rescale_, data_->shift_); +-#else + MS_LOG(ERROR) << "Unsupported Rescale."; + return nullptr; +-#endif + } + + // Resize Transform Operation. +@@ -1410,12 +1396,8 @@ std::shared_ptr Solarize::Parse() { return std::make_shared SwapRedBlue::Parse() { +-#if defined(ENABLE_MINDDATA_PYTHON) +- return std::make_shared(); +-#else + MS_LOG(ERROR) << "Unsupported SwapRedBlue."; + return nullptr; +-#endif + } + + // ToTensor Transform Operation. +diff --git a/mindspore/ccsrc/minddata/dataset/core/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/core/CMakeLists.txt +index d4de59f7a63..7c54cdb769e 100644 +--- a/mindspore/ccsrc/minddata/dataset/core/CMakeLists.txt ++++ b/mindspore/ccsrc/minddata/dataset/core/CMakeLists.txt +@@ -33,12 +33,12 @@ if(ENABLE_D) + ) + endif() + +-if(NOT MSLITE_ENABLE_ACL) +- set(DATASET_CORE_SRC_FILES +- ${DATASET_CORE_SRC_FILES} +- types.cc # in lite, src code has types.cc impl +- ) +-endif() ++#if(NOT MSLITE_ENABLE_ACL) ++# set(DATASET_CORE_SRC_FILES ++# ${DATASET_CORE_SRC_FILES} ++# types.cc # in lite, src code has types.cc impl ++# ) ++#endif() + + ms_protobuf_generate(EXAMPLE_SRCS EXAMPLE_HDRS example.proto) + ms_protobuf_generate(FEATURE_SRCS FEATURE_HDRS feature.proto) +diff --git a/mindspore/ccsrc/minddata/dataset/core/data_type.cc b/mindspore/ccsrc/minddata/dataset/core/data_type.cc +index 79ef80ab93b..98a72ca2abd 100644 +--- a/mindspore/ccsrc/minddata/dataset/core/data_type.cc ++++ b/mindspore/ccsrc/minddata/dataset/core/data_type.cc +@@ -14,10 +14,6 @@ + * limitations under the License. + */ + #include "minddata/dataset/core/data_type.h" +-#ifdef ENABLE_MINDDATA_PYTHON +-#include "minddata/dataset/core/pybind_support.h" +-#endif +- + #include "minddata/dataset/util/log_adapter.h" + + namespace mindspore { +@@ -30,64 +26,6 @@ uint8_t DataType::SizeInBytes() const { + } + } + +-#ifdef ENABLE_MINDDATA_PYTHON +-py::dtype DataType::AsNumpyType() const { +- if (type_ < DataType::NUM_OF_TYPES) { +- return py::dtype(kTypeInfo[type_].pybindType_); +- } else { +- return py::dtype("unknown"); +- } +-} +-#endif +- +-#if defined(ENABLE_MINDDATA_PYTHON) +-uint8_t DataType::AsCVType() const { +- uint8_t res = kCVInvalidType; +- if (type_ < DataType::NUM_OF_TYPES) { +- res = kTypeInfo[type_].cvType_; +- } +- +- if (res == kCVInvalidType) { +- std::string type_name = "unknown"; +- if (type_ < DataType::NUM_OF_TYPES) { +- type_name = std::string(kTypeInfo[type_].name_); +- } +- std::string err_msg = "Cannot convert [" + type_name + "] to OpenCV type."; +- err_msg += " Currently unsupported data type: [uint32, int64, uint64, string]"; +- MS_LOG(ERROR) << err_msg; +- } +- +- return res; +-} +- +-DataType DataType::FromCVType(int cv_type) { +- auto depth = static_cast(cv_type) & static_cast(CV_MAT_DEPTH_MASK); +- switch (depth) { +- case CV_8S: +- return DataType(DataType::DE_INT8); +- case CV_8U: +- return DataType(DataType::DE_UINT8); +- case CV_16S: +- return DataType(DataType::DE_INT16); +- case CV_16U: +- return DataType(DataType::DE_UINT16); +- case CV_32S: +- return DataType(DataType::DE_INT32); +- case CV_16F: +- return DataType(DataType::DE_FLOAT16); +- case CV_32F: +- return DataType(DataType::DE_FLOAT32); +- case CV_64F: +- return DataType(DataType::DE_FLOAT64); +- default: +- std::string err_msg = "Cannot convert from OpenCV type, unknown CV type."; +- err_msg += " Currently supported data type: [int8, uint8, int16, uint16, int32, float16, float32, float64]"; +- MS_LOG(ERROR) << err_msg; +- return DataType(DataType::DE_UNKNOWN); +- } +-} +-#endif +- + DataType::DataType(const std::string &type_str) { + if (type_str == "bool") { + type_ = DE_BOOL; +@@ -117,10 +55,6 @@ DataType::DataType(const std::string &type_str) { + type_ = DE_STRING; + } else if (type_str == "bytes") { + type_ = DE_BYTES; +-#ifdef ENABLE_MINDDATA_PYTHON +- } else if (type_str == "python") { +- type_ = DE_PYTHON; +-#endif + } else { + type_ = DE_UNKNOWN; + } +@@ -133,61 +67,5 @@ std::string DataType::ToString() const { + return "unknown"; + } + } +- +-#ifdef ENABLE_MINDDATA_PYTHON +-DataType DataType::FromNpArray(const py::array &arr) { +- if (py::isinstance>(arr)) { +- return DataType(DataType::DE_BOOL); +- } else if (py::isinstance>(arr)) { +- return DataType(DataType::DE_INT8); +- } else if (py::isinstance>(arr)) { +- return DataType(DataType::DE_UINT8); +- } else if (py::isinstance>(arr)) { +- return DataType(DataType::DE_INT16); +- } else if (py::isinstance>(arr)) { +- return DataType(DataType::DE_UINT16); +- } else if (py::isinstance>(arr)) { +- return DataType(DataType::DE_INT32); +- } else if (py::isinstance>(arr)) { +- return DataType(DataType::DE_UINT32); +- } else if (py::isinstance>(arr)) { +- return DataType(DataType::DE_INT64); +- } else if (py::isinstance>(arr)) { +- return DataType(DataType::DE_UINT64); +- } else if (py::isinstance>(arr)) { +- return DataType(DataType::DE_FLOAT16); +- } else if (py::isinstance>(arr)) { +- return DataType(DataType::DE_FLOAT32); +- } else if (py::isinstance>(arr)) { +- return DataType(DataType::DE_FLOAT64); +- } else if (arr.dtype().kind() == 'U') { +- return DataType(DataType::DE_STRING); +- } else if (arr.dtype().kind() == 'S') { +- return DataType(DataType::DE_BYTES); +- } else { +- if (arr.size() == 0) { +- MS_LOG(ERROR) << "Please check input data, the data of numpy array is empty."; +- } +- std::string err_msg = "Cannot convert from numpy type. Unknown data type is returned!"; +- err_msg += +- " Currently supported data type: [int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, " +- "float64, string, bytes]"; +- MS_LOG(ERROR) << err_msg; +- return DataType(DataType::DE_UNKNOWN); +- } +-} +- +-std::string DataType::GetPybindFormat() const { +- std::string res; +- if (type_ < DataType::NUM_OF_TYPES) { +- res = kTypeInfo[type_].pybindFormatDescriptor_; +- } +- +- if (res.empty()) { +- MS_LOG(ERROR) << "Cannot convert from data type to pybind format descriptor!"; +- } +- return res; +-} +-#endif + } // namespace dataset + } // namespace mindspore +diff --git a/mindspore/ccsrc/minddata/dataset/core/data_type.h b/mindspore/ccsrc/minddata/dataset/core/data_type.h +index 1bfa80cbe10..22450e42e2d 100644 +--- a/mindspore/ccsrc/minddata/dataset/core/data_type.h ++++ b/mindspore/ccsrc/minddata/dataset/core/data_type.h +@@ -16,22 +16,10 @@ + #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_ + #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DATA_TYPE_H_ + +-#if defined(ENABLE_MINDDATA_PYTHON) +-#include +-#endif +- + #include + #include +- +-#ifdef ENABLE_MINDDATA_PYTHON +-#include "pybind11/numpy.h" +-#include "pybind11/pybind11.h" +-#include "minddata/dataset/core/pybind_support.h" +-namespace py = pybind11; +-#else + #include "base/bfloat16.h" + #include "base/float16.h" +-#endif + #include "minddata/dataset/include/dataset/constants.h" + + namespace mindspore { +@@ -67,48 +55,6 @@ class DataType { + const uint8_t cvType_; // OpenCv matching type + }; + +-#ifdef ENABLE_MINDDATA_PYTHON +- static inline const TypeInfo kTypeInfo[] = { +- // name, sizeInBytes, pybindType, pybindFormatDescriptor, openCV +- {"unknown", 0, "object", "", kCVInvalidType}, // DE_UNKNOWN +- {"bool", 1, "bool", py::format_descriptor::format(), CV_8U}, // DE_BOOL +- {"int8", 1, "int8", py::format_descriptor::format(), CV_8S}, // DE_INT8 +- {"uint8", 1, "uint8", py::format_descriptor::format(), CV_8U}, // DE_UINT8 +- {"int16", 2, "int16", py::format_descriptor::format(), CV_16S}, // DE_INT16 +- {"uint16", 2, "uint16", py::format_descriptor::format(), CV_16U}, // DE_UINT16 +- {"int32", 4, "int32", py::format_descriptor::format(), CV_32S}, // DE_INT32 +- {"uint32", 4, "uint32", py::format_descriptor::format(), kCVInvalidType}, // DE_UINT32 +- {"int64", 8, "int64", py::format_descriptor::format(), kCVInvalidType}, // DE_INT64 +- {"uint64", 8, "uint64", py::format_descriptor::format(), kCVInvalidType}, // DE_UINT64 +- {"float16", 2, "float16", "e", CV_16F}, // DE_FLOAT16 +- {"float32", 4, "float32", py::format_descriptor::format(), CV_32F}, // DE_FLOAT32 +- {"float64", 8, "double", py::format_descriptor::format(), CV_64F}, // DE_FLOAT64 +- {"string", 0, "str", "U", kCVInvalidType}, // DE_STRING +- {"bytes", 0, "bytes", "S", CV_8U}, // DE_BYTES +- {"python", 0, "object", "O", kCVInvalidType} // DE_PYTHON +- }; +-#else +-#if defined(ENABLE_MINDDATA_PYTHON) +- static inline const TypeInfo kTypeInfo[] = { +- // name, sizeInBytes, pybindTypem formatDescriptor, openCV +- {"unknown", 0, "object", "", kCVInvalidType}, // DE_UNKNOWN +- {"bool", 1, "bool", "", CV_8U}, // DE_BOOL +- {"int8", 1, "int8", "", CV_8S}, // DE_INT8 +- {"uint8", 1, "uint8", "", CV_8U}, // DE_UINT8 +- {"int16", 2, "int16", "", CV_16S}, // DE_INT16 +- {"uint16", 2, "uint16", "", CV_16U}, // DE_UINT16 +- {"int32", 4, "int32", "", CV_32S}, // DE_INT32 +- {"uint32", 4, "uint32", "", kCVInvalidType}, // DE_UINT32 +- {"int64", 8, "int64", "", kCVInvalidType}, // DE_INT64 +- {"uint64", 8, "uint64", "", kCVInvalidType}, // DE_UINT64 +- {"float16", 2, "float16", "", CV_16F}, // DE_FLOAT16 +- {"float32", 4, "float32", "", CV_32F}, // DE_FLOAT32 +- {"float64", 8, "double", "", CV_64F}, // DE_FLOAT64 +- {"string", 0, "str", "", kCVInvalidType}, // DE_STRING +- {"bytes", 0, "bytes", "", CV_8U}, // DE_BYTES +- {"python", 0, "object", "O", kCVInvalidType} // DE_PYTHON +- }; +-#else + // android and no python + static inline const TypeInfo kTypeInfo[] = { + // name, sizeInBytes, formatDescriptor +@@ -129,8 +75,6 @@ class DataType { + {"bytes", 0, "bytes", ""}, // DE_BYTES + {"python", 0, "object", "O", kCVInvalidType} // DE_PYTHON + }; +-#endif +-#endif + // No arg constructor to create an unknown shape + DataType() : type_(DE_UNKNOWN) {} + +@@ -165,17 +109,6 @@ class DataType { + /// \return the number of bytes of the type. + uint8_t SizeInBytes() const; + +-#if defined(ENABLE_MINDDATA_PYTHON) +- // Convert from DataType to OpenCV type +- /// \return +- uint8_t AsCVType() const; +- +- // Convert from OpenCV type to DataType +- /// \param cv_type +- /// \return +- static DataType FromCVType(int cv_type); +-#endif +- + // Returns a string representation of the type + /// \return + std::string ToString() const; +@@ -207,22 +140,6 @@ class DataType { + template + static DataType FromCType(); + +-#ifdef ENABLE_MINDDATA_PYTHON +- // Convert from DataType to Pybind type +- /// \return +- py::dtype AsNumpyType() const; +- +- // Convert from NP type to DataType +- /// \param type +- /// \return +- static DataType FromNpType(const py::dtype &type); +- +- // Convert from NP array to DataType +- /// \param py array +- /// \return +- static DataType FromNpArray(const py::array &arr); +-#endif +- + // Get the buffer string format of the current type. Used in pybind buffer protocol. + /// \return + std::string GetPybindFormat() const; +diff --git a/mindspore/ccsrc/minddata/dataset/core/global_context.cc b/mindspore/ccsrc/minddata/dataset/core/global_context.cc +index 1bf9f5dafd5..b2ed8f9a257 100644 +--- a/mindspore/ccsrc/minddata/dataset/core/global_context.cc ++++ b/mindspore/ccsrc/minddata/dataset/core/global_context.cc +@@ -57,9 +57,7 @@ Status GlobalContext::Init() { + + // Create some tensor allocators for the different types and hook them into the pool. + tensor_allocator_ = std::make_unique>(mem_pool_); +-#if defined(ENABLE_MINDDATA_PYTHON) + cv_tensor_allocator_ = std::make_unique>(mem_pool_); +-#endif + device_tensor_allocator_ = std::make_unique>(mem_pool_); + int_allocator_ = std::make_unique(mem_pool_); + profiler_manager_ = std::make_shared(); +diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/CMakeLists.txt +index e388521b19b..daa4fc5447d 100644 +--- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/CMakeLists.txt ++++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/CMakeLists.txt +@@ -48,6 +48,6 @@ set(DVPP_IMAGE_SOURCE + endif() + + add_library(kernels-dvpp-image OBJECT ${DVPP_IMAGE_SOURCE}) +-if(ENABLE_ACL OR MSLITE_ENABLE_ACL) ++if(ENABLE_ACL) + add_subdirectory(utils) + endif() +diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt +index 5aa83166bae..b880f5ab45f 100644 +--- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt ++++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt +@@ -14,12 +14,12 @@ set(DVPP_UTILS_SRC + # plugin + acl_plugin.cc + ) +-if(NOT MSLITE_ENABLE_ACL) +- set(DVPP_UTILS_SRC +- ${DVPP_UTILS_SRC} +- acl_env_guard.cc # in lite, src code has acl_env_guard.cc impl +- ) +-endif() ++#if(NOT MSLITE_ENABLE_ACL) ++# set(DVPP_UTILS_SRC ++# ${DVPP_UTILS_SRC} ++# acl_env_guard.cc # in lite, src code has acl_env_guard.cc impl ++# ) ++#endif() + + if(ENABLE_D) + set(DVPP_UTILS_SRC +@@ -33,25 +33,25 @@ endif() + add_library(dvpp_utils SHARED ${DVPP_UTILS_SRC}) + enable_target_when_only_build_plugins(dvpp_utils) + +-if(MSLITE_ENABLE_ACL) +- find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +- find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +- find_library(nnopbase libnnopbase.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +- find_library(acl_dvpp_op libacl_dvpp_op.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +- find_library(acl_dvpp_mpi libacl_dvpp_mpi.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) +- # find acl_env_guard in ascend_kernel_plugin +- target_link_libraries(dvpp_utils PRIVATE ascend_kernel_plugin minddata-lite ${acl} ${acl_dvpp} mindspore_core ${nnopbase} ${acl_dvpp_op} ${acl_dvpp_mpi}) +-else() ++#if(MSLITE_ENABLE_ACL) ++# find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) ++# find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) ++# find_library(nnopbase libnnopbase.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) ++# find_library(acl_dvpp_op libacl_dvpp_op.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) ++# find_library(acl_dvpp_mpi libacl_dvpp_mpi.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) ++# # find acl_env_guard in ascend_kernel_plugin ++# target_link_libraries(dvpp_utils PRIVATE ascend_kernel_plugin minddata-lite ${acl} ${acl_dvpp} mindspore_core ${nnopbase} ${acl_dvpp_op} ${acl_dvpp_mpi}) ++#else() + find_library(acl_dvpp libacl_dvpp.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(acl libascendcl.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(nnopbase libnnopbase.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(acl_dvpp_op libacl_dvpp_op.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + find_library(acl_dvpp_mpi libacl_dvpp_mpi.so ${ASCEND_CANN_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH}) + target_link_libraries(dvpp_utils PRIVATE _c_dataengine ${acl} ${acl_dvpp} mindspore_core ${nnopbase} ${acl_dvpp_op} ${acl_dvpp_mpi}) +-endif() ++#endif() + add_dependencies(dvpp_utils _mindspore_ascend_symbol_obj) + target_link_libraries(dvpp_utils PRIVATE $) +- +-if(MSLITE_ENABLE_CLOUD_MIND_DATA) +- add_dependencies(dvpp_utils fbs_src) +-endif() ++# ++#if(MSLITE_ENABLE_CLOUD_MIND_DATA) ++# add_dependencies(dvpp_utils fbs_src) ++#endif() +diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc +index 6a1452f698b..e80b4204be1 100644 +--- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc ++++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc +@@ -20,24 +20,13 @@ + #include + #include + +-#if defined(ENABLE_MINDDATA_PYTHON) +-#include +-#include +-#include +-#endif +- +-#if defined(ENABLE_MINDDATA_PYTHON) +-#include "minddata/dataset/core/cv_tensor.h" +-#endif + #include "minddata/dataset/core/tensor.h" + #include "minddata/dataset/core/tensor_shape.h" + #include "minddata/dataset/include/dataset/constants.h" + #include "minddata/dataset/kernels/image/lite_cv/image_process.h" + #include "minddata/dataset/kernels/image/lite_cv/lite_mat.h" + #include "minddata/dataset/kernels/image/math_utils.h" +-#if defined(ENABLE_MINDDATA_PYTHON) +-#include "minddata/dataset/kernels/image/resize_cubic_op.h" +-#endif ++ + #include "minddata/dataset/util/random.h" + + constexpr int64_t hw_shape = 2; +@@ -46,55 +35,6 @@ constexpr int64_t hwc_rank = 3; + #define MAX_INT_PRECISION 16777216 // float int precision is 16777216 + namespace mindspore { + namespace dataset { +-#if defined(ENABLE_MINDDATA_PYTHON) +-bool IsNonEmptyPNG(const std::shared_ptr &input) { +- const unsigned char kPngMagic[] = "\x89\x50\x4E\x47"; +- constexpr dsize_t kPngMagicLen = 4; +- return input->SizeInBytes() > kPngMagicLen && memcmp(input->GetBuffer(), kPngMagic, kPngMagicLen) == 0; +-} +- +-Status Rescale(const std::shared_ptr &input, std::shared_ptr *output, float rescale, float shift) { +- std::shared_ptr input_cv = CVTensor::AsCVTensor(input); +- if (!input_cv->mat().data) { +- RETURN_STATUS_UNEXPECTED("[Internal ERROR] Rescale: load image failed."); +- } +- cv::Mat input_image = input_cv->mat(); +- std::shared_ptr output_cv; +- RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), DataType(DataType::DE_FLOAT32), &output_cv)); +- try { +- input_image.convertTo(output_cv->mat(), CV_32F, rescale, shift); +- *output = std::static_pointer_cast(output_cv); +- } catch (const cv::Exception &e) { +- RETURN_STATUS_UNEXPECTED("Rescale: " + std::string(e.what())); +- } +- return Status::OK(); +-} +- +-Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *output) { +- try { +- RETURN_IF_NOT_OK(ValidateImage(input, "SwapRedBlue", {3, 5, 11})); +- std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); +- CHECK_FAIL_RETURN_UNEXPECTED( +- input_cv->shape().Size() > kChannelIndexHWC, +- "SwapRedAndBlue: rank of input data should be greater than:" + std::to_string(kChannelIndexHWC) + +- ", but got:" + std::to_string(input_cv->shape().Size())); +- int num_channels = static_cast(input_cv->shape()[kChannelIndexHWC]); +- if (input_cv->shape().Size() != kDefaultImageRank || num_channels != kDefaultImageChannel) { +- RETURN_STATUS_UNEXPECTED("SwapRedBlue: image shape should be in format, but got:" + +- input_cv->shape().ToString()); +- } +- std::shared_ptr output_cv; +- RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv)); +- +- cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast(cv::COLOR_BGR2RGB)); +- *output = std::static_pointer_cast(output_cv); +- return Status::OK(); +- } catch (const cv::Exception &e) { +- RETURN_STATUS_UNEXPECTED("SwapRedBlue: " + std::string(e.what())); +- } +-} +-#endif +- + bool IsNonEmptyJPEG(const std::shared_ptr &input) { + const unsigned char *kJpegMagic = (unsigned char *)"\xFF\xD8\xFF"; + constexpr size_t kJpegMagicLen = 3; +@@ -305,39 +245,11 @@ static LDataType GetLiteCVDataType(const DataType &data_type) { + } + } + +-#if defined(ENABLE_MINDDATA_PYTHON) +-Status DecodeCv(const std::shared_ptr &input, std::shared_ptr *output) { +- std::shared_ptr input_cv = CVTensor::AsCVTensor(input); +- if (!input_cv->mat().data) { +- RETURN_STATUS_UNEXPECTED("[Internal ERROR] Decode: load image failed."); +- } +- try { +- cv::Mat img_mat = cv::imdecode(input_cv->mat(), cv::IMREAD_COLOR | cv::IMREAD_IGNORE_ORIENTATION); +- if (img_mat.data == nullptr) { +- std::string err = "Decode: image decode failed."; +- RETURN_STATUS_UNEXPECTED(err); +- } +- cv::cvtColor(img_mat, img_mat, static_cast(cv::COLOR_BGR2RGB)); +- std::shared_ptr output_cv; +- const dsize_t rank_num = 3; +- RETURN_IF_NOT_OK(CVTensor::CreateFromMat(img_mat, rank_num, &output_cv)); +- *output = std::static_pointer_cast(output_cv); +- return Status::OK(); +- } catch (const cv::Exception &e) { +- RETURN_STATUS_UNEXPECTED("Decode: " + std::string(e.what())); +- } +-} +-#endif +- + Status Decode(const std::shared_ptr &input, std::shared_ptr *output) { + if (IsNonEmptyJPEG(input)) { + return JpegCropAndDecode(input, output); + } else { +-#if defined(ENABLE_MINDDATA_PYTHON) +- return DecodeCv(input, output); +-#else + RETURN_STATUS_UNEXPECTED("Decode: Decode only supports jpeg for android"); +-#endif + } + } + +@@ -465,94 +377,6 @@ Status Normalize(const std::shared_ptr &input, std::shared_ptr * + return Status::OK(); + } + +-#if defined(ENABLE_MINDDATA_PYTHON) +-int GetCVInterpolationMode(InterpolationMode mode) { +- switch (mode) { +- case InterpolationMode::kLinear: +- return static_cast(cv::InterpolationFlags::INTER_LINEAR); +- case InterpolationMode::kCubic: +- return static_cast(cv::InterpolationFlags::INTER_CUBIC); +- case InterpolationMode::kArea: +- return static_cast(cv::InterpolationFlags::INTER_AREA); +- case InterpolationMode::kNearestNeighbour: +- return static_cast(cv::InterpolationFlags::INTER_NEAREST); +- default: +- return static_cast(cv::InterpolationFlags::INTER_LINEAR); +- } +-} +- +-Status Resize(const std::shared_ptr &input, std::shared_ptr *output, int32_t output_height, +- int32_t output_width, double fx, double fy, InterpolationMode mode) { +- std::shared_ptr input_cv = CVTensor::AsCVTensor(input); +- if (!input_cv->mat().data) { +- RETURN_STATUS_UNEXPECTED("[Internal ERROR] Resize: load image failed."); +- } +- RETURN_IF_NOT_OK(ValidateImageRank("Resize", input_cv->Rank())); +- +- cv::Mat in_image = input_cv->mat(); +- const uint32_t kResizeShapeLimits = 1000; +- // resize image too large or too small, 1000 is arbitrarily chosen here to prevent open cv from segmentation fault +- CHECK_FAIL_RETURN_UNEXPECTED((std::numeric_limits::max() / kResizeShapeLimits) > in_image.rows, +- "Resize: in_image rows out of bounds."); +- CHECK_FAIL_RETURN_UNEXPECTED((std::numeric_limits::max() / kResizeShapeLimits) > in_image.cols, +- "Resize: in_image cols out of bounds."); +- if (output_height > in_image.rows * kResizeShapeLimits || output_width > in_image.cols * kResizeShapeLimits) { +- RETURN_STATUS_ERROR( +- StatusCode::kMDShapeMisMatch, +- "Resize: the resizing width or height is too big, it's 1000 times bigger than the original image, got output " +- "height: " + +- std::to_string(output_height) + ", width: " + std::to_string(output_width) + +- ", and original image size:" + std::to_string(in_image.rows) + ", " + std::to_string(in_image.cols)); +- } +- if (output_height == 0 || output_width == 0) { +- RETURN_STATUS_ERROR(StatusCode::kMDShapeMisMatch, +- "Resize: the input value of 'resize' is invalid, width or height is zero."); +- } +- +- if (mode == InterpolationMode::kCubicPil) { +- if (input_cv->shape().Size() != kDefaultImageChannel || +- input_cv->shape()[kChannelIndexHWC] != kDefaultImageChannel) { +- RETURN_STATUS_UNEXPECTED("Resize: Interpolation mode PILCUBIC only supports image with 3 channels, but got: " + +- input_cv->shape().ToString()); +- } +- +- LiteMat im_in; +- LiteMat im_out; +- std::shared_ptr output_tensor; +- TensorShape new_shape = TensorShape({output_height, output_width, 3}); +- RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input_cv->type(), &output_tensor)); +- uint8_t *buffer = reinterpret_cast(&(*output_tensor->begin())); +- im_out.Init(output_width, output_height, static_cast(input_cv->shape()[kChannelIndexHWC]), +- reinterpret_cast(buffer), LDataType::UINT8); +- im_in.Init(static_cast(input_cv->shape()[1]), static_cast(input_cv->shape()[0]), +- static_cast(input_cv->shape()[kChannelIndexHWC]), input_cv->mat().data, LDataType::UINT8); +- CHECK_FAIL_RETURN_UNEXPECTED(!im_out.IsEmpty(), "Resize: Init image tensor failed, return empty tensor."); +- CHECK_FAIL_RETURN_UNEXPECTED(!im_in.IsEmpty(), "Resize: Init image tensor failed, return empty tensor."); +- if (ResizeCubic(im_in, im_out, output_width, output_height) == false) { +- RETURN_STATUS_UNEXPECTED("Resize: failed to do resize, please check the error msg."); +- } +- *output = output_tensor; +- return Status::OK(); +- } +- try { +- TensorShape shape{output_height, output_width}; +- if (input_cv->Rank() == kDefaultImageRank) { +- int num_channels = static_cast(input_cv->shape()[kChannelIndexHWC]); +- shape = shape.AppendDim(num_channels); +- } +- std::shared_ptr output_cv; +- RETURN_IF_NOT_OK(CVTensor::CreateEmpty(shape, input_cv->type(), &output_cv)); +- +- auto cv_mode = GetCVInterpolationMode(mode); +- cv::resize(in_image, output_cv->mat(), cv::Size(output_width, output_height), fx, fy, cv_mode); +- *output = std::static_pointer_cast(output_cv); +- return Status::OK(); +- } catch (const cv::Exception &e) { +- RETURN_STATUS_UNEXPECTED("Resize: " + std::string(e.what())); +- } +-} +- +-#else + Status Resize(const std::shared_ptr &input, std::shared_ptr *output, int32_t output_height, + int32_t output_width, double fx, double fy, InterpolationMode mode) { + if (mode != InterpolationMode::kLinear) { +@@ -608,7 +432,7 @@ Status Resize(const std::shared_ptr &input, std::shared_ptr *out + } + return Status::OK(); + } +-#endif ++ + + Status ResizePreserve(const TensorRow &inputs, int32_t height, int32_t width, int32_t img_orientation, + TensorRow *outputs) { +diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h +index e4cd22e2ef7..572c62d2b2c 100644 +--- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h ++++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h +@@ -67,22 +67,6 @@ struct JpegErrorManagerCustom { + jmp_buf setjmp_buffer; + }; + +-#if defined(ENABLE_MINDDATA_PYTHON) +-bool IsNonEmptyPNG(const std::shared_ptr &input); +- +-/// \brief Returns Rescaled image +-/// \param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. +-/// \param rescale: rescale parameter +-/// \param shift: shift parameter +-/// \param output: Rescaled image Tensor of same input shape and type DE_FLOAT32 +-Status Rescale(const std::shared_ptr &input, std::shared_ptr *output, float rescale, float shift); +- +-/// \brief Swap the red and blue pixels (RGB <-> BGR) +-/// \param input: Tensor of shape and any OpenCv compatible type, see CVTensor. +-/// \param output: Swapped image of same shape and type +-Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *output); +-#endif +- + bool IsNonEmptyJPEG(const std::shared_ptr &input); + + void JpegSetSource(j_decompress_ptr c_info, const void *data, int64_t data_size); +diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/pad_ir.cc b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/pad_ir.cc +index efa2b28638c..c666ddfb506 100644 +--- a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/pad_ir.cc ++++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/pad_ir.cc +@@ -18,9 +18,6 @@ + + #include + +-#if defined(ENABLE_MINDDATA_PYTHON) +-#include "minddata/dataset/kernels/image/pad_op.h" +-#endif + #if defined(ENABLE_D) + #include "minddata/dataset/kernels/image/dvpp/ascend910b/dvpp_pad_op.h" + #endif +@@ -30,127 +27,7 @@ + namespace mindspore { + namespace dataset { + namespace vision { +-#if defined(ENABLE_MINDDATA_PYTHON) +-// PadOperation +-PadOperation::PadOperation(const std::vector &padding, const std::vector &fill_value, +- BorderType padding_mode, const std::string &device_target) +- : padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode), device_target_(device_target) {} +- +-PadOperation::~PadOperation() = default; +- +-std::string PadOperation::Name() const { return kPadOperation; } +- +-Status PadOperation::ValidateParams() { +- // padding +- RETURN_IF_NOT_OK(ValidateVectorPadding("Pad", padding_)); +- // fill_value +- RETURN_IF_NOT_OK(ValidateVectorFillvalue("Pad", fill_value_)); +- // padding_mode +- if (padding_mode_ != BorderType::kConstant && padding_mode_ != BorderType::kEdge && +- padding_mode_ != BorderType::kReflect && padding_mode_ != BorderType::kSymmetric) { +- std::string err_msg = "Pad: Invalid BorderType, check input value of enum."; +- LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); +- } +- // device target +- if (device_target_ != "CPU" && device_target_ != "Ascend") { +- std::string err_msg = "Pad: Invalid device target. It's not CPU or Ascend."; +- LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); +- } +- return Status::OK(); +-} +- +-std::shared_ptr PadOperation::Build() { +- constexpr size_t dimension_zero = 0; +- constexpr size_t dimension_one = 1; +- constexpr size_t dimension_two = 2; +- constexpr size_t dimension_three = 3; +- constexpr size_t size_one = 1; +- constexpr size_t size_two = 2; +- constexpr size_t size_three = 3; +- int32_t pad_top, pad_bottom, pad_left, pad_right; +- switch (padding_.size()) { +- case size_one: +- pad_left = padding_[dimension_zero]; +- pad_top = padding_[dimension_zero]; +- pad_right = padding_[dimension_zero]; +- pad_bottom = padding_[dimension_zero]; +- break; +- case size_two: +- pad_left = padding_[dimension_zero]; +- pad_right = padding_[dimension_zero]; +- pad_top = padding_[dimension_one]; +- pad_bottom = padding_[dimension_one]; +- break; +- default: +- pad_left = padding_[dimension_zero]; +- pad_top = padding_[dimension_one]; +- pad_right = padding_[dimension_two]; +- pad_bottom = padding_[dimension_three]; +- } +- uint8_t fill_r, fill_g, fill_b; +- +- fill_r = fill_value_[dimension_zero]; +- fill_g = fill_value_[dimension_zero]; +- fill_b = fill_value_[dimension_zero]; +- +- if (fill_value_.size() == size_three) { +- fill_r = fill_value_[dimension_zero]; +- fill_g = fill_value_[dimension_one]; +- fill_b = fill_value_[dimension_two]; +- } + +- if (device_target_ == "CPU") { +- std::shared_ptr tensor_op = +- std::make_shared(pad_top, pad_bottom, pad_left, pad_right, padding_mode_, fill_r, fill_g, fill_b); +- return tensor_op; +-#if defined(ENABLE_D) +- } else if (device_target_ == "Ascend") { +- std::shared_ptr dvpp_tensor_op = +- std::make_shared(pad_top, pad_bottom, pad_left, pad_right, padding_mode_, fill_r, fill_g, fill_b); +- return dvpp_tensor_op; +-#endif +- } else { +- MS_LOG(ERROR) << "Pad: Invalid device target. It's not CPU or Ascend."; +- return nullptr; +- } +-} +- +-Status PadOperation::to_json(nlohmann::json *out_json) { +- RETURN_UNEXPECTED_IF_NULL(out_json); +- nlohmann::json args; +- args["padding"] = padding_; +- args["fill_value"] = fill_value_; +- args["padding_mode"] = padding_mode_; +- args["device_target"] = device_target_; +- *out_json = args; +- return Status::OK(); +-} +- +-Status PadOperation::from_json(nlohmann::json op_params, std::shared_ptr *operation) { +- RETURN_UNEXPECTED_IF_NULL(operation); +- RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "padding", kPadOperation)); +- RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "fill_value", kPadOperation)); +- RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "padding_mode", kPadOperation)); +- RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "device_target", kPadOperation)); +- std::vector padding = op_params["padding"]; +- std::vector fill_value = op_params["fill_value"]; +- auto padding_mode = static_cast(op_params["padding_mode"]); +- std::string device_target = op_params["device_target"]; +- *operation = std::make_shared(padding, fill_value, padding_mode, device_target); +- return Status::OK(); +-} +- +-MapTargetDevice PadOperation::Type() { +- if (device_target_ == "CPU") { +- return MapTargetDevice::kCpu; +- } else if (device_target_ == "Ascend") { +- return MapTargetDevice::kAscend910B; +- } else { +- MS_LOG(ERROR) << "Pad: Invalid device target. It's not CPU or Ascend."; +- return MapTargetDevice::kInvalid; +- } +-} +-#endif + } // namespace vision + } // namespace dataset + } // namespace mindspore +diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/rescale_ir.cc b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/rescale_ir.cc +index d748ddc1fcb..717d4222ae2 100644 +--- a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/rescale_ir.cc ++++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/rescale_ir.cc +@@ -21,46 +21,6 @@ + namespace mindspore { + namespace dataset { + namespace vision { +-#if defined(ENABLE_MINDDATA_PYTHON) +-// RescaleOperation +-RescaleOperation::RescaleOperation(float rescale, float shift) : rescale_(rescale), shift_(shift) {} +- +-RescaleOperation::~RescaleOperation() = default; +- +-std::string RescaleOperation::Name() const { return kRescaleOperation; } +- +-Status RescaleOperation::ValidateParams() { +- if (rescale_ < 0.0) { +- std::string err_msg = "Rescale: rescale must be greater than or equal to 0, got: " + std::to_string(rescale_); +- LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg); +- } +- return Status::OK(); +-} +- +-std::shared_ptr RescaleOperation::Build() { +- std::shared_ptr tensor_op = std::make_shared(rescale_, shift_); +- return tensor_op; +-} +- +-Status RescaleOperation::to_json(nlohmann::json *out_json) { +- RETURN_UNEXPECTED_IF_NULL(out_json); +- nlohmann::json args; +- args["rescale"] = rescale_; +- args["shift"] = shift_; +- *out_json = args; +- return Status::OK(); +-} +- +-Status RescaleOperation::from_json(nlohmann::json op_params, std::shared_ptr *operation) { +- RETURN_UNEXPECTED_IF_NULL(operation); +- RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "rescale", kRescaleOperation)); +- RETURN_IF_NOT_OK(ValidateParamInJson(op_params, "shift", kRescaleOperation)); +- float rescale = op_params["rescale"]; +- float shift = op_params["shift"]; +- *operation = std::make_shared(rescale, shift); +- return Status::OK(); +-} +-#endif + } // namespace vision + } // namespace dataset + } // namespace mindspore +diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/swap_red_blue_ir.cc b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/swap_red_blue_ir.cc +index 8ed116c35f9..b56b99fdf9c 100644 +--- a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/swap_red_blue_ir.cc ++++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/swap_red_blue_ir.cc +@@ -15,34 +15,10 @@ + */ + #include "minddata/dataset/kernels/ir/vision/swap_red_blue_ir.h" + +-#if defined(ENABLE_MINDDATA_PYTHON) +-#include "minddata/dataset/kernels/image/swap_red_blue_op.h" +-#endif +- + namespace mindspore { + namespace dataset { + namespace vision { +-#if defined(ENABLE_MINDDATA_PYTHON) +-// SwapRedBlueOperation. +-SwapRedBlueOperation::SwapRedBlueOperation() = default; +- +-SwapRedBlueOperation::~SwapRedBlueOperation() = default; +- +-std::string SwapRedBlueOperation::Name() const { return kSwapRedBlueOperation; } +- +-Status SwapRedBlueOperation::ValidateParams() { return Status::OK(); } +- +-std::shared_ptr SwapRedBlueOperation::Build() { +- std::shared_ptr tensor_op = std::make_shared(); +- return tensor_op; +-} + +-Status SwapRedBlueOperation::from_json(nlohmann::json op_params, std::shared_ptr *operation) { +- RETURN_UNEXPECTED_IF_NULL(operation); +- *operation = std::make_shared(); +- return Status::OK(); +-} +-#endif + } // namespace vision + } // namespace dataset + } // namespace mindspore +diff --git a/mindspore/ccsrc/minddata/dataset/util/log_adapter.h b/mindspore/ccsrc/minddata/dataset/util/log_adapter.h +index 6dcc8f20833..5f886102faf 100644 +--- a/mindspore/ccsrc/minddata/dataset/util/log_adapter.h ++++ b/mindspore/ccsrc/minddata/dataset/util/log_adapter.h +@@ -16,12 +16,7 @@ + #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_LOG_ADAPTER_H_ + #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_LOG_ADAPTER_H_ + +-#if defined(ENABLE_MINDDATA_PYTHON) +-#include "utils/log_adapter.h" +-#define DATASET_SRC_FILE_NAME FILE_NAME +-#else + #include "mindspore/lite/src/common/log_adapter.h" + #define DATASET_SRC_FILE_NAME LITE_FILE_NAME +-#endif + + #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_LOG_ADAPTER_H_ +diff --git a/mindspore/ccsrc/pipeline/jit/ps/compile_cache_manager.cc b/mindspore/ccsrc/pipeline/jit/ps/compile_cache_manager.cc +index 445f91cf9b8..caff57bce65 100644 +--- a/mindspore/ccsrc/pipeline/jit/ps/compile_cache_manager.cc ++++ b/mindspore/ccsrc/pipeline/jit/ps/compile_cache_manager.cc +@@ -37,9 +37,7 @@ + #endif + #include "include/common/utils/compile_cache_context.h" + #include "include/common/utils/config_manager.h" +-#if !defined(BUILD_LITE) + #include "include/backend/distributed/collective/collective_manager.h" +-#endif + + namespace mindspore { + #ifndef MINDIR_EXPORT_TENSOR_LAYOUT_CLIP +@@ -97,11 +95,7 @@ namespace { + std::string GetCompileCacheDir() { + static const std::string user_defined_path = Common::GetUserDefineCachePath(); + +-#if !defined(BUILD_LITE) + bool is_distributed = distributed::collective::CollectiveManager::instance()->initialized(); +-#else +- bool is_distributed = !IsStandAlone(); +-#endif + static const uint32_t rank_id = is_distributed ? GetRank() : 0; + static const std::string compile_cache_dir = user_defined_path + "rank_" + std::to_string(rank_id); + return compile_cache_dir; +diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/CMakeLists.txt b/mindspore/ccsrc/plugin/device/cpu/kernel/CMakeLists.txt +index c166b2a7043..9335e96cf2e 100644 +--- a/mindspore/ccsrc/plugin/device/cpu/kernel/CMakeLists.txt ++++ b/mindspore/ccsrc/plugin/device/cpu/kernel/CMakeLists.txt +@@ -1,16 +1,5 @@ + file(GLOB_RECURSE CPU_KERNEL_OBJECTS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") + +-if(BUILD_LITE) +- # mslite do not support python op +- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx ") +- set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -maxv ") +- string(REPLACE "-Wall" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) +- string(REPLACE "-Wall" "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) +- list(REMOVE_ITEM CPU_KERNEL_OBJECTS "pyexecute/py_execute_cpu_kernel.cc") +- list(REMOVE_ITEM CPU_KERNEL_OBJECTS "pyfunc/py_func_cpu_kernel.cc") +- list(REMOVE_ITEM CPU_KERNEL_OBJECTS "opaque_predicate_kernel.cc") +-endif() +- + if(ENABLE_AKG AND ${CMAKE_SYSTEM_NAME} MATCHES "Linux" AND ENABLE_CPU) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") + else() +diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_kernel_loader.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_kernel_loader.cc +index d66ad0bfa22..743808d51c8 100644 +--- a/mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_kernel_loader.cc ++++ b/mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_kernel_loader.cc +@@ -23,11 +23,7 @@ + #include + #include + +-#ifdef BUILD_LITE +-#include "src/common/log_adapter.h" +-#else + #include "utils/log_adapter.h" +-#endif + + namespace mindspore { + namespace kernel { +diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.cc b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.cc +index 6bfcd501183..07de4ac1648 100644 +--- a/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.cc ++++ b/mindspore/ccsrc/plugin/res_manager/ascend/mem_manager/ascend_memory_pool.cc +@@ -135,9 +135,7 @@ static uint64_t GetPid() { + + int32_t GetDeviceId() { + int32_t device_id = 0; +-#if !defined(BUILD_LITE) + device_id = static_cast(DistributedMeta::GetInstance()->local_rank_id()); +-#endif + return device_id; + } + +diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/custom_op_infer.cc b/mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/custom_op_infer.cc +index 097dc9fcea1..8ca39ac217a 100644 +--- a/mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/custom_op_infer.cc ++++ b/mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/custom_op_infer.cc +@@ -16,15 +16,15 @@ + + #include + #include +-#ifdef MSLITE_ENABLE_GRAPH_KERNEL +-#include +-#include +-#include +-#include "nlohmann/json.hpp" +-#include "plugin/res_manager/ascend/op_adapter/transform_util.h" +-#include "backend/common/graph_kernel/model/op_register.h" +-#include "backend/common/graph_kernel/core/value_depend_op_utils.h" +-#endif ++//#ifdef MSLITE_ENABLE_GRAPH_KERNEL ++//#include ++//#include ++//#include ++//#include "nlohmann/json.hpp" ++//#include "plugin/res_manager/ascend/op_adapter/transform_util.h" ++//#include "backend/common/graph_kernel/model/op_register.h" ++//#include "backend/common/graph_kernel/core/value_depend_op_utils.h" ++//#endif + #include "utils/log_adapter.h" + #include "graph/operator.h" + +@@ -87,298 +87,288 @@ std::string GetCustomOpKey(const ge::Operator &op) { + } + } // namespace + +-#ifdef MSLITE_ENABLE_GRAPH_KERNEL +-using mindspore::graphkernel::inner::ConstTensorNode; +-using mindspore::graphkernel::inner::DAttrs; +-using mindspore::graphkernel::inner::Node; +-using mindspore::graphkernel::inner::NodeBase; +-using mindspore::graphkernel::inner::NodePtr; +-using mindspore::graphkernel::inner::NodePtrList; +- +-namespace { +-TypeId const ConvertGeDataType(const ge::DataType &type) { +- static std::unordered_map ge_ms_type = { +- {ge::DataType::DT_FLOAT16, TypeId::kNumberTypeFloat16}, +- {ge::DataType::DT_FLOAT, TypeId::kNumberTypeFloat32}, +- {ge::DataType::DT_DOUBLE, TypeId::kNumberTypeFloat64}, +- {ge::DataType::DT_INT8, TypeId::kNumberTypeInt8}, +- {ge::DataType::DT_INT16, TypeId::kNumberTypeInt16}, +- {ge::DataType::DT_INT32, TypeId::kNumberTypeInt32}, +- {ge::DataType::DT_INT64, TypeId::kNumberTypeInt64}, +- {ge::DataType::DT_UINT8, TypeId::kNumberTypeUInt8}, +- {ge::DataType::DT_UINT16, TypeId::kNumberTypeUInt16}, +- {ge::DataType::DT_UINT32, TypeId::kNumberTypeUInt32}, +- {ge::DataType::DT_UINT64, TypeId::kNumberTypeUInt64}, +- {ge::DataType::DT_BOOL, TypeId::kNumberTypeBool}, +- {ge::DataType::DT_STRING, TypeId::kObjectTypeString}, +- {ge::DataType::DT_BF16, TypeId::kNumberTypeBFloat16}, +- {ge::DataType::DT_HIFLOAT8, TypeId::kNumberTypeHiFloat8}, +- {ge::DataType::DT_FLOAT8_E5M2, TypeId::kNumberTypeFloat8E5M2}, +- {ge::DataType::DT_FLOAT8_E4M3FN, TypeId::kNumberTypeFloat8E4M3FN}, +- }; +- auto iter = ge_ms_type.find(type); +- if (iter != ge_ms_type.end()) { +- return iter->second; +- } +- return TypeId::kTypeUnknown; +-} +- +-using ConvertFunc = std::function; +- +-template +-tensor::TensorPtr ConvertSingleValueHelper(const nlohmann::json &value) { +- T v = value; +- ShapeVector shape = {1}; +- return std::make_shared(type_id, shape, &v, type_id); +-} +- +-tensor::TensorPtr ConvertSingleValue(const nlohmann::json &value, const std::string &type) { +- // reverse function of SetSingleValue +- static std::unordered_map convertTable = { +- {"bool", ConvertSingleValueHelper}, +- {"int8", ConvertSingleValueHelper}, +- {"int16", ConvertSingleValueHelper}, +- {"int32", ConvertSingleValueHelper}, +- {"int64", ConvertSingleValueHelper}, +- {"uint8", ConvertSingleValueHelper}, +- {"uint16", ConvertSingleValueHelper}, +- {"uint32", ConvertSingleValueHelper}, +- {"uint64", ConvertSingleValueHelper}, +- {"float16", ConvertSingleValueHelper}, +- {"float32", ConvertSingleValueHelper}, +- {"float64", ConvertSingleValueHelper}}; +- if (convertTable.count(type) > 0) { +- ConvertFunc convertFunc = convertTable[type]; +- return convertFunc(value); +- } +- MS_LOG(WARNING) << "Fail to convert the single value: " << value << ". type is: " << type; +- return nullptr; +-} +- +-tensor::TensorPtr ConvertListValue(const nlohmann::json &value, const std::string &type, const ShapeVector &shape) { +- // reverse function of SetValueList +- tensor::TensorPtr res = nullptr; +- if (type == "int32") { +- std::vector v = value; +- res = std::make_shared(TypeId::kNumberTypeInt32, shape, &v[0], TypeId::kNumberTypeInt32); +- } else if (type == "int64") { +- std::vector v = value; +- res = std::make_shared(TypeId::kNumberTypeInt64, shape, &v[0], TypeId::kNumberTypeInt64); +- } +- MS_LOG(WARNING) << "Fail to convert the list value: " << value << ". type is: " << type; +- return res; +-} +- +-NodePtrList GetOpInputs(const nlohmann::json &op_desc, const std::unordered_map &all_tensors, +- const std::string &op_name) { +- const auto &value_depend_op_info = graphkernel::ValueDependOpUtils::GetOpIndexInfo(); +- bool is_depend_value = value_depend_op_info.find(op_name) != value_depend_op_info.end(); +- NodePtrList res; +- for (const auto &input_desc : op_desc["input_desc"]) { +- for (const auto &item : input_desc) { +- std::string name = item["tensor_name"]; +- auto iter = all_tensors.find(name); +- if (iter != all_tensors.end()) { +- res.push_back(iter->second); +- } else { +- // const value input +- if (is_depend_value && item.find("value") != item.end()) { +- auto value = item["value"]; +- std::string type = item["data_type"]; +- ShapeVector shape = item["shape"]; +- auto tensor = value.is_array() ? ConvertListValue(value, type, shape) : ConvertSingleValue(value, type); +- if (tensor != nullptr) { +- res.push_back(std::make_shared(tensor)); +- continue; +- } +- MS_LOG(WARNING) << "Fail to parse the const value of tensor [" << name << "]. tensor json is: " << item; +- } +- std::string format = item["format"]; +- NodeBase n{ShapeVector(item["shape"]), StringToTypeId(item["data_type"]), format}; +- res.push_back(std::make_shared(n)); +- } +- } +- } +- return res; +-} +- +-DAttrs GetOpAttr(const nlohmann::json &op_desc) { +- DAttrs res; +- // no attr +- if (op_desc.find("attr") == op_desc.end() || op_desc["attr"].is_null()) { +- return res; +- } +- for (const auto &item : op_desc["attr"]) { +- std::string name = item["name"]; +- std::string type = item["data_type"]; +- ValuePtr attr_value = nullptr; +- if (type == "str") { +- std::string value = item["value"]; +- attr_value = (name == "dst_type" && op_desc["name"] == "Cast") ? StringToType(value) : MakeValue(value); +- } else if (type == "int") { +- int64_t value = item["value"]; +- attr_value = MakeValue(value); +- } else if (type == "bool") { +- bool value = item["value"]; +- attr_value = MakeValue(value); +- } else if (type == "float") { +- float value = item["value"]; +- attr_value = MakeValue(value); +- } else if (type == "listInt") { +- std::vector value = item["value"]; +- attr_value = MakeValue(value); +- } else if (type == "listStr") { +- std::vector value = item["value"]; +- attr_value = MakeValue(value); +- } else { +- MS_LOG(WARNING) << "Fail to parse attr [" << name << "] because its type: " << type +- << " is not in supported list: [str, int, bool, float, listInt, listStr]. attr json is: " << item; +- } +- if (attr_value != nullptr) { +- res[name] = attr_value; +- } +- } +- return res; +-} +- +-bool InferOnline(const ge::Operator &op, const nlohmann::json &js, std::vector *outputs_info) { +- if (outputs_info == nullptr) { +- return false; +- } +- std::unordered_map all_tensors; +- // iter input_desc: inputs info use the real info pass by GE +- std::vector input_desc = js["input_desc"]; +- for (size_t i = 0; i < input_desc.size(); ++i) { +- const auto &item = input_desc[i][0]; +- std::string input_name = "x" + std::to_string(i); +- auto ge_desc = op.GetInputDescByName(input_name.c_str()); +- std::string format = item["format"]; +- NodeBase n{ge_desc.GetShape().GetDims(), ConvertGeDataType(ge_desc.GetDataType()), format}; +- MS_LOG(DEBUG) << "input[" << i << "]: " << n.shape << " " << TypeIdToString(n.type); +- all_tensors[item["tensor_name"]] = std::make_shared(n); +- } +- +- // iter op_desc: infer each op +- for (const auto &op_desc : js["op_desc"]) { +- std::string op_name = op_desc["name"]; +- auto op_ptr = mindspore::graphkernel::inner::OpRegistry::Instance().NewOp(op_name); +- auto op_inputs = GetOpInputs(op_desc, all_tensors, op_name); +- auto op_attr = GetOpAttr(op_desc); +- auto infer_res = op_ptr->Infer(op_inputs, op_attr); +- std::vector op_output_desc = op_desc["output_desc"]; +- if (infer_res.size() != op_output_desc.size()) { +- MS_LOG(ERROR) << "For op [" << op_name +- << "], the length of inferred output shape list is not equal to the length of output_desc list: " +- << infer_res.size() << " vs " << op_output_desc.size(); +- return false; +- } +- for (size_t i = 0; i < op_output_desc.size(); ++i) { +- std::string name = op_output_desc[i]["tensor_name"]; +- all_tensors[name] = std::make_shared(infer_res[i]); +- } +- } +- +- // iter output_desc: combine the outputs info +- std::vector output_desc = js["output_desc"]; +- // format not need infer +- std::vector output_formats; +- if (op.GetAttr("output_formats", output_formats) != ge::GRAPH_SUCCESS || +- output_formats.size() != output_desc.size()) { +- return false; +- } +- +- for (size_t i = 0; i < output_desc.size(); ++i) { +- std::string name = output_desc[i]["tensor_name"]; +- auto iter = all_tensors.find(name); +- if (iter == all_tensors.end()) { +- MS_LOG(ERROR) << "Tensor [" << name << "] not found in op_desc"; +- return false; +- } +- auto shape = iter->second->shape; +- (void)outputs_info->emplace_back(ge::Shape(shape), static_cast(output_formats[i]), +- device::ascend::TransformUtil::ConvertDataType(iter->second->type)); +- MS_LOG(DEBUG) << "output[" << i << "]: " << shape << " " << TypeIdToString(iter->second->type); +- } +- return true; +-} +- +-bool InputsInfoNotChanged(const ge::Operator &op, const nlohmann::json &js) { +- std::vector input_desc = js["input_desc"]; +- for (size_t i = 0; i < input_desc.size(); ++i) { +- std::string input_name = "x" + std::to_string(i); +- auto ge_desc = op.GetInputDescByName(input_name.c_str()); +- auto ge_shape = ge_desc.GetShape().GetDims(); +- auto ge_type = ge_desc.GetDataType(); +- const auto &item = input_desc[i][0]; +- ShapeVector ms_shape = item["shape"]; +- auto ms_type = StringToTypeId(item["data_type"]); +- if (ge_shape != ms_shape || ConvertGeDataType(ge_type) != ms_type) { +- return false; +- } +- } +- return true; +-} +- +-bool Infer(const ge::Operator &op, const std::string &op_key, const std::string &info_path, +- std::vector *outputs_info) { +- if (outputs_info == nullptr) { +- return false; +- } +- +- // read akg info and parse it to json format +- std::ifstream info_str(info_path); +- if (!info_str.is_open()) { +- return false; +- } +- nlohmann::json js; +- info_str >> js; +- info_str.close(); +- +- // 1) if input information not changed, reuse the outputs info saved in op attr 2) else infer online +- if (InputsInfoNotChanged(op, js)) { +- MS_LOG(INFO) << "Infer shape offline for op " << op_key; +- return InferOffline(op, outputs_info); +- } +- MS_LOG(INFO) << "Infer shape online for op " << op_key; +- return InferOnline(op, js, outputs_info); +-} +-} // namespace +- +-ge::graphStatus CustomAkgOpInferFunc(ge::Operator &op) { +- auto op_key = GetCustomOpKey(op); +- MS_LOG(INFO) << "Start infer shape for op " << op_key; +- +- // get akg info path of current op +- std::string info_path; +- auto status = op.GetAttr("info_path", info_path); +- if (status != ge::GRAPH_SUCCESS) { +- return status; +- } +- +- // infer shape +- std::vector outputs_info; +- try { +- if (!Infer(op, op_key, info_path, &outputs_info)) { +- MS_LOG(ERROR) << "Failed infer shape for op " << op_key << ", akg info path: " << info_path; +- return ge::GRAPH_FAILED; +- } +- } catch (std::exception &e) { +- MS_LOG(ERROR) << "Failed infer shape for op " << op_key << ", akg info path: " << info_path +- << " error message: " << e.what(); +- return ge::GRAPH_FAILED; +- } +- +- // update output desc +- for (size_t i = 0; i < outputs_info.size(); ++i) { +- std::string output_name = "y" + std::to_string(i); +- (void)op.UpdateOutputDesc(output_name, outputs_info[i]); +- } +- MS_LOG(INFO) << "End infer shape for op " << op_key; +- return ge::GRAPH_SUCCESS; +-} +-#else ++//#ifdef MSLITE_ENABLE_GRAPH_KERNEL ++// using mindspore::graphkernel::inner::ConstTensorNode; ++// using mindspore::graphkernel::inner::DAttrs; ++// using mindspore::graphkernel::inner::Node; ++// using mindspore::graphkernel::inner::NodeBase; ++// using mindspore::graphkernel::inner::NodePtr; ++// using mindspore::graphkernel::inner::NodePtrList; ++// ++// namespace { ++// TypeId ConvertGeDataType(const ge::DataType &type) { ++// static std::unordered_map ge_ms_type = { ++// {ge::DataType::DT_FLOAT16, TypeId::kNumberTypeFloat16}, {ge::DataType::DT_FLOAT, TypeId::kNumberTypeFloat32}, ++// {ge::DataType::DT_DOUBLE, TypeId::kNumberTypeFloat64}, {ge::DataType::DT_INT8, TypeId::kNumberTypeInt8}, ++// {ge::DataType::DT_INT16, TypeId::kNumberTypeInt16}, {ge::DataType::DT_INT32, TypeId::kNumberTypeInt32}, ++// {ge::DataType::DT_INT64, TypeId::kNumberTypeInt64}, {ge::DataType::DT_UINT8, TypeId::kNumberTypeUInt8}, ++// {ge::DataType::DT_UINT16, TypeId::kNumberTypeUInt16}, {ge::DataType::DT_UINT32, TypeId::kNumberTypeUInt32}, ++// {ge::DataType::DT_UINT64, TypeId::kNumberTypeUInt64}, {ge::DataType::DT_BOOL, TypeId::kNumberTypeBool}, ++// {ge::DataType::DT_STRING, TypeId::kObjectTypeString}, {ge::DataType::DT_BF16, TypeId::kNumberTypeBFloat16}}; ++// auto iter = ge_ms_type.find(type); ++// if (iter != ge_ms_type.end()) { ++// return iter->second; ++// } ++// return TypeId::kTypeUnknown; ++//} ++// ++// using ConvertFunc = std::function; ++// ++// template ++// tensor::TensorPtr ConvertSingleValueHelper(const nlohmann::json &value) { ++// T v = value; ++// ShapeVector shape = {1}; ++// return std::make_shared(type_id, shape, &v, type_id); ++//} ++// ++// tensor::TensorPtr ConvertSingleValue(const nlohmann::json &value, const std::string &type) { ++// // reverse function of SetSingleValue ++// static std::unordered_map convertTable = { ++// {"bool", ConvertSingleValueHelper}, ++// {"int8", ConvertSingleValueHelper}, ++// {"int16", ConvertSingleValueHelper}, ++// {"int32", ConvertSingleValueHelper}, ++// {"int64", ConvertSingleValueHelper}, ++// {"uint8", ConvertSingleValueHelper}, ++// {"uint16", ConvertSingleValueHelper}, ++// {"uint32", ConvertSingleValueHelper}, ++// {"uint64", ConvertSingleValueHelper}, ++// {"float16", ConvertSingleValueHelper}, ++// {"float32", ConvertSingleValueHelper}, ++// {"float64", ConvertSingleValueHelper}}; ++// if (convertTable.count(type) > 0) { ++// ConvertFunc convertFunc = convertTable[type]; ++// return convertFunc(value); ++// } ++// MS_LOG(WARNING) << "Fail to convert the single value: " << value << ". type is: " << type; ++// return nullptr; ++//} ++// ++// tensor::TensorPtr ConvertListValue(const nlohmann::json &value, const std::string &type, const ShapeVector &shape) { ++// // reverse function of SetValueList ++// tensor::TensorPtr res = nullptr; ++// if (type == "int32") { ++// std::vector v = value; ++// res = std::make_shared(TypeId::kNumberTypeInt32, shape, &v[0], TypeId::kNumberTypeInt32); ++// } else if (type == "int64") { ++// std::vector v = value; ++// res = std::make_shared(TypeId::kNumberTypeInt64, shape, &v[0], TypeId::kNumberTypeInt64); ++// } ++// MS_LOG(WARNING) << "Fail to convert the list value: " << value << ". type is: " << type; ++// return res; ++//} ++// ++// NodePtrList GetOpInputs(const nlohmann::json &op_desc, const std::unordered_map &all_tensors, ++// const std::string &op_name) { ++// const auto &value_depend_op_info = graphkernel::ValueDependOpUtils::GetOpIndexInfo(); ++// bool is_depend_value = value_depend_op_info.find(op_name) != value_depend_op_info.end(); ++// NodePtrList res; ++// for (const auto &input_desc : op_desc["input_desc"]) { ++// for (const auto &item : input_desc) { ++// std::string name = item["tensor_name"]; ++// auto iter = all_tensors.find(name); ++// if (iter != all_tensors.end()) { ++// res.push_back(iter->second); ++// } else { ++// // const value input ++// if (is_depend_value && item.find("value") != item.end()) { ++// auto value = item["value"]; ++// std::string type = item["data_type"]; ++// ShapeVector shape = item["shape"]; ++// auto tensor = value.is_array() ? ConvertListValue(value, type, shape) : ConvertSingleValue(value, type); ++// if (tensor != nullptr) { ++// res.push_back(std::make_shared(tensor)); ++// continue; ++// } ++// MS_LOG(WARNING) << "Fail to parse the const value of tensor [" << name << "]. tensor json is: " << item; ++// } ++// std::string format = item["format"]; ++// NodeBase n{ShapeVector(item["shape"]), StringToTypeId(item["data_type"]), format}; ++// res.push_back(std::make_shared(n)); ++// } ++// } ++// } ++// return res; ++//} ++// ++// DAttrs GetOpAttr(const nlohmann::json &op_desc) { ++// DAttrs res; ++// // no attr ++// if (op_desc.find("attr") == op_desc.end() || op_desc["attr"].is_null()) { ++// return res; ++// } ++// for (const auto &item : op_desc["attr"]) { ++// std::string name = item["name"]; ++// std::string type = item["data_type"]; ++// ValuePtr attr_value = nullptr; ++// if (type == "str") { ++// std::string value = item["value"]; ++// attr_value = (name == "dst_type" && op_desc["name"] == "Cast") ? StringToType(value) : MakeValue(value); ++// } else if (type == "int") { ++// int64_t value = item["value"]; ++// attr_value = MakeValue(value); ++// } else if (type == "bool") { ++// bool value = item["value"]; ++// attr_value = MakeValue(value); ++// } else if (type == "float") { ++// float value = item["value"]; ++// attr_value = MakeValue(value); ++// } else if (type == "listInt") { ++// std::vector value = item["value"]; ++// attr_value = MakeValue(value); ++// } else if (type == "listStr") { ++// std::vector value = item["value"]; ++// attr_value = MakeValue(value); ++// } else { ++// MS_LOG(WARNING) << "Fail to parse attr [" << name << "] because its type: " << type ++// << " is not in supported list: [str, int, bool, float, listInt, listStr]. attr json is: " << ++// item; ++// } ++// if (attr_value != nullptr) { ++// res[name] = attr_value; ++// } ++// } ++// return res; ++//} ++// ++// bool InferOnline(const ge::Operator &op, const nlohmann::json &js, std::vector *outputs_info) { ++// if (outputs_info == nullptr) { ++// return false; ++// } ++// std::unordered_map all_tensors; ++// // iter input_desc: inputs info use the real info pass by GE ++// std::vector input_desc = js["input_desc"]; ++// for (size_t i = 0; i < input_desc.size(); ++i) { ++// const auto &item = input_desc[i][0]; ++// std::string input_name = "x" + std::to_string(i); ++// auto ge_desc = op.GetInputDescByName(input_name.c_str()); ++// std::string format = item["format"]; ++// NodeBase n{ge_desc.GetShape().GetDims(), ConvertGeDataType(ge_desc.GetDataType()), format}; ++// MS_LOG(DEBUG) << "input[" << i << "]: " << n.shape << " " << TypeIdToString(n.type); ++// all_tensors[item["tensor_name"]] = std::make_shared(n); ++// } ++// ++// // iter op_desc: infer each op ++// for (const auto &op_desc : js["op_desc"]) { ++// std::string op_name = op_desc["name"]; ++// auto op_ptr = mindspore::graphkernel::inner::OpRegistry::Instance().NewOp(op_name); ++// auto op_inputs = GetOpInputs(op_desc, all_tensors, op_name); ++// auto op_attr = GetOpAttr(op_desc); ++// auto infer_res = op_ptr->Infer(op_inputs, op_attr); ++// std::vector op_output_desc = op_desc["output_desc"]; ++// if (infer_res.size() != op_output_desc.size()) { ++// MS_LOG(ERROR) << "For op [" << op_name ++// << "], the length of inferred output shape list is not equal to the length of output_desc list: " ++// << infer_res.size() << " vs " << op_output_desc.size(); ++// return false; ++// } ++// for (size_t i = 0; i < op_output_desc.size(); ++i) { ++// std::string name = op_output_desc[i]["tensor_name"]; ++// all_tensors[name] = std::make_shared(infer_res[i]); ++// } ++// } ++// ++// // iter output_desc: combine the outputs info ++// std::vector output_desc = js["output_desc"]; ++// // format not need infer ++// std::vector output_formats; ++// if (op.GetAttr("output_formats", output_formats) != ge::GRAPH_SUCCESS || ++// output_formats.size() != output_desc.size()) { ++// return false; ++// } ++// ++// for (size_t i = 0; i < output_desc.size(); ++i) { ++// std::string name = output_desc[i]["tensor_name"]; ++// auto iter = all_tensors.find(name); ++// if (iter == all_tensors.end()) { ++// MS_LOG(ERROR) << "Tensor [" << name << "] not found in op_desc"; ++// return false; ++// } ++// auto shape = iter->second->shape; ++// (void)outputs_info->emplace_back(ge::Shape(shape), static_cast(output_formats[i]), ++// device::ascend::TransformUtil::ConvertDataType(iter->second->type)); ++// MS_LOG(DEBUG) << "output[" << i << "]: " << shape << " " << TypeIdToString(iter->second->type); ++// } ++// return true; ++//} ++// ++// bool InputsInfoNotChanged(const ge::Operator &op, const nlohmann::json &js) { ++// std::vector input_desc = js["input_desc"]; ++// for (size_t i = 0; i < input_desc.size(); ++i) { ++// std::string input_name = "x" + std::to_string(i); ++// auto ge_desc = op.GetInputDescByName(input_name.c_str()); ++// auto ge_shape = ge_desc.GetShape().GetDims(); ++// auto ge_type = ge_desc.GetDataType(); ++// const auto &item = input_desc[i][0]; ++// ShapeVector ms_shape = item["shape"]; ++// auto ms_type = StringToTypeId(item["data_type"]); ++// if (ge_shape != ms_shape || ConvertGeDataType(ge_type) != ms_type) { ++// return false; ++// } ++// } ++// return true; ++//} ++// ++// bool Infer(const ge::Operator &op, const std::string &op_key, const std::string &info_path, ++// std::vector *outputs_info) { ++// if (outputs_info == nullptr) { ++// return false; ++// } ++// ++// // read akg info and parse it to json format ++// std::ifstream info_str(info_path); ++// if (!info_str.is_open()) { ++// return false; ++// } ++// nlohmann::json js; ++// info_str >> js; ++// info_str.close(); ++// ++// // 1) if input information not changed, reuse the outputs info saved in op attr 2) else infer online ++// if (InputsInfoNotChanged(op, js)) { ++// MS_LOG(INFO) << "Infer shape offline for op " << op_key; ++// return InferOffline(op, outputs_info); ++// } ++// MS_LOG(INFO) << "Infer shape online for op " << op_key; ++// return InferOnline(op, js, outputs_info); ++//} ++//} // namespace ++// ++// ge::graphStatus CustomAkgOpInferFunc(ge::Operator &op) { ++// auto op_key = GetCustomOpKey(op); ++// MS_LOG(INFO) << "Start infer shape for op " << op_key; ++// ++// // get akg info path of current op ++// std::string info_path; ++// auto status = op.GetAttr("info_path", info_path); ++// if (status != ge::GRAPH_SUCCESS) { ++// return status; ++// } ++// ++// // infer shape ++// std::vector outputs_info; ++// try { ++// if (!Infer(op, op_key, info_path, &outputs_info)) { ++// MS_LOG(ERROR) << "Failed infer shape for op " << op_key << ", akg info path: " << info_path; ++// return ge::GRAPH_FAILED; ++// } ++// } catch (std::exception &e) { ++// MS_LOG(ERROR) << "Failed infer shape for op " << op_key << ", akg info path: " << info_path ++// << " error message: " << e.what(); ++// return ge::GRAPH_FAILED; ++// } ++// ++// // update output desc ++// for (size_t i = 0; i < outputs_info.size(); ++i) { ++// std::string output_name = "y" + std::to_string(i); ++// (void)op.UpdateOutputDesc(output_name, outputs_info[i]); ++// } ++// MS_LOG(INFO) << "End infer shape for op " << op_key; ++// return ge::GRAPH_SUCCESS; ++//} ++//#else + ge::graphStatus CustomAkgOpInferFunc(ge::Operator &) { return ge::GRAPH_SUCCESS; } +-#endif ++//#endif + + ge::graphStatus CustomTbeAicpuOpInferFunc(ge::Operator &op) { + auto op_key = GetCustomOpKey(op); +diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/transform_util.cc b/mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/transform_util.cc +index 62b27674d88..02e56a80379 100644 +--- a/mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/transform_util.cc ++++ b/mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/transform_util.cc +@@ -52,6 +52,16 @@ class MsTensorRel { + private: + mutable MeTensorPtr tensor_; + }; ++ ++class MsTensorRelNew { ++ public: ++ explicit MsTensorRelNew(const shared_ptr &tensor) : tensor_(tensor) {} ++ ~MsTensorRelNew() = default; ++ void Rel() const { tensor_ = nullptr; } ++ ++ private: ++ mutable shared_ptr tensor_; ++}; + } // namespace + + class TensorRefData : public tensor::TensorData { +@@ -373,6 +383,57 @@ GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::s + return tensor_ptr; + } + ++GeTensorPtr TransformUtil::ConvertTensor(const std::shared_ptr &tensor, const std::string &format, ++ bool copy) { ++ // get tensor data type size ++ MS_EXCEPTION_IF_NULL(tensor); ++ auto me_data_type = tensor->DataType(); ++#ifndef ENABLE_LITE_ACL ++ if (me_data_type == DataType::kObjectTypeString) { ++ return ConvertStringTensor(tensor, format); ++ } ++#endif ++ size_t type_size = GetDataTypeSize(static_cast(me_data_type)); ++ if (type_size == kErrorSize) { ++ MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size; ++ return nullptr; ++ } ++ ++ // get tensor buff size ++ size_t data_buff_size = tensor->DataSize(); ++ if (data_buff_size == 0) { ++ MS_LOG(INFO) << "The Me Tensor data buff size is 0."; ++ } ++ // create ge tensor ++ auto desc = GetGeTensorDesc(tensor->Shape(), static_cast(tensor->DataType()), format); ++ if (desc == nullptr) { ++ MS_LOG(ERROR) << "Failed to get Tensor Desc"; ++ return nullptr; ++ } ++ GeTensorPtr tensor_ptr = make_shared(*desc); ++ if (tensor_ptr == nullptr) { ++ MS_LOG(ERROR) << "Failed to convert Me Tensor to Ge Tensor!"; ++ return nullptr; ++ } ++ if (copy) { ++ auto ret = tensor_ptr->SetData(std::static_pointer_cast(tensor->Data()).get(), data_buff_size); ++ if (ret != ge::GRAPH_SUCCESS) { ++ MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(const uint8_t*, size), data size " << data_buff_size; ++ return nullptr; ++ } ++ } else { ++ MsTensorRelNew rel(tensor); ++ auto ret = tensor_ptr->SetData(static_cast(tensor->MutableData()), data_buff_size, ++ [rel](uint8_t *) -> void { rel.Rel(); }); ++ if (ret != ge::GRAPH_SUCCESS) { ++ MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(uint8_t*, size, DeleteFunc), data size " << data_buff_size; ++ return nullptr; ++ } ++ } ++ MS_LOG(DEBUG) << "Convert Me Tensor to Ge Tensor success!"; ++ return tensor_ptr; ++} ++ + GeTensorPtr TransformUtil::ConvertScalar(const ValuePtr &val) { + auto ge_tensor = ConvertAnyUtil(val, AnyTraits()); + return make_shared(ge_tensor); +diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/transform_util.h b/mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/transform_util.h +index 4b4ef556bf9..7e7907b8c7b 100644 +--- a/mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/transform_util.h ++++ b/mindspore/ccsrc/plugin/res_manager/ascend/op_adapter/transform_util.h +@@ -28,6 +28,7 @@ + #include "utils/shape_utils.h" + #include "plugin/res_manager/ascend/op_adapter/op_adapter_base.h" + #include "plugin/res_manager/ascend/visible.h" ++#include "include/api/types.h" + + namespace mindspore::device::ascend { + class ASCEND_RES_MANAGER_EXPORT TransformUtil { +@@ -78,6 +79,9 @@ class ASCEND_RES_MANAGER_EXPORT TransformUtil { + * */ + static GeTensorPtr ConvertTensor(const MeTensorPtr &tensor, const std::string &format, bool copy = true); + ++ static GeTensorPtr ConvertTensor(const std::shared_ptr &tensor, const std::string &format, ++ bool copy = true); ++ + /* + * Parameters: + * me_tensors: [vector] the data tensors in ME +diff --git a/mindspore/ccsrc/plugin/res_manager/ascend/symbol_interface/symbol_utils.h b/mindspore/ccsrc/plugin/res_manager/ascend/symbol_interface/symbol_utils.h +index 6847efecb3e..fb7bb08db67 100644 +--- a/mindspore/ccsrc/plugin/res_manager/ascend/symbol_interface/symbol_utils.h ++++ b/mindspore/ccsrc/plugin/res_manager/ascend/symbol_interface/symbol_utils.h +@@ -18,7 +18,6 @@ + #define MINDSPORE_CCSRC_TRANSFORM_SYMBOL_SYMBOL_UTILS_H_ + #include + #include "utils/log_adapter.h" +-#ifndef BUILD_LITE + #include "acl/acl.h" + #include "utils/ms_exception.h" + #include "include/backend/visible.h" +@@ -28,7 +27,6 @@ extern "C" BACKEND_EXPORT FuncGetRecentErrMsg acl_get_recent_err_msg; + + #ifndef ACL_ERROR_RT_DEVICE_MEM_ERROR + #define ACL_ERROR_RT_DEVICE_MEM_ERROR 507053 +-#endif + #ifndef ACL_ERROR_RT_HBM_MULTI_BIT_ECC_ERROR + #define ACL_ERROR_RT_HBM_MULTI_BIT_ECC_ERROR 507054 + #endif +diff --git a/mindspore/ccsrc/runtime/data_queue/data_queue_mgr.cc b/mindspore/ccsrc/runtime/data_queue/data_queue_mgr.cc +index 763f69654f8..320ac09206f 100644 +--- a/mindspore/ccsrc/runtime/data_queue/data_queue_mgr.cc ++++ b/mindspore/ccsrc/runtime/data_queue/data_queue_mgr.cc +@@ -250,7 +250,6 @@ DataQueueStatus DataQueueMgr::SetThreadDevice(const std::string &channel_name) c + return DataQueueStatus::SUCCESS; + } + +-#ifndef BUILD_LITE + void UpdateGetNextWithDataQueueItems(const AnfNodePtr &data_kernel, const std::vector &data) { + auto kernel_info = dynamic_cast(data_kernel->kernel_info()); + std::vector> device_tensors; +@@ -333,6 +332,5 @@ void UpdateGetNextNode(const PrimitivePtr &primitive, const std::vectorenable_recovery()) { + return false; + } +-#endif + + return true; + } +@@ -415,7 +411,6 @@ size_t GetDefragMemoryStepFreq() { + + bool WaitRuntimePipelineFinish(const OpContext *context, const std::string &name, + bool wait_kernel_launch_finish) { +-#ifndef BUILD_LITE + uint64_t start_time = 0; + PROFILER_START(start_time); + +@@ -443,9 +438,6 @@ bool WaitRuntimePipelineFinish(const OpContext *context, const std + return false; + } + return true; +-#else +- return true; +-#endif + } + + bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) { +diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc +index a4eb6297b02..6163ce7fab0 100644 +--- a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc ++++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc +@@ -128,17 +128,11 @@ void AddNodeToGraphTracker(const CNodePtr cnode, const std::string &actor_name) + std::vector comm_ranks; + if (group_name == "hccl_world_group") { + uint32_t rank_size = 1; +-#if !defined(BUILD_LITE) + rank_size = distributed::collective::CollectiveManager::instance()->global_rank_size(); +-#endif + comm_ranks.resize(rank_size); + std::iota(comm_ranks.begin(), comm_ranks.end(), 0); + } else { +-#if !defined(BUILD_LITE) + comm_ranks = distributed::collective::CollectiveManager::instance()->GetGroupRanks(group_name); +-#else +- comm_ranks = {0}; +-#endif + } + std::string comm_ranks_str = std::accumulate( + comm_ranks.begin(), comm_ranks.end(), std::string(), +diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_runner.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_runner.cc +index c74c57f3cd4..f323861ab2c 100644 +--- a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_runner.cc ++++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_runner.cc +@@ -126,17 +126,11 @@ void AddNodeToGraphTracker(const CNodePtr cnode, const std::string &actor_name) + std::vector comm_ranks; + if (group_name == "hccl_world_group") { + uint32_t rank_size = 1; +-#if !defined(BUILD_LITE) + rank_size = distributed::collective::CollectiveManager::instance()->global_rank_size(); +-#endif + comm_ranks.resize(rank_size); + std::iota(comm_ranks.begin(), comm_ranks.end(), 0); + } else { +-#if !defined(BUILD_LITE) + comm_ranks = distributed::collective::CollectiveManager::instance()->GetGroupRanks(group_name); +-#else +- comm_ranks = {0}; +-#endif + } + std::string comm_ranks_str = std::accumulate( + comm_ranks.begin(), comm_ranks.end(), std::string(), +diff --git a/mindspore/ccsrc/runtime/graph_scheduler/execution_order_check/comm_execution_order_check.cc b/mindspore/ccsrc/runtime/graph_scheduler/execution_order_check/comm_execution_order_check.cc +index eda99129c4f..6b8d918fcbe 100644 +--- a/mindspore/ccsrc/runtime/graph_scheduler/execution_order_check/comm_execution_order_check.cc ++++ b/mindspore/ccsrc/runtime/graph_scheduler/execution_order_check/comm_execution_order_check.cc +@@ -79,9 +79,7 @@ uint32_t Process::GetRankSize() { + static uint32_t rank_size = 1; + static bool is_initialized = false; + if (!is_initialized) { +-#if !defined(BUILD_LITE) + rank_size = distributed::collective::CollectiveManager::instance()->global_rank_size(); +-#endif + is_initialized = true; + } + +@@ -92,11 +90,9 @@ std::string Process::GetRankID() { + static uint32_t rank_id = 0; + static bool is_initialized = false; + if (!is_initialized) { +-#if !defined(BUILD_LITE) + if (distributed::collective::CollectiveManager::instance()->initialized()) { + rank_id = CommManager::GetInstance().GetRank(); + } +-#endif + is_initialized = true; + } + +@@ -462,11 +458,7 @@ void Process::FetchCommRanksCache(const std::string &group_name) { + comm_ranks.resize(GetRankSize()); + std::iota(comm_ranks.begin(), comm_ranks.end(), 0); + } else { +-#if !defined(BUILD_LITE) + comm_ranks = distributed::collective::CollectiveManager::instance()->GetGroupRanks(group_name); +-#else +- comm_ranks = {0}; +-#endif + } + comm_rank_cache_[group_name] = comm_ranks; + } +diff --git a/mindspore/core/include/utils/anf_utils.h b/mindspore/core/include/utils/anf_utils.h +index 8dbb05ad482..d17cf0f1012 100644 +--- a/mindspore/core/include/utils/anf_utils.h ++++ b/mindspore/core/include/utils/anf_utils.h +@@ -94,6 +94,8 @@ class MS_CORE_API AnfUtils { + static bool IsNodeInGraphKernel(const AnfNodePtr &node); + // check whether the node is a KernelPacket node. + static bool IsKernelPacket(const AnfNodePtr &node); ++ // check whether two node is same ++ static bool AnfEqual(const BaseRef &a, const BaseRef &b); + // Set dump flag to CNode's primitive. + static void SetDumpFlag(const AnfNodePtr &node); + // Get dump flag from CNode's primitive. +diff --git a/mindspore/core/mindrt/CMakeLists.txt b/mindspore/core/mindrt/CMakeLists.txt +index 171401d4e30..c6c1b9b45cf 100644 +--- a/mindspore/core/mindrt/CMakeLists.txt ++++ b/mindspore/core/mindrt/CMakeLists.txt +@@ -26,6 +26,16 @@ else() + ) + endif() + ++if (CMAKE_SYSTEM_NAME MATCHES "Generic") ++ set(MINDRT_SRC ${MINDRT_SRC} ++ ${CMAKE_CURRENT_SOURCE_DIR}/src/thread/threadpool_ohos.cc ++ ${CMAKE_CURRENT_SOURCE_DIR}/src/thread/core_affinity_ohos.cc) ++else() ++ set(MINDRT_SRC ${MINDRT_SRC} ++ ${CMAKE_CURRENT_SOURCE_DIR}/src/thread/core_affinity_normal.cc ++ ${CMAKE_CURRENT_SOURCE_DIR}/src/thread/threadpool_normal.cc) ++endif() ++ + if(CMAKE_SYSTEM_NAME MATCHES "Windows") + add_compile_definitions(BUILDING_CORE_DLL) + endif() +diff --git a/mindspore/core/mindrt/include/thread/core_affinity.h b/mindspore/core/mindrt/include/thread/core_affinity.h +index de6db97c2b8..882c6afb7cc 100644 +--- a/mindspore/core/mindrt/include/thread/core_affinity.h ++++ b/mindspore/core/mindrt/include/thread/core_affinity.h +@@ -20,18 +20,6 @@ + #include + #include + +-#ifdef __ANDROID__ +-#define BIND_CORE +-#include +-#endif +-#ifdef _WIN32 +-#define BIND_CORE +-#endif +-// Lite not support bind core. +-#if !defined(BUILD_LITE) && defined(__linux__) +-#define BIND_CORE +-#endif +- + namespace mindspore { + enum BindMode { + Power_NoBind = 0, // free schedule +@@ -62,12 +50,8 @@ class CoreAffinity { + static float GetServerFrequency(); + + private: +-#ifdef _WIN32 + int SetAffinity(); +-#elif defined(BIND_CORE) + int SetAffinity(const pthread_t &thread_id, cpu_set_t *cpu_set); +-#endif +- + int InitBindCoreId(size_t thread_num, BindMode bind_mode); + + int BindThreadsToCoreList(const std::vector &workers); +diff --git a/mindspore/core/mindrt/include/thread/threadpool.h b/mindspore/core/mindrt/include/thread/threadpool.h +index 511c3fd2b17..187b138f73c 100644 +--- a/mindspore/core/mindrt/include/thread/threadpool.h ++++ b/mindspore/core/mindrt/include/thread/threadpool.h +@@ -117,7 +117,7 @@ class Worker { + + #ifdef _WIN32 + uint64_t core_id() { return core_id_; } +-#elif defined(BIND_CORE) ++#elif defined(__ANDROID__) || defined(_WIN32) || defined(__linux__) + void set_mask(const cpu_set_t &mask) { mask_ = mask; } + pthread_t handle() { + THREAD_TEST_TRUE(thread_ == nullptr); +@@ -137,7 +137,7 @@ class Worker { + std::unique_ptr thread_{nullptr}; + #ifdef _WIN32 + uint64_t core_id_; +-#elif defined(BIND_CORE) ++#elif defined(__ANDROID__) || defined(_WIN32) || defined(__linux__) + cpu_set_t mask_; + #endif + std::atomic_int status_{kThreadBusy}; +diff --git a/mindspore/core/mindrt/src/thread/core_affinity.cc b/mindspore/core/mindrt/src/thread/core_affinity.cc +index eeff23dffef..77ec477ad7b 100644 +--- a/mindspore/core/mindrt/src/thread/core_affinity.cc ++++ b/mindspore/core/mindrt/src/thread/core_affinity.cc +@@ -309,7 +309,7 @@ std::vector CoreAffinity::GetCoreId(size_t thread_num, BindMode bind_mode) + std::vector bind_id; + #ifdef _WIN32 + return bind_id; +-#elif defined(BIND_CORE) ++#elif defined(__ANDROID__) || defined(_WIN32) || defined(__linux__) + if (core_num_ != sorted_id_.size()) { + THREAD_ERROR("init sorted core id failed"); + return bind_id; +@@ -340,40 +340,12 @@ int CoreAffinity::InitBindCoreId(size_t thread_num, BindMode bind_mode) { + #endif + return THREAD_OK; + } +- +-#ifdef _WIN32 + int CoreAffinity::SetAffinity() { return THREAD_OK; } +-#elif defined(BIND_CORE) +-int CoreAffinity::SetAffinity(const pthread_t &thread_id, cpu_set_t *cpu_set) { +-#ifdef __ANDROID__ +-#if __ANDROID_API__ >= 21 +- THREAD_INFO("thread: %d, mask: %lu", pthread_gettid_np(thread_id), cpu_set->__bits[0]); +- int ret = sched_setaffinity(pthread_gettid_np(thread_id), sizeof(cpu_set_t), cpu_set); +- if (ret != THREAD_OK) { +- THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", pthread_gettid_np(thread_id), ret); +- return THREAD_ERROR; +- } +-#endif +-#else +-#if defined(__APPLE__) +- THREAD_ERROR("not bind thread to apple's cpu."); +- return THREAD_ERROR; +-#else +- int ret = pthread_setaffinity_np(thread_id, sizeof(cpu_set_t), cpu_set); +- if (ret != THREAD_OK) { +- THREAD_ERROR("set thread: %lu to cpu failed", thread_id); +- return THREAD_ERROR; +- } +-#endif // __APPLE__ +-#endif +- return THREAD_OK; +-} +-#endif + + int CoreAffinity::FreeScheduleThreads(const std::vector &workers) { + #ifdef _WIN32 + return THREAD_OK; +-#elif defined(BIND_CORE) ++#elif defined(__ANDROID__) || defined(_WIN32) || defined(__linux__) + cpu_set_t mask; + CPU_ZERO(&mask); + for (int i : bind_id_) { +@@ -385,14 +357,14 @@ int CoreAffinity::FreeScheduleThreads(const std::vector &workers) { + return THREAD_ERROR; + } + } +-#endif // BIND_CORE ++#endif + return THREAD_OK; + } + + int CoreAffinity::BindThreadsToCoreList(const std::vector &workers) { + #ifdef _WIN32 + return THREAD_OK; +-#elif defined(BIND_CORE) ++#elif defined(__ANDROID__) || defined(__linux__) + if (bind_id_.empty()) { + THREAD_INFO("bind id is empty, it will not bind thread"); + return THREAD_OK; +@@ -411,14 +383,14 @@ int CoreAffinity::BindThreadsToCoreList(const std::vector &workers) { + THREAD_INFO("set thread[%zu] affinity to core[%d] success", i, bind_id_[i % window]); + workers[i]->set_frequency(core_freq_[bind_id_[i]]); + } +-#endif // BIND_CORE ++#endif // __ANDROID__ __linux__ + return THREAD_OK; + } + + int CoreAffinity::BindProcess(BindMode bind_mode) { + #ifdef _WIN32 + return THREAD_OK; +-#elif defined(BIND_CORE) ++#elif defined(__ANDROID__) || defined(__linux__) + if (bind_id_.empty()) { + // initializes bind id before bind currently process + THREAD_ERROR("bind id is empty"); +@@ -436,7 +408,7 @@ int CoreAffinity::BindProcess(BindMode bind_mode) { + return SetAffinity(pthread_self(), &mask); + #else + return THREAD_OK; +-#endif // BIND_CORE ++#endif // __ANDROID__ __linux__ + } + + int CoreAffinity::BindThreads(const std::vector &workers, BindMode bind_mode) { +diff --git a/mindspore/core/mindrt/src/thread/core_affinity_normal.cc b/mindspore/core/mindrt/src/thread/core_affinity_normal.cc +new file mode 100644 +index 00000000000..3b3a9b8967d +--- /dev/null ++++ b/mindspore/core/mindrt/src/thread/core_affinity_normal.cc +@@ -0,0 +1,57 @@ ++/** ++ * Copyright 2021-2023 Huawei Technologies Co., Ltd ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++#include "thread/core_affinity.h" ++#include ++#include ++#include ++#include ++#ifdef MS_COMPILE_IOS ++#include ++#include ++#include ++#endif // MS_COMPILE_IOS ++#include "thread/threadpool.h" ++#ifdef _WIN32 ++#include ++#endif ++ ++namespace mindspore { ++int CoreAffinity::SetAffinity(const pthread_t &thread_id, cpu_set_t *cpu_set) { ++#ifdef __ANDROID__ ++#if __ANDROID_API__ >= 21 ++ THREAD_INFO("thread: %d, mask: %lu", pthread_gettid_np(thread_id), cpu_set->__bits[0]); ++ int ret = sched_setaffinity(pthread_gettid_np(thread_id), sizeof(cpu_set_t), cpu_set); ++ if (ret != THREAD_OK) { ++ THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", pthread_gettid_np(thread_id), ret); ++ return THREAD_ERROR; ++ } ++#endif ++#else ++#if defined(__APPLE__) ++ THREAD_ERROR("not bind thread to apple's cpu."); ++ return THREAD_ERROR; ++#else ++ int ret = pthread_setaffinity_np(thread_id, sizeof(cpu_set_t), cpu_set); ++ if (ret != THREAD_OK) { ++ THREAD_ERROR("set thread: %lu to cpu failed", thread_id); ++ return THREAD_ERROR; ++ } ++#endif // __APPLE__ ++#endif ++ return THREAD_OK; ++} ++} // namespace mindspore +diff --git a/mindspore/core/mindrt/src/thread/core_affinity_ohos.cc b/mindspore/core/mindrt/src/thread/core_affinity_ohos.cc +new file mode 100644 +index 00000000000..44f20bcd9f0 +--- /dev/null ++++ b/mindspore/core/mindrt/src/thread/core_affinity_ohos.cc +@@ -0,0 +1,36 @@ ++/** ++ * Copyright 2021-2023 Huawei Technologies Co., Ltd ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++#include "thread/core_affinity.h" ++#include ++#include ++#include ++#include ++#ifdef MS_COMPILE_IOS ++#include ++#include ++#include ++#endif // MS_COMPILE_IOS ++#include "thread/threadpool.h" ++#ifdef _WIN32 ++#include ++#endif ++ ++namespace mindspore { ++int CoreAffinity::SetAffinity(const pthread_t &thread_id, cpu_set_t *cpu_set) { ++ return THREAD_OK; ++} ++} // namespace mindspore +diff --git a/mindspore/core/mindrt/src/thread/threadpool.cc b/mindspore/core/mindrt/src/thread/threadpool.cc +index ea84f779812..452c5ee8389 100644 +--- a/mindspore/core/mindrt/src/thread/threadpool.cc ++++ b/mindspore/core/mindrt/src/thread/threadpool.cc +@@ -67,7 +67,7 @@ void Worker::ChildAfterFork() { + } + } + +-#if defined(BIND_CORE) && !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_MSC_VER) && !defined(_WIN32) ++#if defined(__linux__) && !defined(_MSC_VER) + std::string MaskToStr(const cpu_set_t *mask) { + std::stringstream ss; + size_t cpu_num = static_cast(sysconf(_SC_NPROCESSORS_ONLN)); +@@ -78,184 +78,12 @@ std::string MaskToStr(const cpu_set_t *mask) { + } + #endif + +-void Worker::SetAffinity() { +-#ifdef _WIN32 +- SetWindowsSelfAffinity(core_id_); +-#elif defined(BIND_CORE) +-#ifdef __ANDROID__ +- int ret = sched_setaffinity(gettid(), sizeof(cpu_set_t), &mask_); +- if (ret != THREAD_OK) { +- THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", gettid(), errno); +- } +- return; +-#else +-#if !defined(__APPLE__) && !defined(_MSC_VER) +- +- THREAD_INFO("Worker pthread_setaffinity_np, mask %s", MaskToStr(&mask_)); +- int ret = pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &mask_); +- if (ret != THREAD_OK) { +- THREAD_ERROR("bind thread %lu to cpu failed. ERROR %d", pthread_self(), errno); +- } +- return; +-#endif +-#endif +-#endif +-} +- +-std::vector parse_cpu_list(const std::string &cpu_str) { +- std::vector cpus; +- std::stringstream ss(cpu_str); +- std::string item; +- +- while (std::getline(ss, item, ',')) { +- cpus.push_back(std::stoi(item)); +- } +- return cpus; +-} +- +-#if defined(__linux__) +-void ThreadPool::ThreadPoolSetAffinity(const std::vector &threads) { +-#if defined(BIND_CORE) && !defined(__ANDROID__) +- auto thread_num = threads.size(); +- MS_LOG(INFO) << "Start to bind core for actor thread for [" << thread_num << "] threads."; +- auto env_runtime_reserved = std::getenv("CONFIG_BIND_RUNTIME_LIST"); +- if (env_runtime_reserved == nullptr) { +- return; +- } +- +- MS_LOG(WARNING) << "Start to bind core base on CONFIG_BIND_RUNTIME_LIST."; +- +- std::vector cpu_list = parse_cpu_list(std::string(env_runtime_reserved)); +- if (cpu_list.empty()) { +- MS_LOG(WARNING) << "Cpu list is empty, bind core is not enabled."; +- return; +- } +- int ret; +- auto env_enable_fix = std::getenv("ACTOR_THREAD_FIX_BIND"); +- if (env_enable_fix != nullptr && (std::string(env_enable_fix) == "false" || std::string(env_enable_fix) == "False")) { +- cpu_set_t cpuset; +- CPU_ZERO(&cpuset); +- +- for (const auto &cpu_id : cpu_list) { +- CPU_SET(static_cast(cpu_id), &cpuset); +- } +- +- for (size_t i = 0; i < thread_num; i++) { +- ret = pthread_setaffinity_np(threads[i], sizeof(cpu_set_t), &cpuset); +- if (ret != 0) { +- MS_LOG(WARNING) << "Fail to bind core to " << cpu_list << " for thread " << threads[i]; +- } else { +- MS_LOG(WARNING) << "Success to bind core to " << cpu_list << " for thread " << threads[i]; +- } +- } +- } else if (env_enable_fix != nullptr && (std::string(env_enable_fix) == "cluster")) { +- int cpus_per_thread = cpu_list.size() / thread_num; +- int offset = 0; +- +- for (size_t i = 0; i < thread_num; i++) { +- cpu_set_t cpuset; +- CPU_ZERO(&cpuset); +- std::vector sub_list(cpu_list.begin() + offset, cpu_list.begin() + offset + cpus_per_thread); +- for (const auto &cpu_id : sub_list) { +- CPU_SET(static_cast(cpu_id), &cpuset); +- } +- ret = pthread_setaffinity_np(threads[i], sizeof(cpu_set_t), &cpuset); +- if (ret != 0) { +- MS_LOG(WARNING) << "Fail to bind core to " << sub_list << " for thread " << threads[i]; +- } else { +- MS_LOG(WARNING) << "Success to bind core to " << sub_list << " for thread " << threads[i]; +- } +- offset += cpus_per_thread; +- } +- } else { +- for (size_t i = 0; i < thread_num; i++) { +- cpu_set_t cpuset; +- CPU_ZERO(&cpuset); +- CPU_SET(cpu_list[i % cpu_list.size()], &cpuset); +- +- ret = pthread_setaffinity_np(threads[i], sizeof(cpu_set_t), &cpuset); +- if (ret != 0) { +- MS_LOG(WARNING) << "Fail to bind core to " << cpu_list[i % cpu_list.size()] << " for thread " << threads[i]; +- } else { +- MS_LOG(WARNING) << "Success to bind core to " << cpu_list[i % cpu_list.size()] << " for thread " << threads[i]; +- } +- } +- } +-#else +- MS_LOG(ERROR) << "Call not implemented function."; +- MINDRT_EXIT("Not implemented error"); +-#endif +-} +- +-void ThreadPool::APIThreadPoolSetAffinity(const std::vector &threads, const std::vector &cpu_list, +- const std::string actor_thread_fix_bind) { +-#if defined(BIND_CORE) && !defined(__ANDROID__) +- auto thread_num = threads.size(); +- MS_LOG(INFO) << "Start to bind core for actor thread for [" << thread_num << "] threads."; +- int ret; +- if (!actor_thread_fix_bind.empty() && +- (std::string(actor_thread_fix_bind) == "true" || std::string(actor_thread_fix_bind) == "True")) { +- for (size_t i = 0; i < thread_num; i++) { +- cpu_set_t cpuset; +- CPU_ZERO(&cpuset); +- CPU_SET(cpu_list[i % cpu_list.size()], &cpuset); +- +- ret = pthread_setaffinity_np(threads[i], sizeof(cpu_set_t), &cpuset); +- if (ret != 0) { +- MS_LOG(WARNING) << "Fail to bind core to " << cpu_list[i % cpu_list.size()] << " for thread " << threads[i]; +- } else { +- MS_LOG(WARNING) << "Success to bind core to " << cpu_list[i % cpu_list.size()] << " for thread " << threads[i]; +- } +- } +- } else if (!actor_thread_fix_bind.empty() && (std::string(actor_thread_fix_bind) == "cluster")) { +- int cpus_per_thread = cpu_list.size() / thread_num; +- int offset = 0; +- +- for (size_t i = 0; i < thread_num; i++) { +- cpu_set_t cpuset; +- CPU_ZERO(&cpuset); +- std::vector sub_list(cpu_list.begin() + offset, cpu_list.begin() + offset + cpus_per_thread); +- for (const auto &cpu_id : sub_list) { +- CPU_SET(static_cast(cpu_id), &cpuset); +- } +- ret = pthread_setaffinity_np(threads[i], sizeof(cpu_set_t), &cpuset); +- if (ret != 0) { +- MS_LOG(WARNING) << "Fail to bind core to " << sub_list << " for thread " << threads[i]; +- } else { +- MS_LOG(WARNING) << "Success to bind core to " << sub_list << " for thread " << threads[i]; +- } +- offset += cpus_per_thread; +- } +- } else { +- cpu_set_t cpuset; +- CPU_ZERO(&cpuset); +- +- for (const auto &cpu_id : cpu_list) { +- CPU_SET(static_cast(cpu_id), &cpuset); +- } +- +- for (size_t i = 0; i < thread_num; i++) { +- ret = pthread_setaffinity_np(threads[i], sizeof(cpu_set_t), &cpuset); +- if (ret != 0) { +- MS_LOG(WARNING) << "Fail to bind core to " << cpu_list << " for thread " << threads[i]; +- } else { +- MS_LOG(WARNING) << "Success to bind core to " << cpu_list << " for thread " << threads[i]; +- } +- } +- } +-#else +- MS_LOG(ERROR) << "Call not implemented function."; +- MINDRT_EXIT("Not implemented error"); +-#endif +-} +-#endif +- + void Worker::InitWorkerMask(const std::vector &core_list, const size_t workers_size) { + core_list_ = core_list; + #ifdef _WIN32 + static uint32_t windows_core_index = 0; + core_id_ = windows_core_index++; +-#elif defined(BIND_CORE) ++#elif defined(__ANDROID__) || defined(__linux__) + if (core_list.empty()) { + return; + } +@@ -604,7 +432,7 @@ Worker *ThreadPool::CurrentWorker() const { + } + + int ThreadPool::InitAffinityInfo() { +-#ifdef BIND_CORE ++#if defined(__ANDROID__) || defined(_WIN32) || defined(__linux__) + affinity_ = new (std::nothrow) CoreAffinity(); + THREAD_ERROR_IF_NULL(affinity_); + int ret = affinity_->InitHardwareCoreInfo(); +diff --git a/mindspore/core/mindrt/src/thread/threadpool_normal.cc b/mindspore/core/mindrt/src/thread/threadpool_normal.cc +new file mode 100644 +index 00000000000..57d2d4acfe3 +--- /dev/null ++++ b/mindspore/core/mindrt/src/thread/threadpool_normal.cc +@@ -0,0 +1,194 @@ ++/** ++ * Copyright 2021-2023 Huawei Technologies Co., Ltd ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++#ifndef _MSC_VER ++#include ++#include ++#endif ++#include ++#include ++#include ++#include ++#include ++#include ++#include "thread/threadpool.h" ++#include "thread/core_affinity.h" ++ ++namespace mindspore { ++std::vector parse_cpu_list(const std::string &cpu_str) { ++ std::vector cpus; ++ std::stringstream ss(cpu_str); ++ std::string item; ++ ++ while (std::getline(ss, item, ',')) { ++ cpus.push_back(std::stoi(item)); ++ } ++ return cpus; ++} ++ ++void Worker::SetAffinity() { ++#ifdef _WIN32 ++ SetWindowsSelfAffinity(core_id_); ++#elif defined(__ANDROID__) ++ int ret = sched_setaffinity(gettid(), sizeof(cpu_set_t), &mask_); ++ if (ret != THREAD_OK) { ++ THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", gettid(), errno); ++ } ++ return; ++#else ++#if !defined(__APPLE__) && !defined(_MSC_VER) ++ THREAD_INFO("Worker pthread_setaffinity_np, mask %s", MaskToStr(&mask_)); ++ int ret = pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &mask_); ++ if (ret != THREAD_OK) { ++ THREAD_ERROR("bind thread %lu to cpu failed. ERROR %d", pthread_self(), errno); ++ } ++ return; ++#endif ++#endif ++} ++ ++void ThreadPool::ThreadPoolSetAffinity(const std::vector &threads) { ++ auto thread_num = threads.size(); ++ MS_LOG(INFO) << "Start to bind core for actor thread for [" << thread_num << "] threads."; ++#if defined(__linux__) && !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_MSC_VER) && !defined(_WIN32) ++ auto env_runtime_reserved = std::getenv("CONFIG_BIND_RUNTIME_LIST"); ++ if (env_runtime_reserved == nullptr) { ++ return; ++ } ++ ++ MS_LOG(WARNING) << "Start to bind core base on CONFIG_BIND_RUNTIME_LIST."; ++ ++ std::vector cpu_list = parse_cpu_list(std::string(env_runtime_reserved)); ++ if (cpu_list.empty()) { ++ MS_LOG(WARNING) << "Cpu list is empty, bind core is not enabled."; ++ return; ++ } ++ int ret; ++ auto env_enable_fix = std::getenv("ACTOR_THREAD_FIX_BIND"); ++ if (env_enable_fix != nullptr && (std::string(env_enable_fix) == "false" || std::string(env_enable_fix) == "False")) { ++ cpu_set_t cpuset; ++ CPU_ZERO(&cpuset); ++ ++ for (const auto &cpu_id : cpu_list) { ++ CPU_SET(static_cast(cpu_id), &cpuset); ++ } ++ ++ for (size_t i = 0; i < thread_num; i++) { ++ ret = pthread_setaffinity_np(threads[i], sizeof(cpu_set_t), &cpuset); ++ if (ret != 0) { ++ MS_LOG(WARNING) << "Fail to bind core to " << cpu_list << " for thread " << threads[i]; ++ } else { ++ MS_LOG(WARNING) << "Success to bind core to " << cpu_list << " for thread " << threads[i]; ++ } ++ } ++ } else if (env_enable_fix != nullptr && (std::string(env_enable_fix) == "cluster")) { ++ int cpus_per_thread = cpu_list.size() / thread_num; ++ int offset = 0; ++ ++ for (size_t i = 0; i < thread_num; i++) { ++ cpu_set_t cpuset; ++ CPU_ZERO(&cpuset); ++ std::vector sub_list(cpu_list.begin() + offset, cpu_list.begin() + offset + cpus_per_thread); ++ for (const auto &cpu_id : sub_list) { ++ CPU_SET(static_cast(cpu_id), &cpuset); ++ } ++ ret = pthread_setaffinity_np(threads[i], sizeof(cpu_set_t), &cpuset); ++ if (ret != 0) { ++ MS_LOG(WARNING) << "Fail to bind core to " << sub_list << " for thread " << threads[i]; ++ } else { ++ MS_LOG(WARNING) << "Success to bind core to " << sub_list << " for thread " << threads[i]; ++ } ++ offset += cpus_per_thread; ++ } ++ } else { ++ for (size_t i = 0; i < thread_num; i++) { ++ cpu_set_t cpuset; ++ CPU_ZERO(&cpuset); ++ CPU_SET(cpu_list[i % cpu_list.size()], &cpuset); ++ ++ ret = pthread_setaffinity_np(threads[i], sizeof(cpu_set_t), &cpuset); ++ if (ret != 0) { ++ MS_LOG(WARNING) << "Fail to bind core to " << cpu_list[i % cpu_list.size()] << " for thread " ++ << threads[i]; ++ } else { ++ MS_LOG(WARNING) << "Success to bind core to " << cpu_list[i % cpu_list.size()] << " for thread " ++ << threads[i]; ++ } ++ } ++ } ++#endif ++} ++ ++void ThreadPool::APIThreadPoolSetAffinity(const std::vector &threads, const std::vector &cpu_list, ++ const std::string actor_thread_fix_bind) { ++ auto thread_num = threads.size(); ++ MS_LOG(INFO) << "Start to bind core for actor thread for [" << thread_num << "] threads."; ++#if defined(__linux__) && !defined(__ANDROID__) && !defined(__APPLE__) && !defined(_MSC_VER) && !defined(_WIN32) ++ int ret; ++ if (!actor_thread_fix_bind.empty() && ++ (std::string(actor_thread_fix_bind) == "true" || std::string(actor_thread_fix_bind) == "True")) { ++ for (size_t i = 0; i < thread_num; i++) { ++ cpu_set_t cpuset; ++ CPU_ZERO(&cpuset); ++ CPU_SET(cpu_list[i % cpu_list.size()], &cpuset); ++ ++ ret = pthread_setaffinity_np(threads[i], sizeof(cpu_set_t), &cpuset); ++ if (ret != 0) { ++ MS_LOG(WARNING) << "Fail to bind core to " << cpu_list[i % cpu_list.size()] << " for thread " ++ << threads[i]; ++ } else { ++ MS_LOG(WARNING) << "Success to bind core to " << cpu_list[i % cpu_list.size()] << " for thread " ++ << threads[i]; ++ } ++ } ++ } else if (!actor_thread_fix_bind.empty() && (std::string(actor_thread_fix_bind) == "cluster")) { ++ int cpus_per_thread = cpu_list.size() / thread_num; ++ int offset = 0; ++ ++ for (size_t i = 0; i < thread_num; i++) { ++ cpu_set_t cpuset; ++ CPU_ZERO(&cpuset); ++ std::vector sub_list(cpu_list.begin() + offset, cpu_list.begin() + offset + cpus_per_thread); ++ for (const auto &cpu_id : sub_list) { ++ CPU_SET(static_cast(cpu_id), &cpuset); ++ } ++ ret = pthread_setaffinity_np(threads[i], sizeof(cpu_set_t), &cpuset); ++ if (ret != 0) { ++ MS_LOG(WARNING) << "Fail to bind core to " << sub_list << " for thread " << threads[i]; ++ } else { ++ MS_LOG(WARNING) << "Success to bind core to " << sub_list << " for thread " << threads[i]; ++ } ++ offset += cpus_per_thread; ++ } ++ } else { ++ cpu_set_t cpuset; ++ CPU_ZERO(&cpuset); ++ ++ for (const auto &cpu_id : cpu_list) { ++ CPU_SET(static_cast(cpu_id), &cpuset); ++ } ++ ++ for (size_t i = 0; i < thread_num; i++) { ++ ret = pthread_setaffinity_np(threads[i], sizeof(cpu_set_t), &cpuset); ++ if (ret != 0) { ++ MS_LOG(WARNING) << "Fail to bind core to " << cpu_list << " for thread " << threads[i]; ++ } else { ++ MS_LOG(WARNING) << "Success to bind core to " << cpu_list << " for thread " << threads[i]; ++ } ++ } ++ } ++#endif ++} ++} // namespace mindspore +diff --git a/mindspore/core/mindrt/src/thread/threadpool_ohos.cc b/mindspore/core/mindrt/src/thread/threadpool_ohos.cc +new file mode 100644 +index 00000000000..137a87e5c61 +--- /dev/null ++++ b/mindspore/core/mindrt/src/thread/threadpool_ohos.cc +@@ -0,0 +1,48 @@ ++/** ++ * Copyright 2021-2023 Huawei Technologies Co., Ltd ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++#ifndef _MSC_VER ++#include ++#include ++#endif ++#include ++#include ++#include ++#include ++#include ++#include ++#include "thread/threadpool.h" ++#include "thread/core_affinity.h" ++ ++namespace mindspore { ++void Worker::SetAffinity() { ++ int ret = sched_setaffinity(gettid(), sizeof(cpu_set_t), &mask_); ++ if (ret != THREAD_OK) { ++ THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", gettid(), errno); ++ } ++ return; ++} ++ ++void ThreadPool::ThreadPoolSetAffinity(const std::vector &threads) { ++ auto thread_num = threads.size(); ++ MS_LOG(INFO) << "Start to bind core for actor thread for [" << thread_num << "] threads."; ++} ++ ++void ThreadPool::APIThreadPoolSetAffinity(const std::vector &threads, const std::vector &cpu_list, ++ const std::string actor_thread_fix_bind) { ++ auto thread_num = threads.size(); ++ MS_LOG(INFO) << "Start to bind core for actor thread for [" << thread_num << "] threads."; ++} ++} // namespace mindspore +diff --git a/mindspore/core/utils/anf_utils.cc b/mindspore/core/utils/anf_utils.cc +index c1e5aaae4d2..b11dddff168 100644 +--- a/mindspore/core/utils/anf_utils.cc ++++ b/mindspore/core/utils/anf_utils.cc +@@ -480,6 +480,65 @@ bool AnfUtils::IsKernelPacket(const AnfNodePtr &node) { + return func_graph != nullptr && func_graph->has_attr(FUNC_GRAPH_ATTR_KERNEL_PACKET); + } + ++bool AnfUtils::AnfEqual(const BaseRef &a, const BaseRef &b) { ++ if (utils::isa(a) && utils::isa(b)) { ++ auto a_node = utils::cast(a); ++ auto b_node = utils::cast(b); ++ MS_EXCEPTION_IF_NULL(a_node); ++ MS_EXCEPTION_IF_NULL(b_node); ++ if (IsValueNode(a_node) && IsValueNode(b_node)) { ++ auto a_value_node = a_node->cast(); ++ MS_EXCEPTION_IF_NULL(a_value_node); ++ auto a_value = a_value_node->value(); ++ MS_EXCEPTION_IF_NULL(a_value); ++ auto a_prim = a_value->cast(); ++ MS_EXCEPTION_IF_NULL(a_prim); ++ ++ auto b_value_node = b_node->cast(); ++ MS_EXCEPTION_IF_NULL(b_value_node); ++ auto b_value = b_value_node->value(); ++ MS_EXCEPTION_IF_NULL(b_value); ++ auto b_prim = b_value->cast(); ++ MS_EXCEPTION_IF_NULL(b_prim); ++ ++ return a_prim->name() == b_prim->name(); ++ } else if (a_node->isa() && b_node->isa()) { ++ auto a_value_node_ptr = a_node->cast(); ++ if (a_value_node_ptr == nullptr) { ++ MS_LOG(INTERNAL_EXCEPTION) << "Cast value node ptr fail, node: " << a_node->DebugString(); ++ } ++ auto a_value_ptr = a_value_node_ptr->value(); ++ if (a_value_ptr == nullptr) { ++ MS_LOG(INTERNAL_EXCEPTION) << "Value ptr is nullptr, node: " << a_node->DebugString(); ++ } ++ ++ auto b_value_node_ptr = b_node->cast(); ++ if (b_value_node_ptr == nullptr) { ++ MS_LOG(INTERNAL_EXCEPTION) << "Cast value node ptr fail, node: " << b_node->DebugString(); ++ } ++ auto b_value_ptr = b_value_node_ptr->value(); ++ if (b_value_ptr == nullptr) { ++ MS_LOG(INTERNAL_EXCEPTION) << "Value ptr is nullptr, node: " << b_node->DebugString(); ++ } ++ if (a_value_ptr->isa() && b_value_ptr->isa()) { ++ auto a_tensor_ptr = a_value_ptr->cast(); ++ auto b_tensor_ptr = b_value_ptr->cast(); ++ if (a_tensor_ptr == nullptr || b_tensor_ptr == nullptr) { ++ MS_LOG(INTERNAL_EXCEPTION) << "Cast value node ptr fail."; ++ } ++ return a_tensor_ptr->ValueEqual(*b_tensor_ptr); ++ } else { ++ return (*a_value_ptr) == (*b_value_ptr); ++ } ++ } ++ MS_LOG(DEBUG) << "check AnfNodePtr equal"; ++ } ++ if (utils::isa(a) && utils::isa(b)) { ++ MS_LOG(DEBUG) << "check GraphPtr equal"; ++ } ++ return a == b; ++} ++ + bool AnfUtils::IsNodeInGraphKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + return node->func_graph() != nullptr && node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); +diff --git a/mindspore/ops/kernel/cpu/nnacl/CMakeLists.txt b/mindspore/ops/kernel/cpu/nnacl/CMakeLists.txt +index a43057a612f..a43577439cd 100644 +--- a/mindspore/ops/kernel/cpu/nnacl/CMakeLists.txt ++++ b/mindspore/ops/kernel/cpu/nnacl/CMakeLists.txt +@@ -241,7 +241,7 @@ if(PLATFORM_ARM) + COMPILE_FLAGS "${CMAKE_C_FLAGS} -fno-fast-math") + endif() + +-add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC} ${MS_X86_SIMD_SRC}) ++add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${ASSEMBLY_SRC} ${MS_X86_SIMD_SRC}) + + if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") + target_compile_definitions(nnacl_mid PRIVATE ENABLE_DEBUG) +diff --git a/mindspore/ops/kernel/cpu/nnacl/kernel/arithmetic.c b/mindspore/ops/kernel/cpu/nnacl/kernel/arithmetic.c +index 8bf95cf2b04..d63d484ba29 100644 +--- a/mindspore/ops/kernel/cpu/nnacl/kernel/arithmetic.c ++++ b/mindspore/ops/kernel/cpu/nnacl/kernel/arithmetic.c +@@ -385,7 +385,7 @@ int ArithmeticBroadCastConstTensor(ArithmeticStruct *arithmetic) { + + CalcStructMultiplesAndStrides(arithmetic); + +-#ifdef PARALLEL_INFERENCE ++#ifdef MSLITE_ENABLE_CLOUD_INFERENCE + bool prefer_explicit_broadcast = false; + #else + bool prefer_explicit_broadcast = arithmetic->ndim_ != 1; +diff --git a/tests/ut/cpp/pre_activate/common/fast_pattern_to_pattern_pass_test.cc b/tests/ut/cpp/pre_activate/common/fast_pattern_to_pattern_pass_test.cc +index bf617d7f116..3684b107b62 100644 +--- a/tests/ut/cpp/pre_activate/common/fast_pattern_to_pattern_pass_test.cc ++++ b/tests/ut/cpp/pre_activate/common/fast_pattern_to_pattern_pass_test.cc +@@ -198,19 +198,19 @@ TEST_F(TestFastPatternToPatternPass, Mul0) { + ASSERT_TRUE(check.build_pattern_map(new_node)); + + // check +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("c"), c)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("a"), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("b"), b)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("c"), c)); + ASSERT_EQ(check.m_->Get("bc")->cast()->size(), 3); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(0), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("bc")->cast()->input(0), + NewValueNode(std::make_shared(kAddOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(1), b)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(2), c)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("bc")->cast()->input(1), b)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("bc")->cast()->input(2), c)); + ASSERT_EQ(check.m_->Get("mul")->cast()->size(), 3); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(0), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(0), + NewValueNode(std::make_shared(kMulOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(1), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(2), check.m_->Get("bc"))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(1), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(2), check.m_->Get("bc"))); + } + + /// Feature: Fast PatternToPattern Pass +@@ -313,28 +313,28 @@ TEST_F(TestFastPatternToPatternPass, Mul0NotRoot) { + ASSERT_TRUE(check.build_pattern_map(add1)); + + // check +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("c"), c)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("d"), d)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("a"), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("b"), b)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("c"), c)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("d"), d)); + + ASSERT_EQ(check.m_->Get("bc")->cast()->size(), 3); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(0), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("bc")->cast()->input(0), + NewValueNode(std::make_shared(kAddOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(1), b)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(2), c)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("bc")->cast()->input(1), b)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("bc")->cast()->input(2), c)); + + ASSERT_EQ(check.m_->Get("mul")->cast()->size(), 3); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(0), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(0), + NewValueNode(std::make_shared(kMulOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(1), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(2), check.m_->Get("bc"))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(1), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(2), check.m_->Get("bc"))); + + ASSERT_EQ(check.m_->Get("add1")->cast()->size(), 3); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast()->input(0), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("add1")->cast()->input(0), + NewValueNode(std::make_shared(kAddOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast()->input(1), check.m_->Get("mul"))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast()->input(2), d)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("add1")->cast()->input(1), check.m_->Get("mul"))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("add1")->cast()->input(2), d)); + } + + /// Feature: Fast PatternToPattern Pass +@@ -452,21 +452,21 @@ TEST_F(TestFastPatternToPatternPass, Mul1) { + ASSERT_TRUE(check.build_pattern_map(add1)); + + // check +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("d"), d)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("e"), e)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("a"), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("d"), d)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("e"), e)); + + ASSERT_EQ(check.m_->Get("ad")->cast()->size(), 3); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast()->input(0), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("ad")->cast()->input(0), + NewValueNode(std::make_shared(kMulOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast()->input(1), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast()->input(2), d)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("ad")->cast()->input(1), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("ad")->cast()->input(2), d)); + + ASSERT_EQ(check.m_->Get("add1")->cast()->size(), 3); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast()->input(0), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("add1")->cast()->input(0), + NewValueNode(std::make_shared(kAddOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast()->input(1), check.m_->Get("ad"))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast()->input(2), e)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("add1")->cast()->input(1), check.m_->Get("ad"))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("add1")->cast()->input(2), e)); + } + + namespace { +@@ -524,35 +524,35 @@ void Check1(const TestFastMul2 &pass, const FuncGraphIndexPtr &fg, const std::ma + } + + void Check2(const CheckPattern &check, const std::map &node_map) { +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kA), node_map.at(kA))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kB), node_map.at(kB))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kC), node_map.at(kC))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kD), node_map.at(kD))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kE), node_map.at(kE))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kA), node_map.at(kA))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kB), node_map.at(kB))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kC), node_map.at(kC))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kD), node_map.at(kD))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kE), node_map.at(kE))); + + ASSERT_EQ(check.m_->Get(kAAddB)->cast()->size(), kThree); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast()->input(kZero), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kAAddB)->cast()->input(kZero), + NewValueNode(std::make_shared(kAddOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast()->input(kOne), node_map.at(kA))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast()->input(kTwo), node_map.at(kB))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kAAddB)->cast()->input(kOne), node_map.at(kA))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kAAddB)->cast()->input(kTwo), node_map.at(kB))); + + ASSERT_EQ(check.m_->Get(kCAddD)->cast()->size(), kThree); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast()->input(kZero), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kCAddD)->cast()->input(kZero), + NewValueNode(std::make_shared(kAddOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast()->input(kOne), node_map.at(kC))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast()->input(kTwo), node_map.at(kD))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kCAddD)->cast()->input(kOne), node_map.at(kC))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kCAddD)->cast()->input(kTwo), node_map.at(kD))); + + ASSERT_EQ(check.m_->Get(kMul)->cast()->size(), kThree); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast()->input(kZero), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kMul)->cast()->input(kZero), + NewValueNode(std::make_shared(kMulOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast()->input(kOne), node_map.at(kCAddD))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast()->input(kTwo), node_map.at(kAAddB))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kMul)->cast()->input(kOne), node_map.at(kCAddD))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kMul)->cast()->input(kTwo), node_map.at(kAAddB))); + + ASSERT_EQ(check.m_->Get(kAdd)->cast()->size(), kThree); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast()->input(kZero), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kAdd)->cast()->input(kZero), + NewValueNode(std::make_shared(kAddOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast()->input(kOne), check.m_->Get(kMul))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast()->input(kTwo), node_map.at(kE))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kAdd)->cast()->input(kOne), check.m_->Get(kMul))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get(kAdd)->cast()->input(kTwo), node_map.at(kE))); + } + } // namespace + +diff --git a/tests/ut/cpp/pre_activate/common/pattern_to_pattern_pass_test.cc b/tests/ut/cpp/pre_activate/common/pattern_to_pattern_pass_test.cc +index 0d1c651e24d..9a3cfc2c946 100644 +--- a/tests/ut/cpp/pre_activate/common/pattern_to_pattern_pass_test.cc ++++ b/tests/ut/cpp/pre_activate/common/pattern_to_pattern_pass_test.cc +@@ -223,19 +223,19 @@ TEST_F(TestPatternToPatternPass, Mul0) { + ASSERT_TRUE(check.build_pattern_map(new_node)); + + // check +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("c"), c)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("a"), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("b"), b)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("c"), c)); + ASSERT_EQ(check.m_->Get("bc")->cast()->size(), 3); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(0), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("bc")->cast()->input(0), + NewValueNode(std::make_shared(kAddOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(1), b)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(2), c)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("bc")->cast()->input(1), b)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("bc")->cast()->input(2), c)); + ASSERT_EQ(check.m_->Get("mul")->cast()->size(), 3); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(0), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(0), + NewValueNode(std::make_shared(kMulOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(1), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(2), check.m_->Get("bc"))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(1), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(2), check.m_->Get("bc"))); + } + + /// Feature: PatternToPattern Pass +@@ -263,13 +263,13 @@ TEST_F(TestPatternToPatternPass, Mul1) { + ASSERT_TRUE(check.build_pattern_map(new_node)); + + // check +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), NewValueNode(1))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("a"), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("b"), NewValueNode(1))); + ASSERT_EQ(check.m_->Get("mul")->cast()->size(), 3); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(0), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(0), + NewValueNode(std::make_shared(kMulOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(1), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(2), NewValueNode(1))); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(1), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(2), NewValueNode(1))); + } + + /// Feature: PatternToPattern Pass +@@ -297,13 +297,13 @@ TEST_F(TestPatternToPatternPass, Mul2) { + ASSERT_TRUE(check.build_pattern_map(new_node)); + + // check +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("a"), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("b"), b)); + ASSERT_EQ(check.m_->Get("mul")->cast()->size(), 3); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(0), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(0), + NewValueNode(std::make_shared(kMulOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(1), b)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(2), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(1), b)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("mul")->cast()->input(2), a)); + } + + /// Feature: PatternToPattern Pass +@@ -331,13 +331,13 @@ TEST_F(TestPatternToPatternPass, Mul3) { + ASSERT_TRUE(check.build_pattern_map(new_node)); + + // check +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("a"), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("b"), b)); + ASSERT_EQ(check.m_->Get("add")->cast()->size(), 3); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add")->cast()->input(0), ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("add")->cast()->input(0), + NewValueNode(std::make_shared(kAddOpName)))); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add")->cast()->input(1), a)); +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add")->cast()->input(2), b)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("add")->cast()->input(1), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("add")->cast()->input(2), b)); + } + + /// Feature: PatternToPattern Pass +@@ -361,7 +361,7 @@ TEST_F(TestPatternToPatternPass, EmptySeq) { + ASSERT_TRUE(check.build_pattern_map(new_node)); + + // check +- ASSERT_TRUE(opt::AnfEqual(check.m_->Get("c_a"), a)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(check.m_->Get("c_a"), a)); + } + + /// Feature: PatternToPattern Pass +diff --git a/tests/ut/cpp/pre_activate/common/source_pattern_test.cc b/tests/ut/cpp/pre_activate/common/source_pattern_test.cc +index 09043f46f60..67200c0c54f 100644 +--- a/tests/ut/cpp/pre_activate/common/source_pattern_test.cc ++++ b/tests/ut/cpp/pre_activate/common/source_pattern_test.cc +@@ -88,11 +88,11 @@ TEST_F(TestSrcPattern, Var) { + ASSERT_TRUE(build_pattern_map(mul1_cnode)); + + // check +- ASSERT_TRUE(opt::AnfEqual(m_->Get("anode1"), anode1)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("anode2"), anode2)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("anode3"), anode3)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("mul1_cnode"), mul1_cnode)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("mul2_cnode"), mul2_cnode)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("anode1"), anode1)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("anode2"), anode2)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("anode3"), anode3)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("mul1_cnode"), mul1_cnode)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("mul2_cnode"), mul2_cnode)); + } + + /// Feature: PatternToPattern Pass +@@ -118,13 +118,13 @@ TEST_F(TestSrcPattern, SeqVar) { + ASSERT_TRUE(build_pattern_map(mul1_cnode)); + + // check +- ASSERT_TRUE(opt::AnfEqual(m_->Get("anode1"), anode1)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("mul1_cnode"), mul1_cnode)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("mul2_cnode"), mul2_cnode)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("anode1"), anode1)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("mul1_cnode"), mul1_cnode)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("mul2_cnode"), mul2_cnode)); + auto &v = m_->GetSeq("Sv"); + ASSERT_EQ(v.size(), std::size_t(2)); +- ASSERT_TRUE(opt::AnfEqual(v[0], anode2)); +- ASSERT_TRUE(opt::AnfEqual(v[1], anode3)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(v[0], anode2)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(v[1], anode3)); + } + + /// Feature: PatternToPattern Pass +@@ -159,12 +159,12 @@ TEST_F(TestSrcPattern, RepeatedVar) { + ASSERT_TRUE(build_pattern_map(mul3_cnode)); + + // check +- ASSERT_TRUE(opt::AnfEqual(m_->Get("anode1"), anode1)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("anode2"), anode2)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("anode3"), anode3)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("mul1_cnode"), mul1_cnode)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("mul2_cnode"), mul2_cnode)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("mul3_cnode"), mul3_cnode)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("anode1"), anode1)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("anode2"), anode2)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("anode3"), anode3)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("mul1_cnode"), mul1_cnode)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("mul2_cnode"), mul2_cnode)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("mul3_cnode"), mul3_cnode)); + } + + /// Feature: PatternToPattern Pass +@@ -202,15 +202,15 @@ TEST_F(TestSrcPattern, RepeatedSeqVar) { + ASSERT_TRUE(build_pattern_map(mul4_cnode)); + + // check +- ASSERT_TRUE(opt::AnfEqual(m_->Get("anode1"), anode1)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("mul1_cnode"), mul1_cnode)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("mul2_cnode"), mul2_cnode)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("mul3_cnode"), mul3_cnode)); +- ASSERT_TRUE(opt::AnfEqual(m_->Get("mul4_cnode"), mul4_cnode)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("anode1"), anode1)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("mul1_cnode"), mul1_cnode)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("mul2_cnode"), mul2_cnode)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("mul3_cnode"), mul3_cnode)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(m_->Get("mul4_cnode"), mul4_cnode)); + auto &v = m_->GetSeq("Sv"); + ASSERT_EQ(v.size(), std::size_t(2)); +- ASSERT_TRUE(opt::AnfEqual(v[0], anode2)); +- ASSERT_TRUE(opt::AnfEqual(v[1], anode3)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(v[0], anode2)); ++ ASSERT_TRUE(AnfUtils::AnfEqual(v[1], anode3)); + } + + /// Feature: PatternToPattern Pass